diff --git a/sglang/.claude/skills/add-jit-kernel/SKILL.md b/sglang/.claude/skills/add-jit-kernel/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..d7f944bf5ac00128514c024bac21aa6ae3064d44 --- /dev/null +++ b/sglang/.claude/skills/add-jit-kernel/SKILL.md @@ -0,0 +1,553 @@ +--- +name: add-jit-kernel +description: Step-by-step tutorial for adding a new lightweight JIT CUDA kernel to sglang's jit_kernel module +--- + +# Tutorial: Adding a New JIT Kernel to SGLang + +This tutorial walks through adding a simple element-wise scale operation as a JIT kernel. We'll implement `scale(x, factor) = x * factor` to demonstrate the complete workflow. + +## Goal + +Add a new operation that scales each element of a tensor by a scalar factor: + +- Input: tensor `x` (CUDA) and scalar `factor` (float, passed as C++ template argument) +- Output: `x * factor` (element-wise), allocated internally +- Supported dtypes: **FP16 (`torch.float16`), BF16 (`torch.bfloat16`), FP32 (`torch.float32`)** + +## When to use JIT vs AOT (`sgl-kernel`) + +- **JIT (`jit_kernel`)**: lightweight, few dependencies, rapid iteration, compiled on first use +- **AOT (`sgl-kernel`)**: depends on CUTLASS / FlashInfer / DeepGEMM, needs pre-built wheel + +--- + +## Common Abstractions in `python/sglang/jit_kernel/include/sgl_kernel/` + +**Always prefer these abstractions over raw CUDA primitives.** They provide safety, readability, and consistency with the rest of the codebase. + +### `utils.h` — Host-side utilities + +```cpp +#include +``` + +- **`host::RuntimeCheck(cond, args...)`** — Assert a condition at runtime; throws `PanicError` with file/line info on failure. Prefer this over bare `assert`. +- **`host::Panic(args...)`** — Unconditionally throw a `PanicError` with a descriptive message. +- **`host::div_ceil(a, b)`** — Integer ceiling division `(a + b - 1) / b`. +- **`host::irange(n)`** / **`host::irange(start, end)`** — Range views for cleaner loops. +- **`host::pointer::offset(ptr, offsets...)`** — Byte-safe pointer arithmetic on `void*`. Use this instead of raw casts. + +### `utils.cuh` — Device-side utilities + `LaunchKernel` + +```cpp +#include +``` + +- **Type aliases**: `fp16_t`, `bf16_t`, `fp32_t`, `fp8_e4m3_t`, `fp8_e5m2_t` and their packed variants `fp16x2_t`, `bf16x2_t`, `fp32x2_t`, etc. +- **`SGL_DEVICE`** — Expands to `__forceinline__ __device__`. Use on all device functions. +- **`device::kWarpThreads`** — Constant `32`. +- **`device::load_as(ptr, offset)`** / **`device::store_as(ptr, val, offset)`** — Type-safe loads/stores from `void*`. +- **`device::pointer::offset(ptr, offsets...)`** — Pointer arithmetic on device. +- **`host::LaunchKernel(grid, block, device_or_stream [, smem])`** — RAII kernel launcher that: + - Resolves the CUDA stream from a `DLDevice` via TVM-FFI automatically. + - Checks the CUDA error with file/line info after launch via `operator()(kernel, args...)`. + - Supports `.enable_pdl(bool)` for PDL (Programmatic Dependent Launch, SM90+). +- **`host::RuntimeDeviceCheck(cudaError_t)`** — Check a CUDA error; throw on failure. + +### `tensor.h` — Tensor validation (`TensorMatcher`, Symbolic types) + +```cpp +#include +``` + +This is the **primary validation API** for all kernel launchers. Use it to validate every `tvm::ffi::TensorView` argument. + +- **`host::SymbolicSize{"name"}`** — A named symbolic dimension. Call `.set_value(n)` to pin it, `.unwrap()` to extract after verification. +- **`host::SymbolicDType`** — Symbolic dtype. Use `.set_options()` to restrict allowed types. +- **`host::SymbolicDevice`** — Symbolic device. Use `.set_options()` to restrict to CUDA. +- **`host::TensorMatcher({dims...})`** — Fluent builder for tensor validation: + - `.with_dtype()` — require a specific C++ type (e.g. `fp16_t`) + - `.with_dtype()` — allow a set of types + - `.with_device(device_sym)` — require CUDA, bind device to symbol + - `.with_strides({strides...})` — validate strides (omit to require contiguous) + - `.verify(tensor_view)` — execute the check; throws `PanicError` with full context on failure; **chainable** (`verify(a).verify(b)` to check multiple tensors with the same shape) + +**Typical pattern:** +```cpp +auto N = SymbolicSize{"num_elements"}; +auto device = SymbolicDevice{}; +device.set_options(); +TensorMatcher({N}) // + .with_dtype() + .with_device(device) + .verify(dst) + .verify(src); // same shape, dtype, device as dst +const size_t n = N.unwrap(); +const DLDevice dev = device.unwrap(); +``` + +### `type.cuh` — `dtype_trait` and `packed_t` + +```cpp +#include +``` + +- **`dtype_trait`** — Static trait struct for each scalar type. Provides: + - `dtype_trait::from(value)` — convert from another type (e.g. `fp32_t` → `fp16_t`) + - `dtype_trait::abs/sqrt/rsqrt/max/min(x)` — type-dispatched math (for `fp32_t`) +- **`packed_t`** — Two-element packed alias: `packed_t` = `fp16x2_t`, `packed_t` = `bf16x2_t`, `packed_t` = `fp32x2_t`. Use for vectorized loads/stores. +- **`device::cast(value)`** — Type-safe cast using `dtype_trait`, e.g. `cast(v)`. + +### `vec.cuh` — Vectorized memory access (`AlignedVector`) + +```cpp +#include +``` + +- **`device::AlignedVector`** — Aligned storage for N elements of type T. N must be a power of two, `sizeof(T)*N <= 32`. Enables 128-bit vector loads/stores for bandwidth efficiency. + - `.load(ptr, offset)` — vectorized load from `ptr[offset]` + - `.store(ptr, offset)` — vectorized store to `ptr[offset]` + - `.fill(value)` — fill all lanes + - `operator[](i)` — element access + +### `tile.cuh` — `tile::Memory` (strided memory access pattern) + +```cpp +#include +``` + +- **`device::tile::Memory::cta(blockDim.x)`** — Creates a tile accessor where each thread handles `tid = threadIdx.x` with stride `blockDim.x`. Common for loops over a 1D array. +- **`.load(ptr, offset)`** — loads `ptr[tid + offset * blockDim.x]` +- **`.store(ptr, val, offset)`** — stores to `ptr[tid + offset * blockDim.x]` +- **`.in_bound(n, offset)`** — boundary check + +### `math.cuh` — Device math (`device::math::`) + +```cpp +#include +``` + +- `device::math::max/min/abs/sqrt/rsqrt(a, b)` — type-dispatched math via `dtype_trait` +- `device::math::exp/sin/cos(float)` — fast float math wrappers + +### `warp.cuh` — Warp-level primitives + +```cpp +#include +``` + +- `device::warp::reduce_sum(value)` — warp-level sum reduction via `__shfl_xor_sync` +- `device::warp::reduce_max(value)` — warp-level max reduction + +### `cta.cuh` — CTA-level primitives + +```cpp +#include +``` + +- `device::cta::reduce_max(value, smem, min_value)` — CTA-wide max using shared memory + warp reduction. Caller is responsible for a `__syncthreads()` after if the result in `smem[0]` is needed. + +### `atomic.cuh` — Atomic operations + +```cpp +#include +``` + +- `device::atomic::max(float* addr, float value)` — float atomic max (handles negative values correctly via bit tricks). + +### `runtime.cuh` — Occupancy and device info + +```cpp +#include +``` + +- `host::runtime::get_blocks_per_sm(kernel, block_dim)` — max active blocks per SM (occupancy) +- `host::runtime::get_sm_count(device_id)` — number of SMs on the device +- `host::runtime::get_cc_major(device_id)` — compute capability major version + +**Persistent kernel pattern** (cap blocks to SM count × occupancy): +```cpp +static const uint32_t max_occ = runtime::get_blocks_per_sm(kernel, kBlockSize); +static const uint32_t num_sm = runtime::get_sm_count(device.unwrap().device_id); +const auto num_blocks = std::min(num_sm * max_occ, div_ceil(n, kBlockSize)); +LaunchKernel(num_blocks, kBlockSize, device.unwrap())(kernel, params); +``` + +--- + +## Step 0 (optional): Generate a `.clangd` config for better IDE support + +```bash +python -m sglang.jit_kernel +``` + +--- + +## Step 1: Implement the CUDA kernel in `jit_kernel/csrc/` + +Create `python/sglang/jit_kernel/csrc/elementwise/scale.cuh`. + +The implementation fully uses the project abstractions described above: + +```cpp +#include // TensorMatcher, SymbolicSize, SymbolicDevice +#include // dtype_trait, fp16_t, bf16_t, fp32_t +#include // RuntimeCheck, div_ceil +#include // LaunchKernel, SGL_DEVICE +#include // AlignedVector + +#include +#include + +namespace { + +// ---------------------------------------------------------------- +// Kernel: element-wise scale using vectorized 128-bit loads/stores +// T = fp16_t | bf16_t | fp32_t +// kVecN = number of elements per vector load (e.g. 8 for fp16) +// kFactor = scale factor encoded as kFactorNumer / kFactorDenom +// ---------------------------------------------------------------- +template +__global__ void scale_kernel(T* __restrict__ dst, + const T* __restrict__ src, + uint32_t n_vecs, + uint32_t n_remainder, + uint32_t n_total) { + constexpr float kFactor = static_cast(kFactorNumer) + / static_cast(kFactorDenom); + + using vec_t = device::AlignedVector; + + // --- vectorised body --- + const uint32_t vec_stride = blockDim.x * gridDim.x; + for (uint32_t vi = blockIdx.x * blockDim.x + threadIdx.x; + vi < n_vecs; + vi += vec_stride) { + vec_t v; + v.load(src, vi); +#pragma unroll + for (int i = 0; i < kVecN; ++i) { + v[i] = static_cast(static_cast(v[i]) * kFactor); + } + v.store(dst, vi); + } + + // --- scalar tail --- + const uint32_t base = n_vecs * kVecN; + const uint32_t scalar_stride = blockDim.x * gridDim.x; + for (uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + i < n_remainder; + i += scalar_stride) { + dst[base + i] = static_cast(static_cast(src[base + i]) * kFactor); + } +} + +// ---------------------------------------------------------------- +// Launcher: validates tensors, selects vector width, launches kernel +// ---------------------------------------------------------------- +template +void scale(tvm::ffi::TensorView dst, tvm::ffi::TensorView src) { + using namespace host; + + // 1. Validate input tensors with TensorMatcher + SymbolicSize N = {"num_elements"}; + SymbolicDevice device_; + device_.set_options(); + + TensorMatcher({N}) // + .with_dtype() + .with_device(device_) + .verify(dst) + .verify(src); // same shape / dtype / device as dst + + const uint32_t n = static_cast(N.unwrap()); + const DLDevice device = device_.unwrap(); + + RuntimeCheck(n > 0, "scale: num_elements must be > 0, got ", n); + + // 2. Choose vector width for 128-bit loads (16 bytes) + // fp16/bf16: 8 elements × 2 bytes = 16 bytes + // fp32: 4 elements × 4 bytes = 16 bytes + constexpr int kVecN = 16 / sizeof(T); + const uint32_t n_vecs = n / kVecN; + const uint32_t n_remainder = n % kVecN; + + // 3. Launch + constexpr uint32_t kBlockSize = 256; + const uint32_t grid = div_ceil(std::max(n_vecs, n_remainder), kBlockSize); + + LaunchKernel(grid, kBlockSize, device)( + scale_kernel, + static_cast(dst.data_ptr()), + static_cast(src.data_ptr()), + n_vecs, + n_remainder, + n); +} + +} // namespace +``` + +**Key points:** + +- Include headers from `sgl_kernel/` — **not** raw CUDA headers for anything already covered +- Use `TensorMatcher` for all tensor validation; never manually check shape/dtype/device +- Use `AlignedVector` for vectorised 128-bit loads/stores — significant bandwidth win +- Use `LaunchKernel` — it resolves the stream and checks errors automatically +- Use `RuntimeCheck` for runtime assertions with useful error messages +- `fp16_t` / `bf16_t` / `fp32_t` are the project's type aliases (from `utils.cuh`) +- `device::cast` or `dtype_trait::from(val)` for cross-type conversions +- `device::math::` functions for device math instead of bare `__` intrinsics + +--- + +## Step 2: Add the Python wrapper in `jit_kernel/` + +Create `python/sglang/jit_kernel/scale.py`: + +```python +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_scale_module(dtype: torch.dtype, factor_numer: int, factor_denom: int) -> Module: + """Compile and cache the JIT scale module for a given dtype and factor.""" + args = make_cpp_args(dtype, factor_numer, factor_denom) + return load_jit( + "scale", + *args, + cuda_files=["elementwise/scale.cuh"], + cuda_wrappers=[("scale", f"scale<{args}>")], + ) + + +def scale(src: torch.Tensor, factor: float, out: torch.Tensor | None = None) -> torch.Tensor: + """ + Element-wise scale: dst = src * factor. + + Supported dtypes: torch.float16, torch.bfloat16, torch.float32. + + Parameters + ---------- + src : CUDA tensor (FP16 / BF16 / FP32) + factor : scale factor + out : optional pre-allocated output tensor (same shape/dtype as src) + + Returns + ------- + Scaled tensor (dst = src * factor). + """ + assert src.is_cuda, "src must be a CUDA tensor" + assert src.dtype in (torch.float16, torch.bfloat16, torch.float32), ( + f"Unsupported dtype {src.dtype}. Supported: float16, bfloat16, float32" + ) + if out is None: + out = torch.empty_like(src) + else: + assert out.shape == src.shape, "out shape must match src" + assert out.dtype == src.dtype, "out dtype must match src" + + # Encode factor as integer ratio; denom=1000 gives 3 decimal places of precision + factor_denom = 1000 + factor_numer = round(factor * factor_denom) + + module = _jit_scale_module(src.dtype, factor_numer, factor_denom) + module.scale(out, src) + return out +``` + +**Key points:** + +- Use `cache_once` — **not** `functools.lru_cache` (incompatible with `torch.compile`) +- `load_jit` first arg(s) form the unique build marker; same marker = same cached binary +- `cuda_wrappers`: `(export_name, kernel_symbol)` — `export_name` is called from Python +- `make_cpp_args(dtype, ...)` converts `torch.dtype` to C++ type alias: + +| `torch.dtype` | C++ type | +|--------------------|------------| +| `torch.float16` | `fp16_t` | +| `torch.bfloat16` | `bf16_t` | +| `torch.float32` | `fp32_t` | + +--- + +## Step 3 (optional): Tune JIT build flags + +```python +return load_jit( + "scale", + *args, + cuda_files=["elementwise/scale.cuh"], + cuda_wrappers=[("scale", f"scale<{args}>")], + extra_cuda_cflags=["-O3", "--use_fast_math"], +) +``` + +If your kernel requires SM90+, raise a clear Python error before calling `load_jit`: + +```python +if torch.cuda.get_device_capability()[0] < 9: + raise RuntimeError("This kernel requires SM90 (Hopper) or later") +``` + +--- + +## Step 4: Write tests (required) + +Create `python/sglang/jit_kernel/tests/test_scale.py`: + +```python +import pytest +import torch +from sglang.jit_kernel.scale import scale + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("size", [1, 127, 128, 1024, 4097]) # cover tail remainder +@pytest.mark.parametrize("factor", [0.5, 1.0, 2.0, 3.0]) +def test_scale_correctness(dtype, size, factor): + src = torch.randn(size, dtype=dtype, device="cuda") + out = scale(src, factor) + expected = src * factor + + rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-2, 1e-2) + torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_scale_out_param(dtype): + src = torch.randn(1024, dtype=dtype, device="cuda") + out = torch.empty_like(src) + result = scale(src, 2.0, out=out) + assert result is out + torch.testing.assert_close(out, src * 2.0, rtol=1e-2, atol=1e-2) + + +def test_scale_cpu_error(): + src = torch.randn(128, dtype=torch.float16) # CPU tensor + with pytest.raises(AssertionError, match="CUDA"): + scale(src, 2.0) + + +def test_scale_unsupported_dtype(): + src = torch.randint(0, 10, (128,), dtype=torch.int32, device="cuda") + with pytest.raises(AssertionError, match="Unsupported dtype"): + scale(src, 2.0) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) +``` + +--- + +## Step 5: Add a benchmark (required) + +Create `python/sglang/jit_kernel/benchmark/bench_scale.py`: + +```python +import itertools + +import torch +import triton +import triton.testing + +from sglang.jit_kernel.benchmark.utils import ( + DEFAULT_DEVICE, + DEFAULT_DTYPE, + get_benchmark_range, + run_benchmark, +) +from sglang.jit_kernel.scale import scale as jit_scale + + +SIZE_LIST = get_benchmark_range( + full_range=[2**n for n in range(10, 20)], # 1K … 512K elements + ci_range=[4096, 65536], +) + +configs = list(itertools.product(SIZE_LIST)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["size"], + x_vals=configs, + line_arg="provider", + line_vals=["jit", "torch"], + line_names=["SGL JIT Kernel", "PyTorch"], + styles=[("blue", "-"), ("red", "--")], + ylabel="us", + plot_name="scale-performance", + args={}, + ) +) +def benchmark(size: int, provider: str): + src = torch.randn(size, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE) + factor = 2.0 + + if provider == "jit": + fn = lambda: jit_scale(src, factor) + else: + fn = lambda: src * factor + + return run_benchmark(fn) + + +if __name__ == "__main__": + benchmark.run(print_data=True) +``` + +Run: + +```bash +python python/sglang/jit_kernel/benchmark/bench_scale.py +``` + +--- + +## Troubleshooting + +- **JIT compilation fails**: ensure the `.cuh` file is under `python/sglang/jit_kernel/csrc/`; reduce template argument combinations +- **CUDA crash / illegal memory access**: `CUDA_LAUNCH_BLOCKING=1`; `compute-sanitizer --tool memcheck python ...` +- **Unstable benchmark results**: `run_benchmark` uses CUDA-graph-based timing by default + +--- + +## References + +- `docs/developer_guide/development_jit_kernel_guide.md` +- `python/sglang/jit_kernel/utils.py` — `cache_once`, `load_jit`, `make_cpp_args` +- `python/sglang/jit_kernel/include/sgl_kernel/tensor.h` — `TensorMatcher`, `SymbolicSize/DType/Device` +- `python/sglang/jit_kernel/include/sgl_kernel/utils.cuh` — type aliases, `LaunchKernel`, `SGL_DEVICE` +- `python/sglang/jit_kernel/include/sgl_kernel/vec.cuh` — `AlignedVector` +- `python/sglang/jit_kernel/include/sgl_kernel/tile.cuh` — `tile::Memory` +- `python/sglang/jit_kernel/include/sgl_kernel/type.cuh` — `dtype_trait`, `packed_t`, `device::cast` +- `python/sglang/jit_kernel/include/sgl_kernel/math.cuh` — `device::math::` +- `python/sglang/jit_kernel/include/sgl_kernel/warp.cuh` — `warp::reduce_sum/max` +- `python/sglang/jit_kernel/include/sgl_kernel/cta.cuh` — `cta::reduce_max` +- `python/sglang/jit_kernel/include/sgl_kernel/atomic.cuh` — `atomic::max` +- `python/sglang/jit_kernel/include/sgl_kernel/runtime.cuh` — occupancy / SM count helpers +- `python/sglang/jit_kernel/csrc/add_constant.cuh` — minimal runnable reference +- `python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh` — real example using `TensorMatcher` + `LaunchKernel` + `tile::Memory` +- `python/sglang/jit_kernel/csrc/elementwise/qknorm.cuh` — real example using `runtime::get_blocks_per_sm` + persistent kernel pattern +- `python/sglang/jit_kernel/benchmark/utils.py` — benchmark helpers + +## Summary of Files Created + +``` +python/sglang/jit_kernel/csrc/elementwise/scale.cuh # NEW: CUDA kernel +python/sglang/jit_kernel/scale.py # NEW: Python wrapper +python/sglang/jit_kernel/tests/test_scale.py # NEW: Tests +python/sglang/jit_kernel/benchmark/bench_scale.py # NEW: Benchmark +``` diff --git a/sglang/.claude/skills/add-sgl-kernel/SKILL.md b/sglang/.claude/skills/add-sgl-kernel/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..b31187613cb15bfbeb70e442fd6abb716af68cb4 --- /dev/null +++ b/sglang/.claude/skills/add-sgl-kernel/SKILL.md @@ -0,0 +1,358 @@ +--- +name: add-sgl-kernel +description: Step-by-step tutorial for adding a heavyweight AOT CUDA/C++ kernel to sgl-kernel (including tests & benchmarks) +--- + +# Tutorial: Adding a New Kernel to `sgl-kernel` (AOT / Heavyweight) + +This tutorial walks through adding a simple element-wise scale operation as an AOT kernel. We'll implement `scale(x, factor) = x * factor` to demonstrate the complete workflow. + +## Goal + +Add a new operation that scales each element of a tensor by a scalar factor: + +- Input: tensor `x` (CUDA) and scalar `factor` (float) +- Output: `x * factor` (element-wise, in-place or into pre-allocated `out`) +- Supported dtypes: **FP16 (`torch.float16`), BF16 (`torch.bfloat16`), FP32 (`torch.float32`)** + - Dispatched via `DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16` macro (defined in `sgl-kernel/include/utils.h`) + +## Two rules of thumb (must follow) + +1. **Heavyweight kernels go to `sgl-kernel`.** If it depends on CUTLASS / FlashInfer / DeepGEMM (or similarly heavy stacks), implement it in `sgl-kernel/`. +2. **Lightweight kernels go to `python/sglang/jit_kernel`.** If it is small, has few dependencies, and benefits from rapid iteration, implement it as a JIT kernel instead. + +In addition, every new kernel must ship with: + +- **Tests** (pytest) +- **A benchmark script** (triton.testing) + +--- + +## Repository integration map + +You will typically touch these files/areas: + +- Implementation: `sgl-kernel/csrc/elementwise/scale.cu` (pick the right subdirectory) +- Public declarations: `sgl-kernel/include/sgl_kernel_ops.h` +- Torch extension registration: `sgl-kernel/csrc/common_extension.cc` +- Build: `sgl-kernel/CMakeLists.txt` (`set(SOURCES ...)`) +- Python API: `sgl-kernel/python/sgl_kernel/` and `sgl-kernel/python/sgl_kernel/__init__.py` +- Tests: `sgl-kernel/tests/test_scale.py` +- Benchmarks: `sgl-kernel/benchmark/bench_scale.py` + +--- + +## Step 1: Implement the kernel in `csrc/` + +Pick the right subdirectory: + +- `csrc/elementwise/` — for element-wise ops (our example) +- `csrc/gemm/`, `csrc/attention/`, `csrc/moe/` — for other categories + +Create `sgl-kernel/csrc/elementwise/scale.cu`: + +```cpp +#include +#include +#include + +#include "utils.h" // DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16 + +// scale_kernel: out[i] = input[i] * factor +// Supports float, half (__half), __nv_bfloat16 via template T +template +__global__ void scale_kernel(T* __restrict__ out, + const T* __restrict__ input, + float factor, + int64_t n) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = static_cast(static_cast(input[idx]) * factor); + } +} + +void scale(at::Tensor& out, const at::Tensor& input, double factor) { + TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + TORCH_CHECK(out.is_cuda(), "out must be a CUDA tensor"); + TORCH_CHECK(out.is_contiguous(), "out must be contiguous"); + TORCH_CHECK(out.sizes() == input.sizes(), "out and input must have the same shape"); + TORCH_CHECK(out.scalar_type() == input.scalar_type(), + "out and input must have the same dtype"); + + const int64_t n = input.numel(); + const int threads = 256; + const int blocks = (n + threads - 1) / threads; + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + + // Dispatches over float, float16, bfloat16 + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { + scale_kernel<<>>( + static_cast(out.data_ptr()), + static_cast(input.data_ptr()), + static_cast(factor), + n); + cudaError_t status = cudaGetLastError(); + TORCH_CHECK(status == cudaSuccess, + "scale_kernel launch failed: ", cudaGetErrorString(status)); + return true; + }); +} +``` + +**Key points:** + +- Use `at::Tensor` (PyTorch tensors), `TORCH_CHECK` for validation, `at::cuda::getCurrentCUDAStream()` for stream +- `DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16` covers `float`, `half` (FP16), `__nv_bfloat16` (BF16) +- Add device error checking after every kernel launch +- If a kernel only works on certain architectures, enforce that with `TORCH_CHECK` and skip logic in tests + +--- + +## Step 2: Add a C++ declaration in `include/sgl_kernel_ops.h` + +Edit `sgl-kernel/include/sgl_kernel_ops.h`, add to the elementwise section: + +```cpp +void scale(at::Tensor& out, const at::Tensor& input, double factor); +``` + +--- + +## Step 3: Register the op in `csrc/common_extension.cc` + +Edit `sgl-kernel/csrc/common_extension.cc`, inside `TORCH_LIBRARY_FRAGMENT(sgl_kernel, m)`: + +```cpp +// From csrc/elementwise +m.def("scale(Tensor! out, Tensor input, float factor) -> ()"); +m.impl("scale", torch::kCUDA, &scale); +``` + +**Key points:** + +- `Tensor!` means in-place / mutable output argument +- The schema is important for `torch.compile` and for consistent call signatures +- If your underlying C++ API uses `float` but PyTorch bindings expect `double`, the implicit cast is fine for scalars; use shims if needed for other types + +--- + +## Step 4: Add the new source file to `CMakeLists.txt` + +Edit `sgl-kernel/CMakeLists.txt`, add to `set(SOURCES ...)`: + +```cmake +csrc/elementwise/scale.cu +``` + +**Key points:** + +- Keep the list **alphabetically sorted** (the file explicitly requires this) +- If the kernel has arch constraints, reflect that in tests/benchmarks via skip logic + +--- + +## Step 5: Expose a Python API under `sgl-kernel/python/sgl_kernel/` + +In `sgl-kernel/python/sgl_kernel/__init__.py`, add: + +```python +from torch.ops import sgl_kernel as _ops + +def scale(out: torch.Tensor, input: torch.Tensor, factor: float) -> None: + """ + Element-wise scale: out = input * factor (in-place into out). + + Supported dtypes: torch.float16, torch.bfloat16, torch.float32. + + Parameters + ---------- + out : pre-allocated CUDA output tensor (same shape/dtype as input) + input : CUDA input tensor + factor : scale factor (float) + """ + _ops.scale(out, input, factor) +``` + +Or export it from the existing module organisation — follow the pattern already used by similar ops in `__init__.py`. + +--- + +## Step 6: Write tests (required) + +Create `sgl-kernel/tests/test_scale.py`: + +```python +import pytest +import torch +import sgl_kernel + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("size", [128, 1024, 4096, 65536]) +@pytest.mark.parametrize("factor", [0.5, 1.0, 2.0]) +def test_scale_correctness(dtype, size, factor): + input = torch.randn(size, dtype=dtype, device="cuda") + out = torch.empty_like(input) + + sgl_kernel.scale(out, input, factor) + + expected = input * factor + rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-2, 1e-2) + torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) + + +def test_scale_shape_mismatch(): + input = torch.randn(128, dtype=torch.float16, device="cuda") + out = torch.empty(256, dtype=torch.float16, device="cuda") + with pytest.raises(RuntimeError, match="same shape"): + sgl_kernel.scale(out, input, 2.0) + + +def test_scale_cpu_input(): + input = torch.randn(128, dtype=torch.float16) # CPU + out = torch.empty_like(input) + with pytest.raises(RuntimeError, match="CUDA"): + sgl_kernel.scale(out, input, 2.0) + + +if __name__ == "__main__": + pytest.main([__file__, "-q"]) +``` + +Run: + +```bash +pytest sgl-kernel/tests/test_scale.py -q +``` + +--- + +## Step 7: Add a benchmark (required) + +Create `sgl-kernel/benchmark/bench_scale.py`: + +```python +import itertools +import os + +import torch +import triton +import triton.testing + +import sgl_kernel + +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + +dtypes = [torch.float16] if IS_CI else [torch.float16, torch.bfloat16, torch.float32] +sizes = [4096] if IS_CI else [2**n for n in range(10, 20)] # 1K … 512K +factors = [2.0] + +configs = list(itertools.product(dtypes, sizes)) + + +def torch_scale(input: torch.Tensor, factor: float) -> torch.Tensor: + return input * factor + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["dtype", "size"], + x_vals=configs, + line_arg="provider", + line_vals=["sglang", "torch"], + line_names=["SGL Kernel", "PyTorch"], + styles=[("green", "-"), ("red", "--")], + ylabel="µs (median)", + plot_name="scale-performance", + args={}, + ) +) +def benchmark(dtype, size, provider): + input = torch.randn(size, dtype=dtype, device="cuda") + out = torch.empty_like(input) + factor = 2.0 + + if provider == "sglang": + fn = lambda: sgl_kernel.scale(out, input, factor) + else: + fn = lambda: torch_scale(input, factor) + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + fn, quantiles=[0.5, 0.2, 0.8] + ) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + benchmark.run(print_data=True) +``` + +Run: + +```bash +python sgl-kernel/benchmark/bench_scale.py +``` + +--- + +## Step 8: Build and validate + +Build: + +```bash +cd sgl-kernel +make build -j16 +``` + +If you need to limit host resource usage: + +```bash +cd sgl-kernel +make build -j1 MAX_JOBS=2 CMAKE_ARGS="-DSGL_KERNEL_COMPILE_THREADS=1" +``` + +Validate: + +```bash +pytest sgl-kernel/tests/test_scale.py -q +python sgl-kernel/benchmark/bench_scale.py +``` + +--- + +## Troubleshooting + +- **Async CUDA errors**: `CUDA_LAUNCH_BLOCKING=1` +- **Memory errors**: `compute-sanitizer --tool memcheck python ...` +- **Build is too slow / OOM**: reduce `MAX_JOBS` and `SGL_KERNEL_COMPILE_THREADS` +- **Binary bloat**: use `sgl-kernel/analyze_whl_kernel_sizes.py` +- **CMake sources list**: if your `.cu` file is missing from `SOURCES`, the symbol will be undefined at link time + +--- + +## References + +- `sgl-kernel/README.md` +- `sgl-kernel/include/sgl_kernel_ops.h` +- `sgl-kernel/csrc/common_extension.cc` +- `sgl-kernel/CMakeLists.txt` +- `sgl-kernel/include/utils.h` — `DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16` macro and friends +- `sgl-kernel/csrc/elementwise/activation.cu` — reference for the FP16/BF16/FP32 dispatch pattern + +## Summary of Files Created/Modified + +``` +sgl-kernel/csrc/elementwise/scale.cu # NEW: CUDA kernel + launcher +sgl-kernel/include/sgl_kernel_ops.h # MODIFIED: C++ declaration +sgl-kernel/csrc/common_extension.cc # MODIFIED: schema + dispatch registration +sgl-kernel/CMakeLists.txt # MODIFIED: add source file (alphabetical) +sgl-kernel/python/sgl_kernel/__init__.py # MODIFIED: export Python API +sgl-kernel/tests/test_scale.py # NEW: tests +sgl-kernel/benchmark/bench_scale.py # NEW: benchmark +``` diff --git a/sglang/.claude/skills/sglang-bisect-ci-regression/SKILL.md b/sglang/.claude/skills/sglang-bisect-ci-regression/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..4eb39227c16e91df9a8fd198f3a90c0e4cfec438 --- /dev/null +++ b/sglang/.claude/skills/sglang-bisect-ci-regression/SKILL.md @@ -0,0 +1,219 @@ +# SGLang Bisect CI Regression + +Investigate a consistently failing CI test to find the root cause - whether it's a code regression from a specific PR, a hardware/runner-specific issue, or an environment change. Optionally reproduce the failure on a remote GPU server. + +## Slash Command + +`/sglang-bisect-ci-regression [ssh_target] [docker_container]` + +## When to Use This Skill + +- A CI test is failing consistently on main (scheduled runs) +- You need to find which PR introduced a regression +- You suspect a runner-specific or GPU-specific issue +- You want to reproduce a CI failure on a remote server + +## Arguments + +- **First argument (required)**: Test file name (e.g. `test_lora_tp.py`) or a GitHub Actions job URL +- **Second argument (optional)**: SSH target for remote reproduction (e.g. `user@host`) +- **Third argument (optional)**: Docker container name on the SSH target (e.g. `sglang_dev`) + +If SSH target and docker container are not provided, the skill will only perform the CI log analysis and bisection, without remote reproduction. **Ask the user** for these if reproduction is needed and they weren't provided. + +## Background: Scheduled CI Runs + +SGLang uses the `pr-test.yml` workflow with **scheduled runs** (cron-triggered) to periodically test the `main` branch. These runs are the primary data source for detecting regressions: + +- **Workflow**: `pr-test.yml` with `event: schedule` +- **Branch**: `main` +- **Dashboard**: https://github.com/sgl-project/sglang/actions/workflows/pr-test.yml?query=event%3Aschedule +- **Frequency**: Runs multiple times daily, each pinned to the HEAD of `main` at trigger time +- **Purpose**: Catches regressions that slip through PR-level CI (e.g., interaction bugs between merged PRs, hardware-specific issues) + +Always use these scheduled runs (not PR-triggered runs) when bisecting regressions on `main`. The `--event schedule` filter in `gh run list` ensures you only see these periodic main-branch runs. + +## Workflow + +### Phase 1: Extract the Failure Signature + +1. **Get the failing test details from CI logs.** If given a URL, fetch logs directly. If given a test name, find recent scheduled runs of `pr-test.yml` on `main` that failed: + +```bash +# List recent scheduled runs targeting main (the primary source of truth for regressions) +# These are cron-triggered runs visible at: +# https://github.com/sgl-project/sglang/actions/workflows/pr-test.yml?query=event%3Aschedule +gh run list --repo sgl-project/sglang --workflow="pr-test.yml" --event schedule --branch main --limit 20 --json databaseId,conclusion,createdAt,headSha + +# Find the job containing the test +gh run view {RUN_ID} --repo sgl-project/sglang --json jobs --jq '.jobs[] | select(.conclusion == "failure") | {name, conclusion, databaseId}' + +# Get the failure details +gh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -E -B 5 -A 30 "AssertionError|FAIL|Error|{TEST_NAME}" +``` + +2. **Record the failure signature:** + - Exact error message and assertion + - Affected test method name + - Model/config involved + - Numeric values (e.g., tolerance diffs, scores) + - Whether the failure is deterministic (same values across runs) + +### Phase 2: Temporal Bisection + +3. **Find the boundary between passing and failing runs.** Walk through the scheduled run history (from the `pr-test.yml` schedule runs on `main`) to identify: + - Last known PASSING run (sha + date) + - First known FAILING run (sha + date) + +```bash +# For each scheduled run, check the specific partition/job status +gh run view {RUN_ID} --repo sgl-project/sglang --json jobs --jq '.jobs[] | select(.name == "{JOB_NAME}") | {conclusion, databaseId}' + +# Verify a specific test passed or failed in a run +gh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -E "{TEST_NAME}|PASSED|FAILED|logprobs mismatch" | head -10 +``` + +4. **List commits between the boundary:** + +```bash +git log --oneline {LAST_PASS_SHA}..{FIRST_FAIL_SHA} +``` + +5. **Filter for relevant commits** that touch files related to the failing test (model layers, kernels, test utilities, etc.): + +```bash +git log --oneline {LAST_PASS_SHA}..{FIRST_FAIL_SHA} -- {relevant_paths} +``` + +### Phase 3: Runner/Hardware Analysis + +6. **Check if the failure is runner-specific.** Extract the runner identity from each failing and passing run: + +```bash +# Get runner name and machine +gh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -E "Runner name|Machine name" | head -5 + +# Get GPU/driver info +gh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -i -E "NVIDIA-SMI|Driver Version|CUDA Version" | head -5 + +# Get package versions +gh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -E "sgl.kernel.*==|flashinfer.*==" | head -5 +``` + +7. **Correlate runners with pass/fail outcomes.** Build a table: + +| Run ID | Date | Runner | GPU Type | Driver | Result | +|--------|------|--------|----------|--------|--------| + +If all failures map to a specific runner type/GPU and all passes map to another, the issue is **hardware-specific**, not a code regression. + +### Phase 4: Code Analysis + +8. **If a code regression is suspected** (failures not runner-specific), examine the candidate commits: + - Read the changed files + - Understand how the changes could affect the failing test + - Look for prefill-vs-decode differences, TP-specific paths, kernel changes + +9. **If a hardware issue is suspected**, analyze: + - Kernel compatibility (CUDA compute capability) + - Driver version differences + - All-reduce / NCCL behavior differences + - CUDA graph capture differences across GPU architectures + +### Phase 5: Remote Reproduction (Optional) + +Only if SSH target and docker container were provided. + +10. **Verify the remote environment:** + +```bash +ssh {SSH_TARGET} "docker exec {CONTAINER} nvidia-smi --query-gpu=name,driver_version --format=csv" +ssh {SSH_TARGET} "docker exec {CONTAINER} pip show sgl-kernel sglang flashinfer-python 2>&1 | grep -E 'Name:|Version:'" +``` + +11. **Ensure latest code is installed.** If the container is stale, update: + +```bash +# Try fetching latest main +ssh {SSH_TARGET} "docker exec {CONTAINER} bash -c 'cd /path/to/sglang && git fetch origin main && git checkout origin/main'" +# Or download and install from tarball if git auth fails +ssh {SSH_TARGET} "docker exec {CONTAINER} bash -c 'cd /tmp && curl -L https://github.com/sgl-project/sglang/archive/refs/heads/main.tar.gz | tar xz && cd sglang-main && pip install -e \"python[all]\"'" +# Reinstall (after git fetch) +ssh {SSH_TARGET} "docker exec {CONTAINER} bash -c 'cd /path/to/sglang && pip install -e \"python[all]\"'" +# Install test dependencies if needed +ssh {SSH_TARGET} "docker exec {CONTAINER} pip install peft rouge-score" +``` + +12. **Create a minimal reproduction script** that: + - Uses `if __name__ == '__main__'` with `mp.set_start_method("spawn")` + - Runs the specific failing test configuration + - Prints key metrics (diffs, scores, outputs) + - Exits with code 1 on failure + +13. **Copy and run the reproduction script:** + +```bash +scp /tmp/repro_script.py {SSH_TARGET}:/tmp/ +ssh {SSH_TARGET} "docker cp /tmp/repro_script.py {CONTAINER}:/tmp/" +ssh {SSH_TARGET} "docker exec -e CUDA_VISIBLE_DEVICES=0,1 {CONTAINER} python3 /tmp/repro_script.py" +``` + +14. **Run control experiments** to isolate the variable: + - If suspecting TP issue: run with TP=1 as control + - If suspecting GPU issue: compare same code on different GPU + - If suspecting a specific commit: test before/after that commit + +### Phase 6: Report + +15. **Produce a structured report:** + +```markdown +## CI Regression Bisection Report + +### Failure Signature +- **Test**: {test_file}::{test_method} +- **Error**: {exact error message} +- **Key metrics**: {numeric values} +- **Deterministic**: Yes/No + +### Root Cause Classification +One of: +- **Code Regression**: PR #{number} introduced the bug +- **Hardware-Specific**: Fails on {GPU_TYPE}, passes on others +- **Environment Change**: New runner/driver/package version +- **Pre-existing Flakiness**: Intermittent, not a new regression + +### Evidence +| Condition | Result | +|-----------|--------| +| {condition1} | PASS/FAIL | +| {condition2} | PASS/FAIL | + +### Timeline +- {date}: Last known pass ({sha}, {runner}) +- {date}: First known fail ({sha}, {runner}) +- {date}: Confirmed reproduction on {server} + +### Recommended Fix +- **Short-term**: {workaround} +- **Long-term**: {proper fix} +``` + +## Key Patterns to Recognize + +| Pattern | Diagnosis | +|---------|-----------| +| Same SHA passes on runner A, fails on runner B | Hardware/runner-specific | +| All runners fail after commit X | Code regression from commit X | +| Intermittent - same runner sometimes passes/fails | Flaky test or race condition | +| Prefill OK but decode fails | TP/all-reduce issue in decode path | +| Works with TP=1, fails with TP>1 | Tensor parallelism bug | +| Exact same numeric diff every time | Deterministic bug, not flakiness | + +## Important Notes + +- **Always check runner identity** before concluding it's a code regression. Many "consistent" failures are actually runner-specific. +- **Test partition assignments change over time** as tests are added/removed. A test may move between partitions, landing on different runner types. +- **H200 runners** use `/root/actions-runner/` path and machine names like `gpu-h200-worker-*`. Non-H200 runners use `/public_sglang_ci/runner-*` paths. +- When running remote reproduction, use `run_in_background` for long-running tests and check output with `TaskOutput`. +- Container environments may be stale - always verify package versions match CI before drawing conclusions. diff --git a/sglang/.claude/skills/write-sglang-test/SKILL.md b/sglang/.claude/skills/write-sglang-test/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..030b451a988a5e0f0949c897343ae9dbe557aa70 --- /dev/null +++ b/sglang/.claude/skills/write-sglang-test/SKILL.md @@ -0,0 +1,248 @@ +--- +name: write-sglang-test +description: Guide for writing SGLang CI/UT tests following project conventions. Covers CustomTestCase, CI registration, server fixtures, model selection, and test placement. Use when creating new tests, adding CI test cases, writing unit tests, or when the user asks to add tests for SGLang features. +--- + +# Writing SGLang CI / UT Tests + +## Core Rules + +1. **Always use `CustomTestCase`** — never raw `unittest.TestCase` +2. **Place tests in `test/registered//`** — only use `test/manual/` for debugging / non-CI tests +3. **Reuse server fixtures** — inherit from `DefaultServerBase` or write `setUpClass`/`tearDownClass` with `popen_launch_server` +4. **Smallest model for model-agnostic functionality** — use `DEFAULT_SMALL_MODEL_NAME_FOR_TEST` (Llama-3.2-1B-Instruct) for basic features that don't depend on model size +5. **8B for general performance** — use `DEFAULT_MODEL_NAME_FOR_TEST` (Llama-3.1-8B-Instruct, single-node) for performance tests that don't involve spec / DP / parallelism +6. **Bigger features → discuss case by case** — spec, DP attention, tensor/pipeline parallelism etc. may need multi-GPU suites and specific models + +--- + +## Test File Template + +### Functional correctness test (small model) + +```python +import unittest + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +register_cuda_ci(est_time=60, suite="stage-b-test-small-1-gpu") + + +class TestMyFeature(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--arg1", "value1"], # feature-specific args + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_basic_functionality(self): + response = requests.post( + self.base_url + "/generate", + json={"text": "Hello", "sampling_params": {"max_new_tokens": 32}}, + ) + self.assertEqual(response.status_code, 200) + + +if __name__ == "__main__": + unittest.main(verbosity=3) +``` + +### General performance test (8B model, single node, no spec/DP/parallelism) + +```python +import time +import unittest + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +register_cuda_ci(est_time=300, suite="stage-b-test-large-1-gpu") + + +class TestMyFeaturePerf(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_latency(self): + start = time.perf_counter() + response = requests.post( + self.base_url + "/generate", + json={"text": "Hello", "sampling_params": {"max_new_tokens": 128}}, + ) + elapsed = time.perf_counter() - start + self.assertEqual(response.status_code, 200) + self.assertLess(elapsed, 5.0, "Latency exceeded threshold") + + +if __name__ == "__main__": + unittest.main(verbosity=3) +``` + +--- + +## Server Fixture Reuse + +For tests that only need a standard server, inherit from `DefaultServerBase` and override class attributes: + +```python +from sglang.test.server_fixtures.default_fixture import DefaultServerBase + +class TestMyFeature(DefaultServerBase): + model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + other_args = ["--enable-my-feature"] + + def test_something(self): + ... +``` + +Available fixtures in `python/sglang/test/server_fixtures/`: + +| Fixture | Use case | +|---------|----------| +| `DefaultServerBase` | Standard single-server tests | +| `EagleServerBase` | EAGLE speculative decoding | +| `PDDisaggregationServerBase` | Disaggregated prefill/decode | +| `MMMUServerBase` | Multimodal VLM tests | + +--- + +## CI Registration + +Every test file in `test/registered/` **must** call a registration function at module level: + +```python +from sglang.test.ci.ci_register import register_cuda_ci, register_amd_ci + +register_cuda_ci(est_time=60, suite="stage-b-test-small-1-gpu") +register_amd_ci(est_time=60, suite="stage-b-test-small-1-gpu-amd") # optional +``` + +Parameters: +- `est_time`: estimated runtime in seconds (used for CI partitioning) +- `suite`: which CI suite to run in (see below) +- `nightly=True`: for nightly-only tests (default `False` = per-commit) +- `disabled="reason"`: temporarily disable with explanation + +### Suite selection guide + +**Default cases (1 GPU):** + +| Scenario | Model | Suite | +|----------|-------|-------| +| Model-agnostic basic functionality | 1B (smallest) | `stage-b-test-small-1-gpu` | +| General performance (no spec/DP/parallelism) | 8B | `stage-b-test-large-1-gpu` | + +**Bigger features (case by case):** + +| Scenario | Suite | +|----------|-------| +| 2 GPU (e.g. TP=2) | `stage-b-test-large-2-gpu` | +| 4 GPU (H100) | `stage-c-test-4-gpu-h100` | +| 8 GPU (H200) | `stage-c-test-8-gpu-h200` | +| Nightly, 1 GPU | `nightly-1-gpu` | +| Nightly, 8 GPU | `nightly-8-gpu` | + +For spec, DP attention, parallelism, disaggregation, etc., discuss with the team to determine the appropriate suite and GPU configuration. + +--- + +## Model Constants + +All defined in `python/sglang/test/test_utils.py`: + +| Constant | Model | When to use | +|----------|-------|-------------| +| `DEFAULT_SMALL_MODEL_NAME_FOR_TEST` | Llama-3.2-1B-Instruct | Model-agnostic basic functionality | +| `DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE` | Llama-3.2-1B | Base (non-instruct) model tests | +| `DEFAULT_MODEL_NAME_FOR_TEST` | Llama-3.1-8B-Instruct | General performance (single node) | +| `DEFAULT_MOE_MODEL_NAME_FOR_TEST` | Mixtral-8x7B-Instruct | MoE-specific tests | +| `DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST` | — | Embedding tests | +| `DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST` | — | Vision-language tests | + +--- + +## Test Placement + +``` +test/ +├── registered/ # CI tests (auto-discovered by run_suite.py) +│ ├── sampling/ # test_penalty.py, test_sampling_params.py ... +│ ├── sessions/ # test_session_control.py ... +│ ├── openai_server/ # basic/, features/, validation/ ... +│ ├── spec/ # eagle/, utils/ ... +│ ├── models/ # model-specific accuracy tests +│ ├── perf/ # performance benchmarks +│ └── / # create new category if needed +├── manual/ # Non-CI: debugging, one-off, manual verification +└── run_suite.py # CI runner (scans registered/ only) +``` + +**Decision rule**: if the test should run in CI → `registered/`. If it's for local debugging or requires special hardware not in CI → `manual/`. + +--- + +## Key Utilities + +```python +from sglang.test.test_utils import ( + CustomTestCase, # base class with retry logic + popen_launch_server, # launch server subprocess + DEFAULT_URL_FOR_TEST, # auto-configured base URL + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, # 600s default + run_bench_serving, # benchmark helper (launch + bench) +) +from sglang.srt.utils import kill_process_tree # cleanup server +``` + +--- + +## Checklist + +Before submitting a test: + +- [ ] Inherits from `CustomTestCase` (not `unittest.TestCase`) +- [ ] Has `register_*_ci(...)` call at module level +- [ ] Placed in `test/registered//` +- [ ] Model selection: smallest for model-agnostic features, 8B for general perf, case-by-case for other complex features +- [ ] `setUpClass` launches server, `tearDownClass` kills it +- [ ] Has `if __name__ == "__main__": unittest.main(verbosity=3)` +- [ ] `est_time` is reasonable (measure locally) diff --git a/sglang/benchmark/json_jump_forward/README.md b/sglang/benchmark/json_jump_forward/README.md new file mode 100644 index 0000000000000000000000000000000000000000..38fb67e89bd5cb6182f8748f7e2b09e450ad15d3 --- /dev/null +++ b/sglang/benchmark/json_jump_forward/README.md @@ -0,0 +1,88 @@ +## Run benchmark + +### Dependencies + +``` +llama_cpp_python 0.2.38 +guidance 0.1.10 +vllm 0.2.7 +outlines 0.0.25 +``` + +### Build dataset + +When benchmarking long document information retrieval, run the following command to build the dataset: + +```bash +pip install wikipedia +python3 build_dataset.py +``` + +### Benchmark sglang + +Run Llama-7B + +```bash +python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +Benchmark Character Generation + +```bash +python3 bench_sglang.py --mode character +``` + +Benchmark City Information Retrieval + +```bash +python3 bench_sglang.py --mode city +``` + + +### Benchmark Outlines + vLLM + +Run Llama-7B + +```bash +python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000 +``` + +Benchmark Character Generation + +```bash +python3 bench_other.py --mode character --backend outlines +``` + +Benchmark City Information Retrieval + +```bash +python3 bench_other.py --mode city --backend outlines +``` + +### Benchmark guidance + +Run Llama-7B and benchmark character generation + +```bash +python3 bench_other.py --mode character --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf +``` + +Run Llama-7B and benchmark city information retrieval + +```bash +python3 bench_other.py --mode city --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf +``` + +### Benchmark lmql + +Run Llama-7B and benchmark character generation + +``` +python3 bench_other.py --mode character --backend lmql --parallel 1 +``` + +Run Llama-7B and benchmark city information retrieval + +``` +python3 bench_other.py --mode city --backend lmql --parallel 1 +``` diff --git a/sglang/benchmark/json_jump_forward/bench_other.py b/sglang/benchmark/json_jump_forward/bench_other.py new file mode 100644 index 0000000000000000000000000000000000000000..a64e950d7c6e6159e899583a44d925adf592b4d5 --- /dev/null +++ b/sglang/benchmark/json_jump_forward/bench_other.py @@ -0,0 +1,288 @@ +import argparse +import json +import time +from concurrent.futures import ThreadPoolExecutor +from functools import partial + +import guidance +from tqdm import tqdm + +from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate +from sglang.utils import dump_state_text, read_jsonl + +# there are some FSM bugs with json regex converted from pydantic model +# here use a string regex instead +# regex_string = build_regex_from_object(HarryPoterRole) +character_regex = ( + r"""\{\n""" + + r""" "name": "[\w\d\s]{1,16}",\n""" + + r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n""" + + r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n""" + + r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n""" + + r""" "wand": \{\n""" + + r""" "wood": "[\w\d\s]{1,16}",\n""" + + r""" "core": "[\w\d\s]{1,16}",\n""" + + r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n""" + + r""" \},\n""" + + r""" "alive": "(Alive|Deceased)",\n""" + + r""" "patronus": "[\w\d\s]{1,16}",\n""" + + r""" "bogart": "[\w\d\s]{1,16}"\n""" + + r"""\}""" +) + +city_regex = ( + r"""\{\n""" + + r""" "name": "[\w\d\s]{1,16}",\n""" + + r""" "country": "[\w\d\s]{1,16}",\n""" + + r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n""" + + r""" "population": [-+]?[0-9]{1,9},\n""" + + r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n""" + + r"""\}""" +) + +# fmt: off +def character_gen(name, generate): + s = name + " is a character in Harry Potter. Please fill in the following information about this character.\n" + s += generate(s, max_tokens=256, regex=character_regex) + return s +# fmt: on + +# fmt: off +def city_gen(document, generate): + s = "Please extract the information of a city from the following wikipedia page.\n" + s += "Page begin.\n" + document + "Page end.\n" + s += "Here is the name, country, and symbol of the city in JSON format.\n" + s += generate(s, max_tokens=256, regex=city_regex) + return s +# fmt: on + + +@guidance +def character_maker(lm, name): + regex_str_no_quote = r"[\w\d\s]+" + regex_float = r"[0-9]+\.[0-9]+" + lm += f"""\ + {name} is a character in Harry Potter. Please fill in the following information about this character. + {{ + "name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}", + "house": "{guidance.select(options=['Gryffindor', 'Slytherin', 'Ravenclaw', 'Hufflepuff'], name='house')}", + "blood status": "{guidance.select(options=['Pure-blood', 'Half-blood', 'Muggle-born'], name='blood status')}", + "occupation": "{guidance.select(options=['student', 'teacher', 'auror', 'ministry of magic', 'death eater', 'order of the phoenix'], name='occupation')}", + "wand": {{ + "wood": "{guidance.gen("wood", max_tokens=16, regex=regex_str_no_quote)}", + "core": "{guidance.gen('core', max_tokens=16, regex=regex_str_no_quote)}", + "length": {guidance.gen('length', max_tokens=10, regex=regex_float)} + }}, + "alive": "{guidance.select(options=['Alive', 'Deceased'], name='alive')}", + "patronus": "{guidance.gen('patronus', max_tokens=16, regex=regex_str_no_quote)}", + "bogart": "{guidance.gen('bogart', max_tokens=16, regex=regex_str_no_quote)}" + }} + """ + + return lm + + +async def call_generate_lmql( + prompt, temperature, max_tokens, regex, max_len=4096, model=None, **kwargs +): + assert model is not None + import lmql + + @lmql.query(model=model) + async def program(question, max_tokens, regex): + '''lmql + """{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and REGEX(ANSWER, regex) + return ANSWER + ''' + + return await program( + question=prompt, + temperature=temperature, + max_tokens=max_tokens, + max_len=max_len, + regex=regex, + **kwargs, + ) + + +@guidance +def city_maker(lm, document): + regex_str_no_quote = r"[\w\d\s]+" + regex_float = r"[0-9]+\.[0-9]+" + lm += f"""\ + Please extract the information of a city from the following wikipedia page. + Page begin. + {document} + Page end. + Here is the name, country, and symbol of the city in JSON format. + {{ + "name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}", + "country": "{guidance.gen("country", max_tokens=16, regex=regex_str_no_quote)}", + "latitude": {guidance.gen("latitude", max_tokens=10, regex=regex_float)}, + "population": {guidance.gen("population", max_tokens=10, regex=r"[0-9]+")}, + "top 3 landmarks": [ + "{guidance.gen("landmark1", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark2", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark3", max_tokens=16, regex=regex_str_no_quote)}" + ] + }} + """ + + return lm + + +def bench_character(args): + arguments = [] + with open(args.data_path, "r") as f: + for line in f: + arguments.append({"name": line.strip()}) + arguments = arguments[: args.num_jsons] + + states = [None] * len(arguments) + + # Select backend + if args.backend == "outlines": + call_generate = partial(get_call_generate(args), temperature=0) + + def get_one_answer(i): + states[i] = character_gen(**arguments[i], generate=call_generate) + + elif args.backend == "guidance": + model = guidance.models.LlamaCpp( + args.model_path, + n_gpu_layers=-1, + n_ctx=args.n_ctx, + ) + + def get_one_answer(i): + lm = model + character_maker(**arguments[i]) + states[i] = lm + + elif args.backend == "lmql": + import asyncio + + import lmql + + model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}") + call_generate = partial( + call_generate_lmql, + model=model, + max_tokens=256, + regex=character_regex, + ) + + async def get_one_answer_async(i): + states[i] = await call_generate(prompt=arguments[i]["name"], temperature=0) + + else: + raise ValueError(f"Invalid backend: {args.backend}") + + tic = time.perf_counter() + + if args.backend != "lmql": + if args.parallel == 1: + for i in tqdm(range(len(arguments))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + rets = list( + tqdm( + executor.map(get_one_answer, list(range(len(arguments)))), + total=len(arguments), + ) + ) + for _ in rets: + pass + else: + batches = [] + for i in range(0, len(arguments), args.parallel): + batches.append(list(range(i, min(i + args.parallel, len(arguments))))) + loop = asyncio.get_event_loop() + + for bt in tqdm(batches): + loop.run_until_complete( + asyncio.gather(*[get_one_answer_async(i) for i in bt]) + ) + + latency = time.perf_counter() - tic + + return states, latency + + +def bench_city_doc(args): + arguments = [] + for line in read_jsonl(args.data_path): + arguments.append({"document": line["document"]}) + arguments = arguments[: args.num_jsons] + + states = [None] * len(arguments) + + # Select backend + if args.backend == "outlines": + call_generate = partial(get_call_generate(args), temperature=0) + + def get_one_answer(i): + states[i] = city_gen(**arguments[i], generate=call_generate) + + elif args.backend == "guidance": + model = guidance.models.LlamaCpp( + args.model_path, + n_gpu_layers=-1, + n_ctx=args.n_ctx, + ) + + def get_one_answer(i): + lm = model + city_maker(**arguments[i]) + states[i] = lm + + else: + raise ValueError(f"Invalid backend: {args.backend}") + + tic = time.perf_counter() + if args.parallel == 1: + for i in tqdm(range(len(arguments))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + rets = executor.map(get_one_answer, list(range(len(arguments)))) + for _ in rets: + pass + + latency = time.perf_counter() - tic + + return states, latency + + +def main(args): + if args.mode == "character": + args.data_path = "dataset.txt" + states, latency = bench_character(args) + elif args.mode == "city": + args.data_path = "questions.jsonl" + states, latency = bench_city_doc(args) + + # Compute accuracy + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "json_jump_forward", + "backend": args.backend, + "latency": round(latency, 3), + "num_jsons": args.num_jsons, + "mode": args.mode, + "parallel": args.parallel, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str) + parser.add_argument("--num-jsons", type=int, default=50) + parser.add_argument( + "--mode", type=str, default="character", choices=["character", "city"] + ) + args = add_common_other_args_and_parse(parser) + main(args) diff --git a/sglang/benchmark/json_jump_forward/bench_sglang.py b/sglang/benchmark/json_jump_forward/bench_sglang.py new file mode 100644 index 0000000000000000000000000000000000000000..29f635f75ac49c82a089f5c42e5994765281b963 --- /dev/null +++ b/sglang/benchmark/json_jump_forward/bench_sglang.py @@ -0,0 +1,143 @@ +import argparse +import json +import time + +import sglang as sgl +from sglang.test.test_utils import ( + add_common_sglang_args_and_parse, + select_sglang_backend, +) +from sglang.utils import dump_state_text, read_jsonl + +# there are some FSM bugs with json regex converted from pydantic model +# here use a string regex instead +# regex_string = build_regex_from_object(HarryPoterRole) +character_regex = ( + r"""\{\n""" + + r""" "name": "[\w\d\s]{1,16}",\n""" + + r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n""" + + r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n""" + + r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n""" + + r""" "wand": \{\n""" + + r""" "wood": "[\w\d\s]{1,16}",\n""" + + r""" "core": "[\w\d\s]{1,16}",\n""" + + r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n""" + + r""" \},\n""" + + r""" "alive": "(Alive|Deceased)",\n""" + + r""" "patronus": "[\w\d\s]{1,16}",\n""" + + r""" "bogart": "[\w\d\s]{1,16}"\n""" + + r"""\}""" +) + +city_regex = ( + r"""\{\n""" + + r""" "name": "[\w\d\s]{1,16}",\n""" + + r""" "country": "[\w\d\s]{1,16}",\n""" + + r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n""" + + r""" "population": [-+]?[0-9]{1,9},\n""" + + r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n""" + + r"""\}""" +) + +# fmt: off +@sgl.function +def character_gen(s, name): + s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n" + s += sgl.gen("json_output", max_tokens=256, regex=character_regex) +# fmt: on + +# fmt: off +@sgl.function +def city_gen(s, document): + s += "Please extract the information of a city from the following wikipedia page.\n" + s += "Page begin.\n" + document + "Page end.\n" + s += "Here is the name, country, and symbol of the city in JSON format.\n" + s += sgl.gen("json_output",max_tokens=256, regex=city_regex) +# fmt: on + + +def bench_city_doc(args): + arguments = [] + for line in read_jsonl(args.data_path): + arguments.append({"document": line["document"]}) + arguments = arguments[: args.num_jsons] + + # Select backend + backend = select_sglang_backend(args) + sgl.set_default_backend(backend) + + # Run requests + tic = time.perf_counter() + states = city_gen.run_batch( + arguments, + temperature=0, + num_threads=args.parallel, + progress_bar=True, + ) + latency = time.perf_counter() - tic + + return states, latency + + +def bench_character(args): + arguments = [] + with open(args.data_path, "r") as f: + for line in f: + arguments.append({"name": line.strip()}) + arguments = arguments[: args.num_jsons] + + # Select backend + backend = select_sglang_backend(args) + sgl.set_default_backend(backend) + + # Run requests + tic = time.perf_counter() + states = character_gen.run_batch( + arguments, + temperature=0, + num_threads=args.parallel, + progress_bar=True, + ) + latency = time.perf_counter() - tic + + return states, latency + + +def main(args): + if args.mode == "character": + args.data_path = "dataset.txt" + states, latency = bench_character(args) + elif args.mode == "city": + args.data_path = "questions.jsonl" + states, latency = bench_city_doc(args) + + # Compute accuracy + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states) + with open(f"{args.backend}_{args.mode}.json", "w") as fout: + for state in states: + fout.write(state["json_output"] + "\n") + + with open(args.result_file, "a") as fout: + value = { + "task": "json_jump_forward", + "backend": args.backend, + "latency": round(latency, 3), + "num_jsons": args.num_jsons, + "mode": args.mode, + "parallel": args.parallel, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str) + parser.add_argument("--num-jsons", type=int, default=50) + parser.add_argument( + "--mode", type=str, default="character", choices=["character", "city"] + ) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/sglang/benchmark/json_jump_forward/build_dataset.py b/sglang/benchmark/json_jump_forward/build_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1396e5edeccc34b4e7b1692428c24241cae8a925 --- /dev/null +++ b/sglang/benchmark/json_jump_forward/build_dataset.py @@ -0,0 +1,58 @@ +import json + +import transformers +import wikipedia + +model_path = "meta-llama/Llama-2-7b-chat-hf" +t = transformers.AutoTokenizer.from_pretrained(model_path) +city_names = [ + "los angles", + "london", + "tokyo", + "beijing", + "singapore", + "paris", + "dubai", + "sydney", + "moscow", + "rome", + "toronto", + "rio de janeiro", + "istanbul", + "berlin", + "auckland", + "buenos aires", + "mexico city", + "mumbai", + "seoul", + "bangkok", + "cairo", + "athens", + "jerusalem", +] + + +def get_content(city_name): + content = str(wikipedia.page(city_name).content) + content = content.replace("\n\n", "\n") + + tokens = t.encode(content) + + expected_tokens = 3000 + truncate_len = int((expected_tokens / len(tokens)) * len(content)) + truncate_content = content[:truncate_len] + truncate_tokens = t.encode(truncate_content) + + # Count token + print( + f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}" + ) + + return truncate_content + + +if __name__ == "__main__": + with open("questions.jsonl", "w") as fout: + for city_name in city_names: + truncate_content = get_content(city_name) + fout.write(json.dumps({"document": truncate_content}) + "\n") diff --git a/sglang/benchmark/json_jump_forward/dataset.txt b/sglang/benchmark/json_jump_forward/dataset.txt new file mode 100644 index 0000000000000000000000000000000000000000..c12421e5d355c83c96ec9bfe42c615a2801e53f3 --- /dev/null +++ b/sglang/benchmark/json_jump_forward/dataset.txt @@ -0,0 +1,50 @@ +Harry Potter +Hermione Granger +Ron Weasley +Albus Dumbledore +Severus Snape +Rubeus Hagrid +Draco Malfoy +Ginny Weasley +Fred Weasley +George Weasley +Percy Weasley +Sirius Black +Remus Lupin +Neville Longbottom +Luna Lovegood +Cedric Diggory +Cho Chang +Lord Voldemort +Minerva McGonagall +Filius Flitwick +Dolores Umbridge +Bellatrix Lestrange +Lucius Malfoy +Molly Weasley +Arthur Weasley +Nymphadora Tonks +Dobby +Moaning Myrtle +Peter Pettigrew +Alastor 'Mad-Eye' Moody +Horace Slughorn +Vernon Dursley +Petunia Dursley +Dudley Dursley +Argus Filch +Sybill Trelawney +Gilderoy Lockhart +Fleur Delacour +Viktor Krum +Bill Weasley +Oliver Wood +Cornelius Fudge +Barty Crouch Sr. +Barty Crouch Jr. +Kingsley Shacklebolt +Quirinus Quirrell +Nearly Headless Nick +Aunt Marge +Griphook +Ludo Bagman diff --git a/sglang/benchmark/multi_turn_chat/bench_other.py b/sglang/benchmark/multi_turn_chat/bench_other.py new file mode 100644 index 0000000000000000000000000000000000000000..9189af5beca469bb58c05389c5ad7264734e8ad0 --- /dev/null +++ b/sglang/benchmark/multi_turn_chat/bench_other.py @@ -0,0 +1,93 @@ +import json +import time +from argparse import ArgumentParser +from concurrent.futures import ThreadPoolExecutor +from functools import partial + +from data_gen import gen_arguments +from tqdm import tqdm +from vllm.transformers_utils.tokenizer import get_tokenizer + +from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate +from sglang.utils import dump_state_text + + +def multi_turns(generate, qas): + s = "" + for qa in qas: + s += qa["prompt"] + s += generate(s, max_tokens=qa["new_tokens"]) + + return s + + +def main(args): + print(args) + + tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code) + + multi_qas = gen_arguments(args, tokenizer) + + states = [None] * args.num_qa + + call_generate = partial(get_call_generate(args), temperature=0) + + def get_one_answer(i): + states[i] = multi_turns(generate=call_generate, **multi_qas[i]) + + tic = time.perf_counter() + if args.parallel == 1: + for i in tqdm(range(len(multi_qas))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + rets = list( + tqdm( + executor.map(get_one_answer, list(range(len(multi_qas)))), + total=len(multi_qas), + ) + ) + for _ in rets: + pass + + latency = time.perf_counter() - tic + + # Compute accuracy + print(f"Latency: {latency:.3f}") + + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "multi_turn_chat", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": args.num_qa, + "num_turns": args.turns, + "other": { + "parallel": args.parallel, + "output_mode": "long" if args.long else "short", + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--turns", type=int, default=4) + parser.add_argument("--num-qa", type=int, default=20) + parser.add_argument("--min-len-q", type=int, default=256) + parser.add_argument("--max-len-q", type=int, default=512) + parser.add_argument("--min-len-a", type=int, default=4) + parser.add_argument("--max-len-a", type=int, default=8) + parser.add_argument("--tokenizer", type=str, required=True) + parser.add_argument("--trust-remote-code", action="store_true") + parser.add_argument("--long", action="store_true") + args = add_common_other_args_and_parse(parser) + + if args.long: + args.min_len_a = 256 + args.max_len_a = 512 + args.num_qa = 20 + main(args) diff --git a/sglang/benchmark/multi_turn_chat/data_gen.py b/sglang/benchmark/multi_turn_chat/data_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..043c07a76aeaa2a3153ac803c573ca9490bae6bd --- /dev/null +++ b/sglang/benchmark/multi_turn_chat/data_gen.py @@ -0,0 +1,29 @@ +import random +import string + +random.seed(42) + + +def gen_prompt(tokenizer, token_num): + cha_set = string.ascii_letters + string.digits + ret = "".join(random.choices(cha_set, k=token_num)) + while len(tokenizer(ret).input_ids) < token_num: + ret += random.choice(cha_set) + return ret + + +def gen_arguments(args, tokenizer): + multi_qas = [{"qas": []} for _ in range(args.num_qa)] + for i in range(args.num_qa): + qas = multi_qas[i]["qas"] + for _ in range(args.turns): + prompt_len = random.randint(args.min_len_q, args.max_len_q) + new_tokens = random.randint(args.min_len_a, args.max_len_a) + qas.append( + { + "prompt": gen_prompt(tokenizer, prompt_len), + "new_tokens": new_tokens, + } + ) + + return multi_qas diff --git a/sglang/benchmark/tree_of_thought_deep/README.md b/sglang/benchmark/tree_of_thought_deep/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bf5ab16387d664ebe2d782347c1025856da32aef --- /dev/null +++ b/sglang/benchmark/tree_of_thought_deep/README.md @@ -0,0 +1,51 @@ +## Download data +``` +wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl +``` + +## Run benchmark + +NOTE: This is an implementation for throughput/latency benchmark purposes. The prompts are not tuned to achieve good accuracy on the GSM-8K tasks. + +### Benchmark sglang +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +``` +python3 bench_sglang.py --num-questions 32 +python3 bench_sglang.py --num-questions 16 --parallel 1 +``` + + +### Benchmark vllm +``` +python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000 +``` + +``` +python3 bench_other.py --num-questions 32 --backend vllm +``` + + +### Benchmark lightllm +``` +# A10G +python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000 +``` + +``` +python3 bench_other.py --num-questions 32 --backend lightllm +``` + + +### Benchmark guidance +``` +python3 bench_other.py --num-questions 8 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf +``` + +### Benchmark lmql + +``` +python3 bench_other.py --num-questions 8 --backend lmql --parallel 1 +``` diff --git a/sglang/benchmark/tree_of_thought_deep/bench_other.py b/sglang/benchmark/tree_of_thought_deep/bench_other.py new file mode 100644 index 0000000000000000000000000000000000000000..0ef8c636078e54edc51ec04da87ceee34121a71a --- /dev/null +++ b/sglang/benchmark/tree_of_thought_deep/bench_other.py @@ -0,0 +1,222 @@ +import argparse +import ast +import json +import re +import time +from collections import Counter +from concurrent.futures import ThreadPoolExecutor + +import numpy as np +from tqdm import tqdm + +from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate +from sglang.utils import dump_state_text, read_jsonl + +INVALID = -9999999 + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def most_frequent_number(numbers): + if not numbers: + return None + + frequency = Counter(numbers) + most_frequent = max(frequency, key=frequency.get) + return most_frequent + + +USER_PREFIX = "[INST] " +USER_SUFFIX = " [/INST]" +ASSISTANT_PREFIX = "" +ASSISTANT_SUFFIX = " " + +# Use a low temp to make the results more deterministic and the comparison more fair. +temp = 0.001 + + +def propose_plan(s, question, num_branches, call_generate): + s += ( + USER_PREFIX + + """Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + + question + + USER_SUFFIX + ) + + s += ASSISTANT_PREFIX + comps = call_generate( + s, max_tokens=256, temperature=temp, stop=None, n=num_branches + ) + return [s + comp + ASSISTANT_SUFFIX for comp in comps] + + +def execute_plan(s, num_branches, call_generate): + s += ( + USER_PREFIX + + """The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""" + + USER_SUFFIX + ) + s += ASSISTANT_PREFIX + comps = call_generate( + s, max_tokens=256, temperature=temp, stop=None, n=num_branches + ) + return [s + comp + ASSISTANT_SUFFIX for comp in comps] + + +def reflect_solution(s, num_branches, call_generate): + s += ( + USER_PREFIX + + """Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""" + + USER_SUFFIX + ) + s += ASSISTANT_PREFIX + comps = call_generate( + s, max_tokens=256, temperature=temp, stop=None, n=num_branches + ) + return [s + comp + ASSISTANT_SUFFIX for comp in comps] + + +def get_final_answer(s, num_branches, call_generate): + s += ( + USER_PREFIX + + """Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration.""" + + USER_SUFFIX + ) + s += ASSISTANT_PREFIX + comps = call_generate( + s, max_tokens=256, temperature=temp, stop=None, n=num_branches + ) + return [s + comp + ASSISTANT_SUFFIX for comp in comps] + + +def tree_search(question, num_branches, call_generate): + plan_forks = propose_plan("", question, num_branches, call_generate) + + sol_states = [] + for plan in plan_forks: + forks = execute_plan(plan, num_branches, call_generate) + sol_states.extend(forks) + + ref_states = [] + for sol in sol_states: + forks = reflect_solution(sol, num_branches, call_generate) + ref_states.extend(forks) + + solutions = [] + for sol in ref_states: + ans = get_final_answer(sol, num_branches, call_generate) + solutions.append(ans) + + return solutions + + +def main(args): + lines = read_jsonl(args.data_path) + + # Construct prompts + num_branches = 2 + questions = [] + labels = [] + for i in range(len(lines[: args.num_questions])): + questions.append(lines[i]["question"]) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(l != INVALID for l in labels) + arguments = [{"question": q, "num_branches": num_branches} for q in questions] + + # Select backend + call_generate = get_call_generate(args) + + # Run requests + states = [None] * len(questions) + + tic = time.perf_counter() + if args.backend != "lmql": + + def get_one_answer(i): + states[i] = tree_search(**arguments[i], call_generate=call_generate) + + if args.parallel == 1: + for i in tqdm(range(len(questions))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + list( + tqdm( + executor.map(get_one_answer, list(range(len(questions)))), + total=len(questions), + ) + ) + + else: + import asyncio + + from lmql_funcs import tree_search_async + + async def get_one_answer_async(i): + states[i] = await tree_search_async( + **arguments[i], call_generate=call_generate + ) + + batches = [ + [] for _ in range((len(questions) + args.parallel - 1) // args.parallel) + ] + for i in range(len(questions)): + batches[i // args.parallel].append(i) + + loop = asyncio.get_event_loop() + for bt in tqdm(batches): + tasks = [get_one_answer_async(k) for k in bt] + loop.run_until_complete(asyncio.gather(*tasks)) + + latency = time.perf_counter() - tic + + answers_text = [] + for s in states: + answers_text.append([x for xs in s for x in xs]) + + preds = [] + for i in range(len(states)): + answers = [get_answer_value(v) for v in answers_text[i]] + preds.append(most_frequent_number(answers)) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + print(f"Latency: {latency:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Accuracy: {acc:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", answers_text) + + with open(args.result_file, "a") as fout: + value = { + "task": "tree_of_thought_gsm8k", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument("--num-questions", type=int, default=200) + args = add_common_other_args_and_parse(parser) + main(args) diff --git a/sglang/benchmark/tree_of_thought_deep/bench_sglang.py b/sglang/benchmark/tree_of_thought_deep/bench_sglang.py new file mode 100644 index 0000000000000000000000000000000000000000..bcdb6e54d2283ab8dea6aa9424a8db387177c4a3 --- /dev/null +++ b/sglang/benchmark/tree_of_thought_deep/bench_sglang.py @@ -0,0 +1,171 @@ +import argparse +import ast +import json +import re +import time +from collections import Counter + +import numpy as np + +import sglang as sgl +from sglang.test.test_utils import ( + add_common_sglang_args_and_parse, + select_sglang_backend, +) +from sglang.utils import dump_state_text, read_jsonl + +INVALID = -9999999 + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def most_frequent_number(numbers): + if not numbers: + return None + + frequency = Counter(numbers) + most_frequent = max(frequency, key=frequency.get) + return most_frequent + + +# Use a low temp to make the results more deterministic and the comparison more fair. +temp = 0.001 + + +def propose_plan(s, question, num_branches): + s += sgl.user( + """Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + + question + ) + forks = s.fork(num_branches) + forks += sgl.assistant(sgl.gen("plan", max_tokens=256, temperature=temp)) + return forks + + +def execute_plan(s, num_branches): + s += sgl.user( + """The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""" + ) + forks = s.fork(num_branches) + forks += sgl.assistant(sgl.gen("answer", max_tokens=256, temperature=temp)) + return forks + + +def reflect_solution(s, num_branches): + s += sgl.user( + """Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""" + ) + forks = s.fork(num_branches) + forks += sgl.assistant(sgl.gen("score", max_tokens=256, temperature=temp)) + return forks + + +def get_final_answer(s, num_branches): + s += sgl.user( + """Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration.""" + ) + forks = s.fork(num_branches) + forks += sgl.assistant(sgl.gen("final_answer", max_tokens=256, temperature=temp)) + return forks + + +@sgl.function +def tree_search(s, question, num_branches): + plan_forks = propose_plan(s, question, num_branches) + + sol_states = [] + for plan in plan_forks: + forks = execute_plan(plan, num_branches) + sol_states.extend(forks) + + ref_states = [] + for sol in sol_states: + forks = reflect_solution(sol, num_branches) + ref_states.extend(forks) + + solutions = [] + for sol in ref_states: + forks = get_final_answer(sol, num_branches) + solutions.append(forks) + solutions = [[s.text() for s in forks] for forks in solutions] + + return solutions + + +def main(args): + lines = read_jsonl(args.data_path) + lines = list(lines) + + # Construct prompts + num_branches = 2 + questions = [] + labels = [] + for i in range(len(lines[: args.num_questions])): + questions.append(lines[i]["question"]) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(l != INVALID for l in labels) + arguments = [{"question": q, "num_branches": num_branches} for q in questions] + + # Select backend + backend = select_sglang_backend(args) + + # Run requests + tic = time.perf_counter() + states = tree_search.run_batch( + arguments, + temperature=0, + backend=backend, + num_threads=args.parallel, + progress_bar=True, + ) + latency = time.perf_counter() - tic + answers_text = [] + for s in states: + answers_text.append([x for xs in s.ret_value for x in xs]) + + preds = [] + for i in range(len(states)): + answers = [get_answer_value(v) for v in answers_text[i]] + preds.append(most_frequent_number(answers)) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + print(f"Latency: {latency:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Accuracy: {acc:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", answers_text) + + with open(args.result_file, "a") as fout: + value = { + "task": "tree_of_thought_gsm8k", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument("--num-questions", type=int, default=200) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/sglang/docker/configs/.zshrc b/sglang/docker/configs/.zshrc new file mode 100644 index 0000000000000000000000000000000000000000..5c7113e051017d040895f9692d26e9ef8af33a7a --- /dev/null +++ b/sglang/docker/configs/.zshrc @@ -0,0 +1,27 @@ +export ZSH="/root/.oh-my-zsh" + +# Theme +ZSH_THEME="robbyrussell" + +# Plugins +plugins=( + git + z + zsh-autosuggestions + zsh-syntax-highlighting +) + +source $ZSH/oh-my-zsh.sh + +# Aliases +alias ll='ls -alF' +alias la='ls -A' +alias l='ls -CF' +alias vi='vim' + +# Enhanced history +HISTSIZE=10000 +SAVEHIST=10000 +setopt HIST_IGNORE_ALL_DUPS +setopt HIST_FIND_NO_DUPS +setopt INC_APPEND_HISTORY diff --git a/sglang/docker/configs/opt/.gitconfig b/sglang/docker/configs/opt/.gitconfig new file mode 100644 index 0000000000000000000000000000000000000000..8150e40d8c6d3c654f774181630d03e953436740 --- /dev/null +++ b/sglang/docker/configs/opt/.gitconfig @@ -0,0 +1,30 @@ +[core] + editor = vim + whitespace = fix,-indent-with-non-tab,trailing-space,cr-at-eol + pager = diff-so-fancy | less --tabs=4 -RFX + +[color] + ui = true + +[color "diff-highlight"] + oldNormal = red bold + oldHighlight = red bold 52 + newNormal = green bold + newHighlight = green bold 22 + +[color "diff"] + meta = 11 + frag = magenta bold + commit = yellow bold + old = red bold + new = green bold + whitespace = red reverse + +[alias] + lg = log --color --graph --pretty=format:'%Cred%h%Creset - %s %Cgreen(%cr) %C(bold blue)<%an>%Creset%C(auto)%d%Creset' --abbrev-commit -- + +[http] + sslVerify = false + +[pull] + rebase = true diff --git a/sglang/docker/configs/opt/.tmux.conf b/sglang/docker/configs/opt/.tmux.conf new file mode 100644 index 0000000000000000000000000000000000000000..89f20064e3cdc0320ad2ecddfcc29166f31e0613 --- /dev/null +++ b/sglang/docker/configs/opt/.tmux.conf @@ -0,0 +1,27 @@ +# Pane border styling +set -g pane-border-style fg='#742727',bg=black +set -g pane-active-border-style fg=red,bg=black + +# Status bar styling +set -g status-style bg='#0C8A92',fg=black + +# Change prefix key to backtick +set-option -g prefix ` +unbind C-b +bind-key ` send-prefix + +# Split panes using - and = with current path +unbind '"' +bind - splitw -v -c '#{pane_current_path}' +unbind '%' +bind = splitw -h -c '#{pane_current_path}' + +# Vi mode settings +bind-key -T copy-mode-vi Y send-keys -X copy-pipe 'yank > #{pane_tty}' +set-window-option -g mode-keys vi + +# Other settings +set-option -g escape-time 0 +set-option -g base-index 1 +set-window-option -g mouse on +set -g history-limit 100000 diff --git a/sglang/docker/configs/opt/.vimrc b/sglang/docker/configs/opt/.vimrc new file mode 100644 index 0000000000000000000000000000000000000000..d4414000baa5170270e6409fb92ec0148f49391e --- /dev/null +++ b/sglang/docker/configs/opt/.vimrc @@ -0,0 +1,45 @@ +function! Yank(text) abort + let escape = system('yank', a:text) + if v:shell_error + echoerr escape + else + call writefile([escape], '/dev/tty', 'b') + endif +endfunction + +noremap y y:call Yank(@0) + +" automatically run yank(1) whenever yanking in Vim +function! CopyYank() abort + call Yank(join(v:event.regcontents, "\n")) +endfunction + +autocmd TextYankPost * call CopyYank() + +" Basic settings +set number +syntax on +set mouse=a +filetype indent on + +" Indentation +set autoindent nosmartindent +set smarttab +set expandtab +set shiftwidth=4 +set softtabstop=4 + +" Visual guides +set colorcolumn=120 +highlight ColorColumn ctermbg=5 + +" Status line +set laststatus=2 +set statusline=%<%f\ %h%m%r%=%{\"[\".(&fenc==\"\"?&enc:&fenc).((exists(\"+bomb\")\ &&\ &bomb)?\",B\":\"\").\"]\ \"}%k\ %-14.(%l,%c%V%)\ %P + +" Backspace behavior +set backspace=2 + +" Encoding +set encoding=utf-8 +set fileencoding=utf-8 diff --git a/sglang/docker/configs/yank b/sglang/docker/configs/yank new file mode 100644 index 0000000000000000000000000000000000000000..c9de641bca69cf9e195b83f26b622df002e6a911 --- /dev/null +++ b/sglang/docker/configs/yank @@ -0,0 +1,12 @@ +#!/bin/bash +put() { + esc=$1 + test -n "$TMUX" -o -z "${TERM##screen*}" && esc="\033Ptmux;\033$esc\033\\" + printf "$esc" +} +put "\033]52;c;!\a" +buf=$( cat "$@" ) +len=$( printf %s "$buf" | wc -c ) max=74994 +test $len -gt $max && echo "$0: input is $(( len - max )) bytes too long" >&2 +put "\033]52;c;$( printf %s "$buf" | head -c $max | base64 | tr -d '\r\n' )\a" +test -n "$TMUX" && tmux set-buffer "$buf" ||: diff --git a/sglang/python/sglang.egg-info/PKG-INFO b/sglang/python/sglang.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..d1dcfb3f777c8c840548fc583c5c4a78c5eab6e6 --- /dev/null +++ b/sglang/python/sglang.egg-info/PKG-INFO @@ -0,0 +1,120 @@ +Metadata-Version: 2.4 +Name: sglang +Version: 0.5.9 +Summary: SGLang is a fast serving framework for large language models and vision language models. +Project-URL: Homepage, https://github.com/sgl-project/sglang +Project-URL: Bug Tracker, https://github.com/sgl-project/sglang/issues +Classifier: Programming Language :: Python :: 3 +Classifier: License :: OSI Approved :: Apache Software License +Requires-Python: >=3.10 +Description-Content-Type: text/markdown +Requires-Dist: IPython +Requires-Dist: aiohttp +Requires-Dist: apache-tvm-ffi<0.2,>=0.1.5 +Requires-Dist: anthropic>=0.20.0 +Requires-Dist: blobfile==3.0.0 +Requires-Dist: build +Requires-Dist: compressed-tensors +Requires-Dist: cuda-python==12.9 +Requires-Dist: decord2 +Requires-Dist: datasets +Requires-Dist: einops +Requires-Dist: fastapi +Requires-Dist: flashinfer_python==0.6.4 +Requires-Dist: flashinfer_cubin==0.6.4 +Requires-Dist: gguf +Requires-Dist: hf_transfer +Requires-Dist: huggingface_hub +Requires-Dist: interegular +Requires-Dist: llguidance<0.8.0,>=0.7.11 +Requires-Dist: modelscope +Requires-Dist: msgspec +Requires-Dist: ninja +Requires-Dist: numpy +Requires-Dist: nvidia-cutlass-dsl>=4.3.4 +Requires-Dist: nvidia-ml-py +Requires-Dist: openai-harmony==0.0.4 +Requires-Dist: openai==2.6.1 +Requires-Dist: orjson +Requires-Dist: outlines==0.1.11 +Requires-Dist: packaging +Requires-Dist: partial_json_parser +Requires-Dist: pillow +Requires-Dist: prometheus-client>=0.20.0 +Requires-Dist: psutil +Requires-Dist: py-spy +Requires-Dist: pybase64 +Requires-Dist: pydantic +Requires-Dist: python-multipart +Requires-Dist: pyzmq>=25.1.2 +Requires-Dist: quack-kernels==0.2.4 +Requires-Dist: requests +Requires-Dist: scipy +Requires-Dist: sentencepiece +Requires-Dist: setproctitle +Requires-Dist: sgl-fa4==4.0.3 +Requires-Dist: sgl-kernel==0.3.21 +Requires-Dist: soundfile==0.13.1 +Requires-Dist: tiktoken +Requires-Dist: timm==1.0.16 +Requires-Dist: torch_memory_saver==0.0.9 +Requires-Dist: torch==2.9.1 +Requires-Dist: torchao==0.9.0 +Requires-Dist: torchaudio==2.9.1 +Requires-Dist: torchcodec==0.8.0; sys_platform != "linux" or (sys_platform == "linux" and platform_machine != "aarch64" and platform_machine != "arm64" and platform_machine != "armv7l") +Requires-Dist: torchvision +Requires-Dist: tqdm +Requires-Dist: transformers==4.57.1 +Requires-Dist: uvicorn +Requires-Dist: uvloop +Requires-Dist: watchfiles +Requires-Dist: xgrammar==0.1.27 +Requires-Dist: smg-grpc-proto>=0.4.1 +Requires-Dist: grpcio>=1.78.0 +Requires-Dist: grpcio-reflection>=1.78.0 +Requires-Dist: grpcio-health-checking>=1.78.0 +Provides-Extra: checkpoint-engine +Requires-Dist: checkpoint-engine==0.1.2; extra == "checkpoint-engine" +Provides-Extra: diffusion +Requires-Dist: PyYAML==6.0.1; extra == "diffusion" +Requires-Dist: cloudpickle==3.1.2; extra == "diffusion" +Requires-Dist: diffusers==0.36.0; extra == "diffusion" +Requires-Dist: imageio==2.36.0; extra == "diffusion" +Requires-Dist: imageio-ffmpeg==0.5.1; extra == "diffusion" +Requires-Dist: moviepy>=2.0.0; extra == "diffusion" +Requires-Dist: opencv-python-headless==4.10.0.84; extra == "diffusion" +Requires-Dist: remote-pdb==2.1.0; extra == "diffusion" +Requires-Dist: st_attn==0.0.7; (platform_machine != "aarch64" and platform_machine != "arm64") and extra == "diffusion" +Requires-Dist: vsa==0.0.4; (platform_machine != "aarch64" and platform_machine != "arm64") and extra == "diffusion" +Requires-Dist: runai_model_streamer>=0.15.5; extra == "diffusion" +Requires-Dist: cache-dit==1.2.3; extra == "diffusion" +Requires-Dist: addict==2.4.0; extra == "diffusion" +Requires-Dist: av==16.1.0; extra == "diffusion" +Requires-Dist: scikit-image==0.25.2; extra == "diffusion" +Requires-Dist: trimesh>=4.0.0; extra == "diffusion" +Requires-Dist: xatlas; extra == "diffusion" +Provides-Extra: ray +Requires-Dist: ray[default]>=2.54.0; extra == "ray" +Provides-Extra: tracing +Requires-Dist: opentelemetry-api; extra == "tracing" +Requires-Dist: opentelemetry-exporter-otlp; extra == "tracing" +Requires-Dist: opentelemetry-exporter-otlp-proto-grpc; extra == "tracing" +Requires-Dist: opentelemetry-sdk; extra == "tracing" +Provides-Extra: test +Requires-Dist: accelerate; extra == "test" +Requires-Dist: bitsandbytes; extra == "test" +Requires-Dist: expecttest; extra == "test" +Requires-Dist: jsonlines; extra == "test" +Requires-Dist: lm-eval[api]>=0.4.9.2; extra == "test" +Requires-Dist: matplotlib; extra == "test" +Requires-Dist: pandas; extra == "test" +Requires-Dist: parameterized; extra == "test" +Requires-Dist: peft; extra == "test" +Requires-Dist: pytest; extra == "test" +Requires-Dist: sentence_transformers; extra == "test" +Requires-Dist: tabulate; extra == "test" +Provides-Extra: dev +Requires-Dist: sglang[test]; extra == "dev" +Provides-Extra: all +Requires-Dist: sglang[diffusion]; extra == "all" +Requires-Dist: sglang[tracing]; extra == "all" diff --git a/sglang/python/sglang.egg-info/SOURCES.txt b/sglang/python/sglang.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..62ed7ba8db3b27ef271cbaf8737aa4a24072af6e --- /dev/null +++ b/sglang/python/sglang.egg-info/SOURCES.txt @@ -0,0 +1,2123 @@ +pyproject.toml +pyproject_cpu.toml +pyproject_npu.toml +pyproject_other.toml +pyproject_xpu.toml +sglang/README.md +sglang/__init__.py +sglang/_version.py +sglang/bench_offline_throughput.py +sglang/bench_one_batch.py +sglang/bench_one_batch_server.py +sglang/bench_serving.py +sglang/check_env.py +sglang/compile_deep_gemm.py +sglang/global_config.py +sglang/launch_server.py +sglang/profiler.py +sglang/utils.py +sglang/version.py +sglang.egg-info/PKG-INFO +sglang.egg-info/SOURCES.txt +sglang.egg-info/dependency_links.txt +sglang.egg-info/entry_points.txt +sglang.egg-info/requires.txt +sglang.egg-info/top_level.txt +sglang/benchmark/__init__.py +sglang/benchmark/utils.py +sglang/benchmark/datasets/__init__.py +sglang/benchmark/datasets/common.py +sglang/benchmark/datasets/custom.py +sglang/benchmark/datasets/generated_shared_prefix.py +sglang/benchmark/datasets/image.py +sglang/benchmark/datasets/mmmu.py +sglang/benchmark/datasets/mooncake.py +sglang/benchmark/datasets/openai_dataset.py +sglang/benchmark/datasets/random.py +sglang/benchmark/datasets/sharegpt.py +sglang/cli/__init__.py +sglang/cli/generate.py +sglang/cli/main.py +sglang/cli/serve.py +sglang/cli/utils.py +sglang/eval/llama3_eval.py +sglang/eval/loogle_eval.py +sglang/jit_kernel/.clang-format +sglang/jit_kernel/__init__.py +sglang/jit_kernel/__main__.py +sglang/jit_kernel/add_constant.py +sglang/jit_kernel/awq_dequantize.py +sglang/jit_kernel/awq_marlin_repack.py +sglang/jit_kernel/concat_mla.py +sglang/jit_kernel/cutedsl_gdn.py +sglang/jit_kernel/flash_attention_v4.py +sglang/jit_kernel/fused_metadata_copy.py +sglang/jit_kernel/fused_store_index_cache.py +sglang/jit_kernel/gptq_marlin.py +sglang/jit_kernel/gptq_marlin_repack.py +sglang/jit_kernel/hadamard.py +sglang/jit_kernel/hicache.py +sglang/jit_kernel/kvcache.py +sglang/jit_kernel/moe_wna16_marlin.py +sglang/jit_kernel/norm.py +sglang/jit_kernel/nvfp4.py +sglang/jit_kernel/per_tensor_quant_fp8.py +sglang/jit_kernel/per_token_group_quant_8bit.py +sglang/jit_kernel/pos_enc.py +sglang/jit_kernel/rope.py +sglang/jit_kernel/timestep_embedding.py +sglang/jit_kernel/utils.py +sglang/jit_kernel/benchmark/bench_awq_dequantize.py +sglang/jit_kernel/benchmark/bench_awq_marlin_moe_repack.py +sglang/jit_kernel/benchmark/bench_awq_marlin_repack.py +sglang/jit_kernel/benchmark/bench_concat_mla.py +sglang/jit_kernel/benchmark/bench_fused_add_rmsnorm.py +sglang/jit_kernel/benchmark/bench_fused_norm_scale_shift.py +sglang/jit_kernel/benchmark/bench_gptq_marlin.py +sglang/jit_kernel/benchmark/bench_gptq_marlin_repack.py +sglang/jit_kernel/benchmark/bench_hicache.py +sglang/jit_kernel/benchmark/bench_moe_wna16_marlin.py +sglang/jit_kernel/benchmark/bench_nvfp4_blockwise_moe.py +sglang/jit_kernel/benchmark/bench_nvfp4_quant.py +sglang/jit_kernel/benchmark/bench_nvfp4_scaled_mm.py +sglang/jit_kernel/benchmark/bench_per_tensor_quant_fp8.py +sglang/jit_kernel/benchmark/bench_per_token_group_quant_8bit.py +sglang/jit_kernel/benchmark/bench_qknorm.py +sglang/jit_kernel/benchmark/bench_qknorm_across_heads.py +sglang/jit_kernel/benchmark/bench_renorm.py +sglang/jit_kernel/benchmark/bench_rmsnorm.py +sglang/jit_kernel/benchmark/bench_rope.py +sglang/jit_kernel/benchmark/bench_store_cache.py +sglang/jit_kernel/benchmark/utils.py +sglang/jit_kernel/csrc/add_constant.cuh +sglang/jit_kernel/csrc/hicache.cuh +sglang/jit_kernel/csrc/diffusion/timestep_embedding.cuh +sglang/jit_kernel/csrc/elementwise/concat_mla.cuh +sglang/jit_kernel/csrc/elementwise/fused_add_rmsnorm.cuh +sglang/jit_kernel/csrc/elementwise/fused_metadata_copy.cuh +sglang/jit_kernel/csrc/elementwise/kvcache.cuh +sglang/jit_kernel/csrc/elementwise/pos_enc.cuh +sglang/jit_kernel/csrc/elementwise/qknorm.cuh +sglang/jit_kernel/csrc/elementwise/qknorm_across_heads.cuh +sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh +sglang/jit_kernel/csrc/elementwise/rope.cuh +sglang/jit_kernel/csrc/fast-hadamard-transform/code_gen.py +sglang/jit_kernel/csrc/fast-hadamard-transform/fast_hadamard_transform.h +sglang/jit_kernel/csrc/fast-hadamard-transform/fast_hadamard_transform_common.h +sglang/jit_kernel/csrc/fast-hadamard-transform/fast_hadamard_transform_special.h +sglang/jit_kernel/csrc/fast-hadamard-transform/hadamard_jit.cuh +sglang/jit_kernel/csrc/fast-hadamard-transform/static_switch.h +sglang/jit_kernel/csrc/gemm/awq_dequantize.cuh +sglang/jit_kernel/csrc/gemm/per_tensor_quant_fp8.cuh +sglang/jit_kernel/csrc/gemm/per_token_group_quant_8bit.cuh +sglang/jit_kernel/csrc/gemm/marlin/awq_marlin_repack.cuh +sglang/jit_kernel/csrc/gemm/marlin/dequant.h +sglang/jit_kernel/csrc/gemm/marlin/gptq_marlin.cuh +sglang/jit_kernel/csrc/gemm/marlin/gptq_marlin_repack.cuh +sglang/jit_kernel/csrc/gemm/marlin/kernel.h +sglang/jit_kernel/csrc/gemm/marlin/marlin.cuh +sglang/jit_kernel/csrc/gemm/marlin/marlin_dtypes.cuh +sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h +sglang/jit_kernel/csrc/gemm/marlin_moe/kernel.h +sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h +sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh +sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_expert_quant.cuh +sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_quant.cuh +sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_quant_entry.cuh +sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_quant_kernels.cuh +sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_entry.cuh +sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_kernels.cuh +sglang/jit_kernel/csrc/moe/nvfp4_blockwise_moe.cuh +sglang/jit_kernel/csrc/nsa/fused_store_index_cache.cuh +sglang/jit_kernel/diffusion/cutedsl/scale_residual_norm_scale_shift.py +sglang/jit_kernel/diffusion/cutedsl/utils.py +sglang/jit_kernel/diffusion/cutedsl/common/norm_fusion.py +sglang/jit_kernel/diffusion/cutedsl/common/reduce.py +sglang/jit_kernel/diffusion/triton/norm.py +sglang/jit_kernel/diffusion/triton/npu_fallback.py +sglang/jit_kernel/diffusion/triton/rmsnorm_onepass.py +sglang/jit_kernel/diffusion/triton/rotary.py +sglang/jit_kernel/diffusion/triton/scale_shift.py +sglang/jit_kernel/include/sgl_kernel/atomic.cuh +sglang/jit_kernel/include/sgl_kernel/cta.cuh +sglang/jit_kernel/include/sgl_kernel/math.cuh +sglang/jit_kernel/include/sgl_kernel/runtime.cuh +sglang/jit_kernel/include/sgl_kernel/scalar_type.hpp +sglang/jit_kernel/include/sgl_kernel/source_location.h +sglang/jit_kernel/include/sgl_kernel/tensor.h +sglang/jit_kernel/include/sgl_kernel/tile.cuh +sglang/jit_kernel/include/sgl_kernel/type.cuh +sglang/jit_kernel/include/sgl_kernel/utils.cuh +sglang/jit_kernel/include/sgl_kernel/utils.h +sglang/jit_kernel/include/sgl_kernel/vec.cuh +sglang/jit_kernel/include/sgl_kernel/warp.cuh +sglang/jit_kernel/include/sgl_kernel/impl/norm.cuh +sglang/jit_kernel/tests/test_add_constant.py +sglang/jit_kernel/tests/test_awq_dequantize.py +sglang/jit_kernel/tests/test_awq_marlin_moe_repack.py +sglang/jit_kernel/tests/test_awq_marlin_repack.py +sglang/jit_kernel/tests/test_concat_mla.py +sglang/jit_kernel/tests/test_cutedsl_gdn.py +sglang/jit_kernel/tests/test_flash_attention_4.py +sglang/jit_kernel/tests/test_fused_add_rmsnorm.py +sglang/jit_kernel/tests/test_fused_metadata_copy.py +sglang/jit_kernel/tests/test_fused_norm_scale_shift.py +sglang/jit_kernel/tests/test_fused_store_index_cache.py +sglang/jit_kernel/tests/test_fused_verify_triton_gdn.py +sglang/jit_kernel/tests/test_gptq_marlin.py +sglang/jit_kernel/tests/test_gptq_marlin_repack.py +sglang/jit_kernel/tests/test_moe_wna16_marlin.py +sglang/jit_kernel/tests/test_nvfp4_blockwise_moe.py +sglang/jit_kernel/tests/test_nvfp4_gemm.py +sglang/jit_kernel/tests/test_nvfp4_quant.py +sglang/jit_kernel/tests/test_per_tensor_quant_fp8.py +sglang/jit_kernel/tests/test_per_token_group_quant_8bit.py +sglang/jit_kernel/tests/test_pos_enc.py +sglang/jit_kernel/tests/test_qknorm.py +sglang/jit_kernel/tests/test_qknorm_across_heads.py +sglang/jit_kernel/tests/test_renorm.py +sglang/jit_kernel/tests/test_rmsnorm.py +sglang/jit_kernel/tests/test_rope.py +sglang/jit_kernel/tests/test_store_cache.py +sglang/jit_kernel/tests/test_timestep_embedding.py +sglang/lang/api.py +sglang/lang/chat_template.py +sglang/lang/choices.py +sglang/lang/interpreter.py +sglang/lang/ir.py +sglang/lang/tracer.py +sglang/lang/backend/anthropic.py +sglang/lang/backend/base_backend.py +sglang/lang/backend/litellm.py +sglang/lang/backend/openai.py +sglang/lang/backend/runtime_endpoint.py +sglang/lang/backend/vertexai.py +sglang/multimodal_gen/README.md +sglang/multimodal_gen/__init__.py +sglang/multimodal_gen/envs.py +sglang/multimodal_gen/registry.py +sglang/multimodal_gen/utils.py +sglang/multimodal_gen/.claude/CLAUDE.md +sglang/multimodal_gen/.claude/skills/diffusion-kernel/SKILL.md +sglang/multimodal_gen/.claude/skills/diffusion-kernel/add-cuda-kernel.md +sglang/multimodal_gen/.claude/skills/diffusion-kernel/add-triton-kernel.md +sglang/multimodal_gen/.claude/skills/diffusion-kernel/diffusion-benchmark-and-profile.md +sglang/multimodal_gen/.claude/skills/diffusion-kernel/nsight-profiler.md +sglang/multimodal_gen/.claude/skills/diffusion-kernel/use-efficient-diffusion-kernels.md +sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/a100-optimization-guide.md +sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/h100-optimization-guide.md +sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/kernel-templates.md +sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/t4-optimization-guide.md +sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/troubleshooting.md +sglang/multimodal_gen/.claude/skills/diffusion-kernel/scripts/bench_diffusion_denoise.py +sglang/multimodal_gen/.claude/skills/diffusion-kernel/scripts/bench_diffusion_rmsnorm.py +sglang/multimodal_gen/.claude/skills/diffusion-perf/SKILL.md +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/README.md +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/__init__.py +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/nodes.py +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/utils.py +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/__init__.py +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/generator.py +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/model_patcher.py +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/server_api.py +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/__init__.py +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/base.py +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/flux.py +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/qwen_image.py +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/zimage.py +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/README.md +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/__init__.py +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_flux_pipeline.py +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_qwen_image_edit_pipeline.py +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_qwen_image_pipeline.py +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_zimage_pipeline.py +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/flux_sgld_sp.json +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/qwen_image_sgld.json +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/sgld_image2video.json +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/sgld_text2img.json +sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/z-image_sgld.json +sglang/multimodal_gen/apps/webui/README.md +sglang/multimodal_gen/apps/webui/__init__.py +sglang/multimodal_gen/apps/webui/main.py +sglang/multimodal_gen/benchmarks/bench_offline_throughput.py +sglang/multimodal_gen/benchmarks/bench_serving.py +sglang/multimodal_gen/benchmarks/compare_perf.py +sglang/multimodal_gen/benchmarks/datasets.py +sglang/multimodal_gen/configs/__init__.py +sglang/multimodal_gen/configs/quantization.py +sglang/multimodal_gen/configs/utils.py +sglang/multimodal_gen/configs/backend/vmoba/wan_1.3B_77_448_832.json +sglang/multimodal_gen/configs/backend/vmoba/wan_1.3B_77_480_832.json +sglang/multimodal_gen/configs/models/__init__.py +sglang/multimodal_gen/configs/models/base.py +sglang/multimodal_gen/configs/models/adapter/base.py +sglang/multimodal_gen/configs/models/adapter/ltx_2_connector.py +sglang/multimodal_gen/configs/models/bridges/__init__.py +sglang/multimodal_gen/configs/models/bridges/mova_dual_tower.py +sglang/multimodal_gen/configs/models/dits/__init__.py +sglang/multimodal_gen/configs/models/dits/base.py +sglang/multimodal_gen/configs/models/dits/flux.py +sglang/multimodal_gen/configs/models/dits/glmimage.py +sglang/multimodal_gen/configs/models/dits/helios.py +sglang/multimodal_gen/configs/models/dits/hunyuan3d.py +sglang/multimodal_gen/configs/models/dits/hunyuanvideo.py +sglang/multimodal_gen/configs/models/dits/ltx_2.py +sglang/multimodal_gen/configs/models/dits/mova_audio.py +sglang/multimodal_gen/configs/models/dits/mova_video.py +sglang/multimodal_gen/configs/models/dits/qwenimage.py +sglang/multimodal_gen/configs/models/dits/wanvideo.py +sglang/multimodal_gen/configs/models/dits/zimage.py +sglang/multimodal_gen/configs/models/encoders/__init__.py +sglang/multimodal_gen/configs/models/encoders/base.py +sglang/multimodal_gen/configs/models/encoders/clip.py +sglang/multimodal_gen/configs/models/encoders/gemma_3.py +sglang/multimodal_gen/configs/models/encoders/llama.py +sglang/multimodal_gen/configs/models/encoders/qwen3.py +sglang/multimodal_gen/configs/models/encoders/qwen_image.py +sglang/multimodal_gen/configs/models/encoders/t5.py +sglang/multimodal_gen/configs/models/vaes/__init__.py +sglang/multimodal_gen/configs/models/vaes/base.py +sglang/multimodal_gen/configs/models/vaes/dac.py +sglang/multimodal_gen/configs/models/vaes/flux.py +sglang/multimodal_gen/configs/models/vaes/glmimage.py +sglang/multimodal_gen/configs/models/vaes/hunyuan3d.py +sglang/multimodal_gen/configs/models/vaes/hunyuanvae.py +sglang/multimodal_gen/configs/models/vaes/ltx_audio.py +sglang/multimodal_gen/configs/models/vaes/ltx_video.py +sglang/multimodal_gen/configs/models/vaes/qwenimage.py +sglang/multimodal_gen/configs/models/vaes/wanvae.py +sglang/multimodal_gen/configs/models/vocoder/__init__.py +sglang/multimodal_gen/configs/models/vocoder/base.py +sglang/multimodal_gen/configs/models/vocoder/ltx_vocoder.py +sglang/multimodal_gen/configs/pipeline_configs/__init__.py +sglang/multimodal_gen/configs/pipeline_configs/base.py +sglang/multimodal_gen/configs/pipeline_configs/diffusers_generic.py +sglang/multimodal_gen/configs/pipeline_configs/flux.py +sglang/multimodal_gen/configs/pipeline_configs/flux_finetuned.py +sglang/multimodal_gen/configs/pipeline_configs/glm_image.py +sglang/multimodal_gen/configs/pipeline_configs/helios.py +sglang/multimodal_gen/configs/pipeline_configs/hunyuan.py +sglang/multimodal_gen/configs/pipeline_configs/hunyuan3d.py +sglang/multimodal_gen/configs/pipeline_configs/ltx_2.py +sglang/multimodal_gen/configs/pipeline_configs/mova.py +sglang/multimodal_gen/configs/pipeline_configs/qwen_image.py +sglang/multimodal_gen/configs/pipeline_configs/wan.py +sglang/multimodal_gen/configs/pipeline_configs/zimage.py +sglang/multimodal_gen/configs/sample/__init__.py +sglang/multimodal_gen/configs/sample/diffusers_generic.py +sglang/multimodal_gen/configs/sample/flux.py +sglang/multimodal_gen/configs/sample/glmimage.py +sglang/multimodal_gen/configs/sample/helios.py +sglang/multimodal_gen/configs/sample/hunyuan.py +sglang/multimodal_gen/configs/sample/hunyuan3d.py +sglang/multimodal_gen/configs/sample/ltx_2.py +sglang/multimodal_gen/configs/sample/mova.py +sglang/multimodal_gen/configs/sample/qwenimage.py +sglang/multimodal_gen/configs/sample/sampling_params.py +sglang/multimodal_gen/configs/sample/teacache.py +sglang/multimodal_gen/configs/sample/wan.py +sglang/multimodal_gen/configs/sample/zimage.py +sglang/multimodal_gen/csrc/attn/vmoba_attn/README.md +sglang/multimodal_gen/csrc/attn/vmoba_attn/setup.py +sglang/multimodal_gen/csrc/attn/vmoba_attn/tests/test_vmoba_attn.py +sglang/multimodal_gen/csrc/attn/vmoba_attn/vmoba/__init__.py +sglang/multimodal_gen/csrc/attn/vmoba_attn/vmoba/vmoba.py +sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/__init__.py +sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/rasterizer.cpp +sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/rasterizer.h +sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/rasterizer_gpu.cu +sglang/multimodal_gen/csrc/render/mesh_processor/__init__.py +sglang/multimodal_gen/csrc/render/mesh_processor/mesh_processor.cpp +sglang/multimodal_gen/docs/quantization.md +sglang/multimodal_gen/runtime/launch_server.py +sglang/multimodal_gen/runtime/scheduler_client.py +sglang/multimodal_gen/runtime/server_args.py +sglang/multimodal_gen/runtime/cache/__init__.py +sglang/multimodal_gen/runtime/cache/cache_dit_integration.py +sglang/multimodal_gen/runtime/cache/teacache.py +sglang/multimodal_gen/runtime/distributed/__init__.py +sglang/multimodal_gen/runtime/distributed/communication_op.py +sglang/multimodal_gen/runtime/distributed/group_coordinator.py +sglang/multimodal_gen/runtime/distributed/parallel_groups.py +sglang/multimodal_gen/runtime/distributed/parallel_state.py +sglang/multimodal_gen/runtime/distributed/utils.py +sglang/multimodal_gen/runtime/distributed/device_communicators/__init__.py +sglang/multimodal_gen/runtime/distributed/device_communicators/base_device_communicator.py +sglang/multimodal_gen/runtime/distributed/device_communicators/cpu_communicator.py +sglang/multimodal_gen/runtime/distributed/device_communicators/cuda_communicator.py +sglang/multimodal_gen/runtime/distributed/device_communicators/pynccl.py +sglang/multimodal_gen/runtime/distributed/device_communicators/pynccl_wrapper.py +sglang/multimodal_gen/runtime/entrypoints/__init__.py +sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py +sglang/multimodal_gen/runtime/entrypoints/http_server.py +sglang/multimodal_gen/runtime/entrypoints/utils.py +sglang/multimodal_gen/runtime/entrypoints/cli/__init__.py +sglang/multimodal_gen/runtime/entrypoints/cli/cli_types.py +sglang/multimodal_gen/runtime/entrypoints/cli/generate.py +sglang/multimodal_gen/runtime/entrypoints/cli/main.py +sglang/multimodal_gen/runtime/entrypoints/cli/serve.py +sglang/multimodal_gen/runtime/entrypoints/cli/utils.py +sglang/multimodal_gen/runtime/entrypoints/openai/common_api.py +sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py +sglang/multimodal_gen/runtime/entrypoints/openai/mesh_api.py +sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py +sglang/multimodal_gen/runtime/entrypoints/openai/storage.py +sglang/multimodal_gen/runtime/entrypoints/openai/stores.py +sglang/multimodal_gen/runtime/entrypoints/openai/utils.py +sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py +sglang/multimodal_gen/runtime/entrypoints/post_training/__init__.py +sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py +sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py +sglang/multimodal_gen/runtime/layers/__init__.py +sglang/multimodal_gen/runtime/layers/activation.py +sglang/multimodal_gen/runtime/layers/custom_op.py +sglang/multimodal_gen/runtime/layers/elementwise.py +sglang/multimodal_gen/runtime/layers/layernorm.py +sglang/multimodal_gen/runtime/layers/linear.py +sglang/multimodal_gen/runtime/layers/mlp.py +sglang/multimodal_gen/runtime/layers/usp.py +sglang/multimodal_gen/runtime/layers/utils.py +sglang/multimodal_gen/runtime/layers/visual_embedding.py +sglang/multimodal_gen/runtime/layers/vocab_parallel_embedding.py +sglang/multimodal_gen/runtime/layers/attention/STA_configuration.py +sglang/multimodal_gen/runtime/layers/attention/__init__.py +sglang/multimodal_gen/runtime/layers/attention/layer.py +sglang/multimodal_gen/runtime/layers/attention/selector.py +sglang/multimodal_gen/runtime/layers/attention/turbo_layer.py +sglang/multimodal_gen/runtime/layers/attention/backends/__init__.py +sglang/multimodal_gen/runtime/layers/attention/backends/aiter.py +sglang/multimodal_gen/runtime/layers/attention/backends/attention_backend.py +sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py +sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn_2.py +sglang/multimodal_gen/runtime/layers/attention/backends/sage_attn.py +sglang/multimodal_gen/runtime/layers/attention/backends/sage_attn3.py +sglang/multimodal_gen/runtime/layers/attention/backends/sdpa.py +sglang/multimodal_gen/runtime/layers/attention/backends/sliding_tile_attn.py +sglang/multimodal_gen/runtime/layers/attention/backends/sparse_linear_attn.py +sglang/multimodal_gen/runtime/layers/attention/backends/sparse_video_gen_2_attn.py +sglang/multimodal_gen/runtime/layers/attention/backends/video_sparse_attn.py +sglang/multimodal_gen/runtime/layers/attention/backends/vmoba.py +sglang/multimodal_gen/runtime/layers/lora/linear.py +sglang/multimodal_gen/runtime/layers/quantization/__init__.py +sglang/multimodal_gen/runtime/layers/quantization/fp8.py +sglang/multimodal_gen/runtime/layers/quantization/nunchaku_linear.py +sglang/multimodal_gen/runtime/layers/quantization/configs/base_config.py +sglang/multimodal_gen/runtime/layers/quantization/configs/nunchaku_config.py +sglang/multimodal_gen/runtime/layers/rotary_embedding/__init__.py +sglang/multimodal_gen/runtime/layers/rotary_embedding/base.py +sglang/multimodal_gen/runtime/layers/rotary_embedding/factory.py +sglang/multimodal_gen/runtime/layers/rotary_embedding/mrope.py +sglang/multimodal_gen/runtime/layers/rotary_embedding/utils.py +sglang/multimodal_gen/runtime/loader/fsdp_load.py +sglang/multimodal_gen/runtime/loader/utils.py +sglang/multimodal_gen/runtime/loader/weight_utils.py +sglang/multimodal_gen/runtime/loader/weights_updater.py +sglang/multimodal_gen/runtime/loader/component_loaders/adapter_loader.py +sglang/multimodal_gen/runtime/loader/component_loaders/bridge_loader.py +sglang/multimodal_gen/runtime/loader/component_loaders/component_loader.py +sglang/multimodal_gen/runtime/loader/component_loaders/image_encoder_loader.py +sglang/multimodal_gen/runtime/loader/component_loaders/scheduler_loader.py +sglang/multimodal_gen/runtime/loader/component_loaders/text_encoder_loader.py +sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py +sglang/multimodal_gen/runtime/loader/component_loaders/vae_loader.py +sglang/multimodal_gen/runtime/loader/component_loaders/vl_encoder_loader.py +sglang/multimodal_gen/runtime/loader/component_loaders/vocoder_loader.py +sglang/multimodal_gen/runtime/managers/forward_context.py +sglang/multimodal_gen/runtime/managers/gpu_worker.py +sglang/multimodal_gen/runtime/managers/scheduler.py +sglang/multimodal_gen/runtime/models/__init__.py +sglang/multimodal_gen/runtime/models/parameter.py +sglang/multimodal_gen/runtime/models/registry.py +sglang/multimodal_gen/runtime/models/utils.py +sglang/multimodal_gen/runtime/models/vision_utils.py +sglang/multimodal_gen/runtime/models/adapter/ltx_2_connector.py +sglang/multimodal_gen/runtime/models/bridges/__init__.py +sglang/multimodal_gen/runtime/models/bridges/mova_dual_tower.py +sglang/multimodal_gen/runtime/models/dits/base.py +sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py +sglang/multimodal_gen/runtime/models/dits/flux.py +sglang/multimodal_gen/runtime/models/dits/flux_2.py +sglang/multimodal_gen/runtime/models/dits/glm_image.py +sglang/multimodal_gen/runtime/models/dits/helios.py +sglang/multimodal_gen/runtime/models/dits/hunyuan3d.py +sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py +sglang/multimodal_gen/runtime/models/dits/ltx_2.py +sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py +sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py +sglang/multimodal_gen/runtime/models/dits/qwen_image.py +sglang/multimodal_gen/runtime/models/dits/wanvideo.py +sglang/multimodal_gen/runtime/models/dits/zimage.py +sglang/multimodal_gen/runtime/models/encoders/base.py +sglang/multimodal_gen/runtime/models/encoders/bert.py +sglang/multimodal_gen/runtime/models/encoders/clip.py +sglang/multimodal_gen/runtime/models/encoders/gemma_3.py +sglang/multimodal_gen/runtime/models/encoders/hunyuan3d.py +sglang/multimodal_gen/runtime/models/encoders/llama.py +sglang/multimodal_gen/runtime/models/encoders/mistral_3.py +sglang/multimodal_gen/runtime/models/encoders/qwen2_5vl.py +sglang/multimodal_gen/runtime/models/encoders/qwen3.py +sglang/multimodal_gen/runtime/models/encoders/t5.py +sglang/multimodal_gen/runtime/models/encoders/vision.py +sglang/multimodal_gen/runtime/models/schedulers/base.py +sglang/multimodal_gen/runtime/models/schedulers/flow_match_pair.py +sglang/multimodal_gen/runtime/models/schedulers/hunyuan3d_scheduler.py +sglang/multimodal_gen/runtime/models/schedulers/scheduling_comfyui_passthrough.py +sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_match_euler_discrete.py +sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_unipc_multistep.py +sglang/multimodal_gen/runtime/models/schedulers/scheduling_helios.py +sglang/multimodal_gen/runtime/models/schedulers/scheduling_self_forcing_flow_match.py +sglang/multimodal_gen/runtime/models/schedulers/scheduling_unipc_multistep.py +sglang/multimodal_gen/runtime/models/vaes/autoencoder.py +sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_flux2.py +sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_qwenimage.py +sglang/multimodal_gen/runtime/models/vaes/common.py +sglang/multimodal_gen/runtime/models/vaes/dac.py +sglang/multimodal_gen/runtime/models/vaes/hunyuan3d_vae.py +sglang/multimodal_gen/runtime/models/vaes/hunyuanvae.py +sglang/multimodal_gen/runtime/models/vaes/ltx_2_audio.py +sglang/multimodal_gen/runtime/models/vaes/ltx_2_vae.py +sglang/multimodal_gen/runtime/models/vaes/wanvae.py +sglang/multimodal_gen/runtime/models/vaes/parallel/wan_common_utils.py +sglang/multimodal_gen/runtime/models/vaes/parallel/wan_dist_utils.py +sglang/multimodal_gen/runtime/models/vocoder/ltx_2_vocoder.py +sglang/multimodal_gen/runtime/pipelines/__init__.py +sglang/multimodal_gen/runtime/pipelines/comfyui_flux_pipeline.py +sglang/multimodal_gen/runtime/pipelines/comfyui_qwen_image_pipeline.py +sglang/multimodal_gen/runtime/pipelines/comfyui_zimage_pipeline.py +sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py +sglang/multimodal_gen/runtime/pipelines/flux.py +sglang/multimodal_gen/runtime/pipelines/flux_2.py +sglang/multimodal_gen/runtime/pipelines/flux_2_klein.py +sglang/multimodal_gen/runtime/pipelines/glm_image.py +sglang/multimodal_gen/runtime/pipelines/helios_pipeline.py +sglang/multimodal_gen/runtime/pipelines/hunyuan3d_pipeline.py +sglang/multimodal_gen/runtime/pipelines/hunyuan_pipeline.py +sglang/multimodal_gen/runtime/pipelines/ltx_2_pipeline.py +sglang/multimodal_gen/runtime/pipelines/mova_pipeline.py +sglang/multimodal_gen/runtime/pipelines/qwen_image.py +sglang/multimodal_gen/runtime/pipelines/wan_causal_dmd_pipeline.py +sglang/multimodal_gen/runtime/pipelines/wan_dmd_pipeline.py +sglang/multimodal_gen/runtime/pipelines/wan_i2v_dmd_pipeline.py +sglang/multimodal_gen/runtime/pipelines/wan_i2v_pipeline.py +sglang/multimodal_gen/runtime/pipelines/wan_pipeline.py +sglang/multimodal_gen/runtime/pipelines/zimage_pipeline.py +sglang/multimodal_gen/runtime/pipelines_core/__init__.py +sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py +sglang/multimodal_gen/runtime/pipelines_core/lora_format_adapter.py +sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py +sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py +sglang/multimodal_gen/runtime/pipelines_core/executors/parallel_executor.py +sglang/multimodal_gen/runtime/pipelines_core/executors/pipeline_executor.py +sglang/multimodal_gen/runtime/pipelines_core/executors/sync_executor.py +sglang/multimodal_gen/runtime/pipelines_core/stages/__init__.py +sglang/multimodal_gen/runtime/pipelines_core/stages/base.py +sglang/multimodal_gen/runtime/pipelines_core/stages/causal_denoising.py +sglang/multimodal_gen/runtime/pipelines_core/stages/comfyui_latent_preparation.py +sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py +sglang/multimodal_gen/runtime/pipelines_core/stages/decoding_av.py +sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py +sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py +sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py +sglang/multimodal_gen/runtime/pipelines_core/stages/encoding.py +sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_paint.py +sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_shape.py +sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py +sglang/multimodal_gen/runtime/pipelines_core/stages/input_validation.py +sglang/multimodal_gen/runtime/pipelines_core/stages/latent_preparation.py +sglang/multimodal_gen/runtime/pipelines_core/stages/latent_preparation_av.py +sglang/multimodal_gen/runtime/pipelines_core/stages/text_connector.py +sglang/multimodal_gen/runtime/pipelines_core/stages/text_encoding.py +sglang/multimodal_gen/runtime/pipelines_core/stages/timestep_preparation.py +sglang/multimodal_gen/runtime/pipelines_core/stages/validators.py +sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/glm_image.py +sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/helios_denoising.py +sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py +sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/qwen_image_layered.py +sglang/multimodal_gen/runtime/platforms/__init__.py +sglang/multimodal_gen/runtime/platforms/cpu.py +sglang/multimodal_gen/runtime/platforms/cuda.py +sglang/multimodal_gen/runtime/platforms/interface.py +sglang/multimodal_gen/runtime/platforms/mps.py +sglang/multimodal_gen/runtime/platforms/musa.py +sglang/multimodal_gen/runtime/platforms/npu.py +sglang/multimodal_gen/runtime/platforms/rocm.py +sglang/multimodal_gen/runtime/postprocess/__init__.py +sglang/multimodal_gen/runtime/postprocess/rife_interpolator.py +sglang/multimodal_gen/runtime/utils/common.py +sglang/multimodal_gen/runtime/utils/distributed.py +sglang/multimodal_gen/runtime/utils/hf_diffusers_utils.py +sglang/multimodal_gen/runtime/utils/layerwise_offload.py +sglang/multimodal_gen/runtime/utils/logging_utils.py +sglang/multimodal_gen/runtime/utils/mesh3d_utils.py +sglang/multimodal_gen/runtime/utils/perf_logger.py +sglang/multimodal_gen/runtime/utils/profiler.py +sglang/multimodal_gen/test/__init__.py +sglang/multimodal_gen/test/run_suite.py +sglang/multimodal_gen/test/slack_utils.py +sglang/multimodal_gen/test/test_utils.py +sglang/multimodal_gen/test/cli/test_generate_common.py +sglang/multimodal_gen/test/cli/test_generate_t2i_perf.py +sglang/multimodal_gen/test/scripts/gen_diffusion_ci_outputs.py +sglang/multimodal_gen/test/scripts/gen_perf_baselines.py +sglang/multimodal_gen/test/server/conftest.py +sglang/multimodal_gen/test/server/perf_baselines.json +sglang/multimodal_gen/test/server/test_server_2_gpu_a.py +sglang/multimodal_gen/test/server/test_server_2_gpu_b.py +sglang/multimodal_gen/test/server/test_server_a.py +sglang/multimodal_gen/test/server/test_server_b.py +sglang/multimodal_gen/test/server/test_server_common.py +sglang/multimodal_gen/test/server/test_server_utils.py +sglang/multimodal_gen/test/server/test_update_weights_from_disk.py +sglang/multimodal_gen/test/server/testcase_configs.py +sglang/multimodal_gen/test/server/ascend/perf_baselines_npu.json +sglang/multimodal_gen/test/server/ascend/test_server_1_npu.py +sglang/multimodal_gen/test/server/ascend/test_server_2_npu.py +sglang/multimodal_gen/test/server/ascend/testcase_configs_npu.py +sglang/multimodal_gen/test/test_files/launch_flux.json +sglang/multimodal_gen/test/test_files/launch_wan.json +sglang/multimodal_gen/test/unit/test_lora_format_adapter.py +sglang/multimodal_gen/test/unit/test_sampling_params_validate.py +sglang/multimodal_gen/test/unit/test_server_args_unit.py +sglang/multimodal_gen/test/unit/test_storage.py +sglang/multimodal_gen/third_party/__init__.py +sglang/multimodal_gen/third_party/pynvml.py +sglang/multimodal_gen/tools/convert_hf_to_fp8.py +sglang/srt/constants.py +sglang/srt/environ.py +sglang/srt/server_args.py +sglang/srt/server_args_config_parser.py +sglang/srt/batch_invariant_ops/__init__.py +sglang/srt/batch_invariant_ops/batch_invariant_ops.py +sglang/srt/batch_overlap/operations.py +sglang/srt/batch_overlap/operations_strategy.py +sglang/srt/batch_overlap/single_batch_overlap.py +sglang/srt/batch_overlap/two_batch_overlap.py +sglang/srt/checkpoint_engine/__init__.py +sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +sglang/srt/checkpoint_engine/update.py +sglang/srt/compilation/backend.py +sglang/srt/compilation/compilation_config.py +sglang/srt/compilation/compilation_counter.py +sglang/srt/compilation/compile.py +sglang/srt/compilation/compiler_interface.py +sglang/srt/compilation/cuda_piecewise_backend.py +sglang/srt/compilation/fix_functionalization.py +sglang/srt/compilation/fx_utils.py +sglang/srt/compilation/inductor_pass.py +sglang/srt/compilation/npu_piecewise_backend.py +sglang/srt/compilation/pass_manager.py +sglang/srt/compilation/piecewise_context_manager.py +sglang/srt/compilation/weak_ref_tensor.py +sglang/srt/configs/__init__.py +sglang/srt/configs/afmoe.py +sglang/srt/configs/bailing_hybrid.py +sglang/srt/configs/chatglm.py +sglang/srt/configs/dbrx.py +sglang/srt/configs/deepseek_ocr.py +sglang/srt/configs/deepseekvl2.py +sglang/srt/configs/device_config.py +sglang/srt/configs/dots_ocr.py +sglang/srt/configs/dots_vlm.py +sglang/srt/configs/exaone.py +sglang/srt/configs/falcon_h1.py +sglang/srt/configs/granitemoehybrid.py +sglang/srt/configs/internvl.py +sglang/srt/configs/janus_pro.py +sglang/srt/configs/jet_nemotron.py +sglang/srt/configs/jet_vlm.py +sglang/srt/configs/kimi_k25.py +sglang/srt/configs/kimi_linear.py +sglang/srt/configs/kimi_vl.py +sglang/srt/configs/kimi_vl_moonvit.py +sglang/srt/configs/lfm2.py +sglang/srt/configs/lfm2_moe.py +sglang/srt/configs/load_config.py +sglang/srt/configs/longcat_flash.py +sglang/srt/configs/mamba_utils.py +sglang/srt/configs/model_config.py +sglang/srt/configs/modelopt_config.py +sglang/srt/configs/nano_nemotron_vl.py +sglang/srt/configs/nemotron_h.py +sglang/srt/configs/olmo3.py +sglang/srt/configs/points_v15_chat.py +sglang/srt/configs/qwen3_5.py +sglang/srt/configs/qwen3_next.py +sglang/srt/configs/qwen3_omni.py +sglang/srt/configs/qwen3_vl.py +sglang/srt/configs/radio.py +sglang/srt/configs/step3_vl.py +sglang/srt/configs/step3p5.py +sglang/srt/configs/update_config.py +sglang/srt/configs/utils.py +sglang/srt/connector/__init__.py +sglang/srt/connector/base_connector.py +sglang/srt/connector/redis.py +sglang/srt/connector/remote_instance.py +sglang/srt/connector/s3.py +sglang/srt/connector/utils.py +sglang/srt/connector/serde/__init__.py +sglang/srt/connector/serde/safe_serde.py +sglang/srt/connector/serde/serde.py +sglang/srt/constrained/base_grammar_backend.py +sglang/srt/constrained/grammar_manager.py +sglang/srt/constrained/llguidance_backend.py +sglang/srt/constrained/outlines_backend.py +sglang/srt/constrained/outlines_jump_forward.py +sglang/srt/constrained/reasoner_grammar_backend.py +sglang/srt/constrained/utils.py +sglang/srt/constrained/xgrammar_backend.py +sglang/srt/constrained/triton_ops/bitmask_ops.py +sglang/srt/debug_utils/__init__.py +sglang/srt/debug_utils/cuda_coredump.py +sglang/srt/debug_utils/dump_comparator.py +sglang/srt/debug_utils/dump_loader.py +sglang/srt/debug_utils/dumper.py +sglang/srt/debug_utils/log_parser.py +sglang/srt/debug_utils/model_truncator.py +sglang/srt/debug_utils/tensor_dump_forward_hook.py +sglang/srt/debug_utils/text_comparator.py +sglang/srt/debug_utils/comparator/__init__.py +sglang/srt/debug_utils/comparator/__main__.py +sglang/srt/debug_utils/comparator/bundle_comparator.py +sglang/srt/debug_utils/comparator/bundle_matcher.py +sglang/srt/debug_utils/comparator/display.py +sglang/srt/debug_utils/comparator/dp_utils.py +sglang/srt/debug_utils/comparator/entrypoint.py +sglang/srt/debug_utils/comparator/log_sink.py +sglang/srt/debug_utils/comparator/meta_overrider.py +sglang/srt/debug_utils/comparator/output_formatter.py +sglang/srt/debug_utils/comparator/output_types.py +sglang/srt/debug_utils/comparator/per_token_visualizer.py +sglang/srt/debug_utils/comparator/preset.py +sglang/srt/debug_utils/comparator/report_sink.py +sglang/srt/debug_utils/comparator/utils.py +sglang/srt/debug_utils/comparator/aligner/__init__.py +sglang/srt/debug_utils/comparator/aligner/axis_aligner.py +sglang/srt/debug_utils/comparator/aligner/entrypoint/__init__.py +sglang/srt/debug_utils/comparator/aligner/entrypoint/executor.py +sglang/srt/debug_utils/comparator/aligner/entrypoint/planner.py +sglang/srt/debug_utils/comparator/aligner/entrypoint/traced_types.py +sglang/srt/debug_utils/comparator/aligner/entrypoint/types.py +sglang/srt/debug_utils/comparator/aligner/reorderer/__init__.py +sglang/srt/debug_utils/comparator/aligner/reorderer/executor.py +sglang/srt/debug_utils/comparator/aligner/reorderer/planner.py +sglang/srt/debug_utils/comparator/aligner/reorderer/types.py +sglang/srt/debug_utils/comparator/aligner/token_aligner/__init__.py +sglang/srt/debug_utils/comparator/aligner/token_aligner/entrypoint.py +sglang/srt/debug_utils/comparator/aligner/token_aligner/concat_steps/__init__.py +sglang/srt/debug_utils/comparator/aligner/token_aligner/concat_steps/executor.py +sglang/srt/debug_utils/comparator/aligner/token_aligner/concat_steps/thd_seq_lens_loader.py +sglang/srt/debug_utils/comparator/aligner/token_aligner/smart/__init__.py +sglang/srt/debug_utils/comparator/aligner/token_aligner/smart/aux_loader.py +sglang/srt/debug_utils/comparator/aligner/token_aligner/smart/aux_plugins.py +sglang/srt/debug_utils/comparator/aligner/token_aligner/smart/executor.py +sglang/srt/debug_utils/comparator/aligner/token_aligner/smart/planner.py +sglang/srt/debug_utils/comparator/aligner/token_aligner/smart/seq_info_builder.py +sglang/srt/debug_utils/comparator/aligner/token_aligner/smart/types.py +sglang/srt/debug_utils/comparator/aligner/unsharder/__init__.py +sglang/srt/debug_utils/comparator/aligner/unsharder/executor.py +sglang/srt/debug_utils/comparator/aligner/unsharder/parallel_info.py +sglang/srt/debug_utils/comparator/aligner/unsharder/planner.py +sglang/srt/debug_utils/comparator/aligner/unsharder/types.py +sglang/srt/debug_utils/comparator/dims_spec/__init__.py +sglang/srt/debug_utils/comparator/dims_spec/comment_parser.py +sglang/srt/debug_utils/comparator/dims_spec/dim_parser.py +sglang/srt/debug_utils/comparator/dims_spec/dims_parser.py +sglang/srt/debug_utils/comparator/dims_spec/modifier_parser.py +sglang/srt/debug_utils/comparator/dims_spec/tensor_naming.py +sglang/srt/debug_utils/comparator/dims_spec/types.py +sglang/srt/debug_utils/comparator/tensor_comparator/__init__.py +sglang/srt/debug_utils/comparator/tensor_comparator/comparator.py +sglang/srt/debug_utils/comparator/tensor_comparator/formatter.py +sglang/srt/debug_utils/comparator/tensor_comparator/types.py +sglang/srt/debug_utils/comparator/visualizer/__init__.py +sglang/srt/debug_utils/comparator/visualizer/figure.py +sglang/srt/debug_utils/comparator/visualizer/panels.py +sglang/srt/debug_utils/comparator/visualizer/preprocessing.py +sglang/srt/debug_utils/schedule_simulator/__init__.py +sglang/srt/debug_utils/schedule_simulator/__main__.py +sglang/srt/debug_utils/schedule_simulator/entrypoint.py +sglang/srt/debug_utils/schedule_simulator/gpu_state.py +sglang/srt/debug_utils/schedule_simulator/metrics.py +sglang/srt/debug_utils/schedule_simulator/request.py +sglang/srt/debug_utils/schedule_simulator/simulator.py +sglang/srt/debug_utils/schedule_simulator/data_source/__init__.py +sglang/srt/debug_utils/schedule_simulator/data_source/data_loader.py +sglang/srt/debug_utils/schedule_simulator/data_source/data_synthesis.py +sglang/srt/debug_utils/schedule_simulator/routers/__init__.py +sglang/srt/debug_utils/schedule_simulator/routers/base.py +sglang/srt/debug_utils/schedule_simulator/routers/random_router.py +sglang/srt/debug_utils/schedule_simulator/routers/round_robin_router.py +sglang/srt/debug_utils/schedule_simulator/routers/sticky_router.py +sglang/srt/debug_utils/schedule_simulator/schedulers/__init__.py +sglang/srt/debug_utils/schedule_simulator/schedulers/base.py +sglang/srt/debug_utils/schedule_simulator/schedulers/fifo_scheduler.py +sglang/srt/debug_utils/source_patcher/__init__.py +sglang/srt/debug_utils/source_patcher/code_patcher.py +sglang/srt/debug_utils/source_patcher/source_editor.py +sglang/srt/debug_utils/source_patcher/types.py +sglang/srt/disaggregation/decode.py +sglang/srt/disaggregation/decode_kvcache_offload_manager.py +sglang/srt/disaggregation/decode_schedule_batch_mixin.py +sglang/srt/disaggregation/encode_grpc_server.py +sglang/srt/disaggregation/encode_receiver.py +sglang/srt/disaggregation/encode_server.py +sglang/srt/disaggregation/kv_events.py +sglang/srt/disaggregation/prefill.py +sglang/srt/disaggregation/utils.py +sglang/srt/disaggregation/ascend/__init__.py +sglang/srt/disaggregation/ascend/conn.py +sglang/srt/disaggregation/ascend/transfer_engine.py +sglang/srt/disaggregation/base/__init__.py +sglang/srt/disaggregation/base/conn.py +sglang/srt/disaggregation/common/__init__.py +sglang/srt/disaggregation/common/conn.py +sglang/srt/disaggregation/common/utils.py +sglang/srt/disaggregation/fake/__init__.py +sglang/srt/disaggregation/fake/conn.py +sglang/srt/disaggregation/mooncake/__init__.py +sglang/srt/disaggregation/mooncake/conn.py +sglang/srt/disaggregation/mooncake/utils.py +sglang/srt/disaggregation/mori/__init__.py +sglang/srt/disaggregation/mori/conn.py +sglang/srt/disaggregation/nixl/__init__.py +sglang/srt/disaggregation/nixl/conn.py +sglang/srt/distributed/__init__.py +sglang/srt/distributed/communication_op.py +sglang/srt/distributed/naive_distributed.py +sglang/srt/distributed/parallel_state.py +sglang/srt/distributed/utils.py +sglang/srt/distributed/device_communicators/all_reduce_utils.py +sglang/srt/distributed/device_communicators/cuda_wrapper.py +sglang/srt/distributed/device_communicators/custom_all_reduce.py +sglang/srt/distributed/device_communicators/custom_all_reduce_ops.py +sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +sglang/srt/distributed/device_communicators/hpu_communicator.py +sglang/srt/distributed/device_communicators/mooncake_transfer_engine.py +sglang/srt/distributed/device_communicators/npu_communicator.py +sglang/srt/distributed/device_communicators/pymscclpp.py +sglang/srt/distributed/device_communicators/pynccl.py +sglang/srt/distributed/device_communicators/pynccl_allocator.py +sglang/srt/distributed/device_communicators/pynccl_wrapper.py +sglang/srt/distributed/device_communicators/quick_all_reduce.py +sglang/srt/distributed/device_communicators/shm_broadcast.py +sglang/srt/distributed/device_communicators/torch_symm_mem.py +sglang/srt/distributed/device_communicators/xpu_communicator.py +sglang/srt/dllm/config.py +sglang/srt/dllm/algorithm/__init__.py +sglang/srt/dllm/algorithm/base.py +sglang/srt/dllm/algorithm/joint_threshold.py +sglang/srt/dllm/algorithm/low_confidence.py +sglang/srt/dllm/mixin/req.py +sglang/srt/dllm/mixin/scheduler.py +sglang/srt/elastic_ep/elastic_ep.py +sglang/srt/elastic_ep/expert_backup_client.py +sglang/srt/elastic_ep/expert_backup_manager.py +sglang/srt/entrypoints/EngineBase.py +sglang/srt/entrypoints/context.py +sglang/srt/entrypoints/engine.py +sglang/srt/entrypoints/grpc_server.py +sglang/srt/entrypoints/harmony_utils.py +sglang/srt/entrypoints/http_server.py +sglang/srt/entrypoints/http_server_engine.py +sglang/srt/entrypoints/ssl_utils.py +sglang/srt/entrypoints/tool.py +sglang/srt/entrypoints/v1_loads.py +sglang/srt/entrypoints/warmup.py +sglang/srt/entrypoints/anthropic/__init__.py +sglang/srt/entrypoints/anthropic/protocol.py +sglang/srt/entrypoints/anthropic/serving.py +sglang/srt/entrypoints/ollama/README.md +sglang/srt/entrypoints/ollama/__init__.py +sglang/srt/entrypoints/ollama/protocol.py +sglang/srt/entrypoints/ollama/serving.py +sglang/srt/entrypoints/ollama/smart_router.py +sglang/srt/entrypoints/openai/__init__.py +sglang/srt/entrypoints/openai/encoding_dsv32.py +sglang/srt/entrypoints/openai/protocol.py +sglang/srt/entrypoints/openai/serving_base.py +sglang/srt/entrypoints/openai/serving_chat.py +sglang/srt/entrypoints/openai/serving_classify.py +sglang/srt/entrypoints/openai/serving_completions.py +sglang/srt/entrypoints/openai/serving_embedding.py +sglang/srt/entrypoints/openai/serving_rerank.py +sglang/srt/entrypoints/openai/serving_responses.py +sglang/srt/entrypoints/openai/serving_score.py +sglang/srt/entrypoints/openai/serving_tokenize.py +sglang/srt/entrypoints/openai/serving_transcription.py +sglang/srt/entrypoints/openai/tool_server.py +sglang/srt/entrypoints/openai/usage_processor.py +sglang/srt/entrypoints/openai/utils.py +sglang/srt/eplb/__init__.py +sglang/srt/eplb/eplb_manager.py +sglang/srt/eplb/expert_distribution.py +sglang/srt/eplb/expert_location.py +sglang/srt/eplb/expert_location_dispatch.py +sglang/srt/eplb/expert_location_updater.py +sglang/srt/eplb/eplb_algorithms/__init__.py +sglang/srt/eplb/eplb_algorithms/deepseek.py +sglang/srt/eplb/eplb_algorithms/deepseek_vec.py +sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +sglang/srt/eplb/eplb_simulator/__init__.py +sglang/srt/eplb/eplb_simulator/reader.py +sglang/srt/function_call/base_format_detector.py +sglang/srt/function_call/core_types.py +sglang/srt/function_call/deepseekv31_detector.py +sglang/srt/function_call/deepseekv32_detector.py +sglang/srt/function_call/deepseekv3_detector.py +sglang/srt/function_call/function_call_parser.py +sglang/srt/function_call/gigachat3_detector.py +sglang/srt/function_call/glm47_moe_detector.py +sglang/srt/function_call/glm4_moe_detector.py +sglang/srt/function_call/gpt_oss_detector.py +sglang/srt/function_call/hermes_detector.py +sglang/srt/function_call/internlm_detector.py +sglang/srt/function_call/json_array_parser.py +sglang/srt/function_call/kimik2_detector.py +sglang/srt/function_call/lfm2_detector.py +sglang/srt/function_call/llama32_detector.py +sglang/srt/function_call/mimo_detector.py +sglang/srt/function_call/minimax_m2.py +sglang/srt/function_call/mistral_detector.py +sglang/srt/function_call/pythonic_detector.py +sglang/srt/function_call/qwen25_detector.py +sglang/srt/function_call/qwen3_coder_detector.py +sglang/srt/function_call/step3_detector.py +sglang/srt/function_call/trinity_detector.py +sglang/srt/function_call/utils.py +sglang/srt/grpc/__init__.py +sglang/srt/grpc/grpc_request_manager.py +sglang/srt/grpc/health_servicer.py +sglang/srt/grpc/scheduler_launcher.py +sglang/srt/grpc/utils.py +sglang/srt/hardware_backend/npu/allocator_npu.py +sglang/srt/hardware_backend/npu/cmo.py +sglang/srt/hardware_backend/npu/memory_pool_npu.py +sglang/srt/hardware_backend/npu/utils.py +sglang/srt/hardware_backend/npu/attention/ascend_backend.py +sglang/srt/hardware_backend/npu/attention/ascend_torch_native_backend.py +sglang/srt/hardware_backend/npu/attention/mla_preprocess.py +sglang/srt/hardware_backend/npu/graph_runner/eagle_draft_extend_npu_graph_runner.py +sglang/srt/hardware_backend/npu/graph_runner/eagle_draft_npu_graph_runner.py +sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py +sglang/srt/hardware_backend/npu/modules/deepseek_v2_attention_mla_npu.py +sglang/srt/hardware_backend/npu/moe/topk.py +sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py +sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py +sglang/srt/layers/activation.py +sglang/srt/layers/amx_utils.py +sglang/srt/layers/communicator.py +sglang/srt/layers/communicator_nsa_cp.py +sglang/srt/layers/dp_attention.py +sglang/srt/layers/elementwise.py +sglang/srt/layers/flashinfer_comm_fusion.py +sglang/srt/layers/int4fp8_utils.py +sglang/srt/layers/layernorm.py +sglang/srt/layers/linear.py +sglang/srt/layers/logits_processor.py +sglang/srt/layers/model_parallel.py +sglang/srt/layers/modelopt_utils.py +sglang/srt/layers/multimodal.py +sglang/srt/layers/parameter.py +sglang/srt/layers/pooler.py +sglang/srt/layers/radix_attention.py +sglang/srt/layers/radix_linear_attention.py +sglang/srt/layers/rocm_linear_utils.py +sglang/srt/layers/sampler.py +sglang/srt/layers/sparse_pooler.py +sglang/srt/layers/torchao_utils.py +sglang/srt/layers/vocab_parallel_embedding.py +sglang/srt/layers/attention/aiter_backend.py +sglang/srt/layers/attention/attention_registry.py +sglang/srt/layers/attention/base_attn_backend.py +sglang/srt/layers/attention/cutlass_mla_backend.py +sglang/srt/layers/attention/double_sparsity_backend.py +sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +sglang/srt/layers/attention/flashattention_backend.py +sglang/srt/layers/attention/flashinfer_backend.py +sglang/srt/layers/attention/flashinfer_mla_backend.py +sglang/srt/layers/attention/flashmla_backend.py +sglang/srt/layers/attention/hybrid_attn_backend.py +sglang/srt/layers/attention/hybrid_linear_attn_backend.py +sglang/srt/layers/attention/intel_amx_backend.py +sglang/srt/layers/attention/merge_state.py +sglang/srt/layers/attention/nsa_backend.py +sglang/srt/layers/attention/tbo_backend.py +sglang/srt/layers/attention/torch_flex_backend.py +sglang/srt/layers/attention/torch_native_backend.py +sglang/srt/layers/attention/triton_backend.py +sglang/srt/layers/attention/trtllm_mha_backend.py +sglang/srt/layers/attention/trtllm_mla_backend.py +sglang/srt/layers/attention/utils.py +sglang/srt/layers/attention/vision.py +sglang/srt/layers/attention/vision_utils.py +sglang/srt/layers/attention/wave_backend.py +sglang/srt/layers/attention/xpu_backend.py +sglang/srt/layers/attention/fla/chunk.py +sglang/srt/layers/attention/fla/chunk_delta_h.py +sglang/srt/layers/attention/fla/chunk_o.py +sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +sglang/srt/layers/attention/fla/cumsum.py +sglang/srt/layers/attention/fla/fused_gdn_gating.py +sglang/srt/layers/attention/fla/fused_norm_gate.py +sglang/srt/layers/attention/fla/fused_recurrent.py +sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +sglang/srt/layers/attention/fla/index.py +sglang/srt/layers/attention/fla/kda.py +sglang/srt/layers/attention/fla/l2norm.py +sglang/srt/layers/attention/fla/layernorm_gated.py +sglang/srt/layers/attention/fla/op.py +sglang/srt/layers/attention/fla/solve_tril.py +sglang/srt/layers/attention/fla/utils.py +sglang/srt/layers/attention/fla/wy_fast.py +sglang/srt/layers/attention/linear/__init__.py +sglang/srt/layers/attention/linear/gdn_backend.py +sglang/srt/layers/attention/linear/kda_backend.py +sglang/srt/layers/attention/linear/lightning_attn.py +sglang/srt/layers/attention/linear/lightning_backend.py +sglang/srt/layers/attention/linear/linear_metadata.py +sglang/srt/layers/attention/linear/seg_la.py +sglang/srt/layers/attention/linear/utils.py +sglang/srt/layers/attention/linear/kernels/__init__.py +sglang/srt/layers/attention/linear/kernels/gdn_cutedsl.py +sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py +sglang/srt/layers/attention/linear/kernels/gdn_triton.py +sglang/srt/layers/attention/linear/kernels/kda_triton.py +sglang/srt/layers/attention/linear/kernels/kernel_backend.py +sglang/srt/layers/attention/mamba/causal_conv1d.py +sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +sglang/srt/layers/attention/mamba/mamba.py +sglang/srt/layers/attention/mamba/mamba2_metadata.py +sglang/srt/layers/attention/mamba/mamba_state_scatter_triton.py +sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +sglang/srt/layers/attention/mamba/ops/__init__.py +sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +sglang/srt/layers/attention/mamba/ops/ssd_combined.py +sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +sglang/srt/layers/attention/mamba/ops/ssu_dispatch.py +sglang/srt/layers/attention/nsa/dequant_k_cache.py +sglang/srt/layers/attention/nsa/index_buf_accessor.py +sglang/srt/layers/attention/nsa/nsa_backend_mtp_precompute.py +sglang/srt/layers/attention/nsa/nsa_indexer.py +sglang/srt/layers/attention/nsa/nsa_mtp_verification.py +sglang/srt/layers/attention/nsa/quant_k_cache.py +sglang/srt/layers/attention/nsa/tilelang_kernel.py +sglang/srt/layers/attention/nsa/transform_index.py +sglang/srt/layers/attention/nsa/triton_kernel.py +sglang/srt/layers/attention/nsa/utils.py +sglang/srt/layers/attention/triton_ops/decode_attention.py +sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +sglang/srt/layers/attention/triton_ops/extend_attention.py +sglang/srt/layers/attention/triton_ops/merge_state.py +sglang/srt/layers/attention/triton_ops/prefill_attention.py +sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +sglang/srt/layers/attention/triton_ops/trtllm_fp8_kv_kernel.py +sglang/srt/layers/attention/wave_ops/decode_attention.py +sglang/srt/layers/attention/wave_ops/extend_attention.py +sglang/srt/layers/attention/wave_ops/prefill_attention.py +sglang/srt/layers/deep_gemm_wrapper/__init__.py +sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +sglang/srt/layers/deep_gemm_wrapper/configurer.py +sglang/srt/layers/deep_gemm_wrapper/entrypoint.py +sglang/srt/layers/moe/__init__.py +sglang/srt/layers/moe/cutlass_moe.py +sglang/srt/layers/moe/cutlass_moe_params.py +sglang/srt/layers/moe/cutlass_w4a8_moe.py +sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +sglang/srt/layers/moe/fused_moe_native.py +sglang/srt/layers/moe/kt_ep_wrapper.py +sglang/srt/layers/moe/rocm_moe_utils.py +sglang/srt/layers/moe/routed_experts_capturer.py +sglang/srt/layers/moe/router.py +sglang/srt/layers/moe/topk.py +sglang/srt/layers/moe/utils.py +sglang/srt/layers/moe/ep_moe/__init__.py +sglang/srt/layers/moe/ep_moe/kernels.py +sglang/srt/layers/moe/ep_moe/layer.py +sglang/srt/layers/moe/fused_moe_triton/__init__.py +sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py +sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +sglang/srt/layers/moe/fused_moe_triton/layer.py +sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +sglang/srt/layers/moe/fused_moe_triton/configs/README.md +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_L40S.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=192,device_name=NVIDIA_H20.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=192,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=96,device_name=NVIDIA_H20.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=1856,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=1856,device_name=NVIDIA_L40S.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=928,device_name=NVIDIA_L40S.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=160,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8,per_channel_quant=True.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=128,device_name=,dtype=int4_w4a16.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=128,device_name=,dtype=int4_w4a16_down.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=1344,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=1856,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=1856,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=232,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=232,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=2688,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=464,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=464,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=928,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=928,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=16,N=1856,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=16,N=1856,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=16,N=2048,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=161,N=192,device_name=NVIDIA_B200,dtype=fp8_w8a8,per_channel_quant=True.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=161,N=192,device_name=NVIDIA_H20,dtype=fp8_w8a8,per_channel_quant=True.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=161,N=192,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,per_channel_quant=True.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=161,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8,per_channel_quant=True.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=161,N=192,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=161,N=192,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,per_channel_quant=True.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=161,N=192,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=161,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,per_channel_quant=True.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=161,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,per_channel_quant=True.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=20,N=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,per_channel_quant=True.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=20,N=1536,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=256,N=1344,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=256,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=256,N=2688,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=256,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=256,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=256,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=256,N=672,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=256,N=672,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=32,N=1856,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=32,N=1856,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=32,N=928,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=32,N=928,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=40,N=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,per_channel_quant=True.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=512,N=128,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=512,N=128,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=512,N=1344,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=512,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=512,N=256,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=512,N=256,device_name=NVIDIA_H200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=512,N=2688,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=512,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=512,N=336,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=512,N=336,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=512,N=672,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=512,N=672,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=64,N=1856,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=64,N=1856,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=64,N=2688,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=64,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=64,N=464,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=64,N=464,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=64,N=928,device_name=NVIDIA_B200.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=64,N=928,device_name=NVIDIA_H100_80GB_HBM3.json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=80,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=80,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +sglang/srt/layers/moe/moe_runner/__init__.py +sglang/srt/layers/moe/moe_runner/base.py +sglang/srt/layers/moe/moe_runner/deep_gemm.py +sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +sglang/srt/layers/moe/moe_runner/marlin.py +sglang/srt/layers/moe/moe_runner/runner.py +sglang/srt/layers/moe/moe_runner/triton.py +sglang/srt/layers/moe/moe_runner/triton_kernels.py +sglang/srt/layers/moe/token_dispatcher/__init__.py +sglang/srt/layers/moe/token_dispatcher/base.py +sglang/srt/layers/moe/token_dispatcher/deepep.py +sglang/srt/layers/moe/token_dispatcher/flashinfer.py +sglang/srt/layers/moe/token_dispatcher/flashinfer_utils.py +sglang/srt/layers/moe/token_dispatcher/fuseep.py +sglang/srt/layers/moe/token_dispatcher/mooncake.py +sglang/srt/layers/moe/token_dispatcher/moriep.py +sglang/srt/layers/moe/token_dispatcher/standard.py +sglang/srt/layers/quantization/__init__.py +sglang/srt/layers/quantization/auto_round.py +sglang/srt/layers/quantization/awq.py +sglang/srt/layers/quantization/awq_triton.py +sglang/srt/layers/quantization/base_config.py +sglang/srt/layers/quantization/base_scheme.py +sglang/srt/layers/quantization/bitsandbytes.py +sglang/srt/layers/quantization/blockwise_int8.py +sglang/srt/layers/quantization/fp4_utils.py +sglang/srt/layers/quantization/fp8.py +sglang/srt/layers/quantization/fp8_kernel.py +sglang/srt/layers/quantization/fp8_utils.py +sglang/srt/layers/quantization/fpgemm_fp8.py +sglang/srt/layers/quantization/gguf.py +sglang/srt/layers/quantization/gptq.py +sglang/srt/layers/quantization/int8_kernel.py +sglang/srt/layers/quantization/int8_utils.py +sglang/srt/layers/quantization/kv_cache.py +sglang/srt/layers/quantization/kvfp4_tensor.py +sglang/srt/layers/quantization/marlin_utils.py +sglang/srt/layers/quantization/marlin_utils_fp8.py +sglang/srt/layers/quantization/modelopt_quant.py +sglang/srt/layers/quantization/moe_wna16.py +sglang/srt/layers/quantization/mxfp4.py +sglang/srt/layers/quantization/mxfp4_tensor.py +sglang/srt/layers/quantization/petit.py +sglang/srt/layers/quantization/petit_utils.py +sglang/srt/layers/quantization/qoq.py +sglang/srt/layers/quantization/quark_int4fp8_moe.py +sglang/srt/layers/quantization/rocm_mxfp4_utils.py +sglang/srt/layers/quantization/unquant.py +sglang/srt/layers/quantization/utils.py +sglang/srt/layers/quantization/w4afp8.py +sglang/srt/layers/quantization/w8a8_fp8.py +sglang/srt/layers/quantization/w8a8_int8.py +sglang/srt/layers/quantization/compressed_tensors/README.md +sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +sglang/srt/layers/quantization/compressed_tensors/utils.py +sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxint4_moe.py +sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py +sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int8_moe.py +sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8_moe.py +sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8_moe.py +sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py +sglang/srt/layers/quantization/configs/N=1280,K=5120,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=5120,K=1024,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=5120,K=3200,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=6400,K=5120,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +sglang/srt/layers/quantization/configs/README.md +sglang/srt/layers/quantization/modelslim/README.md +sglang/srt/layers/quantization/modelslim/modelslim.py +sglang/srt/layers/quantization/modelslim/schemes/__init__.py +sglang/srt/layers/quantization/modelslim/schemes/modelslim_scheme.py +sglang/srt/layers/quantization/modelslim/schemes/modelslim_w4a4_int4.py +sglang/srt/layers/quantization/modelslim/schemes/modelslim_w4a8_int8_moe.py +sglang/srt/layers/quantization/modelslim/schemes/modelslim_w8a8_int8.py +sglang/srt/layers/quantization/modelslim/schemes/modelslim_w8a8_int8_moe.py +sglang/srt/layers/quantization/quark/__init__.py +sglang/srt/layers/quantization/quark/quark.py +sglang/srt/layers/quantization/quark/utils.py +sglang/srt/layers/quantization/quark/schemes/__init__.py +sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4_moe.py +sglang/srt/layers/quantization/quark/schemes/quark_w8a8_fp8.py +sglang/srt/layers/quantization/quark/schemes/quark_w8a8_fp8_moe.py +sglang/srt/layers/rotary_embedding/__init__.py +sglang/srt/layers/rotary_embedding/base.py +sglang/srt/layers/rotary_embedding/factory.py +sglang/srt/layers/rotary_embedding/mrope.py +sglang/srt/layers/rotary_embedding/mrope_rope_index.py +sglang/srt/layers/rotary_embedding/rope_variant.py +sglang/srt/layers/rotary_embedding/triton_kernels.py +sglang/srt/layers/rotary_embedding/utils.py +sglang/srt/layers/rotary_embedding/yarn.py +sglang/srt/layers/utils/__init__.py +sglang/srt/layers/utils/common.py +sglang/srt/layers/utils/hash.py +sglang/srt/layers/utils/logprob.py +sglang/srt/layers/utils/multi_platform.py +sglang/srt/lora/eviction_policy.py +sglang/srt/lora/layers.py +sglang/srt/lora/lora.py +sglang/srt/lora/lora_config.py +sglang/srt/lora/lora_manager.py +sglang/srt/lora/lora_overlap_loader.py +sglang/srt/lora/lora_registry.py +sglang/srt/lora/mem_pool.py +sglang/srt/lora/utils.py +sglang/srt/lora/backend/ascend_backend.py +sglang/srt/lora/backend/base_backend.py +sglang/srt/lora/backend/chunked_backend.py +sglang/srt/lora/backend/lora_registry.py +sglang/srt/lora/backend/torch_backend.py +sglang/srt/lora/backend/triton_backend.py +sglang/srt/lora/torch_ops/__init__.py +sglang/srt/lora/torch_ops/lora_ops.py +sglang/srt/lora/triton_ops/__init__.py +sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +sglang/srt/lora/triton_ops/embedding_lora_a.py +sglang/srt/lora/triton_ops/gate_up_lora_b.py +sglang/srt/lora/triton_ops/qkv_lora_b.py +sglang/srt/lora/triton_ops/sgemm_lora_a.py +sglang/srt/lora/triton_ops/sgemm_lora_b.py +sglang/srt/managers/async_dynamic_batch_tokenizer.py +sglang/srt/managers/async_mm_data_processor.py +sglang/srt/managers/cache_controller.py +sglang/srt/managers/configure_logging.py +sglang/srt/managers/data_parallel_controller.py +sglang/srt/managers/detokenizer_manager.py +sglang/srt/managers/disagg_service.py +sglang/srt/managers/io_struct.py +sglang/srt/managers/mm_utils.py +sglang/srt/managers/multi_tokenizer_mixin.py +sglang/srt/managers/multimodal_processor.py +sglang/srt/managers/overlap_utils.py +sglang/srt/managers/prefill_delayer.py +sglang/srt/managers/schedule_batch.py +sglang/srt/managers/schedule_policy.py +sglang/srt/managers/scheduler.py +sglang/srt/managers/scheduler_dp_attn_mixin.py +sglang/srt/managers/scheduler_input_blocker.py +sglang/srt/managers/scheduler_output_processor_mixin.py +sglang/srt/managers/scheduler_pp_mixin.py +sglang/srt/managers/scheduler_profiler_mixin.py +sglang/srt/managers/scheduler_recv_skipper.py +sglang/srt/managers/scheduler_runtime_checker_mixin.py +sglang/srt/managers/scheduler_update_weights_mixin.py +sglang/srt/managers/session_controller.py +sglang/srt/managers/template_manager.py +sglang/srt/managers/tokenizer_communicator_mixin.py +sglang/srt/managers/tokenizer_manager.py +sglang/srt/managers/tokenizer_manager_multiitem_mixin.py +sglang/srt/managers/tp_worker.py +sglang/srt/managers/utils.py +sglang/srt/mem_cache/allocator.py +sglang/srt/mem_cache/base_prefix_cache.py +sglang/srt/mem_cache/cache_init_params.py +sglang/srt/mem_cache/chunk_cache.py +sglang/srt/mem_cache/common.py +sglang/srt/mem_cache/evict_policy.py +sglang/srt/mem_cache/flush_cache.py +sglang/srt/mem_cache/hicache_storage.py +sglang/srt/mem_cache/hiradix_cache.py +sglang/srt/mem_cache/mamba_radix_cache.py +sglang/srt/mem_cache/memory_pool.py +sglang/srt/mem_cache/memory_pool_host.py +sglang/srt/mem_cache/multimodal_cache.py +sglang/srt/mem_cache/radix_cache.py +sglang/srt/mem_cache/radix_cache_cpp.py +sglang/srt/mem_cache/session_aware_cache.py +sglang/srt/mem_cache/swa_memory_pool.py +sglang/srt/mem_cache/swa_radix_cache.py +sglang/srt/mem_cache/utils.py +sglang/srt/mem_cache/cpp_radix_tree/.clang-format +sglang/srt/mem_cache/cpp_radix_tree/common.h +sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +sglang/srt/mem_cache/cpp_radix_tree/tree_v2.cpp +sglang/srt/mem_cache/cpp_radix_tree/tree_v2.h +sglang/srt/mem_cache/cpp_radix_tree/tree_v2_binding.cpp +sglang/srt/mem_cache/cpp_radix_tree/tree_v2_debug.cpp +sglang/srt/mem_cache/cpp_radix_tree/tree_v2_impl.h +sglang/srt/mem_cache/cpp_radix_tree/tree_v2_node.h +sglang/srt/mem_cache/sparsity/__init__.py +sglang/srt/mem_cache/sparsity/factory.py +sglang/srt/mem_cache/sparsity/algorithms/__init__.py +sglang/srt/mem_cache/sparsity/algorithms/base_algorithm.py +sglang/srt/mem_cache/sparsity/algorithms/deepseek_nsa.py +sglang/srt/mem_cache/sparsity/algorithms/quest_algorithm.py +sglang/srt/mem_cache/sparsity/backend/__init__.py +sglang/srt/mem_cache/sparsity/backend/backend_adaptor.py +sglang/srt/mem_cache/sparsity/core/__init__.py +sglang/srt/mem_cache/sparsity/core/sparse_coordinator.py +sglang/srt/mem_cache/storage/__init__.py +sglang/srt/mem_cache/storage/backend_factory.py +sglang/srt/mem_cache/storage/aibrix_kvcache/README.md +sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +sglang/srt/mem_cache/storage/eic/README.md +sglang/srt/mem_cache/storage/eic/eic_storage.py +sglang/srt/mem_cache/storage/eic/test_unit.py +sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +sglang/srt/mem_cache/storage/hf3fs/hf3fs_usrbio_client.py +sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +sglang/srt/mem_cache/storage/hf3fs/docs/README.md +sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md +sglang/srt/mem_cache/storage/hf3fs/docs/setup_usrbio_client.md +sglang/srt/mem_cache/storage/lmcache/README.md +sglang/srt/mem_cache/storage/lmcache/example_config.yaml +sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +sglang/srt/mem_cache/storage/lmcache/unit_test.py +sglang/srt/mem_cache/storage/mooncake_store/README.md +sglang/srt/mem_cache/storage/mooncake_store/embedding_cache_controller.py +sglang/srt/mem_cache/storage/mooncake_store/mooncake_embedding_store.py +sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +sglang/srt/mem_cache/storage/nixl/README.md +sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +sglang/srt/mem_cache/storage/nixl/nixl.config.toml.sample +sglang/srt/mem_cache/storage/nixl/nixl_utils.py +sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +sglang/srt/model_executor/cpu_graph_runner.py +sglang/srt/model_executor/cuda_graph_runner.py +sglang/srt/model_executor/forward_batch_deepseek_mha_mixin.py +sglang/srt/model_executor/forward_batch_info.py +sglang/srt/model_executor/hook_manager.py +sglang/srt/model_executor/input_buffers.py +sglang/srt/model_executor/mindspore_runner.py +sglang/srt/model_executor/model_runner.py +sglang/srt/model_executor/model_runner_kv_cache_mixin.py +sglang/srt/model_executor/piecewise_cuda_graph_runner.py +sglang/srt/model_loader/__init__.py +sglang/srt/model_loader/ci_weight_validation.py +sglang/srt/model_loader/loader.py +sglang/srt/model_loader/remote_instance_weight_loader_utils.py +sglang/srt/model_loader/utils.py +sglang/srt/model_loader/weight_utils.py +sglang/srt/models/afmoe.py +sglang/srt/models/apertus.py +sglang/srt/models/arcee.py +sglang/srt/models/baichuan.py +sglang/srt/models/bailing_moe.py +sglang/srt/models/bailing_moe_linear.py +sglang/srt/models/bailing_moe_nextn.py +sglang/srt/models/bert.py +sglang/srt/models/chatglm.py +sglang/srt/models/clip.py +sglang/srt/models/commandr.py +sglang/srt/models/dbrx.py +sglang/srt/models/deepseek.py +sglang/srt/models/deepseek_janus_pro.py +sglang/srt/models/deepseek_nextn.py +sglang/srt/models/deepseek_ocr.py +sglang/srt/models/deepseek_v2.py +sglang/srt/models/deepseek_vl2.py +sglang/srt/models/dots_ocr.py +sglang/srt/models/dots_vlm.py +sglang/srt/models/dots_vlm_vit.py +sglang/srt/models/ernie4.py +sglang/srt/models/ernie45_moe_vl.py +sglang/srt/models/ernie45_vl.py +sglang/srt/models/ernie4_eagle.py +sglang/srt/models/exaone.py +sglang/srt/models/exaone4.py +sglang/srt/models/exaone_moe.py +sglang/srt/models/exaone_moe_mtp.py +sglang/srt/models/falcon_h1.py +sglang/srt/models/gemma.py +sglang/srt/models/gemma2.py +sglang/srt/models/gemma2_reward.py +sglang/srt/models/gemma3_causal.py +sglang/srt/models/gemma3_mm.py +sglang/srt/models/gemma3n_audio.py +sglang/srt/models/gemma3n_causal.py +sglang/srt/models/gemma3n_mm.py +sglang/srt/models/glm4.py +sglang/srt/models/glm4_moe.py +sglang/srt/models/glm4_moe_lite.py +sglang/srt/models/glm4_moe_nextn.py +sglang/srt/models/glm4v.py +sglang/srt/models/glm4v_moe.py +sglang/srt/models/glm_ocr.py +sglang/srt/models/glm_ocr_nextn.py +sglang/srt/models/glmasr.py +sglang/srt/models/gpt2.py +sglang/srt/models/gpt_bigcode.py +sglang/srt/models/gpt_j.py +sglang/srt/models/gpt_oss.py +sglang/srt/models/granite.py +sglang/srt/models/granitemoe.py +sglang/srt/models/granitemoehybrid.py +sglang/srt/models/grok.py +sglang/srt/models/hunyuan.py +sglang/srt/models/idefics2.py +sglang/srt/models/internlm2.py +sglang/srt/models/internlm2_reward.py +sglang/srt/models/interns1.py +sglang/srt/models/interns1pro.py +sglang/srt/models/internvl.py +sglang/srt/models/iquest_loopcoder.py +sglang/srt/models/jet_nemotron.py +sglang/srt/models/jet_vlm.py +sglang/srt/models/kimi_k25.py +sglang/srt/models/kimi_linear.py +sglang/srt/models/kimi_vl.py +sglang/srt/models/kimi_vl_moonvit.py +sglang/srt/models/lfm2.py +sglang/srt/models/lfm2_moe.py +sglang/srt/models/lightonocr.py +sglang/srt/models/llada2.py +sglang/srt/models/llama.py +sglang/srt/models/llama4.py +sglang/srt/models/llama_classification.py +sglang/srt/models/llama_eagle.py +sglang/srt/models/llama_eagle3.py +sglang/srt/models/llama_embedding.py +sglang/srt/models/llama_reward.py +sglang/srt/models/llava.py +sglang/srt/models/llavavid.py +sglang/srt/models/longcat_flash.py +sglang/srt/models/longcat_flash_nextn.py +sglang/srt/models/midashenglm.py +sglang/srt/models/mimo.py +sglang/srt/models/mimo_mtp.py +sglang/srt/models/mimo_v2_flash.py +sglang/srt/models/mimo_v2_flash_nextn.py +sglang/srt/models/mindspore.py +sglang/srt/models/minicpm.py +sglang/srt/models/minicpm3.py +sglang/srt/models/minicpmo.py +sglang/srt/models/minicpmv.py +sglang/srt/models/minimax_m2.py +sglang/srt/models/ministral3.py +sglang/srt/models/mistral.py +sglang/srt/models/mistral_large_3.py +sglang/srt/models/mistral_large_3_eagle.py +sglang/srt/models/mixtral.py +sglang/srt/models/mixtral_quant.py +sglang/srt/models/mllama.py +sglang/srt/models/mllama4.py +sglang/srt/models/nano_nemotron_vl.py +sglang/srt/models/nemotron_h.py +sglang/srt/models/nemotron_h_mtp.py +sglang/srt/models/nemotron_nas.py +sglang/srt/models/nvila.py +sglang/srt/models/nvila_lite.py +sglang/srt/models/olmo.py +sglang/srt/models/olmo2.py +sglang/srt/models/olmoe.py +sglang/srt/models/opt.py +sglang/srt/models/orion.py +sglang/srt/models/paddleocr_vl.py +sglang/srt/models/persimmon.py +sglang/srt/models/phi.py +sglang/srt/models/phi3_small.py +sglang/srt/models/phi4mm.py +sglang/srt/models/phi4mm_audio.py +sglang/srt/models/phi4mm_utils.py +sglang/srt/models/phimoe.py +sglang/srt/models/pixtral.py +sglang/srt/models/points_v15_chat.py +sglang/srt/models/qwen.py +sglang/srt/models/qwen2.py +sglang/srt/models/qwen2_5_vl.py +sglang/srt/models/qwen2_audio.py +sglang/srt/models/qwen2_classification.py +sglang/srt/models/qwen2_eagle.py +sglang/srt/models/qwen2_moe.py +sglang/srt/models/qwen2_rm.py +sglang/srt/models/qwen2_vl.py +sglang/srt/models/qwen3.py +sglang/srt/models/qwen3_5.py +sglang/srt/models/qwen3_5_mtp.py +sglang/srt/models/qwen3_classification.py +sglang/srt/models/qwen3_moe.py +sglang/srt/models/qwen3_next.py +sglang/srt/models/qwen3_next_mtp.py +sglang/srt/models/qwen3_omni_moe.py +sglang/srt/models/qwen3_rm.py +sglang/srt/models/qwen3_vl.py +sglang/srt/models/qwen3_vl_moe.py +sglang/srt/models/radio.py +sglang/srt/models/registry.py +sglang/srt/models/roberta.py +sglang/srt/models/sarashina2_vision.py +sglang/srt/models/sarvam_moe.py +sglang/srt/models/sdar.py +sglang/srt/models/sdar_moe.py +sglang/srt/models/siglip.py +sglang/srt/models/solar.py +sglang/srt/models/stablelm.py +sglang/srt/models/starcoder2.py +sglang/srt/models/step3_vl.py +sglang/srt/models/step3_vl_10b.py +sglang/srt/models/step3p5.py +sglang/srt/models/step3p5_mtp.py +sglang/srt/models/teleflm.py +sglang/srt/models/torch_native_llama.py +sglang/srt/models/transformers.py +sglang/srt/models/utils.py +sglang/srt/models/whisper.py +sglang/srt/models/xverse.py +sglang/srt/models/xverse_moe.py +sglang/srt/models/yivl.py +sglang/srt/models/deepseek_common/__init__.py +sglang/srt/models/deepseek_common/attention_backend_handler.py +sglang/srt/models/deepseek_common/deepseek_weight_loader.py +sglang/srt/models/deepseek_common/utils.py +sglang/srt/models/deepseek_common/attention_forward_methods/__init__.py +sglang/srt/models/deepseek_common/attention_forward_methods/forward_methods.py +sglang/srt/models/deepseek_common/attention_forward_methods/forward_mha.py +sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py +sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla_fused_rope_cpu.py +sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla_fused_rope_rocm.py +sglang/srt/multimodal/customized_mm_processor_utils.py +sglang/srt/multimodal/internvl_utils.py +sglang/srt/multimodal/internvl_vit_cuda_graph_runner.py +sglang/srt/multimodal/mm_utils.py +sglang/srt/multimodal/vit_cuda_graph_runner.py +sglang/srt/multimodal/evs/README.md +sglang/srt/multimodal/evs/__init__.py +sglang/srt/multimodal/evs/evs_core.py +sglang/srt/multimodal/evs/evs_module.py +sglang/srt/multimodal/evs/evs_processor.py +sglang/srt/multimodal/processors/base_processor.py +sglang/srt/multimodal/processors/clip.py +sglang/srt/multimodal/processors/deepseek_ocr.py +sglang/srt/multimodal/processors/deepseek_vl_v2.py +sglang/srt/multimodal/processors/dots_vlm.py +sglang/srt/multimodal/processors/ernie45_vl.py +sglang/srt/multimodal/processors/gemma3.py +sglang/srt/multimodal/processors/gemma3n.py +sglang/srt/multimodal/processors/glm4v.py +sglang/srt/multimodal/processors/glmasr.py +sglang/srt/multimodal/processors/interns1pro.py +sglang/srt/multimodal/processors/internvl.py +sglang/srt/multimodal/processors/janus_pro.py +sglang/srt/multimodal/processors/kimi_k25.py +sglang/srt/multimodal/processors/kimi_vl.py +sglang/srt/multimodal/processors/lightonocr.py +sglang/srt/multimodal/processors/llava.py +sglang/srt/multimodal/processors/midashenglm.py +sglang/srt/multimodal/processors/minicpm.py +sglang/srt/multimodal/processors/mlama.py +sglang/srt/multimodal/processors/mllama4.py +sglang/srt/multimodal/processors/nano_nemotron_vl.py +sglang/srt/multimodal/processors/nvila.py +sglang/srt/multimodal/processors/paddleocr_vlm.py +sglang/srt/multimodal/processors/phi4mm.py +sglang/srt/multimodal/processors/pixtral.py +sglang/srt/multimodal/processors/points_v15_chat.py +sglang/srt/multimodal/processors/qwen_audio.py +sglang/srt/multimodal/processors/qwen_vl.py +sglang/srt/multimodal/processors/sarashina2_vision.py +sglang/srt/multimodal/processors/step3_vl.py +sglang/srt/multimodal/processors/whisper.py +sglang/srt/multiplex/multiplexing_mixin.py +sglang/srt/multiplex/pdmux_context.py +sglang/srt/observability/cpu_monitor.py +sglang/srt/observability/func_timer.py +sglang/srt/observability/label_transform.py +sglang/srt/observability/metrics_collector.py +sglang/srt/observability/req_time_stats.py +sglang/srt/observability/request_metrics_exporter.py +sglang/srt/observability/scheduler_metrics_mixin.py +sglang/srt/observability/startup_func_log_and_timer.py +sglang/srt/observability/trace.py +sglang/srt/observability/utils.py +sglang/srt/parser/code_completion_parser.py +sglang/srt/parser/conversation.py +sglang/srt/parser/harmony_parser.py +sglang/srt/parser/jinja_template_utils.py +sglang/srt/parser/reasoning_parser.py +sglang/srt/ray/__init__.py +sglang/srt/ray/engine.py +sglang/srt/ray/http_server.py +sglang/srt/ray/scheduler_actor.py +sglang/srt/sampling/custom_logit_processor.py +sglang/srt/sampling/sampling_batch_info.py +sglang/srt/sampling/sampling_params.py +sglang/srt/sampling/penaltylib/__init__.py +sglang/srt/sampling/penaltylib/frequency_penalty.py +sglang/srt/sampling/penaltylib/min_new_tokens.py +sglang/srt/sampling/penaltylib/orchestrator.py +sglang/srt/sampling/penaltylib/presence_penalty.py +sglang/srt/speculative/base_spec_worker.py +sglang/srt/speculative/draft_utils.py +sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +sglang/srt/speculative/eagle_info.py +sglang/srt/speculative/eagle_info_v2.py +sglang/srt/speculative/eagle_utils.py +sglang/srt/speculative/eagle_worker.py +sglang/srt/speculative/eagle_worker_v2.py +sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py +sglang/srt/speculative/multi_layer_eagle_utils.py +sglang/srt/speculative/multi_layer_eagle_worker.py +sglang/srt/speculative/multi_layer_eagle_worker_v2.py +sglang/srt/speculative/ngram_info.py +sglang/srt/speculative/ngram_worker.py +sglang/srt/speculative/spec_info.py +sglang/srt/speculative/spec_utils.py +sglang/srt/speculative/standalone_worker.py +sglang/srt/speculative/standalone_worker_v2.py +sglang/srt/speculative/cpp_ngram/.clang-format +sglang/srt/speculative/cpp_ngram/ngram.cpp +sglang/srt/speculative/cpp_ngram/ngram.h +sglang/srt/speculative/cpp_ngram/ngram_cache.py +sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +sglang/srt/speculative/cpp_ngram/param.h +sglang/srt/speculative/cpp_ngram/queue.h +sglang/srt/tokenizer/tiktoken_tokenizer.py +sglang/srt/utils/__init__.py +sglang/srt/utils/aio_rwlock.py +sglang/srt/utils/auth.py +sglang/srt/utils/bench_utils.py +sglang/srt/utils/common.py +sglang/srt/utils/cuda_ipc_transport_utils.py +sglang/srt/utils/custom_op.py +sglang/srt/utils/device_timer.py +sglang/srt/utils/gauge_histogram.py +sglang/srt/utils/hf_transformers_utils.py +sglang/srt/utils/host_shared_memory.py +sglang/srt/utils/log_utils.py +sglang/srt/utils/mistral_utils.py +sglang/srt/utils/model_file_verifier.py +sglang/srt/utils/multi_stream_utils.py +sglang/srt/utils/numa_utils.py +sglang/srt/utils/nvtx_pytorch_hooks.py +sglang/srt/utils/offloader.py +sglang/srt/utils/patch_tokenizer.py +sglang/srt/utils/patch_torch.py +sglang/srt/utils/poll_based_barrier.py +sglang/srt/utils/profile_merger.py +sglang/srt/utils/profile_utils.py +sglang/srt/utils/request_logger.py +sglang/srt/utils/rpd_utils.py +sglang/srt/utils/scheduler_status_logger.py +sglang/srt/utils/slow_rank_detector.py +sglang/srt/utils/torch_memory_saver_adapter.py +sglang/srt/utils/watchdog.py +sglang/srt/utils/weight_checker.py +sglang/srt/weight_sync/tensor_bucket.py +sglang/srt/weight_sync/utils.py +sglang/test/__init__.py +sglang/test/accuracy_test_runner.py +sglang/test/bench_one_batch_server_internal.py +sglang/test/doc_patch.py +sglang/test/few_shot_gsm8k.py +sglang/test/few_shot_gsm8k_engine.py +sglang/test/get_logits_ut.py +sglang/test/gpt_oss_common.py +sglang/test/kl_test_utils.py +sglang/test/long_prompt.txt +sglang/test/lora_utils.py +sglang/test/nightly_bench_utils.py +sglang/test/nightly_utils.py +sglang/test/performance_test_runner.py +sglang/test/run_combined_tests.py +sglang/test/run_eval.py +sglang/test/runners.py +sglang/test/send_one.py +sglang/test/simple_eval_aime25.py +sglang/test/simple_eval_common.py +sglang/test/simple_eval_gpqa.py +sglang/test/simple_eval_gsm8k.py +sglang/test/simple_eval_humaneval.py +sglang/test/simple_eval_longbench_v2.py +sglang/test/simple_eval_math.py +sglang/test/simple_eval_mgsm.py +sglang/test/simple_eval_mmlu.py +sglang/test/simple_eval_mmmu_vlm.py +sglang/test/test_activation.py +sglang/test/test_block_fp8.py +sglang/test/test_block_fp8_deep_gemm_blackwell.py +sglang/test/test_custom_ops.py +sglang/test/test_cutlass_moe.py +sglang/test/test_cutlass_w16a16_moe.py +sglang/test/test_cutlass_w4a8_moe.py +sglang/test/test_deepep_utils.py +sglang/test/test_deterministic.py +sglang/test/test_deterministic_utils.py +sglang/test/test_dump_metric.py +sglang/test/test_dynamic_grad_mode.py +sglang/test/test_flashinfer_dispatcher.py +sglang/test/test_http_server_auth.py +sglang/test/test_kvfp4_quant_dequant.py +sglang/test/test_layernorm.py +sglang/test/test_marlin_utils.py +sglang/test/test_programs.py +sglang/test/test_utils.py +sglang/test/tool_call_test_runner.py +sglang/test/vlm_utils.py +sglang/test/ascend/__init__.py +sglang/test/ascend/gsm8k_ascend_mixin.py +sglang/test/ascend/test_ascend_utils.py +sglang/test/ascend/vlm_utils.py +sglang/test/attention/__init__.py +sglang/test/attention/test_flashattn_backend.py +sglang/test/attention/test_flashattn_mla_backend.py +sglang/test/attention/test_prefix_chunk_info.py +sglang/test/attention/test_trtllm_mla_backend.py +sglang/test/ci/__init__.py +sglang/test/ci/ci_register.py +sglang/test/ci/ci_stress_utils.py +sglang/test/ci/ci_utils.py +sglang/test/ci/run_with_retry.py +sglang/test/external_models/custom_qwen2_vl.py +sglang/test/kits/abort_timeout_kit.py +sglang/test/kits/cache_hit_kit.py +sglang/test/kits/ebnf_constrained_kit.py +sglang/test/kits/gsm8k_accuracy_kit.py +sglang/test/kits/json_constrained_kit.py +sglang/test/kits/lm_eval_kit.py +sglang/test/kits/matched_stop_kit.py +sglang/test/kits/mmmu_vlm_kit.py +sglang/test/kits/radix_cache_server_kit.py +sglang/test/kits/regex_constrained_kit.py +sglang/test/kits/spec_decoding_kit.py +sglang/test/longbench_v2/__init__.py +sglang/test/longbench_v2/longbench_v2_evaluation.md +sglang/test/longbench_v2/test_longbench_v2_eval.py +sglang/test/longbench_v2/validate_longbench_v2.py +sglang/test/longbench_v2/validate_longbench_v2_standalone.py +sglang/test/server_fixtures/default_fixture.py +sglang/test/server_fixtures/disaggregation_fixture.py +sglang/test/server_fixtures/eagle_fixture.py +sglang/test/server_fixtures/mmmu_fixture.py +sglang/test/speculative/test_spec_utils.py \ No newline at end of file diff --git a/sglang/python/sglang.egg-info/dependency_links.txt b/sglang/python/sglang.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/sglang/python/sglang.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/sglang/python/sglang.egg-info/entry_points.txt b/sglang/python/sglang.egg-info/entry_points.txt new file mode 100644 index 0000000000000000000000000000000000000000..8300958a6dbffae58172e6d14c8fe355e571672a --- /dev/null +++ b/sglang/python/sglang.egg-info/entry_points.txt @@ -0,0 +1,2 @@ +[console_scripts] +sglang = sglang.cli.main:main diff --git a/sglang/python/sglang.egg-info/requires.txt b/sglang/python/sglang.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..0be8458dbbc92560154137cce559d7f86ae908d1 --- /dev/null +++ b/sglang/python/sglang.egg-info/requires.txt @@ -0,0 +1,121 @@ +IPython +aiohttp +apache-tvm-ffi<0.2,>=0.1.5 +anthropic>=0.20.0 +blobfile==3.0.0 +build +compressed-tensors +cuda-python==12.9 +decord2 +datasets +einops +fastapi +flashinfer_python==0.6.4 +flashinfer_cubin==0.6.4 +gguf +hf_transfer +huggingface_hub +interegular +llguidance<0.8.0,>=0.7.11 +modelscope +msgspec +ninja +numpy +nvidia-cutlass-dsl>=4.3.4 +nvidia-ml-py +openai-harmony==0.0.4 +openai==2.6.1 +orjson +outlines==0.1.11 +packaging +partial_json_parser +pillow +prometheus-client>=0.20.0 +psutil +py-spy +pybase64 +pydantic +python-multipart +pyzmq>=25.1.2 +quack-kernels==0.2.4 +requests +scipy +sentencepiece +setproctitle +sgl-fa4==4.0.3 +sgl-kernel==0.3.21 +soundfile==0.13.1 +tiktoken +timm==1.0.16 +torch_memory_saver==0.0.9 +torch==2.9.1 +torchao==0.9.0 +torchaudio==2.9.1 +torchvision +tqdm +transformers==4.57.1 +uvicorn +uvloop +watchfiles +xgrammar==0.1.27 +smg-grpc-proto>=0.4.1 +grpcio>=1.78.0 +grpcio-reflection>=1.78.0 +grpcio-health-checking>=1.78.0 + +[:sys_platform != "linux" or (sys_platform == "linux" and platform_machine != "aarch64" and platform_machine != "arm64" and platform_machine != "armv7l")] +torchcodec==0.8.0 + +[all] +sglang[diffusion] +sglang[tracing] + +[checkpoint-engine] +checkpoint-engine==0.1.2 + +[dev] +sglang[test] + +[diffusion] +PyYAML==6.0.1 +cloudpickle==3.1.2 +diffusers==0.36.0 +imageio==2.36.0 +imageio-ffmpeg==0.5.1 +moviepy>=2.0.0 +opencv-python-headless==4.10.0.84 +remote-pdb==2.1.0 +runai_model_streamer>=0.15.5 +cache-dit==1.2.3 +addict==2.4.0 +av==16.1.0 +scikit-image==0.25.2 +trimesh>=4.0.0 +xatlas + +[diffusion:platform_machine != "aarch64" and platform_machine != "arm64"] +st_attn==0.0.7 +vsa==0.0.4 + +[ray] +ray[default]>=2.54.0 + +[test] +accelerate +bitsandbytes +expecttest +jsonlines +lm-eval[api]>=0.4.9.2 +matplotlib +pandas +parameterized +peft +pytest +sentence_transformers +tabulate + +[tracing] +opentelemetry-api +opentelemetry-exporter-otlp +opentelemetry-exporter-otlp-proto-grpc +opentelemetry-sdk diff --git a/sglang/python/sglang.egg-info/top_level.txt b/sglang/python/sglang.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..bf882dac88fad51de249de66a232d6e6b704d21d --- /dev/null +++ b/sglang/python/sglang.egg-info/top_level.txt @@ -0,0 +1 @@ +sglang diff --git a/sglang/python/sglang/README.md b/sglang/python/sglang/README.md new file mode 100644 index 0000000000000000000000000000000000000000..de0a7189f528e4ba387f4a4aa105a1a46ee12df0 --- /dev/null +++ b/sglang/python/sglang/README.md @@ -0,0 +1,18 @@ +# Code Structure + +- `eval`: The evaluation utilities. +- `lang`: The frontend language. +- `multimodal_gen`: Inference framework for accelerated image/video generation. +- `srt`: The backend engine for running local models. (SRT = SGLang Runtime). +- `test`: The test utilities. +- `api.py`: The public APIs. +- `bench_offline_throughput.py`: Benchmark the performance in the offline mode. +- `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server. +- `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server. +- `bench_serving.py`: Benchmark online serving with dynamic requests. +- `check_env.py`: Check the environment variables and dependencies. +- `global_config.py`: The global configs and constants. +- `launch_server.py`: The entry point for launching a local server. +- `profiler.py`: The profiling entry point to send profile requests. +- `utils.py`: Common utilities. +- `version.py`: Version info. diff --git a/sglang/python/sglang/__init__.py b/sglang/python/sglang/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..509b145a9b306eb48cc098d905a7acdb1dfa4b04 --- /dev/null +++ b/sglang/python/sglang/__init__.py @@ -0,0 +1,83 @@ +# SGLang public APIs + +# Frontend Language APIs +from sglang.global_config import global_config +from sglang.lang.api import ( + Engine, + Runtime, + assistant, + assistant_begin, + assistant_end, + flush_cache, + function, + gen, + gen_int, + gen_string, + get_server_info, + image, + select, + separate_reasoning, + set_default_backend, + system, + system_begin, + system_end, + user, + user_begin, + user_end, + video, +) +from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint +from sglang.lang.choices import ( + greedy_token_selection, + token_length_normalized, + unconditional_likelihood_normalized, +) + +# Lazy import some libraries +from sglang.utils import LazyImport +from sglang.version import __version__ + +Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic") +LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM") +OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI") +VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI") + +# Runtime Engine APIs +ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs") +Engine = LazyImport("sglang.srt.entrypoints.engine", "Engine") + +__all__ = [ + "Engine", + "Runtime", + "assistant", + "assistant_begin", + "assistant_end", + "flush_cache", + "function", + "gen", + "gen_int", + "gen_string", + "get_server_info", + "image", + "select", + "separate_reasoning", + "set_default_backend", + "system", + "system_begin", + "system_end", + "user", + "user_begin", + "user_end", + "video", + "RuntimeEndpoint", + "greedy_token_selection", + "token_length_normalized", + "unconditional_likelihood_normalized", + "ServerArgs", + "Anthropic", + "LiteLLM", + "OpenAI", + "VertexAI", + "global_config", + "__version__", +] diff --git a/sglang/python/sglang/__pycache__/__init__.cpython-311.pyc b/sglang/python/sglang/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af7e70c9b85a8c0e417625776dc882f4356536b6 Binary files /dev/null and b/sglang/python/sglang/__pycache__/__init__.cpython-311.pyc differ diff --git a/sglang/python/sglang/__pycache__/_version.cpython-311.pyc b/sglang/python/sglang/__pycache__/_version.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3988f1676c85502a5487a1979e35ef9d5a4c6afe Binary files /dev/null and b/sglang/python/sglang/__pycache__/_version.cpython-311.pyc differ diff --git a/sglang/python/sglang/__pycache__/bench_serving.cpython-311.pyc b/sglang/python/sglang/__pycache__/bench_serving.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06a2e42cfdf5dd55c1ba69023cbba419c9abf8e3 Binary files /dev/null and b/sglang/python/sglang/__pycache__/bench_serving.cpython-311.pyc differ diff --git a/sglang/python/sglang/__pycache__/check_env.cpython-311.pyc b/sglang/python/sglang/__pycache__/check_env.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65c095c14f66409c52c0cf814cda391d4fe50014 Binary files /dev/null and b/sglang/python/sglang/__pycache__/check_env.cpython-311.pyc differ diff --git a/sglang/python/sglang/__pycache__/global_config.cpython-311.pyc b/sglang/python/sglang/__pycache__/global_config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c680aae4846235517d5824766baac768c647cab7 Binary files /dev/null and b/sglang/python/sglang/__pycache__/global_config.cpython-311.pyc differ diff --git a/sglang/python/sglang/__pycache__/launch_server.cpython-311.pyc b/sglang/python/sglang/__pycache__/launch_server.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2944dcff0ad36b61488b03e7a8d7f572c13d7c15 Binary files /dev/null and b/sglang/python/sglang/__pycache__/launch_server.cpython-311.pyc differ diff --git a/sglang/python/sglang/__pycache__/utils.cpython-311.pyc b/sglang/python/sglang/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b174baf67e75f7f8baf330b4fb3bc63a9a25bbaf Binary files /dev/null and b/sglang/python/sglang/__pycache__/utils.cpython-311.pyc differ diff --git a/sglang/python/sglang/__pycache__/version.cpython-311.pyc b/sglang/python/sglang/__pycache__/version.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..860f220544e856892fb725461cca705f18b5b68e Binary files /dev/null and b/sglang/python/sglang/__pycache__/version.cpython-311.pyc differ diff --git a/sglang/python/sglang/_version.py b/sglang/python/sglang/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..4aba6f9b6b2cdd8b6020136a738f1b4074e8c8be --- /dev/null +++ b/sglang/python/sglang/_version.py @@ -0,0 +1,34 @@ +# file generated by setuptools-scm +# don't change, don't track in version control + +__all__ = [ + "__version__", + "__version_tuple__", + "version", + "version_tuple", + "__commit_id__", + "commit_id", +] + +TYPE_CHECKING = False +if TYPE_CHECKING: + from typing import Tuple + from typing import Union + + VERSION_TUPLE = Tuple[Union[int, str], ...] + COMMIT_ID = Union[str, None] +else: + VERSION_TUPLE = object + COMMIT_ID = object + +version: str +__version__: str +__version_tuple__: VERSION_TUPLE +version_tuple: VERSION_TUPLE +commit_id: COMMIT_ID +__commit_id__: COMMIT_ID + +__version__ = version = '0.5.9' +__version_tuple__ = version_tuple = (0, 5, 9) + +__commit_id__ = commit_id = 'gbbe9c7eeb' diff --git a/sglang/python/sglang/bench_offline_throughput.py b/sglang/python/sglang/bench_offline_throughput.py new file mode 100644 index 0000000000000000000000000000000000000000..0943acd8dc54b1c2452b6006b6e23b471f822afa --- /dev/null +++ b/sglang/python/sglang/bench_offline_throughput.py @@ -0,0 +1,543 @@ +""" +Benchmark the throughput in the offline mode. +It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py). + +# Usage +## Sharegpt dataset with default args +python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10 + +## Random dataset with default args +python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 +""" + +import argparse +import asyncio +import dataclasses +import inspect +import json +import logging +import os +import random +import time +from typing import Dict, List, Optional + +import numpy as np + +from sglang.benchmark.datasets import DatasetRow, get_dataset +from sglang.benchmark.datasets.random import sample_random_requests +from sglang.benchmark.utils import get_tokenizer, set_ulimit +from sglang.lang.backend.runtime_endpoint import Runtime +from sglang.srt.entrypoints.engine import Engine +from sglang.srt.server_args import ServerArgs + + +@dataclasses.dataclass +class BenchArgs: + backend: str = "engine" + result_filename: str = "" + dataset_name: str = "sharegpt" + dataset_path: str = "" + num_prompts: int = 1000 + sharegpt_output_len: Optional[int] = None + sharegpt_context_len: Optional[int] = None + random_input_len: int = 1024 + random_output_len: int = 1024 + random_range_ratio: float = 0.0 + gsp_num_groups: int = 64 + gsp_prompts_per_group: int = 16 + gsp_system_prompt_len: int = 2048 + gsp_question_len: int = 128 + gsp_output_len: int = 256 + seed: int = 1 + disable_ignore_eos: bool = False + extra_request_body: Optional[str] = None + apply_chat_template: bool = False + profile: bool = False + skip_warmup: bool = False + do_not_exit: bool = False + prompt_suffix: str = "" + return_logprob: bool = False + logprob_start_len: int = -1 + + @staticmethod + def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--backend", type=str, default=BenchArgs.backend) + parser.add_argument( + "--result-filename", type=str, default=BenchArgs.result_filename + ) + parser.add_argument( + "--dataset-name", + type=str, + default="sharegpt", + choices=["sharegpt", "random", "generated-shared-prefix"], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument( + "--dataset-path", type=str, default="", help="Path to the dataset." + ) + parser.add_argument( + "--num-prompts", + type=int, + default=BenchArgs.num_prompts, + help="Number of prompts to process. Default is 1000.", + ) + parser.add_argument( + "--sharegpt-output-len", + type=int, + default=BenchArgs.sharegpt_output_len, + help="Output length for each request. Overrides the output length from the ShareGPT dataset.", + ) + parser.add_argument( + "--sharegpt-context-len", + type=int, + default=BenchArgs.sharegpt_context_len, + help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.", + ) + parser.add_argument( + "--random-input-len", + type=int, + default=BenchArgs.random_input_len, + help="Number of input tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-output-len", + type=int, + default=BenchArgs.random_output_len, + help="Number of output tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-range-ratio", + type=float, + default=BenchArgs.random_range_ratio, + help="Range of sampled ratio of input/output length, " + "used only for random dataset.", + ) + parser.add_argument( + "--gsp-num-groups", + type=int, + default=BenchArgs.gsp_num_groups, + help="Number of groups with shared prefix, used" + "only for generate-shared-prefix", + ) + parser.add_argument( + "--gsp-prompts-per-group", + type=int, + default=BenchArgs.gsp_prompts_per_group, + help="Number of prompts per group of shared prefix, used" + "only for generate-shared-prefix", + ) + parser.add_argument( + "--gsp-system-prompt-len", + type=int, + default=BenchArgs.gsp_system_prompt_len, + help="System prompt length, used" "only for generate-shared-prefix", + ) + parser.add_argument( + "--gsp-question-len", + type=int, + default=BenchArgs.gsp_question_len, + help="Question length, used" "only for generate-shared-prefix", + ) + parser.add_argument( + "--gsp-output-len", + type=int, + default=BenchArgs.gsp_output_len, + help="Target length in tokens for outputs in generated-shared-prefix dataset", + ) + parser.add_argument("--seed", type=int, default=1, help="The random seed.") + parser.add_argument( + "--disable-ignore-eos", + action="store_true", + help="Disable ignore EOS token", + ) + parser.add_argument( + "--extra-request-body", + metavar='{"key1": "value1", "key2": "value2"}', + type=str, + default=BenchArgs.extra_request_body, + help="Append given JSON object to the request payload. You can use this to specify" + "additional generate params like sampling params.", + ) + parser.add_argument( + "--apply-chat-template", + action="store_true", + help="Apply chat template", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "SGLANG_TORCH_PROFILER_DIR to enable profiler.", + ) + parser.add_argument( + "--skip-warmup", + action="store_true", + help="Skip the warmup batches.", + ) + parser.add_argument( + "--do-not-exit", + action="store_true", + help="Do not exit the program. This is useful for nsys profile with --duration and --delay.", + ) + parser.add_argument( + "--prompt-suffix", + type=str, + default="", + help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.", + ) + parser.add_argument( + "--return-logprob", + action="store_true", + help="Enable returning log probabilities.", + ) + parser.add_argument( + "--logprob-start-len", + type=int, + default=-1, + help="Start length for logprob. -1 means only return logprobs for output tokens (default). 0 means return logprobs for all tokens including input.", + ) + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + attrs = [attr.name for attr in dataclasses.fields(cls)] + return cls(**{attr: getattr(args, attr) for attr in attrs}) + + +def throughput_test_once( + backend_name: str, + backend, + reqs: List[DatasetRow], + ignore_eos: bool, + extra_request_body: Dict, + profile: bool, + return_logprob: bool = False, + logprob_start_len: int = -1, +): + measurement_results = { + "backend": backend_name, + "successful_requests": len(reqs), + "total_latency": -1, + "total_input_tokens": sum(r.prompt_len for r in reqs), + "total_output_tokens": -1, + "request_throughput": -1, + "input_throughput": -1, + "output_throughput": -1, + "total_throughput": -1, + } + + prompt = [r.prompt for r in reqs] + sampling_params = [ + { + "temperature": 0, + "max_new_tokens": r.output_len, + "ignore_eos": ignore_eos, + **extra_request_body, + } + for r in reqs + ] + + if profile: + assert ( + "SGLANG_TORCH_PROFILER_DIR" in os.environ + ), "Please set SGLANG_TORCH_PROFILER_DIR." + os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True) + backend.start_profile() + + st = time.perf_counter() + gen_out = backend.generate( + prompt=prompt, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + ) + latency = time.perf_counter() - st + + if profile: + dir = os.getenv("SGLANG_TORCH_PROFILER_DIR") + known_files = set(os.listdir(dir)) + backend.stop_profile() + monitor_trace_file(known_files, dir) + + if backend_name == "runtime": + gen_out = json.loads(gen_out) + + server_info = backend.get_server_info() + + measurement_results["total_latency"] = latency + measurement_results["total_output_tokens"] = sum( + o["meta_info"]["completion_tokens"] for o in gen_out + ) + measurement_results["request_throughput"] = ( + measurement_results["successful_requests"] / latency + ) + measurement_results["input_throughput"] = ( + measurement_results["total_input_tokens"] / latency + ) + measurement_results["output_throughput"] = ( + measurement_results["total_output_tokens"] / latency + ) + measurement_results["total_throughput"] = ( + measurement_results["total_input_tokens"] + + measurement_results["total_output_tokens"] + ) / latency + + if inspect.isawaitable(server_info): + server_info = asyncio.run(server_info) + + measurement_results["last_gen_throughput"] = server_info["internal_states"][0][ + "last_gen_throughput" + ] + + return measurement_results + + +def monitor_trace_file(known_files, directory, interval=1): + print(f"Monitoring {directory} for new trace files...") + + while True: + flag = False + time.sleep(interval) + current_files = set(os.listdir(directory)) + + new_files = current_files - known_files + for new_file in new_files: + new_file_path = os.path.join(directory, new_file) + print(f"New file detected: {new_file}") + + previous_size = 0 + while True: + try: + current_size = os.path.getsize(new_file_path) + except FileNotFoundError: + print(f"File {new_file} is no longer accessible.") + break + + if current_size > previous_size: + previous_size = current_size + else: + flag = True + break + + time.sleep(interval) + if flag: + break + + +def _create_ray_engine_backend(server_args: ServerArgs): + """Create a RayEngine inside a Ray actor on a placement group. + + RayEngine requires a placement group, so we launch it inside a Ray actor + and return a lightweight proxy that forwards calls via ray.get(). + """ + import ray + from ray.runtime_env import RuntimeEnv + from ray.util.placement_group import placement_group + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + + env_vars = {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"} + if os.environ.get("HF_TOKEN"): + env_vars["HF_TOKEN"] = os.environ["HF_TOKEN"] + if not ray.is_initialized(): + ray.init(runtime_env=RuntimeEnv(env_vars=env_vars)) + + total_gpus = server_args.tp_size * server_args.pp_size + pg = placement_group([{"CPU": 1, "GPU": total_gpus}], strategy="STRICT_PACK") + ray.get(pg.ready()) + + @ray.remote + class _EngineActor: + def __init__(self, **kwargs): + from sglang.srt.ray.engine import RayEngine + + self.engine = RayEngine(**kwargs) + + def call(self, method, **kwargs): + return getattr(self.engine, method)(**kwargs) + + actor = _EngineActor.options( + num_cpus=1, + num_gpus=0, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=0, + ), + ).remote(**dataclasses.asdict(server_args)) + + class _Proxy: + """Forwards method calls to the remote RayEngine actor.""" + + def generate(self, **kwargs): + return ray.get(actor.call.remote("generate", **kwargs)) + + def get_server_info(self, **kwargs): + return ray.get(actor.call.remote("get_server_info", **kwargs)) + + def start_profile(self, **kwargs): + return ray.get(actor.call.remote("start_profile", **kwargs)) + + def stop_profile(self, **kwargs): + return ray.get(actor.call.remote("stop_profile", **kwargs)) + + def shutdown(self): + try: + ray.get(actor.call.remote("shutdown"), timeout=60) + except Exception: + pass + try: + ray.util.remove_placement_group(pg) + except Exception: + pass + + return _Proxy() + + +def throughput_test( + server_args: ServerArgs, + bench_args: BenchArgs, +): + if bench_args.backend == "engine": + if server_args.use_ray: + backend = _create_ray_engine_backend(server_args) + else: + backend = Engine(**dataclasses.asdict(server_args)) + if not backend: + raise ValueError("Please provide valid engine arguments") + elif bench_args.backend == "runtime": + backend = Runtime(**dataclasses.asdict(server_args)) + else: + raise ValueError('Please set backend to either "engine" or "runtime"') + + tokenizer_id = server_args.tokenizer_path or server_args.model_path + tokenizer = get_tokenizer(tokenizer_id) + + # Set global environments + set_ulimit() + random.seed(bench_args.seed) + np.random.seed(bench_args.seed) + + # Parse args + extra_request_body = {} + if bench_args.extra_request_body: + extra_request_body = json.loads(args.extra_request_body) + + # Read dataset + input_requests = get_dataset(bench_args, tokenizer) + + warmup_requests = sample_random_requests( + input_len=256, + output_len=16, + num_prompts=min(bench_args.num_prompts, 16), + range_ratio=1.0, + tokenizer=tokenizer, + dataset_path=bench_args.dataset_path, + ) + + # Warm up + if not bench_args.skip_warmup: + logging.info("\nWarmup...") + throughput_test_once( + backend_name=bench_args.backend, + backend=backend, + reqs=warmup_requests, + ignore_eos=not bench_args.disable_ignore_eos, + extra_request_body=extra_request_body, + profile=False, + return_logprob=bench_args.return_logprob, + logprob_start_len=bench_args.logprob_start_len, + ) + time.sleep(0.5) + + logging.info("\nBenchmark...") + result = throughput_test_once( + backend_name=bench_args.backend, + backend=backend, + reqs=input_requests, + ignore_eos=not bench_args.disable_ignore_eos, + extra_request_body=extra_request_body, + profile=bench_args.profile, + return_logprob=bench_args.return_logprob, + logprob_start_len=bench_args.logprob_start_len, + ) + backend.shutdown() + + if bench_args.result_filename: + with open(bench_args.result_filename, "a") as fout: + fout.write(json.dumps(result) + "\n") + + print( + "\n{s:{c}^{n}}".format(s=" Offline Throughput Benchmark Result ", n=50, c="=") + ) + print("{:<40} {:<10}".format("Backend:", result["backend"])) + print("{:<40} {:<10}".format("Successful requests:", result["successful_requests"])) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", result["total_latency"])) + print("{:<40} {:<10}".format("Total input tokens:", result["total_input_tokens"])) + print( + "{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"]) + ) + print( + "{:<40} {:<10.2f}".format( + "Last generation throughput (tok/s):", result["last_gen_throughput"] + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", result["request_throughput"] + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Input token throughput (tok/s):", result["input_throughput"] + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", result["output_throughput"] + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Total token throughput (tok/s):", result["total_throughput"] + ) + ) + print("=" * 50) + + return result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + BenchArgs.add_cli_args(parser) + args = parser.parse_args() + + # handling ModelScope model downloads + if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() in ("true", "1"): + if os.path.exists(args.model_path): + print(f"Using local model path: {args.model_path}") + else: + try: + from modelscope import snapshot_download + + print(f"Using ModelScope to download model: {args.model_path}") + + # download the model and replace args.model_path + args.model_path = snapshot_download( + args.model_path, + ) + print(f"Model downloaded to: {args.model_path}") + except Exception as e: + print(f"ModelScope download failed: {str(e)}") + raise e + + server_args = ServerArgs.from_cli_args(args) + bench_args = BenchArgs.from_cli_args(args) + + logging.basicConfig( + level=getattr(logging, server_args.log_level.upper()), + format="%(message)s", + ) + + throughput_test(server_args, bench_args) + + while bench_args.do_not_exit: + pass diff --git a/sglang/python/sglang/bench_one_batch.py b/sglang/python/sglang/bench_one_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..8cf0aee1ba22a1bc2ee82c43028e7c5033073039 --- /dev/null +++ b/sglang/python/sglang/bench_one_batch.py @@ -0,0 +1,837 @@ +""" +Benchmark the latency of running a single static batch without a server. + +This script does not launch a server and uses the low-level APIs. +It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths). + +# Usage (latency test) +## with dummy weights: +python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy +## sweep through multiple data points and store (append) the results in a jsonl file: +python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run +## run with profiling: +python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile +## run with profiling to custom directory: +export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log +python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile +## run with CUDA profiler (nsys): +nsys profile --force-overwrite=true -o bench_one_batch python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile --profile-activities CUDA_PROFILER +# Usage (correctness test): +python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct + +## Reference output (of the correctness test above, can be gpu dependent): +input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]] + +prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633], + [-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633], + [ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]], + device='cuda:0') + +prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141], + [-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781], + [-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]], + device='cuda:0') + +========== Prompt 0 ========== + The capital of France is Paris. +The capital of the United States is Washington, D.C. + + +========== Prompt 1 ========== + The capital of the United Kindom is London. +The capital of the United Kingdom is London. +The capital of the + +========== Prompt 2 ========== + Today is a sunny day and I like to go for a walk in the park. +I'm going to the park +""" + +import argparse +import copy +import dataclasses +import itertools +import json +import logging +import multiprocessing +import os +import time +from types import SimpleNamespace +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist + +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.distributed.parallel_state import destroy_distributed_environment +from sglang.srt.entrypoints.engine import _set_envs_and_config +from sglang.srt.layers.moe import initialize_moe_config +from sglang.srt.layers.quantization.fp4_utils import initialize_fp4_gemm_config +from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.managers.scheduler_dp_attn_mixin import prepare_mlp_sync_batch_raw +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.sampling.sampling_params import SamplingParams +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.utils import ( + configure_logger, + get_bool_env_var, + kill_process_tree, + maybe_reindex_device_id, + require_mlp_sync, + require_mlp_tp_gather, + set_gpu_proc_affinity, + suppress_other_loggers, +) +from sglang.srt.utils.hf_transformers_utils import get_tokenizer + + +def start_profile(profile_activities, profile_record_shapes=False, rank_print=print): + """ + Abstracted function to start profiling based on profile_activities. + Returns profiler object (or None). + """ + if "CUDA_PROFILER" in profile_activities: + try: + torch.cuda.cudart().cudaProfilerStart() + rank_print("CUDA Profiler started (nsys will begin capturing)") + except Exception as e: + rank_print(f"Failed to start CUDA profiler: {e}") + return None + else: + activities = [] + if "CPU" in profile_activities: + activities.append(torch.profiler.ProfilerActivity.CPU) + if "GPU" in profile_activities: + activities.append(torch.profiler.ProfilerActivity.CUDA) + if "XPU" in profile_activities: + activities.append(torch.profiler.ProfilerActivity.XPU) + if activities: + profiler = torch.profiler.profile( + activities=activities, + with_stack=True, + record_shapes=profile_record_shapes, + ) + profiler.start() + return profiler + return None + + +def stop_profile( + profiler, + profile_activities, + rank_print=print, + save_trace=False, + trace_filename=None, + stage=None, +): + """ + Abstracted function to stop profiling based on profile_activities. + Optionally saves trace results and prints completion messages. + """ + if "CUDA_PROFILER" in profile_activities: + try: + torch.cuda.cudart().cudaProfilerStop() + rank_print("CUDA Profiler stopped (nsys should dump traces)") + except Exception as e: + rank_print(f"Failed to stop CUDA profiler: {e}") + elif profiler is not None: + profiler.stop() + + if save_trace: + if profiler is not None: + if trace_filename: + _save_profile_trace_results(profiler, trace_filename) + stage_desc = f"for {stage}" if stage else "" + rank_print( + f"torch profiler chrome trace {stage_desc} saved to {trace_filename}" + ) + if "CUDA_PROFILER" in profile_activities: + rank_print(f"CUDA profiler trace for {stage} completed") + + +@dataclasses.dataclass +class BenchArgs: + run_name: str = "default" + batch_size: Tuple[int] = (1,) + input_len: Tuple[int] = (1024,) + output_len: Tuple[int] = (16,) + prompt_filename: str = "" + result_filename: str = "result.jsonl" + correctness_test: bool = False + # This is only used for correctness test + cut_len: int = 4 + log_decode_step: int = 0 + profile: bool = False + profile_record_shapes: bool = False + profile_activities: Tuple[str] = ("CPU", "GPU") + profile_stage: str = "all" + profile_filename_prefix: str = "profile" + profile_start_step: Optional[int] = None + profile_steps: Optional[int] = None + + @staticmethod + def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--run-name", type=str, default=BenchArgs.run_name) + parser.add_argument( + "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size + ) + parser.add_argument( + "--input-len", type=int, nargs="+", default=BenchArgs.input_len + ) + parser.add_argument( + "--output-len", type=int, nargs="+", default=BenchArgs.output_len + ) + parser.add_argument( + "--prompt-filename", type=str, default=BenchArgs.prompt_filename + ) + parser.add_argument( + "--result-filename", type=str, default=BenchArgs.result_filename + ) + parser.add_argument("--correctness-test", action="store_true") + parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len) + parser.add_argument( + "--log-decode-step", + type=int, + default=BenchArgs.log_decode_step, + help="Log decode latency by step, default is set to zero to disable.", + ) + parser.add_argument("--profile", action="store_true", help="Enable profiling.") + parser.add_argument( + "--profile-record-shapes", + action="store_true", + help="Record tensor shapes in profiling results.", + ) + parser.add_argument( + "--profile-activities", + type=str, + nargs="+", + default=["CPU", "GPU"], + choices=["CPU", "GPU", "CUDA_PROFILER", "XPU"], + help="Profiler activities: CPU, GPU, XPU, CUDA_PROFILER. If CPU/GPU/XPU, use torch profiler. If CUDA_PROFILER, use CUDA profiler.", + ) + parser.add_argument( + "--profile-stage", + type=str, + default=BenchArgs.profile_stage, + choices=["all", "prefill", "decode"], + help="Which stage to profile: all, prefill, or decode only.", + ) + parser.add_argument( + "--profile-filename-prefix", + type=str, + default=BenchArgs.profile_filename_prefix, + help="Prefix of the profiling file names. The full profiling result file(s) be " + '"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"', + ) + parser.add_argument( + "--profile-start-step", + type=int, + default=None, + help="Decode step at which to start profiling (0-indexed). If not specified, defaults to output_len // 2.", + ) + parser.add_argument( + "--profile-steps", + type=int, + default=None, + help="Number of decode steps to profile starting from profile-start-step. If not specified, profiles only one step.", + ) + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + # use the default value's type to cast the args into correct types. + attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)] + result = {} + for attr, attr_type in attrs: + value = getattr(args, attr) + # Handle None values - don't try to cast them + if value is None or attr_type == type(None): + result[attr] = value + else: + result[attr] = attr_type(value) + return cls(**result) + + +def load_model(server_args, port_args, gpu_id, tp_rank): + suppress_other_loggers() + rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None + moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) + + model_config = ModelConfig.from_server_args(server_args) + model_runner = ModelRunner( + model_config=model_config, + mem_fraction_static=server_args.mem_fraction_static, + gpu_id=gpu_id, + tp_rank=tp_rank, + tp_size=server_args.tp_size, + moe_ep_rank=moe_ep_rank, + moe_ep_size=server_args.ep_size, + pp_rank=0, + pp_size=1, + nccl_port=port_args.nccl_port, + server_args=server_args, + ) + rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}") + tokenizer = get_tokenizer( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) + if server_args.tp_size > 1: + dist.barrier() + return model_runner, tokenizer + + +def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts): + prompts = ( + custom_prompts + if custom_prompts + else [ + "The capital of France is", + "The capital of the United Kindom is", + "Today is a sunny day and I like", + ] + ) + input_ids = [tokenizer.encode(p) for p in prompts] + sampling_params = SamplingParams( + temperature=0, + max_new_tokens=BenchArgs.output_len, + ) + + reqs = [] + for i in range(len(prompts)): + assert len(input_ids[i]) > bench_args.cut_len + + tmp_input_ids = input_ids[i][: bench_args.cut_len] + req = Req( + rid=i, + origin_input_text=prompts[i], + origin_input_ids=tmp_input_ids, + sampling_params=sampling_params, + ) + req.fill_ids = req.origin_input_ids + req.logprob_start_len = -1 + req.set_extend_input_len(len(req.fill_ids) - len(req.prefix_indices)) + reqs.append(req) + + return input_ids, reqs + + +def prepare_extend_inputs_for_correctness_test( + bench_args, input_ids, reqs, model_runner +): + for i in range(len(reqs)): + req: Req = reqs[i] + req.fill_ids += input_ids[i][bench_args.cut_len :] + req.prefix_indices = model_runner.req_to_token_pool.req_to_token[ + i, : bench_args.cut_len + ].to(req.prefix_indices.dtype) + req.logprob_start_len = -1 + req.set_extend_input_len(len(req.fill_ids) - len(req.prefix_indices)) + return reqs + + +def prepare_synthetic_inputs_for_latency_test( + batch_size, input_len, custom_inputs=None +): + input_ids = ( + custom_inputs + if custom_inputs + else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32) + ) + sampling_params = SamplingParams( + temperature=0, + max_new_tokens=BenchArgs.output_len, + ) + + reqs = [] + for i in range(len(input_ids)): + req = Req( + rid=i, + origin_input_text="", + origin_input_ids=list(input_ids[i]), + sampling_params=sampling_params, + ) + req.fill_ids = req.origin_input_ids + req.logprob_start_len = -1 + req.set_extend_input_len(len(req.fill_ids) - len(req.prefix_indices)) + reqs.append(req) + + return reqs + + +class TreeCacheNamespace(SimpleNamespace): + def supports_swa(self) -> bool: + return False + + def supports_mamba(self) -> bool: + return False + + def is_chunk_cache(self) -> bool: + return False + + def is_tree_cache(self) -> bool: + return not self.is_chunk_cache() + + +@torch.no_grad +def extend(reqs, model_runner): + # Create dummy tree_cache for benchmarks (no prefix caching, just allocation) + dummy_tree_cache = TreeCacheNamespace( + page_size=model_runner.server_args.page_size, + device=model_runner.device, + token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, + ) + + batch = ScheduleBatch.init_new( + reqs=reqs, + req_to_token_pool=model_runner.req_to_token_pool, + token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, + tree_cache=dummy_tree_cache, + model_config=model_runner.model_config, + enable_overlap=False, + spec_algorithm=SpeculativeAlgorithm.NONE, + ) + batch.prepare_for_extend() + _maybe_prepare_mlp_sync_batch(batch, model_runner) + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) + logits_output = model_runner.forward(forward_batch).logits_output + next_token_ids = model_runner.sample(logits_output, forward_batch) + return next_token_ids, logits_output.next_token_logits, batch + + +@torch.no_grad +def decode(input_token_ids, batch, model_runner): + batch.output_ids = input_token_ids + batch.prepare_for_decode() + _maybe_prepare_mlp_sync_batch(batch, model_runner) + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) + logits_output = model_runner.forward(forward_batch).logits_output + next_token_ids = model_runner.sample(logits_output, forward_batch) + return next_token_ids, logits_output.next_token_logits + + +def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner): + if require_mlp_sync(model_runner.server_args): + prepare_mlp_sync_batch_raw( + batch, + dp_size=model_runner.server_args.dp_size, + attn_tp_size=1, + tp_group=model_runner.tp_group, + get_idle_batch=None, + disable_cuda_graph=model_runner.server_args.disable_cuda_graph, + require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args), + disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule, + offload_tags=set(), + ) + + +def _read_prompts_from_file(prompt_file, rank_print): + """Read custom prompts from the file specified by `--prompt-filename`.""" + if not prompt_file: + return [] + if not os.path.exists(prompt_file): + rank_print( + f"Custom prompt file {prompt_file} not found. Using default inputs..." + ) + return [] + with open(prompt_file, "r") as pf: + return pf.readlines() + + +def _get_torch_profiler_output_dir(): + return os.environ.get("SGLANG_TORCH_PROFILER_DIR", "/tmp") + + +def _create_torch_profiler_filename( + profile_filename_prefix, batch_size, input_len, output_len, stage +): + output_dir = _get_torch_profiler_output_dir() + filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_{stage}.trace.json.gz" + return os.path.join(output_dir, filename) + + +def _save_profile_trace_results(profiler, filename): + parent_dir = os.path.dirname(os.path.abspath(filename)) + os.makedirs(parent_dir, exist_ok=True) + profiler.export_chrome_trace(filename) + print( + profiler.key_averages(group_by_input_shape=True).table( + sort_by="self_cpu_time_total" + ) + ) + + +def correctness_test( + server_args, + port_args, + bench_args, + gpu_id, + tp_rank, +): + # Configure the logger + configure_logger(server_args, prefix=f" TP{tp_rank}") + rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None + + # Load the model + model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank) + + # Prepare inputs + custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print) + input_ids, reqs = prepare_inputs_for_correctness_test( + bench_args, tokenizer, custom_prompts + ) + rank_print(f"\n{input_ids=}\n") + + if bench_args.cut_len > 0: + # Prefill + next_token_ids, next_token_logits, batch = extend(reqs, model_runner) + rank_print(f"prefill logits (first half): {next_token_logits} \n") + + # Prepare extend inputs + reqs = prepare_extend_inputs_for_correctness_test( + bench_args, input_ids, reqs, model_runner + ) + + # Extend (prefill w/ KV cache) + next_token_ids, next_token_logits, batch = extend(reqs, model_runner) + rank_print(f"prefill logits (final): {next_token_logits} \n") + + # Decode + output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))] + for _ in range(bench_args.output_len[0] - 1): + next_token_ids, _ = decode(next_token_ids, batch, model_runner) + next_token_ids_list = next_token_ids.tolist() + for i in range(len(reqs)): + output_ids[i].append(next_token_ids_list[i]) + + # Print output texts + for i in range(len(reqs)): + rank_print(f"========== Prompt {i} ==========") + rank_print(tokenizer.decode(output_ids[i]), "\n") + + +def synchronize(device): + torch.get_device_module(device).synchronize() + + +def latency_test_run_once( + run_name, + model_runner, + rank_print, + reqs, + batch_size, + input_len, + output_len, + device, + log_decode_step, + profile, + profile_record_shapes, + profile_activities, + profile_filename_prefix, + profile_stage, + tp_rank, + profile_start_step=None, + profile_steps=None, +): + max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len) + if batch_size > max_batch_size: + rank_print( + f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit" + ) + return + + model_runner.req_to_token_pool.clear() + model_runner.token_to_kv_pool_allocator.clear() + + measurement_results = { + "run_name": run_name, + "batch_size": batch_size, + "input_len": input_len, + "output_len": output_len, + } + + tot_latency = 0 + + profiler = None + enable_profile_prefill = profile and profile_stage in ["all", "prefill"] + if enable_profile_prefill: + profiler = start_profile( + profile_activities, + profile_record_shapes=profile_record_shapes, + rank_print=rank_print, + ) + + synchronize(device) + tic = time.perf_counter() + next_token_ids, _, batch = extend(reqs, model_runner) + synchronize(device) + prefill_latency = time.perf_counter() - tic + + if enable_profile_prefill: + trace_filename = _create_torch_profiler_filename( + profile_filename_prefix, batch_size, input_len, output_len, "prefill" + ) + stop_profile( + profiler, + profile_activities, + rank_print=rank_print, + save_trace=True, + trace_filename=trace_filename, + stage="prefill", + ) + + tot_latency += prefill_latency + throughput = input_len * batch_size / prefill_latency + rank_print( + f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s" + ) + measurement_results["prefill_latency"] = prefill_latency + measurement_results["prefill_throughput"] = throughput + + decode_latencies = [] + # Determine profiling start step and end step + profile_start = ( + profile_start_step if profile_start_step is not None else (output_len // 2) + ) + profile_end = profile_start + (profile_steps if profile_steps is not None else 1) + enable_profile_decode = profile and profile_stage in ["all", "decode"] + profiler = None + for i in range(output_len - 1): + synchronize(device) + # Start profiler at the specified step + if enable_profile_decode and i == profile_start: + profiler = start_profile( + profile_activities, + profile_record_shapes=profile_record_shapes, + rank_print=rank_print, + ) + + tic = time.perf_counter() + next_token_ids, _ = decode(next_token_ids, batch, model_runner) + synchronize(device) + latency = time.perf_counter() - tic + + # Stop profiler after the specified number of steps + if enable_profile_decode and profiler is not None and i >= profile_end - 1: + trace_filename = _create_torch_profiler_filename( + profile_filename_prefix, batch_size, input_len, output_len, "decode" + ) + stop_profile( + profiler, + profile_activities, + rank_print=rank_print, + save_trace=True, + trace_filename=trace_filename, + stage="decode", + ) + profiler = None + + tot_latency += latency + throughput = batch_size / latency + decode_latencies.append(latency) + if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0): + rank_print( + f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" + ) + + # Record decode timing from 2nd output + if output_len > 1: + med_decode_latency = np.median(decode_latencies) + med_decode_throughput = batch_size / med_decode_latency + rank_print( + f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s" + ) + measurement_results["median_decode_latency"] = med_decode_latency + measurement_results["median_decode_throughput"] = med_decode_throughput + + throughput = (input_len + output_len) * batch_size / tot_latency + rank_print( + f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s" + ) + measurement_results["total_latency"] = tot_latency + measurement_results["overall_throughput"] = throughput + return measurement_results + + +def latency_test( + server_args, + port_args, + bench_args, + gpu_id, + tp_rank, +): + initialize_moe_config(server_args) + initialize_fp8_gemm_config(server_args) + initialize_fp4_gemm_config(server_args) + + # Set CPU affinity + if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): + set_gpu_proc_affinity( + server_args.pp_size, server_args.tp_size, server_args.nnodes, tp_rank + ) + + # Configure the logger + configure_logger(server_args, prefix=f" TP{tp_rank}") + rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None + + # Load the model + model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank) + + # Prepare inputs for warm up + reqs = prepare_synthetic_inputs_for_latency_test( + bench_args.batch_size[0], bench_args.input_len[0] + ) + + # Warm up + rank_print("Warmup ...") + latency_test_run_once( + bench_args.run_name, + model_runner, + rank_print, + reqs, + bench_args.batch_size[0], + bench_args.input_len[0], + min(32, bench_args.output_len[0]), # shorter decoding to speed up the warmup + server_args.device, + log_decode_step=0, + profile=False, + profile_record_shapes=False, + profile_activities=("CPU", "GPU"), + profile_filename_prefix="", + profile_stage="all", + tp_rank=tp_rank, + profile_start_step=None, + profile_steps=None, + ) + + rank_print("Benchmark ...") + + custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print) + custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs] + custom_input_len = len(custom_inputs) + + # Run the sweep + result_list = [] + for bs, il, ol in itertools.product( + bench_args.batch_size, bench_args.input_len, bench_args.output_len + ): + bs_aligned_inputs = [] + if custom_inputs: + if custom_input_len == bs: + bs_aligned_inputs = custom_inputs + elif custom_input_len > bs: + rank_print( + f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). " + f"Using the first {bs} prompts." + ) + bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs]) + else: + rank_print( + f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). " + f"Pad to the desired batch_size with the last prompt." + ) + bs_aligned_inputs = copy.deepcopy(custom_inputs) + bs_aligned_inputs.extend( + [bs_aligned_inputs[-1]] * (bs - custom_input_len) + ) + + reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs) + ret = latency_test_run_once( + bench_args.run_name, + model_runner, + rank_print, + reqs, + bs, + il, + ol, + server_args.device, + bench_args.log_decode_step, + bench_args.profile if tp_rank == 0 else None, + bench_args.profile_record_shapes if tp_rank == 0 else None, + bench_args.profile_activities, + bench_args.profile_filename_prefix, + bench_args.profile_stage, + tp_rank, + bench_args.profile_start_step, + bench_args.profile_steps, + ) + if ret is not None: + result_list.append(ret) + + # Write results in jsonlines format on rank 0. + if tp_rank == 0 and bench_args.result_filename: + with open(bench_args.result_filename, "a") as fout: + for result in result_list: + fout.write(json.dumps(result) + "\n") + + if server_args.tp_size > 1: + destroy_distributed_environment() + + +def main(server_args, bench_args): + server_args.cuda_graph_max_bs = max(bench_args.batch_size) + + _set_envs_and_config(server_args) + + if server_args.model_path: + if bench_args.correctness_test: + work_func = correctness_test + else: + work_func = latency_test + else: + raise ValueError( + "Provide --model-path for running the tests or " + "provide --result-filename for plotting the results" + ) + + port_args = PortArgs.init_new(server_args) + + if server_args.tp_size == 1: + work_func(server_args, port_args, bench_args, 0, 0) + else: + workers = [] + for tp_rank in range(server_args.tp_size): + with maybe_reindex_device_id(tp_rank) as gpu_id: + proc = multiprocessing.Process( + target=work_func, + args=( + server_args, + port_args, + bench_args, + gpu_id, + tp_rank, + ), + ) + proc.start() + workers.append(proc) + + for proc in workers: + proc.join() + + proc.terminate() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + BenchArgs.add_cli_args(parser) + args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) + bench_args = BenchArgs.from_cli_args(args) + + logging.basicConfig( + level=getattr(logging, server_args.log_level.upper()), + format="%(message)s", + ) + + try: + main(server_args, bench_args) + finally: + if server_args.tp_size != 1: + kill_process_tree(os.getpid(), include_parent=False) diff --git a/sglang/python/sglang/bench_one_batch_server.py b/sglang/python/sglang/bench_one_batch_server.py new file mode 100644 index 0000000000000000000000000000000000000000..a2d8325f8e6a76a9fe6dff640c3f2902bd180a14 --- /dev/null +++ b/sglang/python/sglang/bench_one_batch_server.py @@ -0,0 +1,49 @@ +""" +Benchmark the latency of running a single batch with a server. + +This script launches a server and uses the HTTP interface. +It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths). + +Usage: +python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8 + +python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 +python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage +python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --output-path results.json --profile +""" + +import argparse + +from sglang.srt.server_args import ServerArgs +from sglang.test.bench_one_batch_server_internal import ( + BenchArgs, + run_benchmark_internal, +) +from sglang.test.nightly_bench_utils import save_results_as_pydantic_models + + +def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): + results, server_info = run_benchmark_internal(server_args, bench_args) + + # Save results as pydantic models in the JSON format + if bench_args.pydantic_result_filename: + save_results_as_pydantic_models( + results, + pydantic_result_filename=bench_args.pydantic_result_filename, + model_path=server_args.model_path, + server_args=bench_args.server_args_for_metrics, + ) + + return results, server_info + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + BenchArgs.add_cli_args(parser) + args = parser.parse_args() + + server_args = ServerArgs.from_cli_args(args) + bench_args = BenchArgs.from_cli_args(args) + + run_benchmark(server_args, bench_args) diff --git a/sglang/python/sglang/bench_serving.py b/sglang/python/sglang/bench_serving.py new file mode 100644 index 0000000000000000000000000000000000000000..dec3109c22eb3350006809a1cda242a4de60683c --- /dev/null +++ b/sglang/python/sglang/bench_serving.py @@ -0,0 +1,2238 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py +# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py + +""" +Benchmark online serving with dynamic requests. + +Usage: +python3 -m sglang.bench_serving --backend sglang --num-prompt 10 + +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5 +""" + +import argparse +import asyncio +import copy +import importlib.util +import json +import os +import random +import shutil +import sys +import time +import traceback +import uuid +import warnings +from argparse import ArgumentParser +from copy import deepcopy +from dataclasses import dataclass, field, replace +from datetime import datetime +from pathlib import Path +from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple, Union + +import aiohttp +import numpy as np +import requests +from tqdm.asyncio import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from sglang.benchmark.datasets import DatasetRow, get_dataset +from sglang.benchmark.datasets.mooncake import get_mooncake_request_over_time +from sglang.benchmark.utils import ( + get_tokenizer, + parse_custom_headers, + remove_prefix, + set_ulimit, +) + +_ROUTING_KEY_HEADER = "X-SMG-Routing-Key" + +TERM_PLOTLIB_AVAILABLE = (importlib.util.find_spec("termplotlib") is not None) and ( + shutil.which("gnuplot") is not None +) + +global args + + +# don't want to import sglang package here +def _get_bool_env_var(name: str, default: str = "false") -> bool: + value = os.getenv(name, default) + return value.lower() in ("true", "1") + + +def _create_bench_client_session(): + # When the pressure is big, the read buffer could be full before aio thread read + # the content. We increase the read_bufsize from 64K to 10M. + # Define constants for timeout and buffer size for clarity and maintainability + BENCH_AIOHTTP_TIMEOUT_SECONDS = 6 * 60 * 60 # 6 hours + BENCH_AIOHTTP_READ_BUFSIZE_BYTES = 10 * 1024**2 # 10 MB + + aiohttp_timeout = aiohttp.ClientTimeout(total=BENCH_AIOHTTP_TIMEOUT_SECONDS) + return aiohttp.ClientSession( + timeout=aiohttp_timeout, read_bufsize=BENCH_AIOHTTP_READ_BUFSIZE_BYTES + ) + + +@dataclass +class RequestFuncInput: + prompt: Union[str, List[str], List[Dict[str, str]]] + api_url: str + prompt_len: int + output_len: int + model: str + lora_name: str + image_data: Optional[List[str]] + extra_request_body: Dict[str, Any] + timestamp: Optional[float] = None + routing_key: Optional[str] = None + + +@dataclass +class RequestFuncOutput: + generated_text: str = "" + success: bool = False + latency: float = 0.0 + ttft: float = 0.0 # Time to first token + itl: List[float] = field(default_factory=list) # List of inter-token latencies + text_chunks: List[str] = field(default_factory=list) + prompt_len: int = 0 + error: str = "" + output_len: int = 0 + start_time: float = 0.0 + + @staticmethod + def init_new(request_func_input: RequestFuncInput): + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + return output + + +def get_auth_headers() -> Dict[str, str]: + openai_api_key = os.environ.get("OPENAI_API_KEY") + if openai_api_key: + return {"Authorization": f"Bearer {openai_api_key}"} + else: + api_key = os.environ.get("API_KEY") + if api_key: + return {"Authorization": f"{api_key}"} + return {} + + +def get_request_headers() -> Dict[str, str]: + headers = get_auth_headers() + if h := getattr(args, "header", None): + headers.update(parse_custom_headers(h)) + return headers + + +def wait_for_endpoint(url: str, timeout_sec: int = 60) -> bool: + """Wait for the server to become ready by polling the given URL.""" + print(f"Waiting up to {timeout_sec}s for {url} to become ready...") + start_time = time.perf_counter() + headers = get_auth_headers() + while True: + try: + response = requests.get(url, headers=headers, timeout=5) + if response.status_code == 200: + elapsed = time.perf_counter() - start_time + print(f"Server ready in {elapsed:.1f}s.") + return True + except requests.exceptions.RequestException: + pass + elapsed = time.perf_counter() - start_time + if elapsed >= timeout_sec: + print(f"Server did not become ready within {timeout_sec}s timeout.") + return False + time.sleep(1) + + +# trt llm does not support ignore_eos +# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505 +async def async_request_trt_llm( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith("generate_stream") + + async with _create_bench_client_session() as session: + payload = { + "accumulate_tokens": True, + "text_input": request_func_input.prompt, + "temperature": 0.000001, + "top_p": 1.0, + "max_tokens": request_func_input.output_len, + "stream": True, + "min_length": request_func_input.output_len, + "end_id": 1048576, + **request_func_input.extra_request_body, + } + if args.disable_ignore_eos: + del payload["min_length"] + del payload["end_id"] + output = RequestFuncOutput.init_new(request_func_input) + + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:") + + data = json.loads(chunk) + output.generated_text += data["text_output"] + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + + output.latency = most_recent_timestamp - st + output.success = True + output.output_len = request_func_input.output_len + + else: + output.error = ( + (response.reason or "") + ": " + (await response.text()) + ) + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +# set ignore_eos True by default +async def async_request_openai_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith( + "completions" + ), "OpenAI Completions API URL must end with 'completions'." + + prompt = request_func_input.prompt + + async with _create_bench_client_session() as session: + # Build payload with defaults that can be overridden by extra_request_body + payload = { + "model": request_func_input.model, + "prompt": prompt, + "best_of": 1, + "max_tokens": request_func_input.output_len, + "stream": not args.disable_stream, + } + + # Add temperature default only if not specified in extra_request_body + if "temperature" not in request_func_input.extra_request_body: + payload["temperature"] = 0.0 + + # Add ignore_eos default only if not specified in extra_request_body + if "ignore_eos" not in request_func_input.extra_request_body: + payload["ignore_eos"] = not args.disable_ignore_eos + + # Merge in extra parameters - these will override defaults if present + payload.update(request_func_input.extra_request_body) + + # hack to accommodate different LoRA conventions between SGLang and vLLM. + if request_func_input.lora_name: + payload["model"] = request_func_input.lora_name + payload["lora_path"] = request_func_input.lora_name + + if request_func_input.image_data: + payload.update({"image_data": request_func_input.image_data}) + + headers = get_request_headers() + if request_func_input.routing_key: + headers[_ROUTING_KEY_HEADER] = request_func_input.routing_key + + output = RequestFuncOutput.init_new(request_func_input) + + generated_text = "" + output_len = request_func_input.output_len + ttft = 0.0 + st = time.perf_counter() + output.start_time = st + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if data["choices"][0]["text"]: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.text_chunks.append( + data["choices"][0]["text"] + ) + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += data["choices"][0]["text"] + output_len = (data.get("usage") or {}).get( + "completion_tokens", output_len + ) + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = output_len + else: + output.error = ( + (response.reason or "") + ": " + (await response.text()) + ) + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_openai_chat_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + """Makes a request to the OpenAI Chat Completions API. + + Handles both streaming and non-streaming responses, including support + for image data in messages. Calculates and returns various performance + metrics. + + Args: + request_func_input: Input parameters for the request. + pbar: Optional tqdm progress bar to update. + + Returns: + RequestFuncOutput: Output of the request, including generated text, + latency, TTFT, ITL, and success status. + """ + api_url = request_func_input.api_url + assert api_url.endswith( + "chat/completions" + ), "OpenAI Chat Completions API URL must end with 'chat/completions'." + + # TODO put it to other functions when `pbar` logic is refactored + if getattr(args, "print_requests", False): + rid = str(uuid.uuid4()) + input_partial = deepcopy(request_func_input) + input_partial.prompt = "..." + request_start_time = time.time() + print( + f'rid={rid} time={request_start_time} message="request start" request_func_input="{str(input_partial)}"' + ) + + if isinstance(request_func_input.prompt, list): + messages = request_func_input.prompt + elif request_func_input.image_data: + # Build multi-image content: a list of image_url entries followed by the text + content_items = [ + { + "type": "image_url", + "image_url": {"url": img_url}, + } + for img_url in request_func_input.image_data + ] + content_items.append({"type": "text", "text": request_func_input.prompt}) + messages = [ + { + "role": "user", + "content": content_items, + }, + ] + else: + messages = [{"role": "user", "content": request_func_input.prompt}] + + async with _create_bench_client_session() as session: + # Build payload with defaults that can be overridden by extra_request_body + payload = { + "model": request_func_input.model, + "messages": messages, + "max_completion_tokens": request_func_input.output_len, + "stream": not args.disable_stream, + } + + # Add temperature default only if not specified in extra_request_body + if "temperature" not in request_func_input.extra_request_body: + payload["temperature"] = 0.0 + + # Add ignore_eos default only if not specified in extra_request_body + # Default to False for more realistic behavior (respect EOS tokens) + if "ignore_eos" not in request_func_input.extra_request_body: + payload["ignore_eos"] = not args.disable_ignore_eos + + # Merge in extra parameters (tools, temperature, top_p, etc.) + # These will override defaults if present + payload.update(request_func_input.extra_request_body) + + # hack to accommodate different LoRA conventions between SGLang and vLLM. + if request_func_input.lora_name: + payload["model"] = request_func_input.lora_name + payload["lora_path"] = request_func_input.lora_name + + headers = get_request_headers() + if request_func_input.routing_key: + headers[_ROUTING_KEY_HEADER] = request_func_input.routing_key + + output = RequestFuncOutput.init_new(request_func_input) + + generated_text = "" + output_len = request_func_input.output_len + ttft = 0.0 + st = time.perf_counter() + output.start_time = st + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + if args.disable_stream: + # Non-streaming response + response_json = await response.json() + output.generated_text = response_json["choices"][0]["message"][ + "content" + ] + output.success = True + output.latency = time.perf_counter() - st + output.ttft = ( + output.latency + ) # For non-streaming, TTFT = total latency + output.output_len = response_json.get("usage", {}).get( + "completion_tokens", output_len + ) + else: + # Streaming response + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = json.loads(chunk) + + # Check if this chunk contains content + delta = data.get("choices", [{}])[0].get("delta", {}) + content = delta.get("content", "") + + if content: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.text_chunks.append(content) + output.itl.append( + timestamp - most_recent_timestamp + ) + + most_recent_timestamp = timestamp + generated_text += content + + # Check for usage info in final chunk + output_len = (data.get("usage") or {}).get( + "completion_tokens", output_len + ) + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = output_len + else: + output.error = ( + (response.reason or "") + ": " + (await response.text()) + ) + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + # TODO put it to other functions when `pbar` logic is refactored + if getattr(args, "print_requests", False): + curr_t = time.time() + output_partial = deepcopy(output) + output_partial.generated_text = "..." + print( + f'rid={rid} time={curr_t} time_delta={curr_t - request_start_time} message="request end" output="{str(output_partial)}"' + ) + + if pbar: + pbar.update(1) + return output + + +async def async_request_truss( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + + prompt = request_func_input.prompt + + async with _create_bench_client_session() as session: + payload = { + "model": request_func_input.model, + "prompt": prompt, + "temperature": 0.0, + "best_of": 1, + "max_tokens": request_func_input.output_len, + "stream": not args.disable_stream, + "ignore_eos": not args.disable_ignore_eos, + **request_func_input.extra_request_body, + } + headers = get_request_headers() + + output = RequestFuncOutput.init_new(request_func_input) + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if data["choices"][0]["text"]: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += data["choices"][0]["text"] + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = request_func_input.output_len + else: + output.error = ( + (response.reason or "") + ": " + (await response.text()) + ) + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_sglang_generate( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + prompt = request_func_input.prompt + + async with _create_bench_client_session() as session: + payload = { + ("text" if isinstance(prompt, str) else "input_ids"): prompt, + "sampling_params": { + "temperature": 0.0, + "max_new_tokens": request_func_input.output_len, + "ignore_eos": not args.disable_ignore_eos, + }, + "stream": not args.disable_stream, + "lora_path": request_func_input.lora_name, + "return_logprob": args.return_logprob, + "return_routed_experts": args.return_routed_experts, + "logprob_start_len": -1, + **request_func_input.extra_request_body, + } + + # Add image data if available (list of image urls/base64) + if request_func_input.image_data: + payload["image_data"] = request_func_input.image_data + + headers = get_request_headers() + if request_func_input.routing_key: + headers[_ROUTING_KEY_HEADER] = request_func_input.routing_key + + output = RequestFuncOutput.init_new(request_func_input) + + generated_text = "" + output_len = request_func_input.output_len + ttft = 0.0 + st = time.perf_counter() + output.start_time = st + most_recent_timestamp = st + last_output_len = 0 + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if "text" in data and data["text"]: + timestamp = time.perf_counter() + generated_text = data["text"] + output_len = data["meta_info"]["completion_tokens"] + + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + num_new_tokens = output_len - last_output_len + if num_new_tokens == 0: + continue + chunk_gap = timestamp - most_recent_timestamp + adjust_itl = chunk_gap / num_new_tokens + output.itl.extend([adjust_itl] * num_new_tokens) + + most_recent_timestamp = timestamp + last_output_len = output_len + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = output_len + else: + output.error = ( + (response.reason or "") + ": " + (await response.text()) + ) + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + print(f"{output.error=}") + + if pbar: + pbar.update(1) + return output + + +async def async_request_gserver( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + raise NotImplementedError() + + +async def async_request_profile(api_url: str) -> RequestFuncOutput: + async with _create_bench_client_session() as session: + output = RequestFuncOutput() + try: + if api_url.endswith("/start_profile"): + num_steps = getattr(args, "profile_num_steps", None) + profile_by_stage = getattr(args, "profile_by_stage", None) + if profile_by_stage and num_steps is None: + num_steps = 5 + + output_dir = getattr(args, "profile_output_dir", None) + if output_dir is None: + output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp") + output_dir = Path(os.path.abspath(os.path.normpath(output_dir))) / str( + time.time() + ) + output_dir.mkdir(exist_ok=True, parents=True) + output_dir = str(output_dir) + + body = { + "activities": getattr(args, "profile_activities", []), + "num_steps": num_steps, + "profile_by_stage": profile_by_stage, + "profile_stages": getattr(args, "profile_stages", None), + "output_dir": output_dir, + "profile_prefix": getattr(args, "profile_prefix", None), + } + else: + # stop_profile doesn't need any parameters + body = {} + print(f"async_request_profile {api_url=} {body=}") + # Add optional profiling parameters if provided + if ( + hasattr(args, "profile_start_step") + and args.profile_start_step is not None + ): + body["start_step"] = str(args.profile_start_step) + if hasattr(args, "profile_steps") and args.profile_steps is not None: + body["num_steps"] = str(args.profile_steps) + async with session.post(url=api_url, json=body) as response: + if response.status == 200: + output.success = True + else: + output.error = ( + (response.reason or "") + ": " + (await response.text()) + ) + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + return output + + +def _build_profile_urls( + profile_prefill_url: Optional[List[str]], + profile_decode_url: Optional[List[str]], +) -> List[Tuple[str, str]]: + """Build profile URLs list from prefill/decode URL arguments. + + Returns: + List of (worker_type, url) tuples. e.g., [("Prefill-0", "http://..."), ("Decode-0", "http://...")] + """ + profile_urls = [] + if profile_prefill_url: + for idx, url in enumerate(profile_prefill_url): + profile_urls.append((f"Prefill-{idx}", url)) + if profile_decode_url: + for idx, url in enumerate(profile_decode_url): + profile_urls.append((f"Decode-{idx}", url)) + return profile_urls + + +async def _call_profile_pd(profile_urls: List[Tuple[str, str]], mode: str) -> None: + """Call profile endpoint (start/stop) on PD separated workers. + + Args: + profile_urls: List of (worker_type, url) tuples + mode: "start" or "stop" + """ + endpoint = "/start_profile" if mode == "start" else "/stop_profile" + action = "Starting" if mode == "start" else "Stopping" + action_past = "started" if mode == "start" else "stopped" + + print(f"{action} profiler...") + + for worker_type, url in profile_urls: + profile_output = await async_request_profile(api_url=url + endpoint) + if profile_output.success: + print(f"Profiler {action_past} for {worker_type} worker at {url}") + else: + print( + f"Failed to {mode} profiler for {worker_type} worker at {url}: {profile_output.error}" + ) + + +ASYNC_REQUEST_FUNCS = { + "sglang": async_request_sglang_generate, + "sglang-native": async_request_sglang_generate, + "sglang-oai": async_request_openai_completions, + "sglang-oai-chat": async_request_openai_chat_completions, + "vllm": async_request_openai_completions, + "vllm-chat": async_request_openai_chat_completions, + "lmdeploy": async_request_openai_completions, + "lmdeploy-chat": async_request_openai_chat_completions, + "trt": async_request_trt_llm, + "gserver": async_request_gserver, + "truss": async_request_truss, +} + + +@dataclass +class BenchmarkMetrics: + completed: int + total_input: int + total_input_text: int + total_input_vision: int + total_output: int + total_output_retokenized: int + request_throughput: float + input_throughput: float + output_throughput: float + output_throughput_retokenized: float + total_throughput: float + total_throughput_retokenized: float + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + p99_ttft_ms: float + mean_tpot_ms: float + median_tpot_ms: float + std_tpot_ms: float + p99_tpot_ms: float + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + p95_itl_ms: float + p99_itl_ms: float + max_itl_ms: float + mean_e2e_latency_ms: float + median_e2e_latency_ms: float + std_e2e_latency_ms: float + p90_e2e_latency_ms: float + p99_e2e_latency_ms: float + concurrency: float + max_output_tokens_per_s: float = 0.0 + max_concurrent_requests: int = 0 + + +async def get_request( + input_requests: List[DatasetRow], + request_rate: float, + use_trace_timestamps: bool = False, + slowdown_factor: float = 1.0, +) -> AsyncGenerator[DatasetRow, None]: + if use_trace_timestamps: + print( + f"Using trace timestamps for request generation with slowdown factor {slowdown_factor}." + ) + # Sort requests by timestamp for correct replay + input_requests.sort(key=lambda r: r.timestamp) + + start_time = time.perf_counter() + trace_start_time_ms = input_requests[0].timestamp if input_requests else 0 + + for request in input_requests: + trace_time_s = (request.timestamp - trace_start_time_ms) / 1000.0 + target_arrival_time = start_time + (trace_time_s * slowdown_factor) + + sleep_duration = target_arrival_time - time.perf_counter() + if sleep_duration > 0: + await asyncio.sleep(sleep_duration) + + yield request + else: + input_requests_iter = iter(input_requests) + for request in input_requests_iter: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +def calculate_metrics( + input_requests: Optional[List[DatasetRow]], + outputs: List[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, + backend: str, + accept_length: Optional[float] = None, + plot_throughput: bool = False, +) -> Tuple[BenchmarkMetrics, List[int]]: + output_lens: List[int] = [] + retokenized_output_lens: List[int] = [] + total_input = 0 + total_input_text = 0 + total_input_vision = 0 + completed = 0 + itls: List[float] = [] + tpots: List[float] = [] + ttfts: List[float] = [] + e2e_latencies: List[float] = [] + retokenized_itls: List[float] = [] + + use_retokenized_itl = ( + accept_length is not None + and accept_length > 0 + and backend in ("sglang-oai", "sglang-oai-chat") + ) + + for i in range(len(outputs)): + if outputs[i].success: + output_len = outputs[i].output_len + output_lens.append(output_len) + retokenized_output_len = len( + tokenizer.encode(outputs[i].generated_text, add_special_tokens=False) + ) + retokenized_output_lens.append(retokenized_output_len) + if input_requests is not None: + total_input += input_requests[i].prompt_len + total_input_text += input_requests[i].text_prompt_len + total_input_vision += input_requests[i].vision_prompt_len + if output_len > 1: + tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1)) + if use_retokenized_itl: + for k, itl in enumerate(outputs[i].itl): + num_tokens = len( + tokenizer.encode( + outputs[i].text_chunks[k], add_special_tokens=False + ) + ) + adjusted_itl = itl / num_tokens + retokenized_itls.extend([adjusted_itl] * num_tokens) + else: + itls += outputs[i].itl + ttfts.append(outputs[i].ttft) + + e2e_latencies.append(outputs[i].latency) + + completed += 1 + else: + output_lens.append(0) + retokenized_output_lens.append(0) + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2, + ) + + max_output_tokens_per_s = 0.0 + max_concurrent_requests = 0 + + successful_outputs = [output for output in outputs if output.success] + if successful_outputs: + min_start_time = min(output.start_time for output in successful_outputs) + max_end_time = max( + output.start_time + output.latency for output in successful_outputs + ) + + duration_seconds = int(np.ceil(max_end_time - min_start_time)) + 1 + tokens_per_second = np.zeros(duration_seconds) + concurrent_requests_per_second = np.zeros(duration_seconds) + + for output in outputs: + if not output.success: + continue + + token_times = [output.start_time + output.ttft] + current_time = token_times[0] + for itl_value in output.itl: + current_time += itl_value + token_times.append(current_time) + + for token_time in token_times: + second_bucket = int(token_time - min_start_time) + if 0 <= second_bucket < duration_seconds: + tokens_per_second[second_bucket] += 1 + + request_start_second = int(output.start_time - min_start_time) + request_end_second = int( + (output.start_time + output.latency) - min_start_time + ) + + for second in range( + request_start_second, min(request_end_second + 1, duration_seconds) + ): + concurrent_requests_per_second[second] += 1 + + if len(tokens_per_second) > 0: + max_output_tokens_per_s = float(np.max(tokens_per_second)) + max_concurrent_requests = int(np.max(concurrent_requests_per_second)) + + if plot_throughput: + if TERM_PLOTLIB_AVAILABLE: + import termplotlib as tpl + + fig = tpl.figure() + fig.plot( + np.arange(len(tokens_per_second)), + tokens_per_second, + title="Output tokens per second", + xlabel="Time (s)", + ) + fig.plot( + np.arange(len(concurrent_requests_per_second)), + concurrent_requests_per_second, + title="Concurrent requests per second", + xlabel="Time (s)", + ) + fig.show() + else: + print("tip: install termplotlib and gnuplot to plot the metrics") + + itls = retokenized_itls if use_retokenized_itl else itls + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_input_text=total_input_text, + total_input_vision=total_input_vision, + total_output=sum(output_lens), + total_output_retokenized=sum(retokenized_output_lens), + request_throughput=completed / dur_s, + input_throughput=total_input / dur_s, + output_throughput=sum(output_lens) / dur_s, + output_throughput_retokenized=sum(retokenized_output_lens) / dur_s, + total_throughput=(total_input + sum(output_lens)) / dur_s, + total_throughput_retokenized=(total_input + sum(retokenized_output_lens)) + / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) + * 1000, # ttfts is empty if streaming is not supported by backend + median_ttft_ms=np.median(ttfts or 0) * 1000, + std_ttft_ms=np.std(ttfts or 0) * 1000, + p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, + mean_tpot_ms=np.mean(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, + mean_itl_ms=np.mean(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + p95_itl_ms=np.percentile(itls or 0, 95) * 1000, + p99_itl_ms=np.percentile(itls or 0, 99) * 1000, + max_itl_ms=np.max(itls or 0) * 1000, + mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000, + median_e2e_latency_ms=np.median(e2e_latencies) * 1000, + std_e2e_latency_ms=np.std(e2e_latencies) * 1000, + p90_e2e_latency_ms=np.percentile(e2e_latencies, 90) * 1000, + p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000, + concurrency=np.sum(e2e_latencies) / dur_s, + max_output_tokens_per_s=max_output_tokens_per_s, + max_concurrent_requests=max_concurrent_requests, + ) + + return metrics, output_lens + + +MULTI_TURN_BACKENDS = {"sglang-oai-chat", "vllm-chat", "lmdeploy-chat"} + + +def wrap_multi_turn_request_func(request_func: Callable, backend: str) -> Callable: + assert ( + backend in MULTI_TURN_BACKENDS + ), f"Multi-turn only supports chat backends: {MULTI_TURN_BACKENDS}, got {backend}" + + async def f( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, + ) -> List[RequestFuncOutput]: + prompts: List[str] = request_func_input.prompt + prev_messages: List[Dict[str, str]] = [] + outputs = [] + + for round_index in range(len(prompts)): + prev_messages.append({"role": "user", "content": prompts[round_index]}) + + inner_input = replace( + copy.deepcopy(request_func_input), prompt=copy.deepcopy(prev_messages) + ) + output = await request_func( + inner_input, pbar=pbar if round_index == len(prompts) - 1 else None + ) + outputs.append(output) + + prev_messages.append( + {"role": "assistant", "content": output.generated_text} + ) + + return outputs + + return f + + +async def benchmark( + backend: str, + api_url: str, + base_url: str, + model_id: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: List[DatasetRow], + request_rate: float, + max_concurrency: Optional[int], + disable_tqdm: bool, + lora_names: List[str], + lora_request_distribution: Optional[str], + lora_zipf_alpha: Optional[float], + extra_request_body: Dict[str, Any], + profile: bool, + pd_separated: bool = False, + flush_cache: bool = False, + warmup_requests: int = 1, + use_trace_timestamps: bool = False, + mooncake_slowdown_factor=1.0, + mooncake_num_rounds=1, + profile_prefill_url: Optional[List[str]] = None, + profile_decode_url: Optional[List[str]] = None, +): + if backend in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[backend] + else: + raise ValueError(f"Unknown backend: {backend}") + + # Check for multi-turn: prompt is a list of strings (not OpenAI messages dicts) + # Multi-turn format: ["turn1", "turn2", ...] - list of strings + # OpenAI format: [{"role": "user", "content": "..."}, ...] - list of dicts + first_prompt = input_requests[0].prompt + is_multi_turn = ( + isinstance(first_prompt, list) + and len(first_prompt) > 0 + and isinstance(first_prompt[0], str) + ) + if is_multi_turn: + request_func = wrap_multi_turn_request_func(request_func, backend=backend) + + # Limit concurrency + # From https://github.com/vllm-project/vllm/pull/9390 + semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None + + async def limited_request_func(request_func_input, pbar): + if semaphore is None: + return await request_func(request_func_input=request_func_input, pbar=pbar) + async with semaphore: + return await request_func(request_func_input=request_func_input, pbar=pbar) + + # Warmup + print(f"Starting warmup with {warmup_requests} sequences...") + + # Handle the data structure difference for the warmup request + if args.dataset_name == "mooncake": + # For mooncake, input_requests is a list of dicts. + # We need to build a temporary DatasetRow for the warmup phase. + warmup_record = input_requests[0] + + # Build prompt from hash_ids, just like in the async generator + hash_ids = warmup_record.get("hash_ids", []) + prompt_text = "" + for hash_id in hash_ids: + prompt_text += f"{hash_id}" + " ".join(["hi"] * 512) + prompt_text += "Can you tell me a detailed story in 1000 words?" + + output_len = warmup_record.get("output_length", 32) + prompt_len = len(tokenizer.encode(prompt_text)) + + # Create a temporary DatasetRow object for warmup + test_request = DatasetRow( + prompt=prompt_text, + prompt_len=prompt_len, + output_len=output_len, + image_data=None, # Mooncake doesn't have image data + ) + else: + # For all other datasets, input_requests is a list of DatasetRow objects + test_request = input_requests[0] + + if lora_names is not None and len(lora_names) != 0: + lora_name = lora_names[0] + else: + lora_name = None + + # Create the test input once + test_input = RequestFuncInput( + model=model_id, + prompt=test_request.prompt, + api_url=api_url, + prompt_len=test_request.prompt_len, + output_len=min(test_request.output_len, 32), + lora_name=lora_name, + image_data=test_request.image_data, + extra_request_body=extra_request_body, + ) + + # Run warmup requests + warmup_tasks = [] + for _ in range(warmup_requests): + warmup_tasks.append( + asyncio.create_task(request_func(request_func_input=test_input)) + ) + + warmup_outputs = await asyncio.gather(*warmup_tasks) + if is_multi_turn: + warmup_outputs = [x for output in warmup_outputs for x in output] + + # Check if at least one warmup request succeeded + if warmup_requests > 0 and not any(output.success for output in warmup_outputs): + raise ValueError( + "Warmup failed - Please make sure benchmark arguments " + f"are correctly specified. Error: {warmup_outputs[0].error}" + ) + else: + print( + f"Warmup completed with {args.warmup_requests} sequences. Starting main benchmark run..." + ) + + # Flush cache + if ("sglang" in backend and _get_bool_env_var("SGLANG_IS_IN_CI")) or flush_cache: + requests.post(base_url + "/flush_cache", headers=get_auth_headers()) + + time.sleep(1.0) + + # Build profile URLs for PD separated mode (do this once at the beginning) + pd_profile_urls = [] + if profile and pd_separated: + pd_profile_urls = _build_profile_urls(profile_prefill_url, profile_decode_url) + if not pd_profile_urls: + print( + "Warning: PD separated mode requires --profile-prefill-url or --profile-decode-url" + ) + print("Skipping profiler start. Please specify worker URLs for profiling.") + + # Start profiler + if profile: + if pd_separated: + if pd_profile_urls: + await _call_profile_pd(pd_profile_urls, "start") + else: + print("Starting profiler...") + profile_output = await async_request_profile( + api_url=base_url + "/start_profile" + ) + if profile_output.success: + print("Profiler started") + + # Run all requests + benchmark_start_time = time.perf_counter() + tasks: List[asyncio.Task] = [] + pbar_total = len(input_requests) + if ( + backend == "sglang" and args.dataset_name == "mooncake" + ): # Assuming mooncake is mainly for sglang or similar backends + print("Using time-based Mooncake request scheduler, ignoring --request-rate.") + request_generator = get_mooncake_request_over_time( + input_requests, tokenizer, mooncake_slowdown_factor, mooncake_num_rounds + ) + print( + f"Starting Mooncake trace replay. Sessions: {len(input_requests)}, Rounds per session: {mooncake_num_rounds}. Slowdown factor: {mooncake_slowdown_factor}" + ) + pbar_total *= args.mooncake_num_rounds + else: + request_generator = get_request(input_requests, request_rate) + + # Prepare LoRA request distribution parameters + if lora_request_distribution == "distinct": + lora_idx = 0 + elif lora_request_distribution == "skewed": + weights = np.array([lora_zipf_alpha**-i for i in range(len(lora_names))]) + lora_probs = weights / np.sum(weights) + else: + lora_idx = None + lora_probs = None + + pbar = None if disable_tqdm else tqdm(total=pbar_total) + async for request in request_generator: + if lora_names is not None and len(lora_names) != 0: + if lora_request_distribution == "uniform": + lora_name = random.choice(lora_names) + elif lora_request_distribution == "distinct": + lora_name = lora_names[lora_idx] + lora_idx = (lora_idx + 1) % len(lora_names) + else: + assert ( + lora_request_distribution == "skewed" + ), f"Unexpected lora_request_distribution: {lora_request_distribution}. Expected 'skewed'." + + lora_name = np.random.choice(lora_names, p=lora_probs) + else: + lora_name = None + + # Merge global extra_request_body with per-request extras + # Per-request parameters take precedence over global ones + merged_extra_body = {**extra_request_body, **request.extra_request_body} + + request_func_input = RequestFuncInput( + model=model_id, + prompt=request.prompt, + api_url=api_url, + prompt_len=request.prompt_len, + output_len=request.output_len, + lora_name=lora_name, + image_data=request.image_data, + extra_request_body=merged_extra_body, + timestamp=request.timestamp, + routing_key=request.routing_key, + ) + + tasks.append( + asyncio.create_task( + limited_request_func(request_func_input=request_func_input, pbar=pbar) + ) + ) + outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + if is_multi_turn: + outputs = [x for output in outputs for x in output] + + # Stop profiler (only if profile_steps was not provided, as it auto-stops) + if profile and not ( + hasattr(args, "profile_steps") and args.profile_steps is not None + ): + if pd_separated: + if pd_profile_urls: + await _call_profile_pd(pd_profile_urls, "stop") + else: + if getattr(args, "profile_num_steps", None) is None: + print("Stopping profiler...") + profile_output = await async_request_profile( + api_url=base_url + "/stop_profile" + ) + if profile_output.success: + print("Profiler stopped") + + if pbar is not None: + pbar.close() + + if "sglang" in backend: + server_info = requests.get( + base_url + "/get_server_info", headers=get_auth_headers() + ) + if server_info.status_code == 200: + server_info_json = server_info.json() + if "decode" in server_info_json: + server_info_json = server_info_json["decode"][0] + if ( + "internal_states" in server_info_json + and server_info_json["internal_states"] + ): + accept_length = server_info_json["internal_states"][0].get( + "avg_spec_accept_length", None + ) + else: + accept_length = None + else: + accept_length = None + else: + accept_length = None + + # Compute metrics and print results + benchmark_duration = time.perf_counter() - benchmark_start_time + metrics, output_lens = calculate_metrics( + input_requests=None if is_multi_turn else input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + backend=backend, + accept_length=accept_length, + plot_throughput=args.plot_throughput, + ) + + print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) + print("{:<40} {:<10}".format("Backend:", backend)) + print( + "{:<40} {:<10}".format( + "Traffic request rate:", "trace" if use_trace_timestamps else request_rate + ) + ) + print( + "{:<40} {:<10}".format( + "Max request concurrency:", + max_concurrency if max_concurrency else "not set", + ) + ) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total input text tokens:", metrics.total_input_text)) + if args.dataset_name in ["image", "mmmu"]: + print( + "{:<40} {:<10}".format( + "Total input vision tokens:", metrics.total_input_vision + ) + ) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print( + "{:<40} {:<10}".format( + "Total generated tokens (retokenized):", metrics.total_output_retokenized + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", metrics.request_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Input token throughput (tok/s):", metrics.input_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Peak output token throughput (tok/s):", metrics.max_output_tokens_per_s + ) + ) + print( + "{:<40} {:<10}".format( + "Peak concurrent requests:", metrics.max_concurrent_requests + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Total token throughput (tok/s):", metrics.total_throughput + ) + ) + print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency)) + if accept_length: + print("{:<40} {:<10.2f}".format("Accept length:", accept_length)) + print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) + print( + "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) + ) + print( + "{:<40} {:<10.2f}".format( + "Median E2E Latency (ms):", metrics.median_e2e_latency_ms + ) + ) + print( + "{:<40} {:<10.2f}".format("P90 E2E Latency (ms):", metrics.p90_e2e_latency_ms) + ) + print( + "{:<40} {:<10.2f}".format("P99 E2E Latency (ms):", metrics.p99_e2e_latency_ms) + ) + print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) + print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms)) + print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) + print( + "{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-") + ) + print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) + print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms)) + print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) + print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) + print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) + print("{:<40} {:<10.2f}".format("P95 ITL (ms):", metrics.p95_itl_ms)) + print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) + print("{:<40} {:<10.2f}".format("Max ITL (ms):", metrics.max_itl_ms)) + print("=" * 50) + + resp = requests.get(base_url + "/get_server_info", headers=get_auth_headers()) + server_info = resp.json() if resp.status_code == 200 else None + + if ( + metrics.median_ttft_ms is not None + and metrics.mean_itl_ms is not None + and metrics.output_throughput is not None + ): + result = { + # Arguments + "tag": getattr(args, "tag", None), + "backend": args.backend, + "dataset_name": args.dataset_name, + "request_rate": "trace" if use_trace_timestamps else request_rate, + "max_concurrency": max_concurrency, + "sharegpt_output_len": args.sharegpt_output_len, + "random_input_len": args.random_input_len, + "random_output_len": args.random_output_len, + "random_range_ratio": args.random_range_ratio, + # Information + "server_info": server_info, + # Results + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_input_text_tokens": metrics.total_input_text, + "total_input_vision_tokens": metrics.total_input_vision, + "total_output_tokens": metrics.total_output, + "total_output_tokens_retokenized": metrics.total_output_retokenized, + "request_throughput": metrics.request_throughput, + "input_throughput": metrics.input_throughput, + "output_throughput": metrics.output_throughput, + "total_throughput": metrics.total_throughput, + "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, + "median_e2e_latency_ms": metrics.median_e2e_latency_ms, + "std_e2e_latency_ms": metrics.std_e2e_latency_ms, + "p90_e2e_latency_ms": metrics.p90_e2e_latency_ms, + "p99_e2e_latency_ms": metrics.p99_e2e_latency_ms, + "mean_ttft_ms": metrics.mean_ttft_ms, + "median_ttft_ms": metrics.median_ttft_ms, + "std_ttft_ms": metrics.std_ttft_ms, + "p99_ttft_ms": metrics.p99_ttft_ms, + "mean_tpot_ms": metrics.mean_tpot_ms, + "median_tpot_ms": metrics.median_tpot_ms, + "std_tpot_ms": metrics.std_tpot_ms, + "p99_tpot_ms": metrics.p99_tpot_ms, + "mean_itl_ms": metrics.mean_itl_ms, + "median_itl_ms": metrics.median_itl_ms, + "std_itl_ms": metrics.std_itl_ms, + "p95_itl_ms": metrics.p95_itl_ms, + "p99_itl_ms": metrics.p99_itl_ms, + "concurrency": metrics.concurrency, + "accept_length": accept_length, + "max_output_tokens_per_s": metrics.max_output_tokens_per_s, + "max_concurrent_requests": metrics.max_concurrent_requests, + } + else: + print(f"Error running benchmark for request rate: {request_rate}") + print("-" * 30) + + # Determine output file name + if args.output_file: + output_file_name = args.output_file + else: + now = datetime.now().strftime("%m%d") + if args.dataset_name == "image": + output_file_name = ( + f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_" + f"{args.random_output_len}_{args.image_count}imgs_" + f"{args.image_resolution}.jsonl" + ) + elif args.dataset_name.startswith("random"): + output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl" + else: + output_file_name = ( + f"{args.backend}_{now}_{args.num_prompts}_{args.dataset_name}.jsonl" + ) + + result_details = { + "input_lens": [output.prompt_len for output in outputs], + "output_lens": output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + } + + # Append results to a JSONL file + with open(output_file_name, "a") as file: + if args.output_details: + result_for_dump = result | result_details + else: + result_for_dump = result + file.write(json.dumps(result_for_dump) + "\n") + + return result | result_details + + +def check_chat_template(model_path): + try: + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + return "chat_template" in tokenizer.init_kwargs + except Exception as e: + print(f"Fail to load tokenizer config with error={e}") + return False + + +def set_global_args(args_: argparse.Namespace): + """Set the global args.""" + global args + args = args_ + + +def run_benchmark(args_: argparse.Namespace): + global args + args = args_ + + # Set default value for max_concurrency if not present + if not hasattr(args, "max_concurrency"): + args.max_concurrency = None + + # Set default value for warmup_requests if not present + if not hasattr(args, "warmup_requests"): + args.warmup_requests = 1 + + if not hasattr(args, "output_details"): + args.output_details = False + + if not hasattr(args, "tokenize_prompt"): + args.tokenize_prompt = False + + if not hasattr(args, "plot_throughput"): + args.plot_throughput = False + + if not hasattr(args, "use_trace_timestamps"): + args.use_trace_timestamps = False + if not hasattr(args, "mooncake_slowdown_factor"): + args.mooncake_slowdown_factor = 1.0 + + if not hasattr(args, "mooncake_slowdown_factor"): + args.mooncake_slowdown_factor = 1.0 + + if not hasattr(args, "mooncake_num_rounds"): + args.mooncake_num_rounds = 1 + + if not hasattr(args, "served_model_name"): + args.served_model_name = None + + if getattr(args, "print_requests", False): + assert args.backend == "sglang-oai-chat" # only support this now + + print(f"benchmark_args={args}") + + # Set global environments + set_ulimit() + random.seed(args.seed) + np.random.seed(args.seed) + + extra_request_body = {} + if args.extra_request_body: + extra_request_body = json.loads(args.extra_request_body) + + if args.tokenize_prompt: + assert ( + args.backend == "sglang" + ), "`--tokenize-prompt` only compatible with `--backend sglang` currently" + + # Set url + if args.port is None: + args.port = { + "sglang": 30000, + "sglang-native": 30000, + "sglang-oai": 30000, + "lmdeploy": 23333, + "vllm": 8000, + "trt": 8000, + "gserver": 9988, + "truss": 8080, + }.get(args.backend, 30000) + + model_url = ( + f"{args.base_url}/v1/models" + if args.base_url + else f"http://{args.host}:{args.port}/v1/models" + ) + + if args.backend in ["sglang", "sglang-native"]: + api_url = ( + f"{args.base_url}/generate" + if args.base_url + else f"http://{args.host}:{args.port}/generate" + ) + elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]: + api_url = ( + f"{args.base_url}/v1/completions" + if args.base_url + else f"http://{args.host}:{args.port}/v1/completions" + ) + elif args.backend in ["sglang-oai-chat", "vllm-chat", "lmdeploy-chat"]: + api_url = ( + f"{args.base_url}/v1/chat/completions" + if args.base_url + else f"http://{args.host}:{args.port}/v1/chat/completions" + ) + elif args.backend == "trt": + api_url = ( + f"{args.base_url}/v2/models/ensemble/generate_stream" + if args.base_url + else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream" + ) + if args.model is None: + print("Please provide a model using `--model` when using `trt` backend.") + sys.exit(1) + elif args.backend == "gserver": + api_url = args.base_url if args.base_url else f"{args.host}:{args.port}" + args.model = args.model or "default" + elif args.backend == "truss": + api_url = ( + f"{args.base_url}/v1/models/model:predict" + if args.base_url + else f"http://{args.host}:{args.port}/v1/models/model:predict" + ) + base_url = ( + f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url + ) + + # Wait for server to be ready + if args.ready_check_timeout_sec > 0: + health_url = model_url if args.backend not in ("trt", "gserver") else base_url + if not wait_for_endpoint(health_url, args.ready_check_timeout_sec): + print(f"Server at {health_url} is not ready. Exiting.") + sys.exit(1) + + # Get model name + if args.model is None: + if args.backend == "truss": + print( + "Please provide a model with `--model` when using truss backend. e.g. --model meta-llama/Llama-3.1-8B-Instruct" + ) + sys.exit(1) + try: + response = requests.get(model_url, headers=get_auth_headers()) + model_list = response.json().get("data", []) + args.model = model_list[0]["id"] if model_list else None + except Exception as e: + print(f"Failed to fetch model from {model_url}. Error: {e}") + print( + "Please specify the correct host and port using `--host` and `--port`." + ) + sys.exit(1) + + if args.model is None: + print("No model specified or found. Please provide a model using `--model`.") + sys.exit(1) + + if not check_chat_template(args.model): + print( + "\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n" + "Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n" + ) + + if args.dataset_name in ["image", "mmmu"]: + args.apply_chat_template = True + assert ( + not args.tokenize_prompt + ), "`--tokenize-prompt` not compatible with image dataset" + + if args.lora_request_distribution in ["distinct", "skewed"]: + assert ( + args.lora_name is not None and len(args.lora_name) > 1 + ), "More than 1 LoRA adapter must be specified via --lora-name to use 'distinct' or 'skewed' request distribution." + + assert ( + args.lora_zipf_alpha > 1 + ), f"Got invalid value for --lora-zipf-alpha of {args.lora_zipf_alpha}. It must be greater than 1." + + print(f"{args}\n") + + # Read dataset + backend = args.backend + model_id = args.served_model_name or args.model + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + tokenizer = get_tokenizer(tokenizer_id) + input_requests = get_dataset(args, tokenizer, model_id) + + # compatible with SimpleNamespace + if not hasattr(args, "flush_cache"): + args.flush_cache = False + + # Prepare LoRA arguments + lora_request_distribution = ( + args.lora_request_distribution if args.lora_name is not None else None + ) + + lora_zipf_alpha = ( + args.lora_zipf_alpha + if args.lora_name is not None and args.lora_request_distribution == "skewed" + else None + ) + + return asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + base_url=base_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=args.request_rate, + max_concurrency=args.max_concurrency, + disable_tqdm=args.disable_tqdm, + lora_names=args.lora_name, + lora_request_distribution=lora_request_distribution, + lora_zipf_alpha=lora_zipf_alpha, + extra_request_body=extra_request_body, + profile=args.profile, + pd_separated=args.pd_separated, + flush_cache=args.flush_cache, + warmup_requests=args.warmup_requests, + use_trace_timestamps=args.use_trace_timestamps, + mooncake_slowdown_factor=args.mooncake_slowdown_factor, + mooncake_num_rounds=args.mooncake_num_rounds, + profile_prefill_url=getattr(args, "profile_prefill_url", None), + profile_decode_url=getattr(args, "profile_decode_url", None), + ) + ) + + +class LoRAPathAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, []) + for lora_name in values: + getattr(namespace, self.dest).append(lora_name) + + +if __name__ == "__main__": + parser = ArgumentParser(description="Benchmark the online serving throughput.") + parser.add_argument( + "--backend", + type=str, + choices=list(ASYNC_REQUEST_FUNCS.keys()), + default="sglang", + help="Must specify a backend, depending on the LLM Inference Engine.", + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0." + ) + parser.add_argument( + "--port", + type=int, + help="If not set, the default port is configured according to its default value for different LLM Inference Engines.", + ) + parser.add_argument( + "--ready-check-timeout-sec", + type=int, + default=60, + help="Maximum time in seconds to wait for the server to be ready before benchmarking. Set to 0 to skip. Default: 60.", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="sharegpt", + choices=[ + "sharegpt", + "custom", + "openai", + "random", + "random-ids", + "generated-shared-prefix", + "mmmu", + "image", + "mooncake", + ], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument( + "--dataset-path", type=str, default="", help="Path to the dataset." + ) + parser.add_argument( + "--model", + type=str, + help="Name or path of the model. If not set, the default model will request /v1/models for conf.", + ) + parser.add_argument( + "--served-model-name", + type=str, + help="The name of the model as served by the serving service. If not set, this defaults to the value of --model.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help="Name or path of the tokenizer. If not set, using the model conf.", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process. Default is 1000.", + ) + parser.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length from the ShareGPT dataset.", + ) + parser.add_argument( + "--sharegpt-context-len", + type=int, + default=None, + help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.", + ) + parser.add_argument( + "--random-input-len", + type=int, + default=1024, + help="Number of input tokens per request, used only for random and image dataset.", + ) + parser.add_argument( + "--random-output-len", + default=1024, + type=int, + help="Number of output tokens per request, used only for random and image dataset.", + ) + parser.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range of sampled ratio of input/output length, " + "used only for random and image dataset.", + ) + # image dataset args + parser.add_argument( + "--image-count", + type=int, + default=1, + help="Number of images per request (only available with the image dataset)", + ) + parser.add_argument( + "--image-resolution", + type=str, + default="1080p", + help=( + "Resolution of images for image dataset. " + "Supports presets 4k/1080p/720p/360p or custom 'heightxwidth' (e.g., 1080x1920)." + ), + ) + parser.add_argument( + "--random-image-count", + action="store_true", + help="Enable Random Image Count", + ) + parser.add_argument( + "--image-format", + type=str, + default="jpeg", + help=("Format of images for image dataset. " "Supports jpeg and png."), + ) + parser.add_argument( + "--image-content", + type=str, + default="random", + help=("Content for images for image dataset. " "Supports random and blank."), + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.", + ) + parser.add_argument( + "--use-trace-timestamps", + action="store_true", + help="Use timestamps from the trace file for request scheduling. Only valid for 'mooncake' dataset.", + ) + parser.add_argument( + "--max-concurrency", + type=int, + default=None, + help="Maximum number of concurrent requests. This can be used " + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.", + ) + parser.add_argument("--output-file", type=str, help="Output JSONL file name.") + parser.add_argument( + "--output-details", action="store_true", help="Output details of benchmarking." + ) + parser.add_argument( + "--print-requests", + action="store_true", + help="Print requests immediately during benchmarking. Useful to quickly realize issues.", + ) + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disable tqdm progress bar.", + ) + parser.add_argument( + "--disable-stream", + action="store_true", + help="Disable streaming mode.", + ) + parser.add_argument( + "--return-logprob", + action="store_true", + help="Return logprob.", + ) + parser.add_argument( + "--return-routed-experts", + action="store_true", + help="Return routed experts.", + ) + parser.add_argument("--seed", type=int, default=1, help="The random seed.") + parser.add_argument( + "--disable-ignore-eos", + action="store_true", + help="Disable ignoring EOS.", + ) + parser.add_argument( + "--extra-request-body", + metavar='{"key1": "value1", "key2": "value2"}', + type=str, + help="Append given JSON object to the request payload. You can use this to specify" + "additional generate params like sampling params.", + ) + parser.add_argument( + "--apply-chat-template", + action="store_true", + help="Apply chat template", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "SGLANG_TORCH_PROFILER_DIR to enable profiler.", + ) + parser.add_argument( + "--plot-throughput", + action="store_true", + help="Plot throughput and concurrent requests over time. Requires termplotlib and gnuplot.", + ) + # TODO unify all these + parser.add_argument( + "--profile-activities", + type=str, + nargs="+", + default=["CPU", "GPU"], + choices=["CPU", "GPU", "CUDA_PROFILER", "XPU"], + help="Profiler activities to capture: CPU, GPU, XPU, CUDA_PROFILER.", + ) + parser.add_argument( + "--profile-start-step", + type=int, + default=None, + help="Start profiling after this many forward steps. Useful for warmup.", + ) + parser.add_argument( + "--profile-steps", + type=int, + default=None, + help="Number of steps to profile. If specified, profiling stops automatically after this many steps.", + ) + parser.add_argument("--profile-num-steps", type=int, default=None) + parser.add_argument("--profile-by-stage", action="store_true", default=False) + parser.add_argument("--profile-stages", nargs="+", default=None) + parser.add_argument( + "--profile-output-dir", + type=str, + default=None, + help="Output directory for profile traces.", + ) + parser.add_argument( + "--profile-prefix", + type=str, + default=None, + help="Prefix for profile trace filenames.", + ) + parser.add_argument( + "--lora-name", + type=str, + nargs="*", + default=None, + action=LoRAPathAction, + help="The names of LoRA adapters. You can provide a list of names in the format {name} {name} {name}...", + ) + parser.add_argument( + "--lora-request-distribution", + type=str, + default="uniform", + choices=[ + "uniform", + "distinct", + "skewed", + ], + help="What distribution to sample the LoRA adapters specified in --lora-name. Borrowed from the Punica paper. " + "'distinct' distribution means selecting a new LoRA adapter for every request. " + "'skewed' distribution follows the Zipf distribution, where the number of requests " + "to model i specified in --lora-name is α times the number of requests for model i+1, " + "where α > 1.", + ) + parser.add_argument( + "--lora-zipf-alpha", + type=float, + default=1.5, + help="The parameter to use for the Zipf distribution when --lora-request-distribution='skewed'.", + ) + parser.add_argument( + "--prompt-suffix", + type=str, + default="", + help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.", + ) + parser.add_argument( + "--pd-separated", + action="store_true", + help="Benchmark PD disaggregation server", + ) + + # Create a mutually exclusive group for profiling URLs + # In PD separated mode, prefill and decode workers must be profiled separately + profile_url_group = parser.add_mutually_exclusive_group() + profile_url_group.add_argument( + "--profile-prefill-url", + type=str, + nargs="*", + default=None, + help="URL(s) of the prefill worker(s) for profiling in PD separated mode. " + "Can specify multiple URLs: --profile-prefill-url http://localhost:30000 http://localhost:30001. " + "NOTE: Cannot be used together with --profile-decode-url. " + "In PD separated mode, prefill and decode workers must be profiled separately.", + ) + profile_url_group.add_argument( + "--profile-decode-url", + type=str, + nargs="*", + default=None, + help="URL(s) of the decode worker(s) for profiling in PD separated mode. " + "Can specify multiple URLs: --profile-decode-url http://localhost:30010 http://localhost:30011. " + "NOTE: Cannot be used together with --profile-prefill-url. " + "In PD separated mode, prefill and decode workers must be profiled separately.", + ) + parser.add_argument( + "--flush-cache", + action="store_true", + help="Flush the cache before running the benchmark", + ) + parser.add_argument( + "--warmup-requests", + type=int, + default=1, + help="Number of warmup requests to run before the benchmark", + ) + parser.add_argument( + "--tokenize-prompt", + action="store_true", + help="Use integer ids instead of string for inputs. Useful to control prompt lengths accurately", + ) + + group = parser.add_argument_group("generated-shared-prefix dataset arguments") + group.add_argument( + "--gsp-num-groups", + type=int, + default=64, + help="Number of system prompt groups for generated-shared-prefix dataset", + ) + group.add_argument( + "--gsp-prompts-per-group", + type=int, + default=16, + help="Number of prompts per system prompt group for generated-shared-prefix dataset", + ) + group.add_argument( + "--gsp-system-prompt-len", + type=int, + default=2048, + help="Target length in tokens for system prompts in generated-shared-prefix dataset", + ) + group.add_argument( + "--gsp-question-len", + type=int, + default=128, + help="Target length in tokens for questions in generated-shared-prefix dataset", + ) + group.add_argument( + "--gsp-output-len", + type=int, + default=256, + help="Target length in tokens for outputs in generated-shared-prefix dataset", + ) + parser.add_argument( + "--gsp-range-ratio", + type=float, + # WARN: The default 1.0 is for backward compatibility, and is different from the default 0.0 for random dataset + default=1.0, + help="Range of sampled ratio of input/output length, used only for gsp dataset.", + ) + group.add_argument( + "--gsp-fast-prepare", + action="store_true", + help="Speedup preparing by removing statistics computation, which will make some output statistics inaccurate but suitable for pressure tests.", + ) + group.add_argument( + "--gsp-send-routing-key", + action="store_true", + help="Send routing key in requests via X-SMG-Routing-Key header. Requests with the same prefix share the same routing key.", + ) + group.add_argument( + "--gsp-num-turns", + type=int, + default=1, + help="Number of turns for multi-turn conversations. If > 1, each prompt becomes a list of questions sharing the same system prefix.", + ) + group.add_argument( + "--gsp-ordered", + action="store_true", + help="Keep requests in order without shuffling. By default, requests are shuffled randomly.", + ) + mooncake_group = parser.add_argument_group("mooncake dataset arguments") + mooncake_group.add_argument( + "--mooncake-slowdown-factor", + type=float, + default=1.0, + help="Slowdown factor for replaying the mooncake trace. " + "A value of 2.0 means the replay is twice as slow. " + "NOTE: --request-rate is IGNORED in mooncake mode.", + ) + mooncake_group.add_argument( + "--mooncake-num-rounds", + type=int, + default=1, + help="Number of conversation rounds for each session in the mooncake dataset. " + "A value > 1 will enable true multi-turn session benchmarking.", + ) + mooncake_group.add_argument( + "--mooncake-workload", + type=str, + default="conversation", + choices=[ + "mooncake", + "conversation", + "synthetic", + "toolagent", + ], + help="Underlying workload for the mooncake dataset.", + ) + parser.add_argument( + "--tag", type=str, default=None, help="The tag to be dumped to output." + ) + parser.add_argument( + "--header", + type=str, + nargs="+", + default=None, + help="Custom HTTP headers in Key=Value format. Example: --header MyHeader=MY_VALUE MyAnotherHeader=myanothervalue", + ) + args = parser.parse_args() + run_benchmark(args) diff --git a/sglang/python/sglang/benchmark/__init__.py b/sglang/python/sglang/benchmark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sglang/python/sglang/benchmark/__pycache__/__init__.cpython-311.pyc b/sglang/python/sglang/benchmark/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9934013c2a06802649e5a4d0c71b6ccad002ca34 Binary files /dev/null and b/sglang/python/sglang/benchmark/__pycache__/__init__.cpython-311.pyc differ diff --git a/sglang/python/sglang/benchmark/__pycache__/utils.cpython-311.pyc b/sglang/python/sglang/benchmark/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3564b757173de53b245740a4a6f4c0fc9c636afa Binary files /dev/null and b/sglang/python/sglang/benchmark/__pycache__/utils.cpython-311.pyc differ diff --git a/sglang/python/sglang/benchmark/datasets/__init__.py b/sglang/python/sglang/benchmark/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..63612d52e414040dd6cc366d747ce090aa6bfeec --- /dev/null +++ b/sglang/python/sglang/benchmark/datasets/__init__.py @@ -0,0 +1,47 @@ +from typing import Dict, Type + +from sglang.benchmark.datasets.common import BaseDataset, DatasetRow +from sglang.benchmark.datasets.custom import CustomDataset +from sglang.benchmark.datasets.generated_shared_prefix import ( + GeneratedSharedPrefixDataset, +) +from sglang.benchmark.datasets.image import ImageDataset +from sglang.benchmark.datasets.mmmu import MMMUDataset +from sglang.benchmark.datasets.mooncake import MooncakeDataset +from sglang.benchmark.datasets.openai_dataset import OpenAIDataset +from sglang.benchmark.datasets.random import RandomDataset +from sglang.benchmark.datasets.sharegpt import ShareGPTDataset + +DATASET_MAPPING: Dict[str, Type[BaseDataset]] = { + "sharegpt": ShareGPTDataset, + "custom": CustomDataset, + "openai": OpenAIDataset, + # TODO: "random" vs "random-ids" should be a flag (e.g. --random-source=sharegpt|integers), + # not two separate dataset names sharing the same class. + "random": RandomDataset, + "random-ids": RandomDataset, + "generated-shared-prefix": GeneratedSharedPrefixDataset, + "mmmu": MMMUDataset, + "image": ImageDataset, + "mooncake": MooncakeDataset, +} + + +def get_dataset(args, tokenizer, model_id=None): + dataset_name = args.dataset_name + if dataset_name.startswith("random") and dataset_name not in DATASET_MAPPING: + dataset_name = "random-ids" + + if dataset_name not in DATASET_MAPPING: + raise ValueError(f"Unknown dataset: {args.dataset_name}") + + dataset_cls = DATASET_MAPPING[dataset_name] + dataset = dataset_cls.from_args(args) + return dataset.load(tokenizer=tokenizer, model_id=model_id) + + +__all__ = [ + "DATASET_MAPPING", + "DatasetRow", + "get_dataset", +] diff --git a/sglang/python/sglang/benchmark/datasets/__pycache__/__init__.cpython-311.pyc b/sglang/python/sglang/benchmark/datasets/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92a394c8829a8350b74876d9421ac658ff791dc0 Binary files /dev/null and b/sglang/python/sglang/benchmark/datasets/__pycache__/__init__.cpython-311.pyc differ diff --git a/sglang/python/sglang/benchmark/datasets/__pycache__/common.cpython-311.pyc b/sglang/python/sglang/benchmark/datasets/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db7faeed510e477959385da5016109c4c5b8527b Binary files /dev/null and b/sglang/python/sglang/benchmark/datasets/__pycache__/common.cpython-311.pyc differ diff --git a/sglang/python/sglang/benchmark/datasets/__pycache__/custom.cpython-311.pyc b/sglang/python/sglang/benchmark/datasets/__pycache__/custom.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fe45b4c034c67dc78f1f44e0409feea84b3324c Binary files /dev/null and b/sglang/python/sglang/benchmark/datasets/__pycache__/custom.cpython-311.pyc differ diff --git a/sglang/python/sglang/benchmark/datasets/__pycache__/generated_shared_prefix.cpython-311.pyc b/sglang/python/sglang/benchmark/datasets/__pycache__/generated_shared_prefix.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3e31f16df535025042009584cdf90da871ea065 Binary files /dev/null and b/sglang/python/sglang/benchmark/datasets/__pycache__/generated_shared_prefix.cpython-311.pyc differ diff --git a/sglang/python/sglang/benchmark/datasets/__pycache__/image.cpython-311.pyc b/sglang/python/sglang/benchmark/datasets/__pycache__/image.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff2589f22a464fcf8da7a1179ec30471c5e6531e Binary files /dev/null and b/sglang/python/sglang/benchmark/datasets/__pycache__/image.cpython-311.pyc differ diff --git a/sglang/python/sglang/benchmark/datasets/__pycache__/mmmu.cpython-311.pyc b/sglang/python/sglang/benchmark/datasets/__pycache__/mmmu.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..563b6a982f942eb49f25a97ca215067ade016850 Binary files /dev/null and b/sglang/python/sglang/benchmark/datasets/__pycache__/mmmu.cpython-311.pyc differ diff --git a/sglang/python/sglang/benchmark/datasets/__pycache__/mooncake.cpython-311.pyc b/sglang/python/sglang/benchmark/datasets/__pycache__/mooncake.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9c444a30035a5db548a153d6b1bef1308bcf33f Binary files /dev/null and b/sglang/python/sglang/benchmark/datasets/__pycache__/mooncake.cpython-311.pyc differ diff --git a/sglang/python/sglang/benchmark/datasets/__pycache__/openai_dataset.cpython-311.pyc b/sglang/python/sglang/benchmark/datasets/__pycache__/openai_dataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0e14145521239a893042ae4f8f408768d1aa1c8 Binary files /dev/null and b/sglang/python/sglang/benchmark/datasets/__pycache__/openai_dataset.cpython-311.pyc differ diff --git a/sglang/python/sglang/benchmark/datasets/__pycache__/random.cpython-311.pyc b/sglang/python/sglang/benchmark/datasets/__pycache__/random.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4aea3897b78efaac98b385ef5e8a66f056dcee2 Binary files /dev/null and b/sglang/python/sglang/benchmark/datasets/__pycache__/random.cpython-311.pyc differ diff --git a/sglang/python/sglang/benchmark/datasets/__pycache__/sharegpt.cpython-311.pyc b/sglang/python/sglang/benchmark/datasets/__pycache__/sharegpt.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..680333b1fb80f00d7930f16541e68bc33a68344d Binary files /dev/null and b/sglang/python/sglang/benchmark/datasets/__pycache__/sharegpt.cpython-311.pyc differ diff --git a/sglang/python/sglang/benchmark/datasets/common.py b/sglang/python/sglang/benchmark/datasets/common.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a204a7e11617c7f0c548ebd15498573534dc45 --- /dev/null +++ b/sglang/python/sglang/benchmark/datasets/common.py @@ -0,0 +1,83 @@ +import random +from abc import ABC, abstractmethod +from argparse import Namespace +from dataclasses import dataclass +from functools import lru_cache +from typing import Any, Dict, List, Optional + +import numpy as np + +ASSISTANT_SUFFIX = "Assistant:" +SHAREGPT_REPO_ID = "anon8231489123/ShareGPT_Vicuna_unfiltered" +SHAREGPT_FILENAME = "ShareGPT_V3_unfiltered_cleaned_split.json" +MOONCAKE_DATASET_URL = { + "mooncake": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/arxiv-trace/mooncake_trace.jsonl", + "conversation": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/conversation_trace.jsonl", + "synthetic": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/synthetic_trace.jsonl", + "toolagent": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/toolagent_trace.jsonl", +} + + +@dataclass +class DatasetRow: + prompt: Any + prompt_len: int + output_len: int + text_prompt_len: Optional[int] = None + vision_prompt_len: Optional[int] = None + image_data: Optional[List[str]] = None + timestamp: Optional[float] = None + routing_key: Optional[str] = None + extra_request_body: Optional[Dict[str, Any]] = None # Per-request API parameters + + def __post_init__(self): + if self.text_prompt_len is None: + self.text_prompt_len = self.prompt_len + if self.vision_prompt_len is None: + self.vision_prompt_len = 0 + if self.extra_request_body is None: + self.extra_request_body = {} + + +@dataclass +class BaseDataset(ABC): + @classmethod + @abstractmethod + def from_args(cls, args: Namespace) -> "BaseDataset": ... + + @abstractmethod + def load( + self, + tokenizer: Any, + model_id: Optional[str] = None, + ) -> List[DatasetRow]: ... + + +def compute_random_lens(full_len: int, range_ratio: float, num: int) -> List[int]: + return np.random.randint( + max(int(full_len * range_ratio), 1), + full_len + 1, + size=num, + ).tolist() + + +@lru_cache(maxsize=1) +def get_available_tokens(tokenizer): + """Get all available token ids from the tokenizer vocabulary.""" + return list(tokenizer.get_vocab().values()) + + +def gen_prompt(tokenizer, token_num): + """Generate a random prompt of specified token length using tokenizer vocabulary.""" + all_available_tokens = get_available_tokens(tokenizer) + selected_tokens = random.choices(all_available_tokens, k=token_num) + return tokenizer.decode(selected_tokens) + + +def gen_mm_prompt(tokenizer, image_pad_id, token_num): + """Generate a random prompt of specified token length using tokenizer vocabulary.""" + all_available_tokens = list(tokenizer.get_vocab().values()) + if image_pad_id: + all_available_tokens.remove(image_pad_id) + selected_tokens = random.choices(all_available_tokens, k=token_num) + return tokenizer.decode(selected_tokens) diff --git a/sglang/python/sglang/benchmark/datasets/custom.py b/sglang/python/sglang/benchmark/datasets/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..452c7db74c0f1aefd2792c6308f3b9f9a3f28e93 --- /dev/null +++ b/sglang/python/sglang/benchmark/datasets/custom.py @@ -0,0 +1,147 @@ +import json +import os +import random +from argparse import Namespace +from dataclasses import dataclass +from typing import List, Optional + +import numpy as np +from transformers import PreTrainedTokenizerBase + +from sglang.benchmark.datasets.common import ( + ASSISTANT_SUFFIX, + BaseDataset, + DatasetRow, +) +from sglang.benchmark.utils import remove_suffix + + +@dataclass +class CustomDataset(BaseDataset): + dataset_path: str + num_requests: int + fixed_output_len: Optional[int] + context_len: Optional[int] + prompt_suffix: str + apply_chat_template: bool + + @classmethod + def from_args(cls, args: Namespace) -> "CustomDataset": + assert not getattr(args, "tokenize_prompt", False) + return cls( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + fixed_output_len=args.sharegpt_output_len, + context_len=args.sharegpt_context_len, + prompt_suffix=args.prompt_suffix, + apply_chat_template=args.apply_chat_template, + ) + + def load( + self, tokenizer: PreTrainedTokenizerBase, model_id=None + ) -> List[DatasetRow]: + return sample_custom_requests( + dataset_path=self.dataset_path, + num_requests=self.num_requests, + tokenizer=tokenizer, + fixed_output_len=self.fixed_output_len, + context_len=self.context_len, + prompt_suffix=self.prompt_suffix, + apply_chat_template=self.apply_chat_template, + ) + + +def sample_custom_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, + context_len: Optional[int] = None, + prompt_suffix: Optional[str] = "", + apply_chat_template=False, +) -> List[DatasetRow]: + """ + Sample requests from a custom JSONL dataset: supports 'content'/'value' as conversation keys. + """ + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset + dataset = [] + if not os.path.isfile(dataset_path): + raise FileNotFoundError(f"Dataset not found at {dataset_path}") + + with open(dataset_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: # skip empty lines + try: + dataset.append(json.loads(line)) + except json.JSONDecodeError: + continue # skip lines with JSON errors + + # Filter out the conversations with less than 2 turns. + processed_dataset = [] + for data in dataset: + convs = data.get("conversations", data.get("conversation", [])) + if len(convs) >= 2: + user_turn = convs[0].get("content", convs[0].get("value", "")) + assist_turn = convs[1].get("content", convs[1].get("value", "")) + processed_dataset.append((user_turn, assist_turn)) + dataset = processed_dataset + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[DatasetRow] = [] + + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + + if prompt_suffix: + prompt = ( + remove_suffix(prompt, ASSISTANT_SUFFIX) + + prompt_suffix + + ASSISTANT_SUFFIX + ) + + if apply_chat_template: + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + return_dict=False, + ) + if tokenizer.bos_token: + prompt = prompt.replace(tokenizer.bos_token, "") + + prompt_token_ids = tokenizer.encode(prompt) + completion = dataset[i][1] + completion_token_ids = tokenizer.encode(completion) + prompt_len = len(prompt_token_ids) + output_len = ( + len(completion_token_ids) if fixed_output_len is None else fixed_output_len + ) + + if prompt_len < 2 or output_len < 2: + # Prune too short sequences. + continue + + if context_len and prompt_len + output_len > context_len: + # Prune too long sequences. + continue + + filtered_dataset.append( + DatasetRow( + prompt=prompt, + prompt_len=prompt_len, + output_len=output_len, + ) + ) + + print(f"#Input tokens: {np.sum([x.prompt_len for x in filtered_dataset])}") + print(f"#Output tokens: {np.sum([x.output_len for x in filtered_dataset])}") + return filtered_dataset diff --git a/sglang/python/sglang/benchmark/datasets/generated_shared_prefix.py b/sglang/python/sglang/benchmark/datasets/generated_shared_prefix.py new file mode 100644 index 0000000000000000000000000000000000000000..51c4e18aeb46a247bbef40029208f64de06e2e0d --- /dev/null +++ b/sglang/python/sglang/benchmark/datasets/generated_shared_prefix.py @@ -0,0 +1,231 @@ +import pickle +import random +import uuid +from argparse import Namespace +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import List + +import numpy as np +from tqdm.asyncio import tqdm +from transformers import PreTrainedTokenizerBase + +from sglang.benchmark.datasets.common import ( + BaseDataset, + DatasetRow, + compute_random_lens, + gen_prompt, +) + + +@dataclass +class GeneratedSharedPrefixDataset(BaseDataset): + num_groups: int + prompts_per_group: int + system_prompt_len: int + question_len: int + output_len: int + range_ratio: float + seed: int + fast_prepare: bool + send_routing_key: bool + num_turns: int + ordered: bool + + @classmethod + def from_args(cls, args: Namespace) -> "GeneratedSharedPrefixDataset": + assert not getattr(args, "tokenize_prompt", False) + return cls( + num_groups=args.gsp_num_groups, + prompts_per_group=args.gsp_prompts_per_group, + system_prompt_len=args.gsp_system_prompt_len, + question_len=args.gsp_question_len, + output_len=args.gsp_output_len, + range_ratio=getattr(args, "gsp_range_ratio", 1.0), + seed=args.seed, + fast_prepare=getattr(args, "gsp_fast_prepare", False), + send_routing_key=getattr(args, "gsp_send_routing_key", False), + num_turns=getattr(args, "gsp_num_turns", 1), + ordered=getattr(args, "gsp_ordered", False), + ) + + def load( + self, tokenizer: PreTrainedTokenizerBase, model_id=None + ) -> List[DatasetRow]: + return sample_generated_shared_prefix_requests( + num_groups=self.num_groups, + prompts_per_group=self.prompts_per_group, + system_prompt_len=self.system_prompt_len, + question_len=self.question_len, + output_len=self.output_len, + range_ratio=self.range_ratio, + tokenizer=tokenizer, + seed=self.seed, + send_routing_key=self.send_routing_key, + num_turns=self.num_turns, + fast_prepare=self.fast_prepare, + ordered=self.ordered, + ) + + +def get_gen_prefix_cache_path( + seed: int, + num_groups: int, + prompts_per_group: int, + system_prompt_len: int, + question_len: int, + output_len: int, + tokenizer, +): + """Create cache directory under ~/.cache/sglang/benchmark""" + cache_dir = Path.home() / ".cache" / "sglang" / "benchmark" + + cache_key = ( + f"gen_shared_prefix_{seed}_{num_groups}_{prompts_per_group}_" + f"{system_prompt_len}_{question_len}_{output_len}_" + f"{tokenizer.__class__.__name__}.pkl" + ) + return cache_dir / cache_key + + +def sample_generated_shared_prefix_requests( + num_groups: int, + prompts_per_group: int, + system_prompt_len: int, + question_len: int, + output_len: int, + range_ratio: float, + tokenizer: PreTrainedTokenizerBase, + seed: int, + send_routing_key: bool = False, + num_turns: int = 1, + fast_prepare: bool = False, + ordered: bool = False, +) -> List[DatasetRow]: + """Generate benchmark requests with shared system prompts using random tokens and caching.""" + cache_path = get_gen_prefix_cache_path( + seed, + num_groups, + prompts_per_group, + system_prompt_len, + question_len, + output_len, + tokenizer, + ) + should_cache = (range_ratio == 1) and not send_routing_key and num_turns == 1 + + # Try to load from cache first + if cache_path.exists() and should_cache: + print(f"\nLoading cached generated input data from {cache_path}") + with open(cache_path, "rb") as f: + return pickle.load(f) + + print( + f"\nGenerating new input data... " + f"({num_groups=}, {prompts_per_group}, {system_prompt_len=}, {question_len=}, {output_len=}, {range_ratio=}, {num_turns=})" + ) + + run_random_str = uuid.uuid4().hex[:8] + run_start_timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + + system_prompt_lens = compute_random_lens( + full_len=system_prompt_len, + range_ratio=range_ratio, + num=num_groups, + ) + question_lens = np.array( + compute_random_lens( + full_len=question_len, + range_ratio=range_ratio, + num=num_groups * prompts_per_group * num_turns, + ) + ).reshape(num_groups, prompts_per_group, num_turns) + output_lens = np.array( + compute_random_lens( + full_len=output_len, + range_ratio=range_ratio, + num=num_groups * prompts_per_group, + ) + ).reshape(num_groups, prompts_per_group) + del system_prompt_len, question_len, output_len + + # Generate system prompts for each group + system_prompts = [ + gen_prompt(tokenizer, system_prompt_lens[i]) for i in range(num_groups) + ] + + # Generate questions: shape (num_groups, prompts_per_group, num_turns) + questions = [ + [ + [ + gen_prompt(tokenizer, int(question_lens[g, p, t])) + for t in range(num_turns) + ] + for p in range(prompts_per_group) + ] + for g in range(num_groups) + ] + + # Combine system prompts with questions + input_requests = [] + total_input_tokens = 0 + total_output_tokens = 0 + + for group_idx in tqdm(range(num_groups), desc="Generating system prompt"): + system_prompt = system_prompts[group_idx] + routing_key = ( + f"{run_random_str}_{run_start_timestamp}_{group_idx}" + if send_routing_key + else None + ) + for prompt_idx in tqdm( + range(prompts_per_group), desc="Generating questions", leave=False + ): + turn_questions = questions[group_idx][prompt_idx] + turn_prompts = [f"{system_prompt}\n\n{turn_questions[0]}"] + turn_questions[ + 1: + ] + full_prompt = turn_prompts[0] if num_turns == 1 else turn_prompts + prompt_len = 1 if fast_prepare else len(tokenizer.encode(turn_prompts[0])) + output_len_val = int(output_lens[group_idx, prompt_idx]) + + input_requests.append( + DatasetRow( + prompt=full_prompt, + prompt_len=prompt_len, + output_len=output_len_val, + routing_key=routing_key, + ) + ) + total_input_tokens += prompt_len + total_output_tokens += output_len_val + + if not ordered: + random.shuffle(input_requests) + + # Print statistics + print(f"\nGenerated shared prefix dataset statistics:") + print(f"Number of groups: {num_groups}") + print(f"Prompts per group: {prompts_per_group}") + print(f"Number of turns: {num_turns}") + print(f"Total prompts: {len(input_requests)}") + if not fast_prepare: + print(f"Total input tokens: {total_input_tokens}") + print(f"Total output tokens: {total_output_tokens}") + print( + f"Average system prompt length: {sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts):.1f} tokens" + ) + all_questions = [q for group in questions for conv in group for q in conv] + print( + f"Average question length: {sum(len(tokenizer.encode(q)) for q in all_questions) / len(all_questions):.1f} tokens\n" + ) + + # Save to cache + if should_cache: + cache_path.parent.mkdir(parents=True, exist_ok=True) + print(f"Caching generated input data to {cache_path}") + with open(cache_path, "wb") as f: + pickle.dump(input_requests, f) + + return input_requests diff --git a/sglang/python/sglang/benchmark/datasets/image.py b/sglang/python/sglang/benchmark/datasets/image.py new file mode 100644 index 0000000000000000000000000000000000000000..a32576b376a58a220bbe461ddfee5eab8e07d78c --- /dev/null +++ b/sglang/python/sglang/benchmark/datasets/image.py @@ -0,0 +1,288 @@ +import io +import warnings +from argparse import Namespace +from dataclasses import dataclass +from typing import List, Tuple + +import numpy as np +import pybase64 +from PIL import Image +from transformers import AutoProcessor + +from sglang.benchmark.datasets.common import ( + BaseDataset, + DatasetRow, + compute_random_lens, + gen_mm_prompt, +) +from sglang.benchmark.utils import get_processor + + +@dataclass +class ImageDataset(BaseDataset): + num_requests: int + image_count: int + input_len: int + output_len: int + range_ratio: float + image_content: str + image_format: str + image_resolution: str + backend: str + random_image_count: bool + + @classmethod + def from_args(cls, args: Namespace) -> "ImageDataset": + return cls( + num_requests=args.num_prompts, + image_count=args.image_count, + input_len=args.random_input_len, + output_len=args.random_output_len, + range_ratio=args.random_range_ratio, + image_content=args.image_content, + image_format=args.image_format, + image_resolution=args.image_resolution, + backend=args.backend, + random_image_count=args.random_image_count, + ) + + def load(self, tokenizer=None, model_id=None) -> List[DatasetRow]: + processor = get_processor(model_id) + return sample_image_requests( + num_requests=self.num_requests, + image_count=self.image_count, + input_len=self.input_len, + output_len=self.output_len, + range_ratio=self.range_ratio, + processor=processor, + image_content=self.image_content, + image_format=self.image_format, + image_resolution=self.image_resolution, + backend=self.backend, + random_image_count=self.random_image_count, + ) + + +def parse_image_resolution(image_resolution: str) -> Tuple[int, int]: + """Parse image resolution into (width, height). + + Supports presets '1080p', '720p', '360p' and custom 'heightxwidth' format + (e.g., '1080x1920' means height=1080, width=1920). + """ + resolution_to_size = { + "4k": (3840, 2160), + "1080p": (1920, 1080), + "720p": (1280, 720), + "360p": (640, 360), + } + if image_resolution in resolution_to_size: + return resolution_to_size[image_resolution] + + res = image_resolution.strip().lower() + if "x" in res: + parts = res.split("x") + if len(parts) == 2 and parts[0].isdigit() and parts[1].isdigit(): + height = int(parts[0]) + width = int(parts[1]) + if height > 0 and width > 0: + return (width, height) + + raise ValueError( + f"Unsupported image resolution: {image_resolution}. " + "Choose from 4k, 1080p, 720p, 360p, or provide custom 'heightxwidth' (e.g., 1080x1920)." + ) + + +def create_mm_data_row( + text_prompt, images: list, images_base64, output_len, processor, backend +): + try: + if type(processor).__name__ == "Phi4MMProcessor": + # <|endoftext10|> is the image token used in the phi-4-multimodal model. + content_items = text_prompt.replace("image 1", "|endoftext10|") + else: + content_items = [ + {"type": "image", "image": {"url": image_base64}} + for image_base64 in images_base64 + ] + content_items.append({"type": "text", "text": text_prompt}) + prompt_str = processor.apply_chat_template( + [{"role": "user", "content": content_items}], + add_generation_prompt=True, + tokenize=False, + ) + except Exception as e: + # Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL) + print(f"Error applying chat template: {e}, fallback to tag") + # Some tokenizers do not support list content; fall back to a placeholder in the text + prompt_str = f"{text_prompt}" + + # Calculate total tokens (text + vision) + prompt_len = processor( + text=[prompt_str], + images=images, + padding=False, + return_tensors="pt", + )["input_ids"].numel() + + # Calculate text-only tokens + try: + # Create text-only version of the prompt + text_only_prompt = processor.apply_chat_template( + [{"role": "user", "content": text_prompt}], + add_generation_prompt=True, + tokenize=False, + ) + text_prompt_len = processor( + text=[text_only_prompt], + padding=False, + return_tensors="pt", + )["input_ids"].numel() + except Exception: + # Fallback: just tokenize the text prompt directly + tokenizer_to_use = ( + processor.tokenizer if hasattr(processor, "tokenizer") else processor + ) + text_prompt_len = len(tokenizer_to_use.encode(text_prompt)) + + # Vision tokens = total tokens - text tokens + vision_prompt_len = prompt_len - text_prompt_len + + use_raw_prompt = backend in [ + "sglang", + "sglang-oai", + "sglang-oai-chat", + "vllm", + "vllm-chat", + "lmdeploy", + "lmdeploy-chat", + ] + return DatasetRow( + prompt=text_prompt if use_raw_prompt else prompt_str, + prompt_len=prompt_len, + output_len=output_len, + text_prompt_len=text_prompt_len, + vision_prompt_len=vision_prompt_len, + image_data=images_base64, + ) + + +def sample_image_requests( + num_requests: int, + image_count: int, + input_len: int, + output_len: int, + range_ratio: float, + processor: AutoProcessor, + image_content: str, + image_format: str, + image_resolution: str, + backend: str, + random_image_count: bool = False, +) -> List[DatasetRow]: + """Generate requests with images. + + - If ``random_image_count`` is True, each request includes a random number of images between 1 and ``image_count``. + - If ``random_image_count`` is False, each request includes exactly ``image_count`` images. + - Supported resolutions: 4k (3840x2160), 1080p (1920x1080), 720p (1280x720), 360p (640x360), + or custom 'heightxwidth' (e.g., 1080x1920). + - Text lengths follow the 'random' dataset sampling rule. ``prompt_len`` + only counts text tokens and excludes image data. + """ + + # Parse resolution (supports presets and 'heightxwidth') + width, height = parse_image_resolution(image_resolution) + + # Determine image counts for each request + if random_image_count: + # Random number of images per request + image_counts = np.random.randint(1, image_count + 1, size=num_requests) + total_images = np.sum(image_counts) + else: + # Fixed number of images per request + image_counts = np.full(num_requests, image_count) + total_images = image_count * num_requests + + # Check for potentially problematic combinations and warn user + if width * height >= 1920 * 1080 and total_images >= 100: + warnings.warn( + f"High resolution ({width}x{height}) with {total_images} total images " + f"may take a long time. Consider reducing resolution or image count.", + UserWarning, + stacklevel=2, + ) + + # Sample text lengths + input_lens = compute_random_lens( + full_len=input_len, + range_ratio=range_ratio, + num=num_requests, + ) + output_lens = compute_random_lens( + full_len=output_len, + range_ratio=range_ratio, + num=num_requests, + ) + + def _gen_random_image_data_uri( + width: int = width, height: int = height + ) -> Tuple[Image.Image, str, int]: + if image_content == "blank": + # Generate blank white image + arr = np.full((height, width, 3), 255, dtype=np.uint8) + else: + # Generate random colored image + arr = (np.random.rand(height, width, 3) * 255).astype(np.uint8) + img = Image.fromarray(arr) + buf = io.BytesIO() + img.save(buf, format=image_format, quality=85) + encoded = pybase64.b64encode(buf.getvalue()).decode("utf-8") + image_data = f"data:image/{image_format};base64,{encoded}" + image_bytes = len(image_data.encode("utf-8")) + return img, image_data, image_bytes + + dataset: List[DatasetRow] = [] + total_image_bytes = 0 + for i in range(num_requests): + # Get the number of images for this request + request_image_count = int(image_counts[i]) + + # Generate text prompt + text_prompt = gen_mm_prompt( + processor.tokenizer, + processor.image_token_id if hasattr(processor, "image_token_id") else None, + int(input_lens[i]), + ) + + # Generate image list + images, images_base64, images_bytes = zip( + *[_gen_random_image_data_uri() for _ in range(request_image_count)] + ) + total_image_bytes += sum(images_bytes) + + data_row = create_mm_data_row( + text_prompt, + list(images), + list(images_base64), + int(output_lens[i]), + processor, + backend, + ) + dataset.append(data_row) + + # Print statistics + print(f"#Input tokens: {np.sum([x.prompt_len for x in dataset])}") + print(f"#Output tokens: {np.sum([x.output_len for x in dataset])}") + print(f"#Total images: {total_images}") + + if random_image_count: + print( + f"#Images per request: min={np.min(image_counts)}, max={np.max(image_counts)}, mean={np.mean(image_counts):.2f}" + ) + else: + print(f"#Images per request: {image_count} (fixed)") + + print( + f"\nCreated {len(dataset)} {image_content} {image_format} images with average {total_image_bytes // num_requests} bytes per request" + ) + return dataset diff --git a/sglang/python/sglang/benchmark/datasets/mmmu.py b/sglang/python/sglang/benchmark/datasets/mmmu.py new file mode 100644 index 0000000000000000000000000000000000000000..94b03057729e088b614da65e9d9763184775fddf --- /dev/null +++ b/sglang/python/sglang/benchmark/datasets/mmmu.py @@ -0,0 +1,124 @@ +import io +import random +from argparse import Namespace +from dataclasses import dataclass +from typing import List, Optional + +import pybase64 +from datasets import load_dataset +from transformers import AutoProcessor, AutoTokenizer + +from sglang.benchmark.datasets.common import BaseDataset, DatasetRow +from sglang.benchmark.datasets.image import create_mm_data_row +from sglang.benchmark.utils import get_processor + + +@dataclass +class MMMUDataset(BaseDataset): + num_requests: int + backend: str + fixed_output_len: Optional[int] + + @classmethod + def from_args(cls, args: Namespace) -> "MMMUDataset": + return cls( + num_requests=args.num_prompts, + backend=args.backend, + fixed_output_len=args.random_output_len, + ) + + def load(self, tokenizer=None, model_id=None) -> List[DatasetRow]: + processor = get_processor(model_id) + return sample_mmmu_requests( + num_requests=self.num_requests, + processor=processor, + backend=self.backend, + fixed_output_len=self.fixed_output_len, + ) + + +def sample_mmmu_requests( + num_requests: int, + processor: AutoProcessor | AutoTokenizer, + backend: str = "sglang", + fixed_output_len: Optional[int] = None, + random_sample: bool = True, +) -> List[DatasetRow]: + """ + Sample requests from the MMMU dataset using HuggingFace datasets. + + Args: + num_requests: Number of requests to sample. + fixed_output_len: If provided, use this fixed output length for all requests. + random_sample: Whether to randomly sample or take the first N. + + Returns: + List of tuples (prompt, prompt_token_len, output_token_len). + """ + print("Loading MMMU dataset from HuggingFace...") + + try: + print("Attempting to load MMMU Math dataset...") + mmmu_dataset = load_dataset("MMMU/MMMU", "Math", split="test") + print( + f"Successfully loaded MMMU Math dataset from HuggingFace with {len(mmmu_dataset)} examples" + ) + except Exception as e: + print(f"Failed to load MMMU Math dataset: {e}") + raise ValueError(f"Failed to load MMMU dataset: {e}") + + # Sample from the dataset + if len(mmmu_dataset) > num_requests: + if random_sample: + # Random sample + indices = random.sample(range(len(mmmu_dataset)), num_requests) + sample_dataset = mmmu_dataset.select(indices) + else: + # Take first N + sample_dataset = mmmu_dataset.select( + range(min(num_requests, len(mmmu_dataset))) + ) + else: + print(f"Dataset has less than {num_requests} examples, using all examples") + sample_dataset = mmmu_dataset + + print(f"Selected {len(sample_dataset)} examples for benchmarking") + + # Create prompts + filtered_dataset = [] + + for i, example in enumerate(sample_dataset): + try: + # Extract image_1 + image = example.get("image_1") + + if image is not None: + if hasattr(image, "save"): + # Convert RGBA images to RGB before encoding + if image.mode == "RGBA": + image = image.convert("RGB") + + # Encode image to base64 (save as PNG to support palette/alpha modes) + buffered = io.BytesIO() + image.save(buffered, format="PNG") + img_str = pybase64.b64encode(buffered.getvalue()).decode("utf-8") + image_data = f"data:image/png;base64,{img_str}" + else: + continue + + # Extract the question + question = example.get("question") + + # Construct the prompt + text_prompt = f"Question: {question}\n\nAnswer: " + output_len = fixed_output_len if fixed_output_len is not None else 256 + data_row = create_mm_data_row( + text_prompt, [image], [image_data], output_len, processor, backend + ) + filtered_dataset.append(data_row) + + except Exception as e: + print(f"Error processing example {i}: {e}") + + print(f"\nCreated {len(filtered_dataset)} MMMU prompts") + return filtered_dataset diff --git a/sglang/python/sglang/benchmark/datasets/mooncake.py b/sglang/python/sglang/benchmark/datasets/mooncake.py new file mode 100644 index 0000000000000000000000000000000000000000..05bb8e07ebf68708dc6673040d505c22ce7f61ae --- /dev/null +++ b/sglang/python/sglang/benchmark/datasets/mooncake.py @@ -0,0 +1,123 @@ +import asyncio +import json +import os +import time +from argparse import Namespace +from dataclasses import dataclass +from typing import AsyncGenerator, Dict, List + +from transformers import PreTrainedTokenizerBase + +from sglang.benchmark.datasets.common import ( + MOONCAKE_DATASET_URL, + BaseDataset, + DatasetRow, +) +from sglang.benchmark.utils import download_and_cache_file + + +@dataclass +class MooncakeDataset(BaseDataset): + dataset_path: str + mooncake_workload: str + num_requests: int + + @classmethod + def from_args(cls, args: Namespace) -> "MooncakeDataset": + return cls( + dataset_path=args.dataset_path, + mooncake_workload=args.mooncake_workload, + num_requests=args.num_prompts, + ) + + def load(self, tokenizer=None, model_id=None) -> List[Dict]: + if not self.dataset_path: + local_path = os.path.join("/tmp", self.mooncake_workload + "_trace.jsonl") + else: + local_path = self.dataset_path + + if not os.path.exists(local_path): + download_and_cache_file( + MOONCAKE_DATASET_URL[self.mooncake_workload], local_path + ) + + with open(local_path, "r") as f: + all_requests_data = [json.loads(line) for line in f if line.strip()] + + return all_requests_data[: self.num_requests] + + +async def get_mooncake_request_over_time( + input_requests: List[Dict], + tokenizer: PreTrainedTokenizerBase, + slowdown_factor: float, + num_rounds: int, +) -> AsyncGenerator[DatasetRow, None]: + """ + An async generator that yields requests based on the timestamps in the Mooncake trace file, + with support for multi-round sessions. + """ + if not input_requests: + return + + input_requests.sort(key=lambda r: r["timestamp"]) + + start_time = time.perf_counter() + trace_start_time_ms = input_requests[0]["timestamp"] + + for record in input_requests: + # Calculate when this entire session should start + relative_arrival_time_s = (record["timestamp"] - trace_start_time_ms) / 1000.0 + target_arrival_time_s = relative_arrival_time_s * slowdown_factor + + current_elapsed_time_s = time.perf_counter() - start_time + sleep_duration_s = target_arrival_time_s - current_elapsed_time_s + if sleep_duration_s > 0: + await asyncio.sleep(sleep_duration_s) + + # Once the session starts, generate all rounds for it as a burst + # This simulates a user engaging in a multi-turn conversation + + # Base user query constructed from hash_ids + user_query_base = "" + hash_ids = record.get("hash_ids", []) + for hash_id in hash_ids: + user_query_base += f"{hash_id}" + " ".join( + ["hi"] * 128 + ) # Shorter for multi-round + user_query_base += "Tell me a story based on this context." + + output_len_per_round = record.get("output_length", 256) + chat_history = [] + + for i in range(num_rounds): + # Add user query for the current round + chat_history.append( + {"role": "user", "content": f"Round {i + 1}: {user_query_base}"} + ) + + # Form the full prompt from history + try: + full_prompt_text = tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=True, + return_dict=False, + ) + except Exception: + full_prompt_text = "\n".join( + [f"{msg['role']}: {msg['content']}" for msg in chat_history] + ) + + prompt_len = len(tokenizer.encode(full_prompt_text)) + + yield DatasetRow( + prompt=full_prompt_text, + prompt_len=prompt_len, + output_len=output_len_per_round, + ) + + # Add a placeholder assistant response for the next round's context + # We use a placeholder because we don't know the real response + placeholder_response = " ".join(["story"] * output_len_per_round) + chat_history.append({"role": "assistant", "content": placeholder_response}) diff --git a/sglang/python/sglang/benchmark/datasets/openai_dataset.py b/sglang/python/sglang/benchmark/datasets/openai_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3ae8070562e6ce1b806433c019b237dfae6df8ac --- /dev/null +++ b/sglang/python/sglang/benchmark/datasets/openai_dataset.py @@ -0,0 +1,113 @@ +import json +from argparse import Namespace +from dataclasses import dataclass +from typing import List, Optional + +import numpy as np +from transformers import PreTrainedTokenizerBase + +from sglang.benchmark.datasets.common import BaseDataset, DatasetRow + + +@dataclass +class OpenAIDataset(BaseDataset): + dataset_path: str + num_requests: int + fixed_output_len: Optional[int] + + @classmethod + def from_args(cls, args: Namespace) -> "OpenAIDataset": + return cls( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + fixed_output_len=args.sharegpt_output_len, + ) + + def load( + self, tokenizer: PreTrainedTokenizerBase, model_id=None + ) -> List[DatasetRow]: + return sample_openai_requests( + dataset_path=self.dataset_path, + num_requests=self.num_requests, + tokenizer=tokenizer, + fixed_output_len=self.fixed_output_len, + ) + + +def sample_openai_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, +) -> List[DatasetRow]: + """ + Load OpenAI-compatible chat completion requests from a JSONL file. + + Each line should be a JSON object with: + - "messages": list of {"role": str, "content": str} + - "max_tokens": int (used as output_len if fixed_output_len not set) + - "tools": optional list of tool definitions + - "temperature": optional temperature value + - "top_p": optional top_p value + - Other OpenAI API parameters are also extracted and passed through + """ + dataset = [] + with open(dataset_path, "r") as f: + for line in f: + if num_requests > 0 and len(dataset) >= num_requests: + break + if line.strip(): + try: + dataset.append(json.loads(line)) + except json.JSONDecodeError: + # Skip invalid JSON lines + continue + + # Fields that should NOT be passed through extra_request_body + # These are either handled separately or are metadata + # max_tokens is excluded because it's handled via output_len -> max_completion_tokens + # max_completion_tokens is also excluded to avoid conflicts + EXCLUDED_FIELDS = {"messages", "max_tokens", "max_completion_tokens", "model"} + + filtered_dataset: List[DatasetRow] = [] + for data in dataset: + messages = data.get("messages", []) + if not messages: + continue + + # Use max_tokens from the request, or fall back to fixed_output_len + output_len = fixed_output_len or data.get("max_tokens", 256) + + # Extract extra request body parameters (tools, temperature, top_p, etc.) + extra_body = {k: v for k, v in data.items() if k not in EXCLUDED_FIELDS} + + # Calculate prompt length by applying chat template + # This includes the messages but not the tools + prompt_len = len( + tokenizer.apply_chat_template( + messages, tokenize=True, add_generation_prompt=True + ) + ) + + # If tools are present, we need to add their token count + # Tools are sent as part of the request and count toward input tokens + if "tools" in extra_body: + # Encode tools as JSON string to estimate token count + tools_str = json.dumps(extra_body["tools"]) + tools_tokens = len(tokenizer.encode(tools_str)) + prompt_len += tools_tokens + + # Pass messages list directly - bench_serving handles List[Dict] prompts + filtered_dataset.append( + DatasetRow( + prompt=messages, + prompt_len=prompt_len, + output_len=output_len, + extra_request_body=extra_body, # Store per-request parameters + ) + ) + + print(f"Loaded {len(filtered_dataset)} OpenAI-format requests") + print(f"#Input tokens: {np.sum([x.prompt_len for x in filtered_dataset])}") + print(f"#Output tokens: {np.sum([x.output_len for x in filtered_dataset])}") + return filtered_dataset diff --git a/sglang/python/sglang/benchmark/datasets/random.py b/sglang/python/sglang/benchmark/datasets/random.py new file mode 100644 index 0000000000000000000000000000000000000000..b62c932440180c5d83dc35c752e625738b1cc9e4 --- /dev/null +++ b/sglang/python/sglang/benchmark/datasets/random.py @@ -0,0 +1,167 @@ +import json +import random +from argparse import Namespace +from dataclasses import dataclass +from typing import List + +import numpy as np +from transformers import PreTrainedTokenizerBase + +from sglang.benchmark.datasets.common import ( + SHAREGPT_FILENAME, + SHAREGPT_REPO_ID, + BaseDataset, + DatasetRow, + compute_random_lens, +) +from sglang.benchmark.utils import download_and_cache_hf_file, is_file_valid_json + + +@dataclass +class RandomDataset(BaseDataset): + input_len: int + output_len: int + num_requests: int + range_ratio: float + dataset_path: str + return_text: bool + random_sample: bool + + @classmethod + def from_args(cls, args: Namespace) -> "RandomDataset": + return cls( + input_len=args.random_input_len, + output_len=args.random_output_len, + num_requests=args.num_prompts, + range_ratio=args.random_range_ratio, + dataset_path=args.dataset_path, + return_text=not getattr(args, "tokenize_prompt", False), + random_sample=(args.dataset_name == "random"), + ) + + def load( + self, tokenizer: PreTrainedTokenizerBase, model_id=None + ) -> List[DatasetRow]: + return sample_random_requests( + input_len=self.input_len, + output_len=self.output_len, + num_prompts=self.num_requests, + range_ratio=self.range_ratio, + tokenizer=tokenizer, + dataset_path=self.dataset_path, + random_sample=self.random_sample, + return_text=self.return_text, + ) + + +def sample_random_requests( + input_len: int, + output_len: int, + num_prompts: int, + range_ratio: float, + tokenizer: PreTrainedTokenizerBase, + dataset_path: str, + random_sample: bool = True, + return_text: bool = True, +) -> List[DatasetRow]: + input_lens = compute_random_lens( + full_len=input_len, + range_ratio=range_ratio, + num=num_prompts, + ) + output_lens = compute_random_lens( + full_len=output_len, + range_ratio=range_ratio, + num=num_prompts, + ) + + if return_text: + # Need to truncate input_len as server encode will add special token. + num_special_tokens = int(tokenizer.num_special_tokens_to_add()) + for i in range(num_prompts): + input_lens[i] = max(0, input_lens[i] - num_special_tokens) + + if random_sample: + # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens + + # Download sharegpt if necessary + if not is_file_valid_json(dataset_path): + dataset_path = download_and_cache_hf_file( + repo_id=SHAREGPT_REPO_ID, + filename=SHAREGPT_FILENAME, + ) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [ + data + for data in dataset + if len(data.get("conversations", data.get("conversation", []))) >= 2 + ] + # Only keep the first two turns of each conversation. + dataset = [ + ( + data.get("conversations", data.get("conversation", []))[0]["value"], + data.get("conversations", data.get("conversation", []))[1]["value"], + ) + for data in dataset + ] + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + input_requests: List[DatasetRow] = [] + for data in dataset: + i = len(input_requests) + if i == num_prompts: + break + + # Tokenize the prompts and completions. + prompt = data[0] + prompt_token_ids = tokenizer.encode(prompt) + prompt_len = len(prompt_token_ids) + + # Skip empty prompt + if prompt_len == 0: + continue + + if prompt_len > input_lens[i]: + input_ids = prompt_token_ids[: input_lens[i]] + else: + ratio = (input_lens[i] + prompt_len - 1) // prompt_len + input_ids = (prompt_token_ids * ratio)[: input_lens[i]] + input_content = input_ids + if return_text: + input_content = tokenizer.decode(input_content) + input_requests.append( + DatasetRow( + prompt=input_content, + prompt_len=input_lens[i], + output_len=output_lens[i], + ) + ) + else: + # Sample token ids from random integers. This can cause some NaN issues. + offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) + input_requests = [] + for i in range(num_prompts): + # Use int() to convert numpy.int64 to native Python int for JSON serialization + input_content = [ + int((offsets[i] + i + j) % tokenizer.vocab_size) + for j in range(input_lens[i]) + ] + if return_text: + input_content = tokenizer.decode(input_content) + input_requests.append( + DatasetRow( + prompt=input_content, + prompt_len=input_lens[i], + output_len=output_lens[i], + ) + ) + + print(f"#Input tokens: {np.sum(input_lens)}") + print(f"#Output tokens: {np.sum(output_lens)}") + return input_requests diff --git a/sglang/python/sglang/benchmark/datasets/sharegpt.py b/sglang/python/sglang/benchmark/datasets/sharegpt.py new file mode 100644 index 0000000000000000000000000000000000000000..6aed91ea80aef95293f89e9217cb50f184f506e3 --- /dev/null +++ b/sglang/python/sglang/benchmark/datasets/sharegpt.py @@ -0,0 +1,151 @@ +import json +import random +from argparse import Namespace +from dataclasses import dataclass +from typing import List, Optional + +import numpy as np +from transformers import PreTrainedTokenizerBase + +from sglang.benchmark.datasets.common import ( + ASSISTANT_SUFFIX, + SHAREGPT_FILENAME, + SHAREGPT_REPO_ID, + BaseDataset, + DatasetRow, +) +from sglang.benchmark.utils import ( + download_and_cache_hf_file, + is_file_valid_json, + remove_suffix, +) + + +@dataclass +class ShareGPTDataset(BaseDataset): + dataset_path: str + num_requests: int + fixed_output_len: Optional[int] + context_len: Optional[int] + prompt_suffix: str + apply_chat_template: bool + + @classmethod + def from_args(cls, args: Namespace) -> "ShareGPTDataset": + assert not getattr(args, "tokenize_prompt", False) + return cls( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + fixed_output_len=args.sharegpt_output_len, + context_len=args.sharegpt_context_len, + prompt_suffix=args.prompt_suffix, + apply_chat_template=args.apply_chat_template, + ) + + def load( + self, tokenizer: PreTrainedTokenizerBase, model_id=None + ) -> List[DatasetRow]: + return sample_sharegpt_requests( + dataset_path=self.dataset_path, + num_requests=self.num_requests, + tokenizer=tokenizer, + fixed_output_len=self.fixed_output_len, + context_len=self.context_len, + prompt_suffix=self.prompt_suffix, + apply_chat_template=self.apply_chat_template, + ) + + +def sample_sharegpt_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, + context_len: Optional[int] = None, + prompt_suffix: Optional[str] = "", + apply_chat_template=False, +) -> List[DatasetRow]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Download sharegpt if necessary + if not is_file_valid_json(dataset_path) and dataset_path == "": + dataset_path = download_and_cache_hf_file( + repo_id=SHAREGPT_REPO_ID, + filename=SHAREGPT_FILENAME, + ) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + + # Filter out the conversations with less than 2 turns. + dataset = [ + data + for data in dataset + if len(data.get("conversations", data.get("conversation", []))) >= 2 + ] + # Only keep the first two turns of each conversation. + dataset = [ + ( + data.get("conversations", data.get("conversation", []))[0]["value"], + data.get("conversations", data.get("conversation", []))[1]["value"], + ) + for data in dataset + ] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[DatasetRow] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + if prompt_suffix: + prompt = ( + remove_suffix(prompt, ASSISTANT_SUFFIX) + + prompt_suffix + + ASSISTANT_SUFFIX + ) + + if apply_chat_template: + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + return_dict=False, + ) + if tokenizer.bos_token: + prompt = prompt.replace(tokenizer.bos_token, "") + + prompt_token_ids = tokenizer.encode(prompt) + completion = dataset[i][1] + completion_token_ids = tokenizer.encode(completion) + prompt_len = len(prompt_token_ids) + output_len = ( + len(completion_token_ids) if fixed_output_len is None else fixed_output_len + ) + + if prompt_len < 2 or output_len < 2: + # Prune too short sequences. + continue + + if context_len and prompt_len + output_len > context_len: + # Prune too long sequences. + continue + + filtered_dataset.append( + DatasetRow( + prompt=prompt, + prompt_len=prompt_len, + output_len=output_len, + ) + ) + + print(f"#Input tokens: {np.sum([x.prompt_len for x in filtered_dataset])}") + print(f"#Output tokens: {np.sum([x.output_len for x in filtered_dataset])}") + return filtered_dataset diff --git a/sglang/python/sglang/benchmark/utils.py b/sglang/python/sglang/benchmark/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7bf6494b5df13bbaeec267ccd8e3777a5f2df22a --- /dev/null +++ b/sglang/python/sglang/benchmark/utils.py @@ -0,0 +1,159 @@ +import json +import os +import resource +from json import JSONDecodeError +from typing import Dict, List, Optional, Union + +import requests +from tqdm.asyncio import tqdm +from transformers import ( + AutoProcessor, + AutoTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) + + +def remove_prefix(text: str, prefix: str) -> str: + return text[len(prefix) :] if text.startswith(prefix) else text + + +def remove_suffix(text: str, suffix: str) -> str: + return text[: -len(suffix)] if text.endswith(suffix) else text + + +def parse_custom_headers(header_list: List[str]) -> Dict[str, str]: + return {k: v for h in header_list for k, _, v in [h.partition("=")] if k and v} + + +def get_model(pretrained_model_name_or_path: str) -> str: + if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() == "true": + import huggingface_hub.constants + from modelscope import snapshot_download + + model_path = snapshot_download( + model_id=pretrained_model_name_or_path, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"], + ) + + return model_path + return pretrained_model_name_or_path + + +def get_tokenizer( + pretrained_model_name_or_path: str, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + assert ( + pretrained_model_name_or_path is not None + and pretrained_model_name_or_path != "" + ) + if pretrained_model_name_or_path.endswith( + ".json" + ) or pretrained_model_name_or_path.endswith(".model"): + from sglang.srt.utils.hf_transformers_utils import get_tokenizer + + return get_tokenizer(pretrained_model_name_or_path) + + if pretrained_model_name_or_path is not None and not os.path.exists( + pretrained_model_name_or_path + ): + pretrained_model_name_or_path = get_model(pretrained_model_name_or_path) + return AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=True + ) + + +def get_processor( + pretrained_model_name_or_path: str, +) -> AutoProcessor: + assert ( + pretrained_model_name_or_path is not None + and pretrained_model_name_or_path != "" + ) + if pretrained_model_name_or_path.endswith( + ".json" + ) or pretrained_model_name_or_path.endswith(".model"): + from sglang.srt.utils.hf_transformers_utils import get_processor + + return get_processor(pretrained_model_name_or_path) + + if pretrained_model_name_or_path is not None and not os.path.exists( + pretrained_model_name_or_path + ): + pretrained_model_name_or_path = get_model(pretrained_model_name_or_path) + return AutoProcessor.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=True + ) + + +def download_and_cache_hf_file( + repo_id: str, + filename: str, + repo_type: str = "dataset", +): + """Download a file from Hugging Face and cache it locally.""" + from huggingface_hub import hf_hub_download + + return hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type) + + +def download_and_cache_file(url: str, filename: Optional[str] = None): + """Read and cache a file from a url.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) + + # Check if the cache file already exists + if is_file_valid_json(filename): + return filename + + print(f"Downloading from {url} to {filename}") + + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors + + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB + + # Use tqdm to display the progress bar + with open(filename, "wb") as f, tqdm( + desc=filename, + total=total_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + f.write(chunk) + bar.update(len(chunk)) + + return filename + + +def is_file_valid_json(path): + if not os.path.isfile(path): + return False + + # TODO can fuse into the real file open later + try: + with open(path) as f: + json.load(f) + return True + except JSONDecodeError as e: + print( + f"{path} exists but json loading fails ({e=}), thus treat as invalid file" + ) + return False + + +def set_ulimit(target_soft_limit=65535): + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) + except ValueError as e: + print(f"Fail to set RLIMIT_NOFILE: {e}") diff --git a/sglang/python/sglang/check_env.py b/sglang/python/sglang/check_env.py new file mode 100644 index 0000000000000000000000000000000000000000..8a312c560990caa3f7d10f685502d98121cbe43c --- /dev/null +++ b/sglang/python/sglang/check_env.py @@ -0,0 +1,525 @@ +"""Check environment configurations and dependency versions.""" + +import importlib.metadata +import os +import resource +import subprocess +import sys +from abc import abstractmethod +from collections import OrderedDict, defaultdict + +import torch + +from sglang.srt.utils import is_hip, is_musa, is_npu + + +def is_cuda_v2(): + return torch.version.cuda is not None + + +# List of packages to check versions +PACKAGE_LIST = [ + "sglang", + "sgl_kernel", + "flashinfer_python", + "flashinfer_cubin", + "flashinfer_jit_cache", + "triton", + "transformers", + "torchao", + "numpy", + "aiohttp", + "fastapi", + "hf_transfer", + "huggingface_hub", + "interegular", + "modelscope", + "orjson", + "outlines", + "packaging", + "psutil", + "pydantic", + "python-multipart", + "pyzmq", + "torchao", + "uvicorn", + "uvloop", + "vllm", + "xgrammar", + "openai", + "tiktoken", + "anthropic", + "litellm", + "decord2", +] + + +class BaseEnv: + """Base class for environment check""" + + def __init__(self): + self.package_list = PACKAGE_LIST + + @abstractmethod + def get_info(self) -> dict: + """ + Get CUDA-related information if available. + """ + raise NotImplementedError + + @abstractmethod + def get_topology(self) -> dict: + raise NotImplementedError + + def get_package_versions(self) -> dict: + """ + Get versions of specified packages. + """ + versions = {} + for package in self.package_list: + package_name = package.split("==")[0].split(">=")[0].split("<=")[0] + try: + version = importlib.metadata.version(package_name) + versions[package_name] = version + except ModuleNotFoundError: + versions[package_name] = "Module Not Found" + return versions + + def get_device_info(self): + """ + Get information about available GPU devices. + """ + devices = defaultdict(list) + capabilities = defaultdict(list) + for k in range(torch.cuda.device_count()): + devices[torch.cuda.get_device_name(k)].append(str(k)) + capability = torch.cuda.get_device_capability(k) + capabilities[f"{capability[0]}.{capability[1]}"].append(str(k)) + + gpu_info = {} + for name, device_ids in devices.items(): + gpu_info[f"GPU {','.join(device_ids)}"] = name + + if len(capabilities) == 1: + # All GPUs have the same compute capability + cap, gpu_ids = list(capabilities.items())[0] + gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap + else: + # GPUs have different compute capabilities + for cap, gpu_ids in capabilities.items(): + gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap + + return gpu_info + + def get_hypervisor_vendor(self) -> dict: + try: + output = subprocess.check_output(["lscpu"], text=True) + for line in output.split("\n"): + if "Hypervisor vendor:" in line: + return {"Hypervisor vendor:": line.split(":")[1].strip()} + return {} + except: + return {} + + def get_ulimit_soft(self) -> dict: + ulimit_soft, _ = resource.getrlimit(resource.RLIMIT_NOFILE) + return {"ulimit soft": ulimit_soft} + + def check_env(self): + """ + Check and print environment information. + """ + env_info = OrderedDict() + env_info["Python"] = sys.version.replace("\n", "") + env_info.update(self.get_info()) + env_info["PyTorch"] = torch.__version__ + env_info.update(self.get_package_versions()) + env_info.update(self.get_topology()) + env_info.update(self.get_hypervisor_vendor()) + env_info.update(self.get_ulimit_soft()) + + for k, v in env_info.items(): + print(f"{k}: {v}") + + +class GPUEnv(BaseEnv): + """Environment checker for Nvidia GPU""" + + def get_info(self): + cuda_info = {"CUDA available": torch.cuda.is_available()} + + if cuda_info["CUDA available"]: + cuda_info.update(self.get_device_info()) + cuda_info.update(self._get_cuda_version_info()) + + return cuda_info + + def _get_cuda_version_info(self): + """ + Get CUDA version information. + """ + from torch.utils.cpp_extension import CUDA_HOME + + cuda_info = {"CUDA_HOME": CUDA_HOME} + + if CUDA_HOME and os.path.isdir(CUDA_HOME): + cuda_info.update(self._get_nvcc_info()) + cuda_info.update(self._get_cuda_driver_version()) + + return cuda_info + + def _get_nvcc_info(self): + """ + Get NVCC version information. + """ + from torch.utils.cpp_extension import CUDA_HOME + + try: + nvcc = os.path.join(CUDA_HOME, "bin/nvcc") + nvcc_output = ( + subprocess.check_output(f'"{nvcc}" -V', shell=True) + .decode("utf-8") + .strip() + ) + return { + "NVCC": nvcc_output[ + nvcc_output.rfind("Cuda compilation tools") : nvcc_output.rfind( + "Build" + ) + ].strip() + } + except subprocess.SubprocessError: + return {"NVCC": "Not Available"} + + def _get_cuda_driver_version(self): + """ + Get CUDA driver version. + """ + versions = set() + try: + output = subprocess.check_output( + [ + "nvidia-smi", + "--query-gpu=driver_version", + "--format=csv,noheader,nounits", + ] + ) + versions = set(output.decode().strip().split("\n")) + if len(versions) == 1: + return {"CUDA Driver Version": versions.pop()} + else: + return {"CUDA Driver Versions": ", ".join(sorted(versions))} + except subprocess.SubprocessError: + return {"CUDA Driver Version": "Not Available"} + + def get_topology(self): + """ + Get GPU topology information. + """ + try: + result = subprocess.run( + ["nvidia-smi", "topo", "-m"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True, + ) + return { + "NVIDIA Topology": ( + "\n" + result.stdout if result.returncode == 0 else None + ) + } + except subprocess.SubprocessError: + return {} + + +class HIPEnv(BaseEnv): + """Environment checker for ROCm/HIP""" + + def get_info(self): + cuda_info = {"ROCM available": torch.cuda.is_available()} + + if cuda_info["ROCM available"]: + cuda_info.update(self.get_device_info()) + cuda_info.update(self._get_cuda_version_info()) + + return cuda_info + + def _get_cuda_version_info(self): + from torch.utils.cpp_extension import ROCM_HOME as ROCM_HOME + + cuda_info = {"ROCM_HOME": ROCM_HOME} + + if ROCM_HOME and os.path.isdir(ROCM_HOME): + cuda_info.update(self._get_hipcc_info()) + cuda_info.update(self._get_rocm_driver_version()) + + return cuda_info + + def _get_hipcc_info(self): + from torch.utils.cpp_extension import ROCM_HOME + + try: + hipcc = os.path.join(ROCM_HOME, "bin/hipcc") + hipcc_output = ( + subprocess.check_output(f'"{hipcc}" --version', shell=True) + .decode("utf-8") + .strip() + ) + return { + "HIPCC": hipcc_output[ + hipcc_output.rfind("HIP version") : hipcc_output.rfind("AMD clang") + ].strip() + } + except subprocess.SubprocessError: + return {"HIPCC": "Not Available"} + + def _get_rocm_driver_version(self): + try: + output = subprocess.check_output( + [ + "rocm-smi", + "--showdriverversion", + "--csv", + ] + ) + versions = set(output.decode().strip().split("\n")) + versions.discard("name, value") + ver = versions.pop() + ver = ver.replace('"Driver version", ', "").replace('"', "") + + return {"ROCM Driver Version": ver} + except subprocess.SubprocessError: + return {"ROCM Driver Version": "Not Available"} + + def get_topology(self): + try: + result = subprocess.run( + ["rocm-smi", "--showtopotype"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True, + ) + return { + "AMD Topology": "\n" + result.stdout if result.returncode == 0 else None + } + except subprocess.SubprocessError: + return {} + + +class NPUEnv(BaseEnv): + """Environment checker for Ascend NPU""" + + EXTRA_PACKAGE_LIST = [ + "torch_npu", + "sgl-kernel-npu", + "deep_ep", + ] + + def __init__(self): + super().__init__() + self.package_list.extend(NPUEnv.EXTRA_PACKAGE_LIST) + + def get_info(self): + cuda_info = {"NPU available": torch.npu.is_available()} + if cuda_info["NPU available"]: + cuda_info.update(self.get_device_info()) + cuda_info.update(self._get_cann_version_info()) + + return cuda_info + + def get_device_info(self): + """ + Get information about available NPUs. + Need to override due to torch_npu interface differences. + """ + devices = defaultdict(list) + for k in range(torch.npu.device_count()): + devices[torch.npu.get_device_name(k)].append(str(k)) + + npu_info = {} + for name, device_ids in devices.items(): + npu_info[f"NPU {','.join(device_ids)}"] = name + + return npu_info + + def _get_cann_version_info(self): + cann_envs = ["ASCEND_TOOLKIT_HOME", "ASCEND_INSTALL_PATH"] + for var in cann_envs: + path = os.environ.get(var) + if path and os.path.exists(path): + CANN_HOME = path + break + else: + default_path = "/usr/local/Ascend/ascend-toolkit/latest" + CANN_HOME = default_path if os.path.exists(default_path) else None + + if CANN_HOME: + npu_info = {"CANN_HOME": CANN_HOME} + npu_info.update(self._get_cann_info(CANN_HOME)) + npu_info.update(self._get_ascend_driver_version()) + return npu_info + else: + return {"CANN_HOME": "Not found"} + + def _get_cann_info(self, CANN_HOME: str): + cann_info = {} + cann_version_file = os.path.join(CANN_HOME, "version.cfg") + if os.path.exists(cann_version_file): + with open(cann_version_file, "r", encoding="utf-8") as f: + f.readline() # discard first line comment in version.cfg + cann_info["CANN"] = f.readline().split("[")[1].split("]")[0] + else: + cann_info["CANN"] = "Not Available" + try: + bisheng = os.path.join(CANN_HOME, "compiler/ccec_compiler/bin/bisheng") + bisheng_output = ( + subprocess.check_output([bisheng, "--version"]).decode("utf-8").strip() + ) + cann_info["BiSheng"] = bisheng_output.split("\n")[0].strip() + except subprocess.SubprocessError: + cann_info["BiSheng"] = "Not Available" + return cann_info + + def _get_ascend_driver_version(self): + try: + output = subprocess.check_output( + [ + "npu-smi", + "info", + "-t", + "board", + "-i", + "0", + ] + ) + for line in output.decode().strip().split("\n"): + if "Software Version" in line: + version = line.split(":")[-1].strip() + break + else: + version = "Not Available" + + return {"Ascend Driver Version": version} + except subprocess.SubprocessError: + return {"Ascend Driver Version": "Not Available"} + + def get_topology(self): + try: + result = subprocess.run( + ["npu-smi", "info", "-t", "topo"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True, + ) + return { + "Ascend Topology": ( + "\n" + result.stdout if result.returncode == 0 else None + ) + } + except subprocess.SubprocessError: + return {} + + +class MUSAEnv(BaseEnv): + """Environment checker for MThreads GPU""" + + def get_info(self): + musa_info = {"MUSA available": torch.musa.is_available()} + + if musa_info["MUSA available"]: + musa_info.update(self.get_device_info()) + musa_info.update(self._get_musa_version_info()) + + return musa_info + + def _get_musa_version_info(self): + """ + Get MUSA version information. + """ + from torch_musa.utils.musa_extension import MUSA_HOME + + musa_info = {"MUSA_HOME": MUSA_HOME} + + if MUSA_HOME and os.path.isdir(MUSA_HOME): + musa_info.update(self._get_mcc_info()) + musa_info.update(self._get_musa_driver_version()) + + return musa_info + + def _get_mcc_info(self): + """ + Get MCC version information. + """ + from torch_musa.utils.musa_extension import MUSA_HOME + + try: + mcc = os.path.join(MUSA_HOME, "bin/mcc") + mcc_output = ( + subprocess.check_output(f'"{mcc}" --version', shell=True) + .decode("utf-8") + .strip() + ) + return { + "MCC": mcc_output[ + mcc_output.rfind("mcc version") : mcc_output.rfind("Target") + ].strip() + } + except subprocess.SubprocessError: + return {"MCC": "Not Available"} + + def _get_musa_driver_version(self): + """ + Get MUSA driver version. + """ + try: + output = subprocess.check_output( + [ + "mthreads-gmi", + "-q", + ], + text=True, + ) + driver_version = None + for line in output.splitlines(): + if "Driver Version" in line: + driver_version = line.split(":", 1)[1].strip() + break + + return {"MUSA Driver Version": driver_version} + except subprocess.SubprocessError: + return {"MUSA Driver Version": "Not Available"} + + def get_topology(self): + """ + Get GPU topology information. + """ + try: + result = subprocess.run( + ["mthreads-gmi", "topo", "-m"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True, + ) + return { + "MTHREADS Topology": ( + "\n" + result.stdout if result.returncode == 0 else None + ) + } + except subprocess.SubprocessError: + return {} + + +if __name__ == "__main__": + if is_cuda_v2(): + env = GPUEnv() + elif is_hip(): + env = HIPEnv() + elif is_npu(): + env = NPUEnv() + elif is_musa(): + env = MUSAEnv() + env.check_env() diff --git a/sglang/python/sglang/cli/__init__.py b/sglang/python/sglang/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sglang/python/sglang/cli/generate.py b/sglang/python/sglang/cli/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..56ae1ddfe3564effe93673a02fef9c0c36efec31 --- /dev/null +++ b/sglang/python/sglang/cli/generate.py @@ -0,0 +1,33 @@ +import argparse + +from sglang.cli.utils import get_is_diffusion_model, get_model_path + + +def generate(args, extra_argv): + # If help is requested, show generate subcommand help without requiring --model-path + if any(h in extra_argv for h in ("-h", "--help")): + from sglang.multimodal_gen.runtime.entrypoints.cli.generate import ( + add_multimodal_gen_generate_args, + ) + + parser = argparse.ArgumentParser(description="SGLang Multimodal Generation") + add_multimodal_gen_generate_args(parser) + parser.parse_args(extra_argv) + return + + model_path = get_model_path(extra_argv) + is_diffusion_model = get_is_diffusion_model(model_path) + if is_diffusion_model: + from sglang.multimodal_gen.runtime.entrypoints.cli.generate import ( + add_multimodal_gen_generate_args, + generate_cmd, + ) + + parser = argparse.ArgumentParser(description="SGLang Multimodal Generation") + add_multimodal_gen_generate_args(parser) + parsed_args, unknown_args = parser.parse_known_args(extra_argv) + generate_cmd(parsed_args, unknown_args) + else: + raise Exception( + f"Generate subcommand is not yet supported for model: {model_path}" + ) diff --git a/sglang/python/sglang/cli/main.py b/sglang/python/sglang/cli/main.py new file mode 100644 index 0000000000000000000000000000000000000000..36ed5ae5ea9a0e821c84fe3aeeabd130067bc5b7 --- /dev/null +++ b/sglang/python/sglang/cli/main.py @@ -0,0 +1,48 @@ +import argparse + +from sglang.cli.utils import get_git_commit_hash +from sglang.version import __version__ + + +def version(args, extra_argv): + print(f"sglang version: {__version__}") + print(f"git revision: {get_git_commit_hash()[:7]}") + + +def main(): + parser = argparse.ArgumentParser() + + # complex sub commands + subparsers = parser.add_subparsers(dest="subcommand", required=True) + + subparsers.add_parser( + "serve", + help="Launch the SGLang server.", + add_help=False, + ) + + subparsers.add_parser( + "generate", + help="Run inference on a multimodal model.", + add_help=False, + ) + + # simple commands + version_parser = subparsers.add_parser( + "version", + help="Show the version information.", + ) + version_parser.set_defaults(func=version) + + args, extra_argv = parser.parse_known_args() + + if args.subcommand == "serve": + from sglang.cli.serve import serve + + serve(args, extra_argv) + elif args.subcommand == "generate": + from sglang.cli.generate import generate + + generate(args, extra_argv) + elif args.subcommand == "version": + version(args, extra_argv) diff --git a/sglang/python/sglang/cli/serve.py b/sglang/python/sglang/cli/serve.py new file mode 100644 index 0000000000000000000000000000000000000000..03dd6a42298da081f46047ed4c5956c9609d9cb2 --- /dev/null +++ b/sglang/python/sglang/cli/serve.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import logging +import os + +from sglang.cli.utils import get_is_diffusion_model, get_model_path +from sglang.srt.utils import kill_process_tree + +logger = logging.getLogger(__name__) + + +def _extract_model_type_override(extra_argv): + """Extract and remove --model-type override from argv.""" + model_type = "auto" + filtered_argv = [] + i = 0 + while i < len(extra_argv): + arg = extra_argv[i] + if arg == "--model-type": + if i + 1 >= len(extra_argv): + raise Exception( + "Error: --model-type requires a value. " + "Valid values are: auto, llm, diffusion." + ) + model_type = extra_argv[i + 1] + i += 2 + continue + + if arg.startswith("--model-type="): + model_type = arg.split("=", 1)[1] + i += 1 + continue + + filtered_argv.append(arg) + i += 1 + + if model_type not in ("auto", "llm", "diffusion"): + raise Exception( + f"Error: invalid --model-type '{model_type}'. " + "Valid values are: auto, llm, diffusion." + ) + return model_type, filtered_argv + + +def serve(args, extra_argv): + if any(h in extra_argv for h in ("-h", "--help")): + # Since the server type is determined by the model, and we don't have a model path, + # we can't show the exact help. Instead, we show a general help message and then + # the help for both possible server types. + print( + "Usage: sglang serve --model-path [additional-arguments]\n" + ) + print( + "This command can launch either a standard language model server or a diffusion model server." + ) + print("The server type is determined by the model path.\n") + print( + "Optional override: --model-type {auto,llm,diffusion} " + "(default: auto, fallback to LLM on detection failure).\n" + ) + print("For specific arguments, please provide a model_path.") + print("\n--- Help for Standard Language Model Server ---") + from sglang.srt.server_args import prepare_server_args + + try: + prepare_server_args(["--help"]) + except SystemExit: + pass # argparse --help calls sys.exit + + print("\n--- Help for Diffusion Model Server ---") + from sglang.multimodal_gen.runtime.entrypoints.cli.serve import ( + add_multimodal_gen_serve_args, + ) + + parser = argparse.ArgumentParser(description="SGLang Diffusion Model Serving") + add_multimodal_gen_serve_args(parser) + parser.print_help() + return + + model_type, dispatch_argv = _extract_model_type_override(extra_argv) + model_path = get_model_path(dispatch_argv) + try: + if model_type == "auto": + is_diffusion_model = get_is_diffusion_model(model_path) + if is_diffusion_model: + logger.info("Diffusion model detected") + else: + is_diffusion_model = model_type == "diffusion" + logger.info( + "Dispatch override enabled: --model-type=%s " "(skip auto detection)", + model_type, + ) + + if is_diffusion_model: + # Logic for Diffusion Models + from sglang.multimodal_gen.runtime.entrypoints.cli.serve import ( + add_multimodal_gen_serve_args, + execute_serve_cmd, + ) + + parser = argparse.ArgumentParser( + description="SGLang Diffusion Model Serving" + ) + add_multimodal_gen_serve_args(parser) + parsed_args, remaining_argv = parser.parse_known_args(dispatch_argv) + + execute_serve_cmd(parsed_args, remaining_argv) + else: + # Logic for Standard Language Models + from sglang.launch_server import run_server + from sglang.srt.server_args import prepare_server_args + + server_args = prepare_server_args(dispatch_argv) + + run_server(server_args) + finally: + kill_process_tree(os.getpid(), include_parent=False) diff --git a/sglang/python/sglang/cli/utils.py b/sglang/python/sglang/cli/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e3e4d20327ecc46bd1290b80eaa83d654f04be51 --- /dev/null +++ b/sglang/python/sglang/cli/utils.py @@ -0,0 +1,110 @@ +import json +import logging +import os +import subprocess +from functools import lru_cache + +from sglang.srt.environ import envs + +logger = logging.getLogger(__name__) + + +def _is_diffusers_model_dir(model_dir: str) -> bool: + """Check if a local directory contains a valid diffusers model_index.json.""" + config_path = os.path.join(model_dir, "model_index.json") + if not os.path.exists(config_path): + return False + + with open(config_path) as f: + config = json.load(f) + + return "_diffusers_version" in config + + +def get_is_diffusion_model(model_path: str) -> bool: + """Detect whether model_path points to a diffusion model. + + For local directories, checks the filesystem directly. + For HF/ModelScope model IDs, attempts to fetch only model_index.json. + Returns False on any failure (network error, 404, offline mode, etc.) + so that the caller falls through to the standard LLM server path. + """ + try: + from sglang.multimodal_gen.registry import ( + is_known_non_diffusers_multimodal_model, + ) + except ImportError: + is_known_non_diffusers_multimodal_model = lambda _: False + + if os.path.isdir(model_path): + if _is_diffusers_model_dir(model_path): + return True + return is_known_non_diffusers_multimodal_model(model_path) + + if is_known_non_diffusers_multimodal_model(model_path): + return True + + try: + if envs.SGLANG_USE_MODELSCOPE.get(): + from modelscope import model_file_download + + file_path = model_file_download( + model_id=model_path, file_path="model_index.json" + ) + else: + from huggingface_hub import hf_hub_download + + file_path = hf_hub_download(repo_id=model_path, filename="model_index.json") + + return _is_diffusers_model_dir(os.path.dirname(file_path)) + except Exception as e: + logger.debug("Failed to auto-detect diffusion model for %s: %s", model_path, e) + return False + + +def get_model_path(extra_argv): + # Find the model_path argument + model_path = None + for i, arg in enumerate(extra_argv): + if arg == "--model-path": + if i + 1 < len(extra_argv): + model_path = extra_argv[i + 1] + break + elif arg.startswith("--model-path="): + model_path = arg.split("=", 1)[1] + break + + if model_path is None: + # Fallback for --help or other cases where model-path is not provided + if any(h in extra_argv for h in ["-h", "--help"]): + raise Exception( + "Usage: sglang serve --model-path [additional-arguments]\n\n" + "This command can launch either a standard language model server or a diffusion model server.\n" + "The server type is determined by the model path.\n" + "For specific arguments, please provide a model_path." + ) + else: + raise Exception( + "Error: --model-path is required. " + "Please provide the path to the model." + ) + return model_path + + +@lru_cache(maxsize=1) +def get_git_commit_hash() -> str: + try: + commit_hash = os.environ.get("SGLANG_GIT_COMMIT") + if not commit_hash: + commit_hash = ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL + ) + .strip() + .decode("utf-8") + ) + _CACHED_COMMIT_HASH = commit_hash + return commit_hash + except (subprocess.CalledProcessError, FileNotFoundError): + _CACHED_COMMIT_HASH = "N/A" + return "N/A" diff --git a/sglang/python/sglang/compile_deep_gemm.py b/sglang/python/sglang/compile_deep_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..77ddbadceaf285e5f6144208912d8141f4d1911b --- /dev/null +++ b/sglang/python/sglang/compile_deep_gemm.py @@ -0,0 +1,191 @@ +""" +Compile DeepGEMM Kernels for a model with specify server arguments + +This script launches a server for capturing DeepGEMM calls and then compiles the kernels. +It accepts server arguments (the same as launch_server.py). + +Usage: +python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code + +""" + +import argparse +import dataclasses +import multiprocessing +import os +import time + +import requests + +from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST +from sglang.srt.entrypoints.http_server import launch_server +from sglang.srt.entrypoints.warmup import warmup +from sglang.srt.environ import envs +from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import kill_process_tree + +multiprocessing.set_start_method("spawn", force=True) + +# Reduce warning +envs.SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE.set(True) +# Force enable deep gemm +envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(True) +# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case +envs.SGLANG_CHUNKED_PREFIX_CACHE_THRESHOLD.set(0) + + +@dataclasses.dataclass +class CompileArgs: + timeout: int = 3600 + + @staticmethod + def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--timeout", type=int, default=CompileArgs.timeout) + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + # use the default value's type to cast the args into correct types. + attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)] + return cls( + **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs} + ) + + +@warmup("compile-deep-gemm") +async def warm_up_compile( + disaggregation_mode: str, tokenizer_manager: TokenizerManager +): + print("\nGenerate warm up request for compiling DeepGEMM...\n") + generate_req_input = GenerateReqInput( + input_ids=[0, 1, 2, 3], + sampling_params={ + "temperature": 0.0, + "max_new_tokens": 8, + "ignore_eos": True, + }, + ) + if disaggregation_mode != "null": + generate_req_input.bootstrap_room = 0 + generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST + + await tokenizer_manager.generate_request(generate_req_input, None).__anext__() + + +def launch_server_internal(server_args): + try: + launch_server(server_args) + except Exception as e: + raise e + finally: + kill_process_tree(os.getpid(), include_parent=False) + + +def launch_server_process_and_send_one_request( + server_args: ServerArgs, compile_args: CompileArgs +): + proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,)) + proc.start() + base_url = f"http://{server_args.host}:{server_args.port}" + timeout = compile_args.timeout + + start_time = time.perf_counter() + while time.perf_counter() - start_time < timeout: + try: + headers = { + "Content-Type": "application/json; charset=utf-8", + } + if server_args.node_rank == 0: + response = requests.get(f"{base_url}/v1/models", headers=headers) + else: + # This http api is created by launch_dummy_health_check_server for none-rank0 node. + response = requests.get(f"{base_url}/health", headers=headers) + if response.status_code == 200: + # Rank-0 node send a request to sync with other node and then return. + if server_args.node_rank == 0: + payload = { + "input_ids": [0, 1, 2, 3], + "sampling_params": { + "max_new_tokens": 8, + "temperature": 0, + }, + } + # In PD mode, include fake bootstrap fields so workers don't assert + if server_args.disaggregation_mode != "null": + payload["bootstrap_host"] = FAKE_BOOTSTRAP_HOST + payload["bootstrap_room"] = 0 + + response = requests.post( + f"{base_url}/generate", + json=payload, + timeout=600, + ) + if response.status_code != 200: + error = response.json() + raise RuntimeError(f"Sync request failed: {error}") + # Other nodes should wait for the exit signal from Rank-0 node. + else: + start_time_waiting = time.perf_counter() + while proc.is_alive(): + if time.perf_counter() - start_time_waiting < timeout: + time.sleep(10) + else: + raise TimeoutError("Waiting for main node timeout!") + return proc + except requests.RequestException: + pass + time.sleep(10) + raise TimeoutError( + "DeepGEMM Kernels compilation timeout." + "\n\nFeel free and please restart the command." + ) + + +def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs): + # Disable cuda graph and torch compile to save time + server_args.disable_cuda_graph = True + server_args.enable_torch_compile = False + print(f"Disable CUDA Graph and Torch Compile to save time...") + + # Set watchdog timeout to compile_args.timeout because compilation will take a long time + server_args.watchdog_timeout = compile_args.timeout + server_args.warmups = "compile-deep-gemm" + + +def run_compile(server_args: ServerArgs, compile_args: CompileArgs): + print( + "Begin DeepGEMM Kernels compilation...\n" + "It may take a long time and timeout maybe raised " + "while the compilation is still in progress.\n" + "Just feel free to restart the command " + "until the compilation is fully finished.\n" + ) + + proc = launch_server_process_and_send_one_request(server_args, compile_args) + + print("\nDeepGEMM Kernels compilation finished successfully.") + + # Sleep for safety + time.sleep(10) + if proc.is_alive(): + # This is the rank0 node. + kill_process_tree(proc.pid) + else: + try: + kill_process_tree(proc.pid) + except Exception: + pass + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + CompileArgs.add_cli_args(parser) + args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) + compile_args = CompileArgs.from_cli_args(args) + + refine_server_args(server_args, compile_args) + + run_compile(server_args, compile_args) diff --git a/sglang/python/sglang/eval/llama3_eval.py b/sglang/python/sglang/eval/llama3_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..4a3c736de8a9aea92ffdc2eba5bbf69a0836d3e7 --- /dev/null +++ b/sglang/python/sglang/eval/llama3_eval.py @@ -0,0 +1,315 @@ +# Adapt from https://github.com/fw-ai/llm_eval_meta + +import argparse +import asyncio +import os +import pickle +import re +import shutil +from collections import defaultdict +from dataclasses import dataclass + +import httpx +import numpy as np +import openai +from datasets import load_dataset +from openai import AsyncOpenAI +from tqdm import tqdm + +# Mapping providers to their clients and models +provider_to_models = { + "b10": { + "8b": "meta-llama/Llama-3.1-8B-Instruct", + "70b": "meta-llama/Llama-3.1-70B-Instruct", + "405b": "meta-llama/Llama-3.1-405B-Instruct", + }, + "oai": { + "8b": "meta-llama/Llama-3.1-8B-Instruct", + "70b": "meta-llama/Llama-3.1-70B-Instruct", + "405b": "meta-llama/Llama-3.1-405B-Instruct", + }, + "sgl": { + "8b": "meta-llama/Llama-3.1-8B-Instruct", + "70b": "meta-llama/Llama-3.1-70B-Instruct", + "405b": "meta-llama/Llama-3.1-405B-Instruct", + }, +} + + +async def fetch_responses( + client, prompt, semaphore, index, provider, model_size, output_dir, max_tokens +): + output_file = os.path.join(output_dir, f"response_{index}.pkl") + if os.path.exists(output_file): + print(f"File {output_file} already exists, skipping.") + return + + async with semaphore: + response = await client.completions.create( + model=provider_to_models[provider][model_size], + prompt=prompt, + temperature=0.0, + max_tokens=max_tokens, + ) + if isinstance(response, openai.BadRequestError): + with open(output_file, "wb") as f: + pickle.dump("bad_response", f) + assert isinstance(response, openai.types.completion.Completion) + # Save response to a file + with open(output_file, "wb") as f: + pickle.dump(response, f) + + +TASK_TO_MAX_TOKENS = { + "evals__mmlu__details": 1, + "evals__mmlu__0_shot__cot__details": 1024, + # Official meta uses 1024, but a small % (.05) of questions are answered correctly after relaxing + "evals__mmlu_pro__details": 2048, + "evals__gsm8k__details": 1024, +} + +TASK_TO_EVAL_SET = { + "mmlu": "evals__mmlu__details", + "mmlu_cot": "evals__mmlu__0_shot__cot__details", + "mmlu_pro": "evals__mmlu_pro__details", + "gsm8k": "evals__gsm8k__details", +} + + +class CustomAsyncHTTPXClient(httpx.AsyncClient): + async def send(self, request: httpx.Request, *args, **kwargs) -> httpx.Response: + request.url = httpx.URL( + f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict" + ) + return await super().send(request, *args, **kwargs) + + +def get_client(provider): + if provider not in "b10": + if os.getenv("OPENAI_API_KEY") is None: + os.environ["OPENAI_API_KEY"] = "EMPTY" + return { + "oai": AsyncOpenAI(base_url="http://127.0.0.1:8000/v1/"), + "b10": AsyncOpenAI( + api_key=f"Api-Key {os.getenv('OPENAI_API_KEY')}", + base_url=f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict", + http_client=CustomAsyncHTTPXClient(), + ), + "sgl": AsyncOpenAI(base_url="http://127.0.0.1:30000/v1/"), + }[provider] + + +# Define the benchmark function +async def benchmark(args): + ds = load_dataset( + "meta-llama/Llama-3.1-405B-Instruct-evals", + f"Llama-3.1-405B-Instruct-{TASK_TO_EVAL_SET[args.task]}", + ) + semaphore = asyncio.Semaphore(args.concurrency) # Limit to 16 concurrent tasks + + if args.num_examples is None: + args.num_examples = len(ds["latest"]["input_final_prompts"]) + prompts = ds["latest"]["input_final_prompts"][: args.num_examples] + + # Create the output directory if it does not exist + os.makedirs(args.output_dir, exist_ok=True) + + tasks = [] + # Create the tasks with tqdm progress bar + max_tokens = TASK_TO_MAX_TOKENS[TASK_TO_EVAL_SET[args.task]] + client = get_client(args.provider) + for idx, prompt in enumerate(tqdm(prompts, desc="Creating tasks")): + tasks.append( + asyncio.create_task( + fetch_responses( + client, + f"<|begin_of_text|>{prompt[0]}", + semaphore, + idx, + args.provider, + args.model_size, + args.output_dir, + max_tokens=max_tokens, + ) + ) + ) + + # Run the tasks with tqdm progress bar + for future in tqdm( + asyncio.as_completed(tasks), total=len(tasks), desc="Processing tasks" + ): + await future + + +def get_mmlu_answer(response): + if response is not None: + return response.choices[0].text.lstrip().rstrip().upper().replace(".", "") + return None + + +def get_mmlu_cot_answer(response): + pattern = r"The best answer is (.+)\.?" + match = re.search(pattern, response.choices[0].text) + if match: + return match.group(1).replace(".", "").replace("*", "") + + pattern = r"the best answer is (.+)\.?" + match = re.search(pattern, response.choices[0].text) + if match: + return match.group(1).replace(".", "") + + pattern = r"The correct answer is (.+)\.?" + match = re.search(pattern, response.choices[0].text) + if match: + return match.group(1).replace(".", "") + + pattern = r"the correct answer is (.+)\.?" + match = re.search(pattern, response.choices[0].text) + if match: + return match.group(1).replace(".", "") + + +def get_answer_gsm8k(response): + pattern = r"The final answer is (.+)\.?" + match = re.search(pattern, response.choices[0].text) + if match: + s = match.group(1) + for ok_symbol in ["%", "$"]: + s = s.replace(ok_symbol, "") + return s + + +TASK_TO_ANSWER_EXTRACTOR = { + "evals__mmlu__details": get_mmlu_answer, + "evals__mmlu__0_shot__cot__details": get_mmlu_cot_answer, + "evals__gsm8k__details": get_answer_gsm8k, + "evals__mmlu_pro__details": get_mmlu_cot_answer, +} + + +def get_dataset_from_task(task, response_path, model_size): + ds_405b = load_dataset( + f"meta-llama/Llama-3.1-405B-Instruct-evals", + f"Llama-3.1-405B-Instruct-{task}", + ) + ds_405b_hash_order = [x[0] for x in ds_405b["latest"]["input_final_prompts_hash"]] + + if "70b" in model_size or "8b" in model_size: + if "70" in model_size: + ref_model_ds = load_dataset( + f"meta-llama/Llama-3.1-70B-Instruct-evals", + f"Llama-3.1-70B-Instruct-{task}", + ) + else: + ref_model_ds = load_dataset( + f"meta-llama/Llama-3.1-8B-Instruct-evals", + f"Llama-3.1-8B-Instruct-{task}", + ) + + hash_to_row = {} + for row in ref_model_ds["latest"]: + hash_to_row[row["input_final_prompts_hash"][0]] = row + reordered_rows = [] + for prompt_hash in ds_405b_hash_order: + reordered_rows.append(hash_to_row[prompt_hash]) + ref_model_ds["latest"] = reordered_rows + return ref_model_ds + + return ds_405b + + +def analyze(task, response_path, model_size): + ds = get_dataset_from_task(task, response_path, model_size) + + responses = [] + total = len(ds["latest"]) + + for i in range(0, total): + response = pickle.load( + open(os.path.join(response_path, f"response_{i}.pkl"), "rb") + ) + responses.append(response) + + @dataclass + class Stats: + correct: int = 0 + total: int = 0 + meta_correct: int = 0 + + average: float = None + + subtask_name_to_stats = defaultdict(lambda: Stats()) + + for response, ds_row in zip(responses, ds["latest"]): + model_answer = TASK_TO_ANSWER_EXTRACTOR[task](response) + + subtask = ds_row["subtask_name"] + + is_eval_correct = model_answer in ds_row["input_correct_responses"] + if is_eval_correct: + subtask_name_to_stats[subtask].correct += 1 + + if ds_row["is_correct"]: + subtask_name_to_stats[subtask].meta_correct += 1 + + subtask_name_to_stats[subtask].total += 1 + + micro_stats = Stats() + for subtask, stats in subtask_name_to_stats.items(): + stats.average = stats.correct / stats.total + stats.meta_average = stats.meta_correct / stats.total + + micro_stats.correct += stats.correct + micro_stats.total += stats.total + micro_stats.meta_correct += stats.meta_correct + + micro_stats.average = micro_stats.correct / micro_stats.total + micro_stats.meta_average = micro_stats.meta_correct / micro_stats.total + + print("Macro average", np.mean([x.average for x in subtask_name_to_stats.values()])) + print( + "Meta Macro average", + np.mean([x.meta_average for x in subtask_name_to_stats.values()]), + ) + print("Micro average", micro_stats.average) + print("Meta Micro average", micro_stats.meta_average) + + +# Entry point for the script +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Script to run model with specified parameters." + ) + parser.add_argument( + "--model-size", + type=str, + default="8b", + help="Size of the model (e.g., 8b or 70b)", + ) + parser.add_argument( + "--provider", + type=str, + default="sgl", + help="Provider name (e.g., sgl, oai, b10)", + ) + parser.add_argument( + "--task", + type=str, + required=True, + help="Task (e.g., mmlu, mmlu_cot, mmlu_pro, gsm8k)", + ) + parser.add_argument( + "--num-examples", type=int, default=None, help="Number of examples to process" + ) + parser.add_argument("--concurrency", type=int, default=16) + parser.add_argument( + "--output-dir", + type=str, + default="tmp-output-dir", + help="Directory to save responses", + ) + + args = parser.parse_args() + asyncio.run(benchmark(args)) + analyze(TASK_TO_EVAL_SET[args.task], args.output_dir, args.model_size) + shutil.rmtree("tmp-output-dir", ignore_errors=True) diff --git a/sglang/python/sglang/eval/loogle_eval.py b/sglang/python/sglang/eval/loogle_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..895362cd18c78675cb154881ac6f0857298d9a69 --- /dev/null +++ b/sglang/python/sglang/eval/loogle_eval.py @@ -0,0 +1,164 @@ +import argparse +import asyncio +import os +import pickle +from pathlib import Path +from typing import List + +import openai +import torch +from bert_score import BERTScorer +from datasets import load_dataset +from tqdm import tqdm + + +def get_client(api_url: str) -> openai.AsyncOpenAI: + if os.getenv("OPENAI_API_KEY") is None: + os.environ["OPENAI_API_KEY"] = "EMPTY" + return openai.AsyncOpenAI(base_url=api_url) + + +def get_dataset(): + return load_dataset("bigai-nlco/LooGLE", "longdep_qa", split="test") + + +async def fetch_response( + client: openai.AsyncOpenAI, + context: str, + question: str, + semaphore: asyncio.Semaphore, + index: int, + model: str, + output_dir: Path, +): + output_file = output_dir / f"response_{index}.pkl" + if output_file.exists(): + return + + prompt = ( + "Please answer the question based on the long texts below.\n" + f"{context}\n" + f"Question: {question}\n" + "Answer:" + ) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ] + + async with semaphore: + try: + response = await client.chat.completions.create( + model=model, + messages=messages, + temperature=0.0, + max_tokens=512, + ) + except openai.BadRequestError as e: + with open(output_file, "wb") as f: + pickle.dump({"error": str(e)}, f) + return + + with open(output_file, "wb") as f: + pickle.dump(response, f) + + +async def benchmark(args): + dataset = get_dataset() + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + client = get_client(args.api_url) + semaphore = asyncio.Semaphore(args.max_concurrency) + + tasks: List[asyncio.Task] = [] + for idx, ex in enumerate(dataset): + if idx >= args.num_prompts: + break + tasks.append( + asyncio.create_task( + fetch_response( + client, + ex["context"], + ex["question"], + semaphore, + idx, + args.model, + output_dir, + ) + ) + ) + + for _ in tqdm( + asyncio.as_completed(tasks), total=len(tasks), desc="Running benchmark" + ): + await _ + + +def analyse(args): + dataset = get_dataset() + output_dir = Path(args.output_dir) + + device = "cuda" if torch.cuda.is_available() else "cpu" + scorer = BERTScorer(lang="en", device=device) + + hyps: List[str] = [] + refs: List[str] = [] + for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")): + if idx >= args.num_prompts: + break + pkl_file = output_dir / f"response_{idx}.pkl" + if not pkl_file.exists(): + raise FileNotFoundError(pkl_file) + + response = pickle.load(open(pkl_file, "rb")) + if isinstance(response, dict) and "error" in response: + continue + + hyps.append(response.choices[0].message.content.strip()) + refs.append(ex["answer"]) + + if not hyps: + print("No valid responses to score!") + return + + batch_size = 64 + all_f1: List[float] = [] + for i in tqdm(range(0, len(hyps), batch_size), desc="Scoring batches"): + h_batch = hyps[i : i + batch_size] + r_batch = refs[i : i + batch_size] + _, _, f1_scores = scorer.score(h_batch, r_batch, verbose=False) + all_f1.extend([float(x) for x in f1_scores]) + + avg = sum(all_f1) / len(all_f1) + print(f"Average BERTScore (F1): {avg:.2%}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run benchmark and evaluation in one go." + ) + parser.add_argument( + "--api-url", + default="http://127.0.0.1:30000/v1", + help="OpenAI‑compatible API base URL", + ) + parser.add_argument( + "--model", + default="meta-llama/Llama-4-Maverick-17B-128E-Instruct", + help="Model name or ID, only used for model name", + ) + parser.add_argument( + "--max-concurrency", type=int, default=144, help="Maximum concurrent requests" + ) + parser.add_argument( + "--output-dir", default="tmp-output-dir", help="Directory for cached responses" + ) + parser.add_argument( + "--num-prompts", type=int, default=10000, help="Number of prompts to run" + ) + args = parser.parse_args() + + asyncio.run(benchmark(args)) + + analyse(args) diff --git a/sglang/python/sglang/global_config.py b/sglang/python/sglang/global_config.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd65b5ed7841cb16d01e79e9139219487103429 --- /dev/null +++ b/sglang/python/sglang/global_config.py @@ -0,0 +1,29 @@ +"""Global configurations""" + +# FIXME: deprecate this file and move all usage to sglang.srt.environ or sglang.__init__.py + + +class GlobalConfig: + """ + Store some global constants. + """ + + def __init__(self): + # Verbosity level + # 0: do not output anything + # 2: output final text after every run + self.verbosity = 0 + + # Default backend of the language + self.default_backend = None + + # Output tokenization configs + self.skip_special_tokens_in_output = True + self.spaces_between_special_tokens_in_out = True + + # Language frontend interpreter optimization configs + self.enable_precache_with_tracing = True + self.enable_parallel_encoding = True + + +global_config = GlobalConfig() diff --git a/sglang/python/sglang/jit_kernel/.clang-format b/sglang/python/sglang/jit_kernel/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..56acfb8b8f5cae0436ee4b41f8c0e4ab64745bc6 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/.clang-format @@ -0,0 +1,25 @@ +BasedOnStyle: Google +IndentWidth: 2 +ColumnLimit: 120 +AllowShortFunctionsOnASingleLine: Empty +DerivePointerAlignment: false +PointerAlignment: Left +NamespaceIndentation: None +SortIncludes: true +AllowShortLoopsOnASingleLine: false +BinPackParameters: false # Prevents packing parameters in declarations +BinPackArguments: false # Prevents packing arguments in function calls +AlignAfterOpenBracket: AlwaysBreak # Forces a break after the opening parenthesis +AlignOperands: Align # Aligns arguments vertically +PenaltyBreakBeforeFirstCallParameter: 1 # Encourages breaking before the first argument +PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name + +IncludeCategories: + - Regex: '^$' + Priority: 0 + - Regex: '^$' + Priority: 2 + - Regex: '^$' + Priority: 1 + - Regex: '^<.*/.*>$' + Priority: 3 diff --git a/sglang/python/sglang/jit_kernel/__init__.py b/sglang/python/sglang/jit_kernel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sglang/python/sglang/jit_kernel/__main__.py b/sglang/python/sglang/jit_kernel/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..bacf4f84e6eb5b2cba26aad56f1077ff54667048 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/__main__.py @@ -0,0 +1,48 @@ +assert __name__ == "__main__" + + +def generate_clangd(): + import logging + import os + import subprocess + + from tvm_ffi.libinfo import find_dlpack_include_path, find_include_path + + from sglang.jit_kernel.utils import DEFAULT_INCLUDE + + logger = logging.getLogger() + logger.info("Generating .clangd file...") + include_paths = [find_include_path(), find_dlpack_include_path()] + DEFAULT_INCLUDE + status = subprocess.run( + args=["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"], + capture_output=True, + check=True, + ) + compute_cap = status.stdout.decode("utf-8").strip().split("\n")[0] + major, minor = compute_cap.split(".") + compile_flags = ",\n ".join( + [ + "-xcuda", + f"--cuda-gpu-arch=sm_{major}{minor}", + "-std=c++20", + "-Wall", + "-Wextra", + ] + + [f"-isystem{path}" for path in include_paths] + ) + clangd_content = f""" +CompileFlags: + Add: [ + {compile_flags} + ] +""" + if os.path.exists(".clangd"): + logger.warning(".clangd file already exists, nothing done.") + logger.warning(f"suggested content: {clangd_content}") + else: + with open(".clangd", "w") as f: + f.write(clangd_content) + logger.info(".clangd file generated.") + + +generate_clangd() diff --git a/sglang/python/sglang/jit_kernel/__pycache__/__init__.cpython-311.pyc b/sglang/python/sglang/jit_kernel/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88541c9cd90d41f018d3612f9a15ff996f1c36cb Binary files /dev/null and b/sglang/python/sglang/jit_kernel/__pycache__/__init__.cpython-311.pyc differ diff --git a/sglang/python/sglang/jit_kernel/__pycache__/awq_dequantize.cpython-311.pyc b/sglang/python/sglang/jit_kernel/__pycache__/awq_dequantize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb0199923806989f86698a17c04fab23be372ab8 Binary files /dev/null and b/sglang/python/sglang/jit_kernel/__pycache__/awq_dequantize.cpython-311.pyc differ diff --git a/sglang/python/sglang/jit_kernel/__pycache__/awq_marlin_repack.cpython-311.pyc b/sglang/python/sglang/jit_kernel/__pycache__/awq_marlin_repack.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82db4a5b9caef7501d50f51980fe662a02e67f15 Binary files /dev/null and b/sglang/python/sglang/jit_kernel/__pycache__/awq_marlin_repack.cpython-311.pyc differ diff --git a/sglang/python/sglang/jit_kernel/__pycache__/flash_attention_v4.cpython-311.pyc b/sglang/python/sglang/jit_kernel/__pycache__/flash_attention_v4.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b0dd012a4efa47f9a81e6381a9bf7704b265e02 Binary files /dev/null and b/sglang/python/sglang/jit_kernel/__pycache__/flash_attention_v4.cpython-311.pyc differ diff --git a/sglang/python/sglang/jit_kernel/__pycache__/fused_store_index_cache.cpython-311.pyc b/sglang/python/sglang/jit_kernel/__pycache__/fused_store_index_cache.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38f05402355f950984f64fd7400d56112228afb1 Binary files /dev/null and b/sglang/python/sglang/jit_kernel/__pycache__/fused_store_index_cache.cpython-311.pyc differ diff --git a/sglang/python/sglang/jit_kernel/__pycache__/gptq_marlin.cpython-311.pyc b/sglang/python/sglang/jit_kernel/__pycache__/gptq_marlin.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..685b8de3293360aa0f8006ced8b36c71cf93d6b9 Binary files /dev/null and b/sglang/python/sglang/jit_kernel/__pycache__/gptq_marlin.cpython-311.pyc differ diff --git a/sglang/python/sglang/jit_kernel/__pycache__/gptq_marlin_repack.cpython-311.pyc b/sglang/python/sglang/jit_kernel/__pycache__/gptq_marlin_repack.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..973f098fd87b78e380466d07b5423240069fb0c0 Binary files /dev/null and b/sglang/python/sglang/jit_kernel/__pycache__/gptq_marlin_repack.cpython-311.pyc differ diff --git a/sglang/python/sglang/jit_kernel/__pycache__/hicache.cpython-311.pyc b/sglang/python/sglang/jit_kernel/__pycache__/hicache.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09bdb3f8bcd45bf6186e82c5dd9208e719d647e1 Binary files /dev/null and b/sglang/python/sglang/jit_kernel/__pycache__/hicache.cpython-311.pyc differ diff --git a/sglang/python/sglang/jit_kernel/__pycache__/kvcache.cpython-311.pyc b/sglang/python/sglang/jit_kernel/__pycache__/kvcache.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7f3eb8efb442490c95097b39df37cb6f77a77c8 Binary files /dev/null and b/sglang/python/sglang/jit_kernel/__pycache__/kvcache.cpython-311.pyc differ diff --git a/sglang/python/sglang/jit_kernel/__pycache__/norm.cpython-311.pyc b/sglang/python/sglang/jit_kernel/__pycache__/norm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2575b52e27067af9123608098117a75cc2731d3b Binary files /dev/null and b/sglang/python/sglang/jit_kernel/__pycache__/norm.cpython-311.pyc differ diff --git a/sglang/python/sglang/jit_kernel/__pycache__/nvfp4.cpython-311.pyc b/sglang/python/sglang/jit_kernel/__pycache__/nvfp4.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30973571bb589cf78f39309954eb22ec926d9f75 Binary files /dev/null and b/sglang/python/sglang/jit_kernel/__pycache__/nvfp4.cpython-311.pyc differ diff --git a/sglang/python/sglang/jit_kernel/__pycache__/per_tensor_quant_fp8.cpython-311.pyc b/sglang/python/sglang/jit_kernel/__pycache__/per_tensor_quant_fp8.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a992919baac5b8d5624c1dc8b1dd835c13edc0c2 Binary files /dev/null and b/sglang/python/sglang/jit_kernel/__pycache__/per_tensor_quant_fp8.cpython-311.pyc differ diff --git a/sglang/python/sglang/jit_kernel/__pycache__/per_token_group_quant_8bit.cpython-311.pyc b/sglang/python/sglang/jit_kernel/__pycache__/per_token_group_quant_8bit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..158ce9cfcf051456707a76b39b4d593ca65f811f Binary files /dev/null and b/sglang/python/sglang/jit_kernel/__pycache__/per_token_group_quant_8bit.cpython-311.pyc differ diff --git a/sglang/python/sglang/jit_kernel/__pycache__/rope.cpython-311.pyc b/sglang/python/sglang/jit_kernel/__pycache__/rope.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..753667921a121ee75a82203ef3790572612b4efe Binary files /dev/null and b/sglang/python/sglang/jit_kernel/__pycache__/rope.cpython-311.pyc differ diff --git a/sglang/python/sglang/jit_kernel/__pycache__/utils.cpython-311.pyc b/sglang/python/sglang/jit_kernel/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cba1b9089f3f7e7f8b6d888e6fd14aa201f43add Binary files /dev/null and b/sglang/python/sglang/jit_kernel/__pycache__/utils.cpython-311.pyc differ diff --git a/sglang/python/sglang/jit_kernel/add_constant.py b/sglang/python/sglang/jit_kernel/add_constant.py new file mode 100644 index 0000000000000000000000000000000000000000..acef6ed95bcb3b300c2bcc6d9ffc88b3407b665a --- /dev/null +++ b/sglang/python/sglang/jit_kernel/add_constant.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_add_constant_module(constant: int) -> Module: + args = make_cpp_args(constant) # pass all the template argument + return load_jit( + "add_constant", + *args, + cuda_files=["add_constant.cuh"], + cuda_wrappers=[("add_constant", f"add_constant<{args}>")], + ) + + +def add_constant(src: torch.Tensor, constant: int) -> torch.Tensor: + dst = torch.empty_like(src) + module = _jit_add_constant_module(constant) + module.add_constant(dst, src) + return dst diff --git a/sglang/python/sglang/jit_kernel/awq_dequantize.py b/sglang/python/sglang/jit_kernel/awq_dequantize.py new file mode 100644 index 0000000000000000000000000000000000000000..4a188c02e51b2711264c3852c719034b27b52f76 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/awq_dequantize.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_awq_dequantize_module(dtype: torch.dtype) -> Module: + args = make_cpp_args(dtype) + return load_jit( + "awq_dequantize", + *args, + cuda_files=["gemm/awq_dequantize.cuh"], + cuda_wrappers=[("awq_dequantize", f"awq_dequantize<{args}>")], + ) + + +def awq_dequantize( + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, +) -> torch.Tensor: + qweight_rows = qweight.shape[0] + qweight_cols = qweight.shape[1] + output = torch.empty( + (qweight_rows, qweight_cols * 8), + dtype=scales.dtype, + device=scales.device, + ) + module = _jit_awq_dequantize_module(scales.dtype) + module.awq_dequantize(output, qweight, scales, qzeros) + return output diff --git a/sglang/python/sglang/jit_kernel/awq_marlin_repack.py b/sglang/python/sglang/jit_kernel/awq_marlin_repack.py new file mode 100644 index 0000000000000000000000000000000000000000..3b06144cbef6cabab1a7f63119a6fcedeacb3b4b --- /dev/null +++ b/sglang/python/sglang/jit_kernel/awq_marlin_repack.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_awq_marlin_repack_module() -> Module: + return load_jit( + "awq_marlin_repack", + cuda_files=["gemm/marlin/awq_marlin_repack.cuh"], + cuda_wrappers=[("awq_marlin_repack", "awq_marlin_repack")], + ) + + +def awq_marlin_repack( + b_q_weight: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + tile_size = 16 + pack_factor = 32 // num_bits + out = torch.empty( + (size_k // tile_size, size_n * tile_size // pack_factor), + dtype=b_q_weight.dtype, + device=b_q_weight.device, + ) + module = _jit_awq_marlin_repack_module() + module.awq_marlin_repack(out, b_q_weight, size_k, size_n, num_bits) + return out + + +def awq_marlin_moe_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty( + (num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype, + ) + for e in range(num_experts): + output[e] = awq_marlin_repack(b_q_weight[e], size_k, size_n, num_bits) + return output diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_awq_dequantize.py b/sglang/python/sglang/jit_kernel/benchmark/bench_awq_dequantize.py new file mode 100644 index 0000000000000000000000000000000000000000..09b6ccb3fb75fff5f116a3679ee8573d1c506944 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_awq_dequantize.py @@ -0,0 +1,125 @@ +import itertools +import os + +import torch +import triton +import triton.testing + +from sglang.jit_kernel.awq_dequantize import awq_dequantize as jit_awq_dequantize + +try: + from sgl_kernel import awq_dequantize as aot_awq_dequantize + + AOT_AVAILABLE = True +except ImportError: + AOT_AVAILABLE = False + +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + +# CI environment uses simplified parameters +if IS_CI: + qweight_row_range = [128] + qweight_cols_range = [16] +else: + qweight_row_range = [128, 256, 512, 1024, 3584] + qweight_cols_range = [16, 32, 64, 128, 448] + +configs = list(itertools.product(qweight_row_range, qweight_cols_range)) + + +def check_correctness(): + if not AOT_AVAILABLE: + print("sgl_kernel AOT not available, skipping correctness check") + return + + qweight_row, qweight_col = 128, 16 + device = torch.device("cuda") + qweight = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (qweight_row, qweight_col), + dtype=torch.int32, + device=device, + ) + group_size = qweight_row + scales_row = qweight_row // group_size + scales_col = qweight_col * 8 + scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device) + qzeros = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (scales_row, qweight_col), + dtype=torch.int32, + device=device, + ) + + jit_out = jit_awq_dequantize(qweight, scales, qzeros) + aot_out = aot_awq_dequantize(qweight, scales, qzeros) + torch.cuda.synchronize() + torch.testing.assert_close(jit_out, aot_out, rtol=0, atol=0) + print("Correctness check passed (JIT vs AOT)") + + +if AOT_AVAILABLE: + line_vals = ["jit", "aot"] + line_names = ["JIT Kernel", "AOT Kernel"] + styles = [("blue", "-"), ("green", "-")] +else: + line_vals = ["jit"] + line_names = ["JIT Kernel"] + styles = [("blue", "-")] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["qweight_row", "qweight_col"], + x_vals=configs, + line_arg="provider", + line_vals=line_vals, + line_names=line_names, + styles=styles, + ylabel="us", + plot_name="awq-dequantize-jit-vs-aot", + args={}, + ) +) +def benchmark(qweight_row, qweight_col, provider): + device = torch.device("cuda") + qweight = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (qweight_row, qweight_col), + dtype=torch.int32, + device=device, + ) + group_size = qweight_row + scales_row = qweight_row // group_size + scales_col = qweight_col * 8 + scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device) + qzeros = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (scales_row, qweight_col), + dtype=torch.int32, + device=device, + ) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "jit": + fn = lambda: jit_awq_dequantize(qweight, scales, qzeros) + elif provider == "aot": + fn = lambda: aot_awq_dequantize(qweight, scales, qzeros) + else: + raise ValueError(f"Unknown provider: {provider}") + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + check_correctness() + benchmark.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_awq_marlin_moe_repack.py b/sglang/python/sglang/jit_kernel/benchmark/bench_awq_marlin_moe_repack.py new file mode 100644 index 0000000000000000000000000000000000000000..120c177d54d13e4ae2a6c69ce7e42a1f34bd623c --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_awq_marlin_moe_repack.py @@ -0,0 +1,133 @@ +import os + +import numpy as np +import torch +import triton +import triton.testing +from sgl_kernel.scalar_type import scalar_types + +from sglang.jit_kernel.awq_marlin_repack import ( + awq_marlin_moe_repack as jit_awq_marlin_moe_repack, +) +from sglang.srt.layers.quantization.utils import pack_cols, quantize_weights + +try: + from sgl_kernel import awq_marlin_moe_repack as aot_awq_marlin_moe_repack + + AOT_AVAILABLE = True +except ImportError: + AOT_AVAILABLE = False + +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + +# Fixed parameters +NUM_BITS = 4 +GROUP_SIZE = 128 +SIZE_N = 4096 + + +def awq_pack(q_w, num_bits, size_k, size_n): + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + return pack_cols(q_w, num_bits, size_k, size_n) + + +def make_moe_weights(num_experts, size_k, size_n, num_bits, group_size): + pack_factor = 32 // num_bits + b_q_weight = torch.empty( + (num_experts, size_k, size_n // pack_factor), + dtype=torch.int32, + device="cuda", + ) + for e in range(num_experts): + b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda") + w_ref, q_w, s, zp = quantize_weights( + b_weight, scalar_types.uint4, min(group_size, size_k), zero_points=True + ) + b_q_weight[e] = awq_pack(q_w, num_bits, size_k, size_n) + perm = torch.empty((num_experts, 0), dtype=torch.int32, device="cuda") + return b_q_weight, perm + + +def check_correctness(): + if not AOT_AVAILABLE: + print("sgl_kernel AOT not available, skipping correctness check") + return + + num_experts = 4 + size_k = 1024 + b_q_weight, perm = make_moe_weights( + num_experts, size_k, SIZE_N, NUM_BITS, GROUP_SIZE + ) + + out_jit = jit_awq_marlin_moe_repack(b_q_weight, perm, size_k, SIZE_N, NUM_BITS) + out_aot = aot_awq_marlin_moe_repack(b_q_weight, perm, size_k, SIZE_N, NUM_BITS) + torch.cuda.synchronize() + torch.testing.assert_close(out_jit, out_aot, rtol=0, atol=0) + print("Correctness check passed (JIT vs AOT)") + + +if IS_CI: + expert_range = [2, 4] +else: + expert_range = [2, 4, 8, 16] + +if AOT_AVAILABLE: + line_vals = ["jit", "aot"] + line_names = ["JIT Kernel", "AOT Kernel"] + styles = [("blue", "-"), ("green", "-")] +else: + line_vals = ["jit"] + line_names = ["JIT Kernel"] + styles = [("blue", "-")] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_experts"], + x_vals=expert_range, + line_arg="provider", + line_vals=line_vals, + line_names=line_names, + styles=styles, + ylabel="us", + plot_name="awq-marlin-moe-repack-performance", + args={"size_k": 4096, "size_n": SIZE_N, "num_bits": NUM_BITS}, + ) +) +def benchmark(num_experts, size_k, size_n, num_bits, provider): + group_size = min(GROUP_SIZE, size_k) + b_q_weight, perm = make_moe_weights( + num_experts, size_k, size_n, num_bits, group_size + ) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "jit": + fn = lambda: jit_awq_marlin_moe_repack( + b_q_weight, perm, size_k, size_n, num_bits + ) + elif provider == "aot": + fn = lambda: aot_awq_marlin_moe_repack( + b_q_weight, perm, size_k, size_n, num_bits + ) + else: + raise ValueError(f"Unknown provider: {provider}") + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + check_correctness() + benchmark.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_awq_marlin_repack.py b/sglang/python/sglang/jit_kernel/benchmark/bench_awq_marlin_repack.py new file mode 100644 index 0000000000000000000000000000000000000000..51403363d318e5e736862f2cb77070eb38706897 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_awq_marlin_repack.py @@ -0,0 +1,117 @@ +import os + +import numpy as np +import torch +import triton +import triton.testing +from sgl_kernel.scalar_type import scalar_types + +from sglang.jit_kernel.awq_marlin_repack import ( + awq_marlin_repack as jit_awq_marlin_repack, +) +from sglang.srt.layers.quantization.utils import pack_cols, quantize_weights + +try: + from sgl_kernel import awq_marlin_repack as aot_awq_marlin_repack + + AOT_AVAILABLE = True +except ImportError: + AOT_AVAILABLE = False + +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + +# Fixed problem dimensions +SIZE_K = 4096 +SIZE_N = 4096 +NUM_BITS = 4 +GROUP_SIZE = 128 + + +def awq_pack(q_w, num_bits, size_k, size_n): + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + return pack_cols(q_w, num_bits, size_k, size_n) + + +# Quantize weights once +_b_weight = torch.randn((SIZE_K, SIZE_N), dtype=torch.float16, device="cuda") +_w_ref, _q_w, _s, _zp = quantize_weights( + _b_weight, scalar_types.uint4, GROUP_SIZE, zero_points=True +) +_q_w_awq = awq_pack(_q_w, NUM_BITS, SIZE_K, SIZE_N) + + +def check_correctness(): + if not AOT_AVAILABLE: + print("sgl_kernel AOT not available, skipping correctness check") + return + out_jit = jit_awq_marlin_repack(_q_w_awq, SIZE_K, SIZE_N, NUM_BITS) + out_aot = aot_awq_marlin_repack(_q_w_awq, SIZE_K, SIZE_N, NUM_BITS) + torch.cuda.synchronize() + torch.testing.assert_close(out_jit, out_aot, rtol=0, atol=0) + print("Correctness check passed (JIT vs AOT)") + + +if IS_CI: + k_range = [1024, 4096] +else: + k_range = [512, 1024, 2048, 4096, 8192] + +if AOT_AVAILABLE: + line_vals = ["jit", "aot"] + line_names = ["JIT Kernel", "AOT Kernel"] + styles = [("blue", "-"), ("green", "-")] +else: + line_vals = ["jit"] + line_names = ["JIT Kernel"] + styles = [("blue", "-")] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["size_k"], + x_vals=k_range, + line_arg="provider", + line_vals=line_vals, + line_names=line_names, + styles=styles, + ylabel="us", + plot_name="awq-marlin-repack-performance", + args={"size_n": SIZE_N, "num_bits": NUM_BITS}, + ) +) +def benchmark(size_k, size_n, num_bits, provider): + group_size = min(GROUP_SIZE, size_k) + + b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda") + w_ref, q_w, s, zp = quantize_weights( + b_weight, scalar_types.uint4, group_size, zero_points=True + ) + q_w_awq = awq_pack(q_w, num_bits, size_k, size_n) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "jit": + fn = lambda: jit_awq_marlin_repack(q_w_awq, size_k, size_n, num_bits) + elif provider == "aot": + fn = lambda: aot_awq_marlin_repack(q_w_awq, size_k, size_n, num_bits) + else: + raise ValueError(f"Unknown provider: {provider}") + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + check_correctness() + benchmark.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_concat_mla.py b/sglang/python/sglang/jit_kernel/benchmark/bench_concat_mla.py new file mode 100644 index 0000000000000000000000000000000000000000..87c7ae56f91ed948f674b2a3e52cadcbf59d631c --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_concat_mla.py @@ -0,0 +1,163 @@ +import itertools + +import torch +import triton +import triton.testing +from sgl_kernel import concat_mla_absorb_q as aot_absorb_q +from sgl_kernel import concat_mla_k as aot_k + +from sglang.jit_kernel.benchmark.utils import is_in_ci +from sglang.jit_kernel.concat_mla import concat_mla_absorb_q as jit_absorb_q +from sglang.jit_kernel.concat_mla import concat_mla_k as jit_k + +IS_CI = is_in_ci() + +# Constants +NUM_LOCAL_HEADS = 128 +QK_NOPE_HEAD_DIM = 128 +QK_ROPE_HEAD_DIM = 64 +K_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM + +A_LAST_DIM = 512 +B_LAST_DIM = 64 + +DTYPE = torch.bfloat16 +DEVICE = "cuda" + + +def aot_concat_mla_k(k, k_nope, k_rope): + aot_k(k, k_nope, k_rope) + + +def jit_concat_mla_k(k, k_nope, k_rope): + jit_k(k, k_nope, k_rope) + + +def torch_concat_mla_k(k, k_nope, k_rope): + nope_head_dim = k_nope.shape[-1] + k[:, :, :nope_head_dim] = k_nope + k[:, :, nope_head_dim:] = k_rope.expand(-1, k.shape[1], -1) + + +def aot_concat_mla_absorb_q(a, b): + return aot_absorb_q(a, b) + + +def jit_concat_mla_absorb_q(a, b): + return jit_absorb_q(a, b) + + +def torch_concat_mla_absorb_q(a, b, out): + a_last_dim = a.shape[-1] + out[:, :, :a_last_dim] = a + out[:, :, a_last_dim:] = b + + +if IS_CI: + NUM_TOKENS_VALS = [256, 1024] +else: + NUM_TOKENS_VALS = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768] + +K_LINE_VALS = ["aot", "jit", "torch"] +K_LINE_NAMES = ["SGL AOT Kernel", "SGL JIT Kernel", "PyTorch"] +K_STYLES = [("orange", "-"), ("blue", "--"), ("green", "-.")] + + +def _create_concat_mla_k_data(num_tokens): + """Allocate oversized containers and slice to produce non-contiguous tensors.""" + k_nope_container = torch.randn( + (num_tokens, NUM_LOCAL_HEADS, QK_NOPE_HEAD_DIM + 128), + dtype=DTYPE, + device=DEVICE, + ) + k_nope = k_nope_container[:, :, :QK_NOPE_HEAD_DIM] + + k_rope_container = torch.randn( + (num_tokens, 1, 128 + QK_ROPE_HEAD_DIM), + dtype=DTYPE, + device=DEVICE, + ) + k_rope = k_rope_container[:, :, -QK_ROPE_HEAD_DIM:] + + k = torch.empty( + (num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM), + dtype=DTYPE, + device=DEVICE, + ) + return k, k_nope, k_rope + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens"], + x_vals=NUM_TOKENS_VALS, + line_arg="provider", + line_vals=K_LINE_VALS, + line_names=K_LINE_NAMES, + styles=K_STYLES, + ylabel="us", + plot_name="concat-mla-k-performance", + args={}, + ) +) +def bench_concat_mla_k(num_tokens: int, provider: str): + k, k_nope, k_rope = _create_concat_mla_k_data(num_tokens) + + FN_MAP = { + "aot": aot_concat_mla_k, + "jit": jit_concat_mla_k, + "torch": torch_concat_mla_k, + } + fn = lambda: FN_MAP[provider](k, k_nope, k_rope) + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if IS_CI: + ABSORB_Q_VALS = list(itertools.product([4, 16], [16])) +else: + ABSORB_Q_VALS = list(itertools.product([1, 4, 8, 16, 32], [1, 8, 32, 128])) + +Q_LINE_VALS = ["aot", "jit", "torch"] +Q_LINE_NAMES = ["SGL AOT Kernel", "SGL JIT Kernel", "PyTorch"] +Q_STYLES = [("orange", "-"), ("blue", "--"), ("green", "-.")] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["dim_0", "dim_1"], + x_vals=ABSORB_Q_VALS, + line_arg="provider", + line_vals=Q_LINE_VALS, + line_names=Q_LINE_NAMES, + styles=Q_STYLES, + ylabel="us", + plot_name="concat-mla-absorb-q-performance", + args={}, + ) +) +def bench_concat_mla_absorb_q(dim_0: int, dim_1: int, provider: str): + a = torch.randn(dim_0, dim_1, A_LAST_DIM, dtype=DTYPE, device=DEVICE) + b = torch.randn(dim_0, dim_1, B_LAST_DIM, dtype=DTYPE, device=DEVICE) + + if provider == "torch": + out = torch.empty( + dim_0, dim_1, A_LAST_DIM + B_LAST_DIM, dtype=DTYPE, device=DEVICE + ) + fn = lambda: torch_concat_mla_absorb_q(a, b, out) + else: + FN_MAP = { + "aot": aot_concat_mla_absorb_q, + "jit": jit_concat_mla_absorb_q, + } + fn = lambda: FN_MAP[provider](a, b) + + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + bench_concat_mla_k.run(print_data=True) + bench_concat_mla_absorb_q.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_fused_add_rmsnorm.py b/sglang/python/sglang/jit_kernel/benchmark/bench_fused_add_rmsnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..e4c0a7e8e825fd47719d53a271d0e370a161a8e5 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_fused_add_rmsnorm.py @@ -0,0 +1,73 @@ +import itertools + +import torch +import triton +import triton.testing +from flashinfer import fused_add_rmsnorm as fi_fused_add_rmsnorm + +from sglang.jit_kernel.benchmark.utils import is_in_ci +from sglang.jit_kernel.norm import fused_add_rmsnorm as jit_fused_add_rmsnorm + +IS_CI = is_in_ci() + + +def sglang_jit_fused_add_rmsnorm( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float +) -> None: + jit_fused_add_rmsnorm(input, residual, weight, eps) + + +def flashinfer_fused_add_rmsnorm( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float +) -> None: + fi_fused_add_rmsnorm(input, residual, weight, eps=eps) + + +DTYPE = torch.bfloat16 +DEVICE = "cuda" + +if IS_CI: + BS_LIST = [16] + HIDDEN_SIZE_LIST = [512, 2048] +else: + BS_LIST = [2**n for n in range(0, 14)] + HIDDEN_SIZE_LIST = [1536, 3072, 4096, 5120, 8192] + +LINE_VALS = ["jit", "fi"] +LINE_NAMES = ["SGL JIT Kernel", "FlashInfer"] +STYLES = [("orange", "-"), ("blue", "--"), ("green", "-."), ("red", ":")] + +configs = list(itertools.product(HIDDEN_SIZE_LIST, BS_LIST)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["hidden_size", "batch_size"], + x_vals=configs, + line_arg="provider", + line_vals=LINE_VALS, + line_names=LINE_NAMES, + styles=STYLES, + ylabel="us", + plot_name="fused-add-rmsnorm-performance", + args={}, + ) +) +def benchmark(hidden_size: int, batch_size: int, provider: str): + input = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE) + residual = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE) + weight = torch.randn(hidden_size, dtype=DTYPE, device=DEVICE) + FN_MAP = { + "jit": sglang_jit_fused_add_rmsnorm, + "fi": flashinfer_fused_add_rmsnorm, + } + fn = lambda: FN_MAP[provider]( + input.clone(), residual.clone(), weight, torch.finfo(torch.bfloat16).eps + ) + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) # type: ignore + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + benchmark.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_fused_norm_scale_shift.py b/sglang/python/sglang/jit_kernel/benchmark/bench_fused_norm_scale_shift.py new file mode 100644 index 0000000000000000000000000000000000000000..e1c7d72af8b12e58f735d66763814ef6877447ba --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_fused_norm_scale_shift.py @@ -0,0 +1,134 @@ +# Benchmarks SGLang fused layernorm/rmsnorm scale shift kernels +# 1. fused_norm_scale_shift +# 2. fused_scale_residual_norm_scale_shift +import itertools +from typing import Tuple + +import torch +import triton +import triton.testing + +from sglang.jit_kernel.benchmark.utils import is_in_ci +from sglang.multimodal_gen.runtime.layers.layernorm import ( + LayerNormScaleShift, + RMSNormScaleShift, + ScaleResidualLayerNormScaleShift, + ScaleResidualRMSNormScaleShift, +) + +if is_in_ci(): + B_RANGE, S_RANGE, D_RANGE = [1], [128], [1024] +else: + B_RANGE, S_RANGE, D_RANGE = [1], [128, 1024, 4096], [1024, 3072, 4096] + +NORM_TYPE_RANGE = ["layer", "rms"] +AFFINE_RANGE = [True, False] +DTYPE = torch.bfloat16 +DEVICE = "cuda" +EPS = 1e-5 +LINE_VALS = ["native", "cuda"] +LINE_NAMES = ["SGLang Native", "SGLang Fused"] +STYLES = [("red", "-"), ("blue", "--")] +config = list( + itertools.product(B_RANGE, S_RANGE, D_RANGE, NORM_TYPE_RANGE, AFFINE_RANGE) +) + + +def preprocess_layer(layer, affine: bool, D: int, DTYPE: torch.dtype): + if affine: + weight = torch.randn(D, dtype=DTYPE, device=DEVICE) + bias = torch.randn(D, dtype=DTYPE, device=DEVICE) + with torch.no_grad(): + layer.norm.weight.copy_(weight) + if hasattr(layer.norm, "bias"): + layer.norm.bias.copy_(bias) + layer.requires_grad_(False) + return layer.to(DEVICE) + + +# ============================================================================ +# Benchmark 1: fused_norm_scale_shift +# ============================================================================ +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["B", "S", "D", "norm_type", "affine"], + x_vals=config, + line_arg="provider", + line_vals=LINE_VALS, + line_names=LINE_NAMES, + styles=STYLES, + ylabel="us", + plot_name="fused_norm_scale_shift", + args={}, + ) +) +def bench_fused_norm_scale_shift( + B: int, S: int, D: int, norm_type, affine: bool, provider: str +) -> Tuple[float, float, float]: + x = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE) + scale = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE) + shift = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE) + if norm_type == "layer": + layer = LayerNormScaleShift(D, EPS, affine, dtype=DTYPE) + else: + layer = RMSNormScaleShift(D, EPS, affine, dtype=DTYPE) + layer = preprocess_layer(layer, affine, D, DTYPE) + if provider == "native": + fn = lambda: layer.forward_native(x, shift, scale) + else: + fn = lambda: layer.forward_cuda(x, shift, scale) + + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms # convert to us + + +# ============================================================================ +# Benchmark 2: fused_scale_residual_norm_scale_shift +# ============================================================================ +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["B", "S", "D", "norm_type", "affine"], + x_vals=config, + line_arg="provider", + line_vals=LINE_VALS, + line_names=LINE_NAMES, + styles=STYLES, + ylabel="us", + plot_name="fused_scale_residual_norm_scale_shift", + args={}, + ) +) +def bench_fused_scale_residual_norm_scale_shift( + B: int, S: int, D: int, norm_type, affine: bool, provider: str +) -> Tuple[float, float, float]: + residual = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE) + x = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE) + scale = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE) + shift = torch.randn(B, S, D, dtype=DTYPE, device=DEVICE) + gate = torch.randn(B, 1, D, dtype=DTYPE, device=DEVICE) + if norm_type == "layer": + layer = ScaleResidualLayerNormScaleShift(D, EPS, affine, dtype=DTYPE).to(DEVICE) + else: + layer = ScaleResidualRMSNormScaleShift(D, EPS, affine, dtype=DTYPE).to(DEVICE) + layer = preprocess_layer(layer, affine, D, DTYPE) + if provider == "native": + fn = lambda: layer.forward_native(residual, x, gate, shift, scale) + else: + fn = lambda: layer.forward_cuda(residual, x, gate, shift, scale) + + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms # convert to us + + +if __name__ == "__main__": + print(f"\n{'='*80}") + print("Benchmark: fused_norm_scale_shift") + print(f"{'='*80}\n") + bench_fused_norm_scale_shift.run(print_data=True) + + print(f"\n{'='*80}") + print("Benchmark: fused_scale_residual_norm_scale_shift") + print(f"{'='*80}\n") + bench_fused_scale_residual_norm_scale_shift.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_gptq_marlin.py b/sglang/python/sglang/jit_kernel/benchmark/bench_gptq_marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..a4c46ff037dc7827eb44419ac8bebd1970a5789e --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_gptq_marlin.py @@ -0,0 +1,118 @@ +import os + +import torch +import triton +import triton.testing +from sgl_kernel.scalar_type import scalar_types + +from sglang.jit_kernel.gptq_marlin import gptq_marlin_gemm as jit_gptq_marlin_gemm +from sglang.srt.layers.quantization.marlin_utils import marlin_make_workspace +from sglang.test.test_marlin_utils import marlin_quantize + +try: + from sgl_kernel import gptq_marlin_gemm as aot_gptq_marlin_gemm + + AOT_AVAILABLE = True +except ImportError: + AOT_AVAILABLE = False + +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + +# Fixed problem dimensions +SIZE_K = 4096 +SIZE_N = 4096 +GROUP_SIZE = 128 +QUANT_TYPE = scalar_types.uint4b8 + +# Quantize weights once +_b_weight = torch.randn((SIZE_K, SIZE_N), dtype=torch.float16, device="cuda") +_w_ref, _marlin_q_w, _marlin_s, _g_idx, _sort_indices, _ = marlin_quantize( + _b_weight, QUANT_TYPE, GROUP_SIZE, act_order=False +) +_workspace = marlin_make_workspace(_w_ref.device) + + +def _run_gemm(fn, a): + return fn( + a, + None, + _marlin_q_w, + _marlin_s, + None, + None, + _g_idx, + _sort_indices, + _workspace, + QUANT_TYPE, + a.shape[0], + SIZE_N, + SIZE_K, + is_k_full=True, + use_atomic_add=False, + use_fp32_reduce=False, + is_zp_float=False, + ) + + +def check_correctness(): + if not AOT_AVAILABLE: + print("sgl_kernel AOT not available, skipping correctness check") + return + a = torch.randn((16, SIZE_K), dtype=torch.float16, device="cuda") + out_jit = _run_gemm(jit_gptq_marlin_gemm, a) + out_aot = _run_gemm(aot_gptq_marlin_gemm, a) + torch.testing.assert_close(out_jit, out_aot, rtol=1e-3, atol=1e-3) + print("Correctness check passed (JIT vs AOT)") + + +if IS_CI: + m_range = [1, 16, 128] +else: + m_range = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + +if AOT_AVAILABLE: + line_vals = ["jit", "aot"] + line_names = ["JIT Kernel", "AOT Kernel"] + styles = [("blue", "-"), ("green", "-")] +else: + line_vals = ["jit"] + line_names = ["JIT Kernel"] + styles = [("blue", "-")] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["size_m"], + x_vals=m_range, + line_arg="provider", + line_vals=line_vals, + line_names=line_names, + styles=styles, + ylabel="us", + plot_name="gptq-marlin-gemm-performance", + args={}, + ) +) +def benchmark(size_m, provider): + device = torch.device("cuda") + a = torch.randn((size_m, SIZE_K), dtype=torch.float16, device=device) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "jit": + fn = lambda: _run_gemm(jit_gptq_marlin_gemm, a) + elif provider == "aot": + fn = lambda: _run_gemm(aot_gptq_marlin_gemm, a) + else: + raise ValueError(f"Unknown provider: {provider}") + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + check_correctness() + benchmark.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_gptq_marlin_repack.py b/sglang/python/sglang/jit_kernel/benchmark/bench_gptq_marlin_repack.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec65efec5d1ef8c6109d0e4a4bd63712d89f7e6 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_gptq_marlin_repack.py @@ -0,0 +1,104 @@ +import os + +import torch +import triton +import triton.testing +from sgl_kernel.scalar_type import scalar_types + +from sglang.jit_kernel.gptq_marlin_repack import gptq_marlin_repack as jit_fn +from sglang.srt.layers.quantization.utils import gptq_quantize_weights, pack_rows + +try: + from sgl_kernel import gptq_marlin_repack as aot_fn + + AOT_AVAILABLE = True +except ImportError: + AOT_AVAILABLE = False + +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + +# Fixed problem dimensions +SIZE_N = 4096 +NUM_BITS = 4 +QUANT_TYPE = scalar_types.uint4b8 +GROUP_SIZE = 128 + +# Pre-compute quantized weight for each size_k in the sweep +_cache = {} + + +def _get_inputs(size_k): + if size_k not in _cache: + size_n = SIZE_N + b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda") + _, q_w, _, _, _ = gptq_quantize_weights( + b_weight, QUANT_TYPE, GROUP_SIZE, act_order=False + ) + q_w_gptq = pack_rows(q_w, NUM_BITS, size_k, size_n) + sort_indices = torch.empty(0, dtype=torch.int, device="cuda") + _cache[size_k] = (q_w_gptq, sort_indices) + return _cache[size_k] + + +def check_correctness(): + if not AOT_AVAILABLE: + print("sgl_kernel AOT not available, skipping correctness check") + return + size_k = 4096 + q_w_gptq, sort_indices = _get_inputs(size_k) + out_jit = jit_fn(q_w_gptq, sort_indices, size_k, SIZE_N, NUM_BITS) + out_aot = aot_fn(q_w_gptq, sort_indices, size_k, SIZE_N, NUM_BITS) + torch.testing.assert_close(out_jit, out_aot, rtol=0, atol=0) + print("Correctness check passed (JIT vs AOT)") + + +if IS_CI: + k_range = [128, 1024, 4096] +else: + k_range = [128, 256, 512, 1024, 2048, 4096, 8192] + +if AOT_AVAILABLE: + line_vals = ["jit", "aot"] + line_names = ["JIT Kernel", "AOT Kernel"] + styles = [("blue", "-"), ("green", "-")] +else: + line_vals = ["jit"] + line_names = ["JIT Kernel"] + styles = [("blue", "-")] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["size_k"], + x_vals=k_range, + line_arg="provider", + line_vals=line_vals, + line_names=line_names, + styles=styles, + ylabel="us", + plot_name="gptq-marlin-repack-performance", + args={}, + ) +) +def benchmark(size_k, provider): + q_w_gptq, sort_indices = _get_inputs(size_k) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "jit": + fn = lambda: jit_fn(q_w_gptq, sort_indices, size_k, SIZE_N, NUM_BITS) + elif provider == "aot": + fn = lambda: aot_fn(q_w_gptq, sort_indices, size_k, SIZE_N, NUM_BITS) + else: + raise ValueError(f"Unknown provider: {provider}") + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + check_correctness() + benchmark.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_hicache.py b/sglang/python/sglang/jit_kernel/benchmark/bench_hicache.py new file mode 100644 index 0000000000000000000000000000000000000000..8929ce706fa26b5dd93adcc9aca4c001ffee21b4 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_hicache.py @@ -0,0 +1,425 @@ +"""Benchmark for HiCache JIT kernel performance. + +This benchmark tests the performance of KV cache transfer operations +between GPU and CPU (host pinned memory), comparing: +- SGL AOT Kernel: Pre-compiled transfer_kv kernels from sgl_kernel +- SGL JIT Kernel: JIT-compiled hicache kernels +- PyTorch Indexing: Plain PyTorch index copy +- PyTorch 2 Stream: PyTorch implementation using 2 CUDA streams + +Tests cover: +- One Layer: CPU->GPU +- All Layer: GPU->CPU + +Note: Uses do_bench instead of do_bench_cudagraph since CUDA graph +capture doesn't support CPU-GPU memory transfers. +""" + +import itertools +import os +from dataclasses import dataclass +from typing import Tuple + +import torch +import triton +import triton.testing +from sgl_kernel import transfer_kv_all_layer, transfer_kv_per_layer + +from sglang.jit_kernel.benchmark.utils import DEFAULT_QUANTILES, get_benchmark_range +from sglang.jit_kernel.hicache import ( + can_use_hicache_jit_kernel, + transfer_hicache_all_layer, + transfer_hicache_one_layer, +) + +# NOTE: Adjustable hyperparameters for better benchmark stability + +# NOTE: torch impl is too slow in benchmark +DISABLE_TORCH = os.environ.get("DISABLE_TORCH", "0") == "1" +PAGE_SIZE = 1 +ENABLE_SORT = True +GPU_CACHE_SIZE = 256 * 1024 # 256K tokens on GPU +HOST_CACHE_SIZE = 512 * 1024 # 512K tokens on CPU +NUM_LAYERS = 8 + + +@dataclass(frozen=True) +class HiCacheCache: + k_cache_cuda: torch.Tensor + v_cache_cuda: torch.Tensor + k_cache_host: torch.Tensor + v_cache_host: torch.Tensor + + def get_slice(self, num_layers: int, element_size: int) -> "HiCacheCache": + def slice_cuda(t: torch.Tensor) -> torch.Tensor: + needed_cuda = num_layers * GPU_CACHE_SIZE + return t.view(-1, element_size)[:needed_cuda].unflatten(0, (num_layers, -1)) + + def slice_host(t: torch.Tensor) -> torch.Tensor: + needed_host = num_layers * HOST_CACHE_SIZE + return t.view(-1, element_size)[:needed_host].unflatten(0, (num_layers, -1)) + + return HiCacheCache( + k_cache_cuda=slice_cuda(self.k_cache_cuda), + v_cache_cuda=slice_cuda(self.v_cache_cuda), + k_cache_host=slice_host(self.k_cache_host), + v_cache_host=slice_host(self.v_cache_host), + ) + + +def gen_indices( + size: int, max_size: int, *, page_size: int = PAGE_SIZE +) -> torch.Tensor: + def align(x: int) -> int: + return (x + page_size - 1) // page_size + + assert size <= max_size and max_size % page_size == 0 + indices = torch.randperm(align(max_size))[: align(size)] + offsets = torch.arange(page_size) + return (indices[:, None] * page_size + offsets).flatten().cuda()[:size] + + +def sglang_aot_transfer_one( + k_cache_dst: torch.Tensor, + v_cache_dst: torch.Tensor, + indices_dst: torch.Tensor, + k_cache_src: torch.Tensor, + v_cache_src: torch.Tensor, + indices_src: torch.Tensor, + item_size: int, +) -> None: + """SGL AOT Kernel for single layer transfer.""" + transfer_kv_per_layer( + k_cache_src, + k_cache_dst, + v_cache_src, + v_cache_dst, + indices_src, + indices_dst, + item_size, + ) + + +def sglang_jit_transfer_one( + k_cache_dst: torch.Tensor, + v_cache_dst: torch.Tensor, + indices_dst: torch.Tensor, + k_cache_src: torch.Tensor, + v_cache_src: torch.Tensor, + indices_src: torch.Tensor, + element_dim: int, +) -> None: + """SGL JIT Kernel for single layer transfer.""" + transfer_hicache_one_layer( + k_cache_dst, + v_cache_dst, + indices_dst, + k_cache_src, + v_cache_src, + indices_src, + element_dim=element_dim, + ) + + +def sglang_aot_transfer_all( + k_ptrs_dst: torch.Tensor, + v_ptrs_dst: torch.Tensor, + indices_dst: torch.Tensor, + k_ptrs_src: torch.Tensor, + v_ptrs_src: torch.Tensor, + indices_src: torch.Tensor, + item_size: int, + num_layers: int, +) -> None: + """SGL AOT Kernel for all layer transfer.""" + transfer_kv_all_layer( + k_ptrs_src, + k_ptrs_dst, + v_ptrs_src, + v_ptrs_dst, + indices_src, + indices_dst, + item_size, + num_layers, + ) + + +def sglang_jit_transfer_all( + k_ptrs_dst: torch.Tensor, + v_ptrs_dst: torch.Tensor, + indices_dst: torch.Tensor, + k_ptrs_src: torch.Tensor, + v_ptrs_src: torch.Tensor, + indices_src: torch.Tensor, + stride_bytes: int, + element_size: int, +) -> None: + """SGL JIT Kernel for all layer transfer.""" + transfer_hicache_all_layer( + k_ptrs_dst, + v_ptrs_dst, + indices_dst, + k_ptrs_src, + v_ptrs_src, + indices_src, + kv_cache_src_stride_bytes=stride_bytes, + kv_cache_dst_stride_bytes=stride_bytes, + element_size=element_size, + ) + + +def pytorch_transfer( + k_cache_dst: torch.Tensor, + v_cache_dst: torch.Tensor, + indices_dst_on_dst: torch.Tensor, + k_cache_src: torch.Tensor, + v_cache_src: torch.Tensor, + indices_src_on_src: torch.Tensor, +) -> None: + """PyTorch indexing baseline.""" + dst_device = k_cache_dst.device + k_cache_dst[indices_dst_on_dst] = k_cache_src[indices_src_on_src].to(dst_device) + v_cache_dst[indices_dst_on_dst] = v_cache_src[indices_src_on_src].to(dst_device) + + +# Benchmark configuration + +BS_RANGE = get_benchmark_range( + full_range=[2**n for n in range(0, 16)], + ci_range=[16], +) +ELEMENT_SIZE_RANGE = get_benchmark_range( + full_range=[64, 128, 256, 512, 1024], + ci_range=[1024], +) + +LINE_VALS = ["aot", "jit", "pytorch"] +LINE_NAMES = ["SGL AOT Kernel", "SGL JIT Kernel", "PyTorch"] +STYLES = [("orange", "-"), ("blue", "--"), ("red", ":")] + +CONFIGS = list(itertools.product(ELEMENT_SIZE_RANGE, BS_RANGE)) + + +# ============================================================================= +# One Layer Benchmarks +# ============================================================================= + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["element_size", "batch_size"], + x_vals=CONFIGS, + line_arg="provider", + line_vals=LINE_VALS, + line_names=LINE_NAMES, + styles=STYLES, + ylabel="us", + plot_name="hicache-one-layer-h2d", + args={}, + ) +) +def benchmark_one_layer_h2d( + element_size: int, batch_size: int, provider: str +) -> Tuple[float, float, float]: + """One Layer: Host (CPU) -> Device (GPU).""" + global cache + cache_local = cache.get_slice(num_layers=NUM_LAYERS, element_size=element_size) + k_cache_src = cache_local.k_cache_host + v_cache_src = cache_local.v_cache_host + k_cache_dst = cache_local.k_cache_cuda + v_cache_dst = cache_local.v_cache_cuda + # to avoid fluctutation, we set the seed as const + torch.manual_seed(batch_size * 65536 + element_size) + indices_src_gpu = gen_indices(batch_size, HOST_CACHE_SIZE) + indices_dst_gpu = gen_indices(batch_size, GPU_CACHE_SIZE) + + # sort by host indices to improve host access performance + if ENABLE_SORT: + indices_src_gpu, mapping = indices_src_gpu.sort() + indices_dst_gpu = indices_dst_gpu[mapping] + indices_src_cpu = indices_src_gpu.cpu() + torch.cuda.synchronize() + + element_bytes = element_size * k_cache_src.element_size() + + FN_MAP = { + "aot": lambda: [ + sglang_aot_transfer_one( + k_cache_dst[i], + v_cache_dst[i], + indices_dst_gpu, + k_cache_src[i], + v_cache_src[i], + indices_src_gpu, + element_bytes, + ) + for i in range(NUM_LAYERS) + ], + "jit": lambda: [ + sglang_jit_transfer_one( + k_cache_dst[i], + v_cache_dst[i], + indices_dst_gpu, + k_cache_src[i], + v_cache_src[i], + indices_src_gpu, + element_size, + ) + for i in range(NUM_LAYERS) + ], + "pytorch": lambda: [ + pytorch_transfer( + k_cache_dst[i], + v_cache_dst[i], + indices_dst_gpu, + k_cache_src[i], + v_cache_src[i], + indices_src_cpu, + ) + for i in range(NUM_LAYERS) + ], + } + + if provider == "jit" and not can_use_hicache_jit_kernel(element_size=element_bytes): + return (float("nan"), float("nan"), float("nan")) + + if DISABLE_TORCH and provider in ["pytorch"]: + return (float("nan"), float("nan"), float("nan")) + + ms, min_ms, max_ms = triton.testing.do_bench( # type: ignore + FN_MAP[provider], quantiles=DEFAULT_QUANTILES, warmup=5, rep=25 + ) + return ( + 1000 * ms / NUM_LAYERS, + 1000 * max_ms / NUM_LAYERS, + 1000 * min_ms / NUM_LAYERS, + ) + + +# ============================================================================= +# All Layer Benchmarks +# ============================================================================= + + +def _create_ptr_tensor(tensors, device="cuda"): + """Create a tensor of data pointers.""" + return torch.tensor( + [t.data_ptr() for t in tensors], + dtype=torch.uint64, + device=device, + ) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["element_size", "batch_size"], + x_vals=CONFIGS, + line_arg="provider", + line_vals=LINE_VALS, + line_names=LINE_NAMES, + styles=STYLES, + ylabel="us", + plot_name="hicache-all-layer-d2h", + args={}, + ) +) +def benchmark_all_layer_d2h( + element_size: int, batch_size: int, provider: str +) -> Tuple[float, float, float]: + """All Layer: Device (GPU) -> Host (CPU).""" + global cache + cache_local = cache.get_slice(num_layers=NUM_LAYERS, element_size=element_size) + k_caches_src = cache_local.k_cache_cuda + v_caches_src = cache_local.v_cache_cuda + k_caches_dst = cache_local.k_cache_host + v_caches_dst = cache_local.v_cache_host + # to avoid fluctutation, we set the seed as const + torch.manual_seed(batch_size * 65536 + element_size) + + indices_src_gpu = gen_indices(batch_size, GPU_CACHE_SIZE) + indices_dst_gpu = gen_indices(batch_size, HOST_CACHE_SIZE) + # sort by host indices to improve host access performance + if ENABLE_SORT: + indices_dst_gpu, mapping = indices_dst_gpu.sort() + indices_src_gpu = indices_src_gpu[mapping] + indices_dst_cpu = indices_dst_gpu.cpu() + torch.cuda.synchronize() + + element_bytes = element_size * k_caches_src.element_size() + + k_ptrs_src = _create_ptr_tensor([k_caches_src[i] for i in range(NUM_LAYERS)]) + v_ptrs_src = _create_ptr_tensor([v_caches_src[i] for i in range(NUM_LAYERS)]) + k_ptrs_dst = _create_ptr_tensor([k_caches_dst[i] for i in range(NUM_LAYERS)]) + v_ptrs_dst = _create_ptr_tensor([v_caches_dst[i] for i in range(NUM_LAYERS)]) + + FN_MAP = { + "aot": lambda: sglang_aot_transfer_all( + k_ptrs_dst, + v_ptrs_dst, + indices_dst_gpu, + k_ptrs_src, + v_ptrs_src, + indices_src_gpu, + element_bytes, + NUM_LAYERS, + ), + "jit": lambda: sglang_jit_transfer_all( + k_ptrs_dst, + v_ptrs_dst, + indices_dst_gpu, + k_ptrs_src, + v_ptrs_src, + indices_src_gpu, + element_bytes, + element_bytes, + ), + "pytorch": lambda: [ + pytorch_transfer( + k_caches_dst[i], + v_caches_dst[i], + indices_dst_cpu, + k_caches_src[i], + v_caches_src[i], + indices_src_gpu, + ) + for i in range(NUM_LAYERS) + ], + } + + if provider == "jit" and not can_use_hicache_jit_kernel(element_size=element_bytes): + return (float("nan"), float("nan"), float("nan")) + + if DISABLE_TORCH and provider in ["pytorch"]: + return (float("nan"), float("nan"), float("nan")) + + ms, min_ms, max_ms = triton.testing.do_bench( # type: ignore + FN_MAP[provider], quantiles=DEFAULT_QUANTILES, warmup=5, rep=25 + ) + return ( + 1000 * ms / NUM_LAYERS, + 1000 * max_ms / NUM_LAYERS, + 1000 * min_ms / NUM_LAYERS, + ) + + +if __name__ == "__main__": + MAX_SIZE = max(ELEMENT_SIZE_RANGE) + DEVICE_SHAPE = (NUM_LAYERS * GPU_CACHE_SIZE, MAX_SIZE) + HOST_SHAPE = (NUM_LAYERS * HOST_CACHE_SIZE, MAX_SIZE) + + cache = HiCacheCache( + k_cache_cuda=torch.empty(DEVICE_SHAPE, dtype=torch.bfloat16, device="cuda"), + v_cache_cuda=torch.empty(DEVICE_SHAPE, dtype=torch.bfloat16, device="cuda"), + k_cache_host=torch.empty(HOST_SHAPE, dtype=torch.bfloat16, pin_memory=True), + v_cache_host=torch.empty(HOST_SHAPE, dtype=torch.bfloat16, pin_memory=True), + ) + + print("=" * 60) + print("One Layer: Host -> Device (CPU -> GPU)") + print("=" * 60) + benchmark_one_layer_h2d.run(print_data=True) + + print("\n" + "=" * 60) + print("All Layer: Device -> Host (GPU -> CPU) [per-layer avg]") + print("=" * 60) + benchmark_all_layer_d2h.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_moe_wna16_marlin.py b/sglang/python/sglang/jit_kernel/benchmark/bench_moe_wna16_marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..ab7e1ec9caad8d60371419a3c89e4cfe65cfbab9 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_moe_wna16_marlin.py @@ -0,0 +1,251 @@ +import os + +import torch +import triton +import triton.testing +from sgl_kernel.scalar_type import scalar_types + +from sglang.jit_kernel.moe_wna16_marlin import moe_wna16_marlin_gemm as jit_fn +from sglang.srt.layers.moe.fused_moe_triton import moe_align_block_size +from sglang.test.test_marlin_utils import marlin_quantize + +try: + from sgl_kernel import moe_wna16_marlin_gemm as _aot_import # noqa: F401 + + AOT_AVAILABLE = True +except (ImportError, AttributeError): + AOT_AVAILABLE = False + +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + + +def stack_and_dev(tensors): + dev = tensors[0].device + return torch.stack(tensors, dim=0).to(dev) + + +# Fixed problem dimensions +E = 8 +SIZE_K = 4096 +SIZE_N = 4096 +GROUP_SIZE = 128 +TOPK = 2 +QUANT_TYPE = scalar_types.uint4b8 +DTYPE = torch.float16 +BLOCK_SIZE_M = 64 + +# Quantize weights once (per-expert) +torch.manual_seed(0) +_qweight_l, _scales_l, _w_ref_l = [], [], [] +for i in range(E): + _w = torch.randn((SIZE_N, SIZE_K), dtype=DTYPE, device="cuda") / 20 + _perm = torch.randperm(SIZE_K) + _w_ref, _qw, _s, _, _, _ = marlin_quantize(_w, QUANT_TYPE, GROUP_SIZE, False, _perm) + _w_ref_l.append(_w_ref.T) + _qweight_l.append(_qw) + _scales_l.append(_s) + +_qweight = stack_and_dev(_qweight_l).contiguous() +_scales = stack_and_dev(_scales_l) + +_sms = torch.cuda.get_device_properties("cuda").multi_processor_count + + +def _make_inputs(size_m): + a = torch.randn((size_m, SIZE_K), dtype=DTYPE, device="cuda") / 10 + score = torch.randn((size_m, E), dtype=DTYPE, device="cuda") + score_softmax = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weights, topk_ids = torch.topk(score_softmax, TOPK) + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, BLOCK_SIZE_M, E + ) + + max_workspace_size = (SIZE_N // 64) * (sorted_token_ids.size(0) // BLOCK_SIZE_M) + max_workspace_size = min(max_workspace_size, _sms * 4) + workspace = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") + + c = torch.empty((size_m * TOPK, SIZE_N), dtype=DTYPE, device="cuda") + + return ( + a, + c, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + workspace, + ) + + +def _run_jit( + a, + c, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + workspace, + size_m, +): + return jit_fn( + a, + c, + _qweight, + None, + _scales, + None, + None, + None, + None, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + moe_block_size=BLOCK_SIZE_M, + top_k=TOPK, + mul_topk_weights=False, + is_ep=False, + b_q_type=QUANT_TYPE, + size_m=size_m, + size_n=SIZE_N, + size_k=SIZE_K, + is_k_full=True, + use_atomic_add=True, + use_fp32_reduce=True, + is_zp_float=False, + ) + + +def _run_aot( + a, + c, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + workspace, + size_m, +): + return torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default( + a, + c, + _qweight, + None, + _scales, + None, + None, + None, + None, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + moe_block_size=BLOCK_SIZE_M, + top_k=TOPK, + mul_topk_weights=False, + is_ep=False, + b_q_type_id=QUANT_TYPE.id, + size_m=size_m, + size_n=SIZE_N, + size_k=SIZE_K, + is_k_full=True, + use_atomic_add=True, + use_fp32_reduce=True, + is_zp_float=False, + ) + + +def check_correctness(): + if not AOT_AVAILABLE: + print("sgl_kernel AOT not available, skipping correctness check") + return + size_m = 16 + a, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, ntp, workspace = ( + _make_inputs(size_m) + ) + c_jit = c.clone() + c_aot = c.clone() + _run_jit( + a, c_jit, topk_weights, sorted_token_ids, expert_ids, ntp, workspace, size_m + ) + _run_aot( + a, c_aot, topk_weights, sorted_token_ids, expert_ids, ntp, workspace, size_m + ) + torch.testing.assert_close(c_jit, c_aot, rtol=1e-3, atol=1e-3) + print("Correctness check passed (JIT vs AOT)") + + +if IS_CI: + m_range = [1, 16, 128] +else: + m_range = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + +if AOT_AVAILABLE: + line_vals = ["jit", "aot"] + line_names = ["JIT Kernel", "AOT Kernel"] + styles = [("blue", "-"), ("green", "-")] +else: + line_vals = ["jit"] + line_names = ["JIT Kernel"] + styles = [("blue", "-")] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["size_m"], + x_vals=m_range, + line_arg="provider", + line_vals=line_vals, + line_names=line_names, + styles=styles, + ylabel="us", + plot_name="moe-wna16-marlin-gemm-performance", + args={}, + ) +) +def benchmark(size_m, provider): + a, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, ntp, workspace = ( + _make_inputs(size_m) + ) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "jit": + fn = lambda: _run_jit( + a, + c.clone(), + topk_weights, + sorted_token_ids, + expert_ids, + ntp, + workspace, + size_m, + ) + elif provider == "aot": + fn = lambda: _run_aot( + a, + c.clone(), + topk_weights, + sorted_token_ids, + expert_ids, + ntp, + workspace, + size_m, + ) + else: + raise ValueError(f"Unknown provider: {provider}") + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + check_correctness() + benchmark.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_nvfp4_blockwise_moe.py b/sglang/python/sglang/jit_kernel/benchmark/bench_nvfp4_blockwise_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..54d37ddeb84353cfc5dcc06755c3cdfc520a3c65 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_nvfp4_blockwise_moe.py @@ -0,0 +1,250 @@ +from __future__ import annotations + +from typing import Any + +import torch +import triton + +from sglang.jit_kernel.benchmark.utils import get_benchmark_range, run_benchmark +from sglang.jit_kernel.nvfp4 import ( + cutlass_fp4_group_mm, + scaled_fp4_experts_quant, + scaled_fp4_quant, +) + +FLOAT4_E2M1_MAX = 6.0 +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + + +def _round_up(x: int, y: int) -> int: + return ((x + y - 1) // y) * y + + +def _expert_offsets(m_per_expert: list[int], device: torch.device) -> torch.Tensor: + offsets = [0] + for m in m_per_expert: + offsets.append(offsets[-1] + m) + return torch.tensor(offsets, dtype=torch.int32, device=device) + + +def _blockscale_offsets(m_per_expert: list[int], device: torch.device) -> torch.Tensor: + offsets = [0] + for m in m_per_expert: + offsets.append(offsets[-1] + _round_up(m, 128)) + return torch.tensor(offsets, dtype=torch.int32, device=device) + + +def _prepare_case( + total_tokens: int, n: int, k: int, num_experts: int, dtype: torch.dtype +) -> dict[str, Any]: + device = torch.device("cuda") + base = total_tokens // num_experts + rem = total_tokens % num_experts + m_per_expert = [base + (1 if i < rem else 0) for i in range(num_experts)] + + expert_offsets_full = _expert_offsets(m_per_expert, device) + blockscale_offsets_full = _blockscale_offsets(m_per_expert, device) + + a = torch.randn((total_tokens, k), device=device, dtype=dtype) * 0.1 + b = torch.randn((num_experts, n, k), device=device, dtype=dtype) * 0.1 + + a_global_scale = torch.empty((num_experts,), device=device, dtype=torch.float32) + for i in range(num_experts): + start = int(expert_offsets_full[i].item()) + end = int(expert_offsets_full[i + 1].item()) + a_global_scale[i] = ( + FLOAT8_E4M3_MAX + * FLOAT4_E2M1_MAX + / a[start:end].abs().max().to(torch.float32) + ) + + b_global_scale = torch.empty((num_experts,), device=device, dtype=torch.float32) + for i in range(num_experts): + b_global_scale[i] = ( + FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b[i].abs().max().to(torch.float32) + ) + + a_fp4, a_blockscale = scaled_fp4_experts_quant( + a, + a_global_scale, + expert_offsets_full, + blockscale_offsets_full, + topk=1, + ) + + b_fp4 = torch.empty((num_experts, n, k // 2), device=device, dtype=torch.uint8) + b_blockscale = torch.empty( + (num_experts, _round_up(n, 128), _round_up(k // 16, 4)), + device=device, + dtype=torch.float8_e4m3fn, + ) + for i in range(num_experts): + b_fp4_i, b_scale_i = scaled_fp4_quant(b[i], b_global_scale[i]) + b_fp4[i].copy_(b_fp4_i) + b_blockscale[i].copy_(b_scale_i) + + alphas = (1.0 / (a_global_scale * b_global_scale)).to(torch.float32) + params = { + "ab_strides": torch.full((num_experts,), k, dtype=torch.int64, device=device), + "c_strides": torch.full((num_experts,), n, dtype=torch.int64, device=device), + "problem_sizes": torch.tensor( + [[m, n, k] for m in m_per_expert], dtype=torch.int32, device=device + ), + "expert_offsets": expert_offsets_full[:-1].contiguous(), + "blockscale_offsets": blockscale_offsets_full[:-1].contiguous(), + "a_ptrs": torch.empty((num_experts,), dtype=torch.int64, device=device), + "b_ptrs": torch.empty((num_experts,), dtype=torch.int64, device=device), + "out_ptrs": torch.empty((num_experts,), dtype=torch.int64, device=device), + "a_scales_ptrs": torch.empty((num_experts,), dtype=torch.int64, device=device), + "b_scales_ptrs": torch.empty((num_experts,), dtype=torch.int64, device=device), + "alpha_ptrs": torch.empty((num_experts,), dtype=torch.int64, device=device), + "layout_sfa": torch.empty((num_experts, 5), dtype=torch.int64, device=device), + "layout_sfb": torch.empty((num_experts, 5), dtype=torch.int64, device=device), + } + + expert_ranges: list[tuple[int, int]] = [] + start = 0 + for m in m_per_expert: + end = start + m + expert_ranges.append((start, end)) + start = end + + return { + "a": a, + "b": b, + "a_fp4": a_fp4, + "b_fp4": b_fp4, + "a_blockscale": a_blockscale, + "b_blockscale": b_blockscale, + "alphas": alphas, + "params": params, + "expert_offsets_full": expert_offsets_full, + "expert_ranges": expert_ranges, + "dtype": dtype, + } + + +def _torch_ref_group_mm(case: dict[str, Any]) -> torch.Tensor: + a = case["a"] + b = case["b"] + dtype = case["dtype"] + expert_ranges = case["expert_ranges"] + total_tokens = a.shape[0] + n = b.shape[1] + out = torch.empty((total_tokens, n), device=a.device, dtype=dtype) + for i, (start, end) in enumerate(expert_ranges): + out[start:end] = torch.matmul(a[start:end], b[i].t()) + return out + + +def _aot_cutlass_fp4_group_mm(case: dict[str, Any]) -> torch.Tensor: + a_fp4 = case["a_fp4"] + b_fp4 = case["b_fp4"] + a_blockscale = case["a_blockscale"] + b_blockscale = case["b_blockscale"] + alphas = case["alphas"] + params = case["params"] + out_dtype = case["dtype"] + + out = torch.empty( + (a_fp4.shape[0], b_fp4.shape[1]), device=a_fp4.device, dtype=out_dtype + ) + torch.ops.sgl_kernel.cutlass_fp4_group_mm.default( + out, + a_fp4, + b_fp4, + a_blockscale, + b_blockscale, + alphas, + params["ab_strides"], + params["c_strides"], + params["problem_sizes"], + params["expert_offsets"], + params["blockscale_offsets"], + ) + return out + + +def _probe_legacy_aot_group_mm() -> tuple[bool, str]: + if not torch.cuda.is_available(): + return False, "CUDA is not available." + try: + import sgl_kernel # noqa: F401 + except Exception as e: + return False, f"import sgl_kernel failed: {e}" + if not hasattr(torch.ops, "sgl_kernel"): + return False, "torch.ops.sgl_kernel is not registered." + op = getattr(torch.ops.sgl_kernel, "cutlass_fp4_group_mm", None) + if op is None or not hasattr(op, "default"): + return False, "torch.ops.sgl_kernel.cutlass_fp4_group_mm.default is missing." + try: + case = _prepare_case(64, 256, 128, 4, torch.bfloat16) + _aot_cutlass_fp4_group_mm(case) + torch.cuda.synchronize() + except Exception as e: + return False, f"calling AOT grouped_mm op failed: {e}" + return True, "" + + +_AOT_GROUP_MM_AVAILABLE, _AOT_GROUP_MM_REASON = _probe_legacy_aot_group_mm() + +shape_range = get_benchmark_range( + full_range=[(128, 256, 128, 4), (256, 512, 128, 8), (512, 512, 256, 8)], + ci_range=[(128, 256, 128, 4)], +) + +line_vals = ["jit"] +line_names = ["JIT NVFP4 MoE GroupMM"] +styles = [("green", "-")] +if _AOT_GROUP_MM_AVAILABLE: + line_vals.append("aot_sgl_kernel") + line_names.append("AOT NVFP4 MoE GroupMM") + styles.append(("orange", "-")) +line_vals.append("torch_ref") +line_names.append("Torch Ref") +styles.append(("blue", "-")) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["total_tokens", "n", "k", "num_experts"], + x_vals=shape_range, + x_log=False, + line_arg="provider", + line_vals=line_vals, + line_names=line_names, + styles=styles, + ylabel="us", + plot_name="nvfp4-blockwise-moe-groupmm-performance", + args={}, + ) +) +def benchmark(total_tokens, n, k, num_experts, provider): + case = _prepare_case(total_tokens, n, k, num_experts, torch.bfloat16) + + if provider == "jit": + fn = lambda: cutlass_fp4_group_mm( + case["a_fp4"], + case["b_fp4"], + case["a_blockscale"], + case["b_blockscale"], + case["alphas"], + case["dtype"], + case["params"], + ) + elif provider == "aot_sgl_kernel": + fn = lambda: _aot_cutlass_fp4_group_mm(case) + elif provider == "torch_ref": + fn = lambda: _torch_ref_group_mm(case) + else: + raise ValueError(f"Unknown provider: {provider}") + + return run_benchmark(fn) + + +if __name__ == "__main__": + if not _AOT_GROUP_MM_AVAILABLE: + print( + f"[info] legacy AOT grouped_mm baseline unavailable: {_AOT_GROUP_MM_REASON}" + ) + benchmark.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_nvfp4_quant.py b/sglang/python/sglang/jit_kernel/benchmark/bench_nvfp4_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..b1a7a24e69d0e101c7f73050feea322ec23e2592 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_nvfp4_quant.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +import torch +import triton + +from sglang.jit_kernel.benchmark.utils import get_benchmark_range, run_benchmark +from sglang.jit_kernel.nvfp4 import scaled_fp4_quant + +FLOAT4_E2M1_MAX = 6.0 +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max +BLOCK_SIZE = 16 + +try: + from flashinfer import fp4_quantize as flashinfer_fp4_quantize +except Exception: + flashinfer_fp4_quantize = None + + +def _torch_ref_quant(input: torch.Tensor, input_global_scale: torch.Tensor): + m, n = input.shape + x = input.view(m, n // BLOCK_SIZE, BLOCK_SIZE) + vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32) + scale = input_global_scale * (vec_max / FLOAT4_E2M1_MAX) + scale = scale.to(torch.float8_e4m3fn).to(torch.float32) + output_scale = torch.where(scale == 0, torch.zeros_like(scale), 1.0 / scale) + + scaled_x = x.to(torch.float32) * output_scale + clipped = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n) + + rounded = clipped.clone() + rounded[(rounded >= 0.0) & (rounded <= 0.25)] = 0.0 + rounded[(rounded > 0.25) & (rounded < 0.75)] = 0.5 + rounded[(rounded >= 0.75) & (rounded <= 1.25)] = 1.0 + rounded[(rounded > 1.25) & (rounded < 1.75)] = 1.5 + rounded[(rounded >= 1.75) & (rounded <= 2.5)] = 2.0 + rounded[(rounded > 2.5) & (rounded < 3.5)] = 3.0 + rounded[(rounded >= 3.5) & (rounded <= 5.0)] = 4.0 + rounded[rounded > 5.0] = 6.0 + + # This baseline intentionally keeps work on GPU but does not pack to uint8. + return rounded, scale + + +def _aot_scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor): + m, n = input.shape + output = torch.empty((m, n // 2), device=input.device, dtype=torch.uint8) + rounded_m = ((m + 128 - 1) // 128) * 128 + scale_n = n // BLOCK_SIZE + rounded_n = ((scale_n + 4 - 1) // 4) * 4 + output_scale = torch.empty( + (rounded_m, rounded_n // 4), device=input.device, dtype=torch.int32 + ) + torch.ops.sgl_kernel.scaled_fp4_quant.default( + output, input, output_scale, input_global_scale + ) + return output, output_scale.view(torch.float8_e4m3fn) + + +def _probe_legacy_aot_quant() -> tuple[bool, str]: + if not torch.cuda.is_available(): + return False, "CUDA is not available." + try: + import sgl_kernel # noqa: F401 + except Exception as e: + return False, f"import sgl_kernel failed: {e}" + if not hasattr(torch.ops, "sgl_kernel"): + return False, "torch.ops.sgl_kernel is not registered." + op = getattr(torch.ops.sgl_kernel, "scaled_fp4_quant", None) + if op is None or not hasattr(op, "default"): + return False, "torch.ops.sgl_kernel.scaled_fp4_quant.default is missing." + try: + x = torch.randn((16, 64), dtype=torch.bfloat16, device="cuda") + global_scale = ( + FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / torch.abs(x).max().to(torch.float32) + ) + _aot_scaled_fp4_quant(x, global_scale) + torch.cuda.synchronize() + except Exception as e: + return False, f"calling AOT quant op failed: {e}" + return True, "" + + +_AOT_QUANT_AVAILABLE, _AOT_QUANT_REASON = _probe_legacy_aot_quant() + + +def _probe_flashinfer_quant() -> tuple[bool, str]: + if flashinfer_fp4_quantize is None: + return False, "import flashinfer.fp4_quantize failed." + if not torch.cuda.is_available(): + return False, "CUDA is not available." + try: + x = torch.randn((16, 64), dtype=torch.bfloat16, device="cuda") + global_scale = ( + FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / torch.abs(x).max().to(torch.float32) + ) + flashinfer_fp4_quantize( + x, + global_scale, + BLOCK_SIZE, # sf_vec_size + False, # use_ue8m0 + True, # is_sf_swizzled_layout + ) + torch.cuda.synchronize() + except Exception as e: + return False, f"calling flashinfer.fp4_quantize failed: {e}" + return True, "" + + +_FLASHINFER_QUANT_AVAILABLE, _FLASHINFER_QUANT_REASON = _probe_flashinfer_quant() + +shape_range = get_benchmark_range( + full_range=[(128, 2048), (512, 4096), (1024, 4096), (2048, 8192)], + ci_range=[(128, 2048)], +) + +line_vals = [] +line_names = [] +styles = [] +if _FLASHINFER_QUANT_AVAILABLE: + line_vals.append("flashinfer") + line_names.append("FlashInfer FP4 Quant") + styles.append(("purple", "-")) +line_vals.append("jit") +line_names.append("JIT NVFP4 Quant") +styles.append(("green", "-")) +if _AOT_QUANT_AVAILABLE: + line_vals.append("aot_sgl_kernel") + line_names.append("AOT NVFP4 Quant") + styles.append(("orange", "-")) +line_vals.append("torch_ref") +line_names.append("Torch Ref") +styles.append(("blue", "-")) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["m", "n"], + x_vals=shape_range, + x_log=False, + line_arg="provider", + line_vals=line_vals, + line_names=line_names, + styles=styles, + ylabel="us", + plot_name="nvfp4-quant-performance", + args={}, + ) +) +def benchmark(m, n, provider): + x = torch.randn((m, n), dtype=torch.bfloat16, device="cuda") + tensor_amax = torch.abs(x).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + + if provider == "jit": + fn = lambda: scaled_fp4_quant(x, global_scale) + elif provider == "flashinfer": + fn = lambda: flashinfer_fp4_quantize( + x, + global_scale, + BLOCK_SIZE, # sf_vec_size + False, # use_ue8m0 + True, # is_sf_swizzled_layout + ) + elif provider == "aot_sgl_kernel": + fn = lambda: _aot_scaled_fp4_quant(x, global_scale) + elif provider == "torch_ref": + fn = lambda: _torch_ref_quant(x, global_scale) + else: + raise ValueError(f"Unknown provider: {provider}") + + return run_benchmark(fn) + + +if __name__ == "__main__": + if not _FLASHINFER_QUANT_AVAILABLE: + print( + f"[info] flashinfer quant baseline unavailable: {_FLASHINFER_QUANT_REASON}" + ) + if not _AOT_QUANT_AVAILABLE: + print(f"[info] legacy AOT quant baseline unavailable: {_AOT_QUANT_REASON}") + benchmark.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_nvfp4_scaled_mm.py b/sglang/python/sglang/jit_kernel/benchmark/bench_nvfp4_scaled_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..8c0bdc32203d95fccab7232ce8fff0078a20a036 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_nvfp4_scaled_mm.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import torch +import triton + +from sglang.jit_kernel.benchmark.utils import get_benchmark_range, run_benchmark +from sglang.jit_kernel.nvfp4 import cutlass_scaled_fp4_mm, scaled_fp4_quant + +FLOAT4_E2M1_MAX = 6.0 +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max +BLOCK_SIZE = 16 + +K_E2M1_TO_FLOAT = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + 0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + + +def _dequantize_to_fp16( + tensor_fp4: torch.Tensor, tensor_sf: torch.Tensor, global_scale: torch.Tensor +): + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + flat = tensor_fp4.flatten() + high = (flat & 0xF0) >> 4 + low = flat & 0x0F + f_h = torch.tensor([K_E2M1_TO_FLOAT[x] for x in high], device=tensor_fp4.device) + f_l = torch.tensor([K_E2M1_TO_FLOAT[x] for x in low], device=tensor_fp4.device) + val = torch.stack((f_l, f_h), dim=-1).reshape(m, k) + + rounded_m = ((m + 128 - 1) // 128) * 128 + scale_n = k // BLOCK_SIZE + rounded_n = ((scale_n + 4 - 1) // 4) * 4 + sf = tensor_sf.view(torch.float8_e4m3fn) + tmp = torch.reshape(sf, (1, rounded_m // 128, rounded_n // 4, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + scale = torch.reshape(tmp, (rounded_m, rounded_n))[:m, :scale_n].to(torch.float32) + scale = scale / global_scale + + return (val.view(m, scale_n, BLOCK_SIZE) * scale.unsqueeze(-1)).reshape(m, k) + + +def _aot_cutlass_scaled_fp4_mm( + a: torch.Tensor, + b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, + alpha: torch.Tensor, + out_dtype: torch.dtype, +) -> torch.Tensor: + out = torch.empty((a.shape[0], b.shape[0]), dtype=out_dtype, device=a.device) + torch.ops.sgl_kernel.cutlass_scaled_fp4_mm.default( + out, a, b, block_scale_a, block_scale_b, alpha + ) + return out + + +def _probe_legacy_aot_scaled_mm() -> tuple[bool, str]: + if not torch.cuda.is_available(): + return False, "CUDA is not available." + try: + import sgl_kernel # noqa: F401 + except Exception as e: + return False, f"import sgl_kernel failed: {e}" + if not hasattr(torch.ops, "sgl_kernel"): + return False, "torch.ops.sgl_kernel is not registered." + op = getattr(torch.ops.sgl_kernel, "cutlass_scaled_fp4_mm", None) + if op is None or not hasattr(op, "default"): + return False, "torch.ops.sgl_kernel.cutlass_scaled_fp4_mm.default is missing." + try: + m, n, k = 16, 32, 64 + a = torch.randn((m, k), dtype=torch.bfloat16, device="cuda") + b = torch.randn((n, k), dtype=torch.bfloat16, device="cuda") + a_global_scale = ( + FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / torch.amax(a.flatten(), dim=-1) + ).to(torch.float32) + b_global_scale = ( + FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / torch.amax(b.flatten(), dim=-1) + ).to(torch.float32) + alpha = 1.0 / (a_global_scale * b_global_scale) + a_fp4, a_sf = scaled_fp4_quant(a, a_global_scale) + b_fp4, b_sf = scaled_fp4_quant(b, b_global_scale) + _aot_cutlass_scaled_fp4_mm(a_fp4, b_fp4, a_sf, b_sf, alpha, torch.bfloat16) + torch.cuda.synchronize() + except Exception as e: + return False, f"calling AOT scaled_mm op failed: {e}" + return True, "" + + +_AOT_SCALED_MM_AVAILABLE, _AOT_SCALED_MM_REASON = _probe_legacy_aot_scaled_mm() + +shape_range = get_benchmark_range( + full_range=[(128, 4096, 4096), (512, 4096, 4096), (1024, 8192, 4096)], + ci_range=[(128, 4096, 4096)], +) + +line_vals = ["jit"] +line_names = ["JIT NVFP4 GEMM"] +styles = [("green", "-")] +if _AOT_SCALED_MM_AVAILABLE: + line_vals.append("aot_sgl_kernel") + line_names.append("AOT NVFP4 GEMM") + styles.append(("orange", "-")) +line_vals.append("torch_ref") +line_names.append("Torch Ref") +styles.append(("blue", "-")) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["m", "n", "k"], + x_vals=shape_range, + x_log=False, + line_arg="provider", + line_vals=line_vals, + line_names=line_names, + styles=styles, + ylabel="us", + plot_name="nvfp4-scaled-mm-performance", + args={}, + ) +) +def benchmark(m, n, k, provider): + a = torch.randn((m, k), dtype=torch.bfloat16, device="cuda") + b = torch.randn((n, k), dtype=torch.bfloat16, device="cuda") + + a_global_scale = ( + FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / torch.amax(a.flatten(), dim=-1) + ).to(torch.float32) + b_global_scale = ( + FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / torch.amax(b.flatten(), dim=-1) + ).to(torch.float32) + alpha = 1.0 / (a_global_scale * b_global_scale) + + a_fp4, a_sf = scaled_fp4_quant(a, a_global_scale) + b_fp4, b_sf = scaled_fp4_quant(b, b_global_scale) + + if provider == "jit": + fn = lambda: cutlass_scaled_fp4_mm( + a_fp4, b_fp4, a_sf, b_sf, alpha, torch.bfloat16 + ) + elif provider == "aot_sgl_kernel": + fn = lambda: _aot_cutlass_scaled_fp4_mm( + a_fp4, b_fp4, a_sf, b_sf, alpha, torch.bfloat16 + ) + elif provider == "torch_ref": + a_ref = _dequantize_to_fp16(a_fp4, a_sf, a_global_scale) + b_ref = _dequantize_to_fp16(b_fp4, b_sf, b_global_scale) + fn = lambda: torch.matmul(a_ref, b_ref.t()) + else: + raise ValueError(f"Unknown provider: {provider}") + + return run_benchmark(fn) + + +if __name__ == "__main__": + if not _AOT_SCALED_MM_AVAILABLE: + print( + f"[info] legacy AOT scaled_mm baseline unavailable: {_AOT_SCALED_MM_REASON}" + ) + benchmark.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_per_tensor_quant_fp8.py b/sglang/python/sglang/jit_kernel/benchmark/bench_per_tensor_quant_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..9549526d59fe8a821807e1e4d7490b61cad21717 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_per_tensor_quant_fp8.py @@ -0,0 +1,117 @@ +from typing import Optional, Tuple + +import torch +import triton +import triton.testing + +from sglang.jit_kernel.benchmark.utils import get_benchmark_range, run_benchmark +from sglang.jit_kernel.per_tensor_quant_fp8 import per_tensor_quant_fp8 + +try: + from vllm import _custom_ops as ops + + VLLM_AVAILABLE = True +except ImportError: + ops = None + VLLM_AVAILABLE = False + +try: + from sglang.srt.utils import is_hip + + _is_hip = is_hip() +except ImportError: + _is_hip = False + +fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn + + +def vllm_scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if not VLLM_AVAILABLE: + return sglang_scaled_fp8_quant(input, scale) + return ops.scaled_fp8_quant(input, scale) + + +def sglang_scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + fp8_type_: torch.dtype = torch.float8_e4m3fn + output = torch.empty_like(input, device=input.device, dtype=fp8_type_) + is_static = True + if scale is None: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + is_static = False + per_tensor_quant_fp8(input, output, scale, is_static) + + return output, scale + + +def calculate_diff(batch_size: int, seq_len: int): + device = torch.device("cuda") + x = torch.rand((batch_size, seq_len), dtype=torch.bfloat16, device=device) + + if not VLLM_AVAILABLE: + print("vLLM not available, skipping comparison") + return + + vllm_out, vllm_scale = vllm_scaled_fp8_quant(x) + sglang_out, sglang_scale = sglang_scaled_fp8_quant(x) + + vllm_out = vllm_out.to(torch.float32) + sglang_out = sglang_out.to(torch.float32) + + triton.testing.assert_close(vllm_out, sglang_out, rtol=1e-3, atol=1e-3) + triton.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3) + + +# Benchmark configuration +element_range = get_benchmark_range( + full_range=[2**n for n in range(10, 20)], + ci_range=[16384], +) + +if VLLM_AVAILABLE: + line_vals = ["vllm", "sglang"] + line_names = ["VLLM", "SGL Kernel"] + styles = [("blue", "-"), ("green", "-")] +else: + line_vals = ["sglang"] + line_names = ["SGL Kernel"] + styles = [("green", "-")] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["element_count"], + x_vals=element_range, + line_arg="provider", + line_vals=line_vals, + line_names=line_names, + styles=styles, + ylabel="us", + plot_name="per-tensor-quant-fp8-performance", + args={}, + ) +) +def benchmark(element_count, provider): + dtype = torch.float16 + device = torch.device("cuda") + + x = torch.randn(element_count, 4096, device=device, dtype=dtype) + + if provider == "vllm": + fn = lambda: vllm_scaled_fp8_quant(x.clone()) + elif provider == "sglang": + fn = lambda: sglang_scaled_fp8_quant(x.clone()) + else: + raise ValueError(f"Unknown provider: {provider}") + + return run_benchmark(fn) + + +if __name__ == "__main__": + calculate_diff(batch_size=4, seq_len=4096) + benchmark.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_per_token_group_quant_8bit.py b/sglang/python/sglang/jit_kernel/benchmark/bench_per_token_group_quant_8bit.py new file mode 100644 index 0000000000000000000000000000000000000000..7edcf82c2b6f0631d02735eea300893388a9af3c --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_per_token_group_quant_8bit.py @@ -0,0 +1,290 @@ +import itertools +import os +from typing import Any, Dict, List + +import torch +import triton +from sgl_kernel.test_utils import create_per_token_group_quant_test_data + +from sglang.jit_kernel.benchmark.utils import ( + get_benchmark_range, +) +from sglang.jit_kernel.per_token_group_quant_8bit import ( + per_token_group_quant_8bit as sglang_per_token_group_quant_8bit, +) +from sglang.srt.layers.quantization.fp8_kernel import ( + create_per_token_group_quant_fp8_output_scale, +) +from sglang.srt.layers.quantization.fp8_kernel import ( + per_token_group_quant_8bit as triton_per_token_group_quant_8bit, +) +from sglang.srt.utils import is_hip +from sglang.srt.utils.bench_utils import bench_kineto + +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + +_is_hip = is_hip() +fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn + +NUM_TESTS = 300 if IS_CI else 30 + +GROUP_SIZE_RANGE = [128] +DST_DTYPE_RANGE = [fp8_type_] + +# ---- GEMM-like branch (num_ranks=None) ---- +NUM_TOKENS_RANGE_GEMM = get_benchmark_range( + full_range=[1, 4, 16, 64, 256, 768, 2048, 8192, 16384], + ci_range=[768], +) +HIDDEN_DIM_RANGE_GEMM = [1536, 7168, 16384] +NUM_RANKS_RANGE_GEMM = [None] + + +FLAGS_GEMM_FULL: List[Dict[str, Any]] = [ + dict( + column_major_scales=False, + scale_tma_aligned=False, + scale_ue8m0=False, + fuse_silu_and_mul=False, + masked_layout_mode=None, + ), + dict( + column_major_scales=True, + scale_tma_aligned=False, + scale_ue8m0=False, + fuse_silu_and_mul=False, + masked_layout_mode=None, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=False, + fuse_silu_and_mul=False, + masked_layout_mode=None, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=False, + masked_layout_mode=None, + ), +] +FLAGS_GEMM_CI: List[Dict[str, Any]] = [ + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=False, + masked_layout_mode=None, + ), +] +FLAGS_RANGE_GEMM = get_benchmark_range( + full_range=FLAGS_GEMM_FULL, ci_range=FLAGS_GEMM_CI +) + +CONFIGS_GEMM = list( + itertools.product( + NUM_TOKENS_RANGE_GEMM, + HIDDEN_DIM_RANGE_GEMM, + GROUP_SIZE_RANGE, + NUM_RANKS_RANGE_GEMM, + DST_DTYPE_RANGE, + FLAGS_RANGE_GEMM, + ) +) + +# ---- MoE-like / multi-rank branch (hidden_dim=2048, num_ranks in {8,16,32,48}) ---- +NUM_TOKENS_RANGE_MOE = get_benchmark_range( + full_range=[1 * 8, 4 * 8, 64 * 8, 256 * 8, 768 * 8], + ci_range=[768 * 8], +) +HIDDEN_DIM_RANGE_MOE = [2048] +NUM_RANKS_RANGE_MOE = get_benchmark_range( + full_range=[8, 16, 32, 48], + ci_range=[48], +) + +FLAGS_MOE: List[Dict[str, Any]] = [ + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode=None, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode="balanced", + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode="imbalanced", + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode="extreme", + ), +] +FLAGS_RANGE_MOE = get_benchmark_range(full_range=FLAGS_MOE, ci_range=FLAGS_MOE) + +CONFIGS_MOE = list( + itertools.product( + NUM_TOKENS_RANGE_MOE, + HIDDEN_DIM_RANGE_MOE, + GROUP_SIZE_RANGE, + NUM_RANKS_RANGE_MOE, + DST_DTYPE_RANGE, + FLAGS_RANGE_MOE, + ) +) + +# ---- Final configs ---- +CONFIGS = CONFIGS_GEMM + CONFIGS_MOE + +LINE_VALS = ["triton", "sglang"] +LINE_NAMES = ["Triton (Inaccurate)", "SGL Kernel"] +STYLES = [("blue", "-"), ("green", "-")] + + +def _flatten_to_2d(t: torch.Tensor) -> torch.Tensor: + """Reshape a tensor with 3+ dims to 2D by merging all leading dims.""" + if t.ndim <= 2: + return t + return t.reshape(-1, t.shape[-1]) + + +def _make_sglang_bench_fn( + x: torch.Tensor, + group_size: int, + dst_dtype: torch.dtype, + flags: dict, +): + """ + Adapter that pre-allocates output tensors and returns a zero-arg callable + matching the JIT kernel's signature. + + The JIT kernel does not support fuse_silu_and_mul, so when enabled we + pre-compute silu+mul on the input. bench_kineto only times the kernel + matching the given name, so the pre-processing is not included. + + The JIT kernel expects 2D tensors, so any higher-dimensional inputs + (e.g. from masked_layout_mode) are flattened to 2D. + """ + fuse_silu_and_mul = flags.get("fuse_silu_and_mul", False) + column_major_scales = flags.get("column_major_scales", False) + scale_tma_aligned = flags.get("scale_tma_aligned", False) + scale_ue8m0 = flags.get("scale_ue8m0", False) + + # JIT kernel does not support fuse_silu_and_mul; pre-compute it + if fuse_silu_and_mul: + half = x.shape[-1] // 2 + x_input = torch.nn.functional.silu(x[..., :half]) * x[..., half:] + else: + x_input = x + + # JIT kernel expects 2D (num_tokens, hidden_dim); flatten if needed + x_input = _flatten_to_2d(x_input.contiguous()) + + out_shape = x_input.shape + output_q = torch.empty(out_shape, device=x.device, dtype=dst_dtype) + + fp8_max = torch.finfo(dst_dtype).max + fp8_min = -fp8_max + + output_s = create_per_token_group_quant_fp8_output_scale( + x_shape=out_shape, + device=x.device, + group_size=group_size, + column_major_scales=column_major_scales, + scale_tma_aligned=scale_tma_aligned, + scale_ue8m0=scale_ue8m0, + ) + + def _run(): + sglang_per_token_group_quant_8bit( + input=x_input, + output_q=output_q, + output_s=output_s, + group_size=group_size, + eps=1e-10, + fp8_min=fp8_min, + fp8_max=fp8_max, + scale_ue8m0=scale_ue8m0, + ) + + return _run + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=[ + "num_tokens", + "hidden_dim", + "group_size", + "num_ranks", + "dst_dtype", + "flags", + ], + x_vals=CONFIGS, + line_arg="provider", + line_vals=LINE_VALS, + # Triton has multi kernels and we only report the time for the core one + line_names=LINE_NAMES, + styles=STYLES, + ylabel="us", + plot_name="per-token-group-quant-8bit-performance", + args={}, + ) +) +def benchmark( + num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags, provider +): + print( + f"Testing: {num_tokens=} {hidden_dim=} {group_size=} {num_ranks=} {dst_dtype=} {flags=} {provider=}" + ) + + x, masked_m = create_per_token_group_quant_test_data( + num_tokens=num_tokens, hidden_dim=hidden_dim, num_ranks=num_ranks, flags=flags + ) + + if provider == "triton": + fn = triton_per_token_group_quant_8bit + kernel_names = "_per_token_group_quant_8bit|_silu_and_mul_post_quant_kernel" + bench_fn = lambda: fn( + x=x, + masked_m=masked_m, + group_size=group_size, + dst_dtype=dst_dtype, + **{k: v for k, v in flags.items() if k not in ["masked_layout_mode"]}, + ) + elif provider == "sglang": + kernel_names = "per_token_group_quant_8bit_kernel" + bench_fn = _make_sglang_bench_fn( + x=x, + group_size=group_size, + dst_dtype=dst_dtype, + flags=flags, + ) + else: + raise ValueError(f"Unknown provider: {provider}") + + time_s = bench_kineto(bench_fn, kernel_names=kernel_names, num_tests=NUM_TESTS) + return time_s * 1e6 + + +if __name__ == "__main__": + benchmark.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_qknorm.py b/sglang/python/sglang/jit_kernel/benchmark/bench_qknorm.py new file mode 100644 index 0000000000000000000000000000000000000000..76e217f0938d206342edeff374e0b9447a6fb901 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_qknorm.py @@ -0,0 +1,137 @@ +import itertools + +import torch +import triton +import triton.testing +from sgl_kernel import rmsnorm + +from sglang.jit_kernel.benchmark.utils import ( + DEFAULT_DEVICE, + DEFAULT_DTYPE, + get_benchmark_range, + run_benchmark, +) +from sglang.jit_kernel.norm import fused_inplace_qknorm +from sglang.srt.utils import get_current_device_stream_fast + +alt_stream = torch.cuda.Stream() + + +def sglang_aot_qknorm( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, +) -> None: + + head_dim = q.shape[-1] + q = q.view(-1, head_dim) + k = k.view(-1, head_dim) + + current_stream = get_current_device_stream_fast() + alt_stream.wait_stream(current_stream) + rmsnorm(q, q_weight, out=q) + with torch.cuda.stream(alt_stream): + rmsnorm(k, k_weight, out=k) + current_stream.wait_stream(alt_stream) + + +def sglang_jit_qknorm( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, +) -> None: + + fused_inplace_qknorm(q, k, q_weight, k_weight) + + +def flashinfer_qknorm( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, +) -> None: + from flashinfer import rmsnorm + + rmsnorm(q, q_weight, out=q) + rmsnorm(k, k_weight, out=k) + + +@torch.compile() +def torch_impl_qknorm( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + q_mean = q.float().pow(2).mean(dim=-1, keepdim=True) + k_mean = k.float().pow(2).mean(dim=-1, keepdim=True) + q_norm = (q_mean + eps).rsqrt() + k_norm = (k_mean + eps).rsqrt() + q.copy_(q.float() * q_norm * q_weight.float()) + k.copy_(k.float() * k_norm * k_weight.float()) + + +BS_RANGE = get_benchmark_range( + full_range=[2**n for n in range(0, 14)], + ci_range=[16], +) +GQA_RANGE = get_benchmark_range( + full_range=[4, 8], + ci_range=[4], +) +KV_HEAD_RANGE = get_benchmark_range( + full_range=[1, 2, 4, 8], + ci_range=[1], +) +HEAD_DIM_RANGE = get_benchmark_range( + full_range=[128, 256, 512, 1024], + ci_range=[128], +) + +LINE_VALS = ["aot", "jit", "fi", "torch"] +LINE_NAMES = ["SGL AOT Kernel", "SGL JIT Kernel", "FlashInfer", "PyTorch"] +STYLES = [("orange", "-"), ("blue", "--"), ("green", "-."), ("red", ":")] + +configs = list(itertools.product(HEAD_DIM_RANGE, GQA_RANGE, KV_HEAD_RANGE, BS_RANGE)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["head_dim", "GQA", "num_kv_heads", "batch_size"], + x_vals=configs, + line_arg="provider", + line_vals=LINE_VALS, + line_names=LINE_NAMES, + styles=STYLES, + ylabel="us", + plot_name="qknorm-performance", + args={}, + ) +) +def benchmark( + head_dim: int, GQA: int, num_kv_heads: int, batch_size: int, provider: str +): + num_qo_heads = GQA * num_kv_heads + q = torch.randn( + (batch_size, num_qo_heads, head_dim), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE + ) + k = torch.randn( + (batch_size, num_kv_heads, head_dim), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE + ) + q_weight = torch.randn(head_dim, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE) + k_weight = torch.randn(head_dim, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE) + FN_MAP = { + "aot": sglang_aot_qknorm, + "jit": sglang_jit_qknorm, + "fi": flashinfer_qknorm, + "torch": torch_impl_qknorm, + } + fn = lambda: FN_MAP[provider](q, k, q_weight, k_weight) + return run_benchmark(fn) + + +if __name__ == "__main__": + benchmark.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_qknorm_across_heads.py b/sglang/python/sglang/jit_kernel/benchmark/bench_qknorm_across_heads.py new file mode 100644 index 0000000000000000000000000000000000000000..64d6bd921f724b01f36c0bdbc1b2bacb749dd9c9 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_qknorm_across_heads.py @@ -0,0 +1,121 @@ +import itertools +from typing import Tuple + +import torch +import triton +import triton.testing +from sgl_kernel import rmsnorm + +from sglang.jit_kernel.benchmark.utils import is_in_ci +from sglang.jit_kernel.norm import fused_inplace_qknorm_across_heads +from sglang.srt.utils import get_current_device_stream_fast + +IS_CI = is_in_ci() + +alt_stream = torch.cuda.Stream() + + +def sglang_jit_qknorm_across_heads( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, +) -> None: + + fused_inplace_qknorm_across_heads(q, k, q_weight, k_weight) + + +def sglang_aot_qknorm_across_heads( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, +) -> None: + + current_stream = get_current_device_stream_fast() + alt_stream.wait_stream(current_stream) + rmsnorm(q, q_weight, out=q) + with torch.cuda.stream(alt_stream): + rmsnorm(k, k_weight, out=k) + current_stream.wait_stream(alt_stream) + + +def flashinfer_qknorm_across_heads( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, +) -> None: + from flashinfer import rmsnorm + + rmsnorm(q, q_weight, out=q) + rmsnorm(k, k_weight, out=k) + + +@torch.compile() +def torch_impl_qknorm_across_heads( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + q_mean = q.float().pow(2).mean(dim=-1, keepdim=True) + k_mean = k.float().pow(2).mean(dim=-1, keepdim=True) + q_norm = (q_mean + eps).rsqrt() + k_norm = (k_mean + eps).rsqrt() + q.copy_(q.float() * q_norm * q_weight.float()) + k.copy_(k.float() * k_norm * k_weight.float()) + + +DTYPE = torch.bfloat16 +DEVICE = "cuda" + +if IS_CI: + BS_RANGE = [16] + HIDDEN_DIM_RANGE = [1024] +else: + BS_RANGE = [2**n for n in range(0, 14)] + HIDDEN_DIM_RANGE = [512, 1024, 2048, 4096, 8192] + +LINE_VALS = ["jit", "aot", "fi", "torch"] +LINE_NAMES = ["SGL JIT Kernel", "SGL AOT Kernel", "FlashInfer", "PyTorch"] +STYLES = [("blue", "-"), ("orange", "--"), ("green", "-."), ("red", ":")] + +configs = list(itertools.product(BS_RANGE, HIDDEN_DIM_RANGE)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "hidden_dim"], + x_vals=configs, + line_arg="provider", + line_vals=LINE_VALS, + line_names=LINE_NAMES, + styles=STYLES, + ylabel="us", + plot_name="qknorm-across-heads-performance", + args={}, + ) +) +def benchmark( + batch_size: int, hidden_dim: int, provider: str +) -> Tuple[float, float, float]: + q = torch.randn((batch_size, hidden_dim), dtype=DTYPE, device=DEVICE) + k = torch.randn((batch_size, hidden_dim), dtype=DTYPE, device=DEVICE) + q_weight = torch.randn(hidden_dim, dtype=DTYPE, device=DEVICE) + k_weight = torch.randn(hidden_dim, dtype=DTYPE, device=DEVICE) + FN_MAP = { + "jit": sglang_jit_qknorm_across_heads, + "aot": sglang_aot_qknorm_across_heads, + "fi": flashinfer_qknorm_across_heads, + "torch": torch_impl_qknorm_across_heads, + } + fn = lambda: FN_MAP[provider](q, k, q_weight, k_weight) + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) # type: ignore + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + benchmark.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_renorm.py b/sglang/python/sglang/jit_kernel/benchmark/bench_renorm.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b67fb839bab71b0e428720cc7da477e50701b4 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_renorm.py @@ -0,0 +1,321 @@ +import itertools + +import sgl_kernel +import torch +import triton +import triton.testing + +from sglang.jit_kernel.benchmark.utils import is_in_ci + + +def torch_top_k_renorm_probs(probs, top_k): + """Vectorized PyTorch implementation of top-k renormalization.""" + batch_size, vocab_size = probs.shape + + # Handle scalar or tensor k + if isinstance(top_k, int): + k_val = min(max(top_k, 1), vocab_size) + # Get top-k indices for all batches at once + _, topk_indices = torch.topk(probs, k_val, dim=1, largest=True) + + # Create mask: batch_size x vocab_size + mask = torch.zeros_like(probs) + mask.scatter_(1, topk_indices, 1.0) + + # Vectorized renormalization + masked_probs = probs * mask + renorm_probs = masked_probs / (masked_probs.sum(dim=1, keepdim=True) + 1e-10) + return renorm_probs + else: + # Variable k per batch - need to handle separately + renorm_probs = torch.zeros_like(probs) + for i in range(batch_size): + k_val = min(max(top_k[i].item(), 1), vocab_size) + _, topk_indices = torch.topk(probs[i], k_val, largest=True) + mask = torch.zeros_like(probs[i]) + mask[topk_indices] = 1.0 + masked_probs = probs[i] * mask + renorm_probs[i] = masked_probs / (masked_probs.sum() + 1e-10) + return renorm_probs + + +def torch_top_p_renorm_probs(probs, top_p, eps=1e-5): + """Vectorized PyTorch implementation of top-p renormalization.""" + batch_size, vocab_size = probs.shape + + # Handle scalar or tensor p + if isinstance(top_p, float): + p_val = top_p + # Vectorized implementation for uniform top_p + # Sort probs in descending order + sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1) + cumsum_probs = torch.cumsum(sorted_probs, dim=1) + + # Find cutoff: where cumsum exceeds top_p + cutoff_mask = cumsum_probs <= p_val + # Keep at least one token (the highest prob) + cutoff_mask[:, 0] = True + + # Create mask in original order + mask = torch.zeros_like(probs) + mask.scatter_(1, sorted_indices, cutoff_mask.float()) + + # Vectorized renormalization + masked_probs = probs * mask + renorm_probs = masked_probs / (masked_probs.sum(dim=1, keepdim=True) + eps) + return renorm_probs + else: + # Variable p per batch - need to handle separately + renorm_probs = torch.zeros_like(probs) + for i in range(batch_size): + p_val = top_p[i].item() + sorted_prob, indices = torch.sort(probs[i], descending=False) + cdf = torch.cumsum(sorted_prob, dim=-1) + mask = torch.zeros(vocab_size, dtype=torch.float32, device=probs.device) + mask.scatter_(0, indices, (cdf >= (1 - p_val) - eps).float()) + masked_probs = probs[i] * mask + renorm_probs[i] = masked_probs / (masked_probs.sum() + eps) + return renorm_probs + + +def torch_top_k_mask_logits(logits, top_k): + """Vectorized PyTorch implementation of top-k logits masking.""" + batch_size, vocab_size = logits.shape + + # Handle scalar or tensor k + if isinstance(top_k, int): + k_val = min(max(top_k, 1), vocab_size) + # Get top-k indices for all batches at once + _, topk_indices = torch.topk(logits, k_val, dim=1, largest=True) + + # Create masked logits: start with -inf everywhere + masked_logits = torch.full_like(logits, float("-inf")) + # Scatter the top-k values back + masked_logits.scatter_(1, topk_indices, logits.gather(1, topk_indices)) + else: + # Variable k per batch - need to handle separately + masked_logits = torch.full_like(logits, float("-inf")) + for i in range(batch_size): + k_val = min(max(top_k[i].item(), 1), vocab_size) + _, topk_indices = torch.topk(logits[i], k_val, largest=True) + masked_logits[i, topk_indices] = logits[i, topk_indices] + + return masked_logits + + +def calculate_diff_top_k_renorm(batch_size, vocab_size, k): + """Compare Torch reference and SGLang kernel for top-k renorm correctness.""" + torch.manual_seed(42) + device = torch.device("cuda") + + pre_norm_prob = torch.rand(batch_size, vocab_size, device=device) + probs = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + + top_k_tensor = torch.full((batch_size,), k, device=device, dtype=torch.int32) + + torch_output = torch_top_k_renorm_probs(probs, top_k_tensor) + sglang_output = sgl_kernel.top_k_renorm_prob(probs, top_k_tensor) + + torch.testing.assert_close(torch_output, sglang_output, rtol=1e-3, atol=1e-3) + + +def calculate_diff_top_p_renorm(batch_size, vocab_size, p): + """Compare Torch reference and SGLang kernel for top-p renorm correctness.""" + torch.manual_seed(42) + device = torch.device("cuda") + + pre_norm_prob = torch.rand(batch_size, vocab_size, device=device) + probs = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + + top_p_tensor = torch.full((batch_size,), p, device=device, dtype=torch.float32) + + torch_output = torch_top_p_renorm_probs(probs, top_p_tensor) + sglang_output = sgl_kernel.top_p_renorm_prob(probs, top_p_tensor) + + torch.testing.assert_close(torch_output, sglang_output, rtol=1e-3, atol=1e-3) + + +def calculate_diff_top_k_mask(batch_size, vocab_size, k): + """Compare Torch reference and SGLang kernel for top-k mask correctness.""" + torch.manual_seed(42) + device = torch.device("cuda") + + logits = torch.randn(batch_size, vocab_size, device=device) * 5 + top_k_tensor = torch.full((batch_size,), k, device=device, dtype=torch.int32) + + torch_output = torch_top_k_mask_logits(logits, top_k_tensor) + sglang_output = sgl_kernel.top_k_mask_logits(logits, top_k_tensor) + + torch.testing.assert_close(torch_output, sglang_output, rtol=1e-3, atol=1e-3) + + +# Parameter space - simplified for CI +if is_in_ci(): + batch_size_range = [16] + vocab_size_range = [111] + k_range = [10] + p_range = [0.5] +else: + batch_size_range = [16, 64, 128] + vocab_size_range = [111, 32000, 128256] + k_range = [10, 100, 500] + p_range = [0.1, 0.5, 0.9] + +configs_k = list(itertools.product(batch_size_range, vocab_size_range, k_range)) +configs_p = list(itertools.product(batch_size_range, vocab_size_range, p_range)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "vocab_size", "k"], + x_vals=configs_k, + line_arg="provider", + line_vals=["torch", "sglang"], + line_names=["Torch Reference", "SGL Kernel (FlashInfer)"], + styles=[("red", "-"), ("green", "-")], + ylabel="us", + plot_name="top-k-renorm-probs-performance", + args={}, + ) +) +def benchmark_top_k_renorm(batch_size, vocab_size, k, provider): + # Skip invalid configurations + if k >= vocab_size: + return float("nan"), float("nan"), float("nan") + + torch.manual_seed(42) + device = torch.device("cuda") + + pre_norm_prob = torch.rand(batch_size, vocab_size, device=device) + probs = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + top_k_tensor = torch.full((batch_size,), k, device=device, dtype=torch.int32) + + if provider == "torch": + fn = lambda: torch_top_k_renorm_probs(probs.clone(), top_k_tensor) + elif provider == "sglang": + fn = lambda: sgl_kernel.top_k_renorm_prob(probs.clone(), top_k_tensor) + + ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8]) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "vocab_size", "p"], + x_vals=configs_p, + line_arg="provider", + line_vals=["torch", "sglang"], + line_names=["Torch Reference", "SGL Kernel (FlashInfer)"], + styles=[("red", "-"), ("blue", "-")], + ylabel="us", + plot_name="top-p-renorm-probs-performance", + args={}, + ) +) +def benchmark_top_p_renorm(batch_size, vocab_size, p, provider): + torch.manual_seed(42) + device = torch.device("cuda") + + pre_norm_prob = torch.rand(batch_size, vocab_size, device=device) + probs = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + top_p_tensor = torch.full((batch_size,), p, device=device, dtype=torch.float32) + + if provider == "torch": + fn = lambda: torch_top_p_renorm_probs(probs.clone(), top_p_tensor) + elif provider == "sglang": + fn = lambda: sgl_kernel.top_p_renorm_prob(probs.clone(), top_p_tensor) + + ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8]) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "vocab_size", "k"], + x_vals=configs_k, + line_arg="provider", + line_vals=["torch", "sglang"], + line_names=["Torch Reference", "SGL Kernel (FlashInfer)"], + styles=[("red", "-"), ("orange", "-")], + ylabel="us", + plot_name="top-k-mask-logits-performance", + args={}, + ) +) +def benchmark_top_k_mask(batch_size, vocab_size, k, provider): + # Skip invalid configurations + if k >= vocab_size: + return float("nan"), float("nan"), float("nan") + + torch.manual_seed(42) + device = torch.device("cuda") + + logits = torch.randn(batch_size, vocab_size, device=device) * 5 + top_k_tensor = torch.full((batch_size,), k, device=device, dtype=torch.int32) + + if provider == "torch": + fn = lambda: torch_top_k_mask_logits(logits.clone(), top_k_tensor) + elif provider == "sglang": + fn = lambda: sgl_kernel.top_k_mask_logits(logits.clone(), top_k_tensor) + + ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8]) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + print("=" * 60) + print("Running correctness checks...") + print("=" * 60) + + # Correctness checks - simplified for CI + if is_in_ci(): + test_configs_k = [configs_k[0]] if configs_k else [(16, 111, 10)] + test_configs_p = [configs_p[0]] if configs_p else [(16, 111, 0.5)] + else: + test_configs_k = configs_k[:3] # Test first 3 configs + test_configs_p = configs_p[:3] + + print("\n1. Testing top_k_renorm_probs...") + for cfg in test_configs_k: + batch_size, vocab_size, k = cfg + if k < vocab_size: # Skip invalid configs + calculate_diff_top_k_renorm(batch_size, vocab_size, k) + print( + f" ✓ Passed: batch_size={batch_size}, vocab_size={vocab_size}, k={k}" + ) + + print("\n2. Testing top_p_renorm_probs...") + for cfg in test_configs_p: + calculate_diff_top_p_renorm(*cfg) + batch_size, vocab_size, p = cfg + print(f" ✓ Passed: batch_size={batch_size}, vocab_size={vocab_size}, p={p}") + + print("\n3. Testing top_k_mask_logits...") + for cfg in test_configs_k: + batch_size, vocab_size, k = cfg + if k < vocab_size: # Skip invalid configs + calculate_diff_top_k_mask(batch_size, vocab_size, k) + print( + f" ✓ Passed: batch_size={batch_size}, vocab_size={vocab_size}, k={k}" + ) + + print("\n" + "=" * 60) + print("All correctness checks passed!") + print("=" * 60) + + print("\n" + "=" * 60) + print("Starting performance benchmarks...") + print("=" * 60) + + print("\n1. Benchmarking top_k_renorm_probs...") + benchmark_top_k_renorm.run(print_data=True) + + print("\n2. Benchmarking top_p_renorm_probs...") + benchmark_top_p_renorm.run(print_data=True) + + print("\n3. Benchmarking top_k_mask_logits...") + benchmark_top_k_mask.run(print_data=True) + + print("\n" + "=" * 60) + print("Benchmarking complete!") + print("=" * 60) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_rmsnorm.py b/sglang/python/sglang/jit_kernel/benchmark/bench_rmsnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..74b929da60e68c9c65959b60eeb538d9d75fc0e3 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_rmsnorm.py @@ -0,0 +1,95 @@ +import itertools + +import torch +import triton +import triton.testing +from flashinfer import rmsnorm as fi_rmsnorm +from sgl_kernel import rmsnorm + +from sglang.jit_kernel.benchmark.utils import ( + DEFAULT_DEVICE, + DEFAULT_DTYPE, + get_benchmark_range, + run_benchmark, +) +from sglang.jit_kernel.norm import rmsnorm as jit_rmsnorm + + +def sglang_aot_rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, +) -> None: + rmsnorm(input, weight, out=input) + + +def sglang_jit_rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, +) -> None: + jit_rmsnorm(input, weight, output=input) + + +def flashinfer_rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, +) -> None: + fi_rmsnorm(input, weight, out=input) + + +@torch.compile() +def torch_impl_rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + mean = input.float().pow(2).mean(dim=-1, keepdim=True) + norm = (mean + eps).rsqrt() + input.copy_(input.float() * norm * weight.float()) + + +BS_LIST = get_benchmark_range( + full_range=[2**n for n in range(0, 14)], + ci_range=[16], +) +HIDDEN_SIZE_LIST = get_benchmark_range( + full_range=[1536, 3072, 4096, 5120, 8192], + ci_range=[512, 2048], +) + +LINE_VALS = ["aot", "jit", "fi", "torch"] +LINE_NAMES = ["SGL AOT Kernel", "SGL JIT Kernel", "FlashInfer", "PyTorch"] +STYLES = [("orange", "-"), ("blue", "--"), ("green", "-."), ("red", ":")] + +configs = list(itertools.product(HIDDEN_SIZE_LIST, BS_LIST)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["hidden_size", "batch_size"], + x_vals=configs, + line_arg="provider", + line_vals=LINE_VALS, + line_names=LINE_NAMES, + styles=STYLES, + ylabel="us", + plot_name="rmsnorm-performance", + args={}, + ) +) +def benchmark(hidden_size: int, batch_size: int, provider: str): + input = torch.randn( + (batch_size, hidden_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE + ) + weight = torch.randn(hidden_size, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE) + FN_MAP = { + "aot": sglang_aot_rmsnorm, + "jit": sglang_jit_rmsnorm, + "fi": flashinfer_rmsnorm, + "torch": torch_impl_rmsnorm, + } + fn = lambda: FN_MAP[provider](input.clone(), weight) + return run_benchmark(fn) + + +if __name__ == "__main__": + benchmark.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_rope.py b/sglang/python/sglang/jit_kernel/benchmark/bench_rope.py new file mode 100644 index 0000000000000000000000000000000000000000..8d5cdae27ae4b58061b1727d4d2a7b84af18bdbb --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_rope.py @@ -0,0 +1,350 @@ +import itertools + +import torch +import triton +import triton.testing + +from sglang.jit_kernel.benchmark.utils import ( + DEFAULT_DEVICE, + DEFAULT_DTYPE, + get_benchmark_range, + run_benchmark, +) + +MAX_SEQ_LEN = 131072 +ROPE_BASE = 10000.0 +ROPE_DIM = 128 +CACHE_SIZE = 1024 * 1024 + + +def create_cos_sin_cache( + rotary_dim: int = ROPE_DIM, + max_position: int = MAX_SEQ_LEN, + base: float = ROPE_BASE, +) -> torch.Tensor: + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, rotary_dim, 2, dtype=torch.float32, device=DEFAULT_DEVICE) + / rotary_dim + ) + ) + t = torch.arange(max_position, dtype=torch.float32, device=DEFAULT_DEVICE) + freqs = torch.einsum("i,j->ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + return torch.cat((cos, sin), dim=-1) + + +# Pre-build the cache once +COS_SIN_CACHE = create_cos_sin_cache() + + +# --------------------------------------------------------------------------- +# RoPE-only provider implementations +# --------------------------------------------------------------------------- + + +def flashinfer_rope( + q: torch.Tensor, + k: torch.Tensor, + positions: torch.Tensor, + is_neox: bool, +) -> None: + from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace + + head_size = q.shape[-1] + apply_rope_with_cos_sin_cache_inplace( + positions=positions, + query=q.view(q.shape[0], -1), + key=k.view(k.shape[0], -1), + head_size=head_size, + cos_sin_cache=COS_SIN_CACHE, + is_neox=is_neox, + ) + + +def sglang_rope_v0( + q: torch.Tensor, + k: torch.Tensor, + positions: torch.Tensor, + is_neox: bool, +) -> None: + from sglang.jit_kernel.pos_enc import rotary_embedding_with_key + + head_size = q.shape[-1] + rotary_embedding_with_key( + positions=positions, + query=q.view(q.shape[0], -1), + key=k.view(k.shape[0], -1), + head_size=head_size, + cos_sin_cache=COS_SIN_CACHE, + is_neox=is_neox, + ) + + +def sglang_rope_v1( + q: torch.Tensor, + k: torch.Tensor, + positions: torch.Tensor, + is_neox: bool, +) -> None: + from sgl_kernel import apply_rope_with_cos_sin_cache_inplace + + head_size = q.shape[-1] + apply_rope_with_cos_sin_cache_inplace( + positions=positions, + query=q.view(q.shape[0], -1), + key=k.view(k.shape[0], -1), + head_size=head_size, + cos_sin_cache=COS_SIN_CACHE, + is_neox=is_neox, + ) + + +def sglang_rope_v2( + q: torch.Tensor, + k: torch.Tensor, + positions: torch.Tensor, + is_neox: bool, +) -> None: + from sglang.jit_kernel.rope import apply_rope_inplace + + apply_rope_inplace(q, k, COS_SIN_CACHE, positions, is_neox=is_neox) + + +# --------------------------------------------------------------------------- +# RoPE + KV cache store provider implementations +# --------------------------------------------------------------------------- + + +def rope_v0_store( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + positions: torch.Tensor, + out_loc: torch.Tensor, + is_neox: bool, +) -> None: + from sglang.jit_kernel.kvcache import store_cache + from sglang.jit_kernel.rope import apply_rope_inplace + + head_size = q.shape[-1] + row_dim = k.shape[-2] * head_size + apply_rope_inplace( + positions=positions, + q=q, + k=k, + rope_dim=head_size, + cos_sin_cache=COS_SIN_CACHE, + is_neox=is_neox, + ) + store_cache( + k.view(-1, row_dim), + v.view(-1, row_dim), + k_cache, + v_cache, + out_loc, + ) + + +def rope_v1_store( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + positions: torch.Tensor, + out_loc: torch.Tensor, + is_neox: bool, +) -> None: + from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace + + head_size = q.shape[-1] + apply_rope_with_cos_sin_cache_inplace( + positions=positions, + query=q.view(q.shape[0], -1), + key=k.view(k.shape[0], -1), + head_size=head_size, + cos_sin_cache=COS_SIN_CACHE, + is_neox=is_neox, + fused_set_kv_buffer_arg=FusedSetKVBufferArg( + value=v.view(v.shape[0], -1), + k_buffer=k_cache, + v_buffer=v_cache, + k_scale=None, + v_scale=None, + cache_loc=out_loc, + ), + ) + + +def rope_v2_store( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + positions: torch.Tensor, + out_loc: torch.Tensor, + is_neox: bool, +) -> None: + from sglang.jit_kernel.rope import apply_rope_inplace_with_kvcache + + apply_rope_inplace_with_kvcache( + q, k, v, k_cache, v_cache, COS_SIN_CACHE, positions, out_loc, is_neox=is_neox + ) + + +# --------------------------------------------------------------------------- +# Benchmark configuration (shared) +# --------------------------------------------------------------------------- + +BS_RANGE = get_benchmark_range( + full_range=[2**n for n in range(0, 16)], + ci_range=[16], +) +QK_HEAD_RANGE = get_benchmark_range( + full_range=[(8, 1), (16, 2), (32, 8)], + ci_range=[(16, 2)], +) +QK_HEAD_RANGE = [f"{q},{k}" for q, k in QK_HEAD_RANGE] +IS_NEOX_RANGE = get_benchmark_range( + full_range=[True, False], + ci_range=[True], +) + + +# --------------------------------------------------------------------------- +# Benchmark 1: RoPE only +# --------------------------------------------------------------------------- + +ROPE_LINE_VALS = ["fi", "rope_v0", "rope_v1", "rope_v2"] +ROPE_LINE_NAMES = ["FlashInfer", "SGL RoPE v0", "SGL RoPE v1", "SGL RoPE v2"] +ROPE_STYLES = [("green", "-."), ("red", "-"), ("orange", "-"), ("blue", "--")] + +rope_configs = list(itertools.product(QK_HEAD_RANGE, IS_NEOX_RANGE, BS_RANGE)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_q_k_heads", "is_neox", "batch_size"], + x_vals=rope_configs, + line_arg="provider", + line_vals=ROPE_LINE_VALS, + line_names=ROPE_LINE_NAMES, + styles=ROPE_STYLES, + ylabel="us", + plot_name="rope-performance", + args={}, + ) +) +def benchmark(batch_size: int, num_q_k_heads: str, is_neox: bool, provider: str): + qo, kv = num_q_k_heads.split(",") + num_qo_heads = int(qo) + num_kv_heads = int(kv) + q = torch.randn( + (batch_size, num_qo_heads, ROPE_DIM), + dtype=DEFAULT_DTYPE, + device=DEFAULT_DEVICE, + ) + k = torch.randn( + (batch_size, num_kv_heads, ROPE_DIM), + dtype=DEFAULT_DTYPE, + device=DEFAULT_DEVICE, + ) + seed = batch_size << 16 | num_qo_heads << 8 | num_kv_heads << 4 | is_neox + torch.random.manual_seed(seed) + positions = torch.randint( + MAX_SEQ_LEN, (batch_size,), device=DEFAULT_DEVICE, dtype=torch.int64 + ) + torch.cuda.synchronize() + + FN_MAP = { + "fi": flashinfer_rope, + "rope_v0": sglang_rope_v0, + "rope_v1": sglang_rope_v1, + "rope_v2": sglang_rope_v2, + } + fn = lambda: FN_MAP[provider](q, k, positions, is_neox) + return run_benchmark(fn) + + +# --------------------------------------------------------------------------- +# Benchmark 2: RoPE + KV cache store +# --------------------------------------------------------------------------- + +STORE_LINE_VALS = ["rope_v0_store", "rope_v1_store", "rope_v2_store"] +STORE_LINE_NAMES = ["SGL RoPE v0 + Store", "SGL RoPE v1 + Store", "SGL RoPE v2 + Store"] +STORE_STYLES = [("red", "-"), ("orange", "-"), ("blue", "--")] + +store_configs = list(itertools.product(QK_HEAD_RANGE, IS_NEOX_RANGE, BS_RANGE)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_q_k_heads", "is_neox", "batch_size"], + x_vals=store_configs, + line_arg="provider", + line_vals=STORE_LINE_VALS, + line_names=STORE_LINE_NAMES, + styles=STORE_STYLES, + ylabel="us", + plot_name="rope-store-performance", + args={}, + ) +) +def benchmark_store(batch_size: int, num_q_k_heads: str, is_neox: bool, provider: str): + qo, kv = num_q_k_heads.split(",") + num_qo_heads = int(qo) + num_kv_heads = int(kv) + q = torch.randn( + (batch_size, num_qo_heads, ROPE_DIM), + dtype=DEFAULT_DTYPE, + device=DEFAULT_DEVICE, + ) + k = torch.randn( + (batch_size, num_kv_heads, ROPE_DIM), + dtype=DEFAULT_DTYPE, + device=DEFAULT_DEVICE, + ) + v = torch.randn( + (batch_size, num_kv_heads, ROPE_DIM), + dtype=DEFAULT_DTYPE, + device=DEFAULT_DEVICE, + ) + row_size = num_kv_heads * ROPE_DIM + k_cache = torch.zeros( + CACHE_SIZE, row_size, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE + ) + v_cache = torch.zeros( + CACHE_SIZE, row_size, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE + ) + out_loc = torch.randperm(CACHE_SIZE, device=DEFAULT_DEVICE, dtype=torch.int64)[ + :batch_size + ] + seed = batch_size << 16 | num_qo_heads << 8 | num_kv_heads << 4 | is_neox + torch.random.manual_seed(seed) + positions = torch.randint( + MAX_SEQ_LEN, (batch_size,), device=DEFAULT_DEVICE, dtype=torch.int64 + ) + torch.cuda.synchronize() + + FN_MAP = { + "rope_v0_store": rope_v0_store, + "rope_v1_store": rope_v1_store, + "rope_v2_store": rope_v2_store, + } + fn = lambda: FN_MAP[provider]( + q, k, v, k_cache, v_cache, positions, out_loc, is_neox + ) + return run_benchmark(fn) + + +if __name__ == "__main__": + print("Running RoPE performance benchmark...") + benchmark.run(print_data=True) + print("\nRunning RoPE + KV cache store performance benchmark...") + benchmark_store.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/bench_store_cache.py b/sglang/python/sglang/jit_kernel/benchmark/bench_store_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..1f179242e9e7620bb0caacca243d821b1c0bb58e --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/bench_store_cache.py @@ -0,0 +1,151 @@ +import itertools +from typing import Tuple + +import torch +import triton +import triton.testing +from sgl_kernel import set_kv_buffer_kernel + +from sglang.jit_kernel.benchmark.utils import ( + DEFAULT_DEVICE, + DEFAULT_DTYPE, + DEFAULT_QUANTILES, + get_benchmark_range, +) +from sglang.jit_kernel.kvcache import store_cache + +_is_hip = bool(torch.version.hip) +HAS_AOT_STORE_CACHE = hasattr(torch.ops.sgl_kernel, "store_kv_cache") + + +def sglang_aot_store_cache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, +) -> None: + set_kv_buffer_kernel(k_cache, v_cache, indices, k, v) + + +def sglang_jit_store_cache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, +) -> None: + store_cache(k, v, k_cache, v_cache, indices) + + +@torch.compile() +def torch_compile_store_cache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, +) -> None: + k_cache[indices] = k + v_cache[indices] = v + + +alt_stream = torch.cuda.Stream() + + +def torch_streams_store_cache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, +) -> None: + current_stream = torch.cuda.current_stream() + alt_stream.wait_stream(current_stream) + k_cache[indices] = k + with torch.cuda.stream(alt_stream): + v_cache[indices] = v + current_stream.wait_stream(alt_stream) + + +NUM_LAYERS = 8 +CACHE_SIZE = 2 * 1024 * 1024 // NUM_LAYERS + +BS_RANGE = get_benchmark_range( + full_range=[2**n for n in range(0, 15)], + ci_range=[16], +) +ITEM_SIZE = get_benchmark_range( + full_range=[64, 128, 256, 512, 1024], + ci_range=[1024], +) + +LINE_VALS = ["jit", "torch_compile", "torch_streams"] +LINE_NAMES = ["SGL JIT Kernel", "PyTorch Compile", "PyTorch 2 Stream"] +STYLES = [("blue", "--"), ("red", ":"), ("green", "-.")] +# Keep non-HIP benchmark lines unchanged; only HIP tolerates missing AOT op. +if (not _is_hip) or HAS_AOT_STORE_CACHE: + LINE_VALS = ["aot"] + LINE_VALS + LINE_NAMES = ["SGL AOT Kernel"] + LINE_NAMES + STYLES = [("orange", "-")] + STYLES +X_NAMES = ["item_size", "batch_size"] +CONFIGS = list(itertools.product(ITEM_SIZE, BS_RANGE)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=X_NAMES, + x_vals=CONFIGS, + line_arg="provider", + line_vals=LINE_VALS, + line_names=LINE_NAMES, + styles=STYLES, + ylabel="us", + plot_name="store-kvcache-performance", + args={}, + ) +) +def benchmark( + batch_size: int, item_size: int, provider: str +) -> Tuple[float, float, float]: + k = torch.randn( + (NUM_LAYERS, batch_size, item_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE + ) + v = torch.randn( + (NUM_LAYERS, batch_size, item_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE + ) + k_cache = torch.randn( + (NUM_LAYERS, CACHE_SIZE, item_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE + ) + v_cache = torch.randn( + (NUM_LAYERS, CACHE_SIZE, item_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE + ) + indices = torch.randperm(CACHE_SIZE, device=DEFAULT_DEVICE)[:batch_size] + torch.cuda.synchronize() + + FN_MAP = { + "jit": sglang_jit_store_cache, + "torch_compile": torch_compile_store_cache, + "torch_streams": torch_streams_store_cache, + } + if (not _is_hip) or HAS_AOT_STORE_CACHE: + FN_MAP["aot"] = sglang_aot_store_cache + + def fn(): + impl = FN_MAP[provider] + for i in range(NUM_LAYERS): + impl(k[i], v[i], k_cache[i], v_cache[i], indices) + + # Custom time calculation: divide by NUM_LAYERS + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + fn, quantiles=DEFAULT_QUANTILES + ) + return ( + 1000 * ms / NUM_LAYERS, + 1000 * max_ms / NUM_LAYERS, + 1000 * min_ms / NUM_LAYERS, + ) + + +if __name__ == "__main__": + benchmark.run(print_data=True) diff --git a/sglang/python/sglang/jit_kernel/benchmark/utils.py b/sglang/python/sglang/jit_kernel/benchmark/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a4d3f277378758306cde30f82af99e01ccd31992 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/benchmark/utils.py @@ -0,0 +1,42 @@ +"""Common utilities for jit_kernel benchmark files.""" + +import os +from typing import Callable, List, Tuple + +import torch +import triton.testing + +# Common constants +DEFAULT_DTYPE = torch.bfloat16 +DEFAULT_DEVICE = "cuda" +DEFAULT_QUANTILES = [0.5, 0.2, 0.8] + + +def is_in_ci() -> bool: + """Check if running in CI environment.""" + return ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" + ) + + +def get_benchmark_range(full_range: List, ci_range: List) -> List: + """Return appropriate benchmark range based on CI environment.""" + return ci_range if is_in_ci() else full_range + + +def run_benchmark( + fn: Callable, quantiles: List[float] = None +) -> Tuple[float, float, float]: + """Execute benchmark using CUDA graph and return times in microseconds. + + Args: + fn: Function to benchmark + quantiles: Quantiles for timing measurements [median, min, max] + + Returns: + Tuple of (median_us, max_us, min_us) + """ + quantiles = quantiles or DEFAULT_QUANTILES + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms diff --git a/sglang/python/sglang/jit_kernel/concat_mla.py b/sglang/python/sglang/jit_kernel/concat_mla.py new file mode 100644 index 0000000000000000000000000000000000000000..4945b73bc27fbd1607cef83a65bb8148e750acbb --- /dev/null +++ b/sglang/python/sglang/jit_kernel/concat_mla.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_concat_mla_k_module() -> Module: + return load_jit( + "concat_mla_k", + cuda_files=["elementwise/concat_mla.cuh"], + cuda_wrappers=[("concat_mla_k", "ConcatMlaKKernel::run")], + ) + + +@cache_once +def _jit_concat_mla_absorb_q_module() -> Module: + return load_jit( + "concat_mla_absorb_q", + cuda_files=["elementwise/concat_mla.cuh"], + cuda_wrappers=[("concat_mla_absorb_q", "ConcatMlaAbsorbQKernel::run")], + ) + + +def concat_mla_k(k: torch.Tensor, k_nope: torch.Tensor, k_rope: torch.Tensor) -> None: + """ + Concatenate k_nope and k_rope into k for MLA (Multi-head Latent Attention). + + This kernel efficiently broadcasts k_rope across all heads while copying + k_nope values directly. + + Args: + k: Output tensor of shape [num_tokens, num_heads=128, k_head_dim=192], dtype=bfloat16 + k_nope: Input tensor of shape [num_tokens, num_heads=128, nope_head_dim=128], dtype=bfloat16 + k_rope: Input tensor of shape [num_tokens, 1, rope_head_dim=64], dtype=bfloat16 + """ + module = _jit_concat_mla_k_module() + module.concat_mla_k(k, k_nope, k_rope) + + +def concat_mla_absorb_q(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Concatenate tensors a and b for MLA absorbed Q computation. + + Args: + a: Input tensor of shape [dim_0, dim_1, a_last_dim], dtype=bfloat16 + b: Input tensor of shape [dim_0, dim_1, b_last_dim], dtype=bfloat16 + + Returns: + Output tensor of shape [dim_0, dim_1, a_last_dim + b_last_dim], dtype=bfloat16 + """ + out = torch.empty( + (*a.shape[:-1], a.shape[-1] + b.shape[-1]), + dtype=a.dtype, + device=a.device, + ) + module = _jit_concat_mla_absorb_q_module() + module.concat_mla_absorb_q(a, b, out) + return out diff --git a/sglang/python/sglang/jit_kernel/csrc/add_constant.cuh b/sglang/python/sglang/jit_kernel/csrc/add_constant.cuh new file mode 100644 index 0000000000000000000000000000000000000000..6c723a761db2e6ac4b8671feb8508b6c48b099d0 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/add_constant.cuh @@ -0,0 +1,59 @@ +#include // For TensorMatcher, SymbolicSize, SymbolicDevice +#include // For div_ceil, RuntimeCheck + +#include // For LaunchKernel + +#include +#include + +#include +#include + +namespace { + +template +__global__ void add_constant_kernel(int32_t* dst, const int32_t* src, size_t length) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < length) { + dst[idx] = src[idx] + kConstant; + } +} + +constexpr size_t kBlockSize = 256; + +// You can also use struct with static method as an alternative +template +void add_constant(tvm::ffi::TensorView dst, tvm::ffi::TensorView src) { + using namespace host; + + // 1. Validate input tensors + SymbolicSize N = {"num_elements"}; + SymbolicDevice device_; + TensorMatcher({N}) // 1D tensor, must be contiguous + .with_dtype() // must be int32 + .with_device(device_) // must be on CUDA device + .verify(dst) // check tensor dst + .verify(src); // check tensor src + + // 2. Extract required parameters, prepare for kernel launch + const size_t num_elements = N.unwrap(); + const size_t grid_size = div_ceil(num_elements, kBlockSize); + const DLDevice device = device_.unwrap(); + [[maybe_unused]] // optional, can be omitted + const size_t dynamic_smem = 0; + [[maybe_unused]] // optional, LaunchKernel can auto determine stream from device + const cudaStream_t stream = LaunchKernel::resolve_device(device); + // some extra runtime checks using host::RuntimeCheck + RuntimeCheck(num_elements > 0, "We only support non-empty tensors, got num_elements = ", num_elements); + + // 3. Launch the kernel. Error code will be automatically checked. + LaunchKernel(grid_size, kBlockSize, device /*, dynamic_smem*/)( + // kernel function + add_constant_kernel, + // kernel arguments + static_cast(dst.data_ptr()), + static_cast(src.data_ptr()), + num_elements); +} + +} // namespace diff --git a/sglang/python/sglang/jit_kernel/csrc/diffusion/timestep_embedding.cuh b/sglang/python/sglang/jit_kernel/csrc/diffusion/timestep_embedding.cuh new file mode 100644 index 0000000000000000000000000000000000000000..2d29da50bf431b4de3203d6a3203b18a8c764678 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/diffusion/timestep_embedding.cuh @@ -0,0 +1,150 @@ +#include +#include + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include + +namespace { + +template +__global__ void timestep_embedding_kernel( + const TIn* __restrict__ t_ptr, + float* __restrict__ output_ptr, + int dim, + float neg_log_max_period, + float scale, + int batch_size) { + int row_idx = static_cast(blockIdx.x * blockDim.y + threadIdx.y); + if (row_idx >= batch_size) { + return; + } + + float t_val = device::cast(t_ptr[row_idx]); + float* output_batch_base_ptr = output_ptr + row_idx * dim; + + int half_dim = dim / 2; + int thread_offset = static_cast(threadIdx.x); + while (thread_offset * 4 < half_dim) { + float4* top_half; + float4* bottom_half; + if constexpr (!kFlipSinToCos) { + bottom_half = reinterpret_cast(output_batch_base_ptr + thread_offset * 4); + top_half = reinterpret_cast(output_batch_base_ptr + half_dim + thread_offset * 4); + } else { + top_half = reinterpret_cast(output_batch_base_ptr + thread_offset * 4); + bottom_half = reinterpret_cast(output_batch_base_ptr + half_dim + thread_offset * 4); + } + + float4 vals; + vals.x = scale * t_val * device::math::exp(neg_log_max_period * __int2float_rn(thread_offset * 4 + 0)); + vals.y = scale * t_val * device::math::exp(neg_log_max_period * __int2float_rn(thread_offset * 4 + 1)); + vals.z = scale * t_val * device::math::exp(neg_log_max_period * __int2float_rn(thread_offset * 4 + 2)); + vals.w = scale * t_val * device::math::exp(neg_log_max_period * __int2float_rn(thread_offset * 4 + 3)); + + float4 cos_vals; + cos_vals.x = device::math::cos(vals.x); + cos_vals.y = device::math::cos(vals.y); + cos_vals.z = device::math::cos(vals.z); + cos_vals.w = device::math::cos(vals.w); + *top_half = cos_vals; + + float4 sin_vals; + sin_vals.x = device::math::sin(vals.x); + sin_vals.y = device::math::sin(vals.y); + sin_vals.z = device::math::sin(vals.z); + sin_vals.w = device::math::sin(vals.w); + *bottom_half = sin_vals; + + thread_offset += static_cast(blockDim.x); + } +} + +template +inline void launch_timestep_embedding( + const tvm::ffi::TensorView t, + const tvm::ffi::TensorView output, + int dim, + bool flip_sin_to_cos, + float downscale_freq_shift, + float scale, + int max_period) { + using namespace host; + + const int batch_size = static_cast(t.shape()[0]); + const int half_dim = dim / 2; + + constexpr int kMaxThreadsPerBlock = 1024; + constexpr int kMinThreadsPerBlock = 128; + + const int num_threads_per_row = std::min(kMaxThreadsPerBlock, half_dim / 4); + const int num_rows = (kMinThreadsPerBlock + num_threads_per_row - 1) / num_threads_per_row; + + dim3 grid((batch_size + num_rows - 1) / num_rows); + dim3 block(num_threads_per_row, num_rows); + + const float neg_log_max_period = + std::log(static_cast(max_period)) * (-1.0f) / (static_cast(half_dim) - downscale_freq_shift); + + const DLDevice device = output.device(); + + if (flip_sin_to_cos) { + LaunchKernel(grid, block, device)( + timestep_embedding_kernel, + static_cast(t.data_ptr()), + static_cast(output.data_ptr()), + dim, + neg_log_max_period, + scale, + batch_size); + } else { + LaunchKernel(grid, block, device)( + timestep_embedding_kernel, + static_cast(t.data_ptr()), + static_cast(output.data_ptr()), + dim, + neg_log_max_period, + scale, + batch_size); + } +} + +template +void timestep_embedding( + tvm::ffi::TensorView input, + tvm::ffi::TensorView output, + int dim, + bool flip_sin_to_cos, + float downscale_freq_shift, + float scale, + int max_period) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto D = SymbolicSize{"dim"}; + auto device = SymbolicDevice{}; + + TensorMatcher({B}) // input + .with_strides({1}) + .with_dtype() + .template with_device(device) + .verify(input); + + TensorMatcher({B, D}).with_strides({D, 1}).with_dtype().template with_device(device).verify(output); + + RuntimeCheck(D.unwrap() == dim, "Output dim mismatch: ", D.unwrap(), " vs ", dim); + RuntimeCheck(dim % 8 == 0, "dim must align to 8, got ", dim); + + launch_timestep_embedding(input, output, dim, flip_sin_to_cos, downscale_freq_shift, scale, max_period); +} + +} // namespace diff --git a/sglang/python/sglang/jit_kernel/csrc/elementwise/concat_mla.cuh b/sglang/python/sglang/jit_kernel/csrc/elementwise/concat_mla.cuh new file mode 100644 index 0000000000000000000000000000000000000000..eee33318fc83c1ec929597e408cc58ce64b362cf --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/elementwise/concat_mla.cuh @@ -0,0 +1,325 @@ +#include +#include + +#include + +#include + +#include +#include + +namespace { + +// ======================= Memory Utilities ======================= +// Adapted from DeepEP: https://github.com/deepseek-ai/DeepEP/blob/main/csrc/kernels/utils.cuh + +SGL_DEVICE int get_lane_id() { + int lane_id; + asm("mov.s32 %0, %laneid;" : "=r"(lane_id)); + return lane_id; +} + +SGL_DEVICE void st_na_global_v1(const int* ptr, int v) { + asm volatile("st.global.L1::no_allocate.s32 [%0], %1;" ::"l"(ptr), "r"(v) : "memory"); +} + +SGL_DEVICE void st_na_global_v2(const int2* ptr, const int2& v) { + asm volatile("st.global.L1::no_allocate.v2.s32 [%0], {%1, %2};" ::"l"(ptr), "r"(v.x), "r"(v.y) : "memory"); +} + +SGL_DEVICE int ld_na_global_v1(const int* ptr) { + int r; + asm volatile("ld.global.nc.L1::no_allocate.s32 %0, [%1];" : "=r"(r) : "l"(ptr)); + return r; +} + +SGL_DEVICE int2 ld_na_global_v2(const int2* ptr) { + int2 r; + asm volatile("ld.global.nc.L1::no_allocate.v2.s32 {%0, %1}, [%2];" : "=r"(r.x), "=r"(r.y) : "l"(ptr)); + return r; +} + +SGL_DEVICE void prefetch_L2(const void* p) { +#if defined(ENABLE_L2_PREFETCH) + asm volatile("prefetch.global.L2 [%0];" ::"l"(p)); +#endif +} + +// ======================= concat_mla_k Kernel ======================= + +constexpr int NUM_LOCAL_HEADS = 128; +constexpr int QK_NOPE_HEAD_DIM = 128; +constexpr int QK_ROPE_HEAD_DIM = 64; +constexpr int K_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM; + +constexpr int HEAD_CHUNK_SIZE = 16; +constexpr int NUM_HEAD_CHUNKS = NUM_LOCAL_HEADS / HEAD_CHUNK_SIZE; + +__global__ void concat_mla_k_kernel( + bf16_t* __restrict__ k, + const bf16_t* __restrict__ k_nope, + const bf16_t* __restrict__ k_rope, + const int num_tokens, + const int64_t k_stride_0, + const int k_stride_1, + const int64_t k_nope_stride_0, + const int k_nope_stride_1, + const int64_t k_rope_stride_0) { + const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + const int token_id = flat_warp_id / NUM_HEAD_CHUNKS; + const int head_chunk_id = flat_warp_id % NUM_HEAD_CHUNKS; + const int lane_id = get_lane_id(); + if (token_id >= num_tokens) return; + + using NopeVec = int2; // 8B/thread, 32 threads = 256B/row + using RopeVec = int; // 4B/thread, 32 threads = 128B/row + static_assert(sizeof(NopeVec) * 32 == QK_NOPE_HEAD_DIM * sizeof(bf16_t), "nope vec mismatch"); + static_assert(sizeof(RopeVec) * 32 == QK_ROPE_HEAD_DIM * sizeof(bf16_t), "rope vec mismatch"); + + const int head_row0 = head_chunk_id * HEAD_CHUNK_SIZE; + + const int2* __restrict__ nope_src = + reinterpret_cast(k_nope + token_id * k_nope_stride_0 + head_row0 * k_nope_stride_1) + lane_id; + + int2* __restrict__ nope_dst = reinterpret_cast(k + token_id * k_stride_0 + head_row0 * k_stride_1) + lane_id; + + int* __restrict__ rope_dst = + reinterpret_cast(k + token_id * k_stride_0 + head_row0 * k_stride_1 + QK_NOPE_HEAD_DIM) + lane_id; + + const int nope_src_stride_v = (k_nope_stride_1 >> 2); // int2 covers 4 bf16 + const int nope_dst_stride_v = (k_stride_1 >> 2); + const int rope_dst_stride_v = (k_stride_1 >> 1); // int covers 2 bf16 + + const int* rope_base = reinterpret_cast(k_rope + token_id * k_rope_stride_0); + const RopeVec rope_val = ld_na_global_v1(rope_base + lane_id); + + prefetch_L2(nope_src); + NopeVec cur = ld_na_global_v2(nope_src); + +#pragma unroll + for (int i = 0; i < HEAD_CHUNK_SIZE; ++i) { + NopeVec next; + if (i + 1 < HEAD_CHUNK_SIZE) { + const int2* next_src = nope_src + nope_src_stride_v; + prefetch_L2(next_src); + next = ld_na_global_v2(next_src); + } + + st_na_global_v2(nope_dst, cur); + st_na_global_v1(rope_dst, rope_val); + + nope_src += nope_src_stride_v; + nope_dst += nope_dst_stride_v; + rope_dst += rope_dst_stride_v; + + cur = next; + } +} + +struct ConcatMlaKKernel { + static void run(tvm::ffi::TensorView k, tvm::ffi::TensorView k_nope, tvm::ffi::TensorView k_rope) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto H = SymbolicSize{"num_heads"}; + auto D = SymbolicSize{"k_head_dim"}; + auto D_nope = SymbolicSize{"nope_head_dim"}; + auto D_rope = SymbolicSize{"rope_head_dim"}; + auto S0_k = SymbolicSize{"k_stride_0"}; + auto S1_k = SymbolicSize{"k_stride_1"}; + auto S0_k_nope = SymbolicSize{"k_nope_stride_0"}; + auto S1_k_nope = SymbolicSize{"k_nope_stride_1"}; + auto S0_k_rope = SymbolicSize{"k_rope_stride_0"}; + auto device = SymbolicDevice{}; + + // Set known fixed values + H.set_value(NUM_LOCAL_HEADS); + D.set_value(K_HEAD_DIM); + D_nope.set_value(QK_NOPE_HEAD_DIM); + D_rope.set_value(QK_ROPE_HEAD_DIM); + + // Verify k: [num_tokens, num_heads, k_head_dim] + TensorMatcher({N, H, D}).with_strides({S0_k, S1_k, 1}).with_dtype().with_device(device).verify(k); + + // Verify k_nope: [num_tokens, num_heads, nope_head_dim] + TensorMatcher({N, H, D_nope}) + .with_strides({S0_k_nope, S1_k_nope, 1}) + .with_dtype() + .with_device(device) + .verify(k_nope); + + // Verify k_rope: [num_tokens, 1, rope_head_dim] + TensorMatcher({N, 1, D_rope}) + .with_strides({S0_k_rope, -1, 1}) + .with_dtype() + .with_device(device) + .verify(k_rope); + + // Check alignment + RuntimeCheck(reinterpret_cast(k.data_ptr()) % 16 == 0, "Tensor k must be 16-byte aligned"); + RuntimeCheck(reinterpret_cast(k_nope.data_ptr()) % 16 == 0, "Tensor k_nope must be 16-byte aligned"); + RuntimeCheck(reinterpret_cast(k_rope.data_ptr()) % 16 == 0, "Tensor k_rope must be 16-byte aligned"); + + const int num_tokens = static_cast(N.unwrap()); + + constexpr int num_warps_per_block = 32; + const int grid_size = div_ceil(num_tokens * NUM_HEAD_CHUNKS, num_warps_per_block); + const int block_size = num_warps_per_block * 32; + + LaunchKernel(grid_size, block_size, device.unwrap())( + concat_mla_k_kernel, + static_cast(k.data_ptr()), + static_cast(k_nope.data_ptr()), + static_cast(k_rope.data_ptr()), + num_tokens, + S0_k.unwrap(), + static_cast(S1_k.unwrap()), + S0_k_nope.unwrap(), + static_cast(S1_k_nope.unwrap()), + S0_k_rope.unwrap()); + } +}; + +// ======================= concat_mla_absorb_q Kernel ======================= + +constexpr int A_LAST_DIM = 512; +constexpr int B_LAST_DIM = 64; +constexpr int OUT_LAST_DIM = A_LAST_DIM + B_LAST_DIM; + +__global__ void concat_mla_absorb_q_kernel( + bf16_t* a, + bf16_t* b, + bf16_t* out, + const int num_items, + const int dim_1, + const int64_t a_stride_0, + const int a_stride_1, + const int64_t b_stride_0, + const int b_stride_1, + const int64_t out_stride_0, + const int out_stride_1) { + const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + const int lane_id = get_lane_id(); + + const int idx_0 = flat_warp_id / dim_1; + const int idx_1 = flat_warp_id % dim_1; + + if (flat_warp_id >= num_items) { + return; + } + + using ABufType = int4; + constexpr int A_NUM_UNROLL = 2; + static_assert(sizeof(ABufType) * A_NUM_UNROLL == A_LAST_DIM * sizeof(a[0]) / 32); + ABufType a_buf[A_NUM_UNROLL]; + + using BBufType = int; + constexpr int B_NUM_UNROLL = 1; + static_assert(sizeof(BBufType) * B_NUM_UNROLL == B_LAST_DIM * sizeof(b[0]) / 32); + BBufType b_buf; + + { + const BBufType* base_addr = reinterpret_cast(b + idx_0 * b_stride_0 + idx_1 * b_stride_1); + b_buf = *(base_addr + lane_id); + } + +#pragma unroll + for (int i = 0; i < A_NUM_UNROLL; ++i) { + const ABufType* base_addr = reinterpret_cast(a + idx_0 * a_stride_0 + idx_1 * a_stride_1); + a_buf[i] = *(base_addr + i * 32 + lane_id); + } + + { + BBufType* base_addr = reinterpret_cast(out + idx_0 * out_stride_0 + idx_1 * out_stride_1 + A_LAST_DIM); + *(base_addr + lane_id) = b_buf; + } + +#pragma unroll + for (int i = 0; i < A_NUM_UNROLL; ++i) { + ABufType* base_addr = reinterpret_cast(out + idx_0 * out_stride_0 + idx_1 * out_stride_1); + *(base_addr + i * 32 + lane_id) = a_buf[i]; + } +} + +struct ConcatMlaAbsorbQKernel { + static void run(tvm::ffi::TensorView a, tvm::ffi::TensorView b, tvm::ffi::TensorView out) { + using namespace host; + + auto N0_a = SymbolicSize{"a_dim_0"}; + auto N1_a = SymbolicSize{"a_dim_1"}; + auto D_a = SymbolicSize{"a_last_dim"}; + auto N0_b = SymbolicSize{"b_dim_0"}; + auto N1_b = SymbolicSize{"b_dim_1"}; + auto D_b = SymbolicSize{"b_last_dim"}; + auto N0_out = SymbolicSize{"out_dim_0"}; + auto N1_out = SymbolicSize{"out_dim_1"}; + auto D_out = SymbolicSize{"out_last_dim"}; + auto S0_a = SymbolicSize{"a_stride_0"}; + auto S1_a = SymbolicSize{"a_stride_1"}; + auto S0_b = SymbolicSize{"b_stride_0"}; + auto S1_b = SymbolicSize{"b_stride_1"}; + auto S0_out = SymbolicSize{"out_stride_0"}; + auto S1_out = SymbolicSize{"out_stride_1"}; + auto device = SymbolicDevice{}; + + // Set known fixed values + D_a.set_value(A_LAST_DIM); + D_b.set_value(B_LAST_DIM); + D_out.set_value(OUT_LAST_DIM); + + // Verify a: [dim_0, dim_1, A_LAST_DIM] + TensorMatcher({N0_a, N1_a, D_a}) + .with_strides({S0_a, S1_a, 1}) + .with_dtype() + .with_device(device) + .verify(a); + + // Verify b: [dim_0, dim_1, B_LAST_DIM] + TensorMatcher({N0_b, N1_b, D_b}) + .with_strides({S0_b, S1_b, 1}) + .with_dtype() + .with_device(device) + .verify(b); + + // Verify out: [dim_0, dim_1, OUT_LAST_DIM] + TensorMatcher({N0_out, N1_out, D_out}) + .with_strides({S0_out, S1_out, 1}) + .with_dtype() + .with_device(device) + .verify(out); + + // Check alignment + RuntimeCheck(reinterpret_cast(a.data_ptr()) % 16 == 0, "Tensor a must be 16-byte aligned"); + RuntimeCheck(reinterpret_cast(b.data_ptr()) % 16 == 0, "Tensor b must be 16-byte aligned"); + RuntimeCheck(reinterpret_cast(out.data_ptr()) % 16 == 0, "Tensor out must be 16-byte aligned"); + + // Verify dimensions match: a.size(0) * a.size(1) == b.size(0) * b.size(1) + RuntimeCheck( + N0_a.unwrap() * N1_a.unwrap() == N0_b.unwrap() * N1_b.unwrap(), + "Dimension mismatch: a.size(0) * a.size(1) must equal b.size(0) * b.size(1)"); + RuntimeCheck(N1_a.unwrap() == N1_b.unwrap(), "Dimension mismatch: a.size(1) must equal b.size(1)"); + + const int num_items = static_cast(N0_a.unwrap() * N1_a.unwrap()); + const int dim_1 = static_cast(N1_a.unwrap()); + + constexpr int num_warps_per_block = 32; + const int grid_size = div_ceil(num_items, num_warps_per_block); + const int block_size = num_warps_per_block * 32; + + LaunchKernel(grid_size, block_size, device.unwrap())( + concat_mla_absorb_q_kernel, + static_cast(a.data_ptr()), + static_cast(b.data_ptr()), + static_cast(out.data_ptr()), + num_items, + dim_1, + S0_a.unwrap(), + static_cast(S1_a.unwrap()), + S0_b.unwrap(), + static_cast(S1_b.unwrap()), + S0_out.unwrap(), + static_cast(S1_out.unwrap())); + } +}; + +} // namespace diff --git a/sglang/python/sglang/jit_kernel/csrc/elementwise/fused_add_rmsnorm.cuh b/sglang/python/sglang/jit_kernel/csrc/elementwise/fused_add_rmsnorm.cuh new file mode 100644 index 0000000000000000000000000000000000000000..5455796afcf4fb16fd9707188bac9cf8c7a0ed32 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/elementwise/fused_add_rmsnorm.cuh @@ -0,0 +1,186 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +namespace { + +template +struct VecTypeTrait; + +template <> +struct VecTypeTrait { + using packed_t = packed_t; + using vec_t = device::AlignedVector; +}; + +template <> +struct VecTypeTrait { + using packed_t = packed_t; + using vec_t = device::AlignedVector; +}; + +template <> +struct VecTypeTrait { + using packed_t = packed_t; + using vec_t = device::AlignedVector; +}; + +template <> +struct VecTypeTrait { + using packed_t = packed_t; + using vec_t = device::AlignedVector; +}; + +template +SGL_DEVICE packed_t rms(packed_t& val, packed_t& weight, float rsqrt_square_sum) { + float2 valf = device::cast(val); + float2 weightf = device::cast(weight); + return device::cast( + make_float2(valf.x * weightf.x * rsqrt_square_sum, valf.y * weightf.y * rsqrt_square_sum)); +} + +template +__global__ void fused_add_rmsnorm_reg_kernel( + T* __restrict__ input, T* __restrict__ residual, const T* __restrict__ weight, int vec_hidden_size, float eps) { + constexpr int inner_loop = VEC_SIZE_IN_BYTE == 16 ? 4 : 8; + + __shared__ float shared_memory[32]; // Used for CTA reduce + + using vec_t = typename VecTypeTrait::vec_t; + using packed_t = typename VecTypeTrait::packed_t; + vec_t v; // Save input + vec_t v_res; // Save residual + vec_t v_weight; // Save weight + vec_t v_out; // Save output + + auto token_id = blockIdx.x; + float2 acc_square = make_float2(0.0f, 0.0f); // Sum of squares for each thread + + if (threadIdx.x < vec_hidden_size) { + // Compute address + vec_t* p = reinterpret_cast(input) + token_id * vec_hidden_size; + vec_t* p_res = reinterpret_cast(residual) + token_id * vec_hidden_size; + const vec_t* p_weight = reinterpret_cast(weight); + + // Load data + v = p[threadIdx.x]; + v_res = p_res[threadIdx.x]; + v_weight = p_weight[threadIdx.x]; + + for (int i = 0; i < inner_loop; i++) { + float2 val = device::cast(v[i]); + float2 res = device::cast(v_res[i]); + float2 inp_res = make_float2(val.x + res.x, val.y + res.y); + acc_square.x += inp_res.x * inp_res.x; + acc_square.y += inp_res.y * inp_res.y; + v[i] = device::cast(inp_res); + } + + // Store inp+res to residual + p_res[threadIdx.x] = v; + } + + // CTA Reduce + // Step 0: Warp Reduce + auto cg_warp = cooperative_groups::tiled_partition<32>(cooperative_groups::this_thread_block()); + float warp_sum = cooperative_groups::reduce(cg_warp, acc_square.x + acc_square.y, cooperative_groups::plus()); + + float* buffer = shared_memory; + if (threadIdx.x % 32 == 0) { + buffer[threadIdx.x / 32] = warp_sum; // Write warp_sum to buffer + } + + // Step 1: CTA Reduce + __syncthreads(); + if (threadIdx.x < 32) { + float cta_sum = cooperative_groups::reduce( + cg_warp, (threadIdx.x < blockDim.x / 32) ? buffer[threadIdx.x] : 0.0f, cooperative_groups::plus()); + buffer[threadIdx.x] = + rsqrtf(eps + cta_sum * (1.0f / static_cast(vec_hidden_size * (VEC_SIZE_IN_BYTE / sizeof(T))))); + } + __syncthreads(); + + // Compute RMSNorm + if (threadIdx.x < vec_hidden_size) { + float rsqrt_square_sum = buffer[threadIdx.x / 32]; // Read rsqrt from Shared Memory(Broadcast) + for (int i = 0; i < inner_loop; i++) { + v_out[i] = rms(v[i], v_weight[i], rsqrt_square_sum); + } + vec_t* p_out = reinterpret_cast(input) + token_id * vec_hidden_size; + p_out[threadIdx.x] = v_out; + } +} + +template +struct FusedAddRMSNormKernel { + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView residual, + const tvm::ffi::TensorView weight, + float eps) { + using namespace host; + auto N = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"hidden_size"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({N, D}) // input + .with_strides({D, 1}) + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({D}) // weight + .with_dtype() + .with_device(device) + .verify(weight); + TensorMatcher({N, D}) // residual + .with_strides({D, 1}) + .with_dtype() + .with_device(device) + .verify(residual); + + auto cc_major = host::runtime::get_cc_major(device.unwrap().device_id); + int hidden_size = static_cast(D.unwrap()); + if ((cc_major <= 9 && hidden_size <= 8192) || (cc_major >= 10 && hidden_size <= 12288)) { + int max_vec_size_byte = cc_major >= 10 ? 32 : 16; + int elements_in_vec = max_vec_size_byte / sizeof(DType); + int vec_hidden_size = hidden_size / elements_in_vec; + uint threads = (vec_hidden_size + 31) / 32 * 32; + + // Runtime check + host::RuntimeCheck( + hidden_size % elements_in_vec == 0, + "hidden_size", + hidden_size, + " can not align to elements_in_vec ", + elements_in_vec); + + // Launch kernel + auto kernel = + max_vec_size_byte == 32 ? fused_add_rmsnorm_reg_kernel : fused_add_rmsnorm_reg_kernel; + LaunchKernel(static_cast(N.unwrap()), threads, device.unwrap()) + .enable_pdl(false)( + kernel, + reinterpret_cast(input.data_ptr()), + reinterpret_cast(residual.data_ptr()), + reinterpret_cast(weight.data_ptr()), + vec_hidden_size, + eps); + } else { + host::RuntimeCheck(false, "Large hidden_sizes are not supported for now."); + } + } +}; + +} // namespace diff --git a/sglang/python/sglang/jit_kernel/csrc/elementwise/fused_metadata_copy.cuh b/sglang/python/sglang/jit_kernel/csrc/elementwise/fused_metadata_copy.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c996f6f1b861228c12a507e0d8c91525d22f0e72 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/elementwise/fused_metadata_copy.cuh @@ -0,0 +1,722 @@ +/* + * Fused metadata copy kernel for NSA backend CUDA graph replay. + * JIT-compiled version for python/sglang/jit_kernel. + * + * OVERVIEW: + * This kernel fuses multiple tensor copy operations (cache_seqlens, cu_seqlens_k, + * page_table, nsa metadata, and optional FlashMLA metadata) into single kernel + * launches, significantly reducing kernel launch overhead and improving CUDA + * graph replay performance during inference. + * + * PERFORMANCE BENEFITS: + * - Single kernel launch vs. multiple separate copies (3-10x faster) + * - Optimized memory coalescing and SM utilization + * - __grid_constant__ parameter passing via constant memory + * - Especially beneficial in CUDA graph replay scenarios + * + * DESIGN: + * - Unified kernel supporting all forward modes (DECODE, TARGET_VERIFY, DRAFT_EXTEND) + * - Structured parameter passing (SourcePointers/DestinationPointers) for clarity + * - Template parameters (HAS_REAL_PAGE_TABLE, HAS_FLASHMLA) for compile-time optimization + * - Multi-backend variant copies to 3 destinations in one kernel (for speculative decoding) + * + * USAGE: + * This header is included by JIT compilation system. The FusedMetadataCopyKernel + * and FusedMetadataCopyMultiKernel wrapper structs provide the Python-accessible interface. + */ + +#pragma once + +#include +#include + +#include + +#include + +#include // for std::min +#include + +// Forward mode enum (must match Python ForwardMode in sglang/srt/layers/attention/nsa_backend.py) +enum ForwardModeEnum { DECODE = 0, TARGET_VERIFY = 1, DRAFT_EXTEND = 2 }; + +/** + * Source pointers for metadata copy operations. + * Groups all source tensor pointers for cleaner parameter passing. + * Some pointers may be nullptr depending on forward mode and feature flags. + */ +struct SourcePointers { + const int32_t* __restrict__ cache_seqlens; // [bs] sequence lengths in cache + const int32_t* __restrict__ cu_seqlens_k; // [bs+1] cumulative sequence lengths + const int32_t* __restrict__ page_indices; // page table indices + const int32_t* __restrict__ nsa_cache_seqlens; // NSA-specific cache lengths + const int32_t* __restrict__ seqlens_expanded; // expanded sequence lengths (TARGET_VERIFY/DRAFT_EXTEND only) + const int32_t* __restrict__ nsa_cu_seqlens_k; // NSA cumulative sequence lengths + const int32_t* __restrict__ real_page_table; // optional real page table + const int32_t* __restrict__ flashmla_num_splits; // optional FlashMLA split counts + const int32_t* __restrict__ flashmla_metadata; // optional FlashMLA metadata +}; + +/** + * Destination pointers for metadata copy operations. + * Groups all destination tensor pointers for cleaner parameter passing. + * Layout matches SourcePointers for consistency. + */ +struct DestinationPointers { + int32_t* __restrict__ cache_seqlens; // [bs] sequence lengths in cache + int32_t* __restrict__ cu_seqlens_k; // [bs+1] cumulative sequence lengths + int32_t* __restrict__ page_table_1; // page table (note: different name from source) + int32_t* __restrict__ nsa_cache_seqlens; // NSA-specific cache lengths + int32_t* __restrict__ seqlens_expanded; // expanded sequence lengths (TARGET_VERIFY/DRAFT_EXTEND only) + int32_t* __restrict__ nsa_cu_seqlens_k; // NSA cumulative sequence lengths + int32_t* __restrict__ real_page_table; // optional real page table + int32_t* __restrict__ flashmla_num_splits; // optional FlashMLA split counts + int32_t* __restrict__ flashmla_metadata; // optional FlashMLA metadata +}; + +/** + * Parameter structure for single-backend fused metadata copy kernel. + * Passed via __grid_constant__ for efficient constant memory access. + */ +struct FusedMetadataCopyParams { + SourcePointers src; // Source tensor pointers + DestinationPointers dst; // Destination tensor pointers + + // Kernel parameters + int forward_mode; // 0=DECODE, 1=TARGET_VERIFY, 2=DRAFT_EXTEND + int bs; // Batch size + int max_len; // Max length for DECODE mode + int max_seqlen_k; // Max sequence length for TARGET_VERIFY/DRAFT_EXTEND + int seqlens_expanded_size; // Size of expanded sequence lengths + int page_indices_rows; // Number of rows in page_indices + int page_table_1_stride; // Stride for page_table_1 + int real_page_table_cols; // Columns in real_page_table + int real_page_table_dst_stride; // Stride for destination real_page_table + int flashmla_metadata_size; // Size of FlashMLA metadata +}; + +/** + * Parameter structure for multi-backend fused metadata copy kernel. + * Enables copying from one source to three destinations in a single kernel launch. + * Used for speculative decoding with multiple draft backends. + */ +struct FusedMetadataCopyMultiParams { + SourcePointers src; // Source pointers (shared across all backends) + DestinationPointers dst0; // Backend 0 destination pointers + DestinationPointers dst1; // Backend 1 destination pointers + DestinationPointers dst2; // Backend 2 destination pointers + + // Kernel parameters + int bs; // Batch size + int max_len; // Max length (DECODE mode only) + int seqlens_expanded_size; // Size of expanded sequence lengths + int page_table_1_stride; // Stride for page_table_1 + int real_page_table_cols; // Columns in real_page_table + int real_page_table_dst_stride; // Stride for destination real_page_table + int flashmla_metadata_size; // Size of FlashMLA metadata +}; + +/** + * Unified kernel for all forward modes (DECODE, TARGET_VERIFY, DRAFT_EXTEND). + * Uses runtime branches for mode selection, with template parameters for + * compile-time optimization of optional features. + * + * DESIGN: + * - Runtime branches (forward_mode) handle mode-specific logic + * - Template parameters (HAS_*) eliminate unused feature code at compile time + * - Structured parameters (SourcePointers/DestinationPointers) passed via constant memory + * + * Used by FusedMetadataCopyKernel for single-backend metadata copy. + * + * @tparam HAS_REAL_PAGE_TABLE Compile-time flag for real_page_table support + * @tparam HAS_FLASHMLA Compile-time flag for FlashMLA metadata support + */ +template +__global__ void fused_metadata_copy_kernel(const FusedMetadataCopyParams __grid_constant__ params) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int total_threads = gridDim.x * blockDim.x; + + // Unpack parameters for readability + const auto& src = params.src; + const auto& dst = params.dst; + const int forward_mode = params.forward_mode; + const int bs = params.bs; + const int max_len = params.max_len; + const int max_seqlen_k = params.max_seqlen_k; + const int seqlens_expanded_size = params.seqlens_expanded_size; + const int page_indices_rows = params.page_indices_rows; + const int page_table_1_stride = params.page_table_1_stride; + const int real_page_table_cols = params.real_page_table_cols; + const int real_page_table_dst_stride = params.real_page_table_dst_stride; + const int flashmla_metadata_size = params.flashmla_metadata_size; + + // Copy cache_seqlens (bs elements) - common to all modes +#pragma unroll 8 + for (int i = tid; i < bs; i += total_threads) { + dst.cache_seqlens[i] = src.cache_seqlens[i]; + } + + // Copy cu_seqlens_k (skip first element) - common to all modes +#pragma unroll 8 + for (int i = tid; i < bs; i += total_threads) { + dst.cu_seqlens_k[i + 1] = src.cu_seqlens_k[i + 1]; + } + + // Branch 1: page_table copy (different dimensions per mode) + if (forward_mode == 0) { // DECODE + int page_table_elements = bs * max_len; +#pragma unroll 4 + for (int i = tid; i < page_table_elements; i += total_threads) { + int row = i / max_len; + int col = i % max_len; + dst.page_table_1[row * page_table_1_stride + col] = src.page_indices[i]; + } + } else { // TARGET_VERIFY or DRAFT_EXTEND + int page_table_elements = page_indices_rows * max_seqlen_k; +#pragma unroll 4 + for (int i = tid; i < page_table_elements; i += total_threads) { + int row = i / max_seqlen_k; + int col = i % max_seqlen_k; + dst.page_table_1[row * page_table_1_stride + col] = src.page_indices[i]; + } + } + + // Branch 2: seqlens_expanded copy (only for TARGET_VERIFY/DRAFT_EXTEND) + if (forward_mode != 0) { // TARGET_VERIFY or DRAFT_EXTEND +#pragma unroll 4 + for (int i = tid; i < seqlens_expanded_size; i += total_threads) { + dst.seqlens_expanded[i] = src.seqlens_expanded[i]; + } + } + + // Branch 3: NSA metadata copy (different loop sizes per mode) + if (forward_mode == 0) { // DECODE +#pragma unroll 8 + for (int i = tid; i < bs; i += total_threads) { + dst.nsa_cache_seqlens[i] = src.nsa_cache_seqlens[i]; + } + +#pragma unroll 8 + for (int i = tid; i < bs; i += total_threads) { + dst.nsa_cu_seqlens_k[i + 1] = src.nsa_cu_seqlens_k[i + 1]; + } + } else { // TARGET_VERIFY or DRAFT_EXTEND +#pragma unroll 4 + for (int i = tid; i < seqlens_expanded_size; i += total_threads) { + dst.nsa_cache_seqlens[i] = src.nsa_cache_seqlens[i]; + } + +#pragma unroll 4 + for (int i = tid; i < seqlens_expanded_size; i += total_threads) { + dst.nsa_cu_seqlens_k[i + 1] = src.nsa_cu_seqlens_k[i + 1]; + } + } + + // Copy real page table - compile-time branch + if constexpr (HAS_REAL_PAGE_TABLE) { + int real_table_elements = (forward_mode == 0 ? bs : page_indices_rows) * real_page_table_cols; +#pragma unroll 2 + for (int i = tid; i < real_table_elements; i += total_threads) { + int row = i / real_page_table_cols; + int col = i % real_page_table_cols; + dst.real_page_table[row * real_page_table_dst_stride + col] = + src.real_page_table[row * real_page_table_cols + col]; + } + } + + // Branch 4: FlashMLA metadata copy (different sizes per mode) + if constexpr (HAS_FLASHMLA) { + int flashmla_size = (forward_mode == 0) ? (bs + 1) : (seqlens_expanded_size + 1); + + if (forward_mode == 0) { +#pragma unroll 8 + for (int i = tid; i < flashmla_size; i += total_threads) { + dst.flashmla_num_splits[i] = src.flashmla_num_splits[i]; + } + } else { +#pragma unroll 4 + for (int i = tid; i < flashmla_size; i += total_threads) { + dst.flashmla_num_splits[i] = src.flashmla_num_splits[i]; + } + } + +#pragma unroll 2 + for (int i = tid; i < flashmla_metadata_size; i += total_threads) { + dst.flashmla_metadata[i] = src.flashmla_metadata[i]; + } + } +} + +/** + * Multi-backend kernel for DECODE mode. + * Copies from one source to THREE destinations in a single kernel launch. + * + * PERFORMANCE: 3x faster than three separate kernel launches due to: + * - Reduced kernel launch overhead (1 launch instead of 3) + * - Improved memory coalescing (source read once, written to 3 destinations) + * - Better instruction-level parallelism + * + * Used by FusedMetadataCopyMultiKernel for speculative decoding scenarios. + * + * @tparam HAS_REAL_PAGE_TABLE Compile-time flag for real_page_table support + * @tparam HAS_FLASHMLA Compile-time flag for FlashMLA metadata support + */ +template +__global__ void fused_metadata_copy_multi_kernel(const FusedMetadataCopyMultiParams __grid_constant__ params) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int total_threads = gridDim.x * blockDim.x; + + // Unpack parameters for readability + const auto& src = params.src; + const auto& dst0 = params.dst0; + const auto& dst1 = params.dst1; + const auto& dst2 = params.dst2; + const int bs = params.bs; + const int max_len = params.max_len; + const int seqlens_expanded_size = params.seqlens_expanded_size; + const int page_table_1_stride = params.page_table_1_stride; + const int real_page_table_cols = params.real_page_table_cols; + const int real_page_table_dst_stride = params.real_page_table_dst_stride; + const int flashmla_metadata_size = params.flashmla_metadata_size; + + // Copy cache_seqlens to all 3 backends +#pragma unroll 8 + for (int i = tid; i < bs; i += total_threads) { + int32_t val = src.cache_seqlens[i]; + dst0.cache_seqlens[i] = val; + dst1.cache_seqlens[i] = val; + dst2.cache_seqlens[i] = val; + } + + // Copy cu_seqlens_k to all 3 backends (skip first element) +#pragma unroll 8 + for (int i = tid; i < bs; i += total_threads) { + int32_t val = src.cu_seqlens_k[i + 1]; + dst0.cu_seqlens_k[i + 1] = val; + dst1.cu_seqlens_k[i + 1] = val; + dst2.cu_seqlens_k[i + 1] = val; + } + + // DECODE mode: copy page_table_1 to all 3 backends + int page_table_elements = bs * max_len; +#pragma unroll 4 + for (int i = tid; i < page_table_elements; i += total_threads) { + int row = i / max_len; + int col = i % max_len; + int32_t val = src.page_indices[i]; + dst0.page_table_1[row * page_table_1_stride + col] = val; + dst1.page_table_1[row * page_table_1_stride + col] = val; + dst2.page_table_1[row * page_table_1_stride + col] = val; + } + + // Copy nsa_cache_seqlens to all 3 backends +#pragma unroll 8 + for (int i = tid; i < bs; i += total_threads) { + int32_t val = src.nsa_cache_seqlens[i]; + dst0.nsa_cache_seqlens[i] = val; + dst1.nsa_cache_seqlens[i] = val; + dst2.nsa_cache_seqlens[i] = val; + } + + // Copy NSA cu_seqlens to all 3 backends +#pragma unroll 8 + for (int i = tid; i < bs; i += total_threads) { + int32_t val = src.nsa_cu_seqlens_k[i + 1]; + dst0.nsa_cu_seqlens_k[i + 1] = val; + dst1.nsa_cu_seqlens_k[i + 1] = val; + dst2.nsa_cu_seqlens_k[i + 1] = val; + } + + // Copy real page table to all 3 backends + if (src.real_page_table != nullptr && dst0.real_page_table != nullptr) { + int real_table_elements = bs * real_page_table_cols; +#pragma unroll 2 + for (int i = tid; i < real_table_elements; i += total_threads) { + int row = i / real_page_table_cols; + int col = i % real_page_table_cols; + int src_idx = row * real_page_table_cols + col; + int dst_idx = row * real_page_table_dst_stride + col; + int32_t val = src.real_page_table[src_idx]; + dst0.real_page_table[dst_idx] = val; + dst1.real_page_table[dst_idx] = val; + dst2.real_page_table[dst_idx] = val; + } + } + + // Copy FlashMLA metadata to all 3 backends + if constexpr (HAS_FLASHMLA) { + int flashmla_size = bs + 1; +#pragma unroll 8 + for (int i = tid; i < flashmla_size; i += total_threads) { + int32_t val = src.flashmla_num_splits[i]; + dst0.flashmla_num_splits[i] = val; + dst1.flashmla_num_splits[i] = val; + dst2.flashmla_num_splits[i] = val; + } + +#pragma unroll 2 + for (int i = tid; i < flashmla_metadata_size; i += total_threads) { + int32_t val = src.flashmla_metadata[i]; + dst0.flashmla_metadata[i] = val; + dst1.flashmla_metadata[i] = val; + dst2.flashmla_metadata[i] = val; + } + } +} + +// ============================================================================ +// Host-side launcher wrappers for JIT compilation +// ============================================================================ + +namespace { + +// Launch configuration constants +constexpr int THREADS_PER_BLOCK = 256; +constexpr int MAX_GRID_SIZE = 1024; // Limit to prevent excessive resource usage + +/** + * Helper function to extract a typed data pointer from a TensorView. + * Performs runtime type checking and returns the properly cast pointer. + * + * @tparam T The expected element type (e.g., int32_t) + * @param tensor The TensorView to extract the pointer from + * @param name The name of the tensor (for error reporting) + * @return Typed pointer to the tensor data + */ +template +inline const T* unwrap_data_ptr(const tvm::ffi::TensorView& tensor, const char* name) { + using namespace host; + if (tensor.data_ptr()) { + RuntimeCheck(is_type(tensor.dtype()), "Tensor ", name, " must have dtype int32"); + } + return static_cast(tensor.data_ptr()); +} + +/** + * Helper function to extract a typed mutable data pointer from a TensorView. + * Performs runtime type checking and returns the properly cast pointer. + * + * @tparam T The expected element type (e.g., int32_t) + * @param tensor The TensorView to extract the pointer from + * @param name The name of the tensor (for error reporting) + * @return Typed mutable pointer to the tensor data + */ +template +inline T* unwrap_data_ptr_mut(const tvm::ffi::TensorView& tensor, const char* name) { + using namespace host; + if (tensor.data_ptr()) { + RuntimeCheck(is_type(tensor.dtype()), "Tensor ", name, " must have dtype int32"); + } + return static_cast(tensor.data_ptr()); +} + +/** + * Helper function to extract a typed data pointer from an Optional TensorView. + * Returns nullptr if the optional has no value, otherwise performs type checking. + * + * @tparam T The expected element type (e.g., int32_t) + * @param optional_tensor The Optional TensorView to extract the pointer from + * @param name The name of the tensor (for error reporting) + * @return Typed pointer to the tensor data, or nullptr if optional has no value + */ +template +inline const T* +unwrap_optional_data_ptr(const tvm::ffi::Optional& optional_tensor, const char* name) { + using namespace host; + if (!optional_tensor.has_value()) { + return nullptr; + } + const auto& tensor = optional_tensor.value(); + RuntimeCheck(is_type(tensor.dtype()), "Tensor ", name, " must have dtype int32"); + return static_cast(tensor.data_ptr()); +} + +/** + * Helper function to extract a typed mutable data pointer from an Optional TensorView. + * Returns nullptr if the optional has no value, otherwise performs type checking. + * + * @tparam T The expected element type (e.g., int32_t) + * @param optional_tensor The Optional TensorView to extract the pointer from + * @param name The name of the tensor (for error reporting) + * @return Typed mutable pointer to the tensor data, or nullptr if optional has no value + */ +template +inline T* +unwrap_optional_data_ptr_mut(const tvm::ffi::Optional& optional_tensor, const char* name) { + using namespace host; + if (!optional_tensor.has_value()) { + return nullptr; + } + const auto& tensor = optional_tensor.value(); + RuntimeCheck(is_type(tensor.dtype()), "Tensor ", name, " must have dtype int32"); + return static_cast(tensor.data_ptr()); +} + +/** + * Calculate kernel launch configuration. + * + * @param total_work Total number of work items + * @param threads_per_block Threads per block (default: THREADS_PER_BLOCK) + * @return Grid dimension for kernel launch + */ +inline dim3 get_launch_config(int total_work, int threads_per_block = THREADS_PER_BLOCK) { + int num_blocks = (total_work + threads_per_block - 1) / threads_per_block; + // Limit grid size to prevent excessive resource usage while ensuring coverage + num_blocks = std::min(num_blocks, MAX_GRID_SIZE); + return dim3(num_blocks); +} + +/** + * JIT wrapper for single-backend fused metadata copy kernel. + * + * This struct provides a unified interface for launching the fused metadata copy + * kernel with different forward modes. It constructs the parameter struct and + * launches the unified kernel. + * + * IMPLEMENTATION: + * - Extracts raw pointers from TensorView objects + * - Constructs FusedMetadataCopyParams with nested SourcePointers/DestinationPointers + * - Calculates grid configuration based on maximum work size + * - Launches fused_metadata_copy_kernel with __grid_constant__ parameters + * + * @tparam FORWARD_MODE Forward mode: 0=DECODE, 1=TARGET_VERIFY, 2=DRAFT_EXTEND + * @tparam HAS_REAL_PAGE_TABLE Whether real_page_table tensors are present + * @tparam HAS_FLASHMLA Whether FlashMLA metadata tensors are present + */ +template +struct FusedMetadataCopyKernel { + static_assert( + FORWARD_MODE >= 0 && FORWARD_MODE <= 2, + "FORWARD_MODE must be 0 (DECODE), 1 (TARGET_VERIFY), or 2 (DRAFT_EXTEND)"); + + static void + run(const tvm::ffi::TensorView cache_seqlens_src, + const tvm::ffi::TensorView cu_seqlens_k_src, + const tvm::ffi::TensorView page_indices_src, + const tvm::ffi::TensorView nsa_cache_seqlens_src, + const tvm::ffi::Optional seqlens_expanded_src, + const tvm::ffi::TensorView nsa_cu_seqlens_k_src, + const tvm::ffi::Optional real_page_table_src, + const tvm::ffi::Optional flashmla_num_splits_src, + const tvm::ffi::Optional flashmla_metadata_src, + const tvm::ffi::TensorView cache_seqlens_dst, + const tvm::ffi::TensorView cu_seqlens_k_dst, + const tvm::ffi::TensorView page_table_1_dst, + const tvm::ffi::TensorView nsa_cache_seqlens_dst, + const tvm::ffi::Optional seqlens_expanded_dst, + const tvm::ffi::TensorView nsa_cu_seqlens_k_dst, + const tvm::ffi::Optional real_page_table_dst, + const tvm::ffi::Optional flashmla_num_splits_dst, + const tvm::ffi::Optional flashmla_metadata_dst, + int bs, + int max_len, + int max_seqlen_k, + int seqlens_expanded_size) { + using namespace host; + + // Build parameter struct with nested source/destination pointers + // unwrap_data_ptr and unwrap_optional_data_ptr perform dtype validation + const auto params = FusedMetadataCopyParams{ + .src = + { + .cache_seqlens = unwrap_data_ptr(cache_seqlens_src, "cache_seqlens_src"), + .cu_seqlens_k = unwrap_data_ptr(cu_seqlens_k_src, "cu_seqlens_k_src"), + .page_indices = unwrap_data_ptr(page_indices_src, "page_indices_src"), + .nsa_cache_seqlens = unwrap_data_ptr(nsa_cache_seqlens_src, "nsa_cache_seqlens_src"), + .seqlens_expanded = unwrap_optional_data_ptr(seqlens_expanded_src, "seqlens_expanded_src"), + .nsa_cu_seqlens_k = unwrap_data_ptr(nsa_cu_seqlens_k_src, "nsa_cu_seqlens_k_src"), + .real_page_table = unwrap_optional_data_ptr(real_page_table_src, "real_page_table_src"), + .flashmla_num_splits = + unwrap_optional_data_ptr(flashmla_num_splits_src, "flashmla_num_splits_src"), + .flashmla_metadata = unwrap_optional_data_ptr(flashmla_metadata_src, "flashmla_metadata_src"), + }, + .dst = + { + .cache_seqlens = unwrap_data_ptr_mut(cache_seqlens_dst, "cache_seqlens_dst"), + .cu_seqlens_k = unwrap_data_ptr_mut(cu_seqlens_k_dst, "cu_seqlens_k_dst"), + .page_table_1 = unwrap_data_ptr_mut(page_table_1_dst, "page_table_1_dst"), + .nsa_cache_seqlens = unwrap_data_ptr_mut(nsa_cache_seqlens_dst, "nsa_cache_seqlens_dst"), + .seqlens_expanded = unwrap_optional_data_ptr_mut(seqlens_expanded_dst, "seqlens_expanded_dst"), + .nsa_cu_seqlens_k = unwrap_data_ptr_mut(nsa_cu_seqlens_k_dst, "nsa_cu_seqlens_k_dst"), + .real_page_table = unwrap_optional_data_ptr_mut(real_page_table_dst, "real_page_table_dst"), + .flashmla_num_splits = + unwrap_optional_data_ptr_mut(flashmla_num_splits_dst, "flashmla_num_splits_dst"), + .flashmla_metadata = + unwrap_optional_data_ptr_mut(flashmla_metadata_dst, "flashmla_metadata_dst"), + }, + .forward_mode = FORWARD_MODE, + .bs = bs, + .max_len = max_len, + .max_seqlen_k = max_seqlen_k, + .seqlens_expanded_size = seqlens_expanded_size, + .page_indices_rows = static_cast(page_indices_src.shape()[0]), + .page_table_1_stride = static_cast(page_table_1_dst.shape()[1]), + .real_page_table_cols = + real_page_table_src.has_value() ? static_cast(real_page_table_src.value().shape()[1]) : 0, + .real_page_table_dst_stride = + real_page_table_dst.has_value() ? static_cast(real_page_table_dst.value().stride(0)) : 0, + .flashmla_metadata_size = + flashmla_metadata_src.has_value() ? static_cast(flashmla_metadata_src.value().numel()) : 0, + }; + + // Calculate grid configuration + int max_elements = std::max( + {bs, + params.page_indices_rows * max_seqlen_k, + seqlens_expanded_size, + HAS_FLASHMLA ? (seqlens_expanded_size + 1) : 0, + HAS_FLASHMLA ? params.flashmla_metadata_size : 0}); + + dim3 grid = get_launch_config(max_elements); + dim3 block(THREADS_PER_BLOCK); + DLDevice device = cache_seqlens_src.device(); + + // Launch unified kernel with params struct + host::LaunchKernel(grid, block, device)(fused_metadata_copy_kernel, params); + } +}; + +/** + * JIT wrapper for multi-backend fused metadata copy kernel. + * + * This kernel optimizes the common case where metadata needs to be copied from + * one source to THREE destination backends in a single kernel launch. This is + * 3x faster than launching three separate kernels due to: + * - Reduced kernel launch overhead (1 launch instead of 3) + * - Improved memory coalescing (source read once, written to 3 destinations) + * - Better GPU occupancy and instruction-level parallelism + * + * USAGE: Primarily for speculative decoding with multiple draft models, where + * the same source metadata needs to be replicated to multiple backend contexts. + * + * LIMITATION: Currently only supports DECODE mode, which is the most frequently + * used mode in speculative decoding scenarios. + * + * IMPLEMENTATION: + * - Constructs FusedMetadataCopyMultiParams with 1 SourcePointers + 3 DestinationPointers + * - Launches fused_metadata_copy_multi_kernel with __grid_constant__ parameters + * + * @tparam HAS_REAL_PAGE_TABLE Whether real_page_table tensors are present + * @tparam HAS_FLASHMLA Whether FlashMLA metadata tensors are present + */ +template +struct FusedMetadataCopyMultiKernel { + static void + run(const tvm::ffi::TensorView cache_seqlens_src, + const tvm::ffi::TensorView cu_seqlens_k_src, + const tvm::ffi::TensorView page_indices_src, + const tvm::ffi::TensorView nsa_cache_seqlens_src, + const tvm::ffi::TensorView nsa_cu_seqlens_k_src, + const tvm::ffi::Optional real_page_table_src, + const tvm::ffi::Optional flashmla_num_splits_src, + const tvm::ffi::Optional flashmla_metadata_src, + const tvm::ffi::TensorView cache_seqlens_dst0, + const tvm::ffi::TensorView cu_seqlens_k_dst0, + const tvm::ffi::TensorView page_table_1_dst0, + const tvm::ffi::TensorView nsa_cache_seqlens_dst0, + const tvm::ffi::TensorView nsa_cu_seqlens_k_dst0, + const tvm::ffi::Optional real_page_table_dst0, + const tvm::ffi::Optional flashmla_num_splits_dst0, + const tvm::ffi::Optional flashmla_metadata_dst0, + const tvm::ffi::TensorView cache_seqlens_dst1, + const tvm::ffi::TensorView cu_seqlens_k_dst1, + const tvm::ffi::TensorView page_table_1_dst1, + const tvm::ffi::TensorView nsa_cache_seqlens_dst1, + const tvm::ffi::TensorView nsa_cu_seqlens_k_dst1, + const tvm::ffi::Optional real_page_table_dst1, + const tvm::ffi::Optional flashmla_num_splits_dst1, + const tvm::ffi::Optional flashmla_metadata_dst1, + const tvm::ffi::TensorView cache_seqlens_dst2, + const tvm::ffi::TensorView cu_seqlens_k_dst2, + const tvm::ffi::TensorView page_table_1_dst2, + const tvm::ffi::TensorView nsa_cache_seqlens_dst2, + const tvm::ffi::TensorView nsa_cu_seqlens_k_dst2, + const tvm::ffi::Optional real_page_table_dst2, + const tvm::ffi::Optional flashmla_num_splits_dst2, + const tvm::ffi::Optional flashmla_metadata_dst2, + int bs, + int max_len, + int seqlens_expanded_size) { + using namespace host; + + // Build parameter struct with nested source/destination pointers + // unwrap_data_ptr and unwrap_optional_data_ptr perform dtype validation + const auto params = FusedMetadataCopyMultiParams{ + .src = + { + .cache_seqlens = unwrap_data_ptr(cache_seqlens_src, "cache_seqlens_src"), + .cu_seqlens_k = unwrap_data_ptr(cu_seqlens_k_src, "cu_seqlens_k_src"), + .page_indices = unwrap_data_ptr(page_indices_src, "page_indices_src"), + .nsa_cache_seqlens = unwrap_data_ptr(nsa_cache_seqlens_src, "nsa_cache_seqlens_src"), + .seqlens_expanded = nullptr, // Not used in multi-backend DECODE mode + .nsa_cu_seqlens_k = unwrap_data_ptr(nsa_cu_seqlens_k_src, "nsa_cu_seqlens_k_src"), + .real_page_table = unwrap_optional_data_ptr(real_page_table_src, "real_page_table_src"), + .flashmla_num_splits = + unwrap_optional_data_ptr(flashmla_num_splits_src, "flashmla_num_splits_src"), + .flashmla_metadata = unwrap_optional_data_ptr(flashmla_metadata_src, "flashmla_metadata_src"), + }, + .dst0 = + { + .cache_seqlens = unwrap_data_ptr_mut(cache_seqlens_dst0, "cache_seqlens_dst0"), + .cu_seqlens_k = unwrap_data_ptr_mut(cu_seqlens_k_dst0, "cu_seqlens_k_dst0"), + .page_table_1 = unwrap_data_ptr_mut(page_table_1_dst0, "page_table_1_dst0"), + .nsa_cache_seqlens = unwrap_data_ptr_mut(nsa_cache_seqlens_dst0, "nsa_cache_seqlens_dst0"), + .seqlens_expanded = nullptr, + .nsa_cu_seqlens_k = unwrap_data_ptr_mut(nsa_cu_seqlens_k_dst0, "nsa_cu_seqlens_k_dst0"), + .real_page_table = unwrap_optional_data_ptr_mut(real_page_table_dst0, "real_page_table_dst0"), + .flashmla_num_splits = + unwrap_optional_data_ptr_mut(flashmla_num_splits_dst0, "flashmla_num_splits_dst0"), + .flashmla_metadata = + unwrap_optional_data_ptr_mut(flashmla_metadata_dst0, "flashmla_metadata_dst0"), + }, + .dst1 = + { + .cache_seqlens = unwrap_data_ptr_mut(cache_seqlens_dst1, "cache_seqlens_dst1"), + .cu_seqlens_k = unwrap_data_ptr_mut(cu_seqlens_k_dst1, "cu_seqlens_k_dst1"), + .page_table_1 = unwrap_data_ptr_mut(page_table_1_dst1, "page_table_1_dst1"), + .nsa_cache_seqlens = unwrap_data_ptr_mut(nsa_cache_seqlens_dst1, "nsa_cache_seqlens_dst1"), + .seqlens_expanded = nullptr, + .nsa_cu_seqlens_k = unwrap_data_ptr_mut(nsa_cu_seqlens_k_dst1, "nsa_cu_seqlens_k_dst1"), + .real_page_table = unwrap_optional_data_ptr_mut(real_page_table_dst1, "real_page_table_dst1"), + .flashmla_num_splits = + unwrap_optional_data_ptr_mut(flashmla_num_splits_dst1, "flashmla_num_splits_dst1"), + .flashmla_metadata = + unwrap_optional_data_ptr_mut(flashmla_metadata_dst1, "flashmla_metadata_dst1"), + }, + .dst2 = + { + .cache_seqlens = unwrap_data_ptr_mut(cache_seqlens_dst2, "cache_seqlens_dst2"), + .cu_seqlens_k = unwrap_data_ptr_mut(cu_seqlens_k_dst2, "cu_seqlens_k_dst2"), + .page_table_1 = unwrap_data_ptr_mut(page_table_1_dst2, "page_table_1_dst2"), + .nsa_cache_seqlens = unwrap_data_ptr_mut(nsa_cache_seqlens_dst2, "nsa_cache_seqlens_dst2"), + .seqlens_expanded = nullptr, + .nsa_cu_seqlens_k = unwrap_data_ptr_mut(nsa_cu_seqlens_k_dst2, "nsa_cu_seqlens_k_dst2"), + .real_page_table = unwrap_optional_data_ptr_mut(real_page_table_dst2, "real_page_table_dst2"), + .flashmla_num_splits = + unwrap_optional_data_ptr_mut(flashmla_num_splits_dst2, "flashmla_num_splits_dst2"), + .flashmla_metadata = + unwrap_optional_data_ptr_mut(flashmla_metadata_dst2, "flashmla_metadata_dst2"), + }, + .bs = bs, + .max_len = max_len, + .seqlens_expanded_size = seqlens_expanded_size, + .page_table_1_stride = static_cast(page_table_1_dst0.shape()[1]), + .real_page_table_cols = + real_page_table_src.has_value() ? static_cast(real_page_table_src.value().shape()[1]) : 0, + .real_page_table_dst_stride = + real_page_table_dst0.has_value() ? static_cast(real_page_table_dst0.value().stride(0)) : 0, + .flashmla_metadata_size = + flashmla_metadata_src.has_value() ? static_cast(flashmla_metadata_src.value().numel()) : 0, + }; + + dim3 grid = get_launch_config(bs * max_len); + dim3 block(THREADS_PER_BLOCK); + DLDevice device = cache_seqlens_src.device(); + + // Launch multi-backend kernel with params struct + host::LaunchKernel(grid, block, device)( + fused_metadata_copy_multi_kernel, params); + } +}; + +} // namespace diff --git a/sglang/python/sglang/jit_kernel/csrc/elementwise/kvcache.cuh b/sglang/python/sglang/jit_kernel/csrc/elementwise/kvcache.cuh new file mode 100644 index 0000000000000000000000000000000000000000..fa17cbf8894dff5aa089496e99823dab336b7a12 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/elementwise/kvcache.cuh @@ -0,0 +1,201 @@ +#include +#include + +#include +#include +#include + +#include +#include + +#include + +namespace { + +struct StoreKVCacheParams { + const void* __restrict__ k; + const void* __restrict__ v; + void* __restrict__ k_cache; + void* __restrict__ v_cache; + const void* __restrict__ indices; + int64_t stride_k_bytes; + int64_t stride_v_bytes; + int64_t stride_cache_bytes; + int64_t stride_indices; + uint32_t batch_size; +}; + +constexpr uint32_t kNumWarps = 4; +constexpr uint32_t kThreadsPerBlock = kNumWarps * device::kWarpThreads; + +/** + * \brief Use a single warp to copy key and value data from source to destination. + * Each thread in the warp copies a portion of the data in a coalesced manner. + * \tparam kElementBytes The size of each key/value element in bytes. + * \param k_src Pointer to the source key data. + * \param v_src Pointer to the source value data. + * \param k_dst Pointer to the destination key data. + * \param v_dst Pointer to the destination value data. + */ +template +SGL_DEVICE void copy_kv_warp( + const void* __restrict__ k_src, + const void* __restrict__ v_src, + void* __restrict__ k_dst, + void* __restrict__ v_dst) { + using namespace device; + constexpr int64_t kAlignment = (kElementBytes % (16 * kWarpThreads) == 0) ? 16 + : kElementBytes % (8 * kWarpThreads) == 0 ? 8 + : kElementBytes % (4 * kWarpThreads) == 0 ? 4 + : kElementBytes % 4 == 0 ? 4 + : 0; + + static_assert(kAlignment > 0, "Element size must be multiple of 4 bytes"); + + using vec_t = AlignedStorage; + constexpr auto kLoopBytes = sizeof(vec_t) * kWarpThreads; + constexpr auto kLoopCount = kElementBytes / kLoopBytes; + + const auto gmem = tile::Memory::warp(); + +#pragma unroll kLoopCount + for (int64_t i = 0; i < kLoopCount; ++i) { + const auto k = gmem.load(k_src, i); + const auto v = gmem.load(v_src, i); + gmem.store(k_dst, k, i); + gmem.store(v_dst, v, i); + } + + // handle the epilogue if any + if constexpr (kLoopCount * kLoopBytes < kElementBytes) { + if (gmem.in_bound(kElementBytes / sizeof(vec_t), kLoopCount)) { + const auto k = gmem.load(k_src, kLoopCount); + const auto v = gmem.load(v_src, kLoopCount); + gmem.store(k_dst, k, kLoopCount); + gmem.store(v_dst, v, kLoopCount); + } + } +} + +/** + * \brief Kernel to store key-value pairs into the KV cache. + * Each element is split into multiple parts to allow parallel memory copy. + * \tparam kElementBytes The size of each key/value element in bytes. + * \tparam kSplit The number of warps that handle each element. + * \tparam kUsePDL Whether to use PDL feature. + * \tparam T The data type of the indices (`int32_t` or `int64_t`). + */ +template +__global__ void store_kvcache(const __grid_constant__ StoreKVCacheParams params) { + using namespace device; + constexpr auto kSplitSize = kElementBytes / kSplit; + const uint32_t warp_id = blockIdx.x * kNumWarps + threadIdx.x / kWarpThreads; + const uint32_t item_id = warp_id / kSplit; + const uint32_t split_id = warp_id % kSplit; + const auto& [ + k_input, v_input, k_cache, v_cache, indices, // ptr + stride_k, stride_v, stride_cache, stride_indices, batch_size // size + ] = params; + if (item_id >= batch_size) return; + + const auto index_ptr = static_cast(indices) + item_id * stride_indices; + PDLWaitPrimary(); + + const auto index = *index_ptr; + const auto k_src = pointer::offset(k_input, item_id * stride_k, split_id * kSplitSize); + const auto v_src = pointer::offset(v_input, item_id * stride_v, split_id * kSplitSize); + const auto k_dst = pointer::offset(k_cache, index * stride_cache, split_id * kSplitSize); + const auto v_dst = pointer::offset(v_cache, index * stride_cache, split_id * kSplitSize); + + copy_kv_warp(k_src, v_src, k_dst, v_dst); + PDLTriggerSecondary(); +} + +template +struct StoreKVCacheKernel { + static_assert(kElementBytes > 0 && kElementBytes % 4 == 0); + + template + static constexpr auto store_kernel = store_kvcache; + + template + static auto get_kernel(const int num_split) { + using namespace host; + // only apply split optimization when element size is aligned + if constexpr (kElementBytes % (4 * 128) == 0) { + if (num_split == 4) return store_kernel<4, T>; + } + if constexpr (kElementBytes % (2 * 128) == 0) { + if (num_split == 2) return store_kernel<2, T>; + } + if (num_split == 1) return store_kernel<1, T>; + Panic("Unsupported num_split {} for element size {}", num_split, kElementBytes); + } + + static void + run(const tvm::ffi::TensorView k, + const tvm::ffi::TensorView v, + const tvm::ffi::TensorView k_cache, + const tvm::ffi::TensorView v_cache, + const tvm::ffi::TensorView indices, + const int num_split) { + using namespace host; + auto B = SymbolicSize{"batch_size"}; + auto D = SymbolicSize{"element_size"}; + auto KS = SymbolicSize{"k_stride"}; + auto VS = SymbolicSize{"v_stride"}; + auto S = SymbolicSize{"cache_stride"}; + auto I = SymbolicSize{"indices_stride"}; + auto dtype = SymbolicDType{}; + auto device = SymbolicDevice{}; + auto indice_dtype = SymbolicDType{}; + device.set_options(); + + TensorMatcher({B, D}) // + .with_strides({KS, 1}) + .with_dtype(dtype) + .with_device(device) + .verify(k); + TensorMatcher({B, D}) // + .with_strides({VS, 1}) + .with_dtype(dtype) + .with_device(device) + .verify(v); + TensorMatcher({-1, D}) // + .with_strides({S, 1}) + .with_dtype(dtype) + .with_device(device) + .verify(k_cache) + .verify(v_cache); + TensorMatcher({B}) // + .with_strides({I}) + .with_dtype(indice_dtype) + .with_device(device) + .verify(indices); + + const int64_t dtype_size = dtype_bytes(dtype.unwrap()); + const uint32_t num_elements = static_cast(B.unwrap()); + RuntimeCheck(kElementBytes == dtype_size * D.unwrap()); + + const auto params = StoreKVCacheParams{ + .k = k.data_ptr(), + .v = v.data_ptr(), + .k_cache = k_cache.data_ptr(), + .v_cache = v_cache.data_ptr(), + .indices = indices.data_ptr(), + .stride_k_bytes = KS.unwrap() * dtype_size, + .stride_v_bytes = VS.unwrap() * dtype_size, + .stride_cache_bytes = S.unwrap() * dtype_size, + .stride_indices = I.unwrap(), + .batch_size = static_cast(B.unwrap()), + }; + // select kernel and update num_split if needed + const auto use_int32 = indice_dtype.is_type(); + const auto kernel = use_int32 ? get_kernel(num_split) : get_kernel(num_split); + const auto num_blocks = div_ceil(num_elements * num_split, kNumWarps); + LaunchKernel(num_blocks, kThreadsPerBlock, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/sglang/python/sglang/jit_kernel/csrc/elementwise/pos_enc.cuh b/sglang/python/sglang/jit_kernel/csrc/elementwise/pos_enc.cuh new file mode 100644 index 0000000000000000000000000000000000000000..9272e6248243f71415a0ce5dc91561a49f8d0e4e --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/elementwise/pos_enc.cuh @@ -0,0 +1,313 @@ +// Adapted from +// https://github.com/vllm-project/vllm/blob/014ece97c7aa49084a1119dca792af081a18dbc1/csrc/pos_encoding_kernels.cu + +#include +#include + +#include + +#include + +#include +#include + +namespace { + +template +inline __device__ void apply_token_rotary_embedding( + scalar_t* __restrict__ arr, + const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, + int rot_offset, + int embed_dim) { + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = embed_dim + rot_offset; + cos = SGLANG_LDG(cos_ptr + x_index); + sin = SGLANG_LDG(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = SGLANG_LDG(cos_ptr + x_index / 2); + sin = SGLANG_LDG(sin_ptr + x_index / 2); + } + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos - y * sin; + arr[y_index] = y * cos + x * sin; +} + +template +inline __device__ void apply_rotary_embedding( + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // nullptr or + // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* cache_ptr, + const int head_size, + const int num_heads, + const int num_kv_heads, + const int rot_dim, + const int token_idx, + const int64_t query_stride, + const int64_t key_stride, + const int64_t head_stride) { + const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; + + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * query_stride + head_idx * head_stride; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding(query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } + + if (key != nullptr) { + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_stride; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding(key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } + } +} + +template +__global__ void rotary_embedding_kernel( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // nullptr or + // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int64_t head_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding( + query, + key, + cache_ptr, + head_size, + num_heads, + num_kv_heads, + rot_dim, + token_idx, + query_stride, + key_stride, + head_stride); +} + +// Helper struct to launch kernel +template +void launch_kernel( + const int64_t* positions_data_ptr, + void* query_ptr, + void* key_ptr, + const void* cos_sin_cache_ptr, + int rot_dim, + int64_t query_stride, + int64_t key_stride, + int64_t head_stride, + int num_heads, + int num_kv_heads, + int head_size, + dim3 grid, + dim3 block, + const cudaStream_t stream) { + rotary_embedding_kernel<<>>( + positions_data_ptr, + static_cast(query_ptr), + static_cast(key_ptr), + static_cast(cos_sin_cache_ptr), + rot_dim, + query_stride, + key_stride, + head_stride, + num_heads, + num_kv_heads, + head_size); +}; + +// Helper macro to reduce repetition +#define DISPATCH_DTYPE(DTYPE_CODE, DTYPE_BITS, IS_NEOX, ...) \ + if (DTYPE_CODE == kDLFloat && DTYPE_BITS == 32) { \ + launch_kernel(__VA_ARGS__); \ + } else if (DTYPE_CODE == kDLFloat && DTYPE_BITS == 16) { \ + launch_kernel(__VA_ARGS__); \ + } else if (DTYPE_CODE == kDLBfloat && DTYPE_BITS == 16) { \ + launch_kernel(__VA_ARGS__); \ + } else { \ + RuntimeCheck( \ + false, "Unsupported data type for rotary embedding. Only float32, float16, and bfloat16 are supported."); \ + } + +// Helper function to dispatch based on data type +template +void dispatch_by_dtype( + const int64_t* positions_data_ptr, + DLDataType query_dtype, + void* query_ptr, + void* key_ptr, + void* cos_sin_cache_ptr, + int rot_dim, + int64_t query_stride, + int64_t key_stride, + int64_t head_stride, + int num_heads, + int num_kv_heads, + int head_size, + dim3 grid, + dim3 block, + const cudaStream_t stream) { + using namespace host; + DISPATCH_DTYPE( + query_dtype.code, + query_dtype.bits, + IS_NEOX, + positions_data_ptr, + query_ptr, + key_ptr, + cos_sin_cache_ptr, + rot_dim, + query_stride, + key_stride, + head_stride, + num_heads, + num_kv_heads, + head_size, + grid, + block, + stream); +} + +struct RotaryEmbeddingKernel { + static void + run(tvm::ffi::TensorView positions, // [batch_size, seq_len] or [num_tokens] + tvm::ffi::TensorView query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] + tvm::ffi::Optional key, + // null or + // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] + int64_t head_size, + tvm::ffi::TensorView cos_sin_cache, // [max_position, rot_dim] + bool is_neox) { + using namespace host; + + // num_tokens = batch_size * seq_len + int64_t num_tokens = positions.numel(); + int32_t positions_ndim = positions.ndim(); + + // Make sure num_tokens dim is consistent across positions, query, and key + RuntimeCheck( + positions_ndim == 1 || positions_ndim == 2, "positions must have shape [num_tokens] or [batch_size, seq_len]"); + if (positions_ndim == 1) { + RuntimeCheck( + query.size(0) == positions.size(0) && (!key.has_value() || key.value().size(0) == positions.size(0)), + "query, key and positions must have the same number of tokens"); + } + if (positions_ndim == 2) { + RuntimeCheck( + query.size(0) == positions.size(0) && (!key.has_value() || key.value().size(0) == positions.size(0)) && + query.size(1) == positions.size(1) && (!key.has_value() || key.value().size(1) == positions.size(1)), + "query, key and positions must have the same batch_size and seq_len"); + } + + // Make sure head_size is valid for query and key + // hidden_size = num_heads * head_size + int query_hidden_size = query.numel() / num_tokens; + int key_hidden_size = key.has_value() ? key.value().numel() / num_tokens : 0; + RuntimeCheck(query_hidden_size % head_size == 0); + RuntimeCheck(key_hidden_size % head_size == 0); + + // Make sure query and key have consistent number of heads + int num_heads = query_hidden_size / head_size; + int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads; + RuntimeCheck(num_heads % num_kv_heads == 0); + + int rot_dim = cos_sin_cache.size(1); + int seq_dim_idx = positions_ndim - 1; + int64_t query_stride = query.stride(seq_dim_idx); + int64_t key_stride = key.has_value() ? key.value().stride(seq_dim_idx) : 0; + // Determine head stride: for [*, heads, head_size] use stride of last dim; + // for flat [*, heads*head_size], heads blocks are contiguous of size + // head_size + int query_ndim = query.dim(); + int64_t head_stride = (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size; + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + + auto device = query.device(); + const cudaStream_t stream = LaunchKernel::resolve_device(device); + + auto positions_data_ptr = static_cast(positions.data_ptr()); + + if (is_neox) { + dispatch_by_dtype( + positions_data_ptr, + query.dtype(), + query.data_ptr(), + key.has_value() ? key.value().data_ptr() : nullptr, + cos_sin_cache.data_ptr(), + rot_dim, + query_stride, + key_stride, + head_stride, + num_heads, + num_kv_heads, + head_size, + grid, + block, + stream); + } else { + dispatch_by_dtype( + positions_data_ptr, + query.dtype(), + query.data_ptr(), + key.has_value() ? key.value().data_ptr() : nullptr, + cos_sin_cache.data_ptr(), + rot_dim, + query_stride, + key_stride, + head_stride, + num_heads, + num_kv_heads, + head_size, + grid, + block, + stream); + } + } +}; + +} // namespace diff --git a/sglang/python/sglang/jit_kernel/csrc/elementwise/qknorm.cuh b/sglang/python/sglang/jit_kernel/csrc/elementwise/qknorm.cuh new file mode 100644 index 0000000000000000000000000000000000000000..789e796b05045a0df1469f0b5def257d78b472ff --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/elementwise/qknorm.cuh @@ -0,0 +1,257 @@ +#include +#include + +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include +#include + +namespace { + +struct QKNormParams { + void* __restrict__ q; + void* __restrict__ k; // k is offset by (-num_qo_heads * head_dim) elements + int64_t q_stride; + int64_t k_stride; + uint32_t num_qo_heads; + uint32_t num_kv_heads; + float eps; + const void* __restrict__ q_weight; + const void* __restrict__ k_weight; + uint32_t num_tokens; +}; + +constexpr uint32_t kWarpsPerBlock = 4; +constexpr uint32_t kThreadsPerBlock = kWarpsPerBlock * device::kWarpThreads; + +// Warp-level kernel for head_dim <= 256 +template +__global__ void fused_qknorm_warp(const QKNormParams __grid_constant__ params) { + using namespace device; + using Storage = norm::StorageType; + + static_assert(sizeof(Float) == 2, "Only support FP16/BF16"); + const auto& [q, k, q_stride, k_stride, num_qo_heads, num_kv_heads, eps, q_weight, k_weight, num_tokens] = params; + + const auto num_blks = gridDim.x; + const auto num_workers = num_blks * kWarpsPerBlock; + const auto num_q_and_k_heads = num_qo_heads + num_kv_heads; + const auto num_works = num_q_and_k_heads * num_tokens; + const auto start_worker_id = blockIdx.x * kWarpsPerBlock + threadIdx.x / kWarpThreads; + const auto gmem = tile::Memory::warp(); + + PDLWaitPrimary(); // wait for primary kernel + + for (auto idx = start_worker_id; idx < num_works; idx += num_workers) { + const int64_t token_id = idx / num_q_and_k_heads; + const int64_t head_id = idx % num_q_and_k_heads; + const auto load_q = head_id < num_qo_heads; + const auto input = load_q ? pointer::offset(q, 2 * (token_id * q_stride + head_id * kHeadDim)) + : pointer::offset(k, 2 * (token_id * k_stride + head_id * kHeadDim)); + const auto weight = load_q ? q_weight : k_weight; + const auto input_vec = gmem.load(input); + const auto weight_vec = gmem.load(weight); + const auto output_vec = norm::apply_norm_warp(input_vec, weight_vec, eps); + gmem.store(input, output_vec); + } + + PDLTriggerSecondary(); // launch secondary kernel +} + +// For CTA level, used for head_dim > 256 (512,1024) +template +__global__ void fused_qknorm_cta(const QKNormParams __grid_constant__ params) { + using namespace device; + using Storage = norm::StorageType; + + constexpr auto kNumThreads = host::norm::get_cta_threads(); + constexpr auto kNumWarps = kNumThreads / kWarpThreads; + + static_assert(sizeof(Float) == 2, "Only support FP16/BF16"); + const auto& [q, k, q_stride, k_stride, num_qo_heads, num_kv_heads, eps, q_weight, k_weight, num_tokens] = params; + + const auto num_q_and_k_heads = num_qo_heads + num_kv_heads; + const auto num_works = num_q_and_k_heads * num_tokens; + const auto gmem = tile::Memory::cta(kNumThreads); + __shared__ float smem[norm::kSmemBufferSize]; + + PDLWaitPrimary(); // wait for primary kernel + + for (auto idx = blockIdx.x; idx < num_works; idx += gridDim.x) { + const int64_t token_id = idx / num_q_and_k_heads; + const int64_t head_id = idx % num_q_and_k_heads; + const auto load_q = head_id < num_qo_heads; + const auto input = load_q ? pointer::offset(q, 2 * (token_id * q_stride + head_id * kHeadDim)) + : pointer::offset(k, 2 * (token_id * k_stride + head_id * kHeadDim)); + const auto weight = load_q ? q_weight : k_weight; + const auto input_vec = gmem.load(input); + const auto weight_vec = gmem.load(weight); + const auto output_vec = norm::apply_norm_cta(input_vec, weight_vec, eps, smem, kNumWarps); + gmem.store(input, output_vec); + } + + PDLTriggerSecondary(); // launch secondary kernel +} + +// Warp-level kernel struct for head_dim <= 256 +template +struct QKNormKernelWarp { + static_assert(std::is_same_v || std::is_same_v); + static_assert(!host::norm::should_use_cta(), "Use QKNormKernelCTA for head_dim > 256"); + static constexpr auto kernel = fused_qknorm_warp; + + static void + run(const tvm::ffi::TensorView q, + const tvm::ffi::TensorView k, + const tvm::ffi::TensorView q_weight, + const tvm::ffi::TensorView k_weight, + float eps) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto Q = SymbolicSize{"num_qo_heads"}; + auto K = SymbolicSize{"num_kv_heads"}; + auto D = SymbolicSize{"head_dim"}; + auto Sq = SymbolicSize{"q_stride"}; + auto Sk = SymbolicSize{"k_stride"}; + auto device = SymbolicDevice{}; + D.set_value(kHeadDim); + device.set_options(); + + TensorMatcher({N, Q, D}) // q input + .with_strides({Sq, D, 1}) + .with_dtype() + .with_device(device) + .verify(q); + TensorMatcher({N, K, D}) // k input + .with_strides({Sk, D, 1}) + .with_dtype() + .with_device(device) + .verify(k); + TensorMatcher({D}) // weight + .with_dtype() + .with_device(device) + .verify(q_weight) + .verify(k_weight); + + const auto num_tokens = static_cast(N.unwrap()); + const auto num_qo_heads = static_cast(Q.unwrap()); + const auto num_kv_heads = static_cast(K.unwrap()); + + // NOTE: we offset the k here to reduce computation cost in the kernel + const auto params = QKNormParams{ + .q = q.data_ptr(), + .k = pointer::offset(k.data_ptr(), -2 * static_cast(num_qo_heads) * kHeadDim), + .q_stride = static_cast(Sq.unwrap()), + .k_stride = static_cast(Sk.unwrap()), + .num_qo_heads = num_qo_heads, + .num_kv_heads = num_kv_heads, + .eps = eps, + .q_weight = q_weight.data_ptr(), + .k_weight = k_weight.data_ptr(), + .num_tokens = num_tokens, + }; + + static const uint32_t max_occupancy = runtime::get_blocks_per_sm(kernel, kThreadsPerBlock); + static const uint32_t kNumSM = runtime::get_sm_count(device.unwrap().device_id); + + // choose kernel based on dtype + const auto num_works = (num_qo_heads + num_kv_heads) * num_tokens; + const auto needed_blocks = div_ceil(num_works, kWarpsPerBlock); + + // we use persistent kernel, which limit the number of blocks to reduce overhead + const auto num_blocks = std::min(kNumSM * max_occupancy, needed_blocks); + LaunchKernel(num_blocks, kThreadsPerBlock, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +// This goes with fused_qknorm_cta +template +struct QKNormKernelCTA { + static_assert(std::is_same_v || std::is_same_v); + static_assert(host::norm::should_use_cta(), "Use QKNormKernelWarp for head_dim <= 256"); + static constexpr auto kernel = fused_qknorm_cta; + static constexpr auto kNumThreads = host::norm::get_cta_threads(); + + static void + run(const tvm::ffi::TensorView q, + const tvm::ffi::TensorView k, + const tvm::ffi::TensorView q_weight, + const tvm::ffi::TensorView k_weight, + float eps) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto Q = SymbolicSize{"num_qo_heads"}; + auto K = SymbolicSize{"num_kv_heads"}; + auto D = SymbolicSize{"head_dim"}; + auto Sq = SymbolicSize{"q_stride"}; + auto Sk = SymbolicSize{"k_stride"}; + auto device = SymbolicDevice{}; + D.set_value(kHeadDim); + device.set_options(); + + TensorMatcher({N, Q, D}) // q input + .with_strides({Sq, D, 1}) + .with_dtype() + .with_device(device) + .verify(q); + TensorMatcher({N, K, D}) // k input + .with_strides({Sk, D, 1}) + .with_dtype() + .with_device(device) + .verify(k); + TensorMatcher({D}) // weight + .with_dtype() + .with_device(device) + .verify(q_weight) + .verify(k_weight); + + const auto num_tokens = static_cast(N.unwrap()); + const auto num_qo_heads = static_cast(Q.unwrap()); + const auto num_kv_heads = static_cast(K.unwrap()); + + // NOTE: we offset the k here to reduce computation cost in the kernel + const auto params = QKNormParams{ + .q = q.data_ptr(), + .k = pointer::offset(k.data_ptr(), -2 * static_cast(num_qo_heads) * kHeadDim), + .q_stride = static_cast(Sq.unwrap()), + .k_stride = static_cast(Sk.unwrap()), + .num_qo_heads = num_qo_heads, + .num_kv_heads = num_kv_heads, + .eps = eps, + .q_weight = q_weight.data_ptr(), + .k_weight = k_weight.data_ptr(), + .num_tokens = num_tokens, + }; + + static const uint32_t max_occupancy = runtime::get_blocks_per_sm(kernel, kNumThreads); + static const uint32_t kNumSM = runtime::get_sm_count(device.unwrap().device_id); + + const auto num_works = (num_qo_heads + num_kv_heads) * num_tokens; + + // we use persistent kernel, which limit the number of blocks to reduce overhead + const auto num_blocks = std::min(num_works, max_occupancy * kNumSM); + LaunchKernel(num_blocks, kNumThreads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +// Unified dispatch: select warp or CTA kernel based on head_dim +template +using QKNormKernel = std::conditional_t< + host::norm::should_use_cta(), + QKNormKernelCTA, + QKNormKernelWarp>; + +} // namespace diff --git a/sglang/python/sglang/jit_kernel/csrc/elementwise/qknorm_across_heads.cuh b/sglang/python/sglang/jit_kernel/csrc/elementwise/qknorm_across_heads.cuh new file mode 100644 index 0000000000000000000000000000000000000000..1c231390bf218d11c654986542367037add3fd50 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/elementwise/qknorm_across_heads.cuh @@ -0,0 +1,232 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +namespace { + +template +struct VecTypeTrait; + +template <> +struct VecTypeTrait { + using packed_t = packed_t; + using vec_t = device::AlignedVector; +}; + +template <> +struct VecTypeTrait { + using packed_t = packed_t; + using vec_t = device::AlignedVector; +}; + +template <> +struct VecTypeTrait { + using packed_t = packed_t; + using vec_t = device::AlignedVector; +}; + +template <> +struct VecTypeTrait { + using packed_t = packed_t; + using vec_t = device::AlignedVector; +}; + +template +SGL_DEVICE packed_t rms(packed_t& val, packed_t& weight, float rsqrt_square_sum) { + float2 valf = device::cast(val); + float2 weightf = device::cast(weight); + return device::cast( + make_float2(valf.x * weightf.x * rsqrt_square_sum, valf.y * weightf.y * rsqrt_square_sum)); +} + +template +__global__ void qknorm_across_heads_reg_kernel( + T* __restrict__ q, + T* __restrict__ k, + const T* __restrict__ q_weight, + const T* __restrict__ k_weight, + int vec_hidden_size, + float eps) { + constexpr int inner_loop = VEC_SIZE_IN_BYTE == 16 ? 4 : 8; + + __shared__ float shared_memory[64]; // Used for CTA reduce, store both Q and K rsqrt + + using vec_t = typename VecTypeTrait::vec_t; + using packed_t = typename VecTypeTrait::packed_t; + vec_t v_q; // Save q + vec_t v_k; // Save k + vec_t v_q_weight; // Save q_weight + vec_t v_k_weight; // Save k_weight + vec_t v_q_out; // Save q output + vec_t v_k_out; // Save k output + + auto token_id = blockIdx.x; + float2 acc_square_q = make_float2(0.0f, 0.0f); // Sum of squares for q + float2 acc_square_k = make_float2(0.0f, 0.0f); // Sum of squares for k + + if (threadIdx.x < vec_hidden_size) { + // Compute address for q and k + vec_t* p_q = reinterpret_cast(q) + token_id * vec_hidden_size; + vec_t* p_k = reinterpret_cast(k) + token_id * vec_hidden_size; + const vec_t* p_q_weight = reinterpret_cast(q_weight); + const vec_t* p_k_weight = reinterpret_cast(k_weight); + + // Load data + v_q = p_q[threadIdx.x]; + v_k = p_k[threadIdx.x]; + v_q_weight = p_q_weight[threadIdx.x]; + v_k_weight = p_k_weight[threadIdx.x]; + + // Compute sum of squares for q + for (int i = 0; i < inner_loop; i++) { + float2 val = device::cast(v_q[i]); + acc_square_q.x += val.x * val.x; + acc_square_q.y += val.y * val.y; + } + + // Compute sum of squares for k + for (int i = 0; i < inner_loop; i++) { + float2 val = device::cast(v_k[i]); + acc_square_k.x += val.x * val.x; + acc_square_k.y += val.y * val.y; + } + } + + auto cg_warp = cooperative_groups::tiled_partition<32>(cooperative_groups::this_thread_block()); + float* buffer_q = shared_memory; // [0, 31] for Q + float* buffer_k = shared_memory + 32; // [32, 63] for K + + // ========== Reduction phase: Compute rsqrt for both Q and K ========== + + // Step 0: Warp Reduce for Q + float warp_sum_q = + cooperative_groups::reduce(cg_warp, acc_square_q.x + acc_square_q.y, cooperative_groups::plus()); + if (threadIdx.x % 32 == 0) { + buffer_q[threadIdx.x / 32] = warp_sum_q; + } + + // Step 0: Warp Reduce for K + float warp_sum_k = + cooperative_groups::reduce(cg_warp, acc_square_k.x + acc_square_k.y, cooperative_groups::plus()); + if (threadIdx.x % 32 == 0) { + buffer_k[threadIdx.x / 32] = warp_sum_k; + } + + // Step 1: CTA Reduce for both Q and K + __syncthreads(); + if (threadIdx.x < 32) { + // CTA Reduce for Q + float cta_sum_q = cooperative_groups::reduce( + cg_warp, (threadIdx.x < blockDim.x / 32) ? buffer_q[threadIdx.x] : 0.0f, cooperative_groups::plus()); + buffer_q[threadIdx.x] = + rsqrtf(eps + cta_sum_q * (1.0f / static_cast(vec_hidden_size * (VEC_SIZE_IN_BYTE / sizeof(T))))); + + // CTA Reduce for K + float cta_sum_k = cooperative_groups::reduce( + cg_warp, (threadIdx.x < blockDim.x / 32) ? buffer_k[threadIdx.x] : 0.0f, cooperative_groups::plus()); + buffer_k[threadIdx.x] = + rsqrtf(eps + cta_sum_k * (1.0f / static_cast(vec_hidden_size * (VEC_SIZE_IN_BYTE / sizeof(T))))); + } + __syncthreads(); + + // ========== Apply normalization phase: Compute and write back Q and K ========== + + if (threadIdx.x < vec_hidden_size) { + // Apply RMSNorm for Q + float rsqrt_q = buffer_q[threadIdx.x / 32]; + for (int i = 0; i < inner_loop; i++) { + v_q_out[i] = rms(v_q[i], v_q_weight[i], rsqrt_q); + } + vec_t* p_q_out = reinterpret_cast(q) + token_id * vec_hidden_size; + p_q_out[threadIdx.x] = v_q_out; + + // Apply RMSNorm for K + float rsqrt_k = buffer_k[threadIdx.x / 32]; + for (int i = 0; i < inner_loop; i++) { + v_k_out[i] = rms(v_k[i], v_k_weight[i], rsqrt_k); + } + vec_t* p_k_out = reinterpret_cast(k) + token_id * vec_hidden_size; + p_k_out[threadIdx.x] = v_k_out; + } +} + +template +struct QKNormAcrossHeadsKernel { + static void + run(const tvm::ffi::TensorView q, + const tvm::ffi::TensorView k, + const tvm::ffi::TensorView q_weight, + const tvm::ffi::TensorView k_weight, + float eps) { + using namespace host; + auto N = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"hidden_size"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({N, D}) // q + .with_strides({D, 1}) + .with_dtype() + .with_device(device) + .verify(q); + TensorMatcher({N, D}) // k + .with_strides({D, 1}) + .with_dtype() + .with_device(device) + .verify(k); + TensorMatcher({D}) // q_weight + .with_dtype() + .with_device(device) + .verify(q_weight); + TensorMatcher({D}) // k_weight + .with_dtype() + .with_device(device) + .verify(k_weight); + + auto cc_major = host::runtime::get_cc_major(device.unwrap().device_id); + int hidden_size = static_cast(D.unwrap()); + if ((cc_major <= 9 && hidden_size <= 8192) || (cc_major >= 10 && hidden_size <= 12288)) { + int max_vec_size_byte = cc_major >= 10 ? 32 : 16; + int elements_in_vec = max_vec_size_byte / sizeof(DType); + int vec_hidden_size = hidden_size / elements_in_vec; + uint threads = (vec_hidden_size + 31) / 32 * 32; + + // Runtime check + host::RuntimeCheck( + hidden_size % elements_in_vec == 0, + "hidden_size", + hidden_size, + " can not align to elements_in_vec ", + elements_in_vec); + + // Launch single kernel for both q and k + auto kernel = max_vec_size_byte == 32 ? qknorm_across_heads_reg_kernel + : qknorm_across_heads_reg_kernel; + + LaunchKernel(static_cast(N.unwrap()), threads, device.unwrap()) + .enable_pdl(false)( + kernel, + reinterpret_cast(q.data_ptr()), + reinterpret_cast(k.data_ptr()), + reinterpret_cast(q_weight.data_ptr()), + reinterpret_cast(k_weight.data_ptr()), + vec_hidden_size, + eps); + } else { + host::RuntimeCheck(false, "Large hidden_sizes are not supported for now."); + } + } +}; + +} // namespace diff --git a/sglang/python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh b/sglang/python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh new file mode 100644 index 0000000000000000000000000000000000000000..aadcc495f51e1b80b98cf835139ac0f85b60fab2 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh @@ -0,0 +1,109 @@ +#include +#include + +#include +#include +#include + +#include + +#include + +namespace { + +struct RMSNormParams { + const void* input; + const void* __restrict__ weight; + void* output; + int64_t input_stride; + int64_t output_stride; + uint32_t num_tokens; + float eps; +}; + +template +__global__ void rmsnorm_cta(const RMSNormParams __grid_constant__ params) { + using namespace device; + using Storage = norm::StorageType; + + constexpr auto kNumThreads = host::norm::get_cta_threads(); + constexpr auto kNumWarps = kNumThreads / kWarpThreads; + + const auto& [input, weight_ptr, output, input_stride, output_stride, num_tokens, eps] = params; + const auto gmem = tile::Memory::cta(kNumThreads); + __shared__ float smem[norm::kSmemBufferSize]; + + PDLWaitPrimary(); // wait for primary kernel + + void* output_ptr = nullptr; + Storage output_vec; + for (uint32_t i = blockIdx.x; i < num_tokens; i += gridDim.x) { + const auto input_ptr = pointer::offset(input, i * input_stride); + const auto input_vec = gmem.load(input_ptr); + const auto weight_vec = gmem.load(weight_ptr); + if (output_ptr != nullptr) { + gmem.store(output_ptr, output_vec); + } + output_ptr = pointer::offset(output, i * output_stride); + output_vec = norm::apply_norm_cta(input_vec, weight_vec, eps, smem, kNumWarps); + } + gmem.store(output_ptr, output_vec); + + PDLTriggerSecondary(); // launch secondary kernel +} + +template +struct RMSNormKernel { + static_assert(host::norm::should_use_cta(), "Hidden size invalid for RMSNorm"); + static constexpr auto kernel = rmsnorm_cta; + + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView weight, + const tvm::ffi::TensorView output, + float eps) { + using namespace host; + auto N = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"hidden_size"}; + auto SI = SymbolicSize{"input_stride"}; + auto SO = SymbolicSize{"output_stride"}; + auto device = SymbolicDevice{}; + D.set_value(kDim); + device.set_options(); + + TensorMatcher({N, D}) // input + .with_strides({SI, 1}) + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({D}) // weight + .with_dtype() + .with_device(device) + .verify(weight); + TensorMatcher({N, D}) // output + .with_strides({SO, 1}) + .with_dtype() + .with_device(device) + .verify(output); + + const auto num_tokens = static_cast(N.unwrap()); + const auto params = RMSNormParams{ + .input = input.data_ptr(), + .weight = weight.data_ptr(), + .output = output.data_ptr(), + .input_stride = SI.unwrap(), + .output_stride = SO.unwrap(), + .num_tokens = num_tokens, + .eps = eps, + }; + + static constexpr auto kNumThreads = norm::get_cta_threads(); + static const uint32_t max_occupancy = runtime::get_blocks_per_sm(kernel, kNumThreads); + static const uint32_t kNumSM = runtime::get_sm_count(device.unwrap().device_id); + const auto num_blocks = std::min(num_tokens, max_occupancy * kNumSM); + LaunchKernel(num_blocks, kNumThreads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/sglang/python/sglang/jit_kernel/csrc/elementwise/rope.cuh b/sglang/python/sglang/jit_kernel/csrc/elementwise/rope.cuh new file mode 100644 index 0000000000000000000000000000000000000000..27b4e7ec83ff93760b1de3e8c187efebb65c62c2 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/elementwise/rope.cuh @@ -0,0 +1,464 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include + +namespace { + +struct FusedRopeParams { + void* __restrict__ q_ptr; + void* __restrict__ k_ptr; // NOTE: this k is pre-offset in host code to reduce computation in kernel + const void* __restrict__ cos_sin_cache_ptr; + const void* __restrict__ positions; + int64_t q_stride_bytes; + int64_t k_stride_bytes; + int64_t head_stride_bytes; + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t num_tokens; +}; + +struct FusedRopeStoreParams { + FusedRopeParams base_params; + void* v_ptr; + void* __restrict__ k_cache; + void* __restrict__ v_cache; + const void* __restrict__ out_loc; + int64_t v_stride_bytes; + int64_t cache_stride_bytes; +}; + +constexpr uint32_t kBlockSize = 128; + +[[maybe_unused]] +constexpr auto next_pow2(uint32_t target, uint32_t factor = 1) { + uint32_t power = 1; + while (power * factor < target) + power *= 2; + return power; +} + +template +__global__ void fused_rope_kernel(const __grid_constant__ FusedRopeParams params) { + using namespace device; + + constexpr int64_t kCosSinStrideBytes = kRopeDim * sizeof(float); + constexpr int64_t kVecSize = next_pow2(kRopeDim, (2 * kWorkThreads * (1 + kIsNeox))); + using DType2 = packed_t; + using InputStorage = AlignedVector; + constexpr int64_t kDimPerThread = kVecSize * 2 * (1 + kIsNeox); + constexpr uint32_t kLaneCount = kRopeDim / kDimPerThread; + static_assert(kRopeDim % kDimPerThread == 0 && kLaneCount <= kWorkThreads); + + const auto &[ + q, k, cos_sin_cache_ptr, positions, // pointers + q_stride_bytes, k_stride_bytes, head_stride_bytes, // strides + num_qo_heads, num_kv_heads, num_tokens // dimensions + ] = params; + + const auto num_blks = gridDim.x; + constexpr auto kWorkersPerBlock = kBlockSize / kWorkThreads; + const auto num_workers = num_blks * kWorkersPerBlock; + const auto num_q_and_k_heads = num_qo_heads + num_kv_heads; + const auto num_works = num_q_and_k_heads * num_tokens; + const auto start_worker_id = (blockIdx.x * kBlockSize + threadIdx.x) / kWorkThreads; + const auto cos_cache_ptr = cos_sin_cache_ptr; + const auto sin_cache_ptr = pointer::offset(cos_sin_cache_ptr, kCosSinStrideBytes / 2); + + uint32_t lane_id = threadIdx.x % kWorkThreads; + if constexpr (kLaneCount < kWorkThreads) { + if (lane_id >= kLaneCount) return; + } + + PDLWaitPrimary(); + + for (auto idx = start_worker_id; idx < num_works; idx += num_workers) { + const int64_t token_id = idx / num_q_and_k_heads; + const int64_t head_id = idx % num_q_and_k_heads; + const auto pos = static_cast(positions)[token_id]; + const auto load_q = head_id < num_qo_heads; + const auto input_ = load_q ? pointer::offset(q, token_id * q_stride_bytes) // + : pointer::offset(k, token_id * k_stride_bytes); + const auto input = pointer::offset(input_, head_id * head_stride_bytes); + const auto cos_ptr = pointer::offset(cos_cache_ptr, pos * kCosSinStrideBytes); + const auto sin_ptr = pointer::offset(sin_cache_ptr, pos * kCosSinStrideBytes); + if constexpr (kIsNeox) { + using CacheStorage = AlignedVector; + const auto input_x = input; + const auto input_y = pointer::offset(input, (kRopeDim / 2) * sizeof(DType)); + auto input_vec_x = load_as(input_x, lane_id); + auto input_vec_y = load_as(input_y, lane_id); + const auto cos_pair = load_as(cos_ptr, lane_id); + const auto sin_pair = load_as(sin_ptr, lane_id); +#pragma unroll + for (int64_t j = 0; j < kVecSize; ++j) { + const auto [x0, x1] = cast(input_vec_x[j]); + const auto [y0, y1] = cast(input_vec_y[j]); + const auto [cos_0, cos_1] = cos_pair[j]; + const auto [sin_0, sin_1] = sin_pair[j]; + const auto out_x0 = x0 * cos_0 - y0 * sin_0; + const auto out_y0 = x0 * sin_0 + y0 * cos_0; + const auto out_x1 = x1 * cos_1 - y1 * sin_1; + const auto out_y1 = x1 * sin_1 + y1 * cos_1; + input_vec_x[j] = cast({out_x0, out_x1}); + input_vec_y[j] = cast({out_y0, out_y1}); + } + store_as(input_x, input_vec_x, lane_id); + store_as(input_y, input_vec_y, lane_id); + } else { + using CacheStorage = AlignedVector; + auto input_vec = load_as(input, lane_id); + const auto cos_vec = load_as(cos_ptr, lane_id); + const auto sin_vec = load_as(sin_ptr, lane_id); +#pragma unroll + for (int64_t j = 0; j < kVecSize; ++j) { + const auto [x, y] = cast(input_vec[j]); + const auto cos = cos_vec[j]; + const auto sin = sin_vec[j]; + const auto out_x = x * cos - y * sin; + const auto out_y = x * sin + y * cos; + input_vec[j] = cast({out_x, out_y}); + } + store_as(input, input_vec, lane_id); + } + } + + PDLTriggerSecondary(); +} + +template +__global__ void fused_rope_store_kernel(const __grid_constant__ FusedRopeStoreParams params) { + using namespace device; + + constexpr int64_t kCosSinStrideBytes = kRopeDim * sizeof(float); + constexpr int64_t kVecSize = kRopeDim / (2 * kWorkThreads * (1 + kIsNeox)); + using DType2 = packed_t; + using InputStorage = AlignedVector; + constexpr int64_t kDimPerThread = kVecSize * 2 * (1 + kIsNeox); + static_assert(kRopeDim == kDimPerThread * kWorkThreads); + + const auto& [base_params, v_ptr, k_cache, v_cache, out_loc, v_stride_bytes, cache_stride_bytes] = params; + const auto &[ + q, k, cos_sin_cache_ptr, positions, // pointers + q_stride_bytes, k_stride_bytes, head_stride_bytes, // strides + num_qo_heads, num_kv_heads, num_tokens // dimensions + ] = base_params; + + const auto num_blks = gridDim.x; + constexpr auto kWorkersPerBlock = kBlockSize / kWorkThreads; + const auto num_workers = num_blks * kWorkersPerBlock; + const auto num_q_and_k_heads = num_qo_heads + num_kv_heads; + const auto num_works = num_q_and_k_heads * num_tokens; + const auto num_extra_works = num_kv_heads * num_tokens; // rope works + v store works + const auto start_worker_id = (blockIdx.x * kBlockSize + threadIdx.x) / kWorkThreads; + const auto lane_id = threadIdx.x % kWorkThreads; + const auto cos_cache_ptr = cos_sin_cache_ptr; + const auto sin_cache_ptr = pointer::offset(cos_sin_cache_ptr, kCosSinStrideBytes / 2); + + auto idx = start_worker_id; + + PDLWaitPrimary(); + // in this case, head_dim = rope_dim must be true + __builtin_assume(head_stride_bytes == kRopeDim * sizeof(DType)); + + for (; idx < num_works; idx += num_workers) { + const int64_t token_id = idx / num_q_and_k_heads; + const int64_t head_id = idx % num_q_and_k_heads; + const auto pos = static_cast(positions)[token_id]; + const auto loc = static_cast(out_loc)[token_id]; + const auto load_q = head_id < num_qo_heads; + const auto input_ = load_q ? pointer::offset(q, token_id * q_stride_bytes) // + : pointer::offset(k, token_id * k_stride_bytes); + const auto input = pointer::offset(input_, head_id * head_stride_bytes); + const auto cos_ptr = pointer::offset(cos_cache_ptr, pos * kCosSinStrideBytes); + const auto sin_ptr = pointer::offset(sin_cache_ptr, pos * kCosSinStrideBytes); + if constexpr (kIsNeox) { + using CacheStorage = AlignedVector; + const auto input_x = input; + const auto input_y = pointer::offset(input, (kRopeDim / 2) * sizeof(DType)); + auto input_vec_x = load_as(input_x, lane_id); + auto input_vec_y = load_as(input_y, lane_id); + const auto cos_pair = load_as(cos_ptr, lane_id); + const auto sin_pair = load_as(sin_ptr, lane_id); +#pragma unroll + for (int64_t j = 0; j < kVecSize; ++j) { + const auto [x0, x1] = cast(input_vec_x[j]); + const auto [y0, y1] = cast(input_vec_y[j]); + const auto [cos_0, cos_1] = cos_pair[j]; + const auto [sin_0, sin_1] = sin_pair[j]; + const auto out_x0 = x0 * cos_0 - y0 * sin_0; + const auto out_y0 = x0 * sin_0 + y0 * cos_0; + const auto out_x1 = x1 * cos_1 - y1 * sin_1; + const auto out_y1 = x1 * sin_1 + y1 * cos_1; + input_vec_x[j] = cast({out_x0, out_x1}); + input_vec_y[j] = cast({out_y0, out_y1}); + } + const auto k_out = pointer::offset(k_cache, loc * cache_stride_bytes, head_id * head_stride_bytes); + const auto output_x = load_q ? input : k_out; + store_as(output_x, input_vec_x, lane_id); + const auto output_y = pointer::offset(output_x, (kRopeDim / 2) * sizeof(DType)); + store_as(output_y, input_vec_y, lane_id); + } else { + using CacheStorage = AlignedVector; + auto input_vec = load_as(input, lane_id); + const auto cos_vec = load_as(cos_ptr, lane_id); + const auto sin_vec = load_as(sin_ptr, lane_id); +#pragma unroll + for (int64_t j = 0; j < kVecSize; ++j) { + const auto [x, y] = cast(input_vec[j]); + const auto cos = cos_vec[j]; + const auto sin = sin_vec[j]; + const auto out_x = x * cos - y * sin; + const auto out_y = x * sin + y * cos; + input_vec[j] = cast({out_x, out_y}); + } + const auto k_out = pointer::offset(k_cache, loc * cache_stride_bytes, head_id * head_stride_bytes); + const auto output = load_q ? input : k_out; + store_as(output, input_vec, lane_id); + } + } + + __syncwarp(); // to avoid warp divergence + idx -= num_works; + for (; idx < num_extra_works; idx += num_workers) { + using VStorage = AlignedVector; + const int64_t token_id = idx / num_kv_heads; + const int64_t head_id = idx % num_kv_heads; + const auto loc = static_cast(out_loc)[token_id]; + const auto input = pointer::offset(v_ptr, token_id * v_stride_bytes, head_id * head_stride_bytes); + const auto input_vec = load_as(input, lane_id); + const auto output = pointer::offset(v_cache, loc * cache_stride_bytes, head_id * head_stride_bytes); + store_as(output, input_vec, lane_id); + } + PDLTriggerSecondary(); +} + +template +struct FusedRopeKernel { + static constexpr uint32_t kDimPerThread = std::gcd(16 / sizeof(DType), kRopeDim); + static constexpr uint32_t kWorkThreads = next_pow2(kRopeDim, kDimPerThread); + static constexpr bool kSupportFused = kWorkThreads * kDimPerThread == kRopeDim; + static_assert(kRopeDim % kDimPerThread == 0); + static_assert(kBlockSize % kWorkThreads == 0); + + template + static constexpr auto _kernel_0 = fused_rope_kernel; + template + static constexpr auto _kernel_1 = fused_rope_store_kernel; + + static auto get_num_sm(DLDevice device) { + static const auto kNumSM = host::runtime::get_sm_count(device.device_id); + return kNumSM; + } + + static void + run(const tvm::ffi::TensorView q, + const tvm::ffi::TensorView k, + const tvm::ffi::TensorView cos_sin_cache, + const tvm::ffi::TensorView positions) { + using namespace host; + auto N = SymbolicSize{"num_tokens"}; + auto Q = SymbolicSize{"num_qo_heads"}; + auto K = SymbolicSize{"num_kv_heads"}; + auto D = SymbolicSize{"rope_dim"}; + auto Dq = SymbolicSize{"q_stride"}; + auto Dk = SymbolicSize{"k_stride"}; + auto Dd = SymbolicSize{"head_stride"}; + auto device = SymbolicDevice{}; + auto id_type = SymbolicDType{}; + D.set_value(kRopeDim); + device.set_options(); + TensorMatcher({N, Q, D}) // q input + .with_strides({Dq, Dd, 1}) + .with_dtype() + .with_device(device) + .verify(q); + TensorMatcher({N, K, D}) // k input + .with_strides({Dk, Dd, 1}) + .with_dtype() + .with_device(device) + .verify(k); + TensorMatcher({-1, D}) // cos_sin_cache + .with_dtype() + .with_device(device) + .verify(cos_sin_cache); + TensorMatcher({N}) // positions + .with_dtype(id_type) + .with_device(device) + .verify(positions); + + const auto num_tokens = static_cast(N.unwrap()); + const auto num_qo_heads = static_cast(Q.unwrap()); + const auto num_kv_heads = static_cast(K.unwrap()); + const auto q_stride_bytes = static_cast(Dq.unwrap() * sizeof(DType)); + const auto k_stride_bytes = static_cast(Dk.unwrap() * sizeof(DType)); + const auto head_stride_bytes = static_cast(Dd.unwrap() * sizeof(DType)); + + // NOTE: we offset the k here to reduce computation cost in the kernel + const int64_t k_offset = static_cast(num_qo_heads) * head_stride_bytes; + const auto params = FusedRopeParams{ + .q_ptr = q.data_ptr(), + .k_ptr = pointer::offset(k.data_ptr(), -k_offset), + .cos_sin_cache_ptr = cos_sin_cache.data_ptr(), + .positions = positions.data_ptr(), + .q_stride_bytes = q_stride_bytes, + .k_stride_bytes = k_stride_bytes, + .head_stride_bytes = head_stride_bytes, + .num_qo_heads = num_qo_heads, + .num_kv_heads = num_kv_heads, + .num_tokens = num_tokens, + }; + + const auto is_int32 = id_type.is_type(); + const auto kernel = is_int32 ? _kernel_0 : _kernel_0; + const uint32_t kNumSM = get_num_sm(device.unwrap()); + static const uint32_t kOccupancyTable[2] = { + runtime::get_blocks_per_sm(_kernel_0, kBlockSize), + runtime::get_blocks_per_sm(_kernel_0, kBlockSize), + }; + const auto max_blocks = kOccupancyTable[is_int32 ? 0 : 1] * kNumSM; + const auto num_works = (num_qo_heads + num_kv_heads) * num_tokens; + const auto needed_blocks = div_ceil(num_works, (kBlockSize / kWorkThreads)); + const auto num_blocks = std::min(max_blocks, needed_blocks); + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } + + static void run_fused( + const tvm::ffi::TensorView q, + const tvm::ffi::TensorView k, + const tvm::ffi::TensorView v, + const tvm::ffi::TensorView k_cache, + const tvm::ffi::TensorView v_cache, + const tvm::ffi::TensorView cos_sin_cache, + const tvm::ffi::TensorView positions, + const tvm::ffi::TensorView out_loc) { + if constexpr (kSupportFused) { + return _run_fused_impl(q, k, v, k_cache, v_cache, cos_sin_cache, positions, out_loc); + } else { + host::Panic("Fused rope + store is not supported for rope_dim ", kRopeDim); + } + } + + static void _run_fused_impl( + const tvm::ffi::TensorView q, + const tvm::ffi::TensorView k, + const tvm::ffi::TensorView v, + const tvm::ffi::TensorView k_cache, + const tvm::ffi::TensorView v_cache, + const tvm::ffi::TensorView cos_sin_cache, + const tvm::ffi::TensorView positions, + const tvm::ffi::TensorView out_loc) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto Q = SymbolicSize{"num_qo_heads"}; + auto K = SymbolicSize{"num_kv_heads"}; + auto D = SymbolicSize{"rope_dim"}; + auto R = SymbolicSize{"row_size"}; + auto Dq = SymbolicSize{"q_stride"}; + auto Dk = SymbolicSize{"k_stride"}; + auto Dv = SymbolicSize{"v_stride"}; + auto Dd = SymbolicSize{"head_stride"}; + auto Dc = SymbolicSize{"cache_stride"}; + auto device = SymbolicDevice{}; + auto id_type = SymbolicDType{}; + D.set_value(kRopeDim); + device.set_options(); + + TensorMatcher({N, Q, D}) // q input + .with_strides({Dq, Dd, 1}) + .with_dtype() + .with_device(device) + .verify(q); + TensorMatcher({N, K, D}) // k input + .with_strides({Dk, Dd, 1}) + .with_dtype() + .with_device(device) + .verify(k); + TensorMatcher({N, K, D}) // v input + .with_strides({Dv, Dd, 1}) + .with_dtype() + .with_device(device) + .verify(v); + TensorMatcher({-1, D}) // cos_sin_cache + .with_dtype() + .with_device(device) + .verify(cos_sin_cache); + TensorMatcher({N}) // positions, out_loc + .with_dtype(id_type) + .with_device(device) + .verify(positions) + .verify(out_loc); + TensorMatcher({-1, R}) // k_cache + .with_strides({Dc, 1}) + .with_dtype() + .with_device(device) + .verify(k_cache) + .verify(v_cache); + + const auto num_tokens = static_cast(N.unwrap()); + const auto num_qo_heads = static_cast(Q.unwrap()); + const auto num_kv_heads = static_cast(K.unwrap()); + const auto q_stride_bytes = static_cast(Dq.unwrap() * sizeof(DType)); + const auto k_stride_bytes = static_cast(Dk.unwrap() * sizeof(DType)); + const auto head_stride = Dd.unwrap(); + const auto row_dim = R.unwrap(); + const auto head_stride_bytes = static_cast(Dd.unwrap() * sizeof(DType)); + + RuntimeCheck(kRopeDim == head_stride, "rope_dim ", kRopeDim, " should = head_stride ", head_stride); + RuntimeCheck(num_kv_heads * kRopeDim == row_dim, "invalid kvcache"); + + // NOTE: we offset the k here to reduce computation cost in the kernel + const int64_t k_offset = static_cast(num_qo_heads) * head_stride_bytes; + const auto params = FusedRopeParams{ + .q_ptr = q.data_ptr(), + .k_ptr = pointer::offset(k.data_ptr(), -k_offset), + .cos_sin_cache_ptr = cos_sin_cache.data_ptr(), + .positions = positions.data_ptr(), + .q_stride_bytes = q_stride_bytes, + .k_stride_bytes = k_stride_bytes, + .head_stride_bytes = head_stride_bytes, + .num_qo_heads = num_qo_heads, + .num_kv_heads = num_kv_heads, + .num_tokens = num_tokens, + }; + + const auto v_stride_bytes = static_cast(Dv.unwrap() * sizeof(DType)); + const auto cache_stride_bytes = static_cast(Dc.unwrap() * sizeof(DType)); + const auto store_params = FusedRopeStoreParams{ + .base_params = params, + .v_ptr = v.data_ptr(), + .k_cache = pointer::offset(k_cache.data_ptr(), -k_offset), + .v_cache = v_cache.data_ptr(), + .out_loc = out_loc.data_ptr(), + .v_stride_bytes = v_stride_bytes, + .cache_stride_bytes = cache_stride_bytes, + }; + + const auto is_int32 = id_type.is_type(); + const auto kernel = is_int32 ? _kernel_1 : _kernel_1; + const uint32_t kNumSM = get_num_sm(device.unwrap()); + static const uint32_t kOccupancyTable[2] = { + runtime::get_blocks_per_sm(_kernel_1, kBlockSize), + runtime::get_blocks_per_sm(_kernel_1, kBlockSize), + }; + const auto max_blocks = kOccupancyTable[is_int32 ? 0 : 1] * kNumSM; + // rope works for q+k heads, plus v store works for kv heads + const auto num_total_works = (num_qo_heads + 2 * num_kv_heads) * num_tokens; + const auto needed_blocks = div_ceil(num_total_works, (kBlockSize / kWorkThreads)); + const auto num_blocks = std::min(max_blocks, needed_blocks); + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, store_params); + } +}; + +} // namespace diff --git a/sglang/python/sglang/jit_kernel/csrc/fast-hadamard-transform/code_gen.py b/sglang/python/sglang/jit_kernel/csrc/fast-hadamard-transform/code_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..b19a8ba698defc9eb3b6db9b25d78775f0a896c0 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/fast-hadamard-transform/code_gen.py @@ -0,0 +1,197 @@ +from pathlib import Path + +import numpy as np + +# From https://en.wikipedia.org/wiki/Paley_construction (construction II for q = 5) + +had_12_paley = """ ++-++++++++++ +--+-+-+-+-+- ++++-++----++ ++---+--+-++- ++++++-++---- ++-+---+--+-+ +++--+++-++-- ++--++---+--+ +++----+++-++ ++--+-++---+- +++++----+++- ++-+--+-++--- +""" + +# From http://neilsloane.com/hadamard/ + +had_12 = """ ++----------- +++-+---+++-+ ++++-+---+++- ++-++-+---+++ +++-++-+---++ ++++-++-+---+ +++++-++-+--- ++-+++-++-+-- ++--+++-++-+- ++---+++-++-+ +++---+++-++- ++-+---+++-++ +""" + +had_20_will = """ ++----+----++--++-++- +-+----+---+++---+-++ +--+----+---+++-+-+-+ +---+----+---+++++-+- +----+----++--++-++-+ +-+++++-----+--+++--+ ++-+++-+---+-+--+++-- +++-++--+---+-+--+++- ++++-+---+---+-+--+++ +++++-----++--+-+--++ +--++-+-++-+-----++++ +---++-+-++-+---+-+++ ++---++-+-+--+--++-++ +++---++-+----+-+++-+ +-++---++-+----+++++- +-+--+--++-+----+---- ++-+-----++-+----+--- +-+-+-+---+--+----+-- +--+-+++------+----+- ++--+--++------+----+ +""" + + +had_28_will = """ ++------++----++-+--+-+--++-- +-+-----+++-----+-+--+-+--++- +--+-----+++---+-+-+----+--++ +---+-----+++---+-+-+-+--+--+ +----+-----+++---+-+-+++--+-- +-----+-----++++--+-+--++--+- +------++----++-+--+-+--++--+ +--++++-+-------++--+++-+--+- +---++++-+-----+-++--+-+-+--+ ++---+++--+----++-++--+-+-+-- +++---++---+----++-++--+-+-+- ++++---+----+----++-++--+-+-+ +++++--------+-+--++-++--+-+- +-++++--------+++--++--+--+-+ +-+-++-++--++--+--------++++- ++-+-++--+--++--+--------++++ +-+-+-++--+--++--+----+---+++ ++-+-+-++--+--+---+---++---++ +++-+-+-++--+------+--+++---+ +-++-+-+-++--+------+-++++--- ++-++-+---++--+------+-++++-- +-++--++-+-++-+++----++------ ++-++--++-+-++-+++-----+----- +++-++---+-+-++-+++-----+---- +-++-++-+-+-+-+--+++-----+--- +--++-++++-+-+----+++-----+-- ++--++-+-++-+-+----+++-----+- +++--++-+-++-+-+----++------+ +""" + + +had_40_tpal = """ ++-------------------+------------------- +++-++----+-+-++++--+++-++----+-+-++++--+ ++++-++----+-+-++++--+++-++----+-+-++++-- ++-++-++----+-+-++++-+-++-++----+-+-++++- ++--++-++----+-+-+++++--++-++----+-+-++++ +++--++-++----+-+-+++++--++-++----+-+-+++ ++++--++-++----+-+-+++++--++-++----+-+-++ +++++--++-++----+-+-+++++--++-++----+-+-+ ++++++--++-++----+-+-+++++--++-++----+-+- ++-++++--++-++----+-++-++++--++-++----+-+ +++-++++--++-++----+-++-++++--++-++----+- ++-+-++++--++-++----++-+-++++--++-++----+ +++-+-++++--++-++----++-+-++++--++-++---- ++-+-+-++++--++-++---+-+-+-++++--++-++--- ++--+-+-++++--++-++--+--+-+-++++--++-++-- ++---+-+-++++--++-++-+---+-+-++++--++-++- ++----+-+-++++--++-+++----+-+-++++--++-++ +++----+-+-++++--++-+++----+-+-++++--++-+ ++++----+-+-++++--++-+++----+-+-++++--++- ++-++----+-+-++++--+++-++----+-+-++++--++ ++--------------------+++++++++++++++++++ +++-++----+-+-++++--+--+--++++-+-+----++- ++++-++----+-+-++++-----+--++++-+-+----++ ++-++-++----+-+-++++--+--+--++++-+-+----+ ++--++-++----+-+-++++-++--+--++++-+-+---- +++--++-++----+-+-+++--++--+--++++-+-+--- ++++--++-++----+-+-++---++--+--++++-+-+-- +++++--++-++----+-+-+----++--+--++++-+-+- ++++++--++-++----+-+------++--+--++++-+-+ ++-++++--++-++----+-+-+----++--+--++++-+- +++-++++--++-++----+---+----++--+--++++-+ ++-+-++++--++-++----+-+-+----++--+--++++- +++-+-++++--++-++------+-+----++--+--++++ ++-+-+-++++--++-++----+-+-+----++--+--+++ ++--+-+-++++--++-++---++-+-+----++--+--++ ++---+-+-++++--++-++--+++-+-+----++--+--+ ++----+-+-++++--++-++-++++-+-+----++--+-- +++----+-+-++++--++-+--++++-+-+----++--+- ++++----+-+-++++--++----++++-+-+----++--+ ++-++----+-+-++++--++-+--++++-+-+----++-- +""" + + +header = """ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// This file is auto-generated. See "code_gen.py"\n + +#pragma once + +""" + +template = """ +__device__ __forceinline__ void hadamard_mult_thread_{N}(float x[{N}]) { + float out[{N}]; + {code} + #pragma unroll + for (int i = 0; i < {N}; i++) { x[i] = out[i]; } +} + +""" + + +def string_to_array(string): + # Convert strings of + and - to bool arrays + string = string.strip().replace("+", "1").replace("-", "-1").split() + return np.stack( + [ + np.fromstring(" ".join(string[i]), dtype=np.int32, sep=" ") + for i in range(len(string)) + ] + ) + + +def array_code_gen(arr): + N = arr.shape[0] + assert arr.shape[0] == arr.shape[1] + out = [] + for i in range(N): + out.append( + f"out[{i}] = " + + " ".join([f"{'+' if arr[i, j] == 1 else '-'} x[{j}]" for j in range(N)]) + + ";" + ) + return template.replace("{N}", str(N)).replace("{code}", "\n ".join(out)) + + +def main(): + output_dir = Path(__file__).parent / "fast_hadamard_transform_special.h" + output_dir.write_text( + header + + array_code_gen(string_to_array(had_12_paley)) + + array_code_gen(string_to_array(had_20_will)) + + array_code_gen(string_to_array(had_28_will)) + + array_code_gen(string_to_array(had_40_tpal)) + ) + + +if __name__ == "__main__": + main() diff --git a/sglang/python/sglang/jit_kernel/csrc/fast-hadamard-transform/fast_hadamard_transform.h b/sglang/python/sglang/jit_kernel/csrc/fast-hadamard-transform/fast_hadamard_transform.h new file mode 100644 index 0000000000000000000000000000000000000000..1dda51c3e29bd9e88b7d0c61f92554c6e2e91eee --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/fast-hadamard-transform/fast_hadamard_transform.h @@ -0,0 +1,24 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Copied from https://github.com/sgl-project/fast-hadamard-transform + +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct HadamardParamsBase { + using index_t = int64_t; + + int batch, dim, log_N; + + index_t x_batch_stride; + index_t out_batch_stride; + + float scale; + + // Common data pointers. + void* __restrict__ x_ptr; + void* __restrict__ out_ptr; +}; diff --git a/sglang/python/sglang/jit_kernel/csrc/fast-hadamard-transform/fast_hadamard_transform_common.h b/sglang/python/sglang/jit_kernel/csrc/fast-hadamard-transform/fast_hadamard_transform_common.h new file mode 100644 index 0000000000000000000000000000000000000000..f6e6117d5372d6687d0736ae3e5bb685d864c4a1 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/fast-hadamard-transform/fast_hadamard_transform_common.h @@ -0,0 +1,214 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Copied from https://github.com/sgl-project/fast-hadamard-transform + +#pragma once + +#include +#include + +#define FULL_MASK 0xffffffff + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct uint8 { + uint4 u; + uint4 v; +}; + +template +struct BytesToType {}; + +template <> +struct BytesToType<32> { + using Type = uint8; + static_assert(sizeof(Type) == 32); +}; + +template <> +struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template <> +struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template <> +struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template <> +struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template <> +struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { + __device__ inline T operator()(T const& x, T const& y) { + return x + y; + } +}; + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +template <> +struct Allreduce<2> { + template + static __device__ inline T run(T x, Operator& op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// https://stackoverflow.com/questions/35311711/whats-the-right-way-to-compute-integral-base-2-logarithms-at-compile-time +constexpr int cilog2(int val) { + return val > 0 ? 1 + cilog2(val >> 1) : -1; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void hadamard_mult_thread(float x[kNChunks][1 << kLogN]) { + constexpr int N = 1 << kLogN; +#pragma unroll + for (int i = 0; i < kLogN; ++i) { + const int stride = 1 << i; +#pragma unroll + for (int j = 0; j < N / 2; ++j) { + const int lo = j & (stride - 1); + const int idx = (j - lo) * 2 + lo; +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { + const float a = x[c][idx]; + const float b = x[c][idx + stride]; + x[c][idx] = a + b; + x[c][idx + stride] = a - b; + } + } + } +} + +template +__device__ __forceinline__ void hadamard_mult_warp(float x[kNChunks][kNItems]) { + constexpr int N = 1 << kLogWarpSize; + int lane_id = threadIdx.x % N; +#pragma unroll + for (int step = kStepStart; step < kLogWarpSize; ++step) { + const int lane_mask = 1 << step; + const float sign = (lane_id & lane_mask) ? -1.f : 1.f; +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { +#pragma unroll + for (int i = 0; i < kNItems; ++i) { + float x_val_other = __shfl_xor_sync(FULL_MASK, x[c][i], lane_mask); + x[c][i] = sign * x[c][i] + x_val_other; + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void load_input(input_t* x, float x_vals[kNChunks][kNElts], int dim) { + using vec_t = typename BytesToType::Type; + input_t x_vals_load[kNChunks][kNElts] = {0}; +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { + if ((c * blockDim.x + threadIdx.x) * kNElts < dim) { + reinterpret_cast(x_vals_load)[c] = reinterpret_cast(x)[c * blockDim.x + threadIdx.x]; + } + } +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { +#pragma unroll + for (int i = 0; i < kNElts; ++i) { + x_vals[c][i] = float(x_vals_load[c][i]); + } + } +} + +template +inline __device__ void store_output(output_t* out, float out_vals[kNChunks][kNElts], int dim, float scale = 1.f) { + using vec_t = typename BytesToType::Type; + output_t out_vals_store[kNChunks][kNElts]; +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { +#pragma unroll + for (int i = 0; i < kNElts; ++i) { + out_vals_store[c][i] = out_vals[c][i] * scale; + } + } +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { + if ((c * blockDim.x + threadIdx.x) * kNElts < dim) { + reinterpret_cast(out)[c * blockDim.x + threadIdx.x] = reinterpret_cast(out_vals_store)[c]; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Pre=true means the exchange before the hadamard_mult_warp, Pre=false means after. +template +inline __device__ void exchange_smem_pre(float x_vals[kNChunks][kNElts], vec_t* smem) { + constexpr int kNThreads = kWarpSize * kNWarps; + constexpr int kNExchangePerVec = kNElts / (sizeof(vec_t) / sizeof(float)); + const int warp_id = threadIdx.x / kWarpSize; + const int lane_id = threadIdx.x % kWarpSize; + const int row_t = threadIdx.x % kNWarps; + const int col_t = threadIdx.x / kNWarps; +// We use the XOR swizzle trick (new_col = col ^ row) to avoid / reduce smem bank conflicts. +#pragma unroll + for (int c0 = 0; c0 < kNChunks / kChunksPerExchange; ++c0) { + __syncthreads(); +#pragma unroll + for (int c1 = 0; c1 < kChunksPerExchange; ++c1) { +#pragma unroll + for (int r = 0; r < kNExchangePerVec; ++r) { + smem + [(c1 * kNExchangePerVec + r) * kNThreads + + (Pre ? warp_id * kWarpSize + lane_id ^ warp_id : row_t * kWarpSize + col_t ^ row_t)] = + reinterpret_cast(x_vals[c0 * kChunksPerExchange + c1])[r]; + } + } + __syncthreads(); +#pragma unroll + for (int c1 = 0; c1 < kChunksPerExchange; ++c1) { +#pragma unroll + for (int r = 0; r < kNExchangePerVec; ++r) { + reinterpret_cast(x_vals[c0 * kChunksPerExchange + c1])[r] = smem + [(c1 * kNExchangePerVec + r) * kNThreads + + (Pre ? row_t * kWarpSize + col_t ^ row_t : warp_id * kWarpSize + lane_id ^ warp_id)]; + } + } + } +} diff --git a/sglang/python/sglang/jit_kernel/csrc/fast-hadamard-transform/fast_hadamard_transform_special.h b/sglang/python/sglang/jit_kernel/csrc/fast-hadamard-transform/fast_hadamard_transform_special.h new file mode 100644 index 0000000000000000000000000000000000000000..b9f92f597099deab1b587291f570da8b90f233af --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/fast-hadamard-transform/fast_hadamard_transform_special.h @@ -0,0 +1,298 @@ + +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Copied from https://github.com/sgl-project/fast-hadamard-transform + +// This file is auto-generated. See "code_gen.py" + +#pragma once + +__device__ __forceinline__ void hadamard_mult_thread_12(float x[12]) { + float out[12]; + out[0] = +x[0] - x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + x[8] + x[9] + x[10] + x[11]; + out[1] = -x[0] - x[1] + x[2] - x[3] + x[4] - x[5] + x[6] - x[7] + x[8] - x[9] + x[10] - x[11]; + out[2] = +x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - x[9] + x[10] + x[11]; + out[3] = +x[0] - x[1] - x[2] - x[3] + x[4] - x[5] - x[6] + x[7] - x[8] + x[9] + x[10] - x[11]; + out[4] = +x[0] + x[1] + x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11]; + out[5] = +x[0] - x[1] + x[2] - x[3] - x[4] - x[5] + x[6] - x[7] - x[8] + x[9] - x[10] + x[11]; + out[6] = +x[0] + x[1] - x[2] - x[3] + x[4] + x[5] + x[6] - x[7] + x[8] + x[9] - x[10] - x[11]; + out[7] = +x[0] - x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] + x[8] - x[9] - x[10] + x[11]; + out[8] = +x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] + x[7] + x[8] - x[9] + x[10] + x[11]; + out[9] = +x[0] - x[1] - x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] + x[10] - x[11]; + out[10] = +x[0] + x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] + x[9] + x[10] - x[11]; + out[11] = +x[0] - x[1] + x[2] - x[3] - x[4] + x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11]; +#pragma unroll + for (int i = 0; i < 12; i++) { + x[i] = out[i]; + } +} + +__device__ __forceinline__ void hadamard_mult_thread_20(float x[20]) { + float out[20]; + out[0] = +x[0] - x[1] - x[2] - x[3] - x[4] + x[5] - x[6] - x[7] - x[8] - x[9] + x[10] + x[11] - x[12] - x[13] + + x[14] + x[15] - x[16] + x[17] + x[18] - x[19]; + out[1] = -x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] - x[7] - x[8] - x[9] + x[10] + x[11] + x[12] - x[13] - + x[14] - x[15] + x[16] - x[17] + x[18] + x[19]; + out[2] = -x[0] - x[1] + x[2] - x[3] - x[4] - x[5] - x[6] + x[7] - x[8] - x[9] - x[10] + x[11] + x[12] + x[13] - + x[14] + x[15] - x[16] + x[17] - x[18] + x[19]; + out[3] = -x[0] - x[1] - x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] - x[9] - x[10] - x[11] + x[12] + x[13] + + x[14] + x[15] + x[16] - x[17] + x[18] - x[19]; + out[4] = -x[0] - x[1] - x[2] - x[3] + x[4] - x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] - x[12] + x[13] + + x[14] - x[15] + x[16] + x[17] - x[18] + x[19]; + out[5] = -x[0] + x[1] + x[2] + x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - x[9] - x[10] + x[11] - x[12] - x[13] + + x[14] + x[15] + x[16] - x[17] - x[18] + x[19]; + out[6] = +x[0] - x[1] + x[2] + x[3] + x[4] - x[5] + x[6] - x[7] - x[8] - x[9] + x[10] - x[11] + x[12] - x[13] - + x[14] + x[15] + x[16] + x[17] - x[18] - x[19]; + out[7] = +x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] + x[7] - x[8] - x[9] - x[10] + x[11] - x[12] + x[13] - + x[14] - x[15] + x[16] + x[17] + x[18] - x[19]; + out[8] = +x[0] + x[1] + x[2] - x[3] + x[4] - x[5] - x[6] - x[7] + x[8] - x[9] - x[10] - x[11] + x[12] - x[13] + + x[14] - x[15] - x[16] + x[17] + x[18] + x[19]; + out[9] = +x[0] + x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] - x[12] + x[13] - + x[14] + x[15] - x[16] - x[17] + x[18] + x[19]; + out[10] = -x[0] - x[1] + x[2] + x[3] - x[4] + x[5] - x[6] + x[7] + x[8] - x[9] + x[10] - x[11] - x[12] - x[13] - + x[14] - x[15] + x[16] + x[17] + x[18] + x[19]; + out[11] = -x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] - x[10] + x[11] - x[12] - x[13] - + x[14] + x[15] - x[16] + x[17] + x[18] + x[19]; + out[12] = +x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] + x[9] - x[10] - x[11] + x[12] - x[13] - + x[14] + x[15] + x[16] - x[17] + x[18] + x[19]; + out[13] = +x[0] + x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] - x[9] - x[10] - x[11] - x[12] + x[13] - + x[14] + x[15] + x[16] + x[17] - x[18] + x[19]; + out[14] = -x[0] + x[1] + x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + x[9] - x[10] - x[11] - x[12] - x[13] + + x[14] + x[15] + x[16] + x[17] + x[18] - x[19]; + out[15] = -x[0] + x[1] - x[2] - x[3] + x[4] - x[5] - x[6] + x[7] + x[8] - x[9] + x[10] - x[11] - x[12] - x[13] - + x[14] + x[15] - x[16] - x[17] - x[18] - x[19]; + out[16] = +x[0] - x[1] + x[2] - x[3] - x[4] - x[5] - x[6] - x[7] + x[8] + x[9] - x[10] + x[11] - x[12] - x[13] - + x[14] - x[15] + x[16] - x[17] - x[18] - x[19]; + out[17] = -x[0] + x[1] - x[2] + x[3] - x[4] + x[5] - x[6] - x[7] - x[8] + x[9] - x[10] - x[11] + x[12] - x[13] - + x[14] - x[15] - x[16] + x[17] - x[18] - x[19]; + out[18] = -x[0] - x[1] + x[2] - x[3] + x[4] + x[5] + x[6] - x[7] - x[8] - x[9] - x[10] - x[11] - x[12] + x[13] - + x[14] - x[15] - x[16] - x[17] + x[18] - x[19]; + out[19] = +x[0] - x[1] - x[2] + x[3] - x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11] - x[12] - x[13] + + x[14] - x[15] - x[16] - x[17] - x[18] + x[19]; +#pragma unroll + for (int i = 0; i < 20; i++) { + x[i] = out[i]; + } +} + +__device__ __forceinline__ void hadamard_mult_thread_28(float x[28]) { + float out[28]; + out[0] = +x[0] - x[1] - x[2] - x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11] - x[12] + x[13] + + x[14] - x[15] + x[16] - x[17] - x[18] + x[19] - x[20] + x[21] - x[22] - x[23] + x[24] + x[25] - x[26] - + x[27]; + out[1] = -x[0] + x[1] - x[2] - x[3] - x[4] - x[5] - x[6] + x[7] + x[8] + x[9] - x[10] - x[11] - x[12] - x[13] - + x[14] + x[15] - x[16] + x[17] - x[18] - x[19] + x[20] - x[21] + x[22] - x[23] - x[24] + x[25] + x[26] - + x[27]; + out[2] = -x[0] - x[1] + x[2] - x[3] - x[4] - x[5] - x[6] - x[7] + x[8] + x[9] + x[10] - x[11] - x[12] - x[13] + + x[14] - x[15] + x[16] - x[17] + x[18] - x[19] - x[20] - x[21] - x[22] + x[23] - x[24] - x[25] + x[26] + + x[27]; + out[3] = -x[0] - x[1] - x[2] + x[3] - x[4] - x[5] - x[6] - x[7] - x[8] + x[9] + x[10] + x[11] - x[12] - x[13] - + x[14] + x[15] - x[16] + x[17] - x[18] + x[19] - x[20] + x[21] - x[22] - x[23] + x[24] - x[25] - x[26] + + x[27]; + out[4] = -x[0] - x[1] - x[2] - x[3] + x[4] - x[5] - x[6] - x[7] - x[8] - x[9] + x[10] + x[11] + x[12] - x[13] - + x[14] - x[15] + x[16] - x[17] + x[18] - x[19] + x[20] + x[21] + x[22] - x[23] - x[24] + x[25] - x[26] - + x[27]; + out[5] = -x[0] - x[1] - x[2] - x[3] - x[4] + x[5] - x[6] - x[7] - x[8] - x[9] - x[10] + x[11] + x[12] + x[13] + + x[14] - x[15] - x[16] + x[17] - x[18] + x[19] - x[20] - x[21] + x[22] + x[23] - x[24] - x[25] + x[26] - + x[27]; + out[6] = -x[0] - x[1] - x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11] + x[12] + x[13] - + x[14] + x[15] - x[16] - x[17] + x[18] - x[19] + x[20] - x[21] - x[22] + x[23] + x[24] - x[25] - x[26] + + x[27]; + out[7] = -x[0] - x[1] + x[2] + x[3] + x[4] + x[5] - x[6] + x[7] - x[8] - x[9] - x[10] - x[11] - x[12] - x[13] - + x[14] + x[15] + x[16] - x[17] - x[18] + x[19] + x[20] + x[21] - x[22] + x[23] - x[24] - x[25] + x[26] - + x[27]; + out[8] = -x[0] - x[1] - x[2] + x[3] + x[4] + x[5] + x[6] - x[7] + x[8] - x[9] - x[10] - x[11] - x[12] - x[13] + + x[14] - x[15] + x[16] + x[17] - x[18] - x[19] + x[20] - x[21] + x[22] - x[23] + x[24] - x[25] - x[26] + + x[27]; + out[9] = +x[0] - x[1] - x[2] - x[3] + x[4] + x[5] + x[6] - x[7] - x[8] + x[9] - x[10] - x[11] - x[12] - x[13] + + x[14] + x[15] - x[16] + x[17] + x[18] - x[19] - x[20] + x[21] - x[22] + x[23] - x[24] + x[25] - x[26] - + x[27]; + out[10] = +x[0] + x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] + x[10] - x[11] - x[12] - x[13] - + x[14] + x[15] + x[16] - x[17] + x[18] + x[19] - x[20] - x[21] + x[22] - x[23] + x[24] - x[25] + x[26] - + x[27]; + out[11] = +x[0] + x[1] + x[2] - x[3] - x[4] - x[5] + x[6] - x[7] - x[8] - x[9] - x[10] + x[11] - x[12] - x[13] - + x[14] - x[15] + x[16] + x[17] - x[18] + x[19] + x[20] - x[21] - x[22] + x[23] - x[24] + x[25] - x[26] + + x[27]; + out[12] = +x[0] + x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] - x[8] - x[9] - x[10] - x[11] + x[12] - x[13] + + x[14] - x[15] - x[16] + x[17] + x[18] - x[19] + x[20] + x[21] - x[22] - x[23] + x[24] - x[25] + x[26] - + x[27]; + out[13] = -x[0] + x[1] + x[2] + x[3] + x[4] - x[5] - x[6] - x[7] - x[8] - x[9] - x[10] - x[11] - x[12] + x[13] + + x[14] + x[15] - x[16] - x[17] + x[18] + x[19] - x[20] - x[21] + x[22] - x[23] - x[24] + x[25] - x[26] + + x[27]; + out[14] = -x[0] + x[1] - x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] + x[10] + x[11] - x[12] - x[13] + + x[14] - x[15] - x[16] - x[17] - x[18] - x[19] - x[20] - x[21] - x[22] + x[23] + x[24] + x[25] + x[26] - + x[27]; + out[15] = +x[0] - x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] + x[8] - x[9] - x[10] + x[11] + x[12] - x[13] - + x[14] + x[15] - x[16] - x[17] - x[18] - x[19] - x[20] - x[21] - x[22] - x[23] + x[24] + x[25] + x[26] + + x[27]; + out[16] = -x[0] + x[1] - x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] + x[9] - x[10] - x[11] + x[12] + x[13] - + x[14] - x[15] + x[16] - x[17] - x[18] - x[19] - x[20] + x[21] - x[22] - x[23] - x[24] + x[25] + x[26] + + x[27]; + out[17] = +x[0] - x[1] + x[2] - x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] + x[10] - x[11] - x[12] + x[13] - + x[14] - x[15] - x[16] + x[17] - x[18] - x[19] - x[20] + x[21] + x[22] - x[23] - x[24] - x[25] + x[26] + + x[27]; + out[18] = +x[0] + x[1] - x[2] + x[3] - x[4] + x[5] - x[6] + x[7] + x[8] - x[9] - x[10] + x[11] - x[12] - x[13] - + x[14] - x[15] - x[16] - x[17] + x[18] - x[19] - x[20] + x[21] + x[22] + x[23] - x[24] - x[25] - x[26] + + x[27]; + out[19] = -x[0] + x[1] + x[2] - x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] - x[10] - x[11] + x[12] - x[13] - + x[14] - x[15] - x[16] - x[17] - x[18] + x[19] - x[20] + x[21] + x[22] + x[23] + x[24] - x[25] - x[26] - + x[27]; + out[20] = +x[0] - x[1] + x[2] + x[3] - x[4] + x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] - x[12] + x[13] - + x[14] - x[15] - x[16] - x[17] - x[18] - x[19] + x[20] - x[21] + x[22] + x[23] + x[24] + x[25] - x[26] - + x[27]; + out[21] = -x[0] + x[1] + x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] - x[9] + x[10] + x[11] - x[12] + x[13] + + x[14] + x[15] - x[16] - x[17] - x[18] - x[19] + x[20] + x[21] - x[22] - x[23] - x[24] - x[25] - x[26] - + x[27]; + out[22] = +x[0] - x[1] + x[2] + x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + x[9] - x[10] + x[11] + x[12] - x[13] + + x[14] + x[15] + x[16] - x[17] - x[18] - x[19] - x[20] - x[21] + x[22] - x[23] - x[24] - x[25] - x[26] - + x[27]; + out[23] = +x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] + x[8] - x[9] + x[10] - x[11] + x[12] + x[13] - + x[14] + x[15] + x[16] + x[17] - x[18] - x[19] - x[20] - x[21] - x[22] + x[23] - x[24] - x[25] - x[26] - + x[27]; + out[24] = -x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] + x[9] - x[10] + x[11] - x[12] + x[13] - + x[14] - x[15] + x[16] + x[17] + x[18] - x[19] - x[20] - x[21] - x[22] - x[23] + x[24] - x[25] - x[26] - + x[27]; + out[25] = -x[0] - x[1] + x[2] + x[3] - x[4] + x[5] + x[6] + x[7] + x[8] - x[9] + x[10] - x[11] + x[12] - x[13] - + x[14] - x[15] - x[16] + x[17] + x[18] + x[19] - x[20] - x[21] - x[22] - x[23] - x[24] + x[25] - x[26] - + x[27]; + out[26] = +x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] - x[10] + x[11] - x[12] + x[13] - + x[14] - x[15] - x[16] - x[17] + x[18] + x[19] + x[20] - x[21] - x[22] - x[23] - x[24] - x[25] + x[26] - + x[27]; + out[27] = +x[0] + x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] + x[9] + x[10] - x[11] + x[12] - x[13] + + x[14] - x[15] - x[16] - x[17] - x[18] + x[19] + x[20] - x[21] - x[22] - x[23] - x[24] - x[25] - x[26] + + x[27]; +#pragma unroll + for (int i = 0; i < 28; i++) { + x[i] = out[i]; + } +} + +__device__ __forceinline__ void hadamard_mult_thread_40(float x[40]) { + float out[40]; + out[0] = +x[0] - x[1] - x[2] - x[3] - x[4] - x[5] - x[6] - x[7] - x[8] - x[9] - x[10] - x[11] - x[12] - x[13] - + x[14] - x[15] - x[16] - x[17] - x[18] - x[19] + x[20] - x[21] - x[22] - x[23] - x[24] - x[25] - x[26] - + x[27] - x[28] - x[29] - x[30] - x[31] - x[32] - x[33] - x[34] - x[35] - x[36] - x[37] - x[38] - x[39]; + out[1] = +x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] - x[8] + x[9] - x[10] + x[11] - x[12] + x[13] + + x[14] + x[15] + x[16] - x[17] - x[18] + x[19] + x[20] + x[21] - x[22] + x[23] + x[24] - x[25] - x[26] - + x[27] - x[28] + x[29] - x[30] + x[31] - x[32] + x[33] + x[34] + x[35] + x[36] - x[37] - x[38] + x[39]; + out[2] = +x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - x[9] + x[10] - x[11] + x[12] - x[13] + + x[14] + x[15] + x[16] + x[17] - x[18] - x[19] + x[20] + x[21] + x[22] - x[23] + x[24] + x[25] - x[26] - + x[27] - x[28] - x[29] + x[30] - x[31] + x[32] - x[33] + x[34] + x[35] + x[36] + x[37] - x[38] - x[39]; + out[3] = +x[0] - x[1] + x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] - x[10] + x[11] - x[12] + x[13] - + x[14] + x[15] + x[16] + x[17] + x[18] - x[19] + x[20] - x[21] + x[22] + x[23] - x[24] + x[25] + x[26] - + x[27] - x[28] - x[29] - x[30] + x[31] - x[32] + x[33] - x[34] + x[35] + x[36] + x[37] + x[38] - x[39]; + out[4] = +x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11] + x[12] - x[13] + + x[14] - x[15] + x[16] + x[17] + x[18] + x[19] + x[20] - x[21] - x[22] + x[23] + x[24] - x[25] + x[26] + + x[27] - x[28] - x[29] - x[30] - x[31] + x[32] - x[33] + x[34] - x[35] + x[36] + x[37] + x[38] + x[39]; + out[5] = +x[0] + x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11] - x[12] + x[13] - + x[14] + x[15] - x[16] + x[17] + x[18] + x[19] + x[20] + x[21] - x[22] - x[23] + x[24] + x[25] - x[26] + + x[27] + x[28] - x[29] - x[30] - x[31] - x[32] + x[33] - x[34] + x[35] - x[36] + x[37] + x[38] + x[39]; + out[6] = +x[0] + x[1] + x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] + x[9] - x[10] - x[11] - x[12] - x[13] + + x[14] - x[15] + x[16] - x[17] + x[18] + x[19] + x[20] + x[21] + x[22] - x[23] - x[24] + x[25] + x[26] - + x[27] + x[28] + x[29] - x[30] - x[31] - x[32] - x[33] + x[34] - x[35] + x[36] - x[37] + x[38] + x[39]; + out[7] = +x[0] + x[1] + x[2] + x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + x[9] + x[10] - x[11] - x[12] - x[13] - + x[14] + x[15] - x[16] + x[17] - x[18] + x[19] + x[20] + x[21] + x[22] + x[23] - x[24] - x[25] + x[26] + + x[27] - x[28] + x[29] + x[30] - x[31] - x[32] - x[33] - x[34] + x[35] - x[36] + x[37] - x[38] + x[39]; + out[8] = +x[0] + x[1] + x[2] + x[3] + x[4] - x[5] - x[6] + x[7] + x[8] - x[9] + x[10] + x[11] - x[12] - x[13] - + x[14] - x[15] + x[16] - x[17] + x[18] - x[19] + x[20] + x[21] + x[22] + x[23] + x[24] - x[25] - x[26] + + x[27] + x[28] - x[29] + x[30] + x[31] - x[32] - x[33] - x[34] - x[35] + x[36] - x[37] + x[38] - x[39]; + out[9] = +x[0] - x[1] + x[2] + x[3] + x[4] + x[5] - x[6] - x[7] + x[8] + x[9] - x[10] + x[11] + x[12] - x[13] - + x[14] - x[15] - x[16] + x[17] - x[18] + x[19] + x[20] - x[21] + x[22] + x[23] + x[24] + x[25] - x[26] - + x[27] + x[28] + x[29] - x[30] + x[31] + x[32] - x[33] - x[34] - x[35] - x[36] + x[37] - x[38] + x[39]; + out[10] = +x[0] + x[1] - x[2] + x[3] + x[4] + x[5] + x[6] - x[7] - x[8] + x[9] + x[10] - x[11] + x[12] + x[13] - + x[14] - x[15] - x[16] - x[17] + x[18] - x[19] + x[20] + x[21] - x[22] + x[23] + x[24] + x[25] + x[26] - + x[27] - x[28] + x[29] + x[30] - x[31] + x[32] + x[33] - x[34] - x[35] - x[36] - x[37] + x[38] - x[39]; + out[11] = +x[0] - x[1] + x[2] - x[3] + x[4] + x[5] + x[6] + x[7] - x[8] - x[9] + x[10] + x[11] - x[12] + x[13] + + x[14] - x[15] - x[16] - x[17] - x[18] + x[19] + x[20] - x[21] + x[22] - x[23] + x[24] + x[25] + x[26] + + x[27] - x[28] - x[29] + x[30] + x[31] - x[32] + x[33] + x[34] - x[35] - x[36] - x[37] - x[38] + x[39]; + out[12] = +x[0] + x[1] - x[2] + x[3] - x[4] + x[5] + x[6] + x[7] + x[8] - x[9] - x[10] + x[11] + x[12] - x[13] + + x[14] + x[15] - x[16] - x[17] - x[18] - x[19] + x[20] + x[21] - x[22] + x[23] - x[24] + x[25] + x[26] + + x[27] + x[28] - x[29] - x[30] + x[31] + x[32] - x[33] + x[34] + x[35] - x[36] - x[37] - x[38] - x[39]; + out[13] = +x[0] - x[1] + x[2] - x[3] + x[4] - x[5] + x[6] + x[7] + x[8] + x[9] - x[10] - x[11] + x[12] + x[13] - + x[14] + x[15] + x[16] - x[17] - x[18] - x[19] + x[20] - x[21] + x[22] - x[23] + x[24] - x[25] + x[26] + + x[27] + x[28] + x[29] - x[30] - x[31] + x[32] + x[33] - x[34] + x[35] + x[36] - x[37] - x[38] - x[39]; + out[14] = +x[0] - x[1] - x[2] + x[3] - x[4] + x[5] - x[6] + x[7] + x[8] + x[9] + x[10] - x[11] - x[12] + x[13] + + x[14] - x[15] + x[16] + x[17] - x[18] - x[19] + x[20] - x[21] - x[22] + x[23] - x[24] + x[25] - x[26] + + x[27] + x[28] + x[29] + x[30] - x[31] - x[32] + x[33] + x[34] - x[35] + x[36] + x[37] - x[38] - x[39]; + out[15] = +x[0] - x[1] - x[2] - x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] + x[10] + x[11] - x[12] - x[13] + + x[14] + x[15] - x[16] + x[17] + x[18] - x[19] + x[20] - x[21] - x[22] - x[23] + x[24] - x[25] + x[26] - + x[27] + x[28] + x[29] + x[30] + x[31] - x[32] - x[33] + x[34] + x[35] - x[36] + x[37] + x[38] - x[39]; + out[16] = +x[0] - x[1] - x[2] - x[3] - x[4] + x[5] - x[6] + x[7] - x[8] + x[9] + x[10] + x[11] + x[12] - x[13] - + x[14] + x[15] + x[16] - x[17] + x[18] + x[19] + x[20] - x[21] - x[22] - x[23] - x[24] + x[25] - x[26] + + x[27] - x[28] + x[29] + x[30] + x[31] + x[32] - x[33] - x[34] + x[35] + x[36] - x[37] + x[38] + x[39]; + out[17] = +x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] - x[7] + x[8] - x[9] + x[10] + x[11] + x[12] + x[13] - + x[14] - x[15] + x[16] + x[17] - x[18] + x[19] + x[20] + x[21] - x[22] - x[23] - x[24] - x[25] + x[26] - + x[27] + x[28] - x[29] + x[30] + x[31] + x[32] + x[33] - x[34] - x[35] + x[36] + x[37] - x[38] + x[39]; + out[18] = +x[0] + x[1] + x[2] - x[3] - x[4] - x[5] - x[6] + x[7] - x[8] + x[9] - x[10] + x[11] + x[12] + x[13] + + x[14] - x[15] - x[16] + x[17] + x[18] - x[19] + x[20] + x[21] + x[22] - x[23] - x[24] - x[25] - x[26] + + x[27] - x[28] + x[29] - x[30] + x[31] + x[32] + x[33] + x[34] - x[35] - x[36] + x[37] + x[38] - x[39]; + out[19] = +x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] - x[9] + x[10] - x[11] + x[12] + x[13] + + x[14] + x[15] - x[16] - x[17] + x[18] + x[19] + x[20] - x[21] + x[22] + x[23] - x[24] - x[25] - x[26] - + x[27] + x[28] - x[29] + x[30] - x[31] + x[32] + x[33] + x[34] + x[35] - x[36] - x[37] + x[38] + x[39]; + out[20] = +x[0] - x[1] - x[2] - x[3] - x[4] - x[5] - x[6] - x[7] - x[8] - x[9] - x[10] - x[11] - x[12] - x[13] - + x[14] - x[15] - x[16] - x[17] - x[18] - x[19] - x[20] + x[21] + x[22] + x[23] + x[24] + x[25] + x[26] + + x[27] + x[28] + x[29] + x[30] + x[31] + x[32] + x[33] + x[34] + x[35] + x[36] + x[37] + x[38] + x[39]; + out[21] = +x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] - x[8] + x[9] - x[10] + x[11] - x[12] + x[13] + + x[14] + x[15] + x[16] - x[17] - x[18] + x[19] - x[20] - x[21] + x[22] - x[23] - x[24] + x[25] + x[26] + + x[27] + x[28] - x[29] + x[30] - x[31] + x[32] - x[33] - x[34] - x[35] - x[36] + x[37] + x[38] - x[39]; + out[22] = +x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - x[9] + x[10] - x[11] + x[12] - x[13] + + x[14] + x[15] + x[16] + x[17] - x[18] - x[19] - x[20] - x[21] - x[22] + x[23] - x[24] - x[25] + x[26] + + x[27] + x[28] + x[29] - x[30] + x[31] - x[32] + x[33] - x[34] - x[35] - x[36] - x[37] + x[38] + x[39]; + out[23] = +x[0] - x[1] + x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] - x[10] + x[11] - x[12] + x[13] - + x[14] + x[15] + x[16] + x[17] + x[18] - x[19] - x[20] + x[21] - x[22] - x[23] + x[24] - x[25] - x[26] + + x[27] + x[28] + x[29] + x[30] - x[31] + x[32] - x[33] + x[34] - x[35] - x[36] - x[37] - x[38] + x[39]; + out[24] = +x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11] + x[12] - x[13] + + x[14] - x[15] + x[16] + x[17] + x[18] + x[19] - x[20] + x[21] + x[22] - x[23] - x[24] + x[25] - x[26] - + x[27] + x[28] + x[29] + x[30] + x[31] - x[32] + x[33] - x[34] + x[35] - x[36] - x[37] - x[38] - x[39]; + out[25] = +x[0] + x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11] - x[12] + x[13] - + x[14] + x[15] - x[16] + x[17] + x[18] + x[19] - x[20] - x[21] + x[22] + x[23] - x[24] - x[25] + x[26] - + x[27] - x[28] + x[29] + x[30] + x[31] + x[32] - x[33] + x[34] - x[35] + x[36] - x[37] - x[38] - x[39]; + out[26] = +x[0] + x[1] + x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] + x[9] - x[10] - x[11] - x[12] - x[13] + + x[14] - x[15] + x[16] - x[17] + x[18] + x[19] - x[20] - x[21] - x[22] + x[23] + x[24] - x[25] - x[26] + + x[27] - x[28] - x[29] + x[30] + x[31] + x[32] + x[33] - x[34] + x[35] - x[36] + x[37] - x[38] - x[39]; + out[27] = +x[0] + x[1] + x[2] + x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + x[9] + x[10] - x[11] - x[12] - x[13] - + x[14] + x[15] - x[16] + x[17] - x[18] + x[19] - x[20] - x[21] - x[22] - x[23] + x[24] + x[25] - x[26] - + x[27] + x[28] - x[29] - x[30] + x[31] + x[32] + x[33] + x[34] - x[35] + x[36] - x[37] + x[38] - x[39]; + out[28] = +x[0] + x[1] + x[2] + x[3] + x[4] - x[5] - x[6] + x[7] + x[8] - x[9] + x[10] + x[11] - x[12] - x[13] - + x[14] - x[15] + x[16] - x[17] + x[18] - x[19] - x[20] - x[21] - x[22] - x[23] - x[24] + x[25] + x[26] - + x[27] - x[28] + x[29] - x[30] - x[31] + x[32] + x[33] + x[34] + x[35] - x[36] + x[37] - x[38] + x[39]; + out[29] = +x[0] - x[1] + x[2] + x[3] + x[4] + x[5] - x[6] - x[7] + x[8] + x[9] - x[10] + x[11] + x[12] - x[13] - + x[14] - x[15] - x[16] + x[17] - x[18] + x[19] - x[20] + x[21] - x[22] - x[23] - x[24] - x[25] + x[26] + + x[27] - x[28] - x[29] + x[30] - x[31] - x[32] + x[33] + x[34] + x[35] + x[36] - x[37] + x[38] - x[39]; + out[30] = +x[0] + x[1] - x[2] + x[3] + x[4] + x[5] + x[6] - x[7] - x[8] + x[9] + x[10] - x[11] + x[12] + x[13] - + x[14] - x[15] - x[16] - x[17] + x[18] - x[19] - x[20] - x[21] + x[22] - x[23] - x[24] - x[25] - x[26] + + x[27] + x[28] - x[29] - x[30] + x[31] - x[32] - x[33] + x[34] + x[35] + x[36] + x[37] - x[38] + x[39]; + out[31] = +x[0] - x[1] + x[2] - x[3] + x[4] + x[5] + x[6] + x[7] - x[8] - x[9] + x[10] + x[11] - x[12] + x[13] + + x[14] - x[15] - x[16] - x[17] - x[18] + x[19] - x[20] + x[21] - x[22] + x[23] - x[24] - x[25] - x[26] - + x[27] + x[28] + x[29] - x[30] - x[31] + x[32] - x[33] - x[34] + x[35] + x[36] + x[37] + x[38] - x[39]; + out[32] = +x[0] + x[1] - x[2] + x[3] - x[4] + x[5] + x[6] + x[7] + x[8] - x[9] - x[10] + x[11] + x[12] - x[13] + + x[14] + x[15] - x[16] - x[17] - x[18] - x[19] - x[20] - x[21] + x[22] - x[23] + x[24] - x[25] - x[26] - + x[27] - x[28] + x[29] + x[30] - x[31] - x[32] + x[33] - x[34] - x[35] + x[36] + x[37] + x[38] + x[39]; + out[33] = +x[0] - x[1] + x[2] - x[3] + x[4] - x[5] + x[6] + x[7] + x[8] + x[9] - x[10] - x[11] + x[12] + x[13] - + x[14] + x[15] + x[16] - x[17] - x[18] - x[19] - x[20] + x[21] - x[22] + x[23] - x[24] + x[25] - x[26] - + x[27] - x[28] - x[29] + x[30] + x[31] - x[32] - x[33] + x[34] - x[35] - x[36] + x[37] + x[38] + x[39]; + out[34] = +x[0] - x[1] - x[2] + x[3] - x[4] + x[5] - x[6] + x[7] + x[8] + x[9] + x[10] - x[11] - x[12] + x[13] + + x[14] - x[15] + x[16] + x[17] - x[18] - x[19] - x[20] + x[21] + x[22] - x[23] + x[24] - x[25] + x[26] - + x[27] - x[28] - x[29] - x[30] + x[31] + x[32] - x[33] - x[34] + x[35] - x[36] - x[37] + x[38] + x[39]; + out[35] = +x[0] - x[1] - x[2] - x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] + x[10] + x[11] - x[12] - x[13] + + x[14] + x[15] - x[16] + x[17] + x[18] - x[19] - x[20] + x[21] + x[22] + x[23] - x[24] + x[25] - x[26] + + x[27] - x[28] - x[29] - x[30] - x[31] + x[32] + x[33] - x[34] - x[35] + x[36] - x[37] - x[38] + x[39]; + out[36] = +x[0] - x[1] - x[2] - x[3] - x[4] + x[5] - x[6] + x[7] - x[8] + x[9] + x[10] + x[11] + x[12] - x[13] - + x[14] + x[15] + x[16] - x[17] + x[18] + x[19] - x[20] + x[21] + x[22] + x[23] + x[24] - x[25] + x[26] - + x[27] + x[28] - x[29] - x[30] - x[31] - x[32] + x[33] + x[34] - x[35] - x[36] + x[37] - x[38] - x[39]; + out[37] = +x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] - x[7] + x[8] - x[9] + x[10] + x[11] + x[12] + x[13] - + x[14] - x[15] + x[16] + x[17] - x[18] + x[19] - x[20] - x[21] + x[22] + x[23] + x[24] + x[25] - x[26] + + x[27] - x[28] + x[29] - x[30] - x[31] - x[32] - x[33] + x[34] + x[35] - x[36] - x[37] + x[38] - x[39]; + out[38] = +x[0] + x[1] + x[2] - x[3] - x[4] - x[5] - x[6] + x[7] - x[8] + x[9] - x[10] + x[11] + x[12] + x[13] + + x[14] - x[15] - x[16] + x[17] + x[18] - x[19] - x[20] - x[21] - x[22] + x[23] + x[24] + x[25] + x[26] - + x[27] + x[28] - x[29] + x[30] - x[31] - x[32] - x[33] - x[34] + x[35] + x[36] - x[37] - x[38] + x[39]; + out[39] = +x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] - x[9] + x[10] - x[11] + x[12] + x[13] + + x[14] + x[15] - x[16] - x[17] + x[18] + x[19] - x[20] + x[21] - x[22] - x[23] + x[24] + x[25] + x[26] + + x[27] - x[28] + x[29] - x[30] + x[31] - x[32] - x[33] - x[34] - x[35] + x[36] + x[37] - x[38] - x[39]; +#pragma unroll + for (int i = 0; i < 40; i++) { + x[i] = out[i]; + } +} diff --git a/sglang/python/sglang/jit_kernel/csrc/fast-hadamard-transform/hadamard_jit.cuh b/sglang/python/sglang/jit_kernel/csrc/fast-hadamard-transform/hadamard_jit.cuh new file mode 100644 index 0000000000000000000000000000000000000000..1be821f29c1c6d7f2a38858f0401c9e796fd1bca --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/fast-hadamard-transform/hadamard_jit.cuh @@ -0,0 +1,482 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +#include + +#include + +#include "fast_hadamard_transform.h" +#include "fast_hadamard_transform_common.h" +#include "fast_hadamard_transform_special.h" +#include "static_switch.h" +#include +#include +#include + +namespace { + +using ::bf16_t; +using ::fp16_t; +using ::HadamardParamsBase; + +constexpr inline int ceil_log2(int val) { + int log = 0; + int p = 1; + while (p < val) { + p <<= 1; + ++log; + } + return log; +} + +template +struct FastHadamardKernelTraits { + using input_t = input_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kLogN = kLogN_; + static constexpr int N = 1 << kLogN; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : 8; + static constexpr int kNExchangePerVec = sizeof(float) / sizeof(input_t); + using vec_t = typename BytesToType::Type; + static constexpr int kNChunks = N / (kNElts * kNThreads); + static constexpr int kSmemExchangeSize = (N * 4) < (32 * 1024) ? (N * 4) : (32 * 1024); + static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize; + static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4); + static constexpr int kSmemSize = kSmemExchangeSize; +}; + +template +struct FastHadamardMNKernelTraits { + using input_t = input_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kLogN = kLogN_; + static constexpr int N = (1 << kLogN) * kMultiple; + static_assert(N <= kMaxDim); + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = 4; + static constexpr int kNExchangePerVec = sizeof(float) / sizeof(input_t); + using vec_t = typename BytesToType::Type; + static constexpr int kNChunks = N / (kNElts * kNThreads); + static_assert(kNChunks == kMultiple); + static constexpr int kSmemExchangeSize = (N * 4) < kMaxSmem ? (N * 4) : kMaxSmem; + static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize; + static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4); + static constexpr int kSmemSize = kSmemExchangeSize; +}; + +template +using FastHadamard12NTraits = FastHadamardMNKernelTraits; + +template +using FastHadamard20NTraits = FastHadamardMNKernelTraits; + +template +using FastHadamard28NTraits = FastHadamardMNKernelTraits; + +template +using FastHadamard40NTraits = FastHadamardMNKernelTraits; + +template +SGL_DEVICE void hadamard_mult_thread_chunk_12(float x[kNChunks][12]) { +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { + hadamard_mult_thread_12(x[c]); + } +} + +template +SGL_DEVICE void hadamard_mult_thread_chunk_20(float x[kNChunks][20]) { +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { + hadamard_mult_thread_20(x[c]); + } +} + +template +SGL_DEVICE void hadamard_mult_thread_chunk_28(float x[kNChunks][28]) { +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { + hadamard_mult_thread_28(x[c]); + } +} + +template +SGL_DEVICE void hadamard_mult_thread_chunk_40(float x[kNChunks][40]) { +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { + hadamard_mult_thread_40(x[c]); + } +} + +template +__global__ __launch_bounds__(Ktraits::kNThreads) void fast_hadamard_transform_kernel(HadamardParamsBase params) { + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNElts = Ktraits::kNElts; + constexpr int kNExchangePerVec = Ktraits::kNExchangePerVec; + constexpr int kNChunks = Ktraits::kNChunks; + using input_t = typename Ktraits::input_t; + using vec_t = typename Ktraits::vec_t; + + constexpr int kLogNElts = cilog2(Ktraits::kNElts); + static_assert(1 << kLogNElts == kNElts, "kNElts must be a power of 2"); + + constexpr int kWarpSize = kNThreads < 32 ? kNThreads : 32; + constexpr int kLogWarpSize = cilog2(kWarpSize); + static_assert(1 << kLogWarpSize == kWarpSize, "Warp size must be a power of 2"); + + constexpr int kNWarps = kNThreads / kWarpSize; + constexpr int kLogNWarps = cilog2(kNWarps); + static_assert(1 << kLogNWarps == kNWarps, "kNWarps must be a power of 2"); + + constexpr int kChunksPerExchange = Ktraits::kSmemExchangeSize / (sizeof(vec_t) * kNExchangePerVec * kNThreads); + static_assert(kChunksPerExchange * sizeof(vec_t) * kNExchangePerVec * kNThreads == Ktraits::kSmemExchangeSize); + constexpr int kNExchanges = kNChunks / kChunksPerExchange; + static_assert(kNExchanges * kChunksPerExchange == kNChunks); + + extern __shared__ char smem_[]; + vec_t* smem_exchange = reinterpret_cast(smem_); + + const int batch_id = static_cast(blockIdx.x); + input_t* x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride; + input_t* out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride; + + float x_vals[kNChunks][kNElts]; + load_input(x, x_vals, params.dim); + + hadamard_mult_thread(x_vals); + hadamard_mult_warp(x_vals); + + if constexpr (kNWarps > 1) { + exchange_smem_pre(x_vals, smem_exchange); + hadamard_mult_warp(x_vals); + exchange_smem_pre(x_vals, smem_exchange); + } + + if constexpr (kNChunks > 1) { + float x_vals_transposed[kNElts][kNChunks]; +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { +#pragma unroll + for (int i = 0; i < kNElts; ++i) { + x_vals_transposed[i][c] = x_vals[c][i]; + } + } + + if constexpr (kNChunks == 12) { + hadamard_mult_thread_chunk_12(x_vals_transposed); + } else if constexpr (kNChunks == 20) { + hadamard_mult_thread_chunk_20(x_vals_transposed); + } else if constexpr (kNChunks == 28) { + hadamard_mult_thread_chunk_28(x_vals_transposed); + } else if constexpr (kNChunks == 40) { + hadamard_mult_thread_chunk_40(x_vals_transposed); + } else { + constexpr int kLogNChunks = cilog2(kNChunks); + static_assert(1 << kLogNChunks == kNChunks, "kNChunks must be a power of 2"); + hadamard_mult_thread(x_vals_transposed); + } + +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { +#pragma unroll + for (int i = 0; i < kNElts; ++i) { + x_vals[c][i] = x_vals_transposed[i][c]; + } + } + } + + store_output(out, x_vals, params.dim, params.scale); +} + +template +inline void set_max_dynamic_smem() { + constexpr int kSmemSize = Ktraits::kSmemSize; + if constexpr (kSmemSize >= 48 * 1024) { + auto kernel = &fast_hadamard_transform_kernel; + host::RuntimeDeviceCheck(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } +} + +template +inline void launch_kernel(HadamardParamsBase& params, DLDevice device) { + constexpr int kSmemSize = Ktraits::kSmemSize; + set_max_dynamic_smem(); + auto kernel = &fast_hadamard_transform_kernel; + host::LaunchKernel(dim3(params.batch), dim3(Ktraits::kNThreads), device, kSmemSize)(kernel, params); + host::RuntimeDeviceCheck(); +} + +template +inline void fast_hadamard_transform_launch(HadamardParamsBase& params, DLDevice device) { + using Ktraits = FastHadamardKernelTraits; + launch_kernel(params, device); +} + +template +inline void fast_hadamard_transform_cuda(HadamardParamsBase& params, DLDevice device) { + if (params.log_N == 3) { + fast_hadamard_transform_launch<1, 3, input_t>(params, device); + } else if (params.log_N == 4) { + fast_hadamard_transform_launch<2, 4, input_t>(params, device); + } else if (params.log_N == 5) { + fast_hadamard_transform_launch<4, 5, input_t>(params, device); + } else if (params.log_N == 6) { + fast_hadamard_transform_launch<8, 6, input_t>(params, device); + } else if (params.log_N == 7) { + fast_hadamard_transform_launch<16, 7, input_t>(params, device); + } else if (params.log_N == 8) { + fast_hadamard_transform_launch<32, 8, input_t>(params, device); + } else if (params.log_N == 9) { + fast_hadamard_transform_launch<32, 9, input_t>(params, device); + } else if (params.log_N == 10) { + fast_hadamard_transform_launch<128, 10, input_t>(params, device); + } else if (params.log_N == 11) { + fast_hadamard_transform_launch<256, 11, input_t>(params, device); + } else if (params.log_N == 12) { + fast_hadamard_transform_launch<256, 12, input_t>(params, device); + } else if (params.log_N == 13) { + fast_hadamard_transform_launch<256, 13, input_t>(params, device); + } else if (params.log_N == 14) { + fast_hadamard_transform_launch<256, 14, input_t>(params, device); + } else if (params.log_N == 15) { + fast_hadamard_transform_launch<256, 15, input_t>(params, device); + } else { + host::Panic("fast_hadamard_transform: unsupported log_N=", params.log_N); + } +} + +template +inline void fast_hadamard_transform_12N_launch(HadamardParamsBase& params, DLDevice device) { + using Ktraits = FastHadamard12NTraits; + launch_kernel(params, device); +} + +template +inline void fast_hadamard_transform_12N_cuda(HadamardParamsBase& params, DLDevice device) { + if (params.log_N == 2) { + fast_hadamard_transform_12N_launch<1, 2, input_t>(params, device); + } else if (params.log_N == 3) { + fast_hadamard_transform_12N_launch<2, 3, input_t>(params, device); + } else if (params.log_N == 4) { + fast_hadamard_transform_12N_launch<4, 4, input_t>(params, device); + } else if (params.log_N == 5) { + fast_hadamard_transform_12N_launch<8, 5, input_t>(params, device); + } else if (params.log_N == 6) { + fast_hadamard_transform_12N_launch<16, 6, input_t>(params, device); + } else if (params.log_N == 7) { + fast_hadamard_transform_12N_launch<32, 7, input_t>(params, device); + } else if (params.log_N == 8) { + fast_hadamard_transform_12N_launch<64, 8, input_t>(params, device); + } else if (params.log_N == 9) { + fast_hadamard_transform_12N_launch<128, 9, input_t>(params, device); + } else if (params.log_N == 10) { + fast_hadamard_transform_12N_launch<256, 10, input_t>(params, device); + } else { + host::Panic("fast_hadamard_transform_12N: unsupported log_N=", params.log_N); + } +} + +template +inline void fast_hadamard_transform_20N_launch(HadamardParamsBase& params, DLDevice device) { + using Ktraits = FastHadamard20NTraits; + launch_kernel(params, device); +} + +template +inline void fast_hadamard_transform_20N_cuda(HadamardParamsBase& params, DLDevice device) { + if (params.log_N == 2) { + fast_hadamard_transform_20N_launch<1, 2, input_t>(params, device); + } else if (params.log_N == 3) { + fast_hadamard_transform_20N_launch<2, 3, input_t>(params, device); + } else if (params.log_N == 4) { + fast_hadamard_transform_20N_launch<4, 4, input_t>(params, device); + } else if (params.log_N == 5) { + fast_hadamard_transform_20N_launch<8, 5, input_t>(params, device); + } else if (params.log_N == 6) { + fast_hadamard_transform_20N_launch<16, 6, input_t>(params, device); + } else if (params.log_N == 7) { + fast_hadamard_transform_20N_launch<32, 7, input_t>(params, device); + } else if (params.log_N == 8) { + fast_hadamard_transform_20N_launch<64, 8, input_t>(params, device); + } else if (params.log_N == 9) { + fast_hadamard_transform_20N_launch<128, 9, input_t>(params, device); + } else if (params.log_N == 10) { + fast_hadamard_transform_20N_launch<256, 10, input_t>(params, device); + } else { + host::Panic("fast_hadamard_transform_20N: unsupported log_N=", params.log_N); + } +} + +template +inline void fast_hadamard_transform_28N_launch(HadamardParamsBase& params, DLDevice device) { + using Ktraits = FastHadamard28NTraits; + launch_kernel(params, device); +} + +template +inline void fast_hadamard_transform_28N_cuda(HadamardParamsBase& params, DLDevice device) { + if (params.log_N == 2) { + fast_hadamard_transform_28N_launch<1, 2, input_t>(params, device); + } else if (params.log_N == 3) { + fast_hadamard_transform_28N_launch<2, 3, input_t>(params, device); + } else if (params.log_N == 4) { + fast_hadamard_transform_28N_launch<4, 4, input_t>(params, device); + } else if (params.log_N == 5) { + fast_hadamard_transform_28N_launch<8, 5, input_t>(params, device); + } else if (params.log_N == 6) { + fast_hadamard_transform_28N_launch<16, 6, input_t>(params, device); + } else if (params.log_N == 7) { + fast_hadamard_transform_28N_launch<32, 7, input_t>(params, device); + } else if (params.log_N == 8) { + fast_hadamard_transform_28N_launch<64, 8, input_t>(params, device); + } else if (params.log_N == 9) { + fast_hadamard_transform_28N_launch<128, 9, input_t>(params, device); + } else if (params.log_N == 10) { + fast_hadamard_transform_28N_launch<256, 10, input_t>(params, device); + } else { + host::Panic("fast_hadamard_transform_28N: unsupported log_N=", params.log_N); + } +} + +template +inline void fast_hadamard_transform_40N_launch(HadamardParamsBase& params, DLDevice device) { + using Ktraits = FastHadamard40NTraits; + launch_kernel(params, device); +} + +template +inline void fast_hadamard_transform_40N_cuda(HadamardParamsBase& params, DLDevice device) { + if (params.log_N == 2) { + fast_hadamard_transform_40N_launch<1, 2, input_t>(params, device); + } else if (params.log_N == 3) { + fast_hadamard_transform_40N_launch<2, 3, input_t>(params, device); + } else if (params.log_N == 4) { + fast_hadamard_transform_40N_launch<4, 4, input_t>(params, device); + } else if (params.log_N == 5) { + fast_hadamard_transform_40N_launch<8, 5, input_t>(params, device); + } else if (params.log_N == 6) { + fast_hadamard_transform_40N_launch<16, 6, input_t>(params, device); + } else if (params.log_N == 7) { + fast_hadamard_transform_40N_launch<32, 7, input_t>(params, device); + } else if (params.log_N == 8) { + fast_hadamard_transform_40N_launch<64, 8, input_t>(params, device); + } else if (params.log_N == 9) { + fast_hadamard_transform_40N_launch<128, 9, input_t>(params, device); + } else if (params.log_N == 10) { + fast_hadamard_transform_40N_launch<256, 10, input_t>(params, device); + } else { + host::Panic("fast_hadamard_transform_40N: unsupported log_N=", params.log_N); + } +} + +inline void set_hadamard_params( + HadamardParamsBase& params, + int64_t batch, + int64_t dim, + int64_t multiple, + const tvm::ffi::TensorView x, + const tvm::ffi::TensorView out, + float scale) { + std::memset(¶ms, 0, sizeof(params)); + params.batch = static_cast(batch); + params.dim = static_cast(dim); + params.log_N = ceil_log2(static_cast(dim / multiple)); + params.x_ptr = const_cast(x.data_ptr()); + params.out_ptr = const_cast(out.data_ptr()); + params.x_batch_stride = x.stride(0); + params.out_batch_stride = out.stride(0); + params.scale = scale; +} + +template +inline void run_hadamard(const tvm::ffi::TensorView x, const tvm::ffi::TensorView out, float scale) { + using namespace host; + + auto N = SymbolicSize{"batch"}; + auto D = SymbolicSize{"dim"}; + auto SX = SymbolicSize{"x_batch_stride"}; + auto SO = SymbolicSize{"out_batch_stride"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({N, D}).with_strides({SX, 1}).with_dtype().with_device(device).verify(x); + TensorMatcher({N, D}).with_strides({SO, 1}).with_dtype().with_device(device).verify(out); + + const int64_t batch = N.unwrap(); + const int64_t dim = D.unwrap(); + + RuntimeCheck(dim % kMultiple == 0, "hadamard: dim must be divisible by ", kMultiple); + + HadamardParamsBase params; + set_hadamard_params(params, batch, dim, kMultiple, x, out, scale); + + if constexpr (kMultiple == 1) { + RuntimeCheck(dim % 8 == 0, "fast_hadamard_transform only supports hidden dim divisible by 8"); + RuntimeCheck(dim <= 32768, "fast_hadamard_transform only supports hidden dim <= 32768"); + fast_hadamard_transform_cuda(params, device.unwrap()); + } else if constexpr (kMultiple == 12) { + RuntimeCheck(dim % (4 * 12) == 0, "fast_hadamard_transform_12N only supports hidden dim divisible by 48"); + RuntimeCheck(dim <= 12 * 1024, "fast_hadamard_transform_12N only supports hidden dim <= 12288"); + fast_hadamard_transform_12N_cuda(params, device.unwrap()); + } else if constexpr (kMultiple == 20) { + RuntimeCheck(dim % (4 * 20) == 0, "fast_hadamard_transform_20N only supports hidden dim divisible by 80"); + RuntimeCheck(dim <= 20 * 1024, "fast_hadamard_transform_20N only supports hidden dim <= 20480"); + fast_hadamard_transform_20N_cuda(params, device.unwrap()); + } else if constexpr (kMultiple == 28) { + RuntimeCheck(dim % (4 * 28) == 0, "fast_hadamard_transform_28N only supports hidden dim divisible by 112"); + RuntimeCheck(dim <= 28 * 1024, "fast_hadamard_transform_28N only supports hidden dim <= 28672"); + fast_hadamard_transform_28N_cuda(params, device.unwrap()); + } else if constexpr (kMultiple == 40) { + RuntimeCheck(dim % (4 * 40) == 0, "fast_hadamard_transform_40N only supports hidden dim divisible by 160"); + RuntimeCheck(dim <= 40 * 1024, "fast_hadamard_transform_40N only supports hidden dim <= 40960"); + fast_hadamard_transform_40N_cuda(params, device.unwrap()); + } else { + Panic("Unsupported multiple"); + } +} + +template +struct HadamardKernel { + static void run(const tvm::ffi::TensorView x, const tvm::ffi::TensorView out, float scale) { + run_hadamard<1, DType>(x, out, scale); + } +}; + +template +struct Hadamard12NKernel { + static void run(const tvm::ffi::TensorView x, const tvm::ffi::TensorView out, float scale) { + run_hadamard<12, DType>(x, out, scale); + } +}; + +template +struct Hadamard20NKernel { + static void run(const tvm::ffi::TensorView x, const tvm::ffi::TensorView out, float scale) { + run_hadamard<20, DType>(x, out, scale); + } +}; + +template +struct Hadamard28NKernel { + static void run(const tvm::ffi::TensorView x, const tvm::ffi::TensorView out, float scale) { + run_hadamard<28, DType>(x, out, scale); + } +}; + +template +struct Hadamard40NKernel { + static void run(const tvm::ffi::TensorView x, const tvm::ffi::TensorView out, float scale) { + run_hadamard<40, DType>(x, out, scale); + } +}; + +} // namespace diff --git a/sglang/python/sglang/jit_kernel/csrc/fast-hadamard-transform/static_switch.h b/sglang/python/sglang/jit_kernel/csrc/fast-hadamard-transform/static_switch.h new file mode 100644 index 0000000000000000000000000000000000000000..aea354665811ebad2a744b8567bf186816135d0e --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/fast-hadamard-transform/static_switch.h @@ -0,0 +1,27 @@ +// Copied from https://github.com/sgl-project/fast-hadamard-transform + +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + static constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + static constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/awq_dequantize.cuh b/sglang/python/sglang/jit_kernel/csrc/gemm/awq_dequantize.cuh new file mode 100644 index 0000000000000000000000000000000000000000..ac6b9a5ffc599af789f4614decf283c0ec4ed089 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/awq_dequantize.cuh @@ -0,0 +1,227 @@ +// Adapted from +// https://github.com/vllm-project/vllm/blob/eb59b5a6cba6727d3727c0372258db9002f687c1/csrc/quantization/awq/gemm_kernels.cu#L350 +#pragma once + +#include + +#include + +namespace device::awq { + +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + uint4 result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW + // dependency if we issue immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // This is the half2 {1024, 1024} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-64, -64} represented as an integer. + static constexpr uint32_t NEG_64 = 0xd400d400; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + + return result; +#else + assert(false); + return {}; +#endif +} + +__device__ uint4 dequantize_s4_to_bf16x2(uint32_t const& source) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + uint4 result; + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = source; + + // Define masks and constants + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC300C300; + + int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s, MASK, EX); + int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 4, MASK, EX); + int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 8, MASK, EX); + int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 12, MASK, EX); + + nv_bfloat162* res = reinterpret_cast(h); + res[0] = __hfma2( + *reinterpret_cast(&lo0), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[1] = __hfma2( + *reinterpret_cast(&hi0), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[2] = __hfma2( + *reinterpret_cast(&lo1), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[3] = __hfma2( + *reinterpret_cast(&hi1), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + + return result; +#else + assert(false); + return {}; +#endif +} + +template +__global__ void __launch_bounds__(256) dequantize_weights( + int* __restrict__ qweight, + OutputT* __restrict__ scales, + int* __restrict__ qzeros, + OutputT* __restrict__ output, + int group_size, + int qweight_cols, + int qweight_rows) { + int col = blockIdx.x * blockDim.x + threadIdx.x; + int row = blockIdx.y * blockDim.y + threadIdx.y; + if (col >= qweight_cols || row >= qweight_rows) return; + + int group_idx = row / group_size; + int scale_offset = 8 * col + group_idx * qweight_cols * 8; + uint4 loaded_scale = *(uint4*)(scales + scale_offset); + + // Handle different data types + if constexpr (std::is_same::value) { + // FP16 path + uint4 zeros = dequantize_s4_to_fp16x2(qzeros[col + group_idx * qweight_cols]); + uint4 weight_fp16 = dequantize_s4_to_fp16x2(qweight[col + row * qweight_cols]); + + // Use PTX assembly for FP16 operations + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x)); + asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(loaded_scale.x)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(zeros.y)); + asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(loaded_scale.y)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(zeros.z)); + asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(loaded_scale.z)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(zeros.w)); + asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(loaded_scale.w)); + + OutputT* output_ptr = output + 8 * col + 8 * row * qweight_cols; + *(uint4*)output_ptr = weight_fp16; + } else if constexpr (std::is_same::value) { + uint4 weight_raw = dequantize_s4_to_bf16x2(qweight[col + row * qweight_cols]); + uint4 zero_raw = dequantize_s4_to_bf16x2(qzeros[col + group_idx * qweight_cols]); + uint4 scale_raw = *reinterpret_cast(scales + scale_offset); + + // Vectorized processing (each uint4 contains 4 nv_bfloat162) + nv_bfloat162* weight_vec = reinterpret_cast(&weight_raw); + nv_bfloat162* zero_vec = reinterpret_cast(&zero_raw); + nv_bfloat162* scale_vec = reinterpret_cast(&scale_raw); + +// Single instruction dual-channel operation +#pragma unroll + for (int i = 0; i < 4; ++i) { // uint4 = 4 * nv_bfloat162 + weight_vec[i] = __hmul2(__hsub2(weight_vec[i], zero_vec[i]), scale_vec[i]); + } + + // Directly store to OutputT array (guaranteed contiguous memory) + OutputT* output_ptr = output + 8 * col + row * qweight_cols * 8; + static_assert(sizeof(uint4) == 8 * sizeof(OutputT), "Memory layout mismatch"); + *reinterpret_cast(output_ptr) = weight_raw; + } +} + +} // namespace device::awq + +// Host wrapper +template +void awq_dequantize( + tvm::ffi::TensorView output, + tvm::ffi::TensorView qweight, + tvm::ffi::TensorView scales, + tvm::ffi::TensorView qzeros) { + using namespace host; + + int64_t qweight_rows = qweight.size(0); + int64_t qweight_cols = qweight.size(1); + int64_t scales_rows = scales.size(0); + + // Validate tensors + SymbolicDevice cuda_device; + cuda_device.set_options(); + + TensorMatcher({qweight_rows, qweight_cols}).with_dtype().with_device(cuda_device).verify(qweight); + TensorMatcher({scales_rows, qweight_cols * 8}).with_dtype().with_device(cuda_device).verify(scales); + TensorMatcher({scales_rows, qweight_cols}).with_dtype().with_device(cuda_device).verify(qzeros); + TensorMatcher({qweight_rows, qweight_cols * 8}).with_dtype().with_device(cuda_device).verify(output); + + // Get device and stream + auto device = cuda_device.unwrap(); + auto stream = LaunchKernel::resolve_device(device); + + int group_size = static_cast(qweight_rows / scales_rows); + int x_num_threads = 16; + int y_num_threads = 16; + int x_blocks = (static_cast(qweight_cols) + x_num_threads - 1) / x_num_threads; + int y_blocks = (static_cast(qweight_rows) + y_num_threads - 1) / y_num_threads; + + dim3 num_blocks(x_blocks, y_blocks); + dim3 threads_per_block(x_num_threads, y_num_threads); + + // Get pointers + auto* qweight_ptr = reinterpret_cast(qweight.data_ptr()); + auto* scales_ptr = reinterpret_cast(scales.data_ptr()); + auto* qzeros_ptr = reinterpret_cast(qzeros.data_ptr()); + auto* output_ptr = reinterpret_cast(output.data_ptr()); + + LaunchKernel(num_blocks, threads_per_block, stream)( + device::awq::dequantize_weights, + qweight_ptr, + scales_ptr, + qzeros_ptr, + output_ptr, + group_size, + static_cast(qweight_cols), + static_cast(qweight_rows)); +} diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/awq_marlin_repack.cuh b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/awq_marlin_repack.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7f1735433ae23fa06af4c1ebb609660c0a5e4017 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/awq_marlin_repack.cuh @@ -0,0 +1,251 @@ +#pragma once + +#include + +#include + +#include "marlin.cuh" + +namespace device::marlin { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +template +__global__ void awq_marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) { + return; +} +#else + +template +__global__ void awq_marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) { + constexpr int pack_factor = 32 / num_bits; + + int k_tiles = size_k / tile_k_size; + int n_tiles = size_n / tile_n_size; + int block_k_tiles = div_ceil(k_tiles, (int)gridDim.x); + + auto start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int tile_n_ints = tile_n_size / pack_factor; + + constexpr int stage_n_threads = tile_n_ints / 4; + constexpr int stage_k_threads = tile_k_size; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * tile_n_size; + int first_n_packed = first_n / pack_factor; + + int4* sh_ptr = sh + stage_size * pipe; + + if (threadIdx.x < stage_size) { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + first_n_packed + (n_id * 4)]))); + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } + + auto warp_id = threadIdx.x / 32; + auto th_id = threadIdx.x % 32; + + if (warp_id >= 4) { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = warp_id * 16 + tc_col; + int cur_n_packed = cur_n / pack_factor; + int cur_n_pos = cur_n % pack_factor; + + constexpr int sh_stride = tile_n_ints; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4* sh_stage_ptr = sh + stage_size * pipe; + uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + // Undo interleaving + int cur_n_pos_unpacked; + if constexpr (num_bits == 4) { + constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } else { + constexpr int undo_pack[4] = {0, 2, 1, 3}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } + + uint32_t vals[8]; +#pragma unroll + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; + + int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + sh_stride * cur_elem]; + + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } + + constexpr int tile_size_val = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size_val; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; +#pragma unroll + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; +#pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } +} +#endif + +} // namespace device::marlin + +// Host wrapper +void awq_marlin_repack( + tvm::ffi::TensorView out, tvm::ffi::TensorView b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) { + using namespace host; + using namespace device::marlin; + + // Validate alignment + RuntimeCheck(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size); + RuntimeCheck(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size); + RuntimeCheck(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); + + int const pack_factor = 32 / num_bits; + + // Validate tensors + SymbolicDevice cuda_device; + cuda_device.set_options(); + + TensorMatcher({size_k, size_n / pack_factor}).with_dtype().with_device(cuda_device).verify(b_q_weight); + + TensorMatcher({size_k / tile_size, size_n * tile_size / pack_factor}) + .with_dtype() + .with_device(cuda_device) + .verify(out); + + // Get device and stream + auto device = cuda_device.unwrap(); + auto stream = LaunchKernel::resolve_device(device); + + // Get pointers + auto* b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); + auto* out_ptr = reinterpret_cast(out.data_ptr()); + + // Get device attributes + int blocks = 0; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, device.device_id); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, device.device_id); + RuntimeCheck(max_shared_mem > 0, "max_shared_mem must be > 0"); + + // Dispatch based on num_bits + if (num_bits == 4) { + cudaFuncSetAttribute( + awq_marlin_repack_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); + LaunchKernel(blocks, repack_threads, stream, max_shared_mem)( + awq_marlin_repack_kernel, + b_q_weight_ptr, + out_ptr, + static_cast(size_k), + static_cast(size_n)); + } else if (num_bits == 8) { + cudaFuncSetAttribute( + awq_marlin_repack_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); + LaunchKernel(blocks, repack_threads, stream, max_shared_mem)( + awq_marlin_repack_kernel, + b_q_weight_ptr, + out_ptr, + static_cast(size_k), + static_cast(size_n)); + } else { + RuntimeCheck(false, "Unsupported repack config: num_bits = ", num_bits); + } +} diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/dequant.h b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/dequant.h new file mode 100644 index 0000000000000000000000000000000000000000..764375f62280a6ddb4f8e28928c40dca58e23fdd --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/dequant.h @@ -0,0 +1,504 @@ +/* +Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16) + +The process of fast dequantization can be summarized as a combination +of bitwise operations and floating-point computations: + +weight =>(bit_op / bitwise operations)=> +f16_value =>(flop / floating-point computation)=> +dequantized_weight + +Since the dequantized weights typically require subtracting the zero point and +applying a scale factor, the floating-point computation step can be fused with +the zero-point subtraction and scaling operations. + +The following are the parts that need to be modified for the fused operation +of zero-point subtraction and scaling. + +## INT4 => FP16/BF16 or INT8 => FP16 + +The floating-point computation is `__hsub2` + +If has zero points: + + flop(bit_op(weight)) - flop(bit_op(zp)) + = sub(bit_op(weight), bias) - sub(bit_op(zp), bias) + = bit_op(weight) - bit_op(zp) + +so we don't need additional modification. + +If has float zero points: + + flop(bit_op(weight)) - fzp + = sub(bit_op(weight), bias) - fzp + = bit_op(weight) - (fzp + bias) + +where the `fzp + bias` can be computed at weight loading. But this +may have accuracy issue, so we should not use this in most cases. + +If has not zero points: + + scale(flop(bit_op(weight))) + = scale(sub(bit_op(weight), bias)) + = scale(bit_op(weight)) - scale(bias) + = fma(bit_op(weight), scale_factor, scale(bias)) + +where the `scale(bias)` can be cached. But this may have accuracy issue, +so we should not use this in most cases. + + +## INT8 => BF16 + +INT8 => BF16 is a special case, it use byte_perm instead of flop. +We cannot fused byte_perm with scaling. + + +## FP4/FP8 => FP16/BF16 + + scale(flop(bit_op(weight))) + = scale(mul(bit_op(weight), multiplier)) + = mul(bit_op(weight), scale_factor * multiplier) + +where `scale_factor * multiplier` can be computed at weight loading. + +*/ + +#include "marlin_dtypes.cuh" + +namespace device::marlin { + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(res) : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline void dequant(int q, scalar_t2* frag_b); + +// +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +// +template <> +__device__ inline void dequant(int q, half2* frag_b) { + const int MASK = 0x000f000f; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + // clang-format on + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t SUB = 0x43084308; + + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t SUB = 0x43004300; + + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); +} + +// +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +// +template <> +__device__ inline void dequant(int q, half2* frag_b) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388608.f; + fp32_intermediates[1] -= 8388608.f; + fp32_intermediates[2] -= 8388608.f; + fp32_intermediates[3] -= 8388608.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to bfloat162 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and BF16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template +__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); + +template <> +__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { + int Out1 = (q & 0xFF00FF00) >> 1; + ; + q <<= 8; + int Out2 = (q & 0xFF00FF00) >> 1; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + +template <> +__device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + +// New version with s_type_id parameter for marlin_moe_wna16_v2 +template +__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); + +template <> +__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { + int Out1 = (q & 0xFF00FF00) >> 1; + ; + q <<= 8; + int Out2 = (q & 0xFF00FF00) >> 1; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + +template <> +__device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { + // In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16, + // but we assume that such a extreme value would not occur in real models. + int Out1 = (q & 0xFF00FF00) >> 1; + q <<= 7; + int Out2 = q & 0x7F807F80; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +#endif + +} // namespace device::marlin diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/gptq_marlin.cuh b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/gptq_marlin.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0f8983e87036c739bffa04fd20b4890d94c122f6 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/gptq_marlin.cuh @@ -0,0 +1,1001 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#pragma once + +#include + +#include + +#include "kernel.h" +#include "marlin_template.h" + +namespace device::marlin { + +__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; + +using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +__global__ void permute_cols_kernel( + int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, + int size_m, + int size_k, + int lda, + int block_rows) {} + +#else + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +__global__ void permute_cols_kernel( + int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, + int size_m, + int size_k, + int lda, + int block_rows) { + auto start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int input_row_stride = lda * sizeof(half) / 16; + int output_row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int input_offset = row * input_row_stride; + int output_offset = row * output_row_stride; + + half const* a_row_half = reinterpret_cast(a_int4_ptr + input_offset); + half* out_half = reinterpret_cast(out_int4_ptr + output_offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}}; + +typedef struct { + int blocks_per_sm; + thread_config_t tb_cfg; +} exec_config_t; + +int get_scales_cache_size( + thread_config_t const& th_config, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } +} + +int get_kernel_cache_size( + thread_config_t const& th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + int tb_m = thread_m_blocks * 16; + int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; + int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_red_size = tb_m * (tb_n + 8); + int sh_s_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); + int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; + int sh_zp_size = 0; + if (has_zp) { + if (is_zp_float) + sh_zp_size = sh_s_size; + else if (num_bits == 4) + sh_zp_size = sh_s_size / 4; + else if (num_bits == 8) + sh_zp_size = sh_s_size / 2; + } + + int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size; + + return total_size; +} + +bool is_valid_config( + thread_config_t const& th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float, + int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Check that pipeline fits into cache + int cache_size = get_kernel_cache_size( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + return cache_size <= max_shared_mem; +} + +#define _GET_IF( \ + W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if ( \ + q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && m_block_size_8 == M_BLOCK_SIZE_8 && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS && is_zp_float == IS_ZP_FLOAT) { \ + kernel = Marlin< \ + scalar_t, \ + W_TYPE.id(), \ + NUM_THREADS, \ + THREAD_M_BLOCKS, \ + THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, \ + pipe_stages, \ + GROUP_BLOCKS, \ + IS_ZP_FLOAT>; \ + } + +// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) +// this is the most common cases +// BIGGROUP: cases for big group size (group_blocks in [-1, 8]) +// FZP: cases for float-zero-point (is_zp_float = true) +// ACT: cases for act order case (group_blocks == 0) +// FP4: cases for nvfp4(e2m1) (group_blocks == 1) +#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF(W_TYPE) \ + COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ + COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \ + COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ + COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M234(W_TYPE, 4, 8, 128) + +#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF(W_TYPE) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) + +#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define FP4_GET_IF(W_TYPE) \ + FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M234(W_TYPE, 4, 8, 128) + +// We currently have 4-bit models only with group_blocks == 4 +#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF(W_TYPE) \ + FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M234(W_TYPE, 4, 8, 128) + +// We currently have 4-bit models only with group_blocks == 4 +#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF(W_TYPE) \ + ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ + ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \ + ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ + ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M234(W_TYPE, 4, 8, 128) + +template +MarlinFuncPtr get_marlin_kernel( + const host::ScalarType q_type, + int thread_m_blocks, + int thread_n_blocks, + int thread_k_blocks, + bool m_block_size_8, + bool has_act_order, + bool has_zp, + int group_blocks, + int num_threads, + bool is_zp_float) { + int num_bits = q_type.size_bits(); + auto kernel = MarlinDefault; + if (false) { + } + + COMMON_GET_IF(host::kU4) + COMMON_GET_IF(host::kU4B8) + COMMON_GET_IF(host::kU8B128) + + FP4_GET_IF(host::kFE2M1f) + + BIGGROUP_GET_IF(host::kFE4M3fn) + + ACT_GET_IF(host::kU4B8) + ACT_GET_IF(host::kU8B128) + + if (std::is_same::value) { + if (false) { + } + FZP_GET_IF(host::kU4) + } + + return kernel; +} + +template +exec_config_t determine_exec_config( + const host::ScalarType& q_type, + int prob_m, + int prob_n, + int prob_k, + int thread_m_blocks, + bool m_block_size_8, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + bool has_zp, + bool is_zp_float, + int max_shared_mem, + int sms) { + exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; + thread_config_t* thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs : small_batch_thread_configs; + int thread_configs_size = thread_m_blocks > 1 ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t) + : sizeof(small_batch_thread_configs) / sizeof(thread_config_t); + + for (int i = 0; i < thread_configs_size; i++) { + thread_config_t th_config = thread_configs[i]; + + if (!is_valid_config( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem)) { + continue; + } + + int cache_size = get_kernel_cache_size( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + + int group_blocks = 0; + if (!has_act_order) { + group_blocks = group_size == -1 ? -1 : group_size / 16; + } + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + th_config.thread_n / 16, + th_config.thread_k / 16, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + th_config.num_threads, + is_zp_float); + + if (kernel == MarlinDefault) continue; + + // int m_tiles = div_ceil(prob_m, thread_m_blocks * 16); + // int n_tiles = prob_n / th_config.thread_n; + // int k_tiles = prob_k / th_config.thread_k; + + return {1, th_config}; + } + + return exec_cfg; +} + +template +void marlin_mm( + const void* A, + const void* B, + void* C, + void* C_tmp, + void* s, + void* s2, + void* zp, + void* g_idx, + void* perm, + void* a_tmp, + int prob_m, + int prob_n, + int prob_k, + int lda, + void* workspace, + host::ScalarType const& q_type, + bool has_act_order, + bool is_k_full, + bool has_zp, + int num_groups, + int group_size, + int dev, + cudaStream_t stream, + int thread_k_init, + int thread_n_init, + int sms, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) { + if (has_zp) { + host::RuntimeCheck( + q_type == host::kU4 || q_type == host::kU8, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); + } else { + host::RuntimeCheck( + q_type == host::kU4B8 || q_type == host::kU8B128 || q_type == host::kFE4M3fn || q_type == host::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + q_type.str()); + } + + host::RuntimeCheck( + prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + host::RuntimeCheck(group_size != -1); + group_blocks = group_size / 16; + host::RuntimeCheck( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } else { + host::RuntimeCheck(group_size == 0); + group_blocks = 0; + } + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + host::RuntimeCheck( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } + } + + int num_bits = q_type.size_bits(); + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + int4* C_tmp_ptr = (int4*)C_tmp; + const int4* s_ptr = (const int4*)s; + const uint16_t* s2_ptr = (const uint16_t*)s2; + const int4* zp_ptr = (const int4*)zp; + const int* g_idx_ptr = (const int*)g_idx; + const int* perm_ptr = (const int*)perm; + int4* a_tmp_ptr = (int4*)a_tmp; + + int* locks = (int*)workspace; + + if (has_act_order) { + // Permute A columns + int block_rows = div_ceil(prob_m, sms); + host::LaunchKernel(sms, default_threads, stream)( + permute_cols_kernel, A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows); + A_ptr = a_tmp_ptr; + lda = prob_k; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) has_act_order = false; + } + + int max_shared_mem = 0; + host::RuntimeDeviceCheck(cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev)); + host::RuntimeCheck(max_shared_mem > 0); + + int max_par = 16; + if (prob_n <= 4096) max_par = 16 * 8; + int max_shared_mem_new = max_shared_mem; + int rest_m = prob_m; + int max_thread_m_blocks = 4; + while (rest_m) { + int par_count = rest_m / (max_thread_m_blocks * 16); + if (par_count > max_par) par_count = max_par; + int prob_m_split = par_count > 0 ? (par_count * (max_thread_m_blocks * 16)) : rest_m; + + int thread_k = thread_k_init; + int thread_n = thread_n_init; + + int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); + int m_block_size_8 = prob_m_split <= 8; + + // Set thread config + exec_config_t exec_cfg; + thread_config_t thread_tfg; + if (thread_k != -1 && thread_n != -1) { + thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = exec_config_t{1, thread_tfg}; + host::RuntimeCheck(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); + host::RuntimeCheck(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k); + } else { + // Auto config + exec_cfg = determine_exec_config( + q_type, + prob_m_split, + prob_n, + prob_k, + thread_m_blocks, + m_block_size_8, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem, + sms); + thread_tfg = exec_cfg.tb_cfg; + if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) { + max_thread_m_blocks--; + continue; + } + } + + int num_threads = thread_tfg.num_threads; + thread_k = thread_tfg.thread_k; + thread_n = thread_tfg.thread_n; + int blocks = sms * exec_cfg.blocks_per_sm; + if (exec_cfg.blocks_per_sm > 1) max_shared_mem_new = max_shared_mem / exec_cfg.blocks_per_sm - 1024; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + host::RuntimeCheck( + is_valid_config( + thread_tfg, + thread_m_blocks, + prob_m_split, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem_new), + "Invalid thread config: thread_m_blocks = ", + thread_m_blocks, + ", thread_k = ", + thread_tfg.thread_k, + ", thread_n = ", + thread_tfg.thread_n, + ", num_threads = ", + thread_tfg.num_threads, + " for MKN = [", + prob_m, + ", ", + prob_k, + ", ", + prob_n, + "] and num_bits = ", + num_bits, + ", prob_m_split = ", + prob_m_split, + ", group_size = ", + group_size, + ", has_act_order = ", + has_act_order, + ", is_k_full = ", + is_k_full, + ", has_zp = ", + has_zp, + ", is_zp_float = ", + is_zp_float, + ", max_shared_mem_new = ", + max_shared_mem_new); + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + thread_n_blocks, + thread_k_blocks, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + num_threads, + is_zp_float); + + if (kernel == MarlinDefault) { + host::Panic( + "Unsupported shapes: MNK = [", + prob_m, + ", ", + prob_n, + ", ", + prob_k, + "]", + ", has_act_order = ", + has_act_order, + ", num_groups = ", + num_groups, + ", group_size = ", + group_size, + ", prob_m_split = ", + prob_m_split, + ", thread_m_blocks = ", + thread_m_blocks, + ", thread_n_blocks = ", + thread_n_blocks, + ", thread_k_blocks = ", + thread_k_blocks, + ", num_threads = ", + num_threads, + ", num_bits = ", + num_bits); + } + + host::RuntimeDeviceCheck( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem_new)); + + bool part_use_atomic_add = use_atomic_add && div_ceil(prob_m_split, 64) * prob_n <= 2048; + + host::LaunchKernel(blocks, num_threads, stream, max_shared_mem_new)( + kernel, + A_ptr, + B_ptr, + C_ptr, + C_tmp_ptr, + s_ptr, + s2_ptr, + zp_ptr, + g_idx_ptr, + num_groups, + prob_m_split, + prob_n, + prob_k, + lda, + locks, + part_use_atomic_add, + use_fp32_reduce, + max_shared_mem_new); + + A_ptr += prob_m_split * (lda / 8); + C_ptr += prob_m_split * (prob_n / 8); + rest_m -= prob_m_split; + } +} + +#endif + +} // namespace device::marlin + +template +void gptq_marlin_gemm( + tvm::ffi::TensorView a, + tvm::ffi::TensorView b_q_weight, + tvm::ffi::TensorView b_scales, + tvm::ffi::TensorView global_scale, + tvm::ffi::TensorView b_zeros, + tvm::ffi::TensorView g_idx, + tvm::ffi::TensorView perm, + tvm::ffi::TensorView c, + tvm::ffi::TensorView c_tmp, + tvm::ffi::TensorView a_tmp, + tvm::ffi::TensorView workspace, + int64_t b_q_type_id, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) { + using namespace host; + + ScalarType const b_q_type = ScalarType::from_id(b_q_type_id); + int pack_factor = 32 / b_q_type.size_bits(); + + // Bind symbolic sizes + auto M = SymbolicSize{"M"}; + auto K = SymbolicSize{"K"}; + auto N = SymbolicSize{"N"}; + auto device = SymbolicDevice{}; + device.set_options(); + + // Verify a: [M, K] + auto lda = SymbolicSize{"lda"}; + TensorMatcher({M, K}).with_strides({lda, 1}).with_dtype().with_device(device).verify(a); + + int64_t size_m = M.unwrap(); + int64_t size_k = K.unwrap(); + + // Verify b_q_weight: [K/tile_size, packed_N] + RuntimeCheck( + size_k % device::marlin::tile_size == 0, + "size_k = ", + size_k, + " is not divisible by tile_size = ", + device::marlin::tile_size); + int64_t expected_bqw_dim0 = size_k / device::marlin::tile_size; + auto bqw_dim0 = SymbolicSize{"bqw_dim0"}; + auto bqw_dim1 = SymbolicSize{"bqw_dim1"}; + bqw_dim0.set_value(expected_bqw_dim0); + TensorMatcher({bqw_dim0, bqw_dim1}).with_dtype().with_device(device).verify(b_q_weight); + + RuntimeCheck( + b_q_weight.size(1) % device::marlin::tile_size == 0, + "b_q_weight.size(1) = ", + b_q_weight.size(1), + " is not divisible by tile_size = ", + device::marlin::tile_size); + int64_t actual_size_n = (b_q_weight.size(1) / device::marlin::tile_size) * pack_factor; + N.set_value(actual_size_n); + int64_t size_n = N.unwrap(); + + // Verify stride alignment + int64_t a_stride0 = a.stride(0); + RuntimeCheck(a_stride0 % 8 == 0, "a.stride(0) must be divisible by 8"); + + // Verify b_scales: [num_groups, N] + auto num_groups_sym = SymbolicSize{"num_groups"}; + TensorMatcher({num_groups_sym, N}).with_device(device).verify(b_scales); + int num_groups = static_cast(num_groups_sym.unwrap()); + + // Verify c: [M, N] + TensorMatcher({M, N}).with_dtype().with_device(device).verify(c); + + // Early return for zero-size M + if (size_m == 0) return; + + // Determine has_act_order from g_idx/perm sizes + int64_t g_idx_size = g_idx.size(0); + int64_t perm_size = perm.size(0); + bool has_act_order = g_idx_size > 0 && perm_size > 0; + + if (has_act_order) { + RuntimeCheck( + (g_idx_size == size_k && perm_size == size_k), + "Unexpected g_idx.size(0) = ", + g_idx_size, + " and perm.size(0) = ", + perm_size, + ", where size_k = ", + size_k); + } + + // Determine has_zp from b_zeros size + int64_t b_zeros_size = b_zeros.size(0); + bool has_zp = b_zeros_size > 0; + + if (has_zp) { + RuntimeCheck( + b_q_type == kU4 || b_q_type == kU8, "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); + } else { + RuntimeCheck( + b_q_type == kU4B8 || b_q_type == kU8B128 || b_q_type == kFE4M3fn || b_q_type == kFE2M1f, + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + b_q_type.str()); + } + + if (has_zp && is_zp_float) { + RuntimeCheck( + std::is_same::value, "Computation type must be float16 (half) when using float zero points."); + } + + // Verify b_zeros shape + if (has_zp) { + RuntimeCheck(b_zeros.dim() == 2, "b_zeros rank = ", b_zeros.dim(), " is not 2"); + if (is_zp_float) { + RuntimeCheck(b_zeros.size(1) == size_n, "b_zeros dim 1 = ", b_zeros.size(1), " is not size_n = ", size_n); + RuntimeCheck( + num_groups == b_zeros.size(0), "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); + RuntimeCheck(num_groups != -1, "num_groups must be != -1"); + } else { + RuntimeCheck( + b_zeros.size(0) == num_groups, "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); + RuntimeCheck( + b_zeros.size(1) == size_n / pack_factor, + "b_zeros dim 1 = ", + b_zeros.size(1), + " is not size_n / pack_factor = ", + size_n / pack_factor); + } + } + + // Verify global_scale + int64_t global_scale_size = global_scale.size(0); + if (global_scale_size > 0) { + RuntimeCheck(b_q_type == kFE2M1f, "global_scale can only be used for float4_e2m1f."); + } else { + RuntimeCheck(!(b_q_type == kFE2M1f), "the global_scale parameter must be passed for float4_e2m1f."); + } + + // Derive group_size + int group_size = -1; + if (has_act_order) { + if (is_k_full) { + RuntimeCheck(num_groups > 1, "For act_order, num_groups must be > 1"); + RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); + group_size = static_cast(size_k / num_groups); + } else { + group_size = 0; + } + } else { + if (num_groups > 1) { + RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); + group_size = static_cast(size_k / num_groups); + } else { + group_size = -1; + } + } + + // Verify workspace and get device info + RuntimeCheck( + size_n % device::marlin::min_thread_n == 0, + "size_n = ", + size_n, + ", is not divisible by min_thread_n = ", + device::marlin::min_thread_n); + + DLDevice dl_device = device.unwrap(); + int dev = dl_device.device_id; + cudaStream_t stream = LaunchKernel::resolve_device(dl_device); + + int sms = -1; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev)); + + RuntimeCheck( + workspace.size(0) >= sms, "workspace.size(0) = ", workspace.size(0), " is below min_workspace_size = ", sms); + + // Hardcoded defaults (auto config) + int thread_k_init = -1; + int thread_n_init = -1; + + // Compute c_tmp and a_tmp pointers + // c_tmp and a_tmp are pre-allocated by caller + + device::marlin::marlin_mm( + a.data_ptr(), + b_q_weight.data_ptr(), + c.data_ptr(), + c_tmp.data_ptr(), + b_scales.data_ptr(), + global_scale.data_ptr(), + b_zeros.data_ptr(), + g_idx.data_ptr(), + perm.data_ptr(), + a_tmp.data_ptr(), + static_cast(size_m), + static_cast(size_n), + static_cast(size_k), + static_cast(a_stride0), + workspace.data_ptr(), + b_q_type, + has_act_order, + is_k_full, + has_zp, + num_groups, + group_size, + dev, + stream, + thread_k_init, + thread_n_init, + sms, + use_atomic_add, + use_fp32_reduce, + is_zp_float); +} diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/gptq_marlin_repack.cuh b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/gptq_marlin_repack.cuh new file mode 100644 index 0000000000000000000000000000000000000000..73bce7903f076fb5d9353be95fb3713a68270905 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/gptq_marlin_repack.cuh @@ -0,0 +1,362 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#pragma once + +#include + +#include + +#include "marlin.cuh" + +namespace device::marlin { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +template +__global__ void gptq_marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, + uint32_t const* __restrict__ perm_ptr, + uint32_t* __restrict__ out_ptr, + int size_k, + int size_n) { + return; +} +#else +template +__global__ void gptq_marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, + uint32_t const* __restrict__ perm_ptr, + uint32_t* __restrict__ out_ptr, + int size_k, + int size_n) { + constexpr int pack_factor = 32 / num_bits; + + int k_tiles = size_k / tile_k_size; + int n_tiles = size_n / tile_n_size; + int block_k_tiles = div_ceil(k_tiles, gridDim.x); + + auto start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int perm_size = tile_k_size / 4; + + int4* sh_perm_ptr = sh; + int4* sh_pipe_ptr = sh_perm_ptr; + if constexpr (has_perm) { + sh_pipe_ptr += perm_size; + } + + constexpr int tile_ints = tile_k_size / pack_factor; + + constexpr int stage_n_threads = tile_n_size / 4; + constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto load_perm_to_shared = [&](int k_tile_id) { + int first_k_int4 = (k_tile_id * tile_k_size) / 4; + + int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr); + + if (threadIdx.x < perm_size) { + sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; + } + __syncthreads(); + }; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * tile_n_size; + + int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; + + if constexpr (has_perm) { + if (threadIdx.x < stage_size) { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; + + uint32_t const* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + + int src_k = sh_perm_int_ptr[k_id]; + int src_k_packed = src_k / pack_factor; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&(b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); + } + + } else { + if (threadIdx.x < stage_size) { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + int first_k_packed = first_k / pack_factor; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)]))); + } + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } + + auto warp_id = threadIdx.x / 32; + auto th_id = threadIdx.x % 32; + + if (warp_id >= 4) { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = warp_id * 16 + tc_col; + + constexpr int sh_stride = 64; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; + uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + uint32_t* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + + uint32_t vals[8]; + + if constexpr (has_perm) { + for (int i = 0; i < 4; i++) { + int k_idx = tc_row + tc_offsets[i]; + + uint32_t src_k = sh_perm_int_ptr[k_idx]; + uint32_t src_k_pos = src_k % pack_factor; + + uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; + uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; + + uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; + uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; + + vals[i] = b1_cur_val; + vals[4 + i] = b2_cur_val; + } + + } else { + uint32_t b1_vals[tile_ints]; + uint32_t b2_vals[tile_ints]; + +#pragma unroll + for (int i = 0; i < tile_ints; i++) { + b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; + b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; + } + +#pragma unroll + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; + int cur_int = cur_elem / pack_factor; + int cur_pos = cur_elem % pack_factor; + + vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; + vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; + } + } + + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; +#pragma unroll + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; +#pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; + + if constexpr (has_perm) { + load_perm_to_shared(k_tile_id); + } + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } +} +#endif + +} // namespace device::marlin + +#define CALL_IF_REPACK(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + host::RuntimeDeviceCheck(cudaFuncSetAttribute( \ + device::marlin::gptq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem)); \ + host::LaunchKernel(blocks, device::marlin::repack_threads, stream, static_cast(max_shared_mem))( \ + device::marlin::gptq_marlin_repack_kernel, \ + b_q_weight_ptr, \ + perm_ptr, \ + out_ptr, \ + size_k, \ + size_n); \ + } + +void gptq_marlin_repack( + tvm::ffi::TensorView b_q_weight, + tvm::ffi::TensorView perm, + tvm::ffi::TensorView out, + int64_t size_k, + int64_t size_n, + int64_t num_bits) { + using namespace host; + + // Validate num_bits + RuntimeCheck(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); + int const pack_factor = 32 / static_cast(num_bits); + + // Validate size alignment + RuntimeCheck( + size_k % device::marlin::tile_k_size == 0, + "size_k = ", + size_k, + " is not divisible by tile_k_size = ", + device::marlin::tile_k_size); + RuntimeCheck( + size_n % device::marlin::tile_n_size == 0, + "size_n = ", + size_n, + " is not divisible by tile_n_size = ", + device::marlin::tile_n_size); + + // Validate b_q_weight + auto bqw_dim0 = SymbolicSize{"bqw_dim0"}; + auto bqw_dim1 = SymbolicSize{"bqw_dim1"}; + bqw_dim0.set_value(size_k / pack_factor); + bqw_dim1.set_value(size_n); + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({bqw_dim0, bqw_dim1}).with_dtype().with_device(device_).verify(b_q_weight); + + // Validate out + auto out_dim0 = SymbolicSize{"out_dim0"}; + auto out_dim1 = SymbolicSize{"out_dim1"}; + out_dim0.set_value(size_k / device::marlin::tile_size); + out_dim1.set_value(size_n * device::marlin::tile_size / pack_factor); + TensorMatcher({out_dim0, out_dim1}).with_dtype().with_device(device_).verify(out); + + // Detect if there is act_order + bool has_perm = perm.size(0) != 0; + + // Get ptrs + uint32_t const* b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); + uint32_t const* perm_ptr = reinterpret_cast(perm.data_ptr()); + uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); + + // Get dev info + DLDevice dl_device = device_.unwrap(); + int dev = dl_device.device_id; + cudaStream_t stream = LaunchKernel::resolve_device(dl_device); + int blocks; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev)); + + int max_shared_mem = 0; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev)); + RuntimeCheck(max_shared_mem > 0, "max_shared_mem must be > 0"); + + if (false) { + } + CALL_IF_REPACK(4, false) + CALL_IF_REPACK(4, true) + CALL_IF_REPACK(8, false) + CALL_IF_REPACK(8, true) + else { + Panic("Unsupported repack config: num_bits = ", num_bits, ", has_perm = ", has_perm); + } +} + +#undef CALL_IF_REPACK diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/kernel.h b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..85af8c7a2a0f5baeafd0ed596e0d390efce50ab4 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/kernel.h @@ -0,0 +1,33 @@ + +#include + +#include "marlin.cuh" +#include "marlin_dtypes.cuh" + +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, const uint16_t *__restrict__ scale2_ptr, const int4 *__restrict__ zp_ptr, \ + const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ + bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem + +namespace device::marlin { +template < + typename scalar_t, // compute dtype, half or nv_float16 + const host::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin(MARLIN_KERNEL_PARAMS); + +} // namespace device::marlin diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/marlin.cuh b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/marlin.cuh new file mode 100644 index 0000000000000000000000000000000000000000..15b52d81cfba8387271062e942d2528b13b42342 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/marlin.cuh @@ -0,0 +1,85 @@ +#pragma once + +#include + +#include + +namespace device::marlin { +// Marlin params + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int default_threads = 256; + +static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; +static constexpr int max_thread_n = 256; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +// Repack params +static constexpr int repack_stages = 8; + +static constexpr int repack_threads = 256; + +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; + +// Helpers +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { + return elems[i]; + } +}; + +using I4 = Vec; + +using host::div_ceil; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +// No support for async +#else + +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), + "l"(glob_ptr), + "n"(BYTES)); +} + +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), + "n"(BYTES)); +} + +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +#endif + +} // namespace device::marlin diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/marlin_dtypes.cuh b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/marlin_dtypes.cuh new file mode 100644 index 0000000000000000000000000000000000000000..20fa77bd046f678033c7d9cd8921c267f7dbeea0 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/marlin_dtypes.cuh @@ -0,0 +1,77 @@ +#ifndef _data_types_cuh +#define _data_types_cuh +#include + +#include "marlin.cuh" + +namespace device::marlin { + +template +class ScalarType {}; + +template <> +class ScalarType { + public: + using scalar_t = fp16_t; + using scalar_t2 = fp16x2_t; + + // Matrix fragments for tensor core instructions; their precise layout is + // documented here: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + using FragZP = Vec; + + static __device__ float inline num2float(const fp16_t x) { + return __half2float(x); + } + + static __device__ fp16x2_t inline num2num2(const fp16_t x) { + return __half2half2(x); + } + + static __device__ fp16x2_t inline nums2num2(const fp16_t x1, const fp16_t x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ fp16_t inline float2num(const float x) { + return __float2half(x); + } +}; + +template <> +class ScalarType { + public: + using scalar_t = bf16_t; + using scalar_t2 = bf16x2_t; + + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + using FragZP = Vec; + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + static __device__ float inline num2float(const bf16_t x) { + return __bfloat162float(x); + } + + static __device__ bf16x2_t inline num2num2(const bf16_t x) { + return __bfloat162bfloat162(x); + } + + static __device__ bf16x2_t inline nums2num2(const bf16_t x1, const bf16_t x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ bf16_t inline float2num(const float x) { + return __float2bfloat16(x); + } +#endif +}; + +} // namespace device::marlin + +#endif diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h new file mode 100644 index 0000000000000000000000000000000000000000..6c4112e633fded7dbcdb1309503b63b0bc3ab711 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h @@ -0,0 +1,1626 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ +#include + +#include "dequant.h" +#include "marlin.cuh" +#include "marlin_dtypes.cuh" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert( \ + std::is_same::value || std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace device::marlin { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template < + typename scalar_t, // compute dtype, half or nv_float16 + const host::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool use_fp32_reduce // whether to use fp32 global reduce +) {} + +} // namespace device::marlin + +#else + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void +mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +template +__device__ inline void mma_trans( + const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + const typename ScalarType::FragB& frag_b2, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + const uint32_t* b2 = reinterpret_cast(&frag_b2); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), + "r"(b2[0]), + "r"(b[1]), + "r"(b2[1]), + "r"(a[0]), + "r"(a[1]), + "f"(c[0]), + "f"(c[1]), + "f"(c[2]), + "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), + "r"(b2[0]), + "r"(b[1]), + "r"(b2[1]), + "r"(a[0]), + "r"(a[1]), + "f"(c[0]), + "f"(c[1]), + "f"(c[2]), + "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm(typename ScalarType::FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (count == 4) { + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); + } else if constexpr (count == 2) { + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(a[0]), "=r"(a[1]) : "r"(smem)); + } else if constexpr (count == 1) { + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" : "=r"(a[0]) : "r"(smem)); + } else { + static_assert(count == 1 || count == 2 || count == 4, "invalid count"); + } +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void +scale(typename ScalarType::FragB& frag_b, typename ScalarType::FragS& frag_s, int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +template +__device__ inline void scale_and_sub(typename ScalarType::FragB& frag_b, scalar_t s, scalar_t zp) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s2 = ScalarType::num2num2(s); + scalar_t2 zp2 = ScalarType::num2num2(zp); + frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); + frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); +} + +template +__device__ inline void +sub_zp(typename ScalarType::FragB& frag_b, typename ScalarType::scalar_t2& frag_zp, int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 zp = ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + +// Same as above, but for act_order (each K is multiplied individually) +template +__device__ inline void scale4( + typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s_1, + typename ScalarType::FragS& frag_s_2, + typename ScalarType::FragS& frag_s_3, + typename ScalarType::FragS& frag_s_4, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float* c, typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + } +} + +// Wait until value of lock to be negative, and then add 1 +__device__ inline void wait_negative_and_add(int* lock) { + if (threadIdx.x == 0) { + int state = 0; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state >= 0); + atomicAdd(lock, 1); + } + __syncthreads(); +} + +template < + typename scalar_t, // compute dtype, half or nv_float16 + const host::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 + // only) + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int lda, // A.stride(0), equal to prob_k is A is contiguous + int* locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce, // whether to use fp32 global reduce + int max_shared_mem) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + using FragZP = typename ScalarType::FragZP; + + static constexpr auto w_type = host::ScalarType::from_id(w_type_id); + constexpr bool has_zp = w_type == host::kU4 || w_type == host::kU8; + constexpr bool is_int_type = + w_type == host::kU4 || w_type == host::kU8 || w_type == host::kU4B8 || w_type == host::kU8B128; + // see comments of dequant.h for more details + constexpr bool dequant_skip_flop = !is_int_type || + has_zp && !is_zp_float && !std::is_same::value || + has_zp && !is_zp_float && !(w_type == host::kU8); + + scalar_t2 global_scale; + + if constexpr (w_type == host::kFE2M1f) { + uint16_t val = scale2_ptr[0]; + global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + } + + constexpr bool has_act_order = group_blocks == 0; + constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); + + constexpr int pack_factor = 32 / w_type.size_bits(); + static_assert(thread_m_blocks == 1 || !m_block_size_8); + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > m_block_size) { + parallel = prob_m / m_block_size; + prob_m = m_block_size; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + int par_id = 0; + int locks_off = 0; + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + slice_col = slice_col_par % n_tiles; + par_id = slice_col_par / n_tiles; + } + if (parallel * n_tiles >= gridDim.x) { + // when parallel * n_tiles >= sms + // then there are at most $sms$ conflict tile blocks + locks_off = blockIdx.x; + } else { + locks_off = (iters * blockIdx.x) / k_tiles - 1; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&](bool first_init = false) { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (parallel * n_tiles >= gridDim.x) { + if (slice_count > 1 && slice_idx == slice_count - 1) { + locks_off++; + } + } else { + locks_off++; + } + + if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) { + constexpr int threads_per_m = 16 * thread_n_blocks / 8; + int m_per_thread = div_ceil(thread_m_blocks * 16, threads / threads_per_m); + if (m_block_size_8) m_per_thread = div_ceil(8, threads / threads_per_m); + for (int i = 0; i < m_per_thread; i++) { + int row = threads / threads_per_m * i + threadIdx.x / threads_per_m; + if (row < prob_m) { + int col = slice_col * 16 * thread_n_blocks / 8 + threadIdx.x % threads_per_m; + C[row * prob_n / 8 + col] = {0, 0, 0, 0}; + } + } + // After write zero to output, write a negative value to lock. + // Every SM that processes the same slice would wait for + // the negative value, and then atomicAdd 1 to it. + // After all SMs are processed, the lock value would back to 0 again. + __syncthreads(); + if (threadIdx.x == 0) locks[locks_off] = 1 - slice_count; + } + + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * lda / 8; + C += 16 * thread_m_blocks * prob_n / 8; + slice_col = 0; + par_id++; + } + }; + init_slice(true); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = lda / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * m_block_size; + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1) + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + constexpr int act_s_max_num_groups = 32; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = is_zp_float ? 16 * thread_n_blocks / 8 : ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + auto b_sh_wr = threadIdx.x * b_thread_vecs; + auto b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) + + s_sh_stride * slice_col + threadIdx.x; + } + } + auto s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } + } + auto zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / n_warps; + + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + s_sh_rd = s_sh_rd * 2 + warp_row % 2; + + } else if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + if constexpr (is_zp_float) { + if constexpr (group_blocks != -1) { + zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + } + } else { + zp_sh_rd = num_ints_per_thread * num_col_threads * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + } + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8); + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; + constexpr int sh_b_size = stages * b_sh_stage; + int4* sh_b = sh; + int4* sh_red = sh; + int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) : (stages * s_sh_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); + // shared memory reused by reduction should be smaller than + // shared memory used by weight. + static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= stages * b_sh_stage); + int4* sh_a = sh_s + sh_s_size; + // constexpr int shm_size_used = + // stages * (g_idx_stage + zp_sh_stage) + sh_s_size + + // (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + + // Zero accumulators. + auto zero_accums = [&]() { +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + + auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups > act_s_max_num_groups) { + sh_num_groups = act_s_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred( + &sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_col_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + + auto fetch_col_scale_to_shared = [&]() { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + +#pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } else if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1)); + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (w_type_id != host::kFE2M1f.id()) { + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + auto th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, 9}; // Tensor core offsets per thread + +#pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp && !is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning +#pragma nv_diagnostic push +#pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; +#pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + + else if constexpr (has_zp && is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + // Suppress bogus and persistent divide-by-zero warning +#pragma nv_diagnostic push +#pragma nv_diag_suppress divide_by_zero + int cur_group_id = k_blocks / group_blocks; +#pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; + } + } + } + }; + + auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { + dequant(q, frag_b_ptr); + }; + + // Execute the actual tensor core matmul of a sub-tile. + bool is_first_matmul_in_slice = true; + auto matmul = [&](int k) { + int k2 = k % 2; + const bool is_new_zp = ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || + (group_blocks == -1 && is_first_matmul_in_slice); + if constexpr (has_zp && !is_zp_float) { + if (is_new_zp) { + if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = frag_qzp[k2][1]; + } + + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + } + } + if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { + if (is_new_zp) { + reinterpret_cast(&frag_zp)[0] = reinterpret_cast(&frag_zpf[k2])[0]; + } + } + + if constexpr (w_type == host::kFE2M1f) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales(s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales(s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } + +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + int b_quant_0, b_quant_1; + + if constexpr (w_type_id == host::kFE2M1f.id()) { + b_quant_1 = frag_b_quant[k2][0][j]; + b_quant_0 = b_quant_1 << 8; + } else if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + + if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); + } + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + static_assert(group_blocks != -1); + scale4( + frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4( + frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); + } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && group_blocks == -1) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 s2 = Dtype::nums2num2( + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); + if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { + if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); + } else if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + if constexpr (m_block_size_8) { + mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + } else { + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + auto red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + +#pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { +#pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { +#pragma unroll + for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = reinterpret_cast(&sh_red[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); +#pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + } + sh_red[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { +#pragma unroll + for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { + float* c_rd = reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); +#pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce_fp16 = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr; + if constexpr (m_block_size_8) { + c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } else { + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } + constexpr int c_sh_wr_delta = active_threads; + auto c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { +// Interestingly, doing direct global accesses here really seems to mess up +// the compiler and lead to slowdowns, hence we also use async-copies even +// though these fetches are not actually asynchronous. +#pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + if constexpr (m_block_size_8) { + cp_async4_pred( + &sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i], + (threadIdx.x % 4) * 2 + i < prob_m); + } else { + cp_async4_pred( + &sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + } + cp_async_fence(); + cp_async_wait<0>(); + } + +#pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + bool mask = (!m_block_size_8) && (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) || + (m_block_size_8) && ((threadIdx.x % 4) * 2 + i < prob_m); + if (mask) { + if (!first) { + int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + } + if constexpr (m_block_size_8) + C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = c; + else + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; + } + } + } + } + }; + + // Globally reduce over threadblocks that compute the same column block. + // We use a tmp C buffer to reduce in full fp32 precision. + auto global_reduce_fp32 = [&](bool first = false, bool last = false) { + constexpr int tb_m = thread_m_blocks * 16; + constexpr int tb_n = thread_n_blocks * 16; + + constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; + + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + + constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int th_size = num_floats * sizeof(float) / 16; + + int c_cur_offset = locks_off * c_size; + + if (!is_th_active) { + return; + } + + if (!first) { + float* frag_c_ptr = reinterpret_cast(&frag_c); +#pragma unroll + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) { + sh_red[threadIdx.x] = C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; + + float* sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); +#pragma unroll + for (int f = 0; f < 4; f++) { + frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; + } + } + } + + if (!last) { + int4* frag_c_ptr = reinterpret_cast(&frag_c); +#pragma unroll + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) { + C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr; + if constexpr (m_block_size_8) { + c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + (threadIdx.x % 32) / 4; + c_sh_wr += 64 * (threadIdx.x / 32); + } else { + c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + } + + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr ( + !has_act_order && group_blocks == -1 && w_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { + res = __hmul2(res, s[0]); + } + + if constexpr (w_type == host::kFE2M1f) { + res = __hmul2(res, global_scale); + } + + if constexpr (m_block_size_8) { + ((scalar_t*)sh_red)[idx] = res.x; + ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; + } else { + ((scalar_t2*)sh_red)[idx] = res; + } + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + if constexpr (m_block_size_8) { + int wr = c_sh_wr + 16 * j; + write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 1]); + } else { + int wr = c_sh_wr + 8 * j; + write( + wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write( + wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write( + wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write( + wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { + if (c_gl_wr < c_gl_wr_end) { + if (use_atomic_add && slice_count > 1) { + scalar_t2* C_half2 = reinterpret_cast(&C[c_gl_wr]); + scalar_t2* sh_red_half2 = reinterpret_cast(&sh_red[c_sh_rd]); +#pragma unroll + for (int a = 0; a < 4; a++) { + atomicAdd(&C_half2[a], sh_red_half2[a]); + } + } else { + C[c_gl_wr] = sh_red[c_sh_rd]; + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + __syncthreads(); + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + +#pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + + if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + if (i == 0) { + fetch_col_zp_to_shared(); + if constexpr (!dequant_skip_flop) { + fetch_col_scale_to_shared(); + } + } + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + if constexpr (has_act_order) { + slice_k_start_shared_fetch += tb_k * (stages - 1); + } + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + +#pragma unroll + for (int pipe = 0; pipe < stages;) { +#pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + + if constexpr (has_act_order) { + slice_k_start += tb_k * stages; + + if (slice_k_start < prob_k) { + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + if constexpr (m_block_size_8) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2* frag_s_half2 = reinterpret_cast(frag_s); +#pragma unroll + for (int i = 0; i < 8; i++) { + frag_s_half2[i] = Dtype::num2num2(reinterpret_cast(&frag_s_half2[i])[idx]); + } + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr ( + !has_act_order && group_blocks == -1 && w_type.size_bits() == 8 && (has_zp && dequant_skip_flop || !has_zp)) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); + + if constexpr (!m_block_size_8) { + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + } + + if (slice_count > 1 && !use_atomic_add) { + // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[locks_off], slice_idx); + if (use_fp32_reduce) { + global_reduce_fp32(slice_idx == 0, last); + } else { + global_reduce_fp16(slice_idx == 0, last); + } + barrier_release(&locks[locks_off], last); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) wait_negative_and_add(&locks[locks_off]); + if (last || use_atomic_add) + // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + is_first_matmul_in_slice = true; + init_slice(); + + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } + } + } +} + +} // namespace device::marlin + +#endif diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/marlin_moe/kernel.h b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin_moe/kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..522a77d40fd15e413b325a03c9bad6a30ed41fa9 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin_moe/kernel.h @@ -0,0 +1,37 @@ + +#include + +#include "../marlin/marlin.cuh" +#include "../marlin/marlin_dtypes.cuh" + +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ b_bias_ptr, const int4 *__restrict__ scales_ptr, \ + const uint16_t *__restrict__ scale2_ptr, const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ + const int32_t *__restrict__ sorted_token_ids_ptr, const int32_t *__restrict__ expert_ids_ptr, \ + const int32_t *__restrict__ num_tokens_past_padded_ptr, const float *__restrict__ topk_weights_ptr, int top_k, \ + bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, int prob_n, int prob_k, int *locks, \ + bool has_bias, bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem + +namespace device::marlin_moe { +template < + typename scalar_t, // compute dtype, half or nv_float16 + const host::ScalarTypeId w_type_id, // weight ScalarType id + const host::ScalarTypeId s_type_id, // weight scale ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin(MARLIN_KERNEL_PARAMS); + +} // namespace device::marlin_moe diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h new file mode 100644 index 0000000000000000000000000000000000000000..bf7dcb2023014a223f4acce33b9eeab206fd6854 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h @@ -0,0 +1,1896 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#include + +#include "../marlin/dequant.h" +#include "../marlin/marlin.cuh" +#include "../marlin/marlin_dtypes.cuh" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert( \ + std::is_same::value || std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace device::marlin_moe { +using namespace device::marlin; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template < + typename scalar_t, // compute dtype, half or nv_float16 + const host::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids + const int32_t* __restrict__ expert_ids_ptr, // moe expert ids + const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens + const float* __restrict__ topk_weights_ptr, // moe top weights + int top_k, // num of experts per token + bool mul_topk_weights, // mul topk weights or not + bool is_ep, // expert parallelism + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce, // whether to use fp32 global reduce + int max_shared_mem) {} + +} // namespace device::marlin_moe + +#else + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void +mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +template +__device__ inline void mma_trans( + const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + const typename ScalarType::FragB& frag_b2, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + const uint32_t* b2 = reinterpret_cast(&frag_b2); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), + "r"(b2[0]), + "r"(b[1]), + "r"(b2[1]), + "r"(a[0]), + "r"(a[1]), + "f"(c[0]), + "f"(c[1]), + "f"(c[2]), + "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), + "r"(b2[0]), + "r"(b[1]), + "r"(b2[1]), + "r"(a[0]), + "r"(a[1]), + "f"(c[0]), + "f"(c[1]), + "f"(c[2]), + "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm(typename ScalarType::FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (count == 4) { + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); + } else if constexpr (count == 2) { + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(a[0]), "=r"(a[1]) : "r"(smem)); + } else if constexpr (count == 1) { + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" : "=r"(a[0]) : "r"(smem)); + } else { + static_assert(count == 1 || count == 2 || count == 4, "invalid count"); + } +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void +scale(typename ScalarType::FragB& frag_b, typename ScalarType::FragS& frag_s, int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +template +__device__ inline void scale_and_sub(typename ScalarType::FragB& frag_b, scalar_t s, scalar_t zp) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s2 = ScalarType::num2num2(s); + scalar_t2 zp2 = ScalarType::num2num2(zp); + frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); + frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); +} + +template +__device__ inline void +sub_zp(typename ScalarType::FragB& frag_b, typename ScalarType::scalar_t2& frag_zp, int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 zp = ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + +// Same as above, but for act_order (each K is multiplied individually) +template +__device__ inline void scale4( + typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s_1, + typename ScalarType::FragS& frag_s_2, + typename ScalarType::FragS& frag_s_3, + typename ScalarType::FragS& frag_s_4, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float* c, typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + } +} + +// Wait until value of lock to be negative, and then add 1 +__device__ inline void wait_negative_and_add(int* lock) { + if (threadIdx.x == 0) { + int state = 0; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state >= 0); + atomicAdd(lock, 1); + } + __syncthreads(); +} + +template < + typename scalar_t, // compute dtype, half or nv_float16 + const host::ScalarTypeId w_type_id, // weight ScalarType id + const host::ScalarTypeId s_type_id, // weight scale ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ b_bias_ptr, + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 + // only) + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids + const int32_t* __restrict__ expert_ids_ptr, // moe expert ids + const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens + const float* __restrict__ topk_weights_ptr, // moe top weights + int top_k, // num of experts per token + bool mul_topk_weights, // mul topk weights or not + bool is_ep, // expert parallelism + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool has_bias, + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce, // whether to use fp32 global reduce + int max_shared_mem) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + using FragZP = typename ScalarType::FragZP; + + extern __shared__ int4 sh[]; + static constexpr auto w_type = host::ScalarType::from_id(w_type_id); + static constexpr auto s_type = host::ScalarType::from_id(s_type_id); + if constexpr (w_type == host::kFE2M1f) { + static_assert(s_type == host::kFE4M3fn && group_blocks == 1 || s_type == host::kFE8M0fnu && group_blocks == 2); + } else if constexpr (std::is_same::value) { + static_assert(s_type == host::kBFloat16); + } else if constexpr (std::is_same::value) { + static_assert(s_type == host::kFloat16); + } + + constexpr bool has_zp = w_type == host::kU4 || w_type == host::kU8; + constexpr bool is_int_type = + w_type == host::kU4 || w_type == host::kU8 || w_type == host::kU4B8 || w_type == host::kU8B128; + // see comments of dequant.h for more details + constexpr bool dequant_skip_flop = w_type == host::kFE4M3fn || w_type == host::kFE2M1f && s_type == host::kFE4M3fn || + has_zp && !is_zp_float && !std::is_same::value || + has_zp && !is_zp_float && !(w_type == host::kU8); + + scalar_t2 global_scale; + + constexpr bool has_act_order = group_blocks == 0; + + constexpr int pack_factor = 32 / w_type.size_bits(); + static_assert(thread_m_blocks == 1 || !m_block_size_8); + constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); + const int group_size = (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; + const int scales_expert_stride = prob_n * prob_k / group_size / (w_type == host::kFE2M1f ? 16 : 8); + const int zp_expert_stride = + is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4); + const int b_bias_expert_stride = prob_n / 8; + + // parallel: num valid moe blocks + int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; + int parallel = num_tokens_past_padded / moe_block_size; + int num_valid_blocks = parallel; + if (is_ep) { + for (int i = 0; i < parallel; i++) { + if (expert_ids_ptr[i] == -1) num_valid_blocks--; + } + } + int num_invalid_blocks = parallel - num_valid_blocks; + parallel = num_valid_blocks; + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + int par_id = 0; + int block_id = -1; + int64_t expert_id = 0; // use int64 to avoid computation result overflow + int old_expert_id = 0; + int64_t B_expert_off = 0; + + int4* sh_block_sorted_ids_int4 = sh; + int4* sh_rd_block_sorted_ids_int4 = sh_block_sorted_ids_int4 + moe_block_size / 4; + int4* sh_block_topk_weights_int4 = sh_rd_block_sorted_ids_int4 + moe_block_size / 4; + // sh_block_topk_weights_int4 only need (moe_block_size / 4); + // but we pad to align to 256 bytes + int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 2 + moe_block_size; + int32_t* sh_block_sorted_ids = reinterpret_cast(sh_block_sorted_ids_int4); + int32_t* sh_rd_block_sorted_ids = reinterpret_cast(sh_rd_block_sorted_ids_int4); + scalar_t2* sh_block_topk_weights = reinterpret_cast(sh_block_topk_weights_int4); + + int32_t block_num_valid_tokens = 0; + int32_t locks_off = 0; + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + slice_col = slice_col_par % n_tiles; + par_id = slice_col_par / n_tiles; + } + if (parallel * n_tiles >= gridDim.x) { + // when parallel * n_tiles >= sms + // then there are at most $sms$ conflict tile blocks + locks_off = blockIdx.x; + } else { + locks_off = (iters * blockIdx.x) / k_tiles - 1; + } + + // read moe block data given block_id + // block_sorted_ids / block_num_valid_tokens / block_topk_weights + auto read_moe_block_data = [&](int block_id) { + block_num_valid_tokens = moe_block_size; +#pragma unroll + for (int i = 0; i < moe_block_size / 4; i++) { + int4 sorted_token_ids_int4 = + reinterpret_cast(sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i]; + int* sorted_token_ids = reinterpret_cast(&sorted_token_ids_int4); +#pragma unroll + for (int j = 0; j < 4; j++) { + if (sorted_token_ids[j] >= prob_m * top_k) { + block_num_valid_tokens = i * 4 + j; + break; + } + } + if (block_num_valid_tokens != moe_block_size) break; + } + + __syncthreads(); + int tid4 = threadIdx.x / 4; + if (threadIdx.x % 4 == 0 && threadIdx.x < block_num_valid_tokens) { + sh_block_sorted_ids_int4[tid4] = + reinterpret_cast(sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4]; + +#pragma unroll + for (int i = 0; i < 4; i++) + sh_rd_block_sorted_ids[tid4 * 4 + i] = sh_block_sorted_ids[tid4 * 4 + i] / top_k; + + if (mul_topk_weights) { +#pragma unroll + for (int i = 0; i < 4; i++) { + int idx = tid4 * 4 + i; + // idx = idx < block_num_valid_tokens ? idx : 0; + if (idx < block_num_valid_tokens) { + if constexpr (w_type == host::kFE2M1f && s_type == host::kFE4M3fn) { + sh_block_topk_weights[idx] = + __hmul2(global_scale, Dtype::num2num2(Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]))); + } else { + sh_block_topk_weights[idx] = + Dtype::num2num2(Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]])); + } + } + } + } + } + __syncthreads(); + }; + + // when move to next moe block, find the next block_id and expert_id + // and then read moe block data + auto update_next_moe_block_data = [&]() { + if (par_id >= parallel) return; + + old_expert_id = expert_id; + if (num_invalid_blocks > 0) { + int skip_count = block_id == -1 ? par_id : 0; + block_id++; + for (int i = block_id; i < num_tokens_past_padded / moe_block_size; i++) { + expert_id = expert_ids_ptr[i]; + if (expert_id != -1) { + if (skip_count == 0) { + block_id = i; + break; + }; + skip_count--; + }; + } + } else { + block_id = par_id; + expert_id = expert_ids_ptr[block_id]; + } + + if constexpr (w_type == host::kFE2M1f && s_type == host::kFE4M3fn) { + uint16_t val = scale2_ptr[expert_id]; + global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + } + + B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4); + scales_ptr += (expert_id - old_expert_id) * scales_expert_stride; + if constexpr (has_zp) { + zp_ptr += (expert_id - old_expert_id) * zp_expert_stride; + } + if constexpr (has_act_order) { + g_idx += (expert_id - old_expert_id) * prob_k; + } + if (has_bias) { + b_bias_ptr += (expert_id - old_expert_id) * b_bias_expert_stride; + } + + read_moe_block_data(block_id); + }; + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&](bool first_init = false) { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (parallel * n_tiles >= gridDim.x) { + if (slice_count > 1 && slice_idx == slice_count - 1) { + locks_off++; + } + } else { + locks_off++; + } + + if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) { + constexpr int threads_per_m = 16 * thread_n_blocks / 8; + int m_per_thread = div_ceil(block_num_valid_tokens, threads / threads_per_m); + for (int i = 0; i < m_per_thread; i++) { + int row = threads / threads_per_m * i + threadIdx.x / threads_per_m; + if (row < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[row]; + int col = slice_col * 16 * thread_n_blocks / 8 + threadIdx.x % threads_per_m; + C[sorted_row * prob_n / 8 + col] = {0, 0, 0, 0}; + } + } + // After write zero to output, write a negative value to lock. + // Every SM that processes the same slice would wait for + // the negative value, and then atomicAdd 1 to it. + // After all SMs are processed, the lock value would back to 0 again. + __syncthreads(); + if (threadIdx.x == 0) locks[locks_off] = 1 - slice_count; + } + + if (slice_col == n_tiles) { + slice_col = 0; + par_id++; + update_next_moe_block_data(); + } + }; + + update_next_moe_block_data(); + init_slice(true); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1) + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + constexpr int act_s_max_num_groups = 32; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = is_zp_float ? 16 * thread_n_blocks / 8 : ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd_row = threadIdx.x / a_gl_rd_delta_o; + int a_gl_rd_col = a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o; + + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + auto b_sh_wr = threadIdx.x * b_thread_vecs; + auto b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) + + s_sh_stride * slice_col + threadIdx.x; + } + } + auto s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } + } + auto zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / n_warps; + + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2; + + } else if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + + int bias_sh_rd; + if constexpr (m_block_size_8) { + bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; + } else { + bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + } + + int bias_sh_wr = threadIdx.x; + int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + if constexpr (is_zp_float) { + if constexpr (group_blocks != -1) { + zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + } + } else { + zp_sh_rd = num_ints_per_thread * num_col_threads * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8); + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + // Shared memory storage for global fetch pipelines. + constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; + constexpr int sh_b_size = stages * b_sh_stage; + int4* sh_b = sh_new; + int4* sh_red = sh_new; + + constexpr int sh_size_b_red_min = (sh_red_size < sh_b_size ? sh_red_size : sh_b_size); + constexpr int sh_size_b_red_max = (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + constexpr int sh_bias_size = (thread_n_blocks * 16 / 8); + constexpr int sh_b_red_bias_size = + sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size) ? sh_size_b_red_max : (sh_size_b_red_min + sh_bias_size); + + int4* sh_bias = sh_new + sh_size_b_red_min; + int4* sh_g_idx = sh_new + sh_b_red_bias_size; + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) : (stages * s_sh_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); + // shared memory reused by reduction should be smaller than + // shared memory used by weight. + static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= stages * b_sh_stage); + int4* sh_a = sh_s + sh_s_size; + constexpr int shm_size_used = moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size + sh_b_red_bias_size; + + // all remaining shared memory is used to cache A (input) + // sh_a_max_row is at least ` stages * 16 * thread_m_blocks ` + int sh_a_max_row = ((max_shared_mem - 1024) / 16 - shm_size_used) / (thread_k_blocks * 2); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS frag_bias[2][4]; + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + + // Zero accumulators. + auto zero_accums = [&]() { +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + + auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups > act_s_max_num_groups) { + sh_num_groups = act_s_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred( + &sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]; + } + } + } + }; + + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + bool should_load_a = true; + int max_num_stage_groups = ((sh_a_max_row - moe_block_size) / moe_block_size + 1) / stages; + max_num_stage_groups = max(max_num_stage_groups, 1); + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true, int pipe_a = 0) { + if (pred) { + if (should_load_a) { + int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row; + int64_t sorted_row = 0; + if (!m_block_size_8 || row < 8) sorted_row = sh_rd_block_sorted_ids[row]; + int64_t true_idx = sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off; + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx], row < block_num_valid_tokens); + } + } + + int4* sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j + B_expert_off); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_col_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + + auto fetch_col_scale_to_shared = [&]() { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe, int pipe_a = 0) { + int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a; +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + +#pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } else if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1)); + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (w_type_id != host::kFE2M1f.id()) { + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + k % 2]; + } + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + auto th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, 9}; // Tensor core offsets per thread + +#pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp && !is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning +#pragma nv_diagnostic push +#pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; +#pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + + else if constexpr (has_zp && is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + // Suppress bogus and persistent divide-by-zero warning +#pragma nv_diagnostic push +#pragma nv_diag_suppress divide_by_zero + int cur_group_id = k_blocks / group_blocks; +#pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; + } + } + } + }; + + auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { + dequant(q, frag_b_ptr); + }; + + // Execute the actual tensor core matmul of a sub-tile. + bool is_first_matmul_in_slice = true; + auto matmul = [&](int k) { + int k2 = k % 2; + const bool is_new_zp = ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || + (group_blocks == -1 && is_first_matmul_in_slice); + if constexpr (has_zp && !is_zp_float) { + if (is_new_zp) { + if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = frag_qzp[k2][1]; + } + + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + } + } + if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { + if (is_new_zp) { + reinterpret_cast(&frag_zp)[0] = reinterpret_cast(&frag_zpf[k2])[0]; + } + } + + // Commented out FP4/FP8 scale dequantization since we don't generate + // kFE2M1f kernels to reduce compilation time + // if constexpr (w_type == host::kFE2M1f) { + // int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + // int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + // + // dequant_fp8_scales( + // s_quant_0, reinterpret_cast(&frag_s[k2])); + // dequant_fp8_scales( + // s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + // } + +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + int b_quant_0, b_quant_1; + + if constexpr (w_type_id == host::kFE2M1f.id()) { + b_quant_1 = frag_b_quant[k2][0][j]; + b_quant_0 = b_quant_1 << 8; + } else if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + + if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); + } + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + static_assert(group_blocks != -1); + scale4( + frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4( + frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); + } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && group_blocks == -1) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 s2 = Dtype::nums2num2( + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); + if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { + if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); + } else if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + if constexpr (m_block_size_8) { + mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + } else { + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + auto red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + +#pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { +#pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { +#pragma unroll + for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = reinterpret_cast(&sh_red[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); +#pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + } + sh_red[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { +#pragma unroll + for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { + float* c_rd = reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); +#pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce_fp16 = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + if (!is_th_active) { + return; + } + + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr; + if constexpr (m_block_size_8) { + c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } else { + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + if (!first) { + +#pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + int c_idx; + if constexpr (m_block_size_8) + c_idx = c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i; + else + c_idx = c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + if (c_idx / c_gl_stride < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; + int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + sh_red[c_sh_wr + c_sh_wr_delta * i] = C[true_idx]; + } + } + } + +#pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + if (!first) { + int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + } + + int c_idx; + if constexpr (m_block_size_8) + c_idx = c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i; + else + c_idx = c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + if (c_idx / c_gl_stride < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; + int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + C[true_idx] = c; + } + } + } + }; + + // Globally reduce over threadblocks that compute the same column block. + // We use a tmp C buffer to reduce in full fp32 precision. + auto global_reduce_fp32 = [&](bool first = false, bool last = false) { + constexpr int tb_m = thread_m_blocks * 16; + constexpr int tb_n = thread_n_blocks * 16; + + constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; + + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + + constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int th_size = num_floats * sizeof(float) / 16; + + int c_cur_offset = locks_off * c_size; + + if (!is_th_active) { + return; + } + + if (!first) { + float* frag_c_ptr = reinterpret_cast(&frag_c); +#pragma unroll + for (int k = 0; k < th_size; k++) { + if constexpr (m_block_size_8) { + if (k % 2) continue; + } else { + if (k / 8 * 16 + (threadIdx.x % 32) / 4 >= block_num_valid_tokens) continue; + } + + sh_red[threadIdx.x] = C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; + + float* sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); +#pragma unroll + for (int f = 0; f < 4; f++) { + frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; + } + } + } + + if (!last) { + int4* frag_c_ptr = reinterpret_cast(&frag_c); +#pragma unroll + for (int k = 0; k < th_size; k++) { + if constexpr (m_block_size_8) { + if (k % 2) continue; + } else { + if (k / 8 * 16 + (threadIdx.x % 32) / 4 >= block_num_valid_tokens) continue; + } + + C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&](bool last) { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr; + if constexpr (m_block_size_8) { + c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + (threadIdx.x % 32) / 4; + c_sh_wr += 64 * (threadIdx.x / 32); + } else { + c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + } + + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) { + scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr ( + !has_act_order && group_blocks == -1 && w_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { + scalar_t2 tmp_scale = s[0]; + if constexpr (m_block_size_8) { + tmp_scale = Dtype::num2num2(reinterpret_cast(&s[0])[(threadIdx.x % 8) / 4]); + } + res = __hmul2(res, tmp_scale); + } + + if constexpr (w_type == host::kFE2M1f && s_type == host::kFE4M3fn) { + if (!mul_topk_weights) { + res = __hmul2(res, global_scale); + } + } + if (has_bias && last) { + scalar_t2 tmp_bias = b_bias[0]; + if constexpr (m_block_size_8) { + tmp_bias = Dtype::num2num2(reinterpret_cast(&b_bias[0])[(threadIdx.x % 8) / 4]); + } + res = __hadd2(res, tmp_bias); + } + + if constexpr (m_block_size_8) { + ((scalar_t*)sh_red)[idx] = res.x; + ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; + } else { + ((scalar_t2*)sh_red)[idx] = res; + } + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + if constexpr (m_block_size_8) { + int wr = c_sh_wr + 16 * j; + write( + wr, + frag_c[i][j][0][0], + frag_c[i][j][0][1], + frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); + write( + wr + 8, + frag_c[i][j][0][2], + frag_c[i][j][0][3], + frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); + } else { + int wr = c_sh_wr + 8 * j; + write( + wr + (4 * c_sh_stride) * 0 + 0, + frag_c[i][j][0][0], + frag_c[i][j][0][1], + frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); + write( + wr + (4 * c_sh_stride) * 8 + 0, + frag_c[i][j][0][2], + frag_c[i][j][0][3], + frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); + write( + wr + (4 * c_sh_stride) * 0 + 4, + frag_c[i][j][1][0], + frag_c[i][j][1][1], + frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); + write( + wr + (4 * c_sh_stride) * 8 + 4, + frag_c[i][j][1][2], + frag_c[i][j][1][3], + frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); + } + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { + int row = c_gl_wr / c_gl_stride; + if (row < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[row]; + int64_t true_idx = sorted_row * c_gl_stride + c_gl_wr % c_gl_stride; + scalar_t2 topk_weight_score; + if (mul_topk_weights) topk_weight_score = sh_block_topk_weights[row]; + if (use_atomic_add && slice_count > 1 || mul_topk_weights) { + scalar_t2* C_half2 = reinterpret_cast(&C[true_idx]); + scalar_t2* sh_red_half2 = reinterpret_cast(&sh_red[c_sh_rd]); +#pragma unroll + for (int a = 0; a < 4; a++) { + scalar_t2 res = sh_red_half2[a]; + if (mul_topk_weights) { + res = __hmul2(res, topk_weight_score); + } + + if (use_atomic_add && slice_count > 1) { + atomicAdd(&C_half2[a], res); + } else { + C_half2[a] = res; + }; + } + } else { + C[true_idx] = sh_red[c_sh_rd]; + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + __syncthreads(); + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + +#pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + + if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + if (i == 0) { + fetch_col_zp_to_shared(); + if constexpr (!dequant_skip_flop) { + fetch_col_scale_to_shared(); + } + } + } + fetch_to_shared(i, i, i < slice_iters, i); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd_col += a_gl_rd_delta_o * (stages - 1); + if constexpr (has_act_order) { + slice_k_start_shared_fetch += tb_k * (stages - 1); + } + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + + for (int stage_group_id = 0; stage_group_id < max_num_stage_groups; stage_group_id++) { +#pragma unroll + for (int pipe = 0; pipe < stages;) { +#pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + int idx = (pipe >= stages && stage_group_id == max_num_stage_groups - 1) ? (pipe - stages) + : (pipe + stage_group_id * stages); + fetch_to_registers(k + 1, pipe % stages, idx); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + int idx = (pipe >= 1 && stage_group_id == max_num_stage_groups - 1) + ? (pipe - 1) + : (pipe + (stage_group_id + 1) * stages - 1); + fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages, idx); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd_col += a_gl_rd_delta_o * stages; + + if constexpr (has_act_order) { + slice_k_start += tb_k * stages; + + if (slice_k_start < prob_k) { + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + } + if (slice_iters == 0) { + break; + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + + if (has_bias && last) { + __syncthreads(); + cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd], threadIdx.x < 16 * thread_n_blocks / 8); + cp_async_fence(); + } + + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + if constexpr (m_block_size_8) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2* frag_s_half2 = reinterpret_cast(frag_s); +#pragma unroll + for (int i = 0; i < 8; i++) { + frag_s_half2[i] = Dtype::num2num2(reinterpret_cast(&frag_s_half2[i])[idx]); + } + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr ( + !has_act_order && group_blocks == -1 && w_type.size_bits() == 8 && (has_zp && dequant_skip_flop || !has_zp)) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); + + if constexpr (!m_block_size_8) { + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + } + + if (slice_count > 1 && !use_atomic_add) { + // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[locks_off], slice_idx); + if (use_fp32_reduce) { + global_reduce_fp32(slice_idx == 0, last); + } else { + global_reduce_fp16(slice_idx == 0, last); + } + barrier_release(&locks[locks_off], last); + } + + if (has_bias && last) { + cp_async_wait<0>(); + __syncthreads(); + reinterpret_cast(&frag_bias)[0] = sh_bias[bias_sh_rd]; + reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; + __syncthreads(); + } + + if (use_atomic_add && slice_count > 1 && slice_idx != 0) wait_negative_and_add(&locks[locks_off]); + if (last || use_atomic_add) + // only the last block in a slice actually writes the result + write_result(last); + int old_slice_row = slice_row; + slice_row = 0; + slice_col_par++; + slice_col++; + is_first_matmul_in_slice = true; + init_slice(); + + // Should we load A matrix in next slice? + // `slice_col == 0`: when move to a new moe block + // `old_slice_row > 0`: + // when the last slice is not starting from k_index == 0 + // (only happen when it is the first slice of a threadblock) + // `prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups`: + // when the required shared memory size is larger than + // the remaining shared memory + if (slice_col == 0 || old_slice_row || prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups) { + should_load_a = true; + } else { + should_load_a = false; + } + + if (slice_iters) { + a_gl_rd_col = (threadIdx.x % a_gl_rd_delta_o); +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; + } + + bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + start_pipes(); + } + } + } +} + +} // namespace device::marlin_moe + +#endif diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh new file mode 100644 index 0000000000000000000000000000000000000000..81c021dc8eccac09b26b5dc89c12330c19a17693 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh @@ -0,0 +1,1089 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#pragma once + +#include + +#include + +#include "kernel.h" +#include "marlin_template.h" + +namespace device::marlin_moe { + +__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; + +using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template +__global__ void permute_cols_kernel( + int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, + const int32_t* __restrict__ sorted_token_ids_ptr, + const int32_t* __restrict__ expert_ids_ptr, + const int32_t* __restrict__ num_tokens_past_padded_ptr, + int size_m, + int size_k, + int top_k) {}; + +#else + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +template +__global__ void permute_cols_kernel( + int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, + const int32_t* __restrict__ sorted_token_ids_ptr, + const int32_t* __restrict__ expert_ids_ptr, + const int32_t* __restrict__ num_tokens_past_padded_ptr, + int size_m, + int size_k, + int top_k) { + int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; + int num_moe_blocks = div_ceil(num_tokens_past_padded, moe_block_size); + int32_t block_sorted_ids[moe_block_size]; + int block_num_valid_tokens = 0; + int64_t old_expert_id = 0; + int64_t expert_id = 0; + int row_stride = size_k * sizeof(half) / 16; + + auto read_moe_block_data = [&](int block_id) { + block_num_valid_tokens = moe_block_size; + int4* tmp_block_sorted_ids = reinterpret_cast(block_sorted_ids); + for (int i = 0; i < moe_block_size / 4; i++) { + tmp_block_sorted_ids[i] = ((int4*)sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i]; + } + for (int i = 0; i < moe_block_size; i++) { + if (block_sorted_ids[i] >= size_m * top_k) { + block_num_valid_tokens = i; + break; + }; + } + }; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int in_offset = (row / top_k) * row_stride; + int out_offset = row * row_stride; + + half const* a_row_half = reinterpret_cast(a_int4_ptr + in_offset); + half* out_half = reinterpret_cast(out_int4_ptr + out_offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int index = blockIdx.x; index < num_moe_blocks; index += gridDim.x) { + old_expert_id = expert_id; + int tmp_expert_id = expert_ids_ptr[index]; + if (tmp_expert_id == -1) continue; + expert_id = tmp_expert_id; + perm_int_ptr += (expert_id - old_expert_id) * size_k; + read_moe_block_data(index); + + for (int i = 0; i < block_num_valid_tokens; i++) + permute_row(block_sorted_ids[i]); + } +} + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}}; + +typedef struct { + int blocks_per_sm; + thread_config_t tb_cfg; +} exec_config_t; + +int get_scales_cache_size( + thread_config_t const& th_config, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } +} + +int get_kernel_cache_size( + thread_config_t const& th_config, + bool m_block_size_8, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + int tb_m = thread_m_blocks * 16; + + // shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights + // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) + int sh_block_meta_size = tb_m * 4; + int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; + int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_red_size = tb_m * (tb_n + 8) * 2; + int sh_bias_size = tb_n * 2; + int tmp_size = (sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size; + tmp_size = max(max(sh_b_size, sh_red_size), tmp_size); + + int sh_s_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); + int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; + int sh_zp_size = 0; + if (has_zp) { + if (is_zp_float) + sh_zp_size = sh_s_size; + else if (num_bits == 4) + sh_zp_size = sh_s_size / 4; + else if (num_bits == 8) + sh_zp_size = sh_s_size / 2; + } + + int total_size = tmp_size + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size + sh_block_meta_size; + + return total_size; +} + +bool is_valid_config( + thread_config_t const& th_config, + bool m_block_size_8, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float, + int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Check that pipeline fits into cache + int cache_size = get_kernel_cache_size( + th_config, + m_block_size_8, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + return cache_size + 512 <= max_shared_mem; +} + +#define _GET_IF( \ + W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if ( \ + q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && m_block_size_8 == M_BLOCK_SIZE_8 && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS && is_zp_float == IS_ZP_FLOAT) { \ + constexpr auto S_TYPE = W_TYPE == host::kFE2M1f \ + ? (GROUP_BLOCKS == 1 ? host::kFE4M3fn : host::kFE8M0fnu) \ + : (std::is_same::value ? host::kFloat16 : host::kBFloat16); \ + kernel = Marlin< \ + scalar_t, \ + W_TYPE.id(), \ + S_TYPE.id(), \ + NUM_THREADS, \ + THREAD_M_BLOCKS, \ + THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, \ + pipe_stages, \ + GROUP_BLOCKS, \ + IS_ZP_FLOAT>; \ + } + +// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) +// this is the most common cases +// BIGGROUP: cases for big group size (group_blocks in [-1, 8]) +// FZP: cases for float-zero-point (is_zp_float = true) +// ACT: cases for act order case (group_blocks == 0) +// NVFP4: cases for nvfp4(e2m1) (group_blocks == 1) +// MXFP4: cases for mxfp4(e2m1) (group_blocks == 2) +#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF(W_TYPE) \ + COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ + COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ + COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) + +#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF(W_TYPE) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) + +#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define NVFP4_GET_IF(W_TYPE) \ + NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) + +#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) + +#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) + +#define MXFP4_GET_IF(W_TYPE) \ + MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) + +// We currently have 4-bit models only with group_blocks == 4 +#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF(W_TYPE) \ + FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FZP_GET_IF_M234(W_TYPE, 8, 4, 128) + +// We currently have 4-bit models only with group_blocks == 4 +#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF(W_TYPE) \ + ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ + ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ + ACT_GET_IF_M234(W_TYPE, 8, 4, 128) + +template +MarlinFuncPtr get_marlin_kernel( + const host::ScalarType q_type, + int thread_m_blocks, + int thread_n_blocks, + int thread_k_blocks, + bool m_block_size_8, + bool has_act_order, + bool has_zp, + int group_blocks, + int num_threads, + bool is_zp_float) { + int num_bits = q_type.size_bits(); + auto kernel = MarlinDefault; + if (false) { + } + + COMMON_GET_IF(host::kU4) + COMMON_GET_IF(host::kU4B8) + COMMON_GET_IF(host::kU8B128) + + NVFP4_GET_IF(host::kFE2M1f) + + BIGGROUP_GET_IF(host::kFE4M3fn) + + ACT_GET_IF(host::kU4B8) + ACT_GET_IF(host::kU8B128) + if (std::is_same::value) { + if (false) { + } + MXFP4_GET_IF(host::kFE2M1f) + } + + return kernel; +} + +template +exec_config_t determine_exec_config( + const host::ScalarType& q_type, + int prob_m, + int prob_n, + int prob_k, + int thread_m_blocks, + bool m_block_size_8, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + bool has_zp, + bool is_zp_float, + int max_shared_mem) { + exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; + thread_config_t* thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs : small_batch_thread_configs; + int thread_configs_size = thread_m_blocks > 1 ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t) + : sizeof(small_batch_thread_configs) / sizeof(thread_config_t); + + int count = 0; + constexpr int device_max_reg_size = 255 * 1024; + for (int i = 0; i < thread_configs_size; i++) { + thread_config_t th_config = thread_configs[i]; + + if (!is_valid_config( + th_config, + m_block_size_8, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem)) { + continue; + } + + int cache_size = get_kernel_cache_size( + th_config, + m_block_size_8, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + + int group_blocks = 0; + if (!has_act_order) { + group_blocks = group_size == -1 ? -1 : (group_size / 16); + } + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + th_config.thread_n / 16, + th_config.thread_k / 16, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + th_config.num_threads, + is_zp_float); + + if (kernel == MarlinDefault) continue; + + if (thread_m_blocks > 1) { + exec_cfg = {1, th_config}; + break; + } else { + cudaFuncAttributes attr; + cudaFuncGetAttributes(&attr, kernel); + int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4; + int allow_count = min(device_max_reg_size / reg_size, max_shared_mem / (cache_size + 1024)); + allow_count = max(min(allow_count, 4), 1); + if (allow_count > count) { + count = allow_count; + exec_cfg = {count, th_config}; + }; + } + } + + return exec_cfg; +} + +template +void marlin_mm( + const void* A, + const void* B, + void* C, + void* C_tmp, + void* b_bias, + void* s, + void* s2, + void* zp, + void* g_idx, + void* perm, + void* a_tmp, + void* sorted_token_ids, + void* expert_ids, + void* num_tokens_past_padded, + void* topk_weights, + int moe_block_size, + int top_k, + bool mul_topk_weights, + bool is_ep, + int prob_m, + int prob_n, + int prob_k, + void* workspace, + host::ScalarType const& q_type, + bool has_bias, + bool has_act_order, + bool is_k_full, + bool has_zp, + int num_groups, + int group_size, + int dev, + cudaStream_t stream, + int thread_k, + int thread_n, + int sms, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) { + int thread_m_blocks = div_ceil(moe_block_size, 16); + bool m_block_size_8 = moe_block_size == 8; + + if (has_zp) { + host::RuntimeCheck( + q_type == host::kU4 || q_type == host::kU8, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); + } else { + host::RuntimeCheck( + q_type == host::kU4B8 || q_type == host::kU8B128 || q_type == host::kFE4M3fn || q_type == host::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + q_type.str()); + } + + host::RuntimeCheck( + prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + host::RuntimeCheck(group_size != -1); + group_blocks = group_size / 16; + host::RuntimeCheck( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } else { + host::RuntimeCheck(group_size == 0); + group_blocks = 0; + } + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + host::RuntimeCheck( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } + } + + int num_bits = q_type.size_bits(); + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + int4* C_tmp_ptr = (int4*)C_tmp; + const int4* bias_ptr = (const int4*)b_bias; + const int4* s_ptr = (const int4*)s; + const uint16_t* s2_ptr = (const uint16_t*)s2; + const int4* zp_ptr = (const int4*)zp; + const int* g_idx_ptr = (const int*)g_idx; + const int* perm_ptr = (const int*)perm; + int4* a_tmp_ptr = (int4*)a_tmp; + const int32_t* sorted_token_ids_ptr = (const int32_t*)sorted_token_ids; + const int32_t* expert_ids_ptr = (const int32_t*)expert_ids; + const int32_t* num_tokens_past_padded_ptr = (const int32_t*)num_tokens_past_padded; + const float* topk_weights_ptr = (const float*)topk_weights; + int* locks = (int*)workspace; + + if (has_act_order) { + // Permute A columns + auto perm_kernel = permute_cols_kernel<8>; + if (moe_block_size == 8) { + } else if (moe_block_size == 16) + perm_kernel = permute_cols_kernel<16>; + else if (moe_block_size == 32) + perm_kernel = permute_cols_kernel<32>; + else if (moe_block_size == 48) + perm_kernel = permute_cols_kernel<48>; + else if (moe_block_size == 64) + perm_kernel = permute_cols_kernel<64>; + else + host::Panic("unsupported moe_block_size ", moe_block_size); + + // clang-format off + perm_kernel<<>>( + A_ptr, perm_ptr, a_tmp_ptr, sorted_token_ids_ptr, expert_ids_ptr, + num_tokens_past_padded_ptr, prob_m, prob_k, top_k); + // clang-format on + A_ptr = a_tmp_ptr; + prob_m = prob_m * top_k; + top_k = 1; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) has_act_order = false; + } + + int max_shared_mem = 0; + host::RuntimeDeviceCheck(cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev)); + host::RuntimeCheck(max_shared_mem > 0); + + // Set thread config + exec_config_t exec_cfg; + thread_config_t thread_tfg; + if (thread_k != -1 && thread_n != -1) { + thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = exec_config_t{1, thread_tfg}; + host::RuntimeCheck(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); + host::RuntimeCheck(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k); + } else { + // Auto config + exec_cfg = determine_exec_config( + q_type, + prob_m, + prob_n, + prob_k, + thread_m_blocks, + m_block_size_8, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem); + thread_tfg = exec_cfg.tb_cfg; + } + + int num_threads = thread_tfg.num_threads; + thread_k = thread_tfg.thread_k; + thread_n = thread_tfg.thread_n; + int blocks = sms * exec_cfg.blocks_per_sm; + if (exec_cfg.blocks_per_sm > 1) max_shared_mem = max_shared_mem / exec_cfg.blocks_per_sm - 1024; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + host::RuntimeCheck( + is_valid_config( + thread_tfg, + m_block_size_8, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem), + "Invalid thread config: thread_m_blocks = ", + thread_m_blocks, + ", thread_k = ", + thread_tfg.thread_k, + ", thread_n = ", + thread_tfg.thread_n, + ", num_threads = ", + thread_tfg.num_threads, + " for MKN = [", + prob_m, + ", ", + prob_k, + ", ", + prob_n, + "] and num_bits = ", + num_bits, + ", group_size = ", + group_size, + ", has_act_order = ", + has_act_order, + ", is_k_full = ", + is_k_full, + ", has_zp = ", + has_zp, + ", is_zp_float = ", + is_zp_float, + ", max_shared_mem = ", + max_shared_mem); + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + thread_n_blocks, + thread_k_blocks, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + num_threads, + is_zp_float); + + if (kernel == MarlinDefault) { + host::Panic( + "Unsupported shapes: MNK = [", + prob_m, + ", ", + prob_n, + ", ", + prob_k, + "]", + ", has_act_order = ", + has_act_order, + ", num_groups = ", + num_groups, + ", group_size = ", + group_size, + ", thread_m_blocks = ", + thread_m_blocks, + ", thread_n_blocks = ", + thread_n_blocks, + ", thread_k_blocks = ", + thread_k_blocks, + ", num_bits = ", + num_bits); + } + + host::RuntimeDeviceCheck(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem)); + // clang-format off + kernel<<>>( + A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, + sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, + topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, + prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce, max_shared_mem); + // clang-format on +} + +#endif + +} // namespace device::marlin_moe + +template +void moe_wna16_marlin_gemm( + tvm::ffi::TensorView a, + tvm::ffi::TensorView c, + tvm::ffi::TensorView b_q_weight, + tvm::ffi::TensorView b_bias, + tvm::ffi::TensorView b_scales, + tvm::ffi::TensorView global_scale, + tvm::ffi::TensorView b_zeros, + tvm::ffi::TensorView g_idx, + tvm::ffi::TensorView perm, + tvm::ffi::TensorView workspace, + tvm::ffi::TensorView sorted_token_ids, + tvm::ffi::TensorView expert_ids, + tvm::ffi::TensorView num_tokens_post_padded, + tvm::ffi::TensorView topk_weights, + tvm::ffi::TensorView a_tmp, + tvm::ffi::TensorView c_tmp, + int64_t moe_block_size, + int64_t top_k, + bool mul_topk_weights, + bool is_ep, + int64_t b_q_type_id, + int64_t size_m, + int64_t size_n, + int64_t size_k, + bool has_act_order, + bool has_bias, + bool is_k_full, + bool has_zp, + int64_t num_groups, + int64_t group_size, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) { + using namespace host; + + ScalarType const b_q_type = ScalarType::from_id(b_q_type_id); + int pack_factor = 32 / b_q_type.size_bits(); + + if (moe_block_size != 8) { + RuntimeCheck(moe_block_size % 16 == 0, "unsupported moe_block_size=", moe_block_size); + RuntimeCheck(moe_block_size >= 16 && moe_block_size <= 64, "unsupported moe_block_size=", moe_block_size); + } + + // Verify A + RuntimeCheck(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), ", size_m = ", size_m); + RuntimeCheck(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), ", size_k = ", size_k); + + // Verify B + RuntimeCheck( + size_k % device::marlin::tile_size == 0, + "size_k = ", + size_k, + " is not divisible by tile_size = ", + device::marlin::tile_size); + RuntimeCheck( + (size_k / device::marlin::tile_size) == b_q_weight.size(1), + "Shape mismatch: b_q_weight.size(1) = ", + b_q_weight.size(1), + ", size_k = ", + size_k, + ", tile_size = ", + device::marlin::tile_size); + RuntimeCheck( + b_q_weight.size(2) % device::marlin::tile_size == 0, + "b_q_weight.size(2) = ", + b_q_weight.size(2), + " is not divisible by tile_size = ", + device::marlin::tile_size); + int64_t actual_size_n = (b_q_weight.size(2) / device::marlin::tile_size) * pack_factor; + RuntimeCheck(size_n == actual_size_n, "size_n = ", size_n, ", actual_size_n = ", actual_size_n); + + // Verify device and strides + auto device = SymbolicDevice{}; + device.set_options(); + TensorMatcher({-1, -1}).with_dtype().with_device(device).verify(a); + + device.verify(b_q_weight.device()); + RuntimeCheck(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + device.verify(b_scales.device()); + RuntimeCheck(b_scales.is_contiguous(), "b_scales is not contiguous"); + + // thread_k, thread_n, sms + int thread_k = -1; + int thread_n = -1; + int sms = -1; + DLDevice dl_device = device.unwrap(); + int dev = dl_device.device_id; + cudaStream_t stream = LaunchKernel::resolve_device(dl_device); + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev)); + + // Verify c (allocation done in Python) + device.verify(c.device()); + RuntimeCheck(c.is_contiguous(), "c is not contiguous"); + RuntimeCheck( + c.size(0) == size_m * top_k, "Shape mismatch: c.size(0) = ", c.size(0), ", size_m * topk = ", size_m * top_k); + RuntimeCheck(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1), ", size_n = ", size_n); + + // Alloc c_tmp: SKIP, done in Python + + // Detect groupsize: b_scales rank and dims + RuntimeCheck(b_scales.dim() == 3, "b_scales rank = ", b_scales.dim(), " is not 3"); + RuntimeCheck(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), " is not size_n = ", size_n); + RuntimeCheck( + b_scales.size(1) == num_groups, "b_scales dim 1 = ", b_scales.size(1), " is not num_groups = ", num_groups); + + // Validate g_idx, perm (Optional unwrap done in Python; empty tensors when absent) + if (g_idx.size(g_idx.dim() - 1) > 0 && perm.size(perm.dim() - 1) > 0) { + device.verify(g_idx.device()); + RuntimeCheck(g_idx.is_contiguous(), "g_idx is not contiguous"); + device.verify(perm.device()); + RuntimeCheck(perm.is_contiguous(), "perm is not contiguous"); + + int64_t g_idx_last = g_idx.size(g_idx.dim() - 1); + int64_t perm_last = perm.size(perm.dim() - 1); + RuntimeCheck( + (g_idx_last == 0 && perm_last == 0) || (g_idx_last == size_k && perm_last == size_k), + "Unexpected g_idx.size(-1) = ", + g_idx_last, + " and perm.size(-1) = ", + perm_last, + ", where size_k = ", + size_k); + } + // has_act_order derivation: SKIP (passed as param) + + // Verify group_size consistency + if (has_act_order) { + // SKIP: a_tmp allocation done in Python + if (is_k_full) { + RuntimeCheck(num_groups > 1, "For act_order, num_groups must be > 1"); + RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); + } + } else { + if (num_groups > 1) { + RuntimeCheck( + size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by b_scales.size(1) = ", num_groups); + } + } + + // Verify global_scale (Optional unwrap done in Python) + int64_t global_scale_size = global_scale.size(0); + if (global_scale_size > 0) { + RuntimeCheck(b_q_type == kFE2M1f && group_size == 16, "global_scale can only be used for nvfp4 format."); + } else { + RuntimeCheck( + !(b_q_type == kFE2M1f && group_size == 16), "the global_scale parameter must be passed for nvfp4 format."); + } + + // Verify b_bias (Optional unwrap done in Python) + if (has_bias) { + device.verify(b_bias.device()); + RuntimeCheck(b_bias.is_contiguous(), "b_bias is not contiguous"); + RuntimeCheck(b_bias.size(1) == size_n, "b_bias.size(0) != size_n"); + RuntimeCheck(b_bias.stride(1) == 1, "b_bias.stride(1) != 1"); + } + + // b_zeros Optional unwrap + has_zp derivation: SKIP (done in Python) + + // Verify b_q_type vs has_zp + if (has_zp) { + device.verify(b_zeros.device()); + RuntimeCheck(b_zeros.is_contiguous(), "b_zeros is not contiguous"); + RuntimeCheck( + b_q_type == kU4 || b_q_type == kU8, "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); + } else { + RuntimeCheck( + b_q_type == kU4B8 || b_q_type == kU8B128 || b_q_type == kFE4M3fn || b_q_type == kFE2M1f, + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " + "float4_e2m1f when " + "has_zp = False. Got = ", + b_q_type.str()); + } + + if (has_zp && is_zp_float) { + RuntimeCheck( + std::is_same::value, + "Computation type must be float16 (half) when using float zero " + "points."); + } + + // Verify b_zeros + if (has_zp) { + RuntimeCheck(b_zeros.dim() == 3, "b_zeros rank = ", b_zeros.dim(), " is not 3"); + if (is_zp_float) { + RuntimeCheck(b_zeros.size(2) == size_n, "b_zeros dim 2 = ", b_zeros.size(2), " is not size_n = ", size_n); + RuntimeCheck( + num_groups == b_zeros.size(1), "b_zeros dim 1 = ", b_zeros.size(1), " is not num_groups = ", num_groups); + RuntimeCheck(num_groups != -1, "num_groups must be != -1"); + } else { + RuntimeCheck( + b_zeros.size(1) == num_groups, "b_zeros dim 1 = ", b_zeros.size(1), " is not num_groups = ", num_groups); + RuntimeCheck( + b_zeros.size(2) == size_n / pack_factor, + "b_zeros dim 2 = ", + b_zeros.size(2), + " is not size_n / pack_factor = ", + size_n / pack_factor); + } + } + + // Verify workspace size + RuntimeCheck( + size_n % device::marlin::min_thread_n == 0, + "size_n = ", + size_n, + ", is not divisible by min_thread_n = ", + device::marlin::min_thread_n); + + int64_t max_n_tiles = size_n / device::marlin::min_thread_n; + int64_t min_workspace_size = + std::min(max_n_tiles * (sorted_token_ids.size(0) / moe_block_size), static_cast(sms) * 4); + RuntimeCheck( + workspace.size(0) >= min_workspace_size, + "workspace.numel = ", + workspace.size(0), + " is below min_workspace_size = ", + min_workspace_size); + + // Early return for zero-size M (moved after all validation) + if (size_m == 0) return; + + device::marlin_moe::marlin_mm( + a.data_ptr(), + b_q_weight.data_ptr(), + c.data_ptr(), + c_tmp.data_ptr(), + b_bias.data_ptr(), + b_scales.data_ptr(), + global_scale.data_ptr(), + b_zeros.data_ptr(), + g_idx.data_ptr(), + perm.data_ptr(), + a_tmp.data_ptr(), + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), + num_tokens_post_padded.data_ptr(), + topk_weights.data_ptr(), + static_cast(moe_block_size), + static_cast(top_k), + mul_topk_weights, + is_ep, + static_cast(size_m), + static_cast(size_n), + static_cast(size_k), + workspace.data_ptr(), + b_q_type, + has_bias, + has_act_order, + is_k_full, + has_zp, + static_cast(num_groups), + static_cast(group_size), + dev, + stream, + thread_k, + thread_n, + sms, + use_atomic_add, + use_fp32_reduce, + is_zp_float); +} diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_expert_quant.cuh b/sglang/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_expert_quant.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f76936782090b21b2219c634a18b1509c76a9b03 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_expert_quant.cuh @@ -0,0 +1,712 @@ +#include +#include + +#include +#include +#include + +#include "nvfp4_quant.cuh" +#include +#include + +using namespace host; + +// Quantizes the provided PackedVec into the uint32_t output +template +SGL_DEVICE uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, uint8_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + +// Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + // Extract the 8 exponent bits from float32. + // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. + uint32_t tmp = reinterpret_cast(SFValue) >> 23; + fp8SFVal = tmp & 0xff; + // Convert back to fp32. + reinterpret_cast(SFValue) = tmp << 23; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; + // Convert back to fp32. + SFValue = float(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + fp2Vals[i] = device::cast(vec.elts[i]); + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +#else + return 0; +#endif +} + +SGL_DEVICE float silu(const float& val) { + return val / (1.0f + __expf(-val)); +} + +template +SGL_DEVICE void silu_and_mul(PackedVec& x_vec, const PackedVec& y_vec) { + float2 x[CVT_FP4_ELTS_PER_THREAD / 2]; + float2 y[CVT_FP4_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + x[i] = device::cast(x_vec.elts[i]); + y[i] = device::cast(y_vec.elts[i]); + x[i].x = silu(x[i].x) * y[i].x; + x[i].y = silu(x[i].y) * y[i].y; + x_vec.elts[i] = device::cast>(x[i]); + } +} + +// Use UE4M3 by default. +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(512, 4) cvt_fp16_to_fp4( +#else +cvt_fp16_to_fp4( +#endif + int32_t numRows, + int32_t numCols, + Type const* in, + float const* SFScale, + uint32_t* out, + uint32_t* SFout, + uint32_t* input_offset_by_experts, + uint32_t* output_scale_offset_by_experts, + int32_t* mask, + int n_experts, + bool low_latency) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); + + // Input tensor row/col loops. + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; + // TODO(kaixih@nvidia): For now, we assume mask is used together with + // silu_and_mal. Maybe we want a more general behavior of mask later. In the + // silu case, the input last dim doubles. + bool use_mask = mask != nullptr; + int actualColsPerRow = use_mask ? colsPerRow * 2 : colsPerRow; + + // Each global thread processes one element + for (int globalIdx = tid; globalIdx < numRows * colsPerRow; globalIdx += gridDim.x * blockDim.x) { + // Calculate which row and column this global thread should process + int rowIdx = globalIdx / colsPerRow; + int colIdx = globalIdx % colsPerRow; + + // Find index within the experts using different strategies based on expert + // count + int rowIdx_in_expert = 0; + int expert_idx = 0; + + if constexpr (SMALL_NUM_EXPERTS) { + for (int i = 0; i < n_experts; i++) { + uint32_t current_offset = __ldca(&input_offset_by_experts[i]); + uint32_t next_offset = __ldca(&input_offset_by_experts[i + 1]); + if (rowIdx >= current_offset && rowIdx < next_offset) { + rowIdx_in_expert = rowIdx - current_offset; + expert_idx = i; + break; + } + } + } else { + // Load input offsets into registers first, then do the computation. + // Local array size set to 17 because of register limit. + uint32_t local_offsets[17]; + for (int chunk_start = 0; chunk_start < n_experts; chunk_start += 16) { + *reinterpret_cast(local_offsets) = + __ldca(reinterpret_cast(&input_offset_by_experts[chunk_start])); + *reinterpret_cast(local_offsets + 4) = + __ldca(reinterpret_cast(&input_offset_by_experts[chunk_start + 4])); + *reinterpret_cast(local_offsets + 8) = + __ldca(reinterpret_cast(&input_offset_by_experts[chunk_start + 8])); + *reinterpret_cast(local_offsets + 12) = + __ldca(reinterpret_cast(&input_offset_by_experts[chunk_start + 12])); + local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]); + +// Check against the 16 loaded offsets +#pragma unroll + for (int i = 0; i < 16; i++) { + if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) { + rowIdx_in_expert = rowIdx - local_offsets[i]; + expert_idx = chunk_start + i; + break; + } + } + } + } + + // Early exit when using masks. + if (use_mask && rowIdx_in_expert >= mask[expert_idx]) { + continue; + } + + int64_t inOffset = rowIdx * actualColsPerRow + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + if (use_mask) { + PackedVec in_vec_mul = reinterpret_cast(in)[inOffset + colsPerRow]; + silu_and_mul(in_vec, in_vec_mul); + } + + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = rowIdx * colsPerRow + colIdx; + auto& out_pos = out[outOffset]; + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; + + int factor = CVT_FP4_SF_VEC_SIZE * 4; + // The actual output_scales dim is computed from the padded numCols. + int32_t numCols_padded = (numCols + factor - 1) / factor * factor; + int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4; + uint32_t* SFout_in_expert = SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout; + + auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( + rowIdx_in_expert, colIdx, numCols, SFout_in_expert); + + out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + } +#endif +} + +// Use UE4M3 by default. +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(512, 4) cvt_fp16_to_fp4_expert( +#else +cvt_fp16_to_fp4_expert( +#endif + int32_t numRows, + int32_t numCols, + Type const* in, + float const* SFScale, + uint32_t* out, + uint32_t* SFout, + int32_t* mask, + bool use_silu_and_mul, + int n_experts) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); + + // Input tensor row/col loops. + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = (gridDim.x * blockDim.x) / n_experts; + int remainder = (gridDim.x * blockDim.x) % n_experts; + int expert_idx; + int tid_in_expert; + int actual_stride; + if (remainder > 0) { + int bound = remainder * (stride + 1); + if (tid < bound) { + expert_idx = tid / (stride + 1); + tid_in_expert = tid % (stride + 1); + actual_stride = stride + 1; + } else { + expert_idx = remainder + (tid - bound) / stride; + tid_in_expert = (tid - bound) % stride; + actual_stride = stride; + } + } else { + expert_idx = tid / stride; + tid_in_expert = tid % stride; + actual_stride = stride; + } + int m = numRows / n_experts; + int padded_m = (m + (128 - 1)) / 128 * 128; + + int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; + // TODO(kaixih@nvidia): For now, we assume mask is used together with + // silu_and_mal. Maybe we want a more general behavior of mask later. In the + // silu case, the input last dim doubles. + bool use_mask = mask != nullptr; + int actualColsPerRow = use_silu_and_mul ? colsPerRow * 2 : colsPerRow; + + // Each global thread processes one element + for (int globalIdx = tid_in_expert + expert_idx * m * colsPerRow; globalIdx < (expert_idx + 1) * m * colsPerRow; + globalIdx += actual_stride) { + // Calculate which row and column this global thread should process + int rowIdx = globalIdx / colsPerRow; + int colIdx = globalIdx % colsPerRow; + + // Find index within the experts + int rowIdx_in_expert = rowIdx - expert_idx * m; + + // Early exit when using masks. + if (use_mask && rowIdx_in_expert >= mask[expert_idx]) { + break; + } + + int64_t inOffset = rowIdx * actualColsPerRow + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + if (use_silu_and_mul) { + PackedVec in_vec_mul = reinterpret_cast(in)[inOffset + colsPerRow]; + silu_and_mul(in_vec, in_vec_mul); + } + + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = rowIdx * colsPerRow + colIdx; + auto& out_pos = out[outOffset]; + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; + + int factor = CVT_FP4_SF_VEC_SIZE * 4; + // The actual output_scales dim is computed from the padded numCols. + int32_t numCols_padded = (numCols + factor - 1) / factor * factor; + int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4; + uint32_t* SFout_in_expert = SFout + expert_idx * padded_m * numCols_SFout; + + auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( + rowIdx_in_expert, colIdx, numCols, SFout_in_expert); + + out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + } +#endif +} + +// Kernel for LARGE_M_TOPK = true (large m_topk optimized version) +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(1024, 4) cvt_fp16_to_fp4( +#else +cvt_fp16_to_fp4( +#endif + int32_t numRows, + int32_t numCols, + Type const* in, + float const* SFScale, + uint32_t* out, + uint32_t* SFout, + uint32_t* input_offset_by_experts, + uint32_t* output_scale_offset_by_experts, + int32_t* mask, + int n_experts) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); + extern __shared__ uint32_t shared_input_offsets[]; + + // Load input offsets into shared memory. + // If n_experts is larger than 4, use vectorized int4 to save instructions. + // If n_experts is smaller than 4, read directly. + if constexpr (SMALL_NUM_EXPERTS) { + for (int i = threadIdx.x; i < n_experts + 1; i += blockDim.x) { + shared_input_offsets[i] = input_offset_by_experts[i]; + } + } else { + for (int i = threadIdx.x * 4; i < n_experts; i += blockDim.x * 4) { + *reinterpret_cast(&shared_input_offsets[i]) = *reinterpret_cast(&input_offset_by_experts[i]); + } + if (threadIdx.x == 0) { + shared_input_offsets[n_experts] = input_offset_by_experts[n_experts]; + } + } + + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; + bool use_mask = mask != nullptr; + int actualColsPerRow = use_mask ? colsPerRow * 2 : colsPerRow; + + // Each global thread processes one element + for (int globalIdx = tid; globalIdx < numRows * colsPerRow; globalIdx += gridDim.x * blockDim.x) { + // Calculate which row and column this global thread should process + int rowIdx = globalIdx / colsPerRow; + int colIdx = globalIdx % colsPerRow; + + // Find expert using binary search for better performance with large m_topk + int rowIdx_in_expert = 0; + int expert_idx = 0; + + // Binary search through experts using shared memory + int left = 0, right = n_experts - 1; + while (left <= right) { + int mid = (left + right) / 2; + // Get offsets: shared_input_offsets[i] corresponds to + // input_offset_by_experts[i] + uint32_t mid_offset = shared_input_offsets[mid]; + uint32_t next_offset = shared_input_offsets[mid + 1]; + + if (rowIdx >= mid_offset && rowIdx < next_offset) { + rowIdx_in_expert = rowIdx - mid_offset; + expert_idx = mid; + break; + } else if (rowIdx < mid_offset) { + right = mid - 1; + } else { + left = mid + 1; + } + } + + if (use_mask && rowIdx_in_expert >= mask[expert_idx]) { + continue; + } + + int64_t inOffset = rowIdx * actualColsPerRow + colIdx; + + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + if (use_mask) { + PackedVec in_vec_mul = reinterpret_cast(in)[inOffset + colsPerRow]; + silu_and_mul(in_vec, in_vec_mul); + } + + int64_t outOffset = rowIdx * colsPerRow + colIdx; + auto& out_pos = out[outOffset]; + + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; + + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numCols_padded = (numCols + factor - 1) / factor * factor; + int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4; + uint32_t* SFout_in_expert = SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout; + + auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( + rowIdx_in_expert, colIdx, numCols, SFout_in_expert); + + out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + } +#endif +} + +template +void quant_impl( + void* output, + void* output_scale, + void* input, + void* input_global_scale, + void* input_offset_by_experts, + void* output_scale_offset_by_experts, + void* mask, + bool use_silu_and_mul, + int m_topk, + int k, + int n_experts, + cudaStream_t stream) { + // TODO: this multiProcessorCount should be cached. + int device; + cudaGetDevice(&device); + int multiProcessorCount; + cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device); + + // Grid, Block size. + // Each thread converts 8 values. + int const workSizePerRow = k / ELTS_PER_THREAD; + int const totalWorkSize = m_topk * workSizePerRow; + dim3 block(std::min(workSizePerRow, 512)); + // Get number of blocks per SM (assume we can fully utilize the SM). + int const numBlocksPerSM = 2048 / block.x; + dim3 grid(std::min(static_cast((totalWorkSize + block.x - 1) / block.x), multiProcessorCount * numBlocksPerSM)); + while (grid.x <= multiProcessorCount && block.x > 64) { + grid.x *= 2; + block.x = (block.x + 1) / 2; + } + + // TODO(kaixih@nvidia): Should relax this to allow any grid size. + if (mask != nullptr) { + grid.x = (grid.x + n_experts - 1) / n_experts * n_experts; + cvt_fp16_to_fp4_expert<<>>( + m_topk, + k, + reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(mask), + use_silu_and_mul, + n_experts); + return; + } + + int const blockRepeat = (totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x); + if (blockRepeat > 1) { + size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t); + if (n_experts >= 4) { + cvt_fp16_to_fp4<<>>( + m_topk, + k, + reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), + reinterpret_cast(mask), + n_experts); + } else { + cvt_fp16_to_fp4<<>>( + m_topk, + k, + reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), + reinterpret_cast(mask), + n_experts); + } + } else { + if (n_experts >= 16) { + cvt_fp16_to_fp4<<>>( + m_topk, + k, + reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), + reinterpret_cast(mask), + n_experts, + /* bool low_latency */ true); + } else { + cvt_fp16_to_fp4<<>>( + m_topk, + k, + reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), + reinterpret_cast(mask), + n_experts, + /* bool low_latency */ true); + } + } +} + +inline int getSMVersion(int device_id) { + int sm_major = 0; + int sm_minor = 0; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device_id)); + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device_id)); + return sm_major * 10 + sm_minor; +} + +void scaled_fp4_experts_quant_sm100a( + tvm::ffi::TensorView output, + tvm::ffi::TensorView output_scale, + tvm::ffi::TensorView input, + tvm::ffi::TensorView input_global_scale, + tvm::ffi::TensorView input_offset_by_experts, + tvm::ffi::TensorView output_scale_offset_by_experts) { + auto MTopK = SymbolicSize{"m_topk"}; + auto K = SymbolicSize{"k"}; + auto OutputCols = SymbolicSize{"output_cols"}; + auto OutputScaleRows = SymbolicSize{"output_scale_rows"}; + auto OutputScaleCols = SymbolicSize{"output_scale_cols"}; + auto NExperts = SymbolicSize{"n_experts"}; + auto OffsetSize = SymbolicSize{"offset_size"}; + auto device = SymbolicDevice{}; + + TensorMatcher({MTopK, K}) // + .with_dtype() + .template with_device(device) + .verify(input); + TensorMatcher({MTopK, OutputCols}) // + .with_dtype() + .with_device(device) + .verify(output); + TensorMatcher({OutputScaleRows, OutputScaleCols}) // + .with_dtype() + .with_device(device) + .verify(output_scale); + TensorMatcher({NExperts}) // + .with_dtype() + .with_device(device) + .verify(input_global_scale); + TensorMatcher({OffsetSize}) // + .with_dtype() + .with_device(device) + .verify(input_offset_by_experts) + .verify(output_scale_offset_by_experts); + + const int device_id = input.device().device_id; + RuntimeCheck(getSMVersion(device_id) >= 100, "fp4_quant is only supported on sm100+"); + + const int BLOCK_SIZE = 16; + const auto m_topk = static_cast(MTopK.unwrap()); + const auto k = static_cast(K.unwrap()); + RuntimeCheck(k % BLOCK_SIZE == 0, "k must be a multiple of 16"); + const auto n_experts = static_cast(NExperts.unwrap()); + const auto offset_size = static_cast(OffsetSize.unwrap()); + RuntimeCheck(offset_size == n_experts + 1, "input/output offset size mismatch"); + RuntimeCheck(static_cast(OutputCols.unwrap()) == k / 2, "output second dim mismatch"); + const int scales_k = k / BLOCK_SIZE; + const int padded_k = (scales_k + 3) / 4 * 4; + RuntimeCheck(static_cast(OutputScaleCols.unwrap()) * 4 == padded_k, "output_scale second dim mismatch"); + + const cudaStream_t stream = LaunchKernel::resolve_device(input.device()); + if (host::is_type(input.dtype())) { + quant_impl( + output.data_ptr(), + output_scale.data_ptr(), + input.data_ptr(), + input_global_scale.data_ptr(), + input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), + nullptr, // mask + false, // use_silu_and_mul + m_topk, + k, + n_experts, + stream); + } else { + quant_impl<__nv_bfloat16>( + output.data_ptr(), + output_scale.data_ptr(), + input.data_ptr(), + input_global_scale.data_ptr(), + input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), + nullptr, // mask + false, // use_silu_and_mul + m_topk, + k, + n_experts, + stream); + } +} + +void silu_and_mul_scaled_fp4_experts_quant_sm100a( + tvm::ffi::TensorView output, + tvm::ffi::TensorView output_scale, + tvm::ffi::TensorView input, + tvm::ffi::TensorView input_global_scale, + tvm::ffi::TensorView mask, + bool use_silu_and_mul) { + auto MTopK = SymbolicSize{"m_topk"}; + auto KBy2 = SymbolicSize{"k_by_2"}; + auto OutputCols = SymbolicSize{"output_cols"}; + auto OutputScaleRows = SymbolicSize{"output_scale_rows"}; + auto OutputScaleCols = SymbolicSize{"output_scale_cols"}; + auto NExperts = SymbolicSize{"n_experts"}; + auto device = SymbolicDevice{}; + + TensorMatcher({MTopK, KBy2}) // + .with_dtype() + .template with_device(device) + .verify(input); + TensorMatcher({MTopK, OutputCols}) // + .with_dtype() + .with_device(device) + .verify(output); + TensorMatcher({OutputScaleRows, OutputScaleCols}) // + .with_dtype() + .with_device(device) + .verify(output_scale); + TensorMatcher({NExperts}) // + .with_dtype() + .with_device(device) + .verify(input_global_scale); + TensorMatcher({NExperts}) // + .with_dtype() + .with_device(device) + .verify(mask); + + const int device_id = input.device().device_id; + RuntimeCheck(getSMVersion(device_id) >= 100, "fp4_quant is only supported on sm100+"); + + const int BLOCK_SIZE = 16; + const auto m_topk = static_cast(MTopK.unwrap()); + const auto k_by_2 = static_cast(KBy2.unwrap()); + int k = k_by_2; + if (use_silu_and_mul) { + RuntimeCheck(k_by_2 % 2 == 0, "k must be a multiple of 2"); + k = k_by_2 / 2; + } + const auto n_experts = static_cast(NExperts.unwrap()); + RuntimeCheck(static_cast(OutputCols.unwrap()) == k / 2, "output second dim mismatch"); + const int scales_k = k / BLOCK_SIZE; + const int padded_k = (scales_k + 3) / 4 * 4; + RuntimeCheck(static_cast(OutputScaleCols.unwrap()) * 4 == padded_k, "output_scale second dim mismatch"); + + const cudaStream_t stream = LaunchKernel::resolve_device(input.device()); + if (host::is_type(input.dtype())) { + quant_impl( + output.data_ptr(), + output_scale.data_ptr(), + input.data_ptr(), + input_global_scale.data_ptr(), + nullptr, // input_offset_by_experts + nullptr, // output_scale_offset_by_experts + mask.data_ptr(), + use_silu_and_mul, + m_topk, + k, + n_experts, + stream); + } else { + quant_impl<__nv_bfloat16>( + output.data_ptr(), + output_scale.data_ptr(), + input.data_ptr(), + input_global_scale.data_ptr(), + nullptr, // input_offset_by_experts + nullptr, // output_scale_offset_by_experts + mask.data_ptr(), + use_silu_and_mul, + m_topk, + k, + n_experts, + stream); + } +} diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_quant.cuh b/sglang/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_quant.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e2d696673acd2e4af990a4ae332d0ef3390cd3a1 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_quant.cuh @@ -0,0 +1,160 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include + +#include +#include + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +SGL_DEVICE uint32_t fp32_vec_to_e2m1(float (&array)[8]) { + // PTX instructions used here requires >= sm100f. +#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED || CUTLASS_ARCH_MMA_SM120A_ENABLED || \ + (defined(__CUDA_ARCH_FAMILY_SPECIFIC__) && (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), + "f"(array[1]), + "f"(array[2]), + "f"(array[3]), + "f"(array[4]), + "f"(array[5]), + "f"(array[6]), + "f"(array[7])); + return val; +#else + printf("fp32_vec_to_e2m1 is not supported on this architecture\n"); + __trap(); + return 0; +#endif +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +SGL_DEVICE uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { + // PTX instructions used here requires >= sm100f. +#if CUTLASS_ARCH_MMA_SM100A_ENABLED || CUTLASS_ARCH_MMA_SM103A_ENABLED || CUTLASS_ARCH_MMA_SM120A_ENABLED || \ + (defined(__CUDA_ARCH_FAMILY_SPECIFIC__) && (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), + "f"(array[0].y), + "f"(array[1].x), + "f"(array[1].y), + "f"(array[2].x), + "f"(array[2].y), + "f"(array[3].x), + "f"(array[3].y)); + return val; +#else + printf("fp32_vec_to_e2m1 is not supported on this architecture\n"); + __trap(); + return 0; +#endif +} + +// Fast reciprocal. +SGL_DEVICE float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +SGL_DEVICE uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, int numCols, SFType* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride + + innerMIdx * innerMStride + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } +#endif + return nullptr; +} + +// Define a 16 bytes packed data type. +template +struct PackedVec { + packed_t elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_quant_entry.cuh b/sglang/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_quant_entry.cuh new file mode 100644 index 0000000000000000000000000000000000000000..29b06dfc0a5d55fd41d4beb88ad99966136ace8b --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_quant_entry.cuh @@ -0,0 +1,68 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +void scaled_fp4_quant_sm100a_sm120a( + tvm::ffi::TensorView output, + tvm::ffi::TensorView input, + tvm::ffi::TensorView output_sf, + tvm::ffi::TensorView input_sf); + +void scaled_fp4_experts_quant_sm100a( + tvm::ffi::TensorView output, + tvm::ffi::TensorView output_scale, + tvm::ffi::TensorView input, + tvm::ffi::TensorView input_global_scale, + tvm::ffi::TensorView input_offset_by_experts, + tvm::ffi::TensorView output_scale_offset_by_experts); + +void silu_and_mul_scaled_fp4_experts_quant_sm100a( + tvm::ffi::TensorView output, + tvm::ffi::TensorView output_scale, + tvm::ffi::TensorView input, + tvm::ffi::TensorView input_global_scale, + tvm::ffi::TensorView mask, + bool use_silu_and_mul); + +void scaled_fp4_quant( + tvm::ffi::TensorView output, + tvm::ffi::TensorView input, + tvm::ffi::TensorView output_sf, + tvm::ffi::TensorView input_sf) { + scaled_fp4_quant_sm100a_sm120a(output, input, output_sf, input_sf); +} + +void scaled_fp4_experts_quant( + tvm::ffi::TensorView output, + tvm::ffi::TensorView output_scale, + tvm::ffi::TensorView input, + tvm::ffi::TensorView input_global_scale, + tvm::ffi::TensorView input_offset_by_experts, + tvm::ffi::TensorView output_scale_offset_by_experts) { + scaled_fp4_experts_quant_sm100a( + output, output_scale, input, input_global_scale, input_offset_by_experts, output_scale_offset_by_experts); +} + +void silu_and_mul_scaled_fp4_experts_quant( + tvm::ffi::TensorView output, + tvm::ffi::TensorView output_scale, + tvm::ffi::TensorView input, + tvm::ffi::TensorView input_global_scale, + tvm::ffi::TensorView mask, + bool use_silu_and_mul) { + silu_and_mul_scaled_fp4_experts_quant_sm100a(output, output_scale, input, input_global_scale, mask, use_silu_and_mul); +} diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_quant_kernels.cuh b/sglang/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_quant_kernels.cuh new file mode 100644 index 0000000000000000000000000000000000000000..bac38d83e114a6b6c32ce8a21db290f20f6a4c95 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_quant_kernels.cuh @@ -0,0 +1,241 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include +#include + +#include "nvfp4_quant.cuh" +#include +#include + +using namespace host; + +// Quantizes the provided PackedVec into the uint32_t output +template +SGL_DEVICE uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, uint8_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + +// Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + __nv_fp8_e8m0 tmp; + tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); + SFValue = static_cast(tmp); + fp8SFVal = tmp.__x; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + fp8SFVal = tmp.__x; + SFValue = static_cast(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +#else + return 0; +#endif +} + +// Use UE4M3 by default. +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(512, 4) cvt_fp16_to_fp4( +#else +cvt_fp16_to_fp4( +#endif + int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out, uint32_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) { + int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = inOffset; + auto& out_pos = out[outOffset]; + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset(rowIdx, colIdx, numCols, SFout); + + out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + } + } +#endif +} + +template +void invokeFP4Quantization( + int m, + int n, + T const* input, + float const* SFScale, + int64_t* output, + int32_t* SFOuput, + bool useUE8M0, + int multiProcessorCount, + cudaStream_t stream) { + // Grid, Block size. + // Each thread converts 8 values. + dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); + // Get number of blocks per SM (assume we can fully utilize the SM). + int const numBlocksPerSM = 2048 / block.x; + dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + + // Launch the cvt kernel. + if (useUE8M0) { + cvt_fp16_to_fp4<<>>( + m, n, input, SFScale, reinterpret_cast(output), reinterpret_cast(SFOuput)); + } else { + cvt_fp16_to_fp4<<>>( + m, n, input, SFScale, reinterpret_cast(output), reinterpret_cast(SFOuput)); + } +} + +// Instantiate the function. +template void invokeFP4Quantization( + int m, + int n, + half const* input, + float const* SFScale, + int64_t* output, + int32_t* SFOuput, + bool useUE8M0, + int multiProcessorCount, + cudaStream_t stream); + +template void invokeFP4Quantization( + int m, + int n, + __nv_bfloat16 const* input, + float const* SFScale, + int64_t* output, + int32_t* SFOuput, + bool useUE8M0, + int multiProcessorCount, + cudaStream_t stream); + +inline int getSMVersion(int device_id) { + int sm_major = 0; + int sm_minor = 0; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device_id)); + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device_id)); + return sm_major * 10 + sm_minor; +} + +void scaled_fp4_quant_sm100a_sm120a( + tvm::ffi::TensorView output, + tvm::ffi::TensorView input, + tvm::ffi::TensorView output_sf, + tvm::ffi::TensorView input_sf) { + RuntimeCheck(input.device().device_type == kDLCUDA, "input must be a CUDA tensor"); + RuntimeCheck(output.device() == input.device(), "output and input must be on same device"); + RuntimeCheck(output_sf.device() == input.device(), "output_sf and input must be on same device"); + RuntimeCheck(input_sf.device() == input.device(), "input_sf and input must be on same device"); + RuntimeCheck(input.dim() == 2, "input must be a 2D tensor"); + RuntimeCheck(output.dim() == 2, "output must be a 2D tensor"); + RuntimeCheck(output_sf.dim() == 2, "output_sf must be a 2D tensor"); + RuntimeCheck(input_sf.numel() == 1, "input_sf must have exactly one element"); + RuntimeCheck(host::is_type(output.dtype()), "output must be uint8"); + RuntimeCheck(host::is_type(output_sf.dtype()), "output_sf must be int32"); + RuntimeCheck(host::is_type(input_sf.dtype()), "input_sf must be float32"); + RuntimeCheck( + host::is_type(input.dtype()) || host::is_type(input.dtype()), "input dtype must be fp16 or bf16"); + + const int device_id = input.device().device_id; + const auto sm_version = getSMVersion(device_id); + RuntimeCheck(sm_version >= 100, "fp4_quant is only supported on sm100+"); + + const int32_t m = static_cast(input.size(0)); + const int32_t n = static_cast(input.size(1)); + + RuntimeCheck(output.size(0) == m, "output row size mismatch"); + RuntimeCheck(output.size(1) == n / 2, "output column size mismatch"); + RuntimeCheck(n % 16 == 0, "The N dimension must be multiple of 16."); + + const int multiProcessorCount = static_cast(runtime::get_sm_count(device_id)); + + auto input_sf_ptr = static_cast(input_sf.data_ptr()); + auto sf_out = static_cast(output_sf.data_ptr()); + auto output_ptr = static_cast(output.data_ptr()); + const cudaStream_t stream = LaunchKernel::resolve_device(input.device()); + + constexpr bool useUE8M0 = false; + if (host::is_type(input.dtype())) { + auto input_ptr = reinterpret_cast(input.data_ptr()); + invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, useUE8M0, multiProcessorCount, stream); + } else { + auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()); + invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, useUE8M0, multiProcessorCount, stream); + } +} diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_entry.cuh b/sglang/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_entry.cuh new file mode 100644 index 0000000000000000000000000000000000000000..72d68f7d5b09cf1a1a565e7914f435f924a91378 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_entry.cuh @@ -0,0 +1,34 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +void cutlass_scaled_fp4_mm_sm100a_sm120a( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha); + +void cutlass_scaled_fp4_mm( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha) { + cutlass_scaled_fp4_mm_sm100a_sm120a(D, A, B, A_sf, B_sf, alpha); +} diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_kernels.cuh b/sglang/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_kernels.cuh new file mode 100644 index 0000000000000000000000000000000000000000..9cc309f14b5518d37412c1a5fb39a076b7df8673 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_kernels.cuh @@ -0,0 +1,730 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include +#include + +#include +#include +#include +#include + +using namespace host; + +// clang-format off +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/packed_stride.hpp" +// clang-format on + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + RuntimeCheck(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ + } + +using namespace cute; + +// Helper function for next power of 2 +inline uint32_t next_pow_2(uint32_t x) { + if (x == 0) return 1; + x--; + x |= x >> 1; + x |= x >> 2; + x |= x >> 4; + x |= x >> 8; + x |= x >> 16; + return x + 1; +} + +struct WorkspaceKey { + int device_id; + uintptr_t stream; + auto operator==(const WorkspaceKey&) const -> bool = default; +}; + +struct WorkspaceKeyHash { + auto operator()(const WorkspaceKey& key) const -> size_t { + size_t h1 = std::hash{}(key.device_id); + size_t h2 = std::hash{}(key.stream); + return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2)); + } +}; + +struct WorkspaceState { + void* ptr = nullptr; + size_t bytes = 0; +}; + +inline auto get_cached_workspace(size_t required_bytes, int device_id, cudaStream_t stream) -> void* { + if (required_bytes == 0) { + return nullptr; + } + + thread_local std::unordered_map cache; + WorkspaceKey key{device_id, reinterpret_cast(stream)}; + auto& ws = cache[key]; + + if (ws.ptr != nullptr && ws.bytes >= required_bytes) { + return ws.ptr; + } + + RuntimeDeviceCheck(cudaSetDevice(device_id)); + if (ws.ptr != nullptr) { + RuntimeDeviceCheck(cudaFreeAsync(ws.ptr, stream)); + ws.ptr = nullptr; + ws.bytes = 0; + } + RuntimeDeviceCheck(cudaMallocAsync(&ws.ptr, required_bytes, stream)); + ws.bytes = required_bytes; + return ws.ptr; +} + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || \ + defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) +// Config(half_t/bfloat16_t) for M <= 128 +template +struct KernelConfigM128 { + using OutputType = T; + using MmaTileShape = Shape<_128, _256, _256>; + using ClusterShape = Shape; + using EpilogueTile = Shape<_128, _64>; // Avoid register spilling + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100; + const static dim3 preferred_cluster; + const static dim3 fallback_cluster; +}; +template +const dim3 KernelConfigM128::preferred_cluster(1, 4, 1); +template +const dim3 KernelConfigM128::fallback_cluster(1, 2, 1); + +// Config(half_t/bfloat16_t) for M <= 256 +template +struct KernelConfigM256 { + using OutputType = T; + using MmaTileShape = Shape<_256, _256, _256>; + using ClusterShape = Shape; + using EpilogueTile = Shape<_128, _64>; // Avoid register spilling + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100; + const static dim3 preferred_cluster; + const static dim3 fallback_cluster; +}; +template +const dim3 KernelConfigM256::preferred_cluster(2, 4, 1); +template +const dim3 KernelConfigM256::fallback_cluster(2, 1, 1); + +// Default config(half_t/bfloat16_t) for M > 256 +template +struct KernelConfigDefault { + using OutputType = T; + using MmaTileShape = Shape<_256, _256, _256>; + using ClusterShape = Shape; + using EpilogueTile = Shape<_128, _64>; // Avoid register spilling + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100; + const static dim3 preferred_cluster; + const static dim3 fallback_cluster; +}; +template +const dim3 KernelConfigDefault::preferred_cluster(4, 4, 1); +template +const dim3 KernelConfigDefault::fallback_cluster(2, 1, 1); + +struct KernelConfigFp32 { + using OutputType = float; + using MmaTileShape = Shape<_128, _128, _256>; + using ClusterShape = Shape; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100; + const static dim3 preferred_cluster; + const static dim3 fallback_cluster; +}; +const dim3 KernelConfigFp32::preferred_cluster = dim3(1, 4, 1); +const dim3 KernelConfigFp32::fallback_cluster = dim3(1, 2, 1); + +// SM120 specific configurations +struct sm120_fp4_config_M256 { + using ClusterShape = Shape<_1, _1, _1>; + using MmaTileShape = Shape<_128, _128, _128>; + using PerSmTileShape_MNK = Shape<_128, _128, _128>; +}; + +struct sm120_fp4_config_default { + using ClusterShape = Shape<_1, _1, _1>; + using MmaTileShape = Shape<_256, _128, _128>; + using PerSmTileShape_MNK = Shape<_256, _128, _128>; +}; + +template +struct Fp4GemmSm100 { + using Config = KernelConfig; // For generating args + using OutputType = typename KernelConfig::OutputType; + // A matrix configuration + using ElementA = cutlass::nv_float4_t; + using LayoutATag = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 32; + + // B matrix configuration + using ElementB = cutlass::nv_float4_t; + using LayoutBTag = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 32; + + // C/D matrix configuration + using ElementD = OutputType; + using ElementC = OutputType; + using LayoutCTag = cutlass::layout::RowMajor; + using LayoutDTag = cutlass::layout::RowMajor; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + // Kernel functional config + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + + // Kernel Perf config + using MmaTileShape = typename KernelConfig::MmaTileShape; + using ClusterShape = typename KernelConfig::ClusterShape; + using EpilogueTile = typename KernelConfig::EpilogueTile; + using EpilogueSchedule = typename KernelConfig::EpilogueSchedule; + using MainloopSchedule = typename KernelConfig::MainloopSchedule; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementAccumulator, + void, + LayoutCTag, + AlignmentC, + ElementD, + LayoutDTag, + AlignmentD, + EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + LayoutATag, + AlignmentA, + ElementB, + LayoutBTag, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, CollectiveEpilogue, void>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA = typename Gemm::GemmKernel::StrideA; + using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{})); + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using LayoutB = decltype(cute::make_layout(make_shape(0, 0, 0), StrideB{})); + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutC = decltype(cute::make_layout(make_shape(0, 0, 0), StrideC{})); + using StrideD = typename Gemm::GemmKernel::StrideD; + using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{})); +}; + +// SM120 specific GEMM template +template +struct Fp4GemmSm120 { + using ElementA = cutlass::nv_float4_t; + using LayoutATag = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 32; + + using ElementB = cutlass::nv_float4_t; + using LayoutBTag = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 32; + + using ElementD = OutType; + using ElementC = OutType; + using LayoutCTag = cutlass::layout::RowMajor; + using LayoutDTag = cutlass::layout::RowMajor; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm120; + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + + using MmaTileShape = typename Config::MmaTileShape; + using ClusterShape = typename Config::ClusterShape; + using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + PerSmTileShape_MNK, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + ElementC, + LayoutCTag, + AlignmentC, + ElementD, + LayoutDTag, + AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + LayoutATag, + AlignmentA, + ElementB, + LayoutBTag, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, CollectiveEpilogue, void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +typename T::Gemm::Arguments args_from_options( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int64_t M, + int64_t N, + int64_t K) { + using ElementA = typename T::Gemm::ElementA; + using ElementB = typename T::Gemm::ElementB; + using ElementSFA = cutlass::float_ue4m3_t; + using ElementSFB = cutlass::float_ue4m3_t; + using ElementD = typename T::Gemm::ElementD; + using ElementCompute = float; + using StrideA = typename T::StrideA; + using StrideB = typename T::StrideB; + using StrideD = typename T::StrideD; + using Sm1xxBlkScaledConfig = typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + int m = static_cast(M); + int n = static_cast(N); + int k = static_cast(K); + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1}); + + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); + + typename T::Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {// Mainloop arguments + static_cast(A.data_ptr()), + stride_A, + static_cast(B.data_ptr()), + stride_B, + static_cast(A_sf.data_ptr()), + layout_SFA, + static_cast(B_sf.data_ptr()), + layout_SFB}, + { // Epilogue arguments + {}, // epilogue.thread + nullptr, + stride_D, + static_cast(D.data_ptr()), + stride_D}}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + using KernelConfig = typename T::Config; + arguments.hw_info.cluster_shape = KernelConfig::preferred_cluster; + arguments.hw_info.cluster_shape_fallback = KernelConfig::fallback_cluster; + return arguments; +} + +template +void runGemm( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + typename T::Gemm gemm; + auto arguments = args_from_options(D, A, B, A_sf, B_sf, alpha, m, n, k); + + size_t workspace_size = T::Gemm::get_workspace_size(arguments); + int device_id = A.device().device_id; + void* workspace = get_cached_workspace(workspace_size, device_id, stream); + + CUTLASS_CHECK(gemm.can_implement(arguments)); + + CUTLASS_CHECK(gemm.initialize(arguments, workspace, stream)); + + CUTLASS_CHECK(gemm.run(arguments, workspace, stream)); +} + +// SM120 specific args_from_options function +template +typename Gemm::Arguments args_from_options_sm120( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int M, + int N, + int K) { + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementD = typename Gemm::ElementD; + using ElementSFA = cutlass::float_ue4m3_t; + using ElementSFB = cutlass::float_ue4m3_t; + using ElementCompute = float; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); + + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, 1}, + {static_cast(A.data_ptr()), + stride_A, + static_cast(B.data_ptr()), + stride_B, + static_cast(A_sf.data_ptr()), + layout_SFA, + static_cast(B_sf.data_ptr()), + layout_SFB}, + {{}, static_cast(D.data_ptr()), stride_D, static_cast(D.data_ptr()), stride_D}}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + + return arguments; +} + +// SM120 specific runGemm function +template +void runGemmSm120( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int M, + int N, + int K, + cudaStream_t stream) { + Gemm gemm; + + auto arguments = args_from_options_sm120(D, A, B, A_sf, B_sf, alpha, M, N, K); + + size_t workspace_size = Gemm::get_workspace_size(arguments); + int device_id = A.device().device_id; + void* workspace = get_cached_workspace(workspace_size, device_id, stream); + + CUTLASS_CHECK(gemm.can_implement(arguments)); + + CUTLASS_CHECK(gemm.initialize(arguments, workspace, stream)); + + CUTLASS_CHECK(gemm.run(arguments, workspace, stream)); +} + +// Dispatch function to select appropriate config based on M +template +void cutlassFp4GemmDispatch( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + if (m <= 128) { + // m in [1, 128] + runGemm>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (m <= 256) { + // m in (128, 256] + runGemm>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else { + // m in (256, inf) + runGemm>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } +} + +// Dispatch function to select appropriate config based on M +template <> +void cutlassFp4GemmDispatch( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + runGemm>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); +} + +// SM120 specific dispatch functions +void cutlass_fp4_bf16_gemm_dispatch_sm120( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int m, + int n, + int k, + cudaStream_t stream) { + uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); + if (mp2 <= 256) { + runGemmSm120::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else { + runGemmSm120::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } +} + +void cutlass_fp4_f16_gemm_dispatch_sm120( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int m, + int n, + int k, + cudaStream_t stream) { + uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); + if (mp2 <= 256) { + runGemmSm120::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else { + runGemmSm120::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } +} + +#else +template +void cutlassFp4GemmDispatch( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + RuntimeCheck( + false, + "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " + "a CUTLASS 3.8 source directory to enable support."); +} +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || + // defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + +inline int getSMVersion(int device_id) { + int sm_major = 0; + int sm_minor = 0; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device_id)); + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device_id)); + return sm_major * 10 + sm_minor; +} + +void cutlass_scaled_fp4_mm_sm100a_sm120a( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha) { + RuntimeCheck(A.device().device_type == kDLCUDA, "a must be a CUDA tensor"); + RuntimeCheck(B.device().device_type == kDLCUDA, "b must be a CUDA tensor"); + RuntimeCheck(A_sf.device().device_type == kDLCUDA, "scale_a must be a CUDA tensor"); + RuntimeCheck(B_sf.device().device_type == kDLCUDA, "scale_b must be a CUDA tensor"); + RuntimeCheck(alpha.device().device_type == kDLCUDA, "alpha must be a CUDA tensor"); + RuntimeCheck(D.device().device_type == kDLCUDA, "out must be a CUDA tensor"); + + RuntimeCheck(A.device() == B.device(), "a and b must be on same device"); + RuntimeCheck(A.device() == A_sf.device(), "a and scale_a must be on same device"); + RuntimeCheck(A.device() == B_sf.device(), "a and scale_b must be on same device"); + RuntimeCheck(A.device() == alpha.device(), "a and alpha must be on same device"); + RuntimeCheck(A.device() == D.device(), "a and out must be on same device"); + + RuntimeCheck(A.is_contiguous(), "a must be contiguous"); + RuntimeCheck(B.is_contiguous(), "b must be contiguous"); + RuntimeCheck(A_sf.is_contiguous(), "scale_a must be contiguous"); + RuntimeCheck(B_sf.is_contiguous(), "scale_b must be contiguous"); + RuntimeCheck(alpha.is_contiguous(), "alpha must be contiguous"); + RuntimeCheck(D.is_contiguous(), "out must be contiguous"); + + RuntimeCheck(host::is_type(A.dtype()), "a must be uint8"); + RuntimeCheck(host::is_type(B.dtype()), "b must be uint8"); + RuntimeCheck(host::is_type(A_sf.dtype()), "scale_a must be float8_e4m3fn"); + RuntimeCheck(host::is_type(B_sf.dtype()), "scale_b must be float8_e4m3fn"); + RuntimeCheck(host::is_type(alpha.dtype()), "alpha must be float32"); + + RuntimeCheck(A.dim() == 2, "a must be a matrix"); + RuntimeCheck(B.dim() == 2, "b must be a matrix"); + RuntimeCheck(A_sf.dim() == 2, "scale_a must be a matrix"); + RuntimeCheck(B_sf.dim() == 2, "scale_b must be a matrix"); + RuntimeCheck(alpha.numel() == 1, "alpha must have exactly one element"); + + RuntimeCheck( + A.size(1) == B.size(1), + "a and b shapes cannot be multiplied (", + A.size(0), + "x", + A.size(1), + " and ", + B.size(0), + "x", + B.size(1), + ")"); + + const auto m = static_cast(A.size(0)); + const auto n = static_cast(B.size(0)); + const auto k = static_cast(A.size(1) * 2); + + RuntimeCheck(D.dim() == 2, "out must be 2D"); + RuntimeCheck(D.size(0) == m, "out first dim must equal m"); + RuntimeCheck(D.size(1) == n, "out second dim must equal n"); + + constexpr int alignment = 32; + RuntimeCheck(k % alignment == 0, "Expected k to be divisible by ", alignment, ", but got k: ", k); + RuntimeCheck(n % alignment == 0, "Expected n to be divisible by ", alignment, ", but got n: ", n); + + auto round_up = [](int64_t x, int64_t y) { return (x + y - 1) / y * y; }; + const int64_t rounded_m = round_up(m, 128); + const int64_t rounded_n = round_up(n, 128); + const int64_t rounded_k = round_up(k / 16, 4); + + RuntimeCheck( + A_sf.size(1) == B_sf.size(1), + "scale_a and scale_b shapes cannot be multiplied (", + A_sf.size(0), + "x", + A_sf.size(1), + " and ", + B_sf.size(0), + "x", + B_sf.size(1), + ")"); + RuntimeCheck( + A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k, + "scale_a must be padded/swizzled to shape (", + rounded_m, + "x", + rounded_k, + "), got (", + A_sf.size(0), + "x", + A_sf.size(1), + ")"); + RuntimeCheck( + B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k, + "scale_b must be padded/swizzled to shape (", + rounded_n, + "x", + rounded_k, + "), got (", + B_sf.size(0), + "x", + B_sf.size(1), + ")"); + + const cudaStream_t stream = LaunchKernel::resolve_device(A.device()); + const int sm_version = getSMVersion(A.device().device_id); + + if (sm_version >= 120) { + if (host::is_type(D.dtype())) { + cutlass_fp4_f16_gemm_dispatch_sm120( + D, A, B, A_sf, B_sf, alpha, static_cast(m), static_cast(n), static_cast(k), stream); + } else if (host::is_type(D.dtype())) { + cutlass_fp4_bf16_gemm_dispatch_sm120( + D, A, B, A_sf, B_sf, alpha, static_cast(m), static_cast(n), static_cast(k), stream); + } else { + Panic("Unsupported output data type of nvfp4 mm sm120"); + } + } else { + if (host::is_type(D.dtype())) { + cutlassFp4GemmDispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (host::is_type(D.dtype())) { + cutlassFp4GemmDispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (host::is_type(D.dtype())) { + cutlassFp4GemmDispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else { + Panic("Unsupported output data type of nvfp4 mm"); + } + } +} diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/per_tensor_quant_fp8.cuh b/sglang/python/sglang/jit_kernel/csrc/gemm/per_tensor_quant_fp8.cuh new file mode 100644 index 0000000000000000000000000000000000000000..17c5de8880c65cc7d20896340b5aad0f61335ab4 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/per_tensor_quant_fp8.cuh @@ -0,0 +1,139 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace { + +constexpr size_t kBlockSize = 256; + +// each warp will handle 512B data +template +__global__ void +per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, const int64_t num_elements) { + using namespace device; + constexpr uint32_t VEC_SIZE = 16 / sizeof(T); + + const int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + + float max_value = 0.0f; + if (gid * VEC_SIZE + VEC_SIZE <= num_elements) { + using vec_t = AlignedVector; + const auto gmem_in = tile::Memory::thread(); + const auto input_vec = gmem_in.load(input, gid); +#pragma unroll + for (uint32_t i = 0; i < VEC_SIZE; ++i) { + const float value = static_cast(input_vec[i]); + max_value = math::max(max_value, math::abs(value)); + } + } else if (gid * VEC_SIZE < num_elements) { + [[unlikely]]; // poorly aligned case, do not optimize + const auto remainder = num_elements - gid * VEC_SIZE; + for (uint32_t i = 0; i < remainder; ++i) { + const float value = static_cast(input[gid * VEC_SIZE + i]); + max_value = math::max(max_value, math::abs(value)); + } + } + + // reduce within block and then atomic reduce between blocks + __shared__ float smem[kWarpThreads]; + cta::reduce_max(max_value, smem); + if (threadIdx.x == 0) { + const auto max_value = smem[0]; + atomic::max(output_s, max_value / math::FP8_E4M3_MAX); + } +} + +[[maybe_unused]] +SGL_DEVICE float fp8_e4m3_clip(float val) { + namespace math = device::math; + return math::max(math::min(val, math::FP8_E4M3_MAX), -math::FP8_E4M3_MAX); +} + +template +__global__ void per_tensor_quant_fp8_kernel( + const T* __restrict__ input, + DST_DTYPE* __restrict__ output, + const float* __restrict__ scale, + const int64_t num_elements) { + using namespace device; + constexpr uint32_t VEC_SIZE = 16 / sizeof(T); + + const int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + const float scale_val = 1.0f / (*scale); + + if (gid * VEC_SIZE + VEC_SIZE <= num_elements) { + using input_vec_t = AlignedVector; + using output_vec_t = AlignedVector; + const auto gmem_in = tile::Memory::thread(); + const auto gmem_out = tile::Memory::thread(); + const auto input_vec = gmem_in.load(input, gid); + output_vec_t output_vec; +#pragma unroll + for (uint32_t i = 0; i < VEC_SIZE; ++i) { + const float value = fp8_e4m3_clip(static_cast(input_vec[i]) * scale_val); + output_vec[i] = static_cast(value); + } + gmem_out.store(output, output_vec, gid); + } else if (gid * VEC_SIZE < num_elements) { + [[unlikely]]; // poorly aligned case, do not optimize + const auto remainder = num_elements - gid * VEC_SIZE; + for (uint32_t i = 0; i < remainder; ++i) { + const float value = fp8_e4m3_clip(static_cast(input[gid * VEC_SIZE + i]) * scale_val); + output[gid * VEC_SIZE + i] = static_cast(value); + } + } +} + +template +void per_tensor_quant_fp8(tvm::ffi::TensorView input, tvm::ffi::TensorView output_q, tvm::ffi::TensorView output_s) { + using namespace host; + + auto device = SymbolicDevice{}; + auto N = SymbolicSize{"num_elements"}; + device.set_options(); + + TensorMatcher({N}) // + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({N}) // + .with_dtype() + .with_device(device) + .verify(output_q); + TensorMatcher({1}) // + .with_dtype() + .with_device(device) + .verify(output_s); + + const auto num_elements = N.unwrap(); + + constexpr size_t kElementsPerBlock = kBlockSize * (16 / sizeof(DType)); + const uint32_t num_blocks = div_ceil(num_elements, kElementsPerBlock); + + if constexpr (!kIsStatic) { + LaunchKernel(num_blocks, kBlockSize, device.unwrap())( + per_tensor_absmax_kernel, + static_cast(input.data_ptr()), + static_cast(output_s.data_ptr()), + static_cast(num_elements)); + } + + LaunchKernel(num_blocks, kBlockSize, device.unwrap())( + per_tensor_quant_fp8_kernel, + static_cast(input.data_ptr()), + static_cast(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + static_cast(num_elements)); +} + +} // namespace diff --git a/sglang/python/sglang/jit_kernel/csrc/gemm/per_token_group_quant_8bit.cuh b/sglang/python/sglang/jit_kernel/csrc/gemm/per_token_group_quant_8bit.cuh new file mode 100644 index 0000000000000000000000000000000000000000..20724c92bc99dd79beba45446825ce32fef45cb8 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/gemm/per_token_group_quant_8bit.cuh @@ -0,0 +1,218 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace { + +constexpr int kThreadsPerGroup = 16; + +__device__ __forceinline__ float GroupReduceMax(float val, const int tid) { + unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff; + val = fmaxf(val, __shfl_xor_sync(mask, val, 8)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 4)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 2)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 1)); + return val; +} + +template +using scale_packed_t_t = std::conditional_t; + +template +using scale_element_t_t = std::conditional_t; + +template +__global__ void per_token_group_quant_8bit_kernel( + const T* __restrict__ input, + DST_DTYPE* __restrict__ output_q, + scale_packed_t_t* __restrict__ output_s, + const int group_size, + const int num_groups, + const int groups_per_block, + const float eps, + const float min_8bit, + const float max_8bit, + const int num_groups_per_row = 0, + const int scale_stride = 0) { + using namespace device; + namespace math = device::math; + + (void)num_groups; + + const int local_group_id = static_cast(threadIdx.x / kThreadsPerGroup); + const int lane_id = threadIdx.x % kThreadsPerGroup; + + const int64_t block_group_id = blockIdx.x * groups_per_block; + const int64_t global_group_id = block_group_id + local_group_id; + const int64_t block_group_offset = global_group_id * group_size; + + float local_absmax = eps; + + using scale_packed_t = scale_packed_t_t; + using scale_element_t = scale_element_t_t; + static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0); + + const T* group_input = input + block_group_offset; + DST_DTYPE* group_output = static_cast(output_q) + block_group_offset; + scale_element_t* scale_output = nullptr; + + if constexpr (kIsColumnMajor) { + constexpr int kElemsPerPack = static_cast(sizeof(scale_packed_t) / sizeof(scale_element_t)); + const int row_idx = global_group_id / num_groups_per_row; + const int col_idx_unpacked = global_group_id % num_groups_per_row; + const int col_idx = col_idx_unpacked / kElemsPerPack; + const int pack_idx = col_idx_unpacked % kElemsPerPack; + scale_output = reinterpret_cast(output_s) + + (col_idx * scale_stride * kElemsPerPack + row_idx * kElemsPerPack + pack_idx); + } else { + static_assert(!kScaleUE8M0); + scale_output = output_s + global_group_id; + } + + constexpr uint32_t kVecSize = 16 / sizeof(T); + using vec_t = AlignedVector; + const auto gmem_in = tile::Memory::thread(); + + const int32_t num_vec_elems = group_size / kVecSize; + + for (int32_t i = lane_id; i < num_vec_elems; i += kThreadsPerGroup) { + const vec_t input_vec = gmem_in.load(group_input, i); + +#pragma unroll + for (uint32_t j = 0; j < kVecSize; ++j) { + const float val = static_cast(input_vec[j]); + local_absmax = math::max(local_absmax, math::abs(val)); + } + } + + local_absmax = GroupReduceMax(local_absmax, lane_id); + + float y_s = local_absmax / max_8bit; + if constexpr (kScaleUE8M0) { + y_s = exp2f(ceilf(log2f(math::max(y_s, 1e-10f)))); + } + + scale_element_t y_s_quant; + if constexpr (kScaleUE8M0) { + y_s_quant = static_cast(((int)log2f(y_s)) + 127); + } else { + y_s_quant = y_s; + } + + if (lane_id == 0) { + *scale_output = y_s_quant; + } + + for (int32_t i = lane_id; i < num_vec_elems; i += kThreadsPerGroup) { + const vec_t input_vec = gmem_in.load(group_input, i); + +#pragma unroll + for (uint32_t j = 0; j < kVecSize; ++j) { + const float val = static_cast(input_vec[j]); + const float q_val = math::min(math::max(val / y_s, min_8bit), max_8bit); + group_output[i * kVecSize + j] = DST_DTYPE(q_val); + } + } +} + +inline int compute_groups_per_block(int64_t num_groups) { + if (num_groups % 16 == 0) return 16; + if (num_groups % 8 == 0) return 8; + if (num_groups % 4 == 0) return 4; + if (num_groups % 2 == 0) return 2; + return 1; +} + +template +void per_token_group_quant_8bit( + tvm::ffi::TensorView input, + tvm::ffi::TensorView output_q, + tvm::ffi::TensorView output_s, + int64_t group_size, + double eps, + double min_8bit, + double max_8bit, + bool scale_ue8m0) { + using namespace host; + + auto device = SymbolicDevice{}; + auto M = SymbolicSize{"num_tokens"}; + auto K = SymbolicSize{"hidden_dim"}; + device.set_options(); + + TensorMatcher({M, K}).with_dtype().with_device(device).verify(input); + TensorMatcher({M, K}).with_dtype().with_device(device).verify(output_q); + + const auto num_tokens = M.unwrap(); + const auto hidden_dim = K.unwrap(); + + const int64_t num_groups_per_row = hidden_dim / group_size; + const int64_t num_groups = num_tokens * num_groups_per_row; + + const int groups_per_block = compute_groups_per_block(num_groups); + const int num_blocks = num_groups / groups_per_block; + const int num_threads = groups_per_block * kThreadsPerGroup; + const bool is_column_major = output_s.stride(0) < output_s.stride(1); + const int scale_stride = output_s.stride(1); + + const float feps = static_cast(eps); + const float fmin8 = static_cast(min_8bit); + const float fmax8 = static_cast(max_8bit); + + if (is_column_major) { + if (scale_ue8m0) { + LaunchKernel(num_blocks, num_threads, input.device())( + per_token_group_quant_8bit_kernel, + static_cast(input.data_ptr()), + static_cast(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + static_cast(group_size), + static_cast(num_groups), + static_cast(groups_per_block), + feps, + fmin8, + fmax8, + static_cast(num_groups_per_row), + scale_stride); + } else { + LaunchKernel(num_blocks, num_threads, input.device())( + per_token_group_quant_8bit_kernel, + static_cast(input.data_ptr()), + static_cast(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + static_cast(group_size), + static_cast(num_groups), + static_cast(groups_per_block), + feps, + fmin8, + fmax8, + static_cast(num_groups_per_row), + scale_stride); + } + } else { + LaunchKernel(num_blocks, num_threads, input.device())( + per_token_group_quant_8bit_kernel, + static_cast(input.data_ptr()), + static_cast(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + static_cast(group_size), + static_cast(num_groups), + static_cast(groups_per_block), + feps, + fmin8, + fmax8, + 0, + 0); + } +} +} // namespace diff --git a/sglang/python/sglang/jit_kernel/csrc/hicache.cuh b/sglang/python/sglang/jit_kernel/csrc/hicache.cuh new file mode 100644 index 0000000000000000000000000000000000000000..04f093a02ba669892e487acf8de22d49e058857b --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/hicache.cuh @@ -0,0 +1,340 @@ +#include +#include + +#include +#include + +#include + +#include +#include +#include + +namespace device { + +namespace details { + +template +inline constexpr auto get_mem_package() { + if constexpr (kUnit == 16) { + return uint4{}; + } else if constexpr (kUnit == 8) { + return uint2{}; + } else if constexpr (kUnit == 4) { + return uint1{}; + } else { + static_assert(kUnit == 16 || kUnit == 8 || kUnit == 4, "Unsupported memory package size"); + } +} + +template +using PackageType = decltype(get_mem_package()); + +SGL_DEVICE uint1 load_nc(const uint1* __restrict__ src) { + uint32_t tmp; + asm volatile("ld.global.L1::no_allocate.b32 %0,[%1];" : "=r"(tmp) : "l"(src)); + return uint1{tmp}; +} + +SGL_DEVICE uint2 load_nc(const uint2* __restrict__ src) { + uint32_t tmp0, tmp1; + asm volatile("ld.global.L1::no_allocate.v2.b32 {%0,%1},[%2];" : "=r"(tmp0), "=r"(tmp1) : "l"(src)); + return uint2{tmp0, tmp1}; +} + +SGL_DEVICE uint4 load_nc(const uint4* __restrict__ src) { + uint32_t tmp0, tmp1, tmp2, tmp3; + asm volatile("ld.global.L1::no_allocate.v4.b32 {%0,%1,%2,%3},[%4];" + : "=r"(tmp0), "=r"(tmp1), "=r"(tmp2), "=r"(tmp3) + : "l"(src)); + return uint4{tmp0, tmp1, tmp2, tmp3}; +} + +SGL_DEVICE void store_nc(uint1* __restrict__ dst, const uint1& value) { + uint32_t tmp = value.x; + asm volatile("st.global.L1::no_allocate.b32 [%0],%1;" ::"l"(dst), "r"(tmp)); +} + +SGL_DEVICE void store_nc(uint2* __restrict__ dst, const uint2& value) { + uint32_t tmp0 = value.x; + uint32_t tmp1 = value.y; + asm volatile("st.global.L1::no_allocate.v2.b32 [%0],{%1,%2};" ::"l"(dst), "r"(tmp0), "r"(tmp1)); +} + +SGL_DEVICE void store_nc(uint4* __restrict__ dst, const uint4& value) { + uint32_t tmp0 = value.x; + uint32_t tmp1 = value.y; + uint32_t tmp2 = value.z; + uint32_t tmp3 = value.w; + asm volatile( + "st.global.L1::no_allocate.v4.b32 [%0],{%1,%2,%3,%4};" ::"l"(dst), "r"(tmp0), "r"(tmp1), "r"(tmp2), "r"(tmp3)); +} + +} // namespace details + +template +SGL_DEVICE auto load_vec(const void* __restrict__ src) { + static_assert(kBytes % 128 == 0, "kBytes must be multiple of 128 bytes"); + static_assert(128 % kNumThreads == 0, "kNumThreads must divide 128 bytes"); + constexpr uint32_t kLoopCount = kBytes / 128; + using Package = details::PackageType<128 / kNumThreads>; + using Storage = AlignedStorage; + + const auto src_packed = static_cast(src); + const auto lane_id = threadIdx.x % kNumThreads; + Storage vec; + +#pragma unroll kLoopCount + for (uint32_t i = 0; i < kLoopCount; ++i) { + const auto j = i * kNumThreads + lane_id; + vec.data[i] = details::load_nc(&src_packed[j]); + } + + return vec; +} + +template +SGL_DEVICE void store_vec(void* __restrict__ dst, const Storage& vec) { + using Package = std::decay_t; + constexpr uint32_t kBytesPerLoop = sizeof(Package) * kNumThreads; + constexpr uint32_t kLoopCount = kBytes / kBytesPerLoop; + static_assert(kBytes % kBytesPerLoop == 0, "Invalid Storage configuration"); + + const auto dst_packed = static_cast(dst); + const auto lane_id = threadIdx.x % kNumThreads; + +#pragma unroll kLoopCount + for (uint32_t i = 0; i < kLoopCount; ++i) { + const auto j = i * kNumThreads + lane_id; + details::store_nc(&dst_packed[j], vec.data[i]); + } +} + +} // namespace device + +namespace { + +#define SGL_HICACHE_KERNEL __global__ __launch_bounds__(kBlockSize, 1) + +struct HicacheKernelParams { + void* __restrict__ k_cache_dst; + void* __restrict__ v_cache_dst; + const void* __restrict__ indices_dst; + void* __restrict__ k_cache_src; + void* __restrict__ v_cache_src; + const void* __restrict__ indices_src; + int64_t kv_cache_src_stride; + int64_t kv_cache_dst_stride; + uint32_t length; + uint32_t num_layers = 0; // only used in all_layer transfer +}; + +template +SGL_HICACHE_KERNEL void hicache_transfer_per_layer(const __grid_constant__ HicacheKernelParams params) { + using namespace device; + static_assert(kBlockSize % kWarpThreads == 0); + static_assert(kWarpThreads % kUnroll == 0); + + constexpr uint32_t kNumThreads = kWarpThreads / kUnroll; + constexpr uint32_t kWorkersPerBlock = kBlockSize / kNumThreads; + constexpr uint32_t kNumWorkers = kWorkersPerBlock * kBlockQuota; + + const auto& [ + k_cache_dst, v_cache_dst, indices_dst, // dst + k_cache_src, v_cache_src, indices_src, // src + kv_cache_src_stride, kv_cache_dst_stride, length, _ // metadata + ] = params; + + const uint32_t work_id = blockIdx.x * kWorkersPerBlock + threadIdx.x / kNumThreads; + for (uint32_t i = work_id; i < length; i += kNumWorkers) { + const auto pos_src = static_cast(indices_src)[i]; + const auto pos_dst = static_cast(indices_dst)[i]; + const auto src_k = pointer::offset(k_cache_src, pos_src * kv_cache_src_stride); + const auto dst_k = pointer::offset(k_cache_dst, pos_dst * kv_cache_dst_stride); + const auto src_v = pointer::offset(v_cache_src, pos_src * kv_cache_src_stride); + const auto dst_v = pointer::offset(v_cache_dst, pos_dst * kv_cache_dst_stride); + const auto vec_k = load_vec(src_k); + const auto vec_v = load_vec(src_v); + store_vec(dst_k, vec_k); + store_vec(dst_v, vec_v); + } +} + +template +SGL_HICACHE_KERNEL void hicache_transfer_all_layer(const __grid_constant__ HicacheKernelParams params) { + using namespace device; + using src_ptr_t = const void*; + using dst_ptr_t = void*; + + static_assert(kBlockSize % kWarpThreads == 0); + static_assert(kWarpThreads % kUnroll == 0); + + constexpr uint32_t kNumThreads = kWarpThreads / kUnroll; + constexpr uint32_t kWorkersPerBlock = kBlockSize / kNumThreads; + constexpr uint32_t kNumWorkers = kWorkersPerBlock * kBlockQuota; + + const auto& [ + k_ptr_dst, v_ptr_dst, indices_dst, // dst + k_ptr_src, v_ptr_src, indices_src, // src + kv_cache_src_stride, kv_cache_dst_stride, length, num_layers // metadata + ] = params; + + const uint32_t work_id = blockIdx.x * kWorkersPerBlock + threadIdx.x / kNumThreads; + for (uint32_t i = work_id; i < length; i += kNumWorkers) { + const auto pos_src = static_cast(indices_src)[i]; + const auto pos_dst = static_cast(indices_dst)[i]; + for (uint32_t layer = 0; layer < num_layers; ++layer) { + const auto k_cache_src = static_cast(k_ptr_src)[layer]; + const auto v_cache_src = static_cast(v_ptr_src)[layer]; + const auto k_cache_dst = static_cast(k_ptr_dst)[layer]; + const auto v_cache_dst = static_cast(v_ptr_dst)[layer]; + const auto src_k = pointer::offset(k_cache_src, pos_src * kv_cache_src_stride); + const auto dst_k = pointer::offset(k_cache_dst, pos_dst * kv_cache_dst_stride); + const auto src_v = pointer::offset(v_cache_src, pos_src * kv_cache_src_stride); + const auto dst_v = pointer::offset(v_cache_dst, pos_dst * kv_cache_dst_stride); + const auto vec_k = load_vec(src_k); + const auto vec_v = load_vec(src_v); + store_vec(dst_k, vec_k); + store_vec(dst_v, vec_v); + } + } +} + +template +struct HiCacheKernel { + template + static constexpr auto kernel_one = hicache_transfer_per_layer; + template + static constexpr auto kernel_all = hicache_transfer_all_layer; + + static void run_one( + const tvm::ffi::TensorView k_cache_dst, + const tvm::ffi::TensorView v_cache_dst, + const tvm::ffi::TensorView indices_dst, + const tvm::ffi::TensorView k_cache_src, + const tvm::ffi::TensorView v_cache_src, + const tvm::ffi::TensorView indices_src) { + using namespace host; + + auto D = SymbolicSize{"head dimension"}; + auto N = SymbolicSize{"src kv stride"}; + auto M = SymbolicSize{"dst kv stride"}; + auto L = SymbolicSize{"indices length"}; + auto cache_dtype = SymbolicDType{}; + auto indices_dtype = SymbolicDType{}; + auto indices_device = SymbolicDevice{}; + + TensorMatcher({-1, D}) // + .with_strides({N, 1}) + .with_dtype(cache_dtype) + .with_device() + .verify(k_cache_src) + .verify(v_cache_src); + TensorMatcher({-1, D}) // + .with_strides({M, 1}) + .with_dtype(cache_dtype) + .with_device() + .verify(k_cache_dst) + .verify(v_cache_dst); + TensorMatcher({L}) // + .with_dtype(indices_dtype) + .with_device(indices_device) + .verify(indices_src) + .verify(indices_dst); + + // verify dimension match + const auto dtype_size = dtype_bytes(cache_dtype.unwrap()); + const auto element_bytes = D.unwrap() * dtype_size; + RuntimeCheck(kElementSize == element_bytes, "HicacheKernel: cache dimension mismatch."); + + const auto k_cache_dst_ptr = k_cache_dst.data_ptr(); + const auto v_cache_dst_ptr = v_cache_dst.data_ptr(); + const auto k_cache_src_ptr = k_cache_src.data_ptr(); + const auto v_cache_src_ptr = v_cache_src.data_ptr(); + const auto indices_dst_ptr = indices_dst.data_ptr(); + const auto indices_src_ptr = indices_src.data_ptr(); + const auto length = static_cast(L.unwrap()); + const auto kv_cache_src_stride = static_cast(N.unwrap() * dtype_size); + const auto kv_cache_dst_stride = static_cast(M.unwrap() * dtype_size); + const auto use_int32 = indices_dtype.unwrap().bits == 32; + const auto device = indices_device.unwrap(); + + constexpr auto kWorkersPerBlock = kBlockSize / (device::kWarpThreads / kUnroll); + const auto num_blocks = std::min(div_ceil(length, kWorkersPerBlock), kBlockQuota); + const auto params = HicacheKernelParams{ + .k_cache_dst = k_cache_dst_ptr, + .v_cache_dst = v_cache_dst_ptr, + .indices_dst = indices_dst_ptr, + .k_cache_src = k_cache_src_ptr, + .v_cache_src = v_cache_src_ptr, + .indices_src = indices_src_ptr, + .kv_cache_src_stride = kv_cache_src_stride, + .kv_cache_dst_stride = kv_cache_dst_stride, + .length = length, + }; + const auto kernel = use_int32 ? kernel_one : kernel_one; + LaunchKernel(num_blocks, kBlockSize, device)(kernel, params); + } + + static void run_all( + const tvm::ffi::TensorView k_ptr_dst, + const tvm::ffi::TensorView v_ptr_dst, + const tvm::ffi::TensorView indices_dst, + const tvm::ffi::TensorView k_ptr_src, + const tvm::ffi::TensorView v_ptr_src, + const tvm::ffi::TensorView indices_src, + const int64_t kv_src_stride_bytes, + const int64_t kv_dst_stride_bytes) { + using namespace host; + + auto N = SymbolicSize{"num_layers"}; + auto L = SymbolicSize{"indices length"}; + auto dtype_ = SymbolicDType{}; + auto device_ = SymbolicDevice{}; + + TensorMatcher({N}) // + .with_dtype() + .with_device(device_) + .verify(k_ptr_src) + .verify(v_ptr_src) + .verify(k_ptr_dst) + .verify(v_ptr_dst); + TensorMatcher({L}) // + .with_dtype(dtype_) + .with_device(device_) + .verify(indices_src) + .verify(indices_dst); + + // verify dimension match + const auto k_cache_dst_ptr = k_ptr_dst.data_ptr(); + const auto v_cache_dst_ptr = v_ptr_dst.data_ptr(); + const auto k_cache_src_ptr = k_ptr_src.data_ptr(); + const auto v_cache_src_ptr = v_ptr_src.data_ptr(); + const auto indices_dst_ptr = indices_dst.data_ptr(); + const auto indices_src_ptr = indices_src.data_ptr(); + const auto length = static_cast(L.unwrap()); + const auto use_int32 = dtype_.unwrap().bits == 32; + const auto device = device_.unwrap(); + + constexpr auto kWorkersPerBlock = kBlockSize / (device::kWarpThreads / kUnroll); + const auto num_blocks = std::min(div_ceil(length, kWorkersPerBlock), kBlockQuota); + const auto params = HicacheKernelParams{ + .k_cache_dst = k_cache_dst_ptr, + .v_cache_dst = v_cache_dst_ptr, + .indices_dst = indices_dst_ptr, + .k_cache_src = k_cache_src_ptr, + .v_cache_src = v_cache_src_ptr, + .indices_src = indices_src_ptr, + .kv_cache_src_stride = kv_src_stride_bytes, + .kv_cache_dst_stride = kv_dst_stride_bytes, + .length = length, + .num_layers = static_cast(N.unwrap()), + }; + const auto kernel = use_int32 ? kernel_all : kernel_all; + LaunchKernel(num_blocks, kBlockSize, device)(kernel, params); + } +}; + +#undef SGL_HICACHE_KERNEL + +} // namespace diff --git a/sglang/python/sglang/jit_kernel/csrc/moe/nvfp4_blockwise_moe.cuh b/sglang/python/sglang/jit_kernel/csrc/moe/nvfp4_blockwise_moe.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c3293fbfd08cfceb73b7d31b74a2ab46fcf0ec57 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/moe/nvfp4_blockwise_moe.cuh @@ -0,0 +1,882 @@ +#include +#include + +#include +#include + +#include +#include + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/tensor_view_io.h" +#include +#include +#include +#include +#include + +using namespace host; +using namespace cute; + +struct WorkspaceKey { + int device_id; + uintptr_t stream; + auto operator==(const WorkspaceKey&) const -> bool = default; +}; + +struct WorkspaceKeyHash { + auto operator()(const WorkspaceKey& key) const -> size_t { + size_t h1 = std::hash{}(key.device_id); + size_t h2 = std::hash{}(key.stream); + return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2)); + } +}; + +struct WorkspaceState { + void* ptr = nullptr; + size_t bytes = 0; +}; + +inline auto get_cached_workspace(size_t required_bytes, int device_id, cudaStream_t stream) -> void* { + if (required_bytes == 0) { + return nullptr; + } + + thread_local std::unordered_map cache; + WorkspaceKey key{device_id, reinterpret_cast(stream)}; + auto& ws = cache[key]; + + if (ws.ptr != nullptr && ws.bytes >= required_bytes) { + return ws.ptr; + } + + RuntimeDeviceCheck(cudaSetDevice(device_id)); + if (ws.ptr != nullptr) { + RuntimeDeviceCheck(cudaFreeAsync(ws.ptr, stream)); + ws.ptr = nullptr; + ws.bytes = 0; + } + RuntimeDeviceCheck(cudaMallocAsync(&ws.ptr, required_bytes, stream)); + ws.bytes = required_bytes; + return ws.ptr; +} + +inline int getSMVersion(int device_id) { + int sm_major = 0; + int sm_minor = 0; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device_id)); + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device_id)); + return sm_major * 10 + sm_minor; +} + +template < + typename ElementAB, + typename ElementC, + typename ElementSF, + typename ElementAccumulator, + typename LayoutSFA, + typename LayoutSFB, + typename ScaleConfig> +__global__ void __get_group_gemm_starts( + ElementAB** a_offsets, + ElementAB** b_offsets, + ElementC** out_offsets, + ElementSF** a_scales_offsets, + ElementSF** b_scales_offsets, + ElementAccumulator** alpha_offsets, + LayoutSFA* layout_sfa_base_as_int, + LayoutSFB* layout_sfb_base_as_int, + ElementAB* a_base_as_int, + ElementAB* b_base_as_int, + ElementC* out_base_as_int, + ElementSF* a_scales_base_as_int, + ElementSF* b_scales_base_as_int, + ElementAccumulator* alphas_base_as_int, + const int32_t* expert_offsets, + const int32_t* sf_offsets, + const int32_t* problem_sizes_as_shapes, + const int K, + const int N) { + int64_t expert_id = threadIdx.x; + if (expert_id >= gridDim.x * blockDim.x) { + return; + } + // Originally int32_t but upcasting to int64_t to avoid overflow + // during offset calculations + int64_t expert_offset = static_cast(expert_offsets[expert_id]); + int64_t sf_offset = static_cast(sf_offsets[expert_id]); + // size for block in block scale. + int64_t group_size = 16; + int64_t m = static_cast(problem_sizes_as_shapes[expert_id * 3]); + int64_t n = static_cast(problem_sizes_as_shapes[expert_id * 3 + 1]); + int64_t k = static_cast(problem_sizes_as_shapes[expert_id * 3 + 2]); + assert((m >= 0 && n == N && k == K && k % 2 == 0) && "unexpected problem sizes"); + + int64_t half_k = static_cast(k / 2); + int64_t group_k = static_cast(k / group_size); + // Shape of A as uint8/byte = [M, K // 2] + // Shape of B as uint8/byte = [E, N, K // 2] + a_offsets[expert_id] = a_base_as_int + expert_offset * half_k; + + b_offsets[expert_id] = b_base_as_int + expert_id * n * half_k; + // Shape of C = [M, N] + out_offsets[expert_id] = out_base_as_int + expert_offset * n; + // Shape of a_scale = [sum(sf_sizes), K // group_size] + a_scales_offsets[expert_id] = a_scales_base_as_int + sf_offset * group_k; + + assert((reinterpret_cast(a_scales_offsets[expert_id]) % 128) == 0 && "TMA requires 128-byte alignment"); + + // Shape of B scale = [E, N, K // group_size] + b_scales_offsets[expert_id] = b_scales_base_as_int + expert_id * n * group_k; + assert((reinterpret_cast(b_scales_offsets[expert_id]) % 128) == 0 && "TMA requires 128-byte alignment"); + // Shape of alpha = [E] + alpha_offsets[expert_id] = alphas_base_as_int + expert_id; + + LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id; + LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id; + + *layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA( + cute::make_shape(static_cast(m), static_cast(n), static_cast(k), 1)); + *layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB( + cute::make_shape(static_cast(m), static_cast(n), static_cast(k), 1)); +} + +#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE( \ + ELEMENT_AB_TYPE, SF_TYPE, TYPE_CHECK, C_TYPE, LayoutSFA, LayoutSFB, ScaleConfig) \ + else if (TYPE_CHECK) { \ + __get_group_gemm_starts \ + <<<1, num_experts, 0, stream>>>( \ + static_cast(a_starts.data_ptr()), \ + static_cast(b_starts.data_ptr()), \ + static_cast(out_starts.data_ptr()), \ + static_cast(a_scales_starts.data_ptr()), \ + static_cast(b_scales_starts.data_ptr()), \ + static_cast(alpha_starts.data_ptr()), \ + reinterpret_cast(layout_sfa.data_ptr()), \ + reinterpret_cast(layout_sfb.data_ptr()), \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), \ + static_cast(a_scales.data_ptr()), \ + static_cast(b_scales.data_ptr()), \ + static_cast(alphas.data_ptr()), \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(sf_offsets.data_ptr()), \ + static_cast(problem_sizes.data_ptr()), \ + K, \ + N); \ + } + +template +void run_get_group_gemm_starts( + const tvm::ffi::TensorView a_starts, + const tvm::ffi::TensorView b_starts, + const tvm::ffi::TensorView out_starts, + const tvm::ffi::TensorView a_scales_starts, + const tvm::ffi::TensorView b_scales_starts, + const tvm::ffi::TensorView alpha_starts, + const tvm::ffi::TensorView layout_sfa, + const tvm::ffi::TensorView layout_sfb, + /*these are used for their base addresses*/ + tvm::ffi::TensorView const& a_tensors, + tvm::ffi::TensorView const& b_tensors, + tvm::ffi::TensorView const& out_tensors, + tvm::ffi::TensorView const& a_scales, + tvm::ffi::TensorView const& b_scales, + tvm::ffi::TensorView const& alphas, + tvm::ffi::TensorView const& expert_offsets, + tvm::ffi::TensorView const& sf_offsets, + tvm::ffi::TensorView const& problem_sizes, + int M, + int N, + int K) { + int num_experts = static_cast(expert_offsets.size(0)); + auto stream = LaunchKernel::resolve_device(a_tensors.device()); + + RuntimeCheck(out_tensors.size(1) == N, "Output tensor shape doesn't match expected shape"); + RuntimeCheck( + K / 2 == b_tensors.size(2), + "b_tensors(dim = 2) and a_tensors(dim = 1) trailing" + " dimension must match"); + if (false) { + } + //(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, + // ScaleConfig) + __CALL_GET_STARTS_KERNEL_BLOCKSCALE( + cutlass::float_e2m1_t, + cutlass::float_ue4m3_t, + host::is_type(out_tensors.dtype()), + cutlass::bfloat16_t, + LayoutSFA, + LayoutSFB, + ScaleConfig) + __CALL_GET_STARTS_KERNEL_BLOCKSCALE( + cutlass::float_e2m1_t, + cutlass::float_ue4m3_t, + host::is_type(out_tensors.dtype()), + cutlass::half_t, + LayoutSFA, + LayoutSFB, + ScaleConfig) + else { + Panic("Invalid output type (must be float16 or bfloat16)"); + } +} + +void run_fp4_blockwise_scaled_group_mm_sm120( + tvm::ffi::TensorView output, + const tvm::ffi::TensorView a, + const tvm::ffi::TensorView b, + const tvm::ffi::TensorView a_blockscale, + const tvm::ffi::TensorView b_blockscales, + const tvm::ffi::TensorView alphas, + const tvm::ffi::TensorView ab_strides, + const tvm::ffi::TensorView c_strides, + const tvm::ffi::TensorView problem_sizes, + const tvm::ffi::TensorView expert_offsets, + const tvm::ffi::TensorView sf_offsets, + const tvm::ffi::TensorView a_ptrs, + const tvm::ffi::TensorView b_ptrs, + const tvm::ffi::TensorView out_ptrs, + const tvm::ffi::TensorView a_scales_ptrs, + const tvm::ffi::TensorView b_scales_ptrs, + const tvm::ffi::TensorView alpha_ptrs, + const tvm::ffi::TensorView layout_sfa, + const tvm::ffi::TensorView layout_sfb, + int M, + int N, + int K) { + using ProblemShape = cutlass::gemm::GroupProblemShape>; + using ElementType = cutlass::float_e2m1_t; + using ElementSFType = cutlass::float_ue4m3_t; + using ElementA = cutlass::nv_float4_t; + using ElementB = cutlass::nv_float4_t; + + using ElementC = cutlass::bfloat16_t; + using ElementD = cutlass::bfloat16_t; + using ElementAccumulator = float; + // Layout definitions + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + // Alignment constraints + static constexpr int AlignmentA = 32; + static constexpr int AlignmentB = 32; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + // Architecture definitions + using ArchTag = cutlass::arch::Sm120; + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using ThreadBlockShape = Shape<_128, _128, _128>; + // on the tile size + + using ClusterShape = Shape<_1, _1, _1>; + + using FusionOperation = + cutlass::epilogue::fusion::LinearCombination; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ThreadBlockShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + ElementC, + LayoutC*, + AlignmentC, + ElementD, + LayoutC*, + AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + LayoutA*, + AlignmentA, + ElementB, + LayoutB*, + AlignmentB, + ElementAccumulator, + ThreadBlockShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal; + + using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter; + using Gemm = Gemm1SM; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; + using ScaleConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; + int num_experts = static_cast(expert_offsets.size(0)); + + run_get_group_gemm_starts( + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + alpha_ptrs, + layout_sfa, + layout_sfb, + a, + b, + output, + a_blockscale, + b_blockscales, + alphas, + expert_offsets, + sf_offsets, + problem_sizes, + M, + N, + K); + + // Create an instance of the GEMM + Gemm gemm_op; + + // Initialize problem_sizes_as_shapes correctly + UnderlyingProblemShape* problem_sizes_as_shapes = static_cast(problem_sizes.data_ptr()); + + // Set the Scheduler info + cutlass::KernelHardwareInfo hw_info; + + using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; + scheduler.raster_order = RasterOrderOptions::AlongM; + hw_info.device_id = a.device().device_id; + static std::unordered_map cached_sm_counts; + if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) { + cached_sm_counts[hw_info.device_id] = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + hw_info.sm_count = std::min(cached_sm_counts[hw_info.device_id], std::numeric_limits::max()); + + // Mainloop Arguments + typename GemmKernel::MainloopArguments mainloop_args{ + static_cast(a_ptrs.data_ptr()), + static_cast(ab_strides.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(ab_strides.data_ptr()), + static_cast(a_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfa.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfb.data_ptr())}; + + // Epilogue Arguments + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, // epilogue.thread + nullptr, + static_cast(c_strides.data_ptr()), + static_cast(out_ptrs.data_ptr()), + static_cast(c_strides.data_ptr())}; + auto& fusion_args = epilogue_args.thread; + fusion_args.alpha_ptr_array = reinterpret_cast(alpha_ptrs.data_ptr()); + fusion_args.dAlpha = {_0{}, _0{}, 1}; + fusion_args.beta = 0.0f; + + // Gemm Arguments + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, problem_sizes_as_shapes, nullptr}, + mainloop_args, + epilogue_args, + hw_info, + scheduler}; + + size_t workspace_size = Gemm::get_workspace_size(args); + const cudaStream_t stream = LaunchKernel::resolve_device(a.device()); + void* workspace = get_cached_workspace(workspace_size, hw_info.device_id, stream); + + auto can_implement_status = gemm_op.can_implement(args); + RuntimeCheck( + can_implement_status == cutlass::Status::kSuccess, + "Failed to implement GEMM: ", + cutlassGetStatusString(can_implement_status)); + + // Run the GEMM + auto status = gemm_op.initialize(args, workspace); + RuntimeCheck(status == cutlass::Status::kSuccess, "Failed to initialize GEMM: ", cutlassGetStatusString(status)); + + status = gemm_op.run(args, workspace, stream); + RuntimeCheck(status == cutlass::Status::kSuccess, "Failed to run GEMM: ", cutlassGetStatusString(status)); +} + +template +void run_fp4_blockwise_scaled_group_mm_sm100( + tvm::ffi::TensorView output, + const tvm::ffi::TensorView a, + const tvm::ffi::TensorView b, + const tvm::ffi::TensorView a_blockscale, + const tvm::ffi::TensorView b_blockscales, + const tvm::ffi::TensorView alphas, + const tvm::ffi::TensorView ab_strides, + const tvm::ffi::TensorView c_strides, + const tvm::ffi::TensorView problem_sizes, + const tvm::ffi::TensorView expert_offsets, + const tvm::ffi::TensorView sf_offsets, + const tvm::ffi::TensorView a_ptrs, + const tvm::ffi::TensorView b_ptrs, + const tvm::ffi::TensorView out_ptrs, + const tvm::ffi::TensorView a_scales_ptrs, + const tvm::ffi::TensorView b_scales_ptrs, + const tvm::ffi::TensorView alpha_ptrs, + const tvm::ffi::TensorView layout_sfa, + const tvm::ffi::TensorView layout_sfb, + int M, + int N, + int K) { + using ProblemShape = cutlass::gemm::GroupProblemShape>; + using ElementType = cutlass::float_e2m1_t; + using ElementSFType = cutlass::float_ue4m3_t; + using ElementA = cutlass::nv_float4_t; + using ElementB = cutlass::nv_float4_t; + + using ElementC = OutType; + using ElementD = ElementC; + using ElementAccumulator = float; + // Layout definitions + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = LayoutC; + + // Alignment constraints + static constexpr int AlignmentA = 32; + static constexpr int AlignmentB = 32; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + // Architecture definitions + using ArchTag = cutlass::arch::Sm100; + using EpilogueOperatorClass = cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag + using MainloopOperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag + using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based + // on the tile size + + using ClusterShape = Shape<_1, _1, _1>; + struct MMA1SMConfig { + using MmaTileShape = Shape<_128, _128, _128>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + }; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + EpilogueOperatorClass, + typename MMA1SMConfig::MmaTileShape, + ClusterShape, + Shape<_128, _64>, + ElementAccumulator, + ElementAccumulator, + ElementC, + LayoutC*, + AlignmentC, + ElementD, + LayoutC*, + AlignmentD, + typename MMA1SMConfig::EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + MainloopOperatorClass, + ElementA, + LayoutA*, + AlignmentA, + ElementB, + LayoutB*, + AlignmentB, + ElementAccumulator, + typename MMA1SMConfig::MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename MMA1SMConfig::KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal; + + using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter; + using Gemm = Gemm1SM; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; + using ScaleConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; + int num_experts = static_cast(expert_offsets.size(0)); + + run_get_group_gemm_starts( + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + alpha_ptrs, + layout_sfa, + layout_sfb, + a, + b, + output, + a_blockscale, + b_blockscales, + alphas, + expert_offsets, + sf_offsets, + problem_sizes, + M, + N, + K); + + // Create an instance of the GEMM + Gemm gemm_op; + + // Initialize problem_sizes_as_shapes correctly + UnderlyingProblemShape* problem_sizes_as_shapes = static_cast(problem_sizes.data_ptr()); + + // Set the Scheduler info + cutlass::KernelHardwareInfo hw_info; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams< + typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions; + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; + scheduler.raster_order = RasterOrderOptions::AlongM; + hw_info.device_id = a.device().device_id; + static std::unordered_map cached_sm_counts; + if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) { + cached_sm_counts[hw_info.device_id] = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + hw_info.sm_count = std::min(cached_sm_counts[hw_info.device_id], std::numeric_limits::max()); + + // Mainloop Arguments + typename GemmKernel::MainloopArguments mainloop_args{ + static_cast(a_ptrs.data_ptr()), + static_cast(ab_strides.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(ab_strides.data_ptr()), + static_cast(a_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfa.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfb.data_ptr())}; + + // Epilogue Arguments + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, // epilogue.thread + nullptr, + static_cast(c_strides.data_ptr()), + static_cast(out_ptrs.data_ptr()), + static_cast(c_strides.data_ptr())}; + auto& fusion_args = epilogue_args.thread; + fusion_args.alpha_ptr_array = reinterpret_cast(alpha_ptrs.data_ptr()); + fusion_args.dAlpha = {_0{}, _0{}, 1}; + + // Gemm Arguments + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, problem_sizes_as_shapes, nullptr}, + mainloop_args, + epilogue_args, + hw_info, + scheduler}; + + size_t workspace_size = Gemm::get_workspace_size(args); + const cudaStream_t stream = LaunchKernel::resolve_device(a.device()); + void* workspace = get_cached_workspace(workspace_size, hw_info.device_id, stream); + + auto can_implement_status = gemm_op.can_implement(args); + RuntimeCheck( + can_implement_status == cutlass::Status::kSuccess, + "Failed to implement GEMM: ", + cutlassGetStatusString(can_implement_status)); + + // Run the GEMM + auto status = gemm_op.initialize(args, workspace); + RuntimeCheck(status == cutlass::Status::kSuccess, "Failed to initialize GEMM: ", cutlassGetStatusString(status)); + + status = gemm_op.run(args, workspace, stream); + RuntimeCheck(status == cutlass::Status::kSuccess, "Failed to run GEMM: ", cutlassGetStatusString(status)); +} + +void cutlass_fp4_group_mm_sm100a_sm120a( + tvm::ffi::TensorView output, + const tvm::ffi::TensorView a, + const tvm::ffi::TensorView b, + const tvm::ffi::TensorView a_blockscale, + const tvm::ffi::TensorView b_blockscales, + const tvm::ffi::TensorView alphas, + const tvm::ffi::TensorView ab_strides, + const tvm::ffi::TensorView c_strides, + const tvm::ffi::TensorView problem_sizes, + const tvm::ffi::TensorView expert_offsets, + const tvm::ffi::TensorView sf_offsets, + const tvm::ffi::TensorView a_ptrs, + const tvm::ffi::TensorView b_ptrs, + const tvm::ffi::TensorView out_ptrs, + const tvm::ffi::TensorView a_scales_ptrs, + const tvm::ffi::TensorView b_scales_ptrs, + const tvm::ffi::TensorView alpha_ptrs, + const tvm::ffi::TensorView layout_sfa, + const tvm::ffi::TensorView layout_sfb) { + auto check_cuda_contig = [](const tvm::ffi::TensorView t, const char* name) { + RuntimeCheck(t.device().device_type == kDLCUDA, name, " must be a CUDA tensor"); + RuntimeCheck(t.is_contiguous(), name, " must be contiguous"); + }; + + check_cuda_contig(output, "output"); + check_cuda_contig(a, "a"); + check_cuda_contig(b, "b"); + check_cuda_contig(a_blockscale, "a_blockscale"); + check_cuda_contig(b_blockscales, "b_blockscales"); + check_cuda_contig(alphas, "alphas"); + check_cuda_contig(ab_strides, "ab_strides"); + check_cuda_contig(c_strides, "c_strides"); + check_cuda_contig(problem_sizes, "problem_sizes"); + check_cuda_contig(expert_offsets, "expert_offsets"); + check_cuda_contig(sf_offsets, "sf_offsets"); + check_cuda_contig(a_ptrs, "a_ptrs"); + check_cuda_contig(b_ptrs, "b_ptrs"); + check_cuda_contig(out_ptrs, "out_ptrs"); + check_cuda_contig(a_scales_ptrs, "a_scales_ptrs"); + check_cuda_contig(b_scales_ptrs, "b_scales_ptrs"); + check_cuda_contig(alpha_ptrs, "alpha_ptrs"); + check_cuda_contig(layout_sfa, "layout_sfa"); + check_cuda_contig(layout_sfb, "layout_sfb"); + + RuntimeCheck( + output.device() == a.device() && a.device() == b.device() && a.device() == a_blockscale.device() && + a.device() == b_blockscales.device() && a.device() == alphas.device() && a.device() == ab_strides.device() && + a.device() == c_strides.device() && a.device() == problem_sizes.device() && + a.device() == expert_offsets.device() && a.device() == sf_offsets.device() && a.device() == a_ptrs.device() && + a.device() == b_ptrs.device() && a.device() == out_ptrs.device() && a.device() == a_scales_ptrs.device() && + a.device() == b_scales_ptrs.device() && a.device() == alpha_ptrs.device() && + a.device() == layout_sfa.device() && a.device() == layout_sfb.device(), + "all tensors must be on the same device"); + + RuntimeCheck(host::is_type(a.dtype()), "a must be uint8"); + RuntimeCheck(host::is_type(b.dtype()), "b must be uint8"); + RuntimeCheck(host::is_type(a_blockscale.dtype()), "a_blockscale must be float8_e4m3fn"); + RuntimeCheck(host::is_type(b_blockscales.dtype()), "b_blockscales must be float8_e4m3fn"); + RuntimeCheck(host::is_type(alphas.dtype()), "alphas must be float32"); + RuntimeCheck(host::is_type(ab_strides.dtype()), "ab_strides must be int64"); + RuntimeCheck(host::is_type(c_strides.dtype()), "c_strides must be int64"); + RuntimeCheck(host::is_type(problem_sizes.dtype()), "problem_sizes must be int32"); + RuntimeCheck(host::is_type(expert_offsets.dtype()), "expert_offsets must be int32"); + RuntimeCheck(host::is_type(sf_offsets.dtype()), "sf_offsets must be int32"); + RuntimeCheck(host::is_type(a_ptrs.dtype()), "a_ptrs must be int64"); + RuntimeCheck(host::is_type(b_ptrs.dtype()), "b_ptrs must be int64"); + RuntimeCheck(host::is_type(out_ptrs.dtype()), "out_ptrs must be int64"); + RuntimeCheck(host::is_type(a_scales_ptrs.dtype()), "a_scales_ptrs must be int64"); + RuntimeCheck(host::is_type(b_scales_ptrs.dtype()), "b_scales_ptrs must be int64"); + RuntimeCheck(host::is_type(alpha_ptrs.dtype()), "alpha_ptrs must be int64"); + RuntimeCheck(host::is_type(layout_sfa.dtype()), "layout_sfa must be int64"); + RuntimeCheck(host::is_type(layout_sfb.dtype()), "layout_sfb must be int64"); + RuntimeCheck( + host::is_type(output.dtype()) || host::is_type(output.dtype()), + "output must be bfloat16 or float16"); + + RuntimeCheck(a.dim() == 2, "a must be 2D"); + RuntimeCheck(b.dim() == 3, "b must be 3D"); + RuntimeCheck(a_blockscale.dim() == 2, "a_blockscale must be 2D"); + RuntimeCheck(b_blockscales.dim() == 3, "b_blockscales must be 3D"); + RuntimeCheck(alphas.dim() == 1, "alphas must be 1D"); + RuntimeCheck(ab_strides.dim() == 1, "ab_strides must be 1D"); + RuntimeCheck(c_strides.dim() == 1, "c_strides must be 1D"); + RuntimeCheck(problem_sizes.dim() == 2, "problem_sizes must be 2D"); + RuntimeCheck(expert_offsets.dim() == 1, "expert_offsets must be 1D"); + RuntimeCheck(sf_offsets.dim() == 1, "sf_offsets must be 1D"); + RuntimeCheck(a_ptrs.dim() == 1, "a_ptrs must be 1D"); + RuntimeCheck(b_ptrs.dim() == 1, "b_ptrs must be 1D"); + RuntimeCheck(out_ptrs.dim() == 1, "out_ptrs must be 1D"); + RuntimeCheck(a_scales_ptrs.dim() == 1, "a_scales_ptrs must be 1D"); + RuntimeCheck(b_scales_ptrs.dim() == 1, "b_scales_ptrs must be 1D"); + RuntimeCheck(alpha_ptrs.dim() == 1, "alpha_ptrs must be 1D"); + RuntimeCheck(layout_sfa.dim() == 2, "layout_sfa must be 2D"); + RuntimeCheck(layout_sfb.dim() == 2, "layout_sfb must be 2D"); + RuntimeCheck(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)"); + + const int num_experts = static_cast(expert_offsets.size(0)); + RuntimeCheck(problem_sizes.size(0) == num_experts, "problem_sizes size mismatch with expert_offsets"); + RuntimeCheck(sf_offsets.size(0) == num_experts, "sf_offsets size mismatch with expert_offsets"); + RuntimeCheck(alphas.size(0) == num_experts, "alphas size mismatch with expert_offsets"); + RuntimeCheck(ab_strides.size(0) == num_experts, "ab_strides size mismatch with expert_offsets"); + RuntimeCheck(c_strides.size(0) == num_experts, "c_strides size mismatch with expert_offsets"); + RuntimeCheck(a_ptrs.size(0) == num_experts, "a_ptrs size mismatch with expert_offsets"); + RuntimeCheck(b_ptrs.size(0) == num_experts, "b_ptrs size mismatch with expert_offsets"); + RuntimeCheck(out_ptrs.size(0) == num_experts, "out_ptrs size mismatch with expert_offsets"); + RuntimeCheck(a_scales_ptrs.size(0) == num_experts, "a_scales_ptrs size mismatch with expert_offsets"); + RuntimeCheck(b_scales_ptrs.size(0) == num_experts, "b_scales_ptrs size mismatch with expert_offsets"); + RuntimeCheck(alpha_ptrs.size(0) == num_experts, "alpha_ptrs size mismatch with expert_offsets"); + RuntimeCheck(layout_sfa.size(0) == num_experts && layout_sfa.size(1) == 5, "layout_sfa must be [num_experts, 5]"); + RuntimeCheck(layout_sfb.size(0) == num_experts && layout_sfb.size(1) == 5, "layout_sfb must be [num_experts, 5]"); + + int M = static_cast(a.size(0)); + int N = static_cast(b.size(1)); + int K = static_cast(2 * b.size(2)); + RuntimeCheck(output.dim() == 2, "output must be 2D"); + RuntimeCheck(output.size(0) == M && output.size(1) == N, "output shape mismatch"); + + auto sm_version = getSMVersion(a.device().device_id); + if (sm_version == 100 || sm_version == 103) { + if (host::is_type(output.dtype())) { + run_fp4_blockwise_scaled_group_mm_sm100( + output, + a, + b, + a_blockscale, + b_blockscales, + alphas, + ab_strides, + c_strides, + problem_sizes, + expert_offsets, + sf_offsets, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + alpha_ptrs, + layout_sfa, + layout_sfb, + M, + N, + K); + } else { + run_fp4_blockwise_scaled_group_mm_sm100( + output, + a, + b, + a_blockscale, + b_blockscales, + alphas, + ab_strides, + c_strides, + problem_sizes, + expert_offsets, + sf_offsets, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + alpha_ptrs, + layout_sfa, + layout_sfb, + M, + N, + K); + } + } else if (sm_version >= 120) { + if (host::is_type(output.dtype())) { + run_fp4_blockwise_scaled_group_mm_sm120( + output, + a, + b, + a_blockscale, + b_blockscales, + alphas, + ab_strides, + c_strides, + problem_sizes, + expert_offsets, + sf_offsets, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + alpha_ptrs, + layout_sfa, + layout_sfb, + M, + N, + K); + } else { + Panic("SM120 path currently supports only bfloat16 output"); + } + } else { + RuntimeCheck(false, "Unsupported SM version: ", sm_version); + } +} + +void cutlass_fp4_group_mm( + tvm::ffi::TensorView output, + const tvm::ffi::TensorView a, + const tvm::ffi::TensorView b, + const tvm::ffi::TensorView a_blockscale, + const tvm::ffi::TensorView b_blockscales, + const tvm::ffi::TensorView alphas, + const tvm::ffi::TensorView ab_strides, + const tvm::ffi::TensorView c_strides, + const tvm::ffi::TensorView problem_sizes, + const tvm::ffi::TensorView expert_offsets, + const tvm::ffi::TensorView sf_offsets, + const tvm::ffi::TensorView a_ptrs, + const tvm::ffi::TensorView b_ptrs, + const tvm::ffi::TensorView out_ptrs, + const tvm::ffi::TensorView a_scales_ptrs, + const tvm::ffi::TensorView b_scales_ptrs, + const tvm::ffi::TensorView alpha_ptrs, + const tvm::ffi::TensorView layout_sfa, + const tvm::ffi::TensorView layout_sfb) { + cutlass_fp4_group_mm_sm100a_sm120a( + output, + a, + b, + a_blockscale, + b_blockscales, + alphas, + ab_strides, + c_strides, + problem_sizes, + expert_offsets, + sf_offsets, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + alpha_ptrs, + layout_sfa, + layout_sfb); +} diff --git a/sglang/python/sglang/jit_kernel/csrc/nsa/fused_store_index_cache.cuh b/sglang/python/sglang/jit_kernel/csrc/nsa/fused_store_index_cache.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e649fda57db2d36527e89d7fbf119e51a1a839d4 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/csrc/nsa/fused_store_index_cache.cuh @@ -0,0 +1,124 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace { + +struct FusedStoreCacheParam { + const void* __restrict__ input; + void* __restrict__ cache; + const void* __restrict__ indices; + uint32_t num_tokens; +}; + +[[maybe_unused]] +SGL_DEVICE float fp8_e4m3_clip(float val) { + namespace math = device::math; + return math::max(math::min(val, math::FP8_E4M3_MAX), -math::FP8_E4M3_MAX); +} + +[[maybe_unused]] +SGL_DEVICE fp8x2_e4m3_t pack_fp8(float x, float y) { + return fp8x2_e4m3_t{fp32x2_t{fp8_e4m3_clip(x), fp8_e4m3_clip(y)}}; +} + +template +__global__ void fused_store_indexer_cache(const __grid_constant__ FusedStoreCacheParam param) { + using namespace device; + + /// NOTE: 132 = 128 + 4 + constexpr int64_t kPageBytes = 132 << kPageBits; + + // each warp handles 128 elements, each block handles multiple rows + const auto& [input, cache, indices, num_tokens] = param; + const auto global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto global_wid = global_tid / 32; + const auto lane_id = threadIdx.x % 32; + + if (global_wid >= num_tokens) return; + + PDLWaitPrimary(); // wait for primary kernel + + // prefetch the index + const auto index = static_cast(indices)[global_wid]; + // always load the value from input (don't store if invalid) + using KeyT2 = packed_t; + using InStorage = AlignedVector; + using OutStorage = AlignedVector; + const auto elems = static_cast(input)[global_tid]; + const auto [x0, x1] = cast(elems[0]); + const auto [y0, y1] = cast(elems[1]); + const auto local_max = fmaxf(fmaxf(fabs(x0), fabs(x1)), fmaxf(fabs(y0), fabs(y1))); + const auto abs_max = warp::reduce_max(local_max); + // use normal fp32 scale + const auto scale = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX; + const auto inv_scale = 1.0f / scale; + const int32_t page = index >> kPageBits; + const int32_t offset = index & ((1 << kPageBits) - 1); + const auto page_ptr = pointer::offset(cache, page * kPageBytes); + const auto value_ptr = pointer::offset(page_ptr, offset * 128); + const auto scale_ptr = pointer::offset(page_ptr, 128 << kPageBits, offset * 4); + OutStorage result; + result[0] = pack_fp8(x0 * inv_scale, x1 * inv_scale); + result[1] = pack_fp8(y0 * inv_scale, y1 * inv_scale); + static_cast(value_ptr)[lane_id] = result; + static_cast(scale_ptr)[0] = scale; + + PDLTriggerSecondary(); // launch secondary kernel +} + +template +struct FusedStoreCacheIndexerKernel { + static constexpr int32_t kLogSize = std::countr_zero(kPageSize); + /// NOTE: 132 = 128 + 4 (128 represent K and 4 represent scale) + static constexpr int64_t kPageBytes = 132 * kPageSize; + static constexpr auto kernel = fused_store_indexer_cache; + + static_assert(std::has_single_bit(kPageSize), "kPageSize must be a power of 2"); + static_assert(1 << kLogSize == kPageSize); + + static void run(tvm::ffi::TensorView input, tvm::ffi::TensorView cache, tvm::ffi::TensorView indices) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({N, 128}) // input + .with_dtype() + .with_device(device_) + .verify(input); + TensorMatcher({-1, -1}) // cache + .with_strides({kPageBytes, 1}) + .with_dtype() + .with_device(device_) + .verify(cache); + TensorMatcher({N}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + const auto num_tokens = static_cast(N.unwrap()); + const auto params = FusedStoreCacheParam{ + .input = input.data_ptr(), + .cache = cache.data_ptr(), + .indices = indices.data_ptr(), + .num_tokens = num_tokens, + }; + const auto kBlockSize = 128; + const auto num_blocks = div_ceil(num_tokens * 32, kBlockSize); + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/sglang/python/sglang/jit_kernel/cutedsl_gdn.py b/sglang/python/sglang/jit_kernel/cutedsl_gdn.py new file mode 100644 index 0000000000000000000000000000000000000000..968d0cef03c750b5dc9e923325ff731e07869613 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/cutedsl_gdn.py @@ -0,0 +1,1494 @@ +"""CuTe DSL Fused Sigmoid Gating Delta Rule Kernel for GDN Decode.""" + +import logging +from typing import Dict, Optional, Tuple + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch +from cutlass.cute.nvgpu import cpasync +from cutlass.cute.runtime import from_dlpack + +logger = logging.getLogger(__name__) + +_compiled_kernels: Dict[Tuple, object] = {} +_cu_seqlens_cache: Dict[Tuple, torch.Tensor] = {} +TILE_K = 128 +TILE_V = 32 +TILE_V_PADDED = 36 +TILE_V_SMALL = 16 +TILE_V_SMALL_PADDED = 20 +NUM_STAGES = 2 +NUM_THREADS = 128 +NUM_BLOCKS_PER_STATE_SMALL = 8 +NUM_THREADS_LARGE = 256 +NUM_WARPS_LARGE = 8 +V_PER_WARP = 4 +ROWS_PER_ITER = 8 +NUM_K_ITERS = TILE_K // ROWS_PER_ITER +SMALL_BATCH_THRESHOLD = 32 + + +def _define_kernels(): + """Define CuTe DSL kernels for normal and varlen decode modes.""" + + NUM_WARPS_SMALL = 4 + V_PER_WARP_SMALL = TILE_V_SMALL // NUM_WARPS_SMALL + ROWS_PER_ITER_SMALL = 32 // V_PER_WARP_SMALL + NUM_K_ITERS_SMALL = TILE_K // ROWS_PER_ITER_SMALL + + @cute.kernel + def gdn_kernel_small_batch( + tiled_copy_load: cute.TiledCopy, + h0_source: cute.Tensor, + smem_layout_staged: cute.Layout, + num_v_tiles: cutlass.Constexpr[int], + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + a: cute.Tensor, + b: cute.Tensor, + A_log: cute.Tensor, + dt_bias: cute.Tensor, + o: cute.Tensor, + h0_indices: cute.Tensor, + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + H: cutlass.Constexpr[int], + HV: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], + ): + """Small batch kernel for (N, 1, ...) format.""" + tidx, _, _ = cute.arch.thread_idx() + in_warp_tid = tidx % 32 + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + block_idx, _, _ = cute.arch.block_idx() + + batch_idx = block_idx // NUM_BLOCKS_PER_STATE_SMALL + batch_inner = block_idx % NUM_BLOCKS_PER_STATE_SMALL + num_v_tiles_per_block = num_v_tiles // NUM_BLOCKS_PER_STATE_SMALL + start_v_tile = batch_inner * num_v_tiles_per_block + + i_n = batch_idx // HV + i_hv = batch_idx % HV + i_h = i_hv // (HV // H) + + pool_idx = h0_indices[i_n] + + if pool_idx >= 0: + k_local = in_warp_tid // V_PER_WARP_SMALL + v_local = in_warp_tid % V_PER_WARP_SMALL + v_base = warp_idx * V_PER_WARP_SMALL + v_idx = v_base + v_local + + smem = cutlass.utils.SmemAllocator() + sData = smem.allocate_tensor(cutlass.Float32, smem_layout_staged, 128) + smem_o_layout = cute.make_layout((TILE_V_SMALL,), stride=(1,)) + smem_o = smem.allocate_tensor(cutlass.Float32, smem_o_layout, 128) + smem_k_layout = cute.make_layout((TILE_K,), stride=(1,)) + smem_q_layout = cute.make_layout((TILE_K,), stride=(1,)) + sK = smem.allocate_tensor(cutlass.Float32, smem_k_layout, 128) + sQ = smem.allocate_tensor(cutlass.Float32, smem_q_layout, 128) + + if tidx < TILE_K: + sK[tidx] = cutlass.Float32(k[i_n, 0, i_h, tidx]) + sQ[tidx] = cutlass.Float32(q[i_n, 0, i_h, tidx]) + + gSrc_batch = h0_source[(pool_idx, i_hv, None, None)] + gSrc = cute.local_tile(gSrc_batch, (TILE_K, TILE_V_SMALL), (0, None)) + thr_copy_load = tiled_copy_load.get_slice(tidx) + + prefetch_count = cutlass.min(NUM_STAGES - 1, num_v_tiles_per_block) + for v_tile_offset in range(prefetch_count): + v_tile = start_v_tile + v_tile_offset + stage = v_tile_offset % NUM_STAGES + gSrc_tile = gSrc[(None, None, v_tile)] + sData_stage = sData[(None, None, stage)] + thr_gSrc = thr_copy_load.partition_S(gSrc_tile) + thr_sData = thr_copy_load.partition_D(sData_stage) + cute.copy(tiled_copy_load, thr_gSrc, thr_sData) + cute.arch.cp_async_commit_group() + + r_A_log = cutlass.Float32(A_log[i_hv]) + r_dt_bias = cutlass.Float32(dt_bias[i_hv]) + r_a = cutlass.Float32(a[i_n, 0, i_hv]) + r_b = cutlass.Float32(b[i_n, 0, i_hv]) + + r_g = 0.0 + r_beta = 0.0 + if in_warp_tid == 0: + x = r_a + r_dt_bias + beta_x = softplus_beta * x + softplus_x = 0.0 + if beta_x <= softplus_threshold: + exp_beta_x = cute.exp(beta_x) + log_input = cutlass.Float32(1.0 + exp_beta_x) + log_result = cutlass.Float32(cute.log(log_input)) + softplus_x = cutlass.Float32( + (cutlass.Float32(1.0) / softplus_beta) * log_result + ) + else: + softplus_x = x + r_g_value = -cute.exp(r_A_log) * softplus_x + r_beta = 1.0 / (1.0 + cute.exp(-r_b)) + r_g = cute.exp(r_g_value) + + r_g = cute.arch.shuffle_sync(r_g, 0) + r_beta = cute.arch.shuffle_sync(r_beta, 0) + + cute.arch.barrier() + + if use_qk_l2norm: + sum_q_partial = 0.0 + sum_k_partial = 0.0 + if tidx < TILE_K: + q_val = sQ[tidx] + k_val = sK[tidx] + sum_q_partial = q_val * q_val + sum_k_partial = k_val * k_val + + for offset in [16, 8, 4, 2, 1]: + sum_q_partial += cute.arch.shuffle_sync_bfly( + sum_q_partial, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_k_partial += cute.arch.shuffle_sync_bfly( + sum_k_partial, offset=offset, mask=-1, mask_and_clamp=31 + ) + + if in_warp_tid == 0: + smem_o[warp_idx] = sum_q_partial + smem_o[warp_idx + 4] = sum_k_partial + cute.arch.barrier() + + inv_norm_q = 0.0 + inv_norm_k = 0.0 + if warp_idx == 0: + local_sum_q = 0.0 + local_sum_k = 0.0 + if in_warp_tid < NUM_WARPS_SMALL: + local_sum_q = smem_o[in_warp_tid] + local_sum_k = smem_o[in_warp_tid + 4] + for offset in [2, 1]: + local_sum_q += cute.arch.shuffle_sync_bfly( + local_sum_q, offset=offset, mask=-1, mask_and_clamp=31 + ) + local_sum_k += cute.arch.shuffle_sync_bfly( + local_sum_k, offset=offset, mask=-1, mask_and_clamp=31 + ) + if in_warp_tid == 0: + smem_o[0] = cute.rsqrt(local_sum_q + 1e-6) + smem_o[1] = cute.rsqrt(local_sum_k + 1e-6) + cute.arch.barrier() + + inv_norm_q = smem_o[0] + inv_norm_k = smem_o[1] + + if tidx < TILE_K: + sK[tidx] = sK[tidx] * inv_norm_k + sQ[tidx] = sQ[tidx] * scale * inv_norm_q + cute.arch.barrier() + else: + if tidx < TILE_K: + sQ[tidx] = sQ[tidx] * scale + cute.arch.barrier() + + for v_tile_offset in range(num_v_tiles_per_block): + v_tile = start_v_tile + v_tile_offset + stage = v_tile_offset % NUM_STAGES + + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + + next_v_tile_offset = v_tile_offset + prefetch_count + if next_v_tile_offset < num_v_tiles_per_block: + next_v_tile = start_v_tile + next_v_tile_offset + next_stage = next_v_tile_offset % NUM_STAGES + gSrc_next = gSrc[(None, None, next_v_tile)] + sData_next = sData[(None, None, next_stage)] + thr_gSrc = thr_copy_load.partition_S(gSrc_next) + thr_sData = thr_copy_load.partition_D(sData_next) + cute.copy(tiled_copy_load, thr_gSrc, thr_sData) + cute.arch.cp_async_commit_group() + + v_global = v_tile * TILE_V_SMALL + v_idx + r_v = cutlass.Float32(v[i_n, 0, i_hv, v_global]) + + sum_hk = 0.0 + for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16): + k_base = k_iter * ROWS_PER_ITER_SMALL + k_idx = k_base + k_local + h_val = sData[(k_idx, v_idx, stage)] * r_g + r_k_val = sK[k_idx] + sum_hk += h_val * r_k_val + + for offset in [4, 2, 1]: + sum_hk += cute.arch.shuffle_sync_bfly( + sum_hk, + offset=offset * V_PER_WARP_SMALL, + mask=-1, + mask_and_clamp=31, + ) + + v_new = (r_v - sum_hk) * r_beta + v_new = cute.arch.shuffle_sync(v_new, v_local) + + sum_hq = 0.0 + for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16): + k_base = k_iter * ROWS_PER_ITER_SMALL + k_idx = k_base + k_local + h_old = sData[(k_idx, v_idx, stage)] * r_g + r_k_val = sK[k_idx] + r_q_val = sQ[k_idx] + h_new = h_old + r_k_val * v_new + sData[(k_idx, v_idx, stage)] = h_new + sum_hq += h_new * r_q_val + + for offset in [4, 2, 1]: + sum_hq += cute.arch.shuffle_sync_bfly( + sum_hq, + offset=offset * V_PER_WARP_SMALL, + mask=-1, + mask_and_clamp=31, + ) + + if k_local == 0: + v_global_out = v_tile * TILE_V_SMALL + v_idx + o[(i_n, 0, i_hv, v_global_out)] = cutlass.BFloat16(sum_hq) + + cute.arch.barrier() + + for k_iter in range(NUM_K_ITERS_SMALL): + flat_idx = tidx + k_iter * 128 + k_write = flat_idx // TILE_V_SMALL + v_write = flat_idx % TILE_V_SMALL + if k_write < TILE_K: + h_val = sData[(k_write, v_write, stage)] + v_global_write = v_tile * TILE_V_SMALL + v_write + h0_source[(pool_idx, i_hv, k_write, v_global_write)] = h_val + + cute.arch.barrier() + + @cute.kernel + def gdn_kernel_small_batch_varlen( + tiled_copy_load: cute.TiledCopy, + h0_source: cute.Tensor, + smem_layout_staged: cute.Layout, + num_v_tiles: cutlass.Constexpr[int], + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + a: cute.Tensor, + b: cute.Tensor, + A_log: cute.Tensor, + dt_bias: cute.Tensor, + o: cute.Tensor, + h0_indices: cute.Tensor, + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + H: cutlass.Constexpr[int], + HV: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], + ): + """Small batch kernel for varlen decode (1, N, ...) format.""" + tidx, _, _ = cute.arch.thread_idx() + in_warp_tid = tidx % 32 + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + block_idx, _, _ = cute.arch.block_idx() + + batch_idx = block_idx // NUM_BLOCKS_PER_STATE_SMALL + batch_inner = block_idx % NUM_BLOCKS_PER_STATE_SMALL + num_v_tiles_per_block = num_v_tiles // NUM_BLOCKS_PER_STATE_SMALL + start_v_tile = batch_inner * num_v_tiles_per_block + + i_n = batch_idx // HV + i_hv = batch_idx % HV + i_h = i_hv // (HV // H) + + pool_idx = h0_indices[i_n] + + if pool_idx >= 0: + k_local = in_warp_tid // V_PER_WARP_SMALL + v_local = in_warp_tid % V_PER_WARP_SMALL + v_base = warp_idx * V_PER_WARP_SMALL + v_idx = v_base + v_local + + smem = cutlass.utils.SmemAllocator() + sData = smem.allocate_tensor(cutlass.Float32, smem_layout_staged, 128) + smem_o_layout = cute.make_layout((TILE_V_SMALL,), stride=(1,)) + smem_o = smem.allocate_tensor(cutlass.Float32, smem_o_layout, 128) + smem_k_layout = cute.make_layout((TILE_K,), stride=(1,)) + smem_q_layout = cute.make_layout((TILE_K,), stride=(1,)) + sK = smem.allocate_tensor(cutlass.Float32, smem_k_layout, 128) + sQ = smem.allocate_tensor(cutlass.Float32, smem_q_layout, 128) + + if tidx < TILE_K: + sK[tidx] = cutlass.Float32(k[0, i_n, i_h, tidx]) + sQ[tidx] = cutlass.Float32(q[0, i_n, i_h, tidx]) + + gSrc_batch = h0_source[(pool_idx, i_hv, None, None)] + gSrc = cute.local_tile(gSrc_batch, (TILE_K, TILE_V_SMALL), (0, None)) + thr_copy_load = tiled_copy_load.get_slice(tidx) + + prefetch_count = cutlass.min(NUM_STAGES - 1, num_v_tiles_per_block) + for v_tile_offset in range(prefetch_count): + v_tile = start_v_tile + v_tile_offset + stage = v_tile_offset % NUM_STAGES + gSrc_tile = gSrc[(None, None, v_tile)] + sData_stage = sData[(None, None, stage)] + thr_gSrc = thr_copy_load.partition_S(gSrc_tile) + thr_sData = thr_copy_load.partition_D(sData_stage) + cute.copy(tiled_copy_load, thr_gSrc, thr_sData) + cute.arch.cp_async_commit_group() + + r_A_log = cutlass.Float32(A_log[i_hv]) + r_dt_bias = cutlass.Float32(dt_bias[i_hv]) + r_a = cutlass.Float32(a[i_n, i_hv]) + r_b = cutlass.Float32(b[i_n, i_hv]) + + r_g = 0.0 + r_beta = 0.0 + if in_warp_tid == 0: + x = r_a + r_dt_bias + beta_x = softplus_beta * x + softplus_x = 0.0 + if beta_x <= softplus_threshold: + exp_beta_x = cute.exp(beta_x) + log_input = cutlass.Float32(1.0 + exp_beta_x) + log_result = cutlass.Float32(cute.log(log_input)) + softplus_x = cutlass.Float32( + (cutlass.Float32(1.0) / softplus_beta) * log_result + ) + else: + softplus_x = x + r_g_value = -cute.exp(r_A_log) * softplus_x + r_beta = 1.0 / (1.0 + cute.exp(-r_b)) + r_g = cute.exp(r_g_value) + + r_g = cute.arch.shuffle_sync(r_g, 0) + r_beta = cute.arch.shuffle_sync(r_beta, 0) + + cute.arch.barrier() + + if use_qk_l2norm: + sum_q_partial = 0.0 + sum_k_partial = 0.0 + if tidx < TILE_K: + q_val = sQ[tidx] + k_val = sK[tidx] + sum_q_partial = q_val * q_val + sum_k_partial = k_val * k_val + + for offset in [16, 8, 4, 2, 1]: + sum_q_partial += cute.arch.shuffle_sync_bfly( + sum_q_partial, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_k_partial += cute.arch.shuffle_sync_bfly( + sum_k_partial, offset=offset, mask=-1, mask_and_clamp=31 + ) + + if in_warp_tid == 0: + smem_o[warp_idx] = sum_q_partial + smem_o[warp_idx + 4] = sum_k_partial + cute.arch.barrier() + + inv_norm_q = 0.0 + inv_norm_k = 0.0 + if warp_idx == 0: + local_sum_q = 0.0 + local_sum_k = 0.0 + if in_warp_tid < NUM_WARPS_SMALL: + local_sum_q = smem_o[in_warp_tid] + local_sum_k = smem_o[in_warp_tid + 4] + for offset in [2, 1]: + local_sum_q += cute.arch.shuffle_sync_bfly( + local_sum_q, offset=offset, mask=-1, mask_and_clamp=31 + ) + local_sum_k += cute.arch.shuffle_sync_bfly( + local_sum_k, offset=offset, mask=-1, mask_and_clamp=31 + ) + if in_warp_tid == 0: + smem_o[0] = cute.rsqrt(local_sum_q + 1e-6) + smem_o[1] = cute.rsqrt(local_sum_k + 1e-6) + cute.arch.barrier() + + inv_norm_q = smem_o[0] + inv_norm_k = smem_o[1] + + if tidx < TILE_K: + sK[tidx] = sK[tidx] * inv_norm_k + sQ[tidx] = sQ[tidx] * scale * inv_norm_q + cute.arch.barrier() + else: + if tidx < TILE_K: + sQ[tidx] = sQ[tidx] * scale + cute.arch.barrier() + + for v_tile_offset in range(num_v_tiles_per_block): + v_tile = start_v_tile + v_tile_offset + stage = v_tile_offset % NUM_STAGES + + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + + next_v_tile_offset = v_tile_offset + prefetch_count + if next_v_tile_offset < num_v_tiles_per_block: + next_v_tile = start_v_tile + next_v_tile_offset + next_stage = next_v_tile_offset % NUM_STAGES + gSrc_next = gSrc[(None, None, next_v_tile)] + sData_next = sData[(None, None, next_stage)] + thr_gSrc = thr_copy_load.partition_S(gSrc_next) + thr_sData = thr_copy_load.partition_D(sData_next) + cute.copy(tiled_copy_load, thr_gSrc, thr_sData) + cute.arch.cp_async_commit_group() + + v_global = v_tile * TILE_V_SMALL + v_idx + r_v = cutlass.Float32(v[0, i_n, i_hv, v_global]) + + sum_hk = 0.0 + for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16): + k_base = k_iter * ROWS_PER_ITER_SMALL + k_idx = k_base + k_local + h_val = sData[(k_idx, v_idx, stage)] * r_g + r_k_val = sK[k_idx] + sum_hk += h_val * r_k_val + + for offset in [4, 2, 1]: + sum_hk += cute.arch.shuffle_sync_bfly( + sum_hk, + offset=offset * V_PER_WARP_SMALL, + mask=-1, + mask_and_clamp=31, + ) + + v_new = (r_v - sum_hk) * r_beta + v_new = cute.arch.shuffle_sync(v_new, v_local) + + sum_hq = 0.0 + for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16): + k_base = k_iter * ROWS_PER_ITER_SMALL + k_idx = k_base + k_local + h_old = sData[(k_idx, v_idx, stage)] * r_g + r_k_val = sK[k_idx] + r_q_val = sQ[k_idx] + h_new = h_old + r_k_val * v_new + sData[(k_idx, v_idx, stage)] = h_new + sum_hq += h_new * r_q_val + + for offset in [4, 2, 1]: + sum_hq += cute.arch.shuffle_sync_bfly( + sum_hq, + offset=offset * V_PER_WARP_SMALL, + mask=-1, + mask_and_clamp=31, + ) + + if k_local == 0: + v_global_out = v_tile * TILE_V_SMALL + v_idx + o[(0, i_n, i_hv, v_global_out)] = cutlass.BFloat16(sum_hq) + + cute.arch.barrier() + + for k_iter in range(NUM_K_ITERS_SMALL): + flat_idx = tidx + k_iter * 128 + k_write = flat_idx // TILE_V_SMALL + v_write = flat_idx % TILE_V_SMALL + if k_write < TILE_K: + h_val = sData[(k_write, v_write, stage)] + v_global_write = v_tile * TILE_V_SMALL + v_write + h0_source[(pool_idx, i_hv, k_write, v_global_write)] = h_val + + cute.arch.barrier() + + @cute.kernel + def gdn_kernel_large_batch( + tiled_copy_load: cute.TiledCopy, + h0_source: cute.Tensor, + smem_layout_staged: cute.Layout, + num_v_tiles: cutlass.Constexpr[int], + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + a: cute.Tensor, + b: cute.Tensor, + A_log: cute.Tensor, + dt_bias: cute.Tensor, + o: cute.Tensor, + h0_indices: cute.Tensor, + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + H: cutlass.Constexpr[int], + HV: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], + ): + """Large batch kernel for (N, 1, ...) format.""" + tidx, _, _ = cute.arch.thread_idx() + in_warp_tid = tidx % 32 + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + batch_idx, _, _ = cute.arch.block_idx() + i_n = batch_idx // HV + i_hv = batch_idx % HV + i_h = i_hv // (HV // H) + + pool_idx = h0_indices[i_n] + + if pool_idx >= 0: + k_local = in_warp_tid // V_PER_WARP + v_local = in_warp_tid % V_PER_WARP + v_base = warp_idx * V_PER_WARP + v_idx = v_base + v_local + + smem = cutlass.utils.SmemAllocator() + sData = smem.allocate_tensor(cutlass.Float32, smem_layout_staged, 128) + smem_o_layout = cute.make_layout((TILE_V,), stride=(1,)) + smem_o = smem.allocate_tensor(cutlass.Float32, smem_o_layout, 128) + smem_k_layout = cute.make_layout((TILE_K,), stride=(1,)) + smem_q_layout = cute.make_layout((TILE_K,), stride=(1,)) + sK = smem.allocate_tensor(cutlass.Float32, smem_k_layout, 128) + sQ = smem.allocate_tensor(cutlass.Float32, smem_q_layout, 128) + + if tidx < TILE_K: + sK[tidx] = cutlass.Float32(k[i_n, 0, i_h, tidx]) + sQ[tidx] = cutlass.Float32(q[i_n, 0, i_h, tidx]) + + gSrc_batch = h0_source[(pool_idx, i_hv, None, None)] + gSrc = cute.local_tile(gSrc_batch, (TILE_K, TILE_V), (0, None)) + thr_copy_load = tiled_copy_load.get_slice(tidx) + + prefetch_count = cutlass.min(NUM_STAGES - 1, num_v_tiles) + for v_tile in range(prefetch_count): + stage = v_tile % NUM_STAGES + gSrc_tile = gSrc[(None, None, v_tile)] + sData_stage = sData[(None, None, stage)] + thr_gSrc = thr_copy_load.partition_S(gSrc_tile) + thr_sData = thr_copy_load.partition_D(sData_stage) + cute.copy(tiled_copy_load, thr_gSrc, thr_sData) + cute.arch.cp_async_commit_group() + + r_A_log = cutlass.Float32(A_log[i_hv]) + r_dt_bias = cutlass.Float32(dt_bias[i_hv]) + r_a = cutlass.Float32(a[i_n, 0, i_hv]) + r_b = cutlass.Float32(b[i_n, 0, i_hv]) + + r_g = 0.0 + r_beta = 0.0 + if in_warp_tid == 0: + x = r_a + r_dt_bias + beta_x = softplus_beta * x + softplus_x = 0.0 + if beta_x <= softplus_threshold: + exp_beta_x = cute.exp(beta_x) + log_input = cutlass.Float32(1.0 + exp_beta_x) + log_result = cutlass.Float32(cute.log(log_input)) + softplus_x = cutlass.Float32( + (cutlass.Float32(1.0) / softplus_beta) * log_result + ) + else: + softplus_x = x + r_g_value = -cute.exp(r_A_log) * softplus_x + r_beta = 1.0 / (1.0 + cute.exp(-r_b)) + r_g = cute.exp(r_g_value) + + r_g = cute.arch.shuffle_sync(r_g, 0) + r_beta = cute.arch.shuffle_sync(r_beta, 0) + + cute.arch.barrier() + + if use_qk_l2norm: + sum_q_partial = 0.0 + sum_k_partial = 0.0 + if tidx < TILE_K: + q_val = sQ[tidx] + k_val = sK[tidx] + sum_q_partial = q_val * q_val + sum_k_partial = k_val * k_val + + for offset in [16, 8, 4, 2, 1]: + sum_q_partial += cute.arch.shuffle_sync_bfly( + sum_q_partial, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_k_partial += cute.arch.shuffle_sync_bfly( + sum_k_partial, offset=offset, mask=-1, mask_and_clamp=31 + ) + + if in_warp_tid == 0: + smem_o[warp_idx] = sum_q_partial + smem_o[warp_idx + 8] = sum_k_partial + cute.arch.barrier() + + inv_norm_q = 0.0 + inv_norm_k = 0.0 + if warp_idx == 0: + local_sum_q = 0.0 + local_sum_k = 0.0 + if in_warp_tid < NUM_WARPS_LARGE: + local_sum_q = smem_o[in_warp_tid] + local_sum_k = smem_o[in_warp_tid + 8] + for offset in [4, 2, 1]: + local_sum_q += cute.arch.shuffle_sync_bfly( + local_sum_q, offset=offset, mask=-1, mask_and_clamp=31 + ) + local_sum_k += cute.arch.shuffle_sync_bfly( + local_sum_k, offset=offset, mask=-1, mask_and_clamp=31 + ) + if in_warp_tid == 0: + smem_o[0] = cute.rsqrt(local_sum_q + 1e-6) + smem_o[1] = cute.rsqrt(local_sum_k + 1e-6) + cute.arch.barrier() + + inv_norm_q = smem_o[0] + inv_norm_k = smem_o[1] + + if tidx < TILE_K: + sK[tidx] = sK[tidx] * inv_norm_k + sQ[tidx] = sQ[tidx] * scale * inv_norm_q + cute.arch.barrier() + else: + if tidx < TILE_K: + sQ[tidx] = sQ[tidx] * scale + cute.arch.barrier() + + for v_tile in range(num_v_tiles): + stage = v_tile % NUM_STAGES + + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + + next_v_tile = v_tile + prefetch_count + if next_v_tile < num_v_tiles: + next_stage = next_v_tile % NUM_STAGES + gSrc_next = gSrc[(None, None, next_v_tile)] + sData_next = sData[(None, None, next_stage)] + thr_gSrc = thr_copy_load.partition_S(gSrc_next) + thr_sData = thr_copy_load.partition_D(sData_next) + cute.copy(tiled_copy_load, thr_gSrc, thr_sData) + cute.arch.cp_async_commit_group() + + v_global = v_tile * TILE_V + v_idx + r_v = cutlass.Float32(v[i_n, 0, i_hv, v_global]) + + sum_hk = 0.0 + for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8): + k_base = k_iter * ROWS_PER_ITER + k_idx = k_base + k_local + h_val = sData[(k_idx, v_idx, stage)] * r_g + r_k_val = sK[k_idx] + sum_hk += h_val * r_k_val + + for offset in [4, 2, 1]: + sum_hk += cute.arch.shuffle_sync_bfly( + sum_hk, offset=offset * V_PER_WARP, mask=-1, mask_and_clamp=31 + ) + + v_new = (r_v - sum_hk) * r_beta + v_new = cute.arch.shuffle_sync(v_new, v_local) + + sum_hq = 0.0 + for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8): + k_base = k_iter * ROWS_PER_ITER + k_idx = k_base + k_local + h_old = sData[(k_idx, v_idx, stage)] * r_g + r_k_val = sK[k_idx] + r_q_val = sQ[k_idx] + h_new = h_old + r_k_val * v_new + sData[(k_idx, v_idx, stage)] = h_new + sum_hq += h_new * r_q_val + + for offset in [4, 2, 1]: + sum_hq += cute.arch.shuffle_sync_bfly( + sum_hq, offset=offset * V_PER_WARP, mask=-1, mask_and_clamp=31 + ) + + if k_local == 0: + v_global_out = v_tile * TILE_V + v_idx + o[(i_n, 0, i_hv, v_global_out)] = cutlass.BFloat16(sum_hq) + + cute.arch.barrier() + + for k_iter in range(NUM_K_ITERS): + flat_idx = tidx + k_iter * 256 + k_write = flat_idx // TILE_V + v_write = flat_idx % TILE_V + if k_write < TILE_K: + h_val = sData[(k_write, v_write, stage)] + v_global_write = v_tile * TILE_V + v_write + h0_source[(pool_idx, i_hv, k_write, v_global_write)] = h_val + + cute.arch.barrier() + + @cute.kernel + def gdn_kernel_large_batch_varlen( + tiled_copy_load: cute.TiledCopy, + h0_source: cute.Tensor, + smem_layout_staged: cute.Layout, + num_v_tiles: cutlass.Constexpr[int], + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + a: cute.Tensor, + b: cute.Tensor, + A_log: cute.Tensor, + dt_bias: cute.Tensor, + o: cute.Tensor, + h0_indices: cute.Tensor, + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + H: cutlass.Constexpr[int], + HV: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], + ): + """Large batch kernel for varlen decode (1, N, ...) format.""" + tidx, _, _ = cute.arch.thread_idx() + in_warp_tid = tidx % 32 + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + batch_idx, _, _ = cute.arch.block_idx() + i_n = batch_idx // HV + i_hv = batch_idx % HV + i_h = i_hv // (HV // H) + + pool_idx = h0_indices[i_n] + + if pool_idx >= 0: + k_local = in_warp_tid // V_PER_WARP + v_local = in_warp_tid % V_PER_WARP + v_base = warp_idx * V_PER_WARP + v_idx = v_base + v_local + + smem = cutlass.utils.SmemAllocator() + sData = smem.allocate_tensor(cutlass.Float32, smem_layout_staged, 128) + smem_o_layout = cute.make_layout((TILE_V,), stride=(1,)) + smem_o = smem.allocate_tensor(cutlass.Float32, smem_o_layout, 128) + smem_k_layout = cute.make_layout((TILE_K,), stride=(1,)) + smem_q_layout = cute.make_layout((TILE_K,), stride=(1,)) + sK = smem.allocate_tensor(cutlass.Float32, smem_k_layout, 128) + sQ = smem.allocate_tensor(cutlass.Float32, smem_q_layout, 128) + + if tidx < TILE_K: + sK[tidx] = cutlass.Float32(k[0, i_n, i_h, tidx]) + sQ[tidx] = cutlass.Float32(q[0, i_n, i_h, tidx]) + + gSrc_batch = h0_source[(pool_idx, i_hv, None, None)] + gSrc = cute.local_tile(gSrc_batch, (TILE_K, TILE_V), (0, None)) + thr_copy_load = tiled_copy_load.get_slice(tidx) + + prefetch_count = cutlass.min(NUM_STAGES - 1, num_v_tiles) + for v_tile in range(prefetch_count): + stage = v_tile % NUM_STAGES + gSrc_tile = gSrc[(None, None, v_tile)] + sData_stage = sData[(None, None, stage)] + thr_gSrc = thr_copy_load.partition_S(gSrc_tile) + thr_sData = thr_copy_load.partition_D(sData_stage) + cute.copy(tiled_copy_load, thr_gSrc, thr_sData) + cute.arch.cp_async_commit_group() + + r_A_log = cutlass.Float32(A_log[i_hv]) + r_dt_bias = cutlass.Float32(dt_bias[i_hv]) + r_a = cutlass.Float32(a[i_n, i_hv]) + r_b = cutlass.Float32(b[i_n, i_hv]) + + r_g = 0.0 + r_beta = 0.0 + if in_warp_tid == 0: + x = r_a + r_dt_bias + beta_x = softplus_beta * x + softplus_x = 0.0 + if beta_x <= softplus_threshold: + exp_beta_x = cute.exp(beta_x) + log_input = cutlass.Float32(1.0 + exp_beta_x) + log_result = cutlass.Float32(cute.log(log_input)) + softplus_x = cutlass.Float32( + (cutlass.Float32(1.0) / softplus_beta) * log_result + ) + else: + softplus_x = x + r_g_value = -cute.exp(r_A_log) * softplus_x + r_beta = 1.0 / (1.0 + cute.exp(-r_b)) + r_g = cute.exp(r_g_value) + + r_g = cute.arch.shuffle_sync(r_g, 0) + r_beta = cute.arch.shuffle_sync(r_beta, 0) + + cute.arch.barrier() + + if use_qk_l2norm: + sum_q_partial = 0.0 + sum_k_partial = 0.0 + if tidx < TILE_K: + q_val = sQ[tidx] + k_val = sK[tidx] + sum_q_partial = q_val * q_val + sum_k_partial = k_val * k_val + + for offset in [16, 8, 4, 2, 1]: + sum_q_partial += cute.arch.shuffle_sync_bfly( + sum_q_partial, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_k_partial += cute.arch.shuffle_sync_bfly( + sum_k_partial, offset=offset, mask=-1, mask_and_clamp=31 + ) + + if in_warp_tid == 0: + smem_o[warp_idx] = sum_q_partial + smem_o[warp_idx + 8] = sum_k_partial + cute.arch.barrier() + + inv_norm_q = 0.0 + inv_norm_k = 0.0 + if warp_idx == 0: + local_sum_q = 0.0 + local_sum_k = 0.0 + if in_warp_tid < NUM_WARPS_LARGE: + local_sum_q = smem_o[in_warp_tid] + local_sum_k = smem_o[in_warp_tid + 8] + for offset in [4, 2, 1]: + local_sum_q += cute.arch.shuffle_sync_bfly( + local_sum_q, offset=offset, mask=-1, mask_and_clamp=31 + ) + local_sum_k += cute.arch.shuffle_sync_bfly( + local_sum_k, offset=offset, mask=-1, mask_and_clamp=31 + ) + if in_warp_tid == 0: + smem_o[0] = cute.rsqrt(local_sum_q + 1e-6) + smem_o[1] = cute.rsqrt(local_sum_k + 1e-6) + cute.arch.barrier() + + inv_norm_q = smem_o[0] + inv_norm_k = smem_o[1] + + if tidx < TILE_K: + sK[tidx] = sK[tidx] * inv_norm_k + sQ[tidx] = sQ[tidx] * scale * inv_norm_q + cute.arch.barrier() + else: + if tidx < TILE_K: + sQ[tidx] = sQ[tidx] * scale + cute.arch.barrier() + + for v_tile in range(num_v_tiles): + stage = v_tile % NUM_STAGES + + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + + next_v_tile = v_tile + prefetch_count + if next_v_tile < num_v_tiles: + next_stage = next_v_tile % NUM_STAGES + gSrc_next = gSrc[(None, None, next_v_tile)] + sData_next = sData[(None, None, next_stage)] + thr_gSrc = thr_copy_load.partition_S(gSrc_next) + thr_sData = thr_copy_load.partition_D(sData_next) + cute.copy(tiled_copy_load, thr_gSrc, thr_sData) + cute.arch.cp_async_commit_group() + + v_global = v_tile * TILE_V + v_idx + r_v = cutlass.Float32(v[0, i_n, i_hv, v_global]) + + sum_hk = 0.0 + for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8): + k_base = k_iter * ROWS_PER_ITER + k_idx = k_base + k_local + h_val = sData[(k_idx, v_idx, stage)] * r_g + r_k_val = sK[k_idx] + sum_hk += h_val * r_k_val + + for offset in [4, 2, 1]: + sum_hk += cute.arch.shuffle_sync_bfly( + sum_hk, offset=offset * V_PER_WARP, mask=-1, mask_and_clamp=31 + ) + + v_new = (r_v - sum_hk) * r_beta + v_new = cute.arch.shuffle_sync(v_new, v_local) + + sum_hq = 0.0 + for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8): + k_base = k_iter * ROWS_PER_ITER + k_idx = k_base + k_local + h_old = sData[(k_idx, v_idx, stage)] * r_g + r_k_val = sK[k_idx] + r_q_val = sQ[k_idx] + h_new = h_old + r_k_val * v_new + sData[(k_idx, v_idx, stage)] = h_new + sum_hq += h_new * r_q_val + + for offset in [4, 2, 1]: + sum_hq += cute.arch.shuffle_sync_bfly( + sum_hq, offset=offset * V_PER_WARP, mask=-1, mask_and_clamp=31 + ) + + if k_local == 0: + v_global_out = v_tile * TILE_V + v_idx + o[(0, i_n, i_hv, v_global_out)] = cutlass.BFloat16(sum_hq) + + cute.arch.barrier() + + for k_iter in range(NUM_K_ITERS): + flat_idx = tidx + k_iter * 256 + k_write = flat_idx // TILE_V + v_write = flat_idx % TILE_V + if k_write < TILE_K: + h_val = sData[(k_write, v_write, stage)] + v_global_write = v_tile * TILE_V + v_write + h0_source[(pool_idx, i_hv, k_write, v_global_write)] = h_val + + cute.arch.barrier() + + return ( + gdn_kernel_small_batch, + gdn_kernel_small_batch_varlen, + gdn_kernel_large_batch, + gdn_kernel_large_batch_varlen, + ) + + +def _create_jit_functions(): + """Create JIT-compiled launcher functions for all kernel variants.""" + + gdn_small, gdn_small_varlen, gdn_large, gdn_large_varlen = _define_kernels() + + @cute.jit + def run_small_batch( + cu_seqlens: cute.Tensor, + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + a: cute.Tensor, + b: cute.Tensor, + A_log: cute.Tensor, + dt_bias: cute.Tensor, + h0_source: cute.Tensor, + h0_indices: cute.Tensor, + o: cute.Tensor, + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + B: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + HV: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + use_initial_state: cutlass.Constexpr[bool], + use_qk_l2norm: cutlass.Constexpr[bool], + stream: cuda.CUstream, + ): + pool_size, hv_dim, k_dim, v_dim = h0_source.layout.shape + n_indices = h0_indices.layout.shape[0] + batch_size = n_indices * hv_dim + + copy_atom = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + cutlass.Float32, + num_bits_per_copy=128, + ) + num_v_tiles_small = cute.ceil_div(v_dim, TILE_V_SMALL) + smem_layout_small = cute.make_layout( + (TILE_K, TILE_V_SMALL, NUM_STAGES), + stride=(TILE_V_SMALL_PADDED, 1, TILE_K * TILE_V_SMALL_PADDED), + ) + thread_layout_small = cute.make_layout((32, 4), stride=(4, 1)) + val_layout_small = cute.make_layout((1, 4)) + tiled_copy_load_small = cute.make_tiled_copy_tv( + copy_atom, thread_layout_small, val_layout_small + ) + smem_bytes_small = ( + 4 * TILE_K * TILE_V_SMALL_PADDED * NUM_STAGES + + 4 * TILE_V_SMALL + + 4 * TILE_K * 2 + + 64 + ) + + gdn_small( + tiled_copy_load_small, + h0_source, + smem_layout_small, + num_v_tiles_small, + q, + k, + v, + a, + b, + A_log, + dt_bias, + o, + h0_indices, + softplus_beta, + softplus_threshold, + scale, + H, + HV, + use_qk_l2norm, + ).launch( + grid=(batch_size * NUM_BLOCKS_PER_STATE_SMALL, 1, 1), + block=[NUM_THREADS, 1, 1], + smem=smem_bytes_small, + stream=stream, + ) + + @cute.jit + def run_small_batch_varlen( + cu_seqlens: cute.Tensor, + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + a: cute.Tensor, + b: cute.Tensor, + A_log: cute.Tensor, + dt_bias: cute.Tensor, + h0_source: cute.Tensor, + h0_indices: cute.Tensor, + o: cute.Tensor, + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + B: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + HV: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + use_initial_state: cutlass.Constexpr[bool], + use_qk_l2norm: cutlass.Constexpr[bool], + stream: cuda.CUstream, + ): + pool_size, hv_dim, k_dim, v_dim = h0_source.layout.shape + n_indices = h0_indices.layout.shape[0] + batch_size = n_indices * hv_dim + + copy_atom = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + cutlass.Float32, + num_bits_per_copy=128, + ) + num_v_tiles_small = cute.ceil_div(v_dim, TILE_V_SMALL) + smem_layout_small = cute.make_layout( + (TILE_K, TILE_V_SMALL, NUM_STAGES), + stride=(TILE_V_SMALL_PADDED, 1, TILE_K * TILE_V_SMALL_PADDED), + ) + thread_layout_small = cute.make_layout((32, 4), stride=(4, 1)) + val_layout_small = cute.make_layout((1, 4)) + tiled_copy_load_small = cute.make_tiled_copy_tv( + copy_atom, thread_layout_small, val_layout_small + ) + smem_bytes_small = ( + 4 * TILE_K * TILE_V_SMALL_PADDED * NUM_STAGES + + 4 * TILE_V_SMALL + + 4 * TILE_K * 2 + + 64 + ) + + gdn_small_varlen( + tiled_copy_load_small, + h0_source, + smem_layout_small, + num_v_tiles_small, + q, + k, + v, + a, + b, + A_log, + dt_bias, + o, + h0_indices, + softplus_beta, + softplus_threshold, + scale, + H, + HV, + use_qk_l2norm, + ).launch( + grid=(batch_size * NUM_BLOCKS_PER_STATE_SMALL, 1, 1), + block=[NUM_THREADS, 1, 1], + smem=smem_bytes_small, + stream=stream, + ) + + @cute.jit + def run_large_batch( + cu_seqlens: cute.Tensor, + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + a: cute.Tensor, + b: cute.Tensor, + A_log: cute.Tensor, + dt_bias: cute.Tensor, + h0_source: cute.Tensor, + h0_indices: cute.Tensor, + o: cute.Tensor, + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + B: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + HV: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + use_initial_state: cutlass.Constexpr[bool], + use_qk_l2norm: cutlass.Constexpr[bool], + stream: cuda.CUstream, + ): + pool_size, hv_dim, k_dim, v_dim = h0_source.layout.shape + n_indices = h0_indices.layout.shape[0] + batch_size = n_indices * hv_dim + + copy_atom = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + cutlass.Float32, + num_bits_per_copy=128, + ) + num_v_tiles = cute.ceil_div(v_dim, TILE_V) + base_smem_layout = cute.make_layout( + (TILE_K, TILE_V, NUM_STAGES), + stride=(TILE_V_PADDED, 1, TILE_K * TILE_V_PADDED), + ) + thread_layout = cute.make_layout((32, 8), stride=(8, 1)) + val_layout = cute.make_layout((1, 4)) + tiled_copy_load = cute.make_tiled_copy_tv(copy_atom, thread_layout, val_layout) + smem_bytes = ( + 4 * TILE_K * TILE_V_PADDED * NUM_STAGES + 4 * TILE_V + 4 * TILE_K * 2 + 64 + ) + + gdn_large( + tiled_copy_load, + h0_source, + base_smem_layout, + num_v_tiles, + q, + k, + v, + a, + b, + A_log, + dt_bias, + o, + h0_indices, + softplus_beta, + softplus_threshold, + scale, + H, + HV, + use_qk_l2norm, + ).launch( + grid=(batch_size, 1, 1), + block=[NUM_THREADS_LARGE, 1, 1], + smem=smem_bytes, + stream=stream, + ) + + @cute.jit + def run_large_batch_varlen( + cu_seqlens: cute.Tensor, + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + a: cute.Tensor, + b: cute.Tensor, + A_log: cute.Tensor, + dt_bias: cute.Tensor, + h0_source: cute.Tensor, + h0_indices: cute.Tensor, + o: cute.Tensor, + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + B: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + HV: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + use_initial_state: cutlass.Constexpr[bool], + use_qk_l2norm: cutlass.Constexpr[bool], + stream: cuda.CUstream, + ): + pool_size, hv_dim, k_dim, v_dim = h0_source.layout.shape + n_indices = h0_indices.layout.shape[0] + batch_size = n_indices * hv_dim + + copy_atom = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + cutlass.Float32, + num_bits_per_copy=128, + ) + num_v_tiles = cute.ceil_div(v_dim, TILE_V) + base_smem_layout = cute.make_layout( + (TILE_K, TILE_V, NUM_STAGES), + stride=(TILE_V_PADDED, 1, TILE_K * TILE_V_PADDED), + ) + thread_layout = cute.make_layout((32, 8), stride=(8, 1)) + val_layout = cute.make_layout((1, 4)) + tiled_copy_load = cute.make_tiled_copy_tv(copy_atom, thread_layout, val_layout) + smem_bytes = ( + 4 * TILE_K * TILE_V_PADDED * NUM_STAGES + 4 * TILE_V + 4 * TILE_K * 2 + 64 + ) + + gdn_large_varlen( + tiled_copy_load, + h0_source, + base_smem_layout, + num_v_tiles, + q, + k, + v, + a, + b, + A_log, + dt_bias, + o, + h0_indices, + softplus_beta, + softplus_threshold, + scale, + H, + HV, + use_qk_l2norm, + ).launch( + grid=(batch_size, 1, 1), + block=[NUM_THREADS_LARGE, 1, 1], + smem=smem_bytes, + stream=stream, + ) + + return ( + run_small_batch, + run_small_batch_varlen, + run_large_batch, + run_large_batch_varlen, + ) + + +_jit_functions = None + + +def _get_jit_functions(): + global _jit_functions + if _jit_functions is None: + _jit_functions = _create_jit_functions() + return _jit_functions + + +def _get_compiled_kernel(N, H, HV, K, V, pool_size, use_small_batch, is_varlen_decode): + """Get or compile the kernel for given dimensions.""" + global _compiled_kernels + + key = (N, H, HV, K, V, pool_size, use_small_batch, is_varlen_decode) + if key in _compiled_kernels: + return _compiled_kernels[key] + + cu_seqlens = torch.zeros(N + 1, dtype=torch.int32, device="cuda") + + if is_varlen_decode: + q = torch.zeros(1, N, H, K, dtype=torch.bfloat16, device="cuda") + k = torch.zeros(1, N, H, K, dtype=torch.bfloat16, device="cuda") + v = torch.zeros(1, N, HV, V, dtype=torch.bfloat16, device="cuda") + a = torch.zeros(N, HV, dtype=torch.bfloat16, device="cuda") + b = torch.zeros(N, HV, dtype=torch.bfloat16, device="cuda") + o = torch.zeros(1, N, HV, V, dtype=torch.bfloat16, device="cuda") + else: + q = torch.zeros(N, 1, H, K, dtype=torch.bfloat16, device="cuda") + k = torch.zeros(N, 1, H, K, dtype=torch.bfloat16, device="cuda") + v = torch.zeros(N, 1, HV, V, dtype=torch.bfloat16, device="cuda") + a = torch.zeros(N, 1, HV, dtype=torch.bfloat16, device="cuda") + b = torch.zeros(N, 1, HV, dtype=torch.bfloat16, device="cuda") + o = torch.zeros(N, 1, HV, V, dtype=torch.bfloat16, device="cuda") + + A_log = torch.zeros(HV, dtype=torch.float32, device="cuda") + dt_bias = torch.zeros(HV, dtype=torch.bfloat16, device="cuda") + h0_source = torch.zeros(pool_size, HV, K, V, dtype=torch.float32, device="cuda") + h0_indices = torch.zeros(N, dtype=torch.int32, device="cuda") + + cu_seqlens_tensor = from_dlpack(cu_seqlens, assumed_align=16) + q_tensor = from_dlpack(q, assumed_align=16) + k_tensor = from_dlpack(k, assumed_align=16) + v_tensor = from_dlpack(v, assumed_align=16) + a_tensor = from_dlpack(a, assumed_align=16) + b_tensor = from_dlpack(b, assumed_align=16) + A_log_tensor = from_dlpack(A_log, assumed_align=16) + dt_bias_tensor = from_dlpack(dt_bias, assumed_align=16) + h0_source_tensor = from_dlpack(h0_source, assumed_align=16) + h0_indices_tensor = from_dlpack(h0_indices, assumed_align=16) + o_tensor = from_dlpack(o, assumed_align=16) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + run_small, run_small_varlen, run_large, run_large_varlen = _get_jit_functions() + + if use_small_batch: + kernel_func = run_small_varlen if is_varlen_decode else run_small + else: + kernel_func = run_large_varlen if is_varlen_decode else run_large + + scale = K**-0.5 + softplus_beta = 1.0 + softplus_threshold = 20.0 + + B_compile = 1 if is_varlen_decode else N + T_compile = N if is_varlen_decode else 1 + + compiled_kernel = cute.compile( + kernel_func, + cu_seqlens_tensor, + q_tensor, + k_tensor, + v_tensor, + a_tensor, + b_tensor, + A_log_tensor, + dt_bias_tensor, + h0_source_tensor, + h0_indices_tensor, + o_tensor, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + scale=scale, + B=B_compile, + T=T_compile, + H=H, + K=K, + V=V, + HV=HV, + use_initial_state=True, + use_qk_l2norm=True, + stream=stream, + ) + + _compiled_kernels[key] = compiled_kernel + logger.info( + f"CuTe DSL GDN kernel compiled: N={N}, H={H}, HV={HV}, K={K}, V={V}, pool_size={pool_size}, small_batch={use_small_batch}, varlen={is_varlen_decode}" + ) + + return compiled_kernel + + +def cutedsl_fused_sigmoid_gating_delta_rule_update( + A_log: torch.Tensor, + dt_bias: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + initial_state_source: torch.Tensor, + initial_state_indices: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + use_qk_l2norm_in_kernel: bool = True, + softplus_beta: float = 1.0, + softplus_threshold: float = 20.0, +) -> torch.Tensor: + """CuTe DSL implementation of fused sigmoid gating delta rule update.""" + + B_q, T_q, H, K = q.shape + HV = v.shape[2] + V = v.shape[3] + N = initial_state_indices.shape[0] + + is_varlen_decode = B_q == 1 and T_q == N and N > 1 + if scale is None: + scale = K**-0.5 + + use_small_batch = N < SMALL_BATCH_THRESHOLD + + if initial_state_source.dim() == 1: + pool_size = initial_state_source.numel() // (HV * K * V) + h0_source = initial_state_source.view(pool_size, HV, K, V) + elif initial_state_source.dim() == 4: + pool_size = initial_state_source.shape[0] + h0_source = initial_state_source + else: + raise ValueError( + f"Unexpected initial_state_source shape: {initial_state_source.shape}" + ) + + if is_varlen_decode: + if a.dim() == 3: + a = a.squeeze(0) + if b.dim() == 3: + b = b.squeeze(0) + o = q.new_empty(1, N, HV, V, dtype=torch.bfloat16) + else: + if a.dim() == 2: + a = a.unsqueeze(1) + if b.dim() == 2: + b = b.unsqueeze(1) + o = q.new_empty(N, 1, HV, V, dtype=torch.bfloat16) + + q, k, v = [t.contiguous() for t in (q, k, v)] + + global _cu_seqlens_cache + if cu_seqlens is not None: + cu_seqlens_to_use = cu_seqlens + else: + cache_key = (N, str(q.device)) + if cache_key not in _cu_seqlens_cache: + _cu_seqlens_cache[cache_key] = torch.arange( + N + 1, dtype=torch.int32, device=q.device + ) + cu_seqlens_to_use = _cu_seqlens_cache[cache_key] + + cu_seqlens_tensor = from_dlpack( + cu_seqlens_to_use.detach(), assumed_align=16 + ).mark_layout_dynamic(leading_dim=0) + q_tensor = from_dlpack(q.detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=q.ndim - 1 + ) + k_tensor = from_dlpack(k.detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=k.ndim - 1 + ) + v_tensor = from_dlpack(v.detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=v.ndim - 1 + ) + a_tensor = from_dlpack(a.detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=a.ndim - 1 + ) + b_tensor = from_dlpack(b.detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=b.ndim - 1 + ) + A_log_tensor = from_dlpack(A_log.detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=0 + ) + dt_bias_tensor = from_dlpack( + dt_bias.detach(), assumed_align=16 + ).mark_layout_dynamic(leading_dim=0) + h0_source_tensor = from_dlpack( + h0_source.detach(), assumed_align=16 + ).mark_layout_dynamic(leading_dim=h0_source.ndim - 1) + h0_indices_tensor = from_dlpack( + initial_state_indices.detach(), assumed_align=16 + ).mark_layout_dynamic(leading_dim=0) + o_tensor = from_dlpack(o.detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=o.ndim - 1 + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compiled_kernel = _get_compiled_kernel( + N, H, HV, K, V, pool_size, use_small_batch, is_varlen_decode + ) + + compiled_kernel( + cu_seqlens_tensor, + q_tensor, + k_tensor, + v_tensor, + a_tensor, + b_tensor, + A_log_tensor, + dt_bias_tensor, + h0_source_tensor, + h0_indices_tensor, + o_tensor, + stream, + ) + + return o diff --git a/sglang/python/sglang/jit_kernel/diffusion/cutedsl/common/norm_fusion.py b/sglang/python/sglang/jit_kernel/diffusion/cutedsl/common/norm_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..01b0802848d0258013b5a7d2339617a4f7299061 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/diffusion/cutedsl/common/norm_fusion.py @@ -0,0 +1,201 @@ +from typing import Optional, Tuple, Union + +import cutlass +import cutlass.cute as cute +import torch +from einops import rearrange + +from sglang.jit_kernel.diffusion.cutedsl.common.reduce import ( + cta_reduce_sum, + warp_reduce_sum, +) + + +@cute.jit +def apply_norm_cta( + norm_type: cutlass.Constexpr, + num_warps: cutlass.Constexpr, + tidx: cutlass.Int32, + tXrX: cute.Tensor, + tWrW: Optional[cute.Tensor], + tBrB: Optional[cute.Tensor], + D: Union[cutlass.Int32, cutlass.Constexpr], + eps: Union[cutlass.Float32, cutlass.Constexpr], +) -> cute.Tensor: + if cutlass.const_expr(norm_type == "rms"): + return apply_rmsnorm_cta(num_warps, tidx, tXrX, tWrW, D, eps) + else: + return apply_layernorm_cta(num_warps, tidx, tXrX, tWrW, tBrB, D, eps) + + +@cute.jit +def apply_rmsnorm_cta( + num_warps: Union[cutlass.Int32, cutlass.Constexpr], + tidx: cutlass.Int32, + tXrX: cute.Tensor, + tWrW: Optional[cute.Tensor], + D: Union[cutlass.Int32, cutlass.Constexpr], + eps: Union[cutlass.Float32, cutlass.Constexpr], +) -> cute.Tensor: + """ + RMSNorm: + y[i] = x[i] / sqrt(sum(x ^ 2) / D + eps) * w[i] + """ + val = cute.Float32(0.0) + for idx in range(cute.size(tXrX)): + # Accumulate in FP32 to improve numerical precision. + x_fp32 = tXrX[idx].to(cutlass.Float32) + val += x_fp32 * x_fp32 + val = warp_reduce_sum(val) + acc_sq = cta_reduce_sum(val, num_warps, tidx) + factor = cute.rsqrt(acc_sq / D + eps) + tNrN = cute.make_fragment_like(tXrX) + if cutlass.const_expr(isinstance(tWrW, cute.Tensor)): + tNrN.store((tXrX.load() * factor * tWrW.load()).to(tNrN.element_type)) + else: + tNrN.store((tXrX.load() * factor).to(tNrN.element_type)) + return tNrN + + +@cute.jit +def apply_layernorm_cta( + num_warps: Union[cutlass.Int32, cutlass.Constexpr], + tidx: cutlass.Int32, + tXrX: cute.Tensor, + tWrW: Optional[cute.Tensor], + tBrB: Optional[cute.Tensor], + D: Union[cutlass.Int32, cutlass.Constexpr], + eps: Union[cutlass.Float32, cutlass.Constexpr], +) -> cute.Tensor: + """ + LayerNorm: + mean = sum(x) / D + var = sum((x - mean) ^ 2) / D + y[i] = (x[i] - mean) / sqrt(var + eps) * w[i] + b[i] + """ + # Reduce mean + val = cute.Float32(0.0) + for idx in range(cute.size(tXrX)): + # Accumulate in FP32 to improve numerical precision. + val += tXrX[idx].to(cutlass.Float32) + val = warp_reduce_sum(val) + val = cta_reduce_sum(val, num_warps, tidx) + mean = val / D + # Reduce variance + val = cute.Float32(0.0) + for idx in range(cute.size(tXrX)): + # Accumulate in FP32 to improve numerical precision. + x_fp32 = tXrX[idx].to(cutlass.Float32) + val += (x_fp32 - mean) * (x_fp32 - mean) + val = warp_reduce_sum(val) + val = cta_reduce_sum(val, num_warps, tidx) + factor = cute.rsqrt(val / D + eps) + # Normalize + tNrN = cute.make_fragment_like(tXrX) + if cutlass.const_expr( + isinstance(tWrW, cute.Tensor) and isinstance(tBrB, cute.Tensor) + ): + tNrN.store( + ((tXrX.load() - mean) * factor * tWrW.load() + tBrB.load()).to( + tNrN.element_type + ) + ) + else: + tNrN.store(((tXrX.load() - mean) * factor).to(tNrN.element_type)) + return tNrN + + +################################################################################ +# BSFD Indexing +################################################################################ +# In diffusion norm-fusion kernels, we compute `norm(x) + y`, where +# `x` has shape [B, S, D] and `y` may come in various broadcastable forms: +# [1], [D], [1, D], [1, 1, D], [B, D], [B, 1, D], [B, S, D], or [B, F, 1, D]. +# +# For a given (batch_id, seq_id), the index mapping for `y` falls into 3 cases: +# 1) Scalar broadcast [1]: +# (batch_id, seq_id, *) -> (0) +# 2) Frame-based BSFD broadcast [B, F, 1, D]: +# frame_id = seq_id // len_frame +# (batch_id, seq_id, *) -> (batch_id, frame_id, *) +# 3) All other cases: +# `y` is broadcast to [B, S, D] (via view/expand, no materialization), +# and indexed as (batch_id, seq_id, *). +# +# This helper normalizes `y` into a BSFD-compatible view so that kernel +# indexing logic remains simple and uniform. +################################################################################ + + +def broadcast_tensor_for_bsfd( + tensor: Union[Optional[torch.Tensor], int], + B: int, + S: int, + D: int, +) -> Union[Optional[torch.Tensor], int]: + """ + Broadcast to (B, S, D) without memory copy for following shapes: + - [D], [1, D], [1, 1, D], [B, D], [B, 1, D], [B, S, D]. + """ + + # Return directly for non-tensor value + if not isinstance(tensor, torch.Tensor): + return tensor + + if tensor.ndim == 1: + # Scalar [1] is preserved as-is and handled specially in CuTe kernel. + if tensor.numel() == 1: + return tensor + return rearrange(tensor, "d -> 1 1 d").expand(B, S, D) + if tensor.ndim == 2: + return rearrange(tensor, "b d -> b 1 d").expand(B, S, D) + if tensor.ndim == 3: + return tensor.expand(B, S, D) + if tensor.ndim == 4: + return tensor + raise ValueError(f"BSFD broadcast: unsupported tensor ndim: {tensor.ndim}.") + + +@cute.jit +def tensor_slice_for_bsfd( + mV: cute.Tensor, + thr_copy: cute.ThrCopy, + batch_id: cutlass.Int32, + seq_id: cutlass.Int32, + S: Union[cutlass.Int32, cutlass.Constexpr], + D: Union[cutlass.Int32, cutlass.Constexpr], +) -> Tuple[cute.Tensor, cute.Tensor]: + """ + Slice a BSFD-compatible tensor into a per-thread gmem tile and rmem fragment. + + Given a logical (batch_id, seq_id), this helper selects the corresponding + D-length slice from `mV` and prepares it for vectorized copy. + """ + gV: cute.Tensor + if cutlass.const_expr(cute.is_static(mV.layout) and cute.size(mV.layout) == 1): + # build a ((1,1),(1,)) layout so it could broadcast-align with the + # regular rmem fragment shape ((4,1),(k,)). + layout = cute.make_layout(shape=((1, 1), (1,))) + tVgV = cute.make_tensor(mV.iterator, layout) + tVrV = cute.make_rmem_tensor(layout, mV.element_type) + return tVgV, tVrV + + # Use `local_tile` instead of direct indexing to preserve gmem base pointer + # alignment required for vectorized loads. + if cutlass.const_expr(len(mV.shape) == 1): + gV = mV + elif cutlass.const_expr(len(mV.shape) == 3): + gV = cute.local_tile(mV, tiler=(1, 1, D), coord=(batch_id, seq_id, 0)) + gV = gV[0, 0, None] + elif cutlass.const_expr(len(mV.shape) == 4): + # Compute frame length at runtime (instead of compile time) to avoid + # specializing kernels on the frame dimension. + frame_len = S // mV.shape[1] + frame_id = seq_id // frame_len + gV = cute.local_tile(mV, tiler=(1, 1, 1, D), coord=(batch_id, frame_id, 0, 0)) + gV = gV[0, 0, 0, None] + else: + raise NotImplementedError(f"BSFD slice: unsupported shape {mV.shape}.") + tVgV = thr_copy.partition_S(gV) + tVrV = cute.make_fragment_like(tVgV, tVgV.element_type) + return tVgV, tVrV diff --git a/sglang/python/sglang/jit_kernel/diffusion/cutedsl/common/reduce.py b/sglang/python/sglang/jit_kernel/diffusion/cutedsl/common/reduce.py new file mode 100644 index 0000000000000000000000000000000000000000..978d9bd6be2c50ef9d50f591c42865421f624757 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/diffusion/cutedsl/common/reduce.py @@ -0,0 +1,33 @@ +import math + +import cutlass +import cutlass.cute as cute + + +@cute.jit +def warp_reduce_sum(val: cute.Numeric, reduce_size: int = 32) -> cute.Numeric: + iters = int(math.log2(reduce_size)) + for i in range(iters): + val = val + cute.arch.shuffle_sync_down(val, offset=1 << (iters - i - 1)) + return val + + +@cute.jit +def cta_reduce_sum( + val: cute.Numeric, num_warps: cutlass.Constexpr, tidx: cutlass.Int32 +) -> cute.Numeric: + smem = cutlass.utils.SmemAllocator() + acc = smem.allocate_tensor(cutlass.Float32, num_warps + 1) + warp_id = tidx >> 5 + lane_id = tidx & 31 + if lane_id == 0: + acc[warp_id] = val + cute.arch.sync_threads() + if warp_id == 0: + val = acc[lane_id] if lane_id < num_warps else cutlass.Float32(0) + val = warp_reduce_sum(val) + if lane_id == 0: + acc[num_warps] = val + cute.arch.sync_threads() + val = acc[num_warps] + return val diff --git a/sglang/python/sglang/jit_kernel/diffusion/cutedsl/scale_residual_norm_scale_shift.py b/sglang/python/sglang/jit_kernel/diffusion/cutedsl/scale_residual_norm_scale_shift.py new file mode 100644 index 0000000000000000000000000000000000000000..734b9bbf78ad054eae8810aaa5a6debd5741d880 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/diffusion/cutedsl/scale_residual_norm_scale_shift.py @@ -0,0 +1,431 @@ +from typing import Optional, Tuple, Union + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch + +from sglang.jit_kernel.diffusion.cutedsl.common.norm_fusion import ( + apply_norm_cta, + broadcast_tensor_for_bsfd, + tensor_slice_for_bsfd, +) +from sglang.jit_kernel.diffusion.cutedsl.utils import TORCH_TO_CUTE_DTYPE, WARP_SIZE + +_COMPILE_CACHE = {} + + +def to_cute_arg( + t, + *, + assume_aligned: Optional[int] = 32, + use_32bit_stride: bool = False, + enable_tvm_ffi: bool = True, +): + """ + Convert a Python value into a CuTeDSL value. + """ + if isinstance(t, torch.Tensor): + return cute.runtime.from_dlpack( + t, + assumed_align=assume_aligned, + use_32bit_stride=use_32bit_stride, + enable_tvm_ffi=enable_tvm_ffi, + ) + if isinstance(t, int): + return cutlass.Int32(t) + if isinstance(t, float): + return cutlass.Float32(t) + return t + + +def to_fake_cute_args(t: torch.Tensor): + if isinstance(t, torch.Tensor): + # Only keep the last dim as compile-time value to maximum compiled kernel reuse + # e.g. (1,2,1536):(3027,1536,1) -> (?,?,1536):(?,?,1) + D = t.shape[-1] + dtype = TORCH_TO_CUTE_DTYPE[t.dtype] + shape = (*(cute.sym_int() for _ in range(t.ndim - 1)), D) + stride = (*(cute.sym_int(divisibility=D) for _ in range(t.ndim - 1)), 1) + fake_t = cute.runtime.make_fake_tensor( + dtype, shape, stride, memspace=cute.AddressSpace.gmem, assumed_align=32 + ) + return fake_t + return to_cute_arg(t) + + +class ScaleResidualNormScaleShift: + @classmethod + def make_hash_key(cls, *inputs): + """ + Compile-time values: + - D: hidden dimension (size of the last dimension) + - norm_type: layer norm or RMS norm + - tensor dtype + - tensor rank (i.e., tensor.ndim) + + Runtime values: + - all other inputs + + This hash key defines the compile-time specialization boundary for + ScaleResidualNormScaleShift kernels. + """ + + def _sig(val): + if isinstance(val, torch.Tensor): + return (val.dtype, val.ndim, val.shape[-1]) + return val + + return tuple(_sig(val) for val in inputs) + + def __init__(self, D: int, norm_type: str): + self.D = D + self.norm_type = norm_type # "layer" or "rms" + self.num_warps = self.D // 256 # num of warps per cta + self.num_threads = self.num_warps * WARP_SIZE # num of threads per cta + + @cute.jit + def __call__( + self, + mY, + mResOut, + mRes, + mX, + mGate, + mWeight, + mBias, + mScale, + mShift, + eps: cutlass.Float32 = cutlass.Float32(1e-5), + stream: cuda.CUstream = cuda.CUstream(cuda.CUstream_flags.CU_STREAM_DEFAULT), + ): + # Tensor shapes + B, S, _ = mX.shape # (batch, seq_len, hidden_dim) + # Vectorized copy configuration + num_vectorized = 8 # maximum num of elem per copy + atom_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mX.element_type, + num_bits_per_copy=128, + ) + # Thread/value layouts for tiled copy + t_layout = cute.make_layout(self.num_threads) # thread layout within a CTA + v_layout = cute.make_layout(num_vectorized) # per-thread vector layout + tiled_copy = cute.make_tiled_copy_tv(atom_copy, t_layout, v_layout) + + self.kernel( + mY, + mResOut, + mRes, + mX, + mGate, + mWeight, + mBias, + mScale, + mShift, + tiled_copy, + eps, + ).launch( + grid=[B * S, 1, 1], + block=[self.num_threads, 1, 1], + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mY, + mResOut, + mRes, + mX, + mGate, + mWeight, + mBias, + mScale, + mShift, + tiled_copy: cute.TiledCopy, + eps: cutlass.Float32, + ): + _, S, _ = mX.shape + tidx, _, _ = cute.arch.thread_idx() # thread index + bid, _, _ = cute.arch.block_idx() # cta index + bidx = cutlass.Int32(bid // S) # batch index + bidy = cutlass.Int32(bid % S) # seq_len index + thr_copy = tiled_copy.get_slice(tidx) + + @cute.jit + def slice_if(mV): + if cutlass.const_expr(isinstance(mV, cute.Tensor)): + return tensor_slice_for_bsfd(mV, thr_copy, bidx, bidy, S, self.D) + return mV, mV + + @cute.jit + def copy_if(src, dst): + if cutlass.const_expr( + isinstance(src, cute.Tensor) and isinstance(src, cute.Tensor) + ): + cute.autovec_copy(src, dst) # LDG.128 + + @cute.jit + def norm(x, weight, bias): + return apply_norm_cta( + self.norm_type, self.num_warps, tidx, x, weight, bias, self.D, eps + ) + + # Slice: retrieve the per-thread data slices for both global memory (gmem) + # and register memory (rmem). The layouts are: + # - ((4,2),(1)):((1,4),(0)) for fp32 + # - ((8,1),(1)):((1,0),(0)) for fp16/bf16 + tRgR, tRrR = slice_if(mRes) # residual + tXgX, tXrX = slice_if(mX) # x + tGgG, tGrG = slice_if(mGate) # gate + tROgRO, tROrRO = slice_if(mResOut) # residual_out + tWgW, tWrW = slice_if(mWeight) # weight + tBgB, tBrB = slice_if(mBias) # bias + tSCgSC, tSCrSC = slice_if(mScale) # scale + tSHgSH, tSHrSH = slice_if(mShift) # shift + tYgY, tYrY = slice_if(mY) # y + # Load: load tensor from global memory to registers + copy_if(tRgR, tRrR) # gmem -> rmem + copy_if(tXgX, tXrX) # gmem -> rmem + copy_if(tGgG, tGrG) # gmem -> rmem + copy_if(tWgW, tWrW) # gmem -> rmem + copy_if(tBgB, tBrB) # gmem -> rmem + + # For norm_scale_shift, output: + # - y = norm(x, weight, bias) * (1 + scale) + shift + # For scale_residual_norm_scale_shift, output: + # - residual_out = residual + gate * x + # - y = norm(residual_out, weight, bias) * (1 + scale) + shift + # Compute: value = * x + value = tXrX.load() + if cutlass.const_expr(isinstance(tGrG, cute.Tensor)): + value = tGrG.load() * value + # Compute: value = value + + if cutlass.const_expr(isinstance(tRrR, cute.Tensor)): + value = value + tRrR.load() + # Store: residual_out + if cutlass.const_expr(isinstance(tROrRO, cute.Tensor)): + tROrRO.store(value.to(tROrRO.element_type)) + copy_if(tROrRO, tROgRO) # rmem -> gmem + # Compute: value = norm(value) * + + tNrN = cute.make_rmem_tensor_like(tXrX, tXrX.element_type) + tNrN.store(value.to(tNrN.element_type)) + tNrN = norm(tNrN, tWrW, tBrB) + # Compute: value = value * (1 + ) + + value = tNrN.load() + copy_if(tSCgSC, tSCrSC) # gmem -> rmem + copy_if(tSHgSH, tSHrSH) # gmem -> rmem + if cutlass.const_expr(isinstance(tSCrSC, cute.Tensor)): + value = value * (1 + tSCrSC.load()) + if cutlass.const_expr(isinstance(tSHrSH, cute.Tensor)): + value = value + tSHrSH.load() + # Store: y + tYrY.store(value.to(tYrY.element_type)) + copy_if(tYrY, tYgY) # rmem -> gmem + + +def validate_x(t: torch.Tensor, B: int, S: int, D: int): + if t.dtype not in (torch.float16, torch.bfloat16, torch.float32): + raise ValueError(f"Validate failed: unsupported dtype: {t.dtype}") + if t.shape != (B, S, D): + raise ValueError(f"Validate failed: unsupported tensor shape: {t.shape}.") + if t.stride()[-1] != 1: + raise ValueError(f"Validate failed: not contiguous on dim D.") + + +def validate_weight_bias(t: Optional[torch.Tensor], B: int, S: int, D: int): + if t is None: + return + if t.dtype not in (torch.float16, torch.bfloat16, torch.float32): + raise ValueError(f"Validate failed: unsupported dtype: {t.dtype}") + if t.shape != (D,): + raise ValueError(f"Validate failed: unsupported tensor shape: {t.shape}.") + if t.stride()[-1] != 1: + raise ValueError(f"Validate failed: not contiguous on dim D.") + + +def validate_scale_shift(t: torch.Tensor, B: int, S: int, D: int): + if t.dtype not in (torch.float16, torch.bfloat16, torch.float32): + raise ValueError(f"Validate failed: unsupported dtype: {t.dtype}") + failed = False + if t.ndim == 1 and (t.shape[0] not in (1, D)): + failed = True + elif t.ndim == 2 and ((t.shape[0] not in (1, B)) or t.shape[1] != D): + failed = True + elif t.ndim == 3 and ( + (t.shape[0] not in (1, B)) or (t.shape[1] not in (1, S) or t.shape[2] != D) + ): + failed = True + elif t.ndim == 4 and (t.shape[0] != B or t.shape[2] != 1 or t.shape[3] != D): + F = t.shape[1] + if S % F != 0: + raise ValueError(f"Validate failed: S({S}) must be divisible by F({F}).") + failed = True + if failed: + raise ValueError(f"Validate failed: unsupported tensor shape: {t.shape}.") + if t.stride()[-1] != 1: + raise ValueError(f"Validate failed: not contiguous on dim D.") + + +def validate_gate(t: Union[torch.Tensor, int], B: int, S: int, D: int): + if not isinstance(t, torch.Tensor): + return + validate_scale_shift(t, B, S, D) + + +@torch.library.custom_op("sglang::fused_norm_scale_shift", mutates_args=()) +def fused_norm_scale_shift( + x: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + scale: torch.Tensor, + shift: torch.Tensor, + norm_type: str, + eps: float = 1e-5, +) -> torch.Tensor: + """ + Fuse: norm(x) * (1 + scale) + shift + where norm is either layernorm or rmsnorm. + + Expects: + - x: [B, S, D] + - weight/bias: None, [D] + - scale/shift: [1], [D], [1/B, D], [1/B, 1/S, D] or [B, F, 1, D] + - norm_type: str, "layer" or "rms" + - eps: Optional[float], default: 1e-5 + + D must be a multiple of 256 and <= 8192 to enable LDG.128 vectorized loads per + thread and avoid predicated loads (e.g., bounds checks such as `index < D`). + """ + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + # Tensor Validation + BSD = x.shape + validate_x(x, *BSD) + validate_weight_bias(weight, *BSD) + validate_weight_bias(bias, *BSD) + validate_scale_shift(scale, *BSD) + validate_scale_shift(shift, *BSD) + + if norm_type == "layer" or norm_type == "rms": + D = x.shape[-1] + if D % 256 != 0 or D > 8192: + raise ValueError( + f"D={D} not supported, must be multiple of 256 and <= 8192" + ) + y = torch.empty_like(x) # create output tensor + scale = broadcast_tensor_for_bsfd(scale, *x.shape) # handle various shapes + shift = broadcast_tensor_for_bsfd(shift, *x.shape) # handle various shapes + # Use scalar placeholders for None tensors as a workaround, since the CuTe DSL + # TVM-FFI backend does not support None parameters. scalar values do not result + # in code generation and have no impact on runtime performance. + weight = 1 if weight is None else weight + bias = 0 if bias is None else bias + ResOut, Residual, Gate = 0, 0, 1 + torch_tensors = [y, ResOut, Residual, x, Gate, weight, bias, scale, shift] + # Compile cache + hash_key = ScaleResidualNormScaleShift.make_hash_key(norm_type, *torch_tensors) + compiled_fn = _COMPILE_CACHE.get(hash_key) + if compiled_fn is None: + kernel = ScaleResidualNormScaleShift(D, norm_type) + fake_sig_args = [to_fake_cute_args(t) for t in torch_tensors] + compiled_fn = cute.compile( + kernel, *fake_sig_args, options="--enable-tvm-ffi" + ) + _COMPILE_CACHE[hash_key] = compiled_fn + # Execute + compiled_fn(*torch_tensors, eps, stream) + return y + else: + raise ValueError(f'norm_type must be one of "layer" and "rms"') + + +@fused_norm_scale_shift.register_fake +def _fused_norm_scale_shift_fake(x, weight, bias, scale, shift, norm_type, eps=1e-5): + y = x.new_empty(x.shape) + return y + + +@torch.library.custom_op( + "sglang::fused_scale_residual_norm_scale_shift", mutates_args=() +) +def fused_scale_residual_norm_scale_shift( + residual: torch.Tensor, + x: torch.Tensor, + gate: Optional[torch.Tensor], # Union[Optional[torch.Tensor], int] indeed + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + scale: torch.Tensor, + shift: torch.Tensor, + norm_type: str, + eps: float = 1e-5, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Fuse: norm(residual + gate * x) * (1 + scale) + shift + where norm is either layernorm or rmsnorm. + + Expects: + - residual, x: [B, S, D] + - gate: None, [1], [D], [1/B, D], [1/B, 1/S, D] or [B, F, 1, D] + - weight/bias: None, [D] + - scale/shift: [1], [D], [1/B, D], [1/B, 1/S, D] or [B, F, 1, D] + - norm_type: str, "layer" or "rms" + - eps: Optional[float], default: 1e-5 + + D must be a multiple of 256 and <= 8192 to enable LDG.128 vectorized loads per + thread and avoid predicated loads (e.g., bounds checks such as `index < D`). + """ + # Tensor Validation + BSD = x.shape + validate_x(x, *BSD) + validate_x(residual, *BSD) + validate_gate(gate, *BSD) + validate_weight_bias(weight, *BSD) + validate_weight_bias(bias, *BSD) + validate_scale_shift(scale, *BSD) + validate_scale_shift(shift, *BSD) + if norm_type == "layer" or norm_type == "rms": + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # if norm_type == "layer" or norm_type == "rms": + D = x.shape[-1] + if D % 256 != 0 or D > 8192: + raise ValueError( + f"D={D} not supported, must be multiple of 256 and <= 8192" + ) + y = torch.empty_like(x) # create output tensor + resi_out = torch.empty_like(x) # create output tensor + gate = broadcast_tensor_for_bsfd(gate, *x.shape) # handle various shapes + scale = broadcast_tensor_for_bsfd(scale, *x.shape) # handle various shapes + shift = broadcast_tensor_for_bsfd(shift, *x.shape) # handle various shapes + # Use scalar placeholders for None tensors as a workaround, since the CuTe DSL + # TVM-FFI backend does not support None parameters. scalar values do not result + # in code generation and have no impact on runtime performance. + gate = 1 if gate is None else gate + weight = 1 if weight is None else weight + bias = 0 if bias is None else bias + torch_tensors = [y, resi_out, residual, x, gate, weight, bias, scale, shift] + # Compile cache + hash_key = ScaleResidualNormScaleShift.make_hash_key(norm_type, *torch_tensors) + compiled_fn = _COMPILE_CACHE.get(hash_key) + if compiled_fn is None: + kernel = ScaleResidualNormScaleShift(D, norm_type) + fake_sig_args = [to_fake_cute_args(t) for t in torch_tensors] + compiled_fn = cute.compile( + kernel, *fake_sig_args, options="--enable-tvm-ffi" + ) + _COMPILE_CACHE[hash_key] = compiled_fn + # Execute + compiled_fn(*torch_tensors, eps, stream) + return y, resi_out + else: + raise ValueError(f'norm_type must be one of "layer" and "rms"') + + +@fused_scale_residual_norm_scale_shift.register_fake +def _fused_scale_residual_norm_scale_shift_fake( + residual, x, gate, weight, bias, scale, shift, norm_type, eps=1e-5 +): + y = x.new_empty(x.shape) + residual_out = x.new_empty(x.shape) + return y, residual_out diff --git a/sglang/python/sglang/jit_kernel/diffusion/cutedsl/utils.py b/sglang/python/sglang/jit_kernel/diffusion/cutedsl/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d23c2342b9a2de5b3a18096c0d07df51192c045b --- /dev/null +++ b/sglang/python/sglang/jit_kernel/diffusion/cutedsl/utils.py @@ -0,0 +1,10 @@ +import cutlass +import torch + +WARP_SIZE = 32 + +TORCH_TO_CUTE_DTYPE = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, +} diff --git a/sglang/python/sglang/jit_kernel/diffusion/triton/norm.py b/sglang/python/sglang/jit_kernel/diffusion/triton/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..17a5bb1ca6bf56af0cdd62388964370642d82faf --- /dev/null +++ b/sglang/python/sglang/jit_kernel/diffusion/triton/norm.py @@ -0,0 +1,620 @@ +from typing import Optional + +import torch +import triton # type: ignore +import triton.language as tl # type: ignore +from torch import Tensor + + +# RMSNorm-fp32 +def maybe_contiguous_lastdim(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def maybe_contiguous(x): + return x.contiguous() if x is not None else None + + +def triton_autotune_configs(): + # Return configs with a valid warp count for the current device + configs = [] + # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 + max_threads_per_block = 1024 + # Default to warp size 32 if not defined by device + warp_size = getattr( + torch.get_device_module().get_device_properties( + torch.get_device_module().current_device() + ), + "warp_size", + 32, + ) + if warp_size is None: + warp_size = 32 + # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit + return [ + triton.Config({}, num_warps=warp_count) + for warp_count in [1, 2, 4, 8, 16, 32] + if warp_count * warp_size <= max_threads_per_block + ] + # return [triton.Config({}, num_warps=8)] + + +# Copied from flash-attn +@triton.autotune( + configs=triton_autotune_configs(), + key=[ + "N", + "HAS_RESIDUAL", + "STORE_RESIDUAL_OUT", + "IS_RMS_NORM", + "HAS_BIAS", + "HAS_WEIGHT", + "HAS_X1", + "HAS_W1", + "HAS_B1", + ], +) +# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +# @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) +# @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) +# @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + X1, + W1, + B1, + Y1, + RESIDUAL_OUT, # pointer to the residual + ROWSCALE, + SEEDS, # Dropout seeds for each row + DROPOUT_MASK, + DROPOUT_MASK1, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + stride_x1_row, + stride_y1_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, # Dropout probability + zero_centered_weight, # If true, add 1.0 to the weight + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + STORE_DROPOUT_MASK: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_X1: tl.constexpr, + HAS_W1: tl.constexpr, + HAS_B1: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + if HAS_X1: + X1 += row * stride_x1_row + if HAS_W1: + Y1 += row * stride_y1_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + x *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = ( + tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + ) + x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) + if HAS_X1: + x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) + x1 *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = ( + tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) + > dropout_p + ) + x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N) + x += x1 + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + if HAS_WEIGHT: + y = x_hat * w + b if HAS_BIAS else x_hat * w + else: + y = x_hat + b if HAS_BIAS else x_hat + # Write output + tl.store(Y + cols, y, mask=mask) + if HAS_W1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 + if HAS_B1: + b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) + y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 + tl.store(Y1 + cols, y1, mask=mask) + + +def _layer_norm_fwd( + x: Tensor, + weight: Tensor, + bias: Tensor, + eps: float, + residual: Optional[Tensor] = None, + x1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + residual_dtype: Optional[torch.dtype] = None, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + return_dropout_mask: bool = False, + out: Optional[Tensor] = None, + residual_out: Optional[Tensor] = None, +) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): + # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library + # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None + # so that _layer_norm_fwd_impl doesn't have to return them. + if out is None: + out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + if residual is not None: + residual_dtype = residual.dtype + if residual_out is None and ( + residual is not None + or (residual_dtype is not None and residual_dtype != x.dtype) + or dropout_p > 0.0 + or rowscale is not None + or x1 is not None + ): + residual_out = torch.empty_like( + x, dtype=residual_dtype if residual_dtype is not None else x.dtype + ) + else: + residual_out = None + y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl( + x, + weight, + bias, + eps, + out, + residual=residual, + x1=x1, + weight1=weight1, + bias1=bias1, + dropout_p=dropout_p, + rowscale=rowscale, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + residual_out=residual_out, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 + if residual_out is None: + residual_out = x + return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 + + +# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema +# since we're returning a tuple of tensors +def _layer_norm_fwd_impl( + x: Tensor, + weight: Optional[Tensor], + bias: Tensor, + eps: float, + out: Tensor, + residual: Optional[Tensor] = None, + x1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + return_dropout_mask: bool = False, + residual_out: Optional[Tensor] = None, +) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + if weight is not None: + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if x1 is not None: + assert x1.shape == x.shape + assert rowscale is None + assert x1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + assert out.shape == x.shape + assert out.stride(-1) == 1 + if residual_out is not None: + assert residual_out.shape == x.shape + assert residual_out.stride(-1) == 1 + if weight1 is not None: + y1 = torch.empty_like(out) + assert y1.stride(-1) == 1 + else: + y1 = None + mean = ( + torch.empty((M,), dtype=torch.float32, device=x.device) + if not is_rms_norm + else None + ) + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + if dropout_p > 0.0: + seeds = torch.randint( + 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 + ) + else: + seeds = None + if return_dropout_mask and dropout_p > 0.0: + dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool) + if x1 is not None: + dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool) + else: + dropout_mask1 = None + else: + dropout_mask, dropout_mask1 = None, None + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + with torch.get_device_module().device(x.device.index): + torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)]( + x, + out, + weight if weight is not None else x, # unused when HAS_WEIGHT == False + bias, + residual, + x1, + weight1, + bias1, + y1, + residual_out, + rowscale, + seeds, + dropout_mask, + dropout_mask1, + mean, + rstd, + x.stride(0), + out.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + x1.stride(0) if x1 is not None else 0, + y1.stride(0) if y1 is not None else 0, + M, + N, + eps, + dropout_p, + # Passing bool make torch inductor very unhappy since it then tries to compare to int_max + int(zero_centered_weight), + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + weight is not None, + bias is not None, + dropout_p > 0.0, + dropout_mask is not None, + rowscale is not None, + HAS_X1=x1 is not None, + HAS_W1=weight1 is not None, + HAS_B1=bias1 is not None, + ) + return y1, mean, rstd, seeds, dropout_mask, dropout_mask1 + + +class LayerNormFn: + + @staticmethod + def forward( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out_dtype=None, + out=None, + residual_out=None, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1])) + if residual is not None: + assert residual.shape == x_shape_og + residual = maybe_contiguous_lastdim( + residual.reshape(-1, residual.shape[-1]) + ) + if x1 is not None: + assert x1.shape == x_shape_og + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1])) + # weight can be None when elementwise_affine=False for LayerNorm + if weight is not None: + weight = weight.contiguous() + bias = maybe_contiguous(bias) + weight1 = maybe_contiguous(weight1) + bias1 = maybe_contiguous(bias1) + if rowscale is not None: + rowscale = rowscale.reshape(-1).contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + if out is not None: + out = out.reshape(-1, out.shape[-1]) + if residual_out is not None: + residual_out = residual_out.reshape(-1, residual_out.shape[-1]) + y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = ( + _layer_norm_fwd( + x, + weight, + bias, + eps, + residual, + x1, + weight1, + bias1, + dropout_p=dropout_p, + rowscale=rowscale, + out_dtype=out_dtype, + residual_dtype=residual_dtype, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + out=out, + residual_out=residual_out, + ) + ) + y = y.reshape(x_shape_og) + if residual is not None: + residual_out = residual_out.reshape(x_shape_og) + return y, residual_out + return y + + +def layer_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out_dtype=None, + out=None, + residual_out=None, +): + return LayerNormFn.forward( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + is_rms_norm, + return_dropout_mask, + out_dtype, + out, + residual_out, + ) + + +@triton.jit +def _norm_infer_kernel( + X, + Y, + W, + B, + stride_x_row, + stride_y_row, + M, + N, + eps, + IS_RMS_NORM: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_WEIGHT: + W += 0 + if HAS_BIAS: + B += 0 + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + if HAS_WEIGHT: + w = tl.load(W + cols, mask=cols < N, other=1.0).to(tl.float32) + y = x_hat * w + else: + y = x_hat + if HAS_BIAS: + b = tl.load(B + cols, mask=cols < N, other=0.0).to(tl.float32) + y += b + tl.store(Y + cols, y, mask=cols < N) + + +def norm_infer( + x: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + eps: float, + is_rms_norm: bool = False, + out: Optional[Tensor] = None, +): + M, N = x.shape + x = x.contiguous() + if weight is not None: + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.shape == (N,) + assert bias.stride(-1) == 1 + if out is None: + out = torch.empty_like(x) + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + num_warps = min(max(BLOCK_N // 256, 1), 8) + _norm_infer_kernel[(M,)]( + x, + out, + weight if weight is not None else x, # dummy when HAS_WEIGHT=False + bias if bias is not None else x, # dummy when HAS_BIAS=False + x.stride(0), + out.stride(0), + M, + N, + eps, + IS_RMS_NORM=is_rms_norm, + HAS_WEIGHT=weight is not None, + HAS_BIAS=bias is not None, + BLOCK_N=BLOCK_N, + num_warps=num_warps, + ) + return out + + +def rms_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + return_dropout_mask=False, + out_dtype=None, + out=None, + residual_out=None, +): + return LayerNormFn.forward( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + True, + return_dropout_mask, + out_dtype, + out, + residual_out, + ) diff --git a/sglang/python/sglang/jit_kernel/diffusion/triton/npu_fallback.py b/sglang/python/sglang/jit_kernel/diffusion/triton/npu_fallback.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb65a9d0b2f8ccf0a7d01e750b33a8dfbaa2bec --- /dev/null +++ b/sglang/python/sglang/jit_kernel/diffusion/triton/npu_fallback.py @@ -0,0 +1,25 @@ +import torch + + +# TODO: remove this when triton ascend bug is fixed +def fuse_scale_shift_native( + x: torch.Tensor, + scale: torch.Tensor, + shift: torch.Tensor, + block_l: int = 128, + block_c: int = 128, +): + return x * (1 + scale) + shift + + +# TODO: remove this when triton ascend bug is fixed +def apply_rotary_embedding_native( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False +) -> torch.Tensor: + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + return torch.stack((o1, o2), dim=-1).flatten(-2) diff --git a/sglang/python/sglang/jit_kernel/diffusion/triton/rmsnorm_onepass.py b/sglang/python/sglang/jit_kernel/diffusion/triton/rmsnorm_onepass.py new file mode 100644 index 0000000000000000000000000000000000000000..df6f7bf21fc154e468830061609196ea8e7a7776 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/diffusion/triton/rmsnorm_onepass.py @@ -0,0 +1,58 @@ +import torch +import triton # type: ignore +import triton.language as tl # type: ignore + + +# Adapted from https://github.com/ModelTC/LightX2V/blob/main/lightx2v/common/ops/norm/triton_ops.py#L905-L956 +@triton.jit +def _rms_norm_tiled_onepass( + y_ptr, + x_ptr, + w_ptr, + SEQ: tl.constexpr, + DIM: tl.constexpr, + EPS: tl.constexpr, + BLOCK_SIZE_SEQ: tl.constexpr, + BLOCK_SIZE_DIM: tl.constexpr, +): + seq_blk_id = tl.program_id(0) + seq_id = seq_blk_id * BLOCK_SIZE_SEQ + + seq_offset = seq_id + tl.arange(0, BLOCK_SIZE_SEQ)[:, None] + s_mask = seq_offset < SEQ + d_offset = tl.arange(0, BLOCK_SIZE_DIM)[None, :] + d_mask = d_offset < DIM + y_blk = y_ptr + seq_offset * DIM + d_offset + x_blk = x_ptr + seq_offset * DIM + d_offset + mask = s_mask & d_mask + + x = tl.load(x_blk, mask=mask, other=0.0).to(tl.float32) + mean_square = tl.sum(x * x, axis=1, keep_dims=True) / DIM + rstd = tl.math.rsqrt(mean_square + EPS) + w = tl.load(w_ptr + d_offset, mask=d_mask) + tl.store(y_blk, x * rstd * w, mask=mask) + + +def triton_one_pass_rms_norm(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6): + shape = x.shape + x = x.contiguous() + y = torch.empty_like(x) + x_view = x.reshape(-1, shape[-1]) + y_view = y.reshape(-1, shape[-1]) + S, D = x_view.shape + + BLOCK_SIZE_SEQ = min(16, triton.next_power_of_2(max(1, S // 512))) + grid = (triton.cdiv(S, BLOCK_SIZE_SEQ),) + + with torch.get_device_module().device(x.device): + torch.library.wrap_triton(_rms_norm_tiled_onepass)[grid]( + y_view, + x_view, + w, + S, + D, + eps, + BLOCK_SIZE_DIM=triton.next_power_of_2(D), + BLOCK_SIZE_SEQ=BLOCK_SIZE_SEQ, + ) + return y diff --git a/sglang/python/sglang/jit_kernel/diffusion/triton/rotary.py b/sglang/python/sglang/jit_kernel/diffusion/triton/rotary.py new file mode 100644 index 0000000000000000000000000000000000000000..067a6ceb41dec7906d5e742df919c4565d1e27fa --- /dev/null +++ b/sglang/python/sglang/jit_kernel/diffusion/triton/rotary.py @@ -0,0 +1,113 @@ +import torch +import triton # type: ignore +import triton.language as tl # type: ignore + +from sglang.multimodal_gen.runtime.platforms import current_platform + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_HS_HALF": 32}, num_warps=2), + triton.Config({"BLOCK_HS_HALF": 64}, num_warps=4), + triton.Config({"BLOCK_HS_HALF": 128}, num_warps=4), + triton.Config({"BLOCK_HS_HALF": 256}, num_warps=8), + ], + key=["head_size", "interleaved"], +) +@triton.jit +def _rotary_embedding_kernel( + output_ptr, + x_ptr, + cos_ptr, + sin_ptr, + num_heads, + head_size, + num_tokens, + stride_x_row, + stride_cos_row, + stride_sin_row, + interleaved: tl.constexpr, + BLOCK_HS_HALF: tl.constexpr, +): + row_idx = tl.program_id(0) + token_idx = (row_idx // num_heads) % num_tokens + + x_row_ptr = x_ptr + row_idx * stride_x_row + cos_row_ptr = cos_ptr + token_idx * stride_cos_row + sin_row_ptr = sin_ptr + token_idx * stride_sin_row + output_row_ptr = output_ptr + row_idx * stride_x_row + + # half size for x1 and x2 + head_size_half = head_size // 2 + + for block_start in range(0, head_size_half, BLOCK_HS_HALF): + offsets_half = block_start + tl.arange(0, BLOCK_HS_HALF) + mask = offsets_half < head_size_half + + cos_vals = tl.load(cos_row_ptr + offsets_half, mask=mask, other=0.0) + sin_vals = tl.load(sin_row_ptr + offsets_half, mask=mask, other=0.0) + + offsets_x1 = 2 * offsets_half + offsets_x2 = 2 * offsets_half + 1 + + x1_vals = tl.load(x_row_ptr + offsets_x1, mask=mask, other=0.0) + x2_vals = tl.load(x_row_ptr + offsets_x2, mask=mask, other=0.0) + + x1_fp32 = x1_vals.to(tl.float32) + x2_fp32 = x2_vals.to(tl.float32) + cos_fp32 = cos_vals.to(tl.float32) + sin_fp32 = sin_vals.to(tl.float32) + o1_vals = tl.fma(-x2_fp32, sin_fp32, x1_fp32 * cos_fp32) + o2_vals = tl.fma(x1_fp32, sin_fp32, x2_fp32 * cos_fp32) + + tl.store(output_row_ptr + offsets_x1, o1_vals.to(x1_vals.dtype), mask=mask) + tl.store(output_row_ptr + offsets_x2, o2_vals.to(x2_vals.dtype), mask=mask) + + +def apply_rotary_embedding( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False +) -> torch.Tensor: + output = torch.empty_like(x) + + if x.dim() > 3: + bsz, num_tokens, num_heads, head_size = x.shape + else: + num_tokens, num_heads, head_size = x.shape + bsz = 1 + + assert head_size % 2 == 0, "head_size must be divisible by 2" + + x_reshaped = x.view(-1, head_size) + output_reshaped = output.view(-1, head_size) + + # num_tokens per head, 1 token per block + grid = (bsz * num_tokens * num_heads,) + + if interleaved and cos.shape[-1] == head_size: + cos = cos[..., ::2].contiguous() + sin = sin[..., ::2].contiguous() + else: + cos = cos.contiguous() + sin = sin.contiguous() + + _rotary_embedding_kernel[grid]( + output_reshaped, + x_reshaped, + cos, + sin, + num_heads, + head_size, + num_tokens, + x_reshaped.stride(0), + cos.stride(0), + sin.stride(0), + interleaved, + ) + + return output + + +if current_platform.is_npu(): + from .npu_fallback import apply_rotary_embedding_native + + apply_rotary_embedding = apply_rotary_embedding_native diff --git a/sglang/python/sglang/jit_kernel/diffusion/triton/scale_shift.py b/sglang/python/sglang/jit_kernel/diffusion/triton/scale_shift.py new file mode 100644 index 0000000000000000000000000000000000000000..b3a27b64c4a23e9cd33119ac01833fb4638902cc --- /dev/null +++ b/sglang/python/sglang/jit_kernel/diffusion/triton/scale_shift.py @@ -0,0 +1,413 @@ +import torch +import triton # type: ignore +import triton.language as tl # type: ignore + +from sglang.multimodal_gen.runtime.platforms import current_platform + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_N": 64}, num_warps=2), + triton.Config({"BLOCK_N": 128}, num_warps=4), + triton.Config({"BLOCK_N": 256}, num_warps=4), + triton.Config({"BLOCK_N": 512}, num_warps=4), + triton.Config({"BLOCK_N": 1024}, num_warps=8), + ], + key=["inner_dim"], +) +@triton.jit +def _fused_scale_shift_4d_kernel( + output_ptr, + normalized_ptr, + scale_ptr, + shift_ptr, + scale_constant: tl.constexpr, # scale_constant is either 0 or 1. + rows, + inner_dim, + seq_len, + num_frames, + frame_seqlen, + BLOCK_N: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_col = tl.program_id(1) + + col_offsets = pid_col * BLOCK_N + tl.arange(0, BLOCK_N) + mask = col_offsets < inner_dim + + # Pointers for normalized and output + row_base = pid_row * inner_dim + norm_ptrs = normalized_ptr + row_base + col_offsets + out_ptrs = output_ptr + row_base + col_offsets + + # Pointers for scale (per-frame) and shift (per-token) + b_idx = pid_row // seq_len + t_idx = pid_row % seq_len + frame_idx_in_batch = t_idx // frame_seqlen + + scale_row_idx = b_idx * num_frames + frame_idx_in_batch + scale_ptrs = scale_ptr + scale_row_idx * inner_dim + col_offsets + # shift is per-token [B*L, C], indexed by pid_row directly + shift_ptrs = shift_ptr + pid_row * inner_dim + col_offsets + + normalized = tl.load(norm_ptrs, mask=mask, other=0.0) + scale = tl.load(scale_ptrs, mask=mask, other=0.0) + shift = tl.load(shift_ptrs, mask=mask, other=0.0) + + scale_const_tensor = tl.full([BLOCK_N], scale_constant, dtype=scale.dtype) + output = normalized * (scale_const_tensor + scale) + shift + + tl.store(out_ptrs, output, mask=mask) + + +@triton.jit +def fuse_scale_shift_kernel_blc_opt( + x_ptr, + shift_ptr, + scale_ptr, + scale_constant: tl.constexpr, # scale_constant is either 0 or 1., + y_ptr, + B, + L, + C, + stride_x_b, + stride_x_l, + stride_x_c, + stride_s_b, + stride_s_l, + stride_s_c, + stride_sc_b, + stride_sc_l, + stride_sc_c, + SCALE_IS_SCALAR: tl.constexpr, + SHIFT_IS_SCALAR: tl.constexpr, + BLOCK_L: tl.constexpr, + BLOCK_C: tl.constexpr, +): + pid_l = tl.program_id(0) + pid_c = tl.program_id(1) + pid_b = tl.program_id(2) + + l_offsets = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) + c_offsets = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + + mask_l = l_offsets < L + mask_c = c_offsets < C + mask = mask_l[:, None] & mask_c[None, :] + + x_off = ( + pid_b * stride_x_b + + l_offsets[:, None] * stride_x_l + + c_offsets[None, :] * stride_x_c + ) + x = tl.load(x_ptr + x_off, mask=mask, other=0) + + if SHIFT_IS_SCALAR: + shift_val = tl.load(shift_ptr) + shift = tl.full((BLOCK_L, BLOCK_C), shift_val, dtype=shift_val.dtype) + else: + s_off = ( + pid_b * stride_s_b + + l_offsets[:, None] * stride_s_l + + c_offsets[None, :] * stride_s_c + ) + shift = tl.load(shift_ptr + s_off, mask=mask, other=0) + + if SCALE_IS_SCALAR: + scale_val = tl.load(scale_ptr) + scale = tl.full((BLOCK_L, BLOCK_C), scale_val, dtype=scale_val.dtype) + else: + sc_off = ( + pid_b * stride_sc_b + + l_offsets[:, None] * stride_sc_l + + c_offsets[None, :] * stride_sc_c + ) + scale = tl.load(scale_ptr + sc_off, mask=mask, other=0) + + y = x * (scale_constant + scale) + shift + tl.store(y_ptr + x_off, y, mask=mask) + + +@triton.jit +def fuse_scale_shift_gate_select01_kernel_blc_opt( + x_ptr, + shift0_ptr, + scale0_ptr, + gate0_ptr, + shift1_ptr, + scale1_ptr, + gate1_ptr, + index_ptr, + y_ptr, + gate_out_ptr, + B, + L, + C, + stride_x_b, + stride_x_l, + stride_x_c, + stride_s0_b, + stride_s0_c, + stride_sc0_b, + stride_sc0_c, + stride_g0_b, + stride_g0_c, + stride_s1_b, + stride_s1_c, + stride_sc1_b, + stride_sc1_c, + stride_g1_b, + stride_g1_c, + stride_i_b, + stride_i_l, + stride_go_b, + stride_go_l, + stride_go_c, + BLOCK_L: tl.constexpr, + BLOCK_C: tl.constexpr, +): + pid_l = tl.program_id(0) + pid_c = tl.program_id(1) + pid_b = tl.program_id(2) + + l_offsets = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) + c_offsets = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + + mask_l = l_offsets < L + mask_c = c_offsets < C + mask = mask_l[:, None] & mask_c[None, :] + + x_off = ( + pid_b * stride_x_b + + l_offsets[:, None] * stride_x_l + + c_offsets[None, :] * stride_x_c + ) + x = tl.load(x_ptr + x_off, mask=mask, other=0) + + idx_off = pid_b * stride_i_b + l_offsets * stride_i_l + idx = tl.load(index_ptr + idx_off, mask=mask_l, other=0).to(tl.int1)[:, None] + + s0_off = pid_b * stride_s0_b + c_offsets[None, :] * stride_s0_c + sc0_off = pid_b * stride_sc0_b + c_offsets[None, :] * stride_sc0_c + g0_off = pid_b * stride_g0_b + c_offsets[None, :] * stride_g0_c + s1_off = pid_b * stride_s1_b + c_offsets[None, :] * stride_s1_c + sc1_off = pid_b * stride_sc1_b + c_offsets[None, :] * stride_sc1_c + g1_off = pid_b * stride_g1_b + c_offsets[None, :] * stride_g1_c + + shift0 = tl.load(shift0_ptr + s0_off, mask=mask_c[None, :], other=0) + scale0 = tl.load(scale0_ptr + sc0_off, mask=mask_c[None, :], other=0) + gate0 = tl.load(gate0_ptr + g0_off, mask=mask_c[None, :], other=0) + shift1 = tl.load(shift1_ptr + s1_off, mask=mask_c[None, :], other=0) + scale1 = tl.load(scale1_ptr + sc1_off, mask=mask_c[None, :], other=0) + gate1 = tl.load(gate1_ptr + g1_off, mask=mask_c[None, :], other=0) + + shift = tl.where(idx, shift1, shift0) + scale = tl.where(idx, scale1, scale0) + gate = tl.where(idx, gate1, gate0) + + y = x * (1 + scale) + shift + tl.store(y_ptr + x_off, y, mask=mask) + + go_off = ( + pid_b * stride_go_b + + l_offsets[:, None] * stride_go_l + + c_offsets[None, :] * stride_go_c + ) + tl.store(gate_out_ptr + go_off, gate, mask=mask) + + +def fuse_scale_shift_kernel( + x: torch.Tensor, + scale: torch.Tensor, + shift: torch.Tensor, + scale_constant: float = 1.0, + block_l: int = 128, + block_c: int = 128, +): + assert x.is_cuda and scale.is_cuda + assert x.is_contiguous() + + B, L, C = x.shape + output = torch.empty_like(x) + + if scale.dim() == 4: + # scale/shift: [B, F, 1, C] + rows = B * L + x_2d = x.view(rows, C) + output_2d = output.view(rows, C) + grid = lambda META: (rows, triton.cdiv(C, META["BLOCK_N"])) + num_frames = scale.shape[1] + assert ( + L % num_frames == 0 + ), "seq_len must be divisible by num_frames for 4D scale/shift" + frame_seqlen = L // num_frames + + # Compact scale [B, F, 1, C] -> [B*F, C] (per-frame) + scale_reshaped = scale.squeeze(2).reshape(-1, C).contiguous() + # shift is per-token [B, L, C] -> [B*L, C] + shift_reshaped = shift.reshape(rows, C).contiguous() + + _fused_scale_shift_4d_kernel[grid]( + output_2d, + x_2d, + scale_reshaped, + shift_reshaped, + scale_constant, + rows, + C, + L, + num_frames, + frame_seqlen, + ) + else: + # 2D: [B, C] or [1, C] -> treat as [B, 1, C] and broadcast over L + # 3D: [B, L, C] (or broadcastable variants like [B, 1, C], [1, L, C], [1, 1, C]) + # Also support scalar (0D or 1-element) + if scale.dim() == 0 or (scale.dim() == 1 and scale.numel() == 1): + scale_blc = scale.reshape(1) + elif scale.dim() == 2: + scale_blc = scale[:, None, :] + elif scale.dim() == 3: + scale_blc = scale + else: + raise ValueError("scale must be 0D/1D(1)/2D/3D or 4D") + + if shift.dim() == 0 or (shift.dim() == 1 and shift.numel() == 1): + shift_blc = shift.reshape(1) + elif shift.dim() == 2: + shift_blc = shift[:, None, :] + elif shift.dim() == 3: + shift_blc = shift + else: + # broadcast later via expand if possible + shift_blc = shift + + need_scale_scalar = scale_blc.dim() == 1 and scale_blc.numel() == 1 + need_shift_scalar = shift_blc.dim() == 1 and shift_blc.numel() == 1 + + if not need_scale_scalar: + scale_exp = scale_blc.expand(B, L, C) + s_sb, s_sl, s_sc = scale_exp.stride() + else: + s_sb = s_sl = s_sc = 0 + + if not need_shift_scalar: + shift_exp = shift_blc.expand(B, L, C) + sh_sb, sh_sl, sh_sc = shift_exp.stride() + else: + sh_sb = sh_sl = sh_sc = 0 + + # If both scalars and both zero, copy fast-path + if need_scale_scalar and need_shift_scalar: + if not ( + scale_blc.any().to("cpu", non_blocking=True) + or shift_blc.any().to("cpu", non_blocking=True) + ): + output.copy_(x) + return output + + grid = (triton.cdiv(L, block_l), triton.cdiv(C, block_c), B) + fuse_scale_shift_kernel_blc_opt[grid]( + x, + shift_blc if need_shift_scalar else shift_exp, + scale_blc if need_scale_scalar else scale_exp, + scale_constant, + output, + B, + L, + C, + x.stride(0), + x.stride(1), + x.stride(2), + sh_sb, + sh_sl, + sh_sc, + s_sb, + s_sl, + s_sc, + SCALE_IS_SCALAR=need_scale_scalar, + SHIFT_IS_SCALAR=need_shift_scalar, + BLOCK_L=block_l, + BLOCK_C=block_c, + num_warps=4, + num_stages=2, + ) + return output + + +def fuse_scale_shift_gate_select01_kernel( + x: torch.Tensor, + scale0: torch.Tensor, + shift0: torch.Tensor, + gate0: torch.Tensor, + scale1: torch.Tensor, + shift1: torch.Tensor, + gate1: torch.Tensor, + index: torch.Tensor, + block_l: int = 128, + block_c: int = 128, +): + assert x.is_contiguous() + B, L, C = x.shape + output = torch.empty_like(x) + gate_out = torch.empty_like(x) + + if ( + scale0.dim() != 2 + or shift0.dim() != 2 + or gate0.dim() != 2 + or scale1.dim() != 2 + or shift1.dim() != 2 + or gate1.dim() != 2 + ): + raise ValueError("scale0/shift0/gate0/scale1/shift1/gate1 must be 2D [B, C]") + if index.dim() != 2: + raise ValueError("index must be 2D [B, L]") + + grid = (triton.cdiv(L, block_l), triton.cdiv(C, block_c), B) + fuse_scale_shift_gate_select01_kernel_blc_opt[grid]( + x, + shift0, + scale0, + gate0, + shift1, + scale1, + gate1, + index, + output, + gate_out, + B, + L, + C, + x.stride(0), + x.stride(1), + x.stride(2), + shift0.stride(0), + shift0.stride(1), + scale0.stride(0), + scale0.stride(1), + gate0.stride(0), + gate0.stride(1), + shift1.stride(0), + shift1.stride(1), + scale1.stride(0), + scale1.stride(1), + gate1.stride(0), + gate1.stride(1), + index.stride(0), + index.stride(1), + gate_out.stride(0), + gate_out.stride(1), + gate_out.stride(2), + BLOCK_L=block_l, + BLOCK_C=block_c, + num_warps=4, + num_stages=2, + ) + return output, gate_out + + +if current_platform.is_npu(): + from .npu_fallback import fuse_scale_shift_native + + fuse_scale_shift_kernel = fuse_scale_shift_native diff --git a/sglang/python/sglang/jit_kernel/flash_attention_v4.py b/sglang/python/sglang/jit_kernel/flash_attention_v4.py new file mode 100644 index 0000000000000000000000000000000000000000..8958ae1c0505e323e200a4483e1e3b2621940455 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/flash_attention_v4.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +from typing import Callable, Optional, Tuple, Union + +import torch + +try: + from sgl_fa4.cute import flash_attn_varlen_func as _flash_attn_varlen_func +except Exception as _e: # pragma: no cover + _flash_attn_varlen_func = None + _flash_attn_import_error = _e +else: + _flash_attn_import_error = None + + +def _maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + page_table: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + softcap: Optional[float] = None, + window_size: Tuple[Optional[int], Optional[int]] = (-1, -1), + learnable_sink: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + score_mod: Optional[Callable] = None, + aux_tensors: Optional[list] = None, + return_softmax_lse: bool = False, + **_: object, +): + if _flash_attn_varlen_func is None: # pragma: no cover + raise ImportError( + "Vendored FlashAttention CUTE is not available (cannot import " + "sgl_fa4.cute). Please check your source tree." + ) from _flash_attn_import_error + + q, k, v = [_maybe_contiguous(t) for t in (q, k, v)] + cu_seqlens_q, cu_seqlens_k = [ + _maybe_contiguous(t) for t in (cu_seqlens_q, cu_seqlens_k) + ] + seqused_q, seqused_k = [_maybe_contiguous(t) for t in (seqused_q, seqused_k)] + page_table = _maybe_contiguous(page_table) + + if learnable_sink is None and sinks is not None: + learnable_sink = sinks + + if window_size == (-1, -1): + window_size = (None, None) + + result = _flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + page_table=page_table, + softmax_scale=softmax_scale, + causal=causal, + softcap=softcap, + window_size=window_size, + learnable_sink=learnable_sink, + num_splits=num_splits, + pack_gqa=pack_gqa, + score_mod=score_mod, + aux_tensors=aux_tensors, + ) + + if return_softmax_lse: + return result + if isinstance(result, tuple): + return result[0] + return result + + +def flash_attn_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k: Optional[torch.Tensor] = None, + v: Optional[torch.Tensor] = None, + qv: Optional[torch.Tensor] = None, + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + cache_seqlens: Optional[Union[int, torch.Tensor]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + attention_chunk: Optional[int] = None, + softcap: float = 0.0, + rotary_interleaved: bool = True, + scheduler_metadata=None, + num_splits: int = 0, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, + sinks: Optional[torch.Tensor] = None, + score_mod: Optional[Callable] = None, + aux_tensors: Optional[list] = None, + return_softmax_lse: bool = False, + **_: object, +): + if k is not None or v is not None or qv is not None: + raise NotImplementedError("FA4 does not support updating KV cache in-place.") + if rotary_cos is not None or rotary_sin is not None or rotary_seqlens is not None: + raise NotImplementedError("FA4 path does not support rotary embedding.") + if cache_batch_idx is not None or cache_leftpad is not None: + raise NotImplementedError( + "FA4 path does not support non-consecutive batch indices or left padding." + ) + if q_descale is not None or k_descale is not None or v_descale is not None: + raise NotImplementedError("FA4 path does not support descale.") + + if isinstance(cache_seqlens, int): + cache_seqlens = torch.full( + (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device + ) + + result = flash_attn_varlen_func( + q=q, + k=k_cache, + v=v_cache, + cu_seqlens_q=cu_seqlens_q, + seqused_k=cache_seqlens, + max_seqlen_q=max_seqlen_q, + page_table=page_table, + softmax_scale=softmax_scale, + causal=causal, + softcap=softcap if softcap != 0.0 else None, + window_size=window_size, + num_splits=num_splits if num_splits != 0 else 1, + pack_gqa=pack_gqa, + learnable_sink=sinks, + score_mod=score_mod, + aux_tensors=aux_tensors, + return_softmax_lse=True, + ) + + if return_softmax_lse: + return result + if isinstance(result, tuple): + return result[0] + return result diff --git a/sglang/python/sglang/jit_kernel/fused_metadata_copy.py b/sglang/python/sglang/jit_kernel/fused_metadata_copy.py new file mode 100644 index 0000000000000000000000000000000000000000..b4d347f6abadd671fe7d7dae8d349c3b832c7157 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/fused_metadata_copy.py @@ -0,0 +1,316 @@ +""" +Fused metadata copy kernel for NSA backend CUDA graph replay. + +This module provides JIT-compiled CUDA kernels for fusing multiple tensor +copy operations into single kernel launches, reducing kernel launch overhead +and improving CUDA graph replay performance. + +The kernels are compiled on-demand using TVM FFI and cached for subsequent use. +""" + +from __future__ import annotations + +import logging +from typing import Optional + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# JIT Module Compilation +# ============================================================================ + + +@cache_once +def _jit_fused_metadata_copy_module( + forward_mode: int, has_real_page_table: bool, has_flashmla: bool +): + """Compile JIT module for single-backend fused metadata copy. + + Args: + forward_mode: 0=DECODE, 1=TARGET_VERIFY, 2=DRAFT_EXTEND + has_real_page_table: Whether real_page_table tensors are used + has_flashmla: Whether FlashMLA metadata tensors are used + """ + args = make_cpp_args(forward_mode, has_real_page_table, has_flashmla) + try: + return load_jit( + "fused_metadata_copy", + *args, + cuda_files=["elementwise/fused_metadata_copy.cuh"], + cuda_wrappers=[ + ( + "fused_metadata_copy", + f"FusedMetadataCopyKernel<{args}>::run", + ) + ], + ) + except Exception as e: + logger.error( + f"Failed to compile JIT fused metadata copy kernel " + f"(forward_mode={forward_mode}, has_real_page_table={has_real_page_table}, " + f"has_flashmla={has_flashmla}): {e}" + ) + raise + + +@cache_once +def _jit_fused_metadata_copy_multi_module( + has_real_page_table: bool, has_flashmla: bool +): + """Compile JIT module for multi-backend fused metadata copy (DECODE mode only). + + Args: + has_real_page_table: Whether real_page_table tensors are used + has_flashmla: Whether FlashMLA metadata tensors are used + """ + args = make_cpp_args(has_real_page_table, has_flashmla) + try: + return load_jit( + "fused_metadata_copy_multi", + *args, + cuda_files=["elementwise/fused_metadata_copy.cuh"], + cuda_wrappers=[ + ( + "fused_metadata_copy_multi", + f"FusedMetadataCopyMultiKernel<{args}>::run", + ) + ], + ) + except Exception as e: + logger.error( + f"Failed to compile JIT fused metadata copy multi kernel " + f"(has_real_page_table={has_real_page_table}, has_flashmla={has_flashmla}): {e}" + ) + raise + + +# ============================================================================ +# Public API +# ============================================================================ + + +def fused_metadata_copy_cuda( + cache_seqlens_src: torch.Tensor, + cu_seqlens_k_src: torch.Tensor, + page_indices_src: torch.Tensor, + nsa_cache_seqlens_src: torch.Tensor, + seqlens_expanded_src: Optional[torch.Tensor], + nsa_cu_seqlens_k_src: torch.Tensor, + real_page_table_src: Optional[torch.Tensor], + flashmla_num_splits_src: Optional[torch.Tensor], + flashmla_metadata_src: Optional[torch.Tensor], + cache_seqlens_dst: torch.Tensor, + cu_seqlens_k_dst: torch.Tensor, + page_table_1_dst: torch.Tensor, + nsa_cache_seqlens_dst: torch.Tensor, + seqlens_expanded_dst: Optional[torch.Tensor], + nsa_cu_seqlens_k_dst: torch.Tensor, + real_page_table_dst: Optional[torch.Tensor], + flashmla_num_splits_dst: Optional[torch.Tensor], + flashmla_metadata_dst: Optional[torch.Tensor], + forward_mode: int, + bs: int, + max_len: int, + max_seqlen_k: int, + seqlens_expanded_size: int, +) -> None: + """ + Fused metadata copy kernel for NSA backend CUDA graph replay. + + This function fuses multiple tensor copy operations into a single kernel launch, + reducing kernel launch overhead and improving performance. + + Args: + cache_seqlens_src: Source cache sequence lengths [bs] + cu_seqlens_k_src: Source cumulative sequence lengths [bs+1] + page_indices_src: Source page indices [rows, max_len] + nsa_cache_seqlens_src: Source NSA cache sequence lengths [size] + seqlens_expanded_src: Optional source expanded sequence lengths [size] (required for TARGET_VERIFY/DRAFT_EXTEND) + nsa_cu_seqlens_k_src: Source NSA cumulative sequence lengths [size+1] + real_page_table_src: Optional source real page table [rows, cols] + flashmla_num_splits_src: Optional source FlashMLA num_splits [size+1] + flashmla_metadata_src: Optional source FlashMLA metadata tensor + cache_seqlens_dst: Destination cache sequence lengths [bs] + cu_seqlens_k_dst: Destination cumulative sequence lengths [bs+1] + page_table_1_dst: Destination page table [rows, stride] + nsa_cache_seqlens_dst: Destination NSA cache sequence lengths [size] + seqlens_expanded_dst: Optional destination expanded sequence lengths [size] (required for TARGET_VERIFY/DRAFT_EXTEND) + nsa_cu_seqlens_k_dst: Destination NSA cumulative sequence lengths [size+1] + real_page_table_dst: Optional destination real page table [rows, cols] + flashmla_num_splits_dst: Optional destination FlashMLA num_splits [size+1] + flashmla_metadata_dst: Optional destination FlashMLA metadata tensor + forward_mode: Forward mode (0=DECODE, 1=TARGET_VERIFY, 2=DRAFT_EXTEND) + bs: Batch size + max_len: Maximum length for decode/draft_extend mode + max_seqlen_k: Maximum sequence length for target_verify mode + seqlens_expanded_size: Size of expanded sequence lengths + """ + # Determine template parameters for kernel specialization + has_real_page_table = real_page_table_src is not None + has_flashmla = flashmla_num_splits_src is not None + + # Get JIT-compiled module for this configuration (cached after first use) + module = _jit_fused_metadata_copy_module( + forward_mode, has_real_page_table, has_flashmla + ) + + # Ensure all required source tensors are contiguous (required for kernel's linear indexing) + # This matches the CHECK_INPUT checks in the verified sgl-kernel implementation + cache_seqlens_src = cache_seqlens_src.contiguous() + cu_seqlens_k_src = cu_seqlens_k_src.contiguous() + page_indices_src = page_indices_src.contiguous() + nsa_cache_seqlens_src = nsa_cache_seqlens_src.contiguous() + if seqlens_expanded_src is not None: + seqlens_expanded_src = seqlens_expanded_src.contiguous() + nsa_cu_seqlens_k_src = nsa_cu_seqlens_k_src.contiguous() + + # Call JIT-compiled kernel (None values are passed as Optional with no value) + module.fused_metadata_copy( + cache_seqlens_src, + cu_seqlens_k_src, + page_indices_src, + nsa_cache_seqlens_src, + seqlens_expanded_src, + nsa_cu_seqlens_k_src, + real_page_table_src, + flashmla_num_splits_src, + flashmla_metadata_src, + cache_seqlens_dst, + cu_seqlens_k_dst, + page_table_1_dst, + nsa_cache_seqlens_dst, + seqlens_expanded_dst, + nsa_cu_seqlens_k_dst, + real_page_table_dst, + flashmla_num_splits_dst, + flashmla_metadata_dst, + bs, + max_len, + max_seqlen_k, + seqlens_expanded_size, + ) + + +def fused_metadata_copy_multi_cuda( + cache_seqlens_src: torch.Tensor, + cu_seqlens_k_src: torch.Tensor, + page_indices_src: torch.Tensor, + nsa_cache_seqlens_src: torch.Tensor, + nsa_cu_seqlens_k_src: torch.Tensor, + real_page_table_src: Optional[torch.Tensor], + flashmla_num_splits_src: Optional[torch.Tensor], + flashmla_metadata_src: Optional[torch.Tensor], + cache_seqlens_dst0: torch.Tensor, + cu_seqlens_k_dst0: torch.Tensor, + page_table_1_dst0: torch.Tensor, + nsa_cache_seqlens_dst0: torch.Tensor, + nsa_cu_seqlens_k_dst0: torch.Tensor, + real_page_table_dst0: Optional[torch.Tensor], + flashmla_num_splits_dst0: Optional[torch.Tensor], + flashmla_metadata_dst0: Optional[torch.Tensor], + cache_seqlens_dst1: torch.Tensor, + cu_seqlens_k_dst1: torch.Tensor, + page_table_1_dst1: torch.Tensor, + nsa_cache_seqlens_dst1: torch.Tensor, + nsa_cu_seqlens_k_dst1: torch.Tensor, + real_page_table_dst1: Optional[torch.Tensor], + flashmla_num_splits_dst1: Optional[torch.Tensor], + flashmla_metadata_dst1: Optional[torch.Tensor], + cache_seqlens_dst2: torch.Tensor, + cu_seqlens_k_dst2: torch.Tensor, + page_table_1_dst2: torch.Tensor, + nsa_cache_seqlens_dst2: torch.Tensor, + nsa_cu_seqlens_k_dst2: torch.Tensor, + real_page_table_dst2: Optional[torch.Tensor], + flashmla_num_splits_dst2: Optional[torch.Tensor], + flashmla_metadata_dst2: Optional[torch.Tensor], + bs: int, + max_len: int, + seqlens_expanded_size: int, +) -> None: + """ + Multi-backend fused metadata copy kernel for NSA backend CUDA graph replay. + + This function copies metadata from one source to THREE destinations in a single + kernel launch, eliminating the overhead of 3 separate kernel calls. Currently + only supports DECODE mode, which is the most common case. + + Args: + cache_seqlens_src: Source cache sequence lengths [bs] + cu_seqlens_k_src: Source cumulative sequence lengths [bs+1] + page_indices_src: Source page indices [bs, max_len] + nsa_cache_seqlens_src: Source NSA cache sequence lengths [bs] + nsa_cu_seqlens_k_src: Source NSA cumulative sequence lengths [bs+1] + real_page_table_src: Optional source real page table [bs, cols] + flashmla_num_splits_src: Optional source FlashMLA num_splits [bs+1] + flashmla_metadata_src: Optional source FlashMLA metadata tensor + cache_seqlens_dst0-2: Destination cache sequence lengths for backends 0-2 + cu_seqlens_k_dst0-2: Destination cumulative sequence lengths for backends 0-2 + page_table_1_dst0-2: Destination page tables for backends 0-2 + nsa_cache_seqlens_dst0-2: Destination NSA cache sequence lengths for backends 0-2 + nsa_cu_seqlens_k_dst0-2: Destination NSA cumulative sequence lengths for backends 0-2 + real_page_table_dst0-2: Optional destination real page tables for backends 0-2 + flashmla_num_splits_dst0-2: Optional destination FlashMLA num_splits for backends 0-2 + flashmla_metadata_dst0-2: Optional destination FlashMLA metadata tensors for backends 0-2 + bs: Batch size + max_len: Maximum length for decode mode + seqlens_expanded_size: Size of expanded sequence lengths + """ + # Determine template parameters for kernel specialization + has_real_page_table = real_page_table_src is not None + has_flashmla = flashmla_num_splits_src is not None + + # Get JIT-compiled module for this configuration (cached after first use) + module = _jit_fused_metadata_copy_multi_module(has_real_page_table, has_flashmla) + + # Ensure all source tensors are contiguous (required for kernel's linear indexing) + # This matches the CHECK_INPUT checks in the verified sgl-kernel implementation + cache_seqlens_src = cache_seqlens_src.contiguous() + cu_seqlens_k_src = cu_seqlens_k_src.contiguous() + page_indices_src = page_indices_src.contiguous() + nsa_cache_seqlens_src = nsa_cache_seqlens_src.contiguous() + nsa_cu_seqlens_k_src = nsa_cu_seqlens_k_src.contiguous() + + # Call JIT-compiled kernel (None values are passed as Optional with no value) + module.fused_metadata_copy_multi( + cache_seqlens_src, + cu_seqlens_k_src, + page_indices_src, + nsa_cache_seqlens_src, + nsa_cu_seqlens_k_src, + real_page_table_src, + flashmla_num_splits_src, + flashmla_metadata_src, + cache_seqlens_dst0, + cu_seqlens_k_dst0, + page_table_1_dst0, + nsa_cache_seqlens_dst0, + nsa_cu_seqlens_k_dst0, + real_page_table_dst0, + flashmla_num_splits_dst0, + flashmla_metadata_dst0, + cache_seqlens_dst1, + cu_seqlens_k_dst1, + page_table_1_dst1, + nsa_cache_seqlens_dst1, + nsa_cu_seqlens_k_dst1, + real_page_table_dst1, + flashmla_num_splits_dst1, + flashmla_metadata_dst1, + cache_seqlens_dst2, + cu_seqlens_k_dst2, + page_table_1_dst2, + nsa_cache_seqlens_dst2, + nsa_cu_seqlens_k_dst2, + real_page_table_dst2, + flashmla_num_splits_dst2, + flashmla_metadata_dst2, + bs, + max_len, + seqlens_expanded_size, + ) diff --git a/sglang/python/sglang/jit_kernel/fused_store_index_cache.py b/sglang/python/sglang/jit_kernel/fused_store_index_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..dcfbf45857496711b68d8fbe604e4fd3f8d25638 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/fused_store_index_cache.py @@ -0,0 +1,103 @@ +""" +This module provides JIT-compiled CUDA kernels for fusing multiple tensor +copy operations into single kernel launches, reducing kernel launch overhead +and improving CUDA graph replay performance. + +The kernels are compiled on-demand using TVM FFI and cached for subsequent use. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import ( + cache_once, + is_arch_support_pdl, + load_jit, + make_cpp_args, +) + +if TYPE_CHECKING: + from tvm_ffi.module import Module + +logger = logging.getLogger(__name__) + + +@cache_once +def _jit_nsa_fused_store_module( + key_dtype: torch.dtype, indices_dtype: torch.dtype, page_size: int +) -> Module: + """ + Build a JIT module that exposes: + module.fused_store_index_k_cache(input_bf16, index_k_with_scale_u8, loc_i64) + """ + args = make_cpp_args(key_dtype, indices_dtype, page_size, is_arch_support_pdl()) + return load_jit( + "fused_store_index_k_cache", + *args, + cuda_files=["nsa/fused_store_index_cache.cuh"], + cuda_wrappers=[ + ( + "fused_store_index_k_cache", + # - Float = bf16_t (sgl_kernel/type.cuh) + # - IndicesT = int64_t (out_cache_loc is int64 in SGLang SetKAndS) + # - kPageSize = 64 (CUDA NSA) + f"FusedStoreCacheIndexerKernel<{args}>::run", + ) + ], + ) + + +@cache_once +def can_use_nsa_fused_store( + key_dtype: torch.dtype, indices_dtype: torch.dtype, page_size: int +) -> bool: + logger = logging.getLogger(__name__) + try: + _jit_nsa_fused_store_module(key_dtype, indices_dtype, page_size) + return True + except Exception as e: + logger.warning(f"Failed to load nsa fused store JIT kernel: {e}") + return False + + +def fused_store_index_k_cache( + key: torch.Tensor, + index_k_with_scale: torch.Tensor, + out_cache_loc: torch.Tensor, + page_size: int = 64, +) -> None: + """ + Fused: quantize bf16 key (N,128) -> fp8 + fp32 scale and write into NSATokenToKVPool.index_k_with_scale_buffer. + + key: (num_tokens, 128) bf16 (or reshapeable to it) + index_k_with_scale: (num_pages, 64*(128+4)) uint8 + out_cache_loc: (num_tokens,) int64 token indices in TokenToKVPool + """ + assert key.is_cuda + assert index_k_with_scale.is_cuda + assert out_cache_loc.is_cuda + + # 1) normalize shapes + if key.dim() != 2: + key = key.view(-1, key.shape[-1]) + assert key.shape[1] == 128, f"expected key last-dim=128, got {key.shape}" + + # 2) dtypes + assert key.dtype == torch.bfloat16, f"{key.dtype=}" + assert index_k_with_scale.dtype == torch.uint8, f"{index_k_with_scale.dtype=}" + assert out_cache_loc.dtype == torch.int64, f"{out_cache_loc.dtype=}" + + # 3) contiguity + if not key.is_contiguous(): + key = key.contiguous() + if not out_cache_loc.is_contiguous(): + out_cache_loc = out_cache_loc.contiguous() + if not index_k_with_scale.is_contiguous(): + index_k_with_scale = index_k_with_scale.contiguous() + + module = _jit_nsa_fused_store_module(key.dtype, out_cache_loc.dtype, page_size) + module.fused_store_index_k_cache(key, index_k_with_scale, out_cache_loc) diff --git a/sglang/python/sglang/jit_kernel/gptq_marlin.py b/sglang/python/sglang/jit_kernel/gptq_marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..8f604730972f995468990d6a316c31a7035dfbf0 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/gptq_marlin.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args + +if TYPE_CHECKING: + from sgl_kernel.scalar_type import ScalarType + from tvm_ffi.module import Module + +# Constants matching device::marlin:: in marlin.cuh +_MAX_THREAD_N = 256 + + +@cache_once +def _jit_gptq_marlin_module(dtype: torch.dtype) -> Module: + args = make_cpp_args(dtype) + return load_jit( + "gptq_marlin", + *args, + cuda_files=["gemm/marlin/gptq_marlin.cuh"], + cuda_wrappers=[("gptq_marlin_gemm", f"gptq_marlin_gemm<{args}>")], + ) + + +def _or_empty( + t: Optional[torch.Tensor], device: torch.device, dtype: torch.dtype +) -> torch.Tensor: + return t if t is not None else torch.empty(0, device=device, dtype=dtype) + + +def gptq_marlin_gemm( + a: torch.Tensor, + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, +) -> torch.Tensor: + device = a.device + + # Allocate output if not provided + if c is None: + c = torch.empty((size_m, size_n), dtype=a.dtype, device=device) + + # Early return for zero-size M + if size_m == 0: + return c + + # Determine activation ordering + has_act_order = ( + g_idx is not None + and perm is not None + and g_idx.numel() > 0 + and perm.numel() > 0 + ) + + # Allocate c_tmp for fp32 reduce + if use_fp32_reduce: + sms = torch.cuda.get_device_properties(device).multi_processor_count + max_m_block = min(((size_m + 15) // 16) * 16, 64) + c_tmp = torch.empty( + sms * max_m_block * _MAX_THREAD_N, + dtype=torch.float32, + device=device, + ) + else: + c_tmp = torch.empty(0, dtype=torch.float32, device=device) + + # Allocate a_tmp for act_order column permutation + if has_act_order: + a_tmp = torch.empty((size_m, size_k), dtype=a.dtype, device=device) + else: + a_tmp = torch.empty(0, dtype=a.dtype, device=device) + + # Convert Optional tensors to empty tensors + global_scale_t = _or_empty(global_scale, device, a.dtype) + b_zeros_t = _or_empty(b_zeros, device, torch.int32) + g_idx_t = _or_empty(g_idx, device, torch.int32) + perm_t = _or_empty(perm, device, torch.int32) + + module = _jit_gptq_marlin_module(a.dtype) + module.gptq_marlin_gemm( + a, + b_q_weight, + b_scales, + global_scale_t, + b_zeros_t, + g_idx_t, + perm_t, + c, + c_tmp, + a_tmp, + workspace, + b_q_type.id, + is_k_full, + use_atomic_add, + use_fp32_reduce, + is_zp_float, + ) + + return c diff --git a/sglang/python/sglang/jit_kernel/gptq_marlin_repack.py b/sglang/python/sglang/jit_kernel/gptq_marlin_repack.py new file mode 100644 index 0000000000000000000000000000000000000000..f04a2ce816156232c0d9f754ad721dc87a210bb3 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/gptq_marlin_repack.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit + +if TYPE_CHECKING: + from tvm_ffi.module import Module + +# Constants matching device::marlin:: in marlin.cuh +_TILE_SIZE = 16 + + +@cache_once +def _jit_gptq_marlin_repack_module() -> Module: + return load_jit( + "gptq_marlin_repack", + cuda_files=["gemm/marlin/gptq_marlin_repack.cuh"], + cuda_wrappers=[("gptq_marlin_repack", "gptq_marlin_repack")], + ) + + +def gptq_marlin_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + pack_factor = 32 // num_bits + + # Allocate output tensor + out = torch.empty( + (size_k // _TILE_SIZE, size_n * _TILE_SIZE // pack_factor), + dtype=b_q_weight.dtype, + device=b_q_weight.device, + ) + + module = _jit_gptq_marlin_repack_module() + module.gptq_marlin_repack(b_q_weight, perm, out, size_k, size_n, num_bits) + return out diff --git a/sglang/python/sglang/jit_kernel/hadamard.py b/sglang/python/sglang/jit_kernel/hadamard.py new file mode 100644 index 0000000000000000000000000000000000000000..25930ce942d33285d40f95791087351abb16489e --- /dev/null +++ b/sglang/python/sglang/jit_kernel/hadamard.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +import torch + +from sglang.jit_kernel.utils import KERNEL_PATH, cache_once, load_jit, make_cpp_args + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_hadamard_module(dtype: torch.dtype) -> Module: + args = make_cpp_args(dtype) + hadamard_include_dir = (KERNEL_PATH / "csrc" / "fast-hadamard-transform").resolve() + return load_jit( + "hadamard", + *args, + cuda_files=["fast-hadamard-transform/hadamard_jit.cuh"], + cuda_wrappers=[ + ("hadamard_transform", f"HadamardKernel<{args}>::run"), + ("hadamard_transform_12n", f"Hadamard12NKernel<{args}>::run"), + ("hadamard_transform_20n", f"Hadamard20NKernel<{args}>::run"), + ("hadamard_transform_28n", f"Hadamard28NKernel<{args}>::run"), + ("hadamard_transform_40n", f"Hadamard40NKernel<{args}>::run"), + ], + extra_include_paths=[str(hadamard_include_dir)], + ) + + +def _hadamard_transform_impl( + x: torch.Tensor, + scale: float, + pad_multiple: int, + kernel_fn: Callable, +) -> torch.Tensor: + if not x.is_cuda: + raise RuntimeError(f"{kernel_fn.__name__} only supports CUDA tensors") + + shapes_og = x.size() + dim_og = x.size(-1) + x = x.reshape(-1, dim_og) + if x.stride(-1) != 1: + x = x.contiguous() + + needs_pad = dim_og % pad_multiple != 0 + if needs_pad: + x = torch.nn.functional.pad(x, (0, pad_multiple - dim_og % pad_multiple)) + + out = torch.empty_like(x) + kernel_fn(x, out, scale) + + if needs_pad: + out = out[:, :dim_og] + return out.reshape(shapes_og) + + +def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + module = _jit_hadamard_module(x.dtype) + return _hadamard_transform_impl(x, scale, 8, module.hadamard_transform) + + +def hadamard_transform_12n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + module = _jit_hadamard_module(x.dtype) + return _hadamard_transform_impl(x, scale, 4 * 12, module.hadamard_transform_12n) + + +def hadamard_transform_20n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + module = _jit_hadamard_module(x.dtype) + return _hadamard_transform_impl(x, scale, 4 * 20, module.hadamard_transform_20n) + + +def hadamard_transform_28n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + module = _jit_hadamard_module(x.dtype) + return _hadamard_transform_impl(x, scale, 4 * 28, module.hadamard_transform_28n) + + +def hadamard_transform_40n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + module = _jit_hadamard_module(x.dtype) + return _hadamard_transform_impl(x, scale, 4 * 40, module.hadamard_transform_40n) diff --git a/sglang/python/sglang/jit_kernel/hicache.py b/sglang/python/sglang/jit_kernel/hicache.py new file mode 100644 index 0000000000000000000000000000000000000000..aa78d5cb4789e68ea62998159e0e5604238f8d48 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/hicache.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args + +if TYPE_CHECKING: + import torch + from tvm_ffi.module import Module + +DEFAULT_BLOCK_QUOTA = 2 + + +@cache_once +def _jit_hicache_module(*, element_size: int, unroll: int, block_quota: int) -> Module: + args = make_cpp_args( + element_size, + unroll, + block_quota, + 1024, # num_threads, can be tuned for performance + ) + return load_jit( + "hicache", + *args, + cuda_files=["hicache.cuh"], + cuda_wrappers=[ + ("launch_one", f"&HiCacheKernel<{args}>::run_one"), + ("launch_all", f"&HiCacheKernel<{args}>::run_all"), + ], + ) + + +def can_use_hicache_jit_kernel( + *, + element_size: int, + unroll: int | None = None, # can be tuned for performance + block_quota: int | None = None, # can be tuned for less interference +) -> bool: + logger = logging.getLogger(__name__) + if element_size % 128 != 0: + logger.warning(f"Unsupported {element_size = } for JIT HiCache kernel") + return False + try: + unroll = unroll or _default_unroll(element_size) + block_quota = block_quota or DEFAULT_BLOCK_QUOTA + _jit_hicache_module( + element_size=element_size, + unroll=unroll, + block_quota=block_quota, + ) + return True + except Exception as e: + logger.warning(f"Failed to load JIT HiCache kernel: {e}") + return False + + +def _default_unroll(element_size: int) -> int: + if element_size <= 512: + return 4 + + if element_size <= 1024: + return 2 + + # fallback: no unroll + return 1 + + +def transfer_hicache_one_layer( + k_cache_dst: torch.Tensor, + v_cache_dst: torch.Tensor, + indices_dst: torch.Tensor, + k_cache_src: torch.Tensor, + v_cache_src: torch.Tensor, + indices_src: torch.Tensor, + *, + element_dim: int | None = None, + unroll: int | None = None, # can be tuned for performance + block_quota: int | None = None, # can be tuned for less interference +) -> None: + element_dim = element_dim or k_cache_dst.size(-1) + k_cache_src = k_cache_src.view(-1, element_dim) + v_cache_src = v_cache_src.view(-1, element_dim) + k_cache_dst = k_cache_dst.view(-1, element_dim) + v_cache_dst = v_cache_dst.view(-1, element_dim) + element_size = element_dim * k_cache_dst.element_size() + block_quota = block_quota or DEFAULT_BLOCK_QUOTA + unroll = unroll or _default_unroll(element_size) + module = _jit_hicache_module( + element_size=element_size, + unroll=unroll, + block_quota=block_quota, + ) + module.launch_one( + k_cache_dst, + v_cache_dst, + indices_dst, + k_cache_src, + v_cache_src, + indices_src, + ) + + +def transfer_hicache_all_layer( + k_ptr_dst: torch.Tensor, + v_ptr_dst: torch.Tensor, + indices_dst: torch.Tensor, + k_ptr_src: torch.Tensor, + v_ptr_src: torch.Tensor, + indices_src: torch.Tensor, + *, + kv_cache_src_stride_bytes: int, + kv_cache_dst_stride_bytes: int, + element_size: int | None = None, + unroll: int | None = None, # can be tuned for performance + block_quota: int | None = None, # can be tuned for less interference +) -> None: + if element_size is None: # assume both contiguous + assert kv_cache_dst_stride_bytes == kv_cache_src_stride_bytes + element_size = kv_cache_dst_stride_bytes + + block_quota = block_quota or DEFAULT_BLOCK_QUOTA + unroll = unroll or _default_unroll(element_size) + module = _jit_hicache_module( + element_size=element_size, + unroll=unroll, + block_quota=block_quota, + ) + module.launch_all( + k_ptr_dst, + v_ptr_dst, + indices_dst, + k_ptr_src, + v_ptr_src, + indices_src, + kv_cache_src_stride_bytes, + kv_cache_dst_stride_bytes, + ) diff --git a/sglang/python/sglang/jit_kernel/include/sgl_kernel/atomic.cuh b/sglang/python/sglang/jit_kernel/include/sgl_kernel/atomic.cuh new file mode 100644 index 0000000000000000000000000000000000000000..574f79b8d82aa34829cacd2db81d643b6316277a --- /dev/null +++ b/sglang/python/sglang/jit_kernel/include/sgl_kernel/atomic.cuh @@ -0,0 +1,23 @@ +#pragma once +#include + +namespace device::atomic { + +SGL_DEVICE float max(float* addr, float value) { +#ifndef USE_ROCM + float old; + old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); + return old; +#else + int* addr_as_i = (int*)addr; + int old = *addr_as_i, assumed; + do { + assumed = old; + old = atomicCAS(addr_as_i, assumed, __float_as_int(fmaxf(value, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); +#endif +} + +} // namespace device::atomic diff --git a/sglang/python/sglang/jit_kernel/include/sgl_kernel/cta.cuh b/sglang/python/sglang/jit_kernel/include/sgl_kernel/cta.cuh new file mode 100644 index 0000000000000000000000000000000000000000..28db34f02595b6da8e9a746b226a9d026f0acc0a --- /dev/null +++ b/sglang/python/sglang/jit_kernel/include/sgl_kernel/cta.cuh @@ -0,0 +1,22 @@ +#pragma once +#include +#include +#include + +namespace device::cta { + +template +SGL_DEVICE void reduce_max(T value, float* smem, float min_value = 0.0f) { + const uint32_t warp_id = threadIdx.x / kWarpThreads; + smem[warp_id] = warp::reduce_max(value); + __syncthreads(); + if (warp_id == 0) { + const auto tx = threadIdx.x; + const auto local_value = tx * kWarpThreads < blockDim.x ? smem[tx] : min_value; + const auto max_value = warp::reduce_max(local_value); + smem[0] = max_value; + } + // no extra sync; it is caller's responsibility to sync if needed +} + +} // namespace device::cta diff --git a/sglang/python/sglang/jit_kernel/include/sgl_kernel/impl/norm.cuh b/sglang/python/sglang/jit_kernel/include/sgl_kernel/impl/norm.cuh new file mode 100644 index 0000000000000000000000000000000000000000..cd024acd4604da140260b95fefb35dfa9bedb9ba --- /dev/null +++ b/sglang/python/sglang/jit_kernel/include/sgl_kernel/impl/norm.cuh @@ -0,0 +1,168 @@ +#pragma once +#include +#include +#include +#include +#include + +#include +#include + +namespace host::norm { + +/** + * \brief Check if the given configuration is supported. + * \tparam T Element type (only fp16_t/bf16_t is supported) + * \tparam kDim Dimension size (usually hidden size) + */ +template +inline constexpr bool is_config_supported() { + if (!std::is_same_v && !std::is_same_v) return false; + if (kDim <= 256) { + return (kDim == 64 || kDim == 128 || kDim == 256); + } else { + return (kDim % 256 == 0 && kDim <= 8192); + } +} + +/** + * \brief Determine whether to use cta norm based on dimension size. + * TL;DR: use warp norm for dim <= 256, cta norm otherwise. + * \tparam T Element type (fp16_t or bf16_t) + * \tparam kDim Dimension size (usually hidden size) + * \note This function assumes that the configuration is supported. + * \see `is_config_supported` + */ +template +inline constexpr bool should_use_cta() { + static_assert(is_config_supported(), "Unsupported norm configuration"); + return kDim > 256; +} + +/** + * \brief Get the number of threads per CTA for cta norm. + * \tparam T Element type (fp16_t or bf16_t) + * \tparam kDim Dimension size (usually hidden size) + * \return Number of threads per CTA + */ +template +inline constexpr uint32_t get_cta_threads() { + static_assert(should_use_cta()); + return (kDim / 256) * device::kWarpThreads; +} + +} // namespace host::norm + +namespace device::norm { + +namespace details { + +template +SGL_DEVICE AlignedVector apply_norm_impl( + const AlignedVector input, + const AlignedVector weight, + const float eps, + [[maybe_unused]] float* smem_buffer, + [[maybe_unused]] uint32_t num_warps) { + float sum_of_squares = 0.0f; + +#pragma unroll + for (auto i = 0u; i < N; ++i) { + const auto fp32_input = cast(input[i]); + sum_of_squares += fp32_input.x * fp32_input.x; + sum_of_squares += fp32_input.y * fp32_input.y; + } + + sum_of_squares = warp::reduce_sum(sum_of_squares); + float norm_factor; + if constexpr (kUseCTA) { + // need to synchronize across the cta + const auto warp_id = threadIdx.x / kWarpThreads; + smem_buffer[warp_id] = sum_of_squares; + __syncthreads(); + // use the first warp to reduce + if (warp_id == 0) { + const auto tx = threadIdx.x; + const auto local_sum = tx < num_warps ? smem_buffer[tx] : 0.0f; + sum_of_squares = warp::reduce_sum(local_sum); + smem_buffer[32] = math::rsqrt(sum_of_squares / kDim + eps); + } + __syncthreads(); + norm_factor = smem_buffer[32]; + } else { + norm_factor = math::rsqrt(sum_of_squares / kDim + eps); + } + + AlignedVector output; + +#pragma unroll + for (auto i = 0u; i < N; ++i) { + const auto fp32_input = cast(input[i]); + const auto fp32_weight = cast(weight[i]); + output[i] = cast({ + fp32_input.x * norm_factor * fp32_weight.x, + fp32_input.y * norm_factor * fp32_weight.y, + }); + } + + return output; +} + +} // namespace details + +/** + * \brief Apply norm using warp-level implementation. + * \tparam kDim Dimension size + * \tparam T Element type (fp16_t or bf16_t) + * \param input Input vector + * \param weight Weight vector + * \param eps Epsilon value for numerical stability + * \return Normalized output vector + */ +template +SGL_DEVICE T apply_norm_warp(const T& input, const T& weight, float eps) { + static_assert(kDim <= 256, "Warp norm only supports dim <= 256"); + return details::apply_norm_impl(input, weight, eps, nullptr, 0); +} + +/** + * \brief Apply norm using CTA-level implementation. + * \tparam kDim Dimension size + * \tparam T Element type (fp16_t or bf16_t) + * \param input Input vector + * \param weight Weight vector + * \param eps Epsilon value for numerical stability + * \param smem Shared memory buffer + * \param num_warps Number of warps in the CTA + * \return Normalized output vector + */ +template +SGL_DEVICE T apply_norm_cta( + const T& input, const T& weight, float eps, float* smem, uint32_t num_warps = blockDim.x / kWarpThreads) { + static_assert(kDim > 256, "CTA norm only supports dim > 256"); + return details::apply_norm_impl(input, weight, eps, smem, num_warps); +} + +/** + * \brief Storage type for norm operation. + * For warp norm, the storage size depends on kDim. + * For cta norm, the storage size is fixed to 16B. + * We will also pack the input 16-bit floats into 32-bit types + * for faster CUDA core operations. + * + * \tparam T Element type (fp16_t or bf16_t) + * \tparam kDim Dimension size + */ +template +using StorageType = std::conditional_t< // storage type + (kDim > 256), // whether to use cta norm + AlignedVector, 4>, // cta norm storage, fixed to 16B + AlignedVector, kDim / (2 * kWarpThreads)> // warp norm storage + >; + +/** + * \brief Minimum shared memory size (in bytes) required for cta norm. + */ +inline constexpr uint32_t kSmemBufferSize = 33; + +} // namespace device::norm diff --git a/sglang/python/sglang/jit_kernel/include/sgl_kernel/math.cuh b/sglang/python/sglang/jit_kernel/include/sgl_kernel/math.cuh new file mode 100644 index 0000000000000000000000000000000000000000..2287b31e24463a8c852523e33ce2e47eb6d369a5 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/include/sgl_kernel/math.cuh @@ -0,0 +1,51 @@ +#pragma once +#include + +namespace device::math { + +inline constexpr float log2e = 1.44269504088896340736f; +inline constexpr float loge2 = 0.693147180559945309417f; +inline constexpr float FP8_E4M3_MAX = 448.0f; +static_assert(log2e * loge2 == 1.0f, "log2e * loge2 must be 1"); + +template +SGL_DEVICE T max(T a, T b) { + return dtype_trait::max(a, b); +} + +template +SGL_DEVICE T min(T a, T b) { + return dtype_trait::min(a, b); +} + +template +SGL_DEVICE T abs(T a) { + return dtype_trait::abs(a); +} + +template +SGL_DEVICE T sqrt(T a) { + return dtype_trait::sqrt(a); +} + +template +SGL_DEVICE T rsqrt(T a) { + return dtype_trait::rsqrt(a); +} + +template +SGL_DEVICE T exp(T a) { + return dtype_trait::exp(a); +} + +template +SGL_DEVICE T sin(T a) { + return dtype_trait::sin(a); +} + +template +SGL_DEVICE T cos(T a) { + return dtype_trait::cos(a); +} + +} // namespace device::math diff --git a/sglang/python/sglang/jit_kernel/include/sgl_kernel/runtime.cuh b/sglang/python/sglang/jit_kernel/include/sgl_kernel/runtime.cuh new file mode 100644 index 0000000000000000000000000000000000000000..bf582999bfdcd324bc95831eba795e5d9424be16 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/include/sgl_kernel/runtime.cuh @@ -0,0 +1,49 @@ +#pragma once + +#include + +#include +#include +#include + +namespace host::runtime { + +// Return the maximum number of active blocks per SM for the given kernel +template +inline auto get_blocks_per_sm(T&& kernel, int32_t block_dim, std::size_t dynamic_smem = 0) -> uint32_t { + int num_blocks_per_sm = 0; + RuntimeDeviceCheck( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, block_dim, dynamic_smem)); + return static_cast(num_blocks_per_sm); +} + +// Return the number of SMs for the given device +inline auto get_sm_count(int device_id) -> uint32_t { + int sm_count; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device_id)); + return static_cast(sm_count); +} + +// Return the Major compute capability for the given device +inline auto get_cc_major(int device_id) -> int { + int cc_major; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device_id)); + return cc_major; +} + +// Return the runtime version +inline auto get_runtime_version() -> int { + int runtime_version; + RuntimeDeviceCheck(cudaRuntimeGetVersion(&runtime_version)); + return runtime_version; +} + +// Return the maximum dynamic shared memory per block for the given kernel +template +inline auto get_available_dynamic_smem_per_block(T&& kernel, int num_blocks, int block_size) -> std::size_t { + std::size_t smem_size; + RuntimeDeviceCheck(cudaOccupancyAvailableDynamicSMemPerBlock(&smem_size, kernel, num_blocks, block_size)); + return smem_size; +} + +} // namespace host::runtime diff --git a/sglang/python/sglang/jit_kernel/include/sgl_kernel/scalar_type.hpp b/sglang/python/sglang/jit_kernel/include/sgl_kernel/scalar_type.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d229d3a975c30c2f060c09cfe648c8d9c8d9ad3c --- /dev/null +++ b/sglang/python/sglang/jit_kernel/include/sgl_kernel/scalar_type.hpp @@ -0,0 +1,334 @@ +#pragma once + +#include +#include +#ifndef __CUDACC__ +#include +#endif + +namespace host { + +// +// ScalarType can represent a wide range of floating point and integer types, +// in particular it can be used to represent sub-byte data types (something +// that torch.dtype currently does not support). +// +// The type definitions on the Python side can be found in: vllm/scalar_type.py +// these type definitions should be kept up to date with any Python API changes +// here. +// +class ScalarType { + public: + enum NanRepr : uint8_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX + }; + + constexpr ScalarType( + uint8_t exponent, + uint8_t mantissa, + bool signed_, + int32_t bias, + bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + signed_(signed_), + bias(bias), + finite_values_only(finite_values_only), + nan_repr(nan_repr) {}; + + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits - 1, true, bias); + } + + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits, false, bias); + } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(uint8_t exponent, uint8_t mantissa) { + assert(mantissa > 0 && exponent > 0); + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) { + assert(nan_repr < NAN_REPR_ID_MAX); + assert(mantissa > 0 && exponent > 0); + assert(nan_repr != NAN_IEEE_754); + return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr); + } + + uint8_t const exponent; // size of the exponent field (0 for integer types) + uint8_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + int32_t const bias; // stored values equal value + bias, + // used for quantized type + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + using Id = int64_t; + + private: + // Field size in id + template + static constexpr size_t member_id_field_width() { + using T = std::decay_t; + return std::is_same_v ? 1 : sizeof(T) * 8; + } + + template + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, Rest... rest) { + auto new_val = f(val, member); + if constexpr (sizeof...(rest) > 0) { + return reduce_members_helper(f, new_val, rest...); + } else { + return new_val; + }; + } + + template + constexpr auto reduce_members(Fn f, Init init) const { + // Should be in constructor order for `from_id` + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, finite_values_only, nan_repr); + }; + + template + static constexpr auto reduce_member_types(Fn f, Init init) { + constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); + return dummy_type.reduce_members(f, init); + }; + + static constexpr auto id_size_bits() { + return reduce_member_types( + [](int acc, auto member) -> int { return acc + member_id_field_width(); }, 0); + } + + public: + // unique id for this scalar type that can be computed at compile time for + // c++17 template specialization this is not needed once we migrate to + // c++20 and can pass literal classes as template parameters + constexpr Id id() const { + static_assert(id_size_bits() <= sizeof(Id) * 8, "ScalarType id is too large to be stored"); + + auto or_and_advance = [](std::pair result, auto member) -> std::pair { + auto [id, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) << bit_offset, bit_offset + bits}; + }; + return reduce_members(or_and_advance, std::pair{}).first; + } + + // create a ScalarType from an id, for c++17 template specialization, + // this is not needed once we migrate to c++20 and can pass literal + // classes as template parameters + static constexpr ScalarType from_id(Id id) { + auto extract_and_advance = [id](auto result, auto member) { + using T = decltype(member); + auto [tuple, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + auto extracted_val = static_cast((int64_t(id) >> bit_offset) & ((uint64_t(1) << bits) - 1)); + auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); + return std::pair{new_tuple, bit_offset + bits}; + }; + + auto [tuple_args, _] = reduce_member_types(extract_and_advance, std::pair, int>{}); + return std::apply([](auto... args) { return ScalarType(args...); }, tuple_args); + } + + constexpr int64_t size_bits() const { + return mantissa + exponent + is_signed(); + } + constexpr bool is_signed() const { + return signed_; + } + constexpr bool is_integer() const { + return exponent == 0; + } + constexpr bool is_floating_point() const { + return exponent > 0; + } + constexpr bool is_ieee_754() const { + return is_floating_point() && finite_values_only == false && nan_repr == NAN_IEEE_754; + } + constexpr bool has_nans() const { + return is_floating_point() && nan_repr != NAN_NONE; + } + constexpr bool has_infs() const { + return is_floating_point() && finite_values_only == false; + } + constexpr bool has_bias() const { + return bias != 0; + } + +#ifndef __CUDACC__ + private: + double _floating_point_max() const { + assert(mantissa <= 52 && exponent <= 11); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { + max_mantissa -= 1; + } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + assert(exponent < 11); + max_exponent += 1; + } + + // adjust the exponent to match that of a double + // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e + // is the exponent bits), there is some precedent for non-standard biases, + // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes + // but to avoid premature over complication we are just assuming the + // standard exponent bias until there is a need to support non-standard + // biases + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = max_exponent - exponent_bias + exponent_bias_double; + + // shift the mantissa into the position for a double and + // the exponent + uint64_t double_raw = (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + + return *reinterpret_cast(&double_raw); + } + + constexpr std::variant _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + assert(size_bits() < 64 || (size_bits() == 64 && is_signed())); + return {(int64_t(1) << mantissa) - 1}; + } + } + + constexpr std::variant _raw_min() const { + if (is_floating_point()) { + assert(is_signed()); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = *reinterpret_cast(&max); + uint64_t min_raw = max_raw | sign_bit_double; + return {*reinterpret_cast(&min_raw)}; + } else { + assert(!is_signed() || size_bits() <= 64); + if (is_signed()) { + // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 + // then perform an arithmetic shift right to set all the bits above + // (size_bits() - 1) to 1 + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } + } + } + + public: + // Max representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant max() const { + return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_max()); + } + + // Min representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant min() const { + return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_min()); + } +#endif // __CUDACC__ + + public: + std::string str() const { + /* naming generally follows: https://github.com/jax-ml/ml_dtypes + * for floating point types (leading f) the scheme is: + * `float_em[flags]` + * flags: + * - no-flags: means it follows IEEE 754 conventions + * - f: means finite values only (no infinities) + * - n: means nans are supported (non-standard encoding) + * for integer types the scheme is: + * `[u]int[b]` + * - if bias is not present it means its zero + */ + if (is_floating_point()) { + auto ret = + "float" + std::to_string(size_bits()) + "_e" + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; + } + } + + constexpr bool operator==(ScalarType const& other) const { + return mantissa == other.mantissa && exponent == other.exponent && bias == other.bias && signed_ == other.signed_ && + finite_values_only == other.finite_values_only && nan_repr == other.nan_repr; + } +}; + +using ScalarTypeId = ScalarType::Id; + +// "rust style" names generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 +static inline constexpr auto kS4 = ScalarType::int_(4); +static inline constexpr auto kU4 = ScalarType::uint(4); +static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); +static inline constexpr auto kS8 = ScalarType::int_(8); +static inline constexpr auto kU8 = ScalarType::uint(8); +static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); + +static inline constexpr auto kFE2M1f = ScalarType::float_(2, 1, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE8M0fnu = ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); +static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); +static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); + +// Fixed width style names, generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57 +static inline constexpr auto kInt4 = kS4; +static inline constexpr auto kUint4 = kU4; +static inline constexpr auto kUint4b8 = kU4B8; +static inline constexpr auto kInt8 = kS8; +static inline constexpr auto kUint8 = kU8; +static inline constexpr auto kUint8b128 = kU8B128; + +static inline constexpr auto kFloat4_e2m1f = kFE2M1f; +static inline constexpr auto kFloat6_e3m2f = kFE3M2f; +static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; +static inline constexpr auto kFloat8_e5m2 = kFE5M2; +static inline constexpr auto kFloat16_e8m7 = kFE8M7; +static inline constexpr auto kFloat16_e5m10 = kFE5M10; + +// colloquial names +static inline constexpr auto kHalf = kFE5M10; +static inline constexpr auto kFloat16 = kHalf; +static inline constexpr auto kBFloat16 = kFE8M7; + +static inline constexpr auto kFloat16Id = kFloat16.id(); +} // namespace host diff --git a/sglang/python/sglang/jit_kernel/include/sgl_kernel/source_location.h b/sglang/python/sglang/jit_kernel/include/sgl_kernel/source_location.h new file mode 100644 index 0000000000000000000000000000000000000000..9616fa7daccdfc31c35d88d5e5fb7c4aa5c160a8 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/include/sgl_kernel/source_location.h @@ -0,0 +1,34 @@ +#pragma once +#include + +/// NOTE: fallback to a minimal source_location implementation +#if defined(__cpp_lib_source_location) +#include + +using source_location_t = std::source_location; + +#else + +struct source_location_fallback { + public: + static constexpr source_location_fallback current() noexcept { + return source_location_fallback{}; + } + constexpr source_location_fallback() noexcept = default; + constexpr unsigned line() const noexcept { + return 0; + } + constexpr unsigned column() const noexcept { + return 0; + } + constexpr const char* file_name() const noexcept { + return ""; + } + constexpr const char* function_name() const noexcept { + return ""; + } +}; + +using source_location_t = source_location_fallback; + +#endif diff --git a/sglang/python/sglang/jit_kernel/include/sgl_kernel/tensor.h b/sglang/python/sglang/jit_kernel/include/sgl_kernel/tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..de7ed8a0c671bf760f9980289a5fbc0fb1bc0921 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/include/sgl_kernel/tensor.h @@ -0,0 +1,539 @@ +#pragma once +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __CUDACC__ +#include +#endif + +namespace host { + +namespace details { + +inline constexpr auto kAnyDeviceID = -1; +inline constexpr auto kAnySize = static_cast(-1); +inline constexpr auto kNullSize = static_cast(-1); +inline constexpr auto kNullDType = static_cast(18u); +inline constexpr auto kNullDevice = static_cast(-1); + +struct SizeRef; +struct DTypeRef; +struct DeviceRef; + +template +struct _dtype_trait {}; + +template +struct _dtype_trait { + inline static constexpr DLDataType value = { + .code = std::is_signed_v ? DLDataTypeCode::kDLInt : DLDataTypeCode::kDLUInt, + .bits = static_cast(sizeof(T) * 8), + .lanes = 1}; +}; + +template +struct _dtype_trait { + inline static constexpr DLDataType value = { + .code = DLDataTypeCode::kDLFloat, .bits = static_cast(sizeof(T) * 8), .lanes = 1}; +}; + +#ifdef __CUDACC__ +template <> +struct _dtype_trait { + inline static constexpr DLDataType value = {.code = DLDataTypeCode::kDLFloat, .bits = 16, .lanes = 1}; +}; +template <> +struct _dtype_trait { + inline static constexpr DLDataType value = {.code = DLDataTypeCode::kDLBfloat, .bits = 16, .lanes = 1}; +}; +template <> +struct _dtype_trait { + inline static constexpr DLDataType value = {.code = DLDataTypeCode::kDLFloat8_e4m3fn, .bits = 8, .lanes = 1}; +}; +#endif + +template +struct _device_trait { + inline static constexpr DLDevice value = {.device_type = Code, .device_id = kAnyDeviceID}; +}; + +template +inline constexpr auto kDTypeList = std::array{_dtype_trait::value...}; + +template +inline constexpr auto kDeviceList = std::array{_device_trait::value...}; + +template +struct PrintAbleSpan { + explicit PrintAbleSpan(std::span data) : data(data) {} + std::span data; +}; + +// define DLDataType comparison and printing in root namespace +inline constexpr auto kDeviceStringMap = [] { + constexpr auto map = std::array, 16>{ + std::pair{DLDeviceType::kDLCPU, "cpu"}, + std::pair{DLDeviceType::kDLCUDA, "cuda"}, + std::pair{DLDeviceType::kDLCUDAHost, "cuda_host"}, + std::pair{DLDeviceType::kDLOpenCL, "opencl"}, + std::pair{DLDeviceType::kDLVulkan, "vulkan"}, + std::pair{DLDeviceType::kDLMetal, "metal"}, + std::pair{DLDeviceType::kDLVPI, "vpi"}, + std::pair{DLDeviceType::kDLROCM, "rocm"}, + std::pair{DLDeviceType::kDLROCMHost, "rocm_host"}, + std::pair{DLDeviceType::kDLExtDev, "ext_dev"}, + std::pair{DLDeviceType::kDLCUDAManaged, "cuda_managed"}, + std::pair{DLDeviceType::kDLOneAPI, "oneapi"}, + std::pair{DLDeviceType::kDLWebGPU, "webgpu"}, + std::pair{DLDeviceType::kDLHexagon, "hexagon"}, + std::pair{DLDeviceType::kDLMAIA, "maia"}, + std::pair{DLDeviceType::kDLTrn, "trn"}, + }; + constexpr auto max_type = stdr::max(map | stdv::keys); + auto result = std::array{}; + for (const auto& [code, name] : map) { + result[static_cast(code)] = name; + } + return result; +}(); + +struct PrintableDevice { + DLDevice device; +}; + +inline auto& operator<<(std::ostream& os, DLDevice device) { + const auto& mapping = kDeviceStringMap; + const auto entry = static_cast(device.device_type); + RuntimeCheck(entry < mapping.size()); + const auto name = mapping[entry]; + RuntimeCheck(!name.empty(), "Unknown device: ", int(device.device_type)); + os << name; + if (device.device_id != kAnyDeviceID && device.device_type != DLDeviceType::kDLCPU) { + os << ":" << device.device_id; + } + return os; +} + +inline auto& operator<<(std::ostream& os, PrintableDevice pd) { + return os << pd.device; +} + +template +inline auto& operator<<(std::ostream& os, PrintAbleSpan span) { + os << "["; + for (const auto i : irange(span.data.size())) { + if (i > 0) { + os << ", "; + } + os << span.data[i]; + } + os << "]"; + return os; +} + +} // namespace details + +template +inline bool is_type(DLDataType dtype) { + return dtype == details::_dtype_trait::value; +} + +struct SymbolicSize { + public: + SymbolicSize(std::string_view annotation = {}) : m_value(details::kNullSize), m_annotation(annotation) {} + SymbolicSize(const SymbolicSize&) = delete; + SymbolicSize& operator=(const SymbolicSize&) = delete; + + auto get_name() const -> std::string_view { + return m_annotation; + } + + auto set_value(int64_t value) -> void { + RuntimeCheck(!this->has_value(), "Size value already set"); + m_value = value; + } + + auto has_value() const -> bool { + return m_value != details::kNullSize; + } + + auto get_value() const -> std::optional { + return this->has_value() ? std::optional{m_value} : std::nullopt; + } + + auto unwrap(DebugInfo info = {}) const -> int64_t { + RuntimeCheck(info, this->has_value(), "Size value is not set"); + return m_value; + } + + auto verify(int64_t value, const char* prefix, int64_t dim) -> void { + if (this->has_value()) { + if (m_value != value) { + [[unlikely]]; + Panic("Size mismatch for ", m_name_str(prefix, dim), ": expected ", m_value, " but got ", value); + } + } else { + this->set_value(value); + } + } + + auto value_or_name(const char* prefix, int64_t dim) const -> std::string { + if (const auto value = this->get_value()) { + return std::to_string(*value); + } else { + return m_name_str(prefix, dim); + } + } + + private: + auto m_name_str(const char* prefix, int64_t dim) const -> std::string { + std::ostringstream os; + os << prefix << '#' << dim; + if (!m_annotation.empty()) os << "('" << m_annotation << "')"; + return std::move(os).str(); + } + + std::int64_t m_value; + std::string_view m_annotation; +}; + +inline auto operator==(DLDevice lhs, DLDevice rhs) -> bool { + return lhs.device_type == rhs.device_type && lhs.device_id == rhs.device_id; +} + +struct SymbolicDType { + public: + SymbolicDType() : m_value({details::kNullDType, 0, 0}) {} + SymbolicDType(const SymbolicDType&) = delete; + SymbolicDType& operator=(const SymbolicDType&) = delete; + + auto set_value(DLDataType value) -> void { + RuntimeCheck(!this->has_value(), "Dtype value already set"); + RuntimeCheck( + m_check(value), "Dtype value [", value, "] not in the allowed options: ", details::PrintAbleSpan{m_options}); + m_value = value; + } + + auto has_value() const -> bool { + return m_value.code != details::kNullDType; + } + + auto get_value() const -> std::optional { + return this->has_value() ? std::optional{m_value} : std::nullopt; + } + + auto unwrap(DebugInfo info = {}) const -> DLDataType { + RuntimeCheck(info, this->has_value(), "Dtype value is not set"); + return m_value; + } + + auto set_options(std::span options) -> void { + m_options = options; + } + + template + auto set_options() -> void { + m_options = details::kDTypeList; + } + + auto verify(DLDataType dtype) -> void { + if (this->has_value()) { + RuntimeCheck(m_value == dtype, "DType mismatch: expected ", m_value, " but got ", dtype); + } else { + this->set_value(dtype); + } + } + + template + auto is_type() const -> bool { + return ::host::is_type(m_value); + } + + private: + auto m_check(DLDataType value) const -> bool { + return stdr::empty(m_options) || (stdr::find(m_options, value) != stdr::end(m_options)); + } + + std::span m_options; + DLDataType m_value; +}; + +struct SymbolicDevice { + public: + SymbolicDevice() : m_value({details::kNullDevice, details::kAnyDeviceID}) {} + SymbolicDevice(const SymbolicDevice&) = delete; + SymbolicDevice& operator=(const SymbolicDevice&) = delete; + + auto set_value(DLDevice value) -> void { + RuntimeCheck(!this->has_value(), "Device value already set"); + RuntimeCheck( + m_check(value), + "Device value [", + details::PrintableDevice{value}, + "] not in the allowed options: ", + details::PrintAbleSpan{m_options}); + m_value = value; + } + + auto has_value() const -> bool { + return m_value.device_type != details::kNullDevice; + } + + auto get_value() const -> std::optional { + return this->has_value() ? std::optional{m_value} : std::nullopt; + } + + auto unwrap(DebugInfo info = {}) const -> DLDevice { + RuntimeCheck(info, this->has_value(), "Device value is not set"); + return m_value; + } + + auto set_options(std::span options) -> void { + m_options = options; + } + + template + auto set_options() -> void { + m_options = details::kDeviceList; + } + + auto verify(DLDevice device) -> void { + if (this->has_value()) { + RuntimeCheck( + m_value == device, + "Device mismatch: expected ", + details::PrintableDevice{m_value}, + " but got ", + details::PrintableDevice{device}); + } else { + this->set_value(device); + } + } + + private: + auto m_check(DLDevice value) const -> bool { + return stdr::empty(m_options) || (stdr::any_of(m_options, [value](const DLDevice& opt) { + // device type must exactly match + if (opt.device_type != value.device_type) return false; + // device id can be wildcarded + return opt.device_id == details::kAnyDeviceID || opt.device_id == value.device_id; + })); + } + + std::span m_options; + DLDevice m_value; +}; + +namespace details { + +template +struct BaseRef { + public: + BaseRef(const BaseRef&) = delete; + BaseRef& operator=(const BaseRef&) = delete; + + auto operator->() const -> T* { + return m_ref; + } + auto operator*() const -> T& { + return *m_ref; + } + auto rebind(T& other) -> void { + m_ref = &other; + } + + explicit BaseRef() : m_ref(&m_cache), m_cache() {} + BaseRef(T& size) : m_ref(&size), m_cache() {} + + private: + T* m_ref; + T m_cache; +}; + +struct SizeRef : BaseRef { + using BaseRef::BaseRef; + SizeRef(int64_t value) { + if (value != kAnySize) { + (**this).set_value(value); + } else { + // otherwise, we can match any size + } + } +}; + +struct DTypeRef : BaseRef { + using BaseRef::BaseRef; + DTypeRef(DLDataType options) { + (**this).set_value(options); + } + DTypeRef(std::initializer_list options) { + (**this).set_options(options); + } + DTypeRef(std::span options) { + (**this).set_options(options); + } +}; + +struct DeviceRef : BaseRef { + using BaseRef::BaseRef; + DeviceRef(DLDevice options) { + (**this).set_value(options); + } + DeviceRef(std::initializer_list options) { + (**this).set_options(options); + } + DeviceRef(std::span options) { + (**this).set_options(options); + } +}; + +} // namespace details + +struct TensorMatcher { + private: + using SizeRef = details::SizeRef; + using DTypeRef = details::DTypeRef; + using DeviceRef = details::DeviceRef; + + public: + TensorMatcher(const TensorMatcher&) = delete; + TensorMatcher& operator=(const TensorMatcher&) = delete; + + explicit TensorMatcher(std::initializer_list shape) : m_shape(shape), m_strides(), m_dtype() {} + + auto with_strides(std::initializer_list strides) && -> TensorMatcher&& { + // no partial update allowed + RuntimeCheck(m_strides.size() == 0, "Strides already specified"); + RuntimeCheck(m_shape.size() == strides.size(), "Strides size must match shape size"); + m_strides = strides; + return std::move(*this); + } + + template + auto with_dtype(DTypeRef&& dtype) && -> TensorMatcher&& { + m_init_dtype(); + m_dtype.rebind(*dtype); + m_dtype->set_options(); + return std::move(*this); + } + + template + auto with_dtype() && -> TensorMatcher&& { + static_assert(sizeof...(Ts) > 0, "At least one dtype option must be specified"); + m_init_dtype(); + m_dtype->set_options(); + return std::move(*this); + } + + template + auto with_device(DeviceRef&& device) && -> TensorMatcher&& { + m_init_device(); + m_device.rebind(*device); + m_device->set_options(); + return std::move(*this); + } + + template + auto with_device() && -> TensorMatcher&& { + static_assert(sizeof...(Codes) > 0, "At least one device option must be specified"); + m_init_device(); + m_device->set_options(); + return std::move(*this); + } + + // once we start verification, we cannot modify anymore + auto verify(tvm::ffi::TensorView view, DebugInfo info = {}) const&& -> const TensorMatcher&& { + try { + m_verify_impl(view); + } catch (PanicError& e) { + auto oss = std::ostringstream{}; + oss << "Tensor match failed for "; + s_print_tensor(oss, view); + oss << " at " << info.file_name() << ":" << info.line() << "\n- Root cause: " << e.root_cause(); + throw PanicError(std::move(oss).str()); + } + return std::move(*this); + } + + private: + static auto s_print_tensor(std::ostringstream& oss, tvm::ffi::TensorView view) -> void { + oss << "Tensor<"; + int64_t dim = 0; + for (const auto& size : view.shape()) { + if (dim++ > 0) oss << ", "; + oss << size; + } + oss << ">[strides=<"; + dim = 0; + for (const auto& stride : view.strides()) { + if (dim++ > 0) { + oss << ", "; + } + oss << stride; + } + oss << ">, dtype=" << view.dtype(); + oss << ", device=" << details::PrintableDevice{view.device()} << "]"; + } + + auto m_verify_impl(tvm::ffi::TensorView view) const -> void { + const auto dim = static_cast(view.dim()); + RuntimeCheck(dim == m_shape.size(), "Tensor dimension mismatch: expected ", m_shape.size(), " but got ", dim); + for (const auto i : irange(dim)) { + m_shape[i]->verify(view.size(i), "shape", i); + } + if (m_has_strides()) { + for (const auto i : irange(dim)) { + if (view.size(i) != 1 || !m_strides[i]->has_value()) { + // skip stride check for size 1 dimension + m_strides[i]->verify(view.stride(i), "stride", i); + } + } + } else { + RuntimeCheck(view.is_contiguous(), "Tensor is not contiguous as expected"); + } + // since we may double verify, we will force to check + m_dtype->verify(view.dtype()); + m_device->verify(view.device()); + } + + auto m_init_dtype() -> void { + RuntimeCheck(!m_has_dtype, "DType already specified"); + m_has_dtype = true; + } + + auto m_init_device() -> void { + RuntimeCheck(!m_has_device, "Device already specified"); + m_has_device = true; + } + + auto m_has_strides() const -> bool { + return !m_strides.empty(); + } + + std::span m_shape; + std::span m_strides; + DTypeRef m_dtype; + DeviceRef m_device; + bool m_has_dtype = false; + bool m_has_device = false; +}; + +} // namespace host diff --git a/sglang/python/sglang/jit_kernel/include/sgl_kernel/tile.cuh b/sglang/python/sglang/jit_kernel/include/sgl_kernel/tile.cuh new file mode 100644 index 0000000000000000000000000000000000000000..1328c64f91ffdfe520c4f40bfc5ef93a23ff0c58 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/include/sgl_kernel/tile.cuh @@ -0,0 +1,36 @@ +#pragma once +#include + +#include + +namespace device::tile { + +template +struct Memory { + public: + SGL_DEVICE constexpr Memory(uint32_t tid, uint32_t tsize) : tid(tid), tsize(tsize) {} + SGL_DEVICE static constexpr Memory thread() { + return Memory{0, 1}; + } + SGL_DEVICE static Memory warp(int warp_threads = kWarpThreads) { + return Memory{static_cast(threadIdx.x % warp_threads), static_cast(warp_threads)}; + } + SGL_DEVICE static Memory cta(int cta_threads = blockDim.x) { + return Memory{static_cast(threadIdx.x), static_cast(cta_threads)}; + } + SGL_DEVICE T load(const void* ptr, int64_t offset = 0) const { + return static_cast(ptr)[tid + offset * tsize]; + } + SGL_DEVICE void store(void* ptr, T val, int64_t offset = 0) const { + static_cast(ptr)[tid + offset * tsize] = val; + } + SGL_DEVICE bool in_bound(int64_t element_count, int64_t offset = 0) const { + return tid + offset * tsize < element_count; + } + + private: + uint32_t tid; + uint32_t tsize; +}; + +} // namespace device::tile diff --git a/sglang/python/sglang/jit_kernel/include/sgl_kernel/type.cuh b/sglang/python/sglang/jit_kernel/include/sgl_kernel/type.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f06bc1407ea7de509f8149277281f6151fa08e34 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/include/sgl_kernel/type.cuh @@ -0,0 +1,78 @@ +#pragma once +#include + +template +struct dtype_trait {}; + +#define SGL_REGISTER_DTYPE_TRAIT(TYPE, PACK2, ...) \ + template <> \ + struct dtype_trait { \ + using self_t = TYPE; \ + using packed_t = PACK2; \ + template \ + SGL_DEVICE static self_t from(const S& value) { \ + return static_cast(value); \ + } \ + __VA_ARGS__ \ + } + +#define SGL_REGISTER_TYPE_END static_assert(true) + +#define SGL_REGISTER_FROM_FUNCTION(FROM, FN) \ + SGL_DEVICE static self_t from(const FROM& x) { \ + return FN(x); \ + } \ + static_assert(true) + +#define SGL_REGISTER_UNARY_FUNCTION(NAME, FN) \ + SGL_DEVICE static self_t NAME(const self_t& x) { \ + return FN(x); \ + } \ + static_assert(true) + +#define SGL_REGISTER_BINARY_FUNCTION(NAME, FN) \ + SGL_DEVICE static self_t NAME(const self_t& x, const self_t& y) { \ + return FN(x, y); \ + } \ + static_assert(true) + +SGL_REGISTER_DTYPE_TRAIT( + fp32_t, fp32x2_t, SGL_REGISTER_TYPE_END; // + SGL_REGISTER_FROM_FUNCTION(fp16_t, __half2float); + SGL_REGISTER_FROM_FUNCTION(bf16_t, __bfloat162float); + SGL_REGISTER_UNARY_FUNCTION(abs, fabsf); + SGL_REGISTER_UNARY_FUNCTION(sqrt, sqrtf); + SGL_REGISTER_UNARY_FUNCTION(rsqrt, rsqrtf); + SGL_REGISTER_UNARY_FUNCTION(exp, expf); + SGL_REGISTER_UNARY_FUNCTION(sin, sinf); + SGL_REGISTER_UNARY_FUNCTION(cos, cosf); + SGL_REGISTER_BINARY_FUNCTION(max, fmaxf); + SGL_REGISTER_BINARY_FUNCTION(min, fminf);); +SGL_REGISTER_DTYPE_TRAIT(fp16_t, fp16x2_t); +SGL_REGISTER_DTYPE_TRAIT(bf16_t, bf16x2_t); + +/// TODO: Add ROCM implementation +SGL_REGISTER_DTYPE_TRAIT( + fp32x2_t, fp32x4_t, SGL_REGISTER_TYPE_END; SGL_REGISTER_FROM_FUNCTION(fp16x2_t, __half22float2); + SGL_REGISTER_FROM_FUNCTION(bf16x2_t, __bfloat1622float2);); + +SGL_REGISTER_DTYPE_TRAIT( + fp16x2_t, void, SGL_REGISTER_TYPE_END; SGL_REGISTER_FROM_FUNCTION(fp32x2_t, __float22half2_rn);); + +SGL_REGISTER_DTYPE_TRAIT( + bf16x2_t, void, SGL_REGISTER_TYPE_END; SGL_REGISTER_FROM_FUNCTION(fp32x2_t, __float22bfloat162_rn);); + +#undef SGL_REGISTER_DTYPE_TRAIT +#undef SGL_REGISTER_FROM_FUNCTION + +template +using packed_t = typename dtype_trait::packed_t; + +namespace device { + +template +SGL_DEVICE To cast(const From& value) { + return dtype_trait::from(value); +} + +} // namespace device diff --git a/sglang/python/sglang/jit_kernel/include/sgl_kernel/utils.cuh b/sglang/python/sglang/jit_kernel/include/sgl_kernel/utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..eae04e68d732bd2bb4e9b529a2adc76eba5b4988 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/include/sgl_kernel/utils.cuh @@ -0,0 +1,231 @@ +#pragma once + +#include + +#include +#include + +#include +#include +#include +#ifndef USE_ROCM +#include +#include +#include +#include +#else +#include +#include +#include +#ifndef __grid_constant__ +#define __grid_constant__ +#endif +using cudaError_t = hipError_t; +using cudaStream_t = hipStream_t; +using cudaLaunchConfig_t = hipLaunchConfig_t; +using cudaLaunchAttribute = hipLaunchAttribute; +inline constexpr auto cudaSuccess = hipSuccess; +#define cudaStreamPerThread hipStreamPerThread +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaLaunchKernel hipLaunchKernel +#endif + +#ifndef USE_ROCM +using fp32_t = float; +using fp16_t = __half; +using bf16_t = __nv_bfloat16; +using fp8_e4m3_t = __nv_fp8_e4m3; +using fp8_e5m2_t = __nv_fp8_e5m2; + +using fp32x2_t = float2; +using fp16x2_t = __half2; +using bf16x2_t = __nv_bfloat162; +using fp8x2_e4m3_t = __nv_fp8x2_e4m3; +using fp8x2_e5m2_t = __nv_fp8x2_e5m2; + +using fp32x4_t = float4; +#else +using fp32_t = float; +using fp16_t = __half; +using bf16_t = __hip_bfloat16; +using fp8_e4m3_t = uint8_t; +using fp8_e5m2_t = uint8_t; +using fp32x2_t = float2; +using fp16x2_t = half2; +using bf16x2_t = __hip_bfloat162; +using fp8x2_e4m3_t = uint16_t; +using fp8x2_e5m2_t = uint16_t; +using fp32x4_t = float4; +#endif + +/* + * LDG Support + */ +#ifndef USE_ROCM +#define SGLANG_LDG(arg) __ldg(arg) +#else +#define SGLANG_LDG(arg) *(arg) +#endif + +namespace device { + +#define SGL_DEVICE __forceinline__ __device__ + +inline constexpr auto kWarpThreads = 32u; +inline constexpr auto kFullMask = 0xffffffffu; + +template +SGL_DEVICE void PDLWaitPrimary() { +#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + if constexpr (kUsePDL) { + asm volatile("griddepcontrol.wait;" ::: "memory"); + } +#endif +} + +template +SGL_DEVICE void PDLTriggerSecondary() { +#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + if constexpr (kUsePDL) { + asm volatile("griddepcontrol.launch_dependents;" :::); + } +#endif +} + +/** + * \brief Load data with the specified type and offset from a void pointer. + * \tparam T The type to load. + * \param ptr The base pointer. + * \param offset The offset in number of elements of type T. + */ +template +SGL_DEVICE T load_as(const void* ptr, int64_t offset = 0) { + return static_cast(ptr)[offset]; +} + +/** + * \brief Store data with the specified type and offset to a void pointer. + * \tparam T The type to store. + * \param ptr The base pointer. + * \param val The value to store. + * \param offset The offset in number of elements of type T. + * \note we use type_identity_t to force the caller to explicitly specify + * the template parameter `T`, which can avoid accidentally using the wrong type. + */ +template +SGL_DEVICE void store_as(void* ptr, std::type_identity_t val, int64_t offset = 0) { + static_cast(ptr)[offset] = val; +} + +namespace pointer { + +// we only allow void * pointer arithmetic for safety + +template +SGL_DEVICE auto offset(void* ptr, U... offset) -> void* { + return static_cast(ptr) + (... + offset); +} + +template +SGL_DEVICE auto offset(const void* ptr, U... offset) -> const void* { + return static_cast(ptr) + (... + offset); +} + +} // namespace pointer + +} // namespace device + +namespace host { + +inline void RuntimeDeviceCheck(::cudaError_t error, DebugInfo location = {}) { + if (error != ::cudaSuccess) { + [[unlikely]]; + ::host::panic(location, "CUDA error: ", ::cudaGetErrorString(error)); + } +} + +inline void RuntimeDeviceCheck(DebugInfo location = {}) { + return RuntimeDeviceCheck(::cudaGetLastError(), location); +} + +struct LaunchKernel { + public: + explicit LaunchKernel( + dim3 grid_dim, + dim3 block_dim, + DLDevice device, + std::size_t dynamic_shared_mem_bytes = 0, + DebugInfo location = {}) noexcept + : m_config(s_make_config(grid_dim, block_dim, resolve_device(device), dynamic_shared_mem_bytes)), + m_location(location) {} + + explicit LaunchKernel( + dim3 grid_dim, + dim3 block_dim, + cudaStream_t stream, + std::size_t dynamic_shared_mem_bytes = 0, + DebugInfo location = {}) noexcept + : m_config(s_make_config(grid_dim, block_dim, stream, dynamic_shared_mem_bytes)), m_location(location) {} + + LaunchKernel(const LaunchKernel&) = delete; + LaunchKernel& operator=(const LaunchKernel&) = delete; + + static auto resolve_device(DLDevice device) -> cudaStream_t { + return static_cast(::TVMFFIEnvGetStream(device.device_type, device.device_id)); + } + + auto enable_pdl(bool enabled = true) -> LaunchKernel& { +#ifdef USE_ROCM + (void)enabled; + m_config.numAttrs = 0; +#else + if (enabled) { + m_attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + m_attrs[0].val.programmaticStreamSerializationAllowed = true; + m_config.numAttrs = 1; + m_config.attrs = m_attrs; + } else { + m_config.numAttrs = 0; + } +#endif + return *this; + } + + template + auto operator()(T&& kernel, Args&&... args) const -> void { +#ifdef USE_ROCM + hipLaunchKernelGGL( + std::forward(kernel), + m_config.gridDim, + m_config.blockDim, + m_config.dynamicSmemBytes, + m_config.stream, + std::forward(args)...); + RuntimeDeviceCheck(m_location); +#else + RuntimeDeviceCheck(::cudaLaunchKernelEx(&m_config, kernel, std::forward(args)...), m_location); +#endif + } + + private: + static auto s_make_config( // Make a config for kernel launch + dim3 grid_dim, + dim3 block_dim, + cudaStream_t stream, + std::size_t smem) -> cudaLaunchConfig_t { + auto config = ::cudaLaunchConfig_t{}; + config.gridDim = grid_dim; + config.blockDim = block_dim; + config.dynamicSmemBytes = smem; + config.stream = stream; + config.numAttrs = 0; + return config; + } + + cudaLaunchConfig_t m_config; + const DebugInfo m_location; + cudaLaunchAttribute m_attrs[1]; +}; + +} // namespace host diff --git a/sglang/python/sglang/jit_kernel/include/sgl_kernel/utils.h b/sglang/python/sglang/jit_kernel/include/sgl_kernel/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..78eae19fc83e76d91495a1abbe5c1c2afa90b796 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/include/sgl_kernel/utils.h @@ -0,0 +1,158 @@ +#pragma once + +// ref: https://forums.developer.nvidia.com/t/c-20s-source-location-compilation-error-when-using-nvcc-12-1/258026/3 +#ifdef __CUDACC__ +#include +#if CUDA_VERSION <= 12010 + +#pragma push_macro("__cpp_consteval") +#pragma push_macro("_NODISCARD") +#pragma push_macro("__builtin_LINE") + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wbuiltin-macro-redefined" +#define __cpp_consteval 201811L +#pragma clang diagnostic pop + +#ifdef _NODISCARD +#undef _NODISCARD +#define _NODISCARD +#endif + +#define consteval constexpr + +#include "source_location.h" + +#undef consteval +#pragma pop_macro("__cpp_consteval") +#pragma pop_macro("_NODISCARD") +#else // __CUDACC__ && CUDA_VERSION > 12010 +#include "source_location.h" +#endif +#else // no __CUDACC__ +#include "source_location.h" +#endif + +#include + +#include +#include +#include +#include +#include +#include + +namespace host { + +template +inline constexpr bool dependent_false_v = false; + +struct DebugInfo : public source_location_t { + DebugInfo(source_location_t loc = source_location_t::current()) : source_location_t(loc) {} +}; + +struct PanicError : public std::runtime_error { + public: + explicit PanicError(std::string msg) : runtime_error(msg), m_message(std::move(msg)) {} + auto root_cause() const -> std::string_view { + const auto str = std::string_view{m_message}; + const auto pos = str.find(": "); + return pos == std::string_view::npos ? str : str.substr(pos + 2); + } + + private: + std::string m_message; +}; + +template +[[noreturn]] +inline auto panic(DebugInfo location, Args&&... args) -> void { + std::ostringstream os; + os << "Runtime check failed at " << location.file_name() << ":" << location.line(); + if constexpr (sizeof...(args) > 0) { + os << ": "; + (os << ... << std::forward(args)); + } else { + os << " in " << location.function_name(); + } + throw PanicError(std::move(os).str()); +} + +template +struct RuntimeCheck { + template + explicit RuntimeCheck(Cond&& condition, Args&&... args, DebugInfo location = {}) { + if (condition) return; + [[unlikely]] ::host::panic(location, std::forward(args)...); + } + template + explicit RuntimeCheck(DebugInfo location, Cond&& condition, Args&&... args) { + if (condition) return; + [[unlikely]] ::host::panic(location, std::forward(args)...); + } +}; + +template +struct Panic { + explicit Panic(Args&&... args, DebugInfo location = {}) { + ::host::panic(location, std::forward(args)...); + } + explicit Panic(DebugInfo location, Args&&... args) { + ::host::panic(location, std::forward(args)...); + } + [[noreturn]] ~Panic() { + std::terminate(); + } +}; + +template +explicit RuntimeCheck(Cond&&, Args&&...) -> RuntimeCheck; + +template +explicit RuntimeCheck(DebugInfo, Cond&&, Args&&...) -> RuntimeCheck; + +template +explicit Panic(Args&&...) -> Panic; + +template +explicit Panic(DebugInfo, Args&&...) -> Panic; + +namespace pointer { + +// we only allow void * pointer arithmetic for safety + +template +inline auto offset(void* ptr, U... offset) -> void* { + return static_cast(ptr) + (... + offset); +} + +template +inline auto offset(const void* ptr, U... offset) -> const void* { + return static_cast(ptr) + (... + offset); +} + +} // namespace pointer + +template +inline constexpr auto div_ceil(T a, U b) { + return (a + b - 1) / b; +} + +inline auto dtype_bytes(DLDataType dtype) -> std::size_t { + return static_cast(dtype.bits / 8); +} + +namespace stdr = std::ranges; +namespace stdv = stdr::views; + +template +inline auto irange(T end) { + return stdv::iota(static_cast(0), end); +} + +template +inline auto irange(T start, T end) { + return stdv::iota(start, end); +} + +} // namespace host diff --git a/sglang/python/sglang/jit_kernel/include/sgl_kernel/vec.cuh b/sglang/python/sglang/jit_kernel/include/sgl_kernel/vec.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c1150d4282a9fa3cdeeb929a8697f9cc575325d2 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/include/sgl_kernel/vec.cuh @@ -0,0 +1,88 @@ +#pragma once +#include + +#include +#include + +namespace device { + +namespace details { + +template +struct uint_trait {}; + +template <> +struct uint_trait<1> { + using type = uint8_t; +}; + +template <> +struct uint_trait<2> { + using type = uint16_t; +}; + +template <> +struct uint_trait<4> { + using type = uint32_t; +}; + +template <> +struct uint_trait<8> { + using type = uint64_t; +}; + +template +using sized_int = typename uint_trait::type; + +} // namespace details + +template +struct alignas(sizeof(T) * N) AlignedStorage { + T data[N]; +}; + +template +struct AlignedVector { + private: + /// NOTE: 1. must be pow of two 2. 16 * 8 = 128 byte, which is the max vector size supported by most devices + static_assert((N > 0 && (N & (N - 1)) == 0) && sizeof(T) * N <= 32, "CUDA only support at most 256B vector op"); + using element_t = typename details::sized_int; + using storage_t = AlignedStorage; + + public: + template + SGL_DEVICE void load(const U* ptr, std::size_t offset = 0) { + static_assert(std::is_same_v || std::is_same_v); + m_storage = reinterpret_cast(ptr)[offset]; + } + template + SGL_DEVICE void store(U* ptr, std::size_t offset = 0) const { + static_assert(std::is_same_v || std::is_same_v); + reinterpret_cast(ptr)[offset] = m_storage; + } + SGL_DEVICE void fill(T value) { + const auto store_value = *reinterpret_cast(&value); +#pragma unroll + for (std::size_t i = 0; i < N; ++i) { + m_storage.data[i] = store_value; + } + } + + SGL_DEVICE auto operator[](std::size_t idx) -> T& { + return reinterpret_cast(&m_storage)[idx]; + } + SGL_DEVICE auto operator[](std::size_t idx) const -> T { + return reinterpret_cast(&m_storage)[idx]; + } + SGL_DEVICE auto data() -> T* { + return reinterpret_cast(&m_storage); + } + SGL_DEVICE auto data() const -> const T* { + return reinterpret_cast(&m_storage); + } + + private: + storage_t m_storage; +}; + +} // namespace device diff --git a/sglang/python/sglang/jit_kernel/include/sgl_kernel/warp.cuh b/sglang/python/sglang/jit_kernel/include/sgl_kernel/warp.cuh new file mode 100644 index 0000000000000000000000000000000000000000..d69526e97f293680ad371f1cdc6fd265044ecce1 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/include/sgl_kernel/warp.cuh @@ -0,0 +1,25 @@ +#pragma once +#include + +// Some warp primitives +namespace device::warp { + +static constexpr uint32_t kFullMask = 0xffffffffu; + +template +SGL_DEVICE T reduce_sum(T value, uint32_t active_mask = kFullMask) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + value = value + __shfl_xor_sync(active_mask, value, mask, 32); + return value; +} + +template +SGL_DEVICE T reduce_max(T value, uint32_t active_mask = kFullMask) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + value = math::max(value, __shfl_xor_sync(active_mask, value, mask, 32)); + return value; +} + +} // namespace device::warp diff --git a/sglang/python/sglang/jit_kernel/kvcache.py b/sglang/python/sglang/jit_kernel/kvcache.py new file mode 100644 index 0000000000000000000000000000000000000000..46a14612b6ff12740ba2fd4ade3b769da75a8935 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/kvcache.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import ( + cache_once, + is_arch_support_pdl, + load_jit, + make_cpp_args, +) + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_kvcache_module(row_bytes: int) -> Module: + args = make_cpp_args(row_bytes, is_arch_support_pdl()) + return load_jit( + "kvcache", + *args, + cuda_files=["elementwise/kvcache.cuh"], + cuda_wrappers=[("store_cache", f"StoreKVCacheKernel<{args}>::run")], + ) + + +@cache_once +def can_use_store_cache(size: int) -> bool: + logger = logging.getLogger(__name__) + if size % 4 != 0: + logger.warning( + f"Unsupported row_bytes={size} for JIT KV-Cache kernel:" + " must be multiple of 4" + ) + return False + try: + _jit_kvcache_module(size) + return True + except Exception as e: + logger.warning( + f"Failed to load JIT KV-Cache kernel " f"with row_bytes={size}: {e}" + ) + return False + + +def store_cache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, + *, + row_bytes: int = 0, + num_split: int = 0, # can be tuned for performance +) -> None: + """Store key and value tensors into KV cache at specified indices. + + Args: + k (torch.Tensor): Key tensor of shape (batch_size, H * D). + v (torch.Tensor): Value tensor of shape (batch_size, H * D). + k_cache (torch.Tensor): Key cache tensor of shape (num_pages, H * D). + v_cache (torch.Tensor): Value cache tensor of shape (num_pages, H * D). + indices (torch.Tensor): Indices tensor of shape (batch_size,). + """ + row_bytes = row_bytes or k.shape[-1] * k.element_size() + module = _jit_kvcache_module(row_bytes) + if num_split <= 0: + if row_bytes % 2048 == 0: + num_split = 4 + elif row_bytes % 1024 == 0: + num_split = 2 + else: + num_split = 1 + module.store_cache( + k, + v, + k_cache, + v_cache, + indices, + num_split, + ) diff --git a/sglang/python/sglang/jit_kernel/moe_wna16_marlin.py b/sglang/python/sglang/jit_kernel/moe_wna16_marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..7e17e5ccd74ef036fa8da39b02ae1cf115a1371f --- /dev/null +++ b/sglang/python/sglang/jit_kernel/moe_wna16_marlin.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args + +if TYPE_CHECKING: + from sgl_kernel.scalar_type import ScalarType + from tvm_ffi.module import Module + +# Constants matching device::marlin_moe:: in marlin.cuh +_MAX_THREAD_N = 256 + + +@cache_once +def _jit_moe_wna16_marlin_module(dtype: torch.dtype) -> Module: + args = make_cpp_args(dtype) + return load_jit( + "moe_wna16_marlin", + *args, + cuda_files=["gemm/marlin_moe/moe_wna16_marlin.cuh"], + cuda_wrappers=[ + ( + "moe_wna16_marlin_gemm", + f"moe_wna16_marlin_gemm<{args}>", + ) + ], + ) + + +def _or_empty( + t: Optional[torch.Tensor], device: torch.device, dtype: torch.dtype +) -> torch.Tensor: + return t if t is not None else torch.empty(0, device=device, dtype=dtype) + + +def moe_wna16_marlin_gemm( + a: torch.Tensor, + c_or_none: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_bias_or_none: Optional[torch.Tensor], + b_scales: torch.Tensor, + global_scale_or_none: Optional[torch.Tensor], + b_zeros_or_none: Optional[torch.Tensor], + g_idx_or_none: Optional[torch.Tensor], + perm_or_none: Optional[torch.Tensor], + workspace: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + topk_weights: torch.Tensor, + moe_block_size: int, + top_k: int, + mul_topk_weights: bool, + is_ep: bool, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, +) -> torch.Tensor: + device = a.device + + # Allocate output if not provided + if c_or_none is not None: + c = c_or_none + else: + c = torch.empty((size_m * top_k, size_n), dtype=a.dtype, device=device) + + # Early return for zero-size M + if size_m == 0: + return c + + # Determine activation ordering + has_act_order = ( + g_idx_or_none is not None + and perm_or_none is not None + and g_idx_or_none.numel() > 0 + and perm_or_none.numel() > 0 + and g_idx_or_none.size(-1) > 0 + and perm_or_none.size(-1) > 0 + ) + + # Determine has_zp + has_zp = b_zeros_or_none is not None and b_zeros_or_none.numel() > 0 + + # Determine has_bias + has_bias = b_bias_or_none is not None + + # Derive num_groups and group_size from b_scales + num_groups = b_scales.size(1) + if has_act_order: + if is_k_full: + group_size = size_k // num_groups + else: + group_size = 0 + else: + if num_groups > 1: + group_size = size_k // num_groups + else: + group_size = -1 + + # Allocate a_tmp for act_order column permutation + if has_act_order: + a_tmp = torch.empty((size_m * top_k, size_k), dtype=a.dtype, device=device) + else: + a_tmp = torch.empty(0, dtype=a.dtype, device=device) + + # Allocate c_tmp for fp32 reduce + if use_fp32_reduce and not use_atomic_add: + sms = torch.cuda.get_device_properties(device).multi_processor_count + # max num of threadblocks is sms * 4 + max_c_tmp_size = min( + size_n * sorted_token_ids.size(0), + sms * 4 * moe_block_size * _MAX_THREAD_N, + ) + if moe_block_size == 8: + max_c_tmp_size *= 2 + c_tmp = torch.empty(max_c_tmp_size, dtype=torch.float32, device=device) + else: + c_tmp = torch.empty(0, dtype=torch.float32, device=device) + + # Convert Optional tensors to empty tensors + g_idx_t = _or_empty(g_idx_or_none, device, torch.int32) + perm_t = _or_empty(perm_or_none, device, torch.int32) + b_zeros_t = _or_empty(b_zeros_or_none, device, a.dtype) + b_bias_t = _or_empty(b_bias_or_none, device, a.dtype) + global_scale_t = _or_empty(global_scale_or_none, device, a.dtype) + + module = _jit_moe_wna16_marlin_module(a.dtype) + module.moe_wna16_marlin_gemm( + a, + c, + b_q_weight, + b_bias_t, + b_scales, + global_scale_t, + b_zeros_t, + g_idx_t, + perm_t, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + a_tmp, + c_tmp, + moe_block_size, + top_k, + mul_topk_weights, + is_ep, + b_q_type.id, + size_m, + size_n, + size_k, + has_act_order, + has_bias, + is_k_full, + has_zp, + num_groups, + group_size, + use_atomic_add, + use_fp32_reduce, + is_zp_float, + ) + + return c diff --git a/sglang/python/sglang/jit_kernel/norm.py b/sglang/python/sglang/jit_kernel/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b2aee1b11061c25666228b169918f27229471d --- /dev/null +++ b/sglang/python/sglang/jit_kernel/norm.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Optional + +import torch + +from sglang.jit_kernel.utils import ( + cache_once, + is_arch_support_pdl, + load_jit, + make_cpp_args, +) + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_qknorm_module(head_dim: int, dtype: torch.dtype) -> Module: + args = make_cpp_args(head_dim, is_arch_support_pdl(), dtype) + return load_jit( + "qknorm", + *args, + cuda_files=["elementwise/qknorm.cuh"], + cuda_wrappers=[("qknorm", f"QKNormKernel<{args}>::run")], + ) + + +@cache_once +def _jit_rmsnorm_module(hidden_size: int, dtype: torch.dtype) -> Module: + args = make_cpp_args(hidden_size, is_arch_support_pdl(), dtype) + return load_jit( + "rmsnorm", + *args, + cuda_files=["elementwise/rmsnorm.cuh"], + cuda_wrappers=[("rmsnorm", f"RMSNormKernel<{args}>::run")], + ) + + +@cache_once +def _jit_fused_add_rmsnorm_module(dtype: torch.dtype) -> Module: + args = make_cpp_args(dtype) + return load_jit( + "fused_add_rmsnorm", + *args, + cuda_files=["elementwise/fused_add_rmsnorm.cuh"], + cuda_wrappers=[("fused_add_rmsnorm", f"FusedAddRMSNormKernel<{args}>::run")], + ) + + +@cache_once +def _jit_qknorm_across_heads_module(dtype: torch.dtype) -> Module: + args = make_cpp_args(dtype) + return load_jit( + "qknorm_across_heads", + *args, + cuda_files=["elementwise/qknorm_across_heads.cuh"], + cuda_wrappers=[ + ("qknorm_across_heads", f"QKNormAcrossHeadsKernel<{args}>::run") + ], + ) + + +@cache_once +def can_use_fused_inplace_qknorm(head_dim: int, dtype: torch.dtype) -> bool: + logger = logging.getLogger(__name__) + if head_dim not in [64, 128, 256, 512, 1024]: + logger.warning(f"Unsupported head_dim={head_dim} for JIT QK-Norm kernel") + return False + try: + _jit_qknorm_module(head_dim, dtype) + return True + except Exception as e: + logger.warning(f"Failed to load JIT QK-Norm kernel: {e}") + return False + + +def fused_inplace_qknorm( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + eps: float = 1e-6, + *, + head_dim: int = 0, +) -> None: + head_dim = head_dim or q.size(-1) + module = _jit_qknorm_module(head_dim, q.dtype) + module.qknorm(q, k, q_weight, k_weight, eps) + + +def rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, + output: Optional[torch.Tensor] = None, + eps: float = 1e-6, +) -> None: + output = output if output is not None else input + hidden_size = input.size(-1) + module = _jit_rmsnorm_module(hidden_size, input.dtype) + module.rmsnorm(input, weight, output, eps) + + +def fused_add_rmsnorm( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + module = _jit_fused_add_rmsnorm_module(input.dtype) + module.fused_add_rmsnorm(input, residual, weight, eps) + + +def fused_inplace_qknorm_across_heads( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + """ + Fused inplace QK normalization across all heads. + + Args: + q: Query tensor of shape [batch_size, num_heads * head_dim] + k: Key tensor of shape [batch_size, num_heads * head_dim] + q_weight: Query weight tensor of shape [num_heads * head_dim] + k_weight: Key weight tensor of shape [num_heads * head_dim] + eps: Epsilon for numerical stability + """ + module = _jit_qknorm_across_heads_module(q.dtype) + module.qknorm_across_heads(q, k, q_weight, k_weight, eps) diff --git a/sglang/python/sglang/jit_kernel/nvfp4.py b/sglang/python/sglang/jit_kernel/nvfp4.py new file mode 100644 index 0000000000000000000000000000000000000000..2fc6631752dfef9a9ea799bfe43b9a82055b5570 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/nvfp4.py @@ -0,0 +1,628 @@ +from __future__ import annotations + +import importlib.util +import os +import pathlib +from contextlib import contextmanager +from typing import TYPE_CHECKING, Optional, Tuple + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit +from sglang.srt.utils.custom_op import register_custom_op + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +_FLOAT4_E2M1_MAX = 6.0 +_FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + + +def _find_package_root(package: str) -> Optional[pathlib.Path]: + spec = importlib.util.find_spec(package) + if spec is None or spec.origin is None: + return None + return pathlib.Path(spec.origin).resolve().parent + + +def _resolve_cutlass_include_paths() -> list[str]: + include_paths: list[str] = [] + + flashinfer_root = _find_package_root("flashinfer") + if flashinfer_root is not None: + candidates = [ + flashinfer_root / "data" / "cutlass" / "include", + flashinfer_root / "data" / "cutlass" / "tools" / "util" / "include", + ] + for path in candidates: + if path.exists(): + include_paths.append(str(path)) + + deep_gemm_root = _find_package_root("deep_gemm") + if deep_gemm_root is not None: + candidate = deep_gemm_root / "include" + if candidate.exists(): + include_paths.append(str(candidate)) + + # De-duplicate while preserving order. + unique_paths = [] + seen = set() + for path in include_paths: + if path in seen: + continue + seen.add(path) + unique_paths.append(path) + return unique_paths + + +def _nvfp4_cuda_flags() -> list[str]: + return [ + "-DNDEBUG", + "-DFLASHINFER_ENABLE_F16", + "-DCUTE_USE_PACKED_TUPLE=1", + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", + "-DCUTLASS_VERSIONS_GENERATED", + "-DCUTLASS_TEST_LEVEL=0", + "-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1", + "-DCUTLASS_DEBUG_TRACE_LEVEL=0", + "--expt-extended-lambda", + ] + + +def _parse_cuda_version() -> tuple[int, int]: + v = torch.version.cuda + if not v: + return (0, 0) + parts = v.split(".") + if len(parts) < 2: + return (0, 0) + try: + return int(parts[0]), int(parts[1]) + except ValueError: + return (0, 0) + + +def _get_nvfp4_cuda_arch_list() -> str: + if not torch.cuda.is_available(): + raise RuntimeError("NVFP4 JIT kernels require CUDA.") + major, minor = torch.cuda.get_device_capability() + if major < 10: + raise RuntimeError( + f"NVFP4 JIT kernels require compute capability >= 10.0, got {major}.{minor}." + ) + # NVFP4 kernels use architecture-family-specific instructions and must be + # compiled for `sm_*a` targets (e.g. sm_100a), not plain sm_100. + archs = [f"{major}.{minor}a"] + cuda_major, _cuda_minor = _parse_cuda_version() + if cuda_major >= 13 and "10.3a" not in archs: + # Match sgl-kernel AOT fatbin behavior on CUDA 13+ for Blackwell. + archs.append("10.3a") + # Preserve order while de-duplicating. + seen = set() + ordered_archs: list[str] = [] + for arch in archs: + if arch in seen: + continue + seen.add(arch) + ordered_archs.append(arch) + return " ".join(ordered_archs) + + +@contextmanager +def _nvfp4_arch_env(): + key = "TVM_FFI_CUDA_ARCH_LIST" + old_val = os.environ.get(key) + os.environ[key] = _get_nvfp4_cuda_arch_list() + try: + yield + finally: + if old_val is None: + os.environ.pop(key, None) + else: + os.environ[key] = old_val + + +@cache_once +def _jit_nvfp4_quant_module() -> Module: + extra_include_paths = _resolve_cutlass_include_paths() + if not extra_include_paths: + raise RuntimeError( + "Cannot find CUTLASS headers required for NVFP4 JIT quantization. " + "Please install flashinfer or deep_gemm with CUTLASS headers." + ) + + with _nvfp4_arch_env(): + return load_jit( + "nvfp4_quant", + cuda_files=[ + "gemm/nvfp4/nvfp4_quant_kernels.cuh", + ], + cuda_wrappers=[ + ("scaled_fp4_quant", "scaled_fp4_quant_sm100a_sm120a"), + ], + extra_include_paths=extra_include_paths, + extra_cuda_cflags=_nvfp4_cuda_flags(), + ) + + +@cache_once +def _jit_nvfp4_expert_quant_module() -> Module: + extra_include_paths = _resolve_cutlass_include_paths() + if not extra_include_paths: + raise RuntimeError( + "Cannot find CUTLASS headers required for NVFP4 JIT expert quantization. " + "Please install flashinfer or deep_gemm with CUTLASS headers." + ) + + with _nvfp4_arch_env(): + return load_jit( + "nvfp4_expert_quant", + cuda_files=[ + "gemm/nvfp4/nvfp4_expert_quant.cuh", + ], + cuda_wrappers=[ + ("scaled_fp4_experts_quant", "scaled_fp4_experts_quant_sm100a"), + ( + "silu_and_mul_scaled_fp4_experts_quant", + "silu_and_mul_scaled_fp4_experts_quant_sm100a", + ), + ], + extra_include_paths=extra_include_paths, + extra_cuda_cflags=_nvfp4_cuda_flags(), + ) + + +@cache_once +def _jit_nvfp4_scaled_mm_module() -> Module: + extra_include_paths = _resolve_cutlass_include_paths() + if not extra_include_paths: + raise RuntimeError( + "Cannot find CUTLASS headers required for NVFP4 JIT GEMM. " + "Please install flashinfer or deep_gemm with CUTLASS headers." + ) + + with _nvfp4_arch_env(): + return load_jit( + "nvfp4_scaled_mm", + cuda_files=[ + "gemm/nvfp4/nvfp4_scaled_mm_kernels.cuh", + "gemm/nvfp4/nvfp4_scaled_mm_entry.cuh", + ], + cuda_wrappers=[("cutlass_scaled_fp4_mm", "cutlass_scaled_fp4_mm")], + extra_include_paths=extra_include_paths, + extra_cuda_cflags=_nvfp4_cuda_flags(), + ) + + +@cache_once +def _jit_nvfp4_blockwise_moe_module() -> Module: + extra_include_paths = _resolve_cutlass_include_paths() + if not extra_include_paths: + raise RuntimeError( + "Cannot find CUTLASS headers required for NVFP4 JIT MoE grouped GEMM. " + "Please install flashinfer or deep_gemm with CUTLASS headers." + ) + + with _nvfp4_arch_env(): + return load_jit( + "nvfp4_blockwise_moe", + cuda_files=[ + "moe/nvfp4_blockwise_moe.cuh", + ], + cuda_wrappers=[ + ("cutlass_fp4_group_mm", "cutlass_fp4_group_mm_sm100a_sm120a") + ], + extra_include_paths=extra_include_paths, + extra_cuda_cflags=_nvfp4_cuda_flags(), + ) + + +def cutlass_scaled_fp4_mm( + a: torch.Tensor, + b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, + alpha: torch.Tensor, + out_dtype: torch.dtype, +) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 + m, n = a.shape[0], b.shape[0] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + module = _jit_nvfp4_scaled_mm_module() + module.cutlass_scaled_fp4_mm(out, a, b, block_scale_a, block_scale_b, alpha) + return out + + +def cutlass_fp4_group_mm( + a_fp4: torch.Tensor, + b_fp4: torch.Tensor, + a_blockscale: torch.Tensor, + b_blockscale: torch.Tensor, + alphas: torch.Tensor, + out_dtype: torch.dtype, + params: dict[str, torch.Tensor], +) -> torch.Tensor: + m_topk = a_fp4.shape[0] + n = b_fp4.shape[1] + output = torch.empty((m_topk, n), device=a_fp4.device, dtype=out_dtype) + num_experts = int(params["expert_offsets"].numel()) + device = a_fp4.device + + # Backward compatibility: older callers may not pass scratch tensors. + a_ptrs = params.get( + "a_ptrs", torch.empty((num_experts,), dtype=torch.int64, device=device) + ) + b_ptrs = params.get( + "b_ptrs", torch.empty((num_experts,), dtype=torch.int64, device=device) + ) + out_ptrs = params.get( + "out_ptrs", torch.empty((num_experts,), dtype=torch.int64, device=device) + ) + a_scales_ptrs = params.get( + "a_scales_ptrs", torch.empty((num_experts,), dtype=torch.int64, device=device) + ) + b_scales_ptrs = params.get( + "b_scales_ptrs", torch.empty((num_experts,), dtype=torch.int64, device=device) + ) + alpha_ptrs = params.get( + "alpha_ptrs", torch.empty((num_experts,), dtype=torch.int64, device=device) + ) + layout_sfa = params.get( + "layout_sfa", torch.empty((num_experts, 5), dtype=torch.int64, device=device) + ) + layout_sfb = params.get( + "layout_sfb", torch.empty((num_experts, 5), dtype=torch.int64, device=device) + ) + + _cutlass_fp4_group_mm_custom_op( + output, + a_fp4, + b_fp4, + a_blockscale, + b_blockscale, + alphas, + params["ab_strides"], + params["c_strides"], + params["problem_sizes"], + params["expert_offsets"], + params["blockscale_offsets"], + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + alpha_ptrs, + layout_sfa, + layout_sfb, + ) + return output + + +@register_custom_op( + op_name="scaled_fp4_quant", + mutates_args=["output", "output_scale"], +) +def _scaled_fp4_quant_custom_op( + input: torch.Tensor, + output: torch.Tensor, + output_scale: torch.Tensor, + input_global_scale: torch.Tensor, +) -> None: + module = _jit_nvfp4_quant_module() + module.scaled_fp4_quant(output, input, output_scale, input_global_scale) + + +def scaled_fp4_quant( + input: torch.Tensor, input_global_scale: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize input tensor to FP4 and return packed FP4 tensor + swizzled scales.""" + assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}." + other_dims = 1 if input.ndim == 1 else -1 + input = input.reshape(other_dims, input.shape[-1]) + m, n = input.shape + block_size = 16 + device = input.device + + assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}." + assert input.dtype in ( + torch.float16, + torch.bfloat16, + ), f"input.dtype needs to be fp16 or bf16 but got {input.dtype}." + + output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + + rounded_m = ((m + 128 - 1) // 128) * 128 + scale_n = n // block_size + rounded_n = ((scale_n + 4 - 1) // 4) * 4 + if rounded_n > scale_n: + output_scale = torch.zeros( + (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 + ) + else: + output_scale = torch.empty( + (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 + ) + + _scaled_fp4_quant_custom_op(input, output, output_scale, input_global_scale) + output_scale = output_scale.view(torch.float8_e4m3fn) + return output, output_scale + + +def _shuffle_rows_torch( + input_tensor: torch.Tensor, + dst2src_map: torch.Tensor, + output_tensor_shape: tuple[int, int], +) -> torch.Tensor: + # Keep compatibility when sgl-kernel is slimmed and shuffle_rows may not be present. + output = input_tensor.index_select(0, dst2src_map.to(dtype=torch.int64)) + return output.view(output_tensor_shape) + + +@register_custom_op( + op_name="scaled_fp4_experts_quant", + mutates_args=["output", "output_scales"], +) +def _scaled_fp4_experts_quant_custom_op( + output: torch.Tensor, + output_scales: torch.Tensor, + input_tensor: torch.Tensor, + input_global_scale: torch.Tensor, + expert_offsets: torch.Tensor, + blockscale_offsets: torch.Tensor, +) -> None: + module = _jit_nvfp4_expert_quant_module() + module.scaled_fp4_experts_quant( + output, + output_scales, + input_tensor, + input_global_scale, + expert_offsets, + blockscale_offsets, + ) + + +def scaled_fp4_experts_quant( + input_tensor: torch.Tensor, + input_global_scale: torch.Tensor, + expert_offsets: torch.Tensor, + blockscale_offsets: torch.Tensor, + topk: int, + expert_map: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize packed MoE activations to NVFP4.""" + assert ( + input_tensor.ndim == 2 + ), f"input.ndim needs to be == 2, but got {input_tensor.ndim}." + if expert_map is not None: + m, k = input_tensor.shape + output_tensor_shape = (m * topk, k) + input_tensor = _shuffle_rows_torch( + input_tensor, expert_map, output_tensor_shape + ) + + m_numtopk, k = input_tensor.shape + max_tokens_per_expert = int(os.environ.get("MODELOPT_MAX_TOKENS_PER_EXPERT", 65536)) + assert m_numtopk <= max_tokens_per_expert * topk, ( + f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT({max_tokens_per_expert})" + f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use" + " MODELOPT_MAX_TOKENS_PER_EXPERT to set this value." + ) + scales_k = k // 16 + # output_scales is int32-packed FP8 scales, so second dim is in int32 units. + padded_k_in_int32 = (scales_k + 3) // 4 + + output = torch.empty( + m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8 + ) + if padded_k_in_int32 * 4 > scales_k: + output_scales = torch.zeros( + max_tokens_per_expert * topk, + padded_k_in_int32, + dtype=torch.int32, + device=input_tensor.device, + ) + else: + output_scales = torch.empty( + max_tokens_per_expert * topk, + padded_k_in_int32, + dtype=torch.int32, + device=input_tensor.device, + ) + + _scaled_fp4_experts_quant_custom_op( + output, + output_scales, + input_tensor, + input_global_scale, + expert_offsets, + blockscale_offsets, + ) + output_scales = output_scales.view(torch.float8_e4m3fn) + return output, output_scales + + +@register_custom_op( + op_name="scaled_fp4_grouped_quant", + mutates_args=["output", "output_scales"], +) +def _scaled_fp4_grouped_quant_custom_op( + input_tensor: torch.Tensor, + output: torch.Tensor, + output_scales: torch.Tensor, + input_global_scale: torch.Tensor, + mask: torch.Tensor, +) -> None: + l, m, k = input_tensor.shape + del l, m + module = _jit_nvfp4_expert_quant_module() + module.silu_and_mul_scaled_fp4_experts_quant( + output.view(-1, k // 2), + output_scales.view(-1, output_scales.shape[-1]), + input_tensor.view(-1, k), + input_global_scale, + mask, + False, + ) + + +def scaled_fp4_grouped_quant( + input_tensor: torch.Tensor, + input_global_scale: torch.Tensor, + mask: torch.Tensor, +): + """Quantize grouped GEMM inputs to FP4 and return logical (m, k//2, l).""" + device = input_tensor.device + l, m, k = input_tensor.shape + sf_vec_size = 16 + assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}." + + scale_k = k // sf_vec_size + padded_k = (scale_k + (4 - 1)) // 4 * 4 + padded_k_int32 = padded_k // 4 + padded_m = (m + (128 - 1)) // 128 * 128 + output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) + output_scales = torch.empty( + l, padded_m, padded_k_int32, device=device, dtype=torch.int32 + ) + + _scaled_fp4_grouped_quant_custom_op( + input_tensor, + output, + output_scales, + input_global_scale, + mask, + ) + + output = output.permute(1, 2, 0) + output_scales = output_scales.view(torch.float8_e4m3fn).view( + l, padded_m // 128, padded_k // 4, 32, 4, 4 + ) + output_scales = output_scales.permute(3, 4, 1, 5, 2, 0) + return output, output_scales + + +@register_custom_op( + op_name="silu_and_mul_scaled_fp4_grouped_quant", + mutates_args=["output", "output_scales"], +) +def _silu_and_mul_scaled_fp4_grouped_quant_custom_op( + input_tensor: torch.Tensor, + output: torch.Tensor, + output_scales: torch.Tensor, + input_global_scale: torch.Tensor, + mask: torch.Tensor, +) -> None: + l, m, k_by_2 = input_tensor.shape + del l, m + module = _jit_nvfp4_expert_quant_module() + module.silu_and_mul_scaled_fp4_experts_quant( + output.view(-1, output.shape[-1]), + output_scales.view(-1, output_scales.shape[-1]), + input_tensor.view(-1, k_by_2), + input_global_scale, + mask, + True, + ) + + +def silu_and_mul_scaled_fp4_grouped_quant( + input_tensor: torch.Tensor, + input_global_scale: torch.Tensor, + mask: torch.Tensor, +): + """Apply SiLU-and-mul then quantize grouped GEMM inputs to FP4.""" + device = input_tensor.device + l, m, k_by_2 = input_tensor.shape + k = k_by_2 // 2 + sf_vec_size = 16 + assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}." + + scale_k = k // sf_vec_size + padded_k = (scale_k + (4 - 1)) // 4 * 4 + padded_k_int32 = padded_k // 4 + padded_m = (m + (128 - 1)) // 128 * 128 + output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) + output_scales = torch.empty( + l, padded_m, padded_k_int32, device=device, dtype=torch.int32 + ) + + _silu_and_mul_scaled_fp4_grouped_quant_custom_op( + input_tensor, + output, + output_scales, + input_global_scale, + mask, + ) + + output = output.permute(1, 2, 0) + output_scales = output_scales.view(torch.float8_e4m3fn).view( + l, padded_m // 128, padded_k // 4, 32, 4, 4 + ) + output_scales = output_scales.permute(3, 4, 1, 5, 2, 0) + return output, output_scales + + +@register_custom_op( + op_name="cutlass_fp4_group_mm", + mutates_args=[ + "output", + "a_ptrs", + "b_ptrs", + "out_ptrs", + "a_scales_ptrs", + "b_scales_ptrs", + "alpha_ptrs", + "layout_sfa", + "layout_sfb", + ], +) +def _cutlass_fp4_group_mm_custom_op( + output: torch.Tensor, + a_fp4: torch.Tensor, + b_fp4: torch.Tensor, + a_blockscale: torch.Tensor, + b_blockscale: torch.Tensor, + alphas: torch.Tensor, + ab_strides: torch.Tensor, + c_strides: torch.Tensor, + problem_sizes: torch.Tensor, + expert_offsets: torch.Tensor, + blockscale_offsets: torch.Tensor, + a_ptrs: torch.Tensor, + b_ptrs: torch.Tensor, + out_ptrs: torch.Tensor, + a_scales_ptrs: torch.Tensor, + b_scales_ptrs: torch.Tensor, + alpha_ptrs: torch.Tensor, + layout_sfa: torch.Tensor, + layout_sfb: torch.Tensor, +) -> None: + module = _jit_nvfp4_blockwise_moe_module() + module.cutlass_fp4_group_mm( + output, + a_fp4, + b_fp4, + a_blockscale, + b_blockscale, + alphas, + ab_strides, + c_strides, + problem_sizes, + expert_offsets, + blockscale_offsets, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + alpha_ptrs, + layout_sfa, + layout_sfb, + ) + + +def suggest_nvfp4_global_scale(x: torch.Tensor) -> torch.Tensor: + """Utility for tests/benchmarks: return global scale used by NVFP4 quantization.""" + tensor_amax = torch.abs(x).max().to(torch.float32) + return _FLOAT8_E4M3_MAX * _FLOAT4_E2M1_MAX / tensor_amax diff --git a/sglang/python/sglang/jit_kernel/per_tensor_quant_fp8.py b/sglang/python/sglang/jit_kernel/per_tensor_quant_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..9225aa45d1e5426d429774ad8e0f2b46386838fa --- /dev/null +++ b/sglang/python/sglang/jit_kernel/per_tensor_quant_fp8.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args +from sglang.srt.utils.custom_op import register_custom_op + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_per_tensor_quant_fp8_module(is_static: bool, dtype: torch.dtype) -> Module: + args = make_cpp_args(is_static, dtype) + return load_jit( + "per_tensor_quant_fp8", + *args, + cuda_files=["gemm/per_tensor_quant_fp8.cuh"], + cuda_wrappers=[("per_tensor_quant_fp8", f"per_tensor_quant_fp8<{args}>")], + ) + + +@register_custom_op( + op_name="per_tensor_quant_fp8", + mutates_args=["output_q", "output_s"], +) +def per_tensor_quant_fp8( + input: torch.Tensor, + output_q: torch.Tensor, + output_s: torch.Tensor, + is_static: bool = False, +) -> None: + """ + Per-tensor quantization to FP8 format. + + Args: + input: Input tensor to quantize (float, half, or bfloat16) + output_q: Output quantized tensor (fp8_e4m3) + output_s: Output scale tensor (float scalar or 1D tensor with 1 element) + is_static: If True, assumes scale is pre-computed and skips absmax computation + """ + module = _jit_per_tensor_quant_fp8_module(is_static, input.dtype) + module.per_tensor_quant_fp8(input.view(-1), output_q.view(-1), output_s.view(-1)) diff --git a/sglang/python/sglang/jit_kernel/per_token_group_quant_8bit.py b/sglang/python/sglang/jit_kernel/per_token_group_quant_8bit.py new file mode 100644 index 0000000000000000000000000000000000000000..fdfbb29802e8997042ffafcfafdc96ebf35271b5 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/per_token_group_quant_8bit.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args +from sglang.srt.utils.custom_op import register_custom_op + +if TYPE_CHECKING: + from tvm_ffi.module import Module + +from sglang.jit_kernel.utils import CPP_DTYPE_MAP as OUTPUT_DTYPE_MAP + + +@cache_once +def _jit_per_token_group_quant_8bit_module( + dtype: torch.dtype, output_type: torch.dtype +) -> Module: + input_args = make_cpp_args(dtype) + out_cpp = OUTPUT_DTYPE_MAP[output_type] + return load_jit( + "per_token_group_quant_8bit", + cuda_files=["gemm/per_token_group_quant_8bit.cuh"], + cuda_wrappers=[ + ( + "per_token_group_quant_8bit", + f"per_token_group_quant_8bit<{input_args}, {out_cpp}>", + ) + ], + ) + + +@register_custom_op( + op_name="per_token_group_quant_8bit", + mutates_args=["output_q", "output_s"], +) +def _per_token_group_quant_8bit_custom_op( + input: torch.Tensor, + output_q: torch.Tensor, + output_s: torch.Tensor, + group_size: int, + eps: float, + fp8_min: float, + fp8_max: float, + scale_ue8m0: bool = False, +) -> None: + """ + Per-token-group quantization to 8-bit format. + + Args: + input: Input tensor to quantize (float, half, or bfloat16). + output_q: Output quantized tensor (e.g., fp8_e4m3 or int8). + output_s: Output scale tensor. + group_size: The size of the group for quantization. + eps: A small value to avoid division by zero. + fp8_min: The minimum value of the 8-bit data type. + fp8_max: The maximum value of the 8-bit data type. + scale_ue8m0: Whether to use UE8M0 format for scales. + """ + module = _jit_per_token_group_quant_8bit_module(input.dtype, output_q.dtype) + module.per_token_group_quant_8bit( + input, + output_q, + output_s, + group_size, + eps, + fp8_min, + fp8_max, + scale_ue8m0, + ) + return None + + +def per_token_group_quant_8bit( + input: torch.Tensor, + output_q: torch.Tensor, + output_s: torch.Tensor, + group_size: int, + eps: float, + fp8_min: float, + fp8_max: float, + scale_ue8m0: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + _per_token_group_quant_8bit_custom_op( + input=input, + output_q=output_q, + output_s=output_s, + group_size=group_size, + eps=eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + scale_ue8m0=scale_ue8m0, + ) + return output_q, output_s diff --git a/sglang/python/sglang/jit_kernel/pos_enc.py b/sglang/python/sglang/jit_kernel/pos_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..16f8ac37e9302e07d174524bee129e61e1351d49 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/pos_enc.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit +from sglang.srt.utils.custom_op import register_custom_op + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_rotary_embedding_module() -> Module: + return load_jit( + "rotary_embedding", + cuda_files=["elementwise/pos_enc.cuh"], + cuda_wrappers=[("rotary_embedding", "RotaryEmbeddingKernel::run")], + ) + + +@register_custom_op( + op_name="rotary_embedding_with_key", + mutates_args=["query", "key"], +) +def rotary_embedding_with_key( + positions: torch.Tensor, # [batch_size, seq_len] or [num_tokens] + query: torch.Tensor, # [batch_size, seq_len, num_heads * head_size] or + # [num_tokens, num_heads * head_size] or + # [batch_size, seq_len, num_heads, head_size] or + # [num_tokens, num_heads, head_size] + key: torch.Tensor, # [batch_size, seq_len, num_kv_heads * head_size] or + # [num_tokens, num_kv_heads * head_size] or + # [batch_size, seq_len, num_heads, head_size] or + # [num_tokens, num_heads, head_size] + head_size: int, + cos_sin_cache: torch.Tensor, # [max_position, rot_dim] + is_neox: bool = True, +) -> None: + """ + Apply rotary embedding to query and key tensors. + + Args: + positions: Position indices of shape [num_tokens] or [batch_size, seq_len] + query: Query tensor of shape [num_tokens, num_heads, head_size] or [num_tokens, num_heads * head_size] + key: Key tensor of shape [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads * head_size] + cos_sin_cache: Cosine and sine cache of shape [max_position, rot_dim] + is_neox: Whether to use GPT-NeoX style rotary embedding (True) or GPT-J style (False) + """ + module = _jit_rotary_embedding_module() + module.rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox) + + +@register_custom_op( + op_name="rotary_embedding_without_key", + mutates_args=["query"], +) +def rotary_embedding_without_key( + positions: torch.Tensor, + query: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool = True, +) -> None: + module = _jit_rotary_embedding_module() + module.rotary_embedding(positions, query, None, head_size, cos_sin_cache, is_neox) + + +def rotary_embedding( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool = True, +): + if key is None: + rotary_embedding_without_key( + positions, query, head_size, cos_sin_cache, is_neox + ) + else: + rotary_embedding_with_key( + positions, query, key, head_size, cos_sin_cache, is_neox + ) + return query, key diff --git a/sglang/python/sglang/jit_kernel/rope.py b/sglang/python/sglang/jit_kernel/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..02c4e5fb6e89a54cdc28062c03333a091de9096e --- /dev/null +++ b/sglang/python/sglang/jit_kernel/rope.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import torch + +from sglang.jit_kernel.utils import ( + cache_once, + is_arch_support_pdl, + load_jit, + make_cpp_args, +) +from sglang.srt.utils.custom_op import register_custom_op + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_fused_rope_module(is_neox: bool, rope_dim: int, dtype: torch.dtype) -> Module: + args = make_cpp_args(is_neox, rope_dim, is_arch_support_pdl(), dtype) + return load_jit( + "fused_rope", + *args, + cuda_files=["elementwise/rope.cuh"], + cuda_wrappers=[ + ("run_rope", f"FusedRopeKernel<{args}>::run"), + ("run_rope_store", f"FusedRopeKernel<{args}>::run_fused"), + ], + ) + + +@dataclass +class FusedSetKVBufferArg: + """ + value : Optional[torch.Tensor] + Value tensor, shape: ``(nnz, num_v_heads * head_size)``. + k_buffer : Optional[torch.Tensor] + Buffer for keys, shape: ``(nnz, num_k_heads * head_size)``. + v_buffer : Optional[torch.Tensor] + Buffer for values, shape: ``(nnz, num_v_heads * head_size)``. + cache_loc : Optional[torch.Tensor] + Cache location tensor, used for indexing kv cache. + """ + + value: torch.Tensor + k_buffer: torch.Tensor + v_buffer: torch.Tensor + cache_loc: torch.Tensor + + +@register_custom_op(mutates_args=["q", "k"]) +def apply_rope_inplace( + q: torch.Tensor, + k: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, + *, + is_neox: bool, + rope_dim: int = 0, +) -> None: + """ + Fused inplace rotary position embedding for query and key tensors. + + Args: + q: Query tensor of shape [num_tokens, num_qo_heads, rope_dim]. + k: Key tensor of shape [num_tokens, num_kv_heads, rope_dim]. + cos_sin_cache: Cosine/sine cache of shape [max_position, rope_dim], + where the first half along dim=-1 is cos and the second half is sin. + Must be float32. + positions: Position indices of shape [num_tokens], int32 or int64. + is_neox: Whether to use GPT-NeoX style (True) or GPT-J interleaved style (False). + rope_dim: Rotary embedding dimension. Defaults to cos_sin_cache.size(-1). + """ + rope_dim = rope_dim or cos_sin_cache.size(-1) + module = _jit_fused_rope_module(is_neox, rope_dim, q.dtype) + module.run_rope(q, k, cos_sin_cache, positions) + + +@register_custom_op(mutates_args=["q", "k_cache", "v_cache"]) +def apply_rope_inplace_with_kvcache( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, + out_loc: torch.Tensor, + *, + is_neox: bool, + rope_dim: int = 0, +) -> None: + """ + Fused inplace RoPE + KV cache store. + + Applies rotary position embedding to q inplace. For k, applies RoPE and + stores the result in k_cache. The original v is also stored in v_cache. + + Args: + q: Query tensor of shape [num_tokens, num_qo_heads, head_dim]. + k: Key tensor of shape [num_tokens, num_kv_heads, head_dim]. + v: Value tensor of shape [num_tokens, num_kv_heads, head_dim]. + k_cache: Key cache of shape [cache_size, num_kv_heads * head_dim]. + v_cache: Value cache of shape [cache_size, num_kv_heads * head_dim]. + cos_sin_cache: Cosine/sine cache of shape [max_position, rope_dim], float32. + positions: Position indices of shape [num_tokens], int32 or int64. + out_loc: Cache write locations of shape [num_tokens], same dtype as positions. + is_neox: Whether to use GPT-NeoX style (True) or GPT-J interleaved (False). + rope_dim: Rotary embedding dimension. Defaults to cos_sin_cache.size(-1). + """ + rope_dim = rope_dim or cos_sin_cache.size(-1) + v = v.view_as(k) + module = _jit_fused_rope_module(is_neox, rope_dim, q.dtype) + module.run_rope_store(q, k, v, k_cache, v_cache, cos_sin_cache, positions, out_loc) + + +# NOTE: this name is intentionally set as the old kernel in `sgl_kernel` +def apply_rope_with_cos_sin_cache_inplace( + q: torch.Tensor, + k: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, + *, + is_neox: bool, + rope_dim: int = 0, + fused_args: Optional[FusedSetKVBufferArg] = None, +) -> None: + """ + Apply RoPE to q and k inplace, with optional fused kv cache store. + + If `fused_args` is provided, it will perform fused RoPE and KV cache store. + Otherwise, it will only apply RoPE inplace. + + Args: + q: Query tensor of shape [num_tokens, num_qo_heads, head_dim]. + k: Key tensor of shape [num_tokens, num_kv_heads, head_dim]. + cos_sin_cache: Cosine/sine cache of shape [max_position, rope_dim], float32. + positions: Position indices of shape [num_tokens], int32 or int64. + is_neox: Whether to use GPT-NeoX style (True) or GPT-J interleaved (False). + rope_dim: Rotary embedding dimension. Defaults to cos_sin_cache.size(-1). + fused_args: Optional arguments for fused RoPE + KV cache store. If None, + only RoPE will be applied inplace without touching kv cache. + """ + if fused_args is not None: + apply_rope_inplace_with_kvcache( + q, + k, + fused_args.value, + fused_args.k_buffer, + fused_args.v_buffer, + cos_sin_cache, + positions, + fused_args.cache_loc, + is_neox=is_neox, + rope_dim=rope_dim, + ) + else: + apply_rope_inplace( + q, k, cos_sin_cache, positions, is_neox=is_neox, rope_dim=rope_dim + ) diff --git a/sglang/python/sglang/jit_kernel/tests/test_add_constant.py b/sglang/python/sglang/jit_kernel/tests/test_add_constant.py new file mode 100644 index 0000000000000000000000000000000000000000..36ea024bae405dfa2f6eb1142e256696b4c307dc --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_add_constant.py @@ -0,0 +1,16 @@ +import pytest +import torch + +from sglang.jit_kernel.add_constant import add_constant + + +@pytest.mark.parametrize("size", [1, 2, 127, 128, 1024, 1025]) +@pytest.mark.parametrize("constant", [0, 1, 7, 1024, -3]) +def test_add_constant(size: int, constant: int) -> None: + src = torch.arange(0, size, dtype=torch.int32, device="cuda") + dst = add_constant(src, constant) + assert torch.all(dst == src + constant) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_awq_dequantize.py b/sglang/python/sglang/jit_kernel/tests/test_awq_dequantize.py new file mode 100644 index 0000000000000000000000000000000000000000..d2970e99bd60616aaac28f988409b3d974bcdc6d --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_awq_dequantize.py @@ -0,0 +1,164 @@ +import itertools + +import pytest +import torch + +from sglang.jit_kernel.awq_dequantize import awq_dequantize as jit_awq_dequantize + +try: + from sgl_kernel import awq_dequantize as aot_awq_dequantize + + AOT_AVAILABLE = True +except ImportError: + AOT_AVAILABLE = False + + +def reverse_awq_order(t: torch.Tensor): + bits = 4 + AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + reverse_order_tensor = torch.arange( + t.shape[-1], + dtype=torch.int32, + device=t.device, + ) + reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) + reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER] + reverse_order_tensor = reverse_order_tensor.view(-1) + + t = t[:, reverse_order_tensor] & 0xF + return t + + +# qweights - [R , C // 8], int32 +# scales - [R // G, C ], float16 +# zeros - [R // G, C // 8], int32 +def awq_dequantize_torch( + qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int +) -> torch.Tensor: + if group_size == -1: + group_size = qweight.shape[0] + + bits = 4 + shifts = torch.arange(0, 32, bits, device=qzeros.device) + + iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to( + torch.int8 + ) + + iweights = iweights.view(iweights.shape[0], -1) + + zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to( + torch.int8 + ) + zeros = zeros.view(qzeros.shape[0], -1) + zeros = reverse_awq_order(zeros) + + iweights = reverse_awq_order(iweights) + + iweights = torch.bitwise_and(iweights, (2**bits) - 1) + zeros = torch.bitwise_and(zeros, (2**bits) - 1) + + scales = scales.repeat_interleave(group_size, dim=0) + zeros = zeros.repeat_interleave(group_size, dim=0) + return (iweights - zeros) * scales + + +@pytest.mark.parametrize( + "qweight_row,qweight_col,is_bf16_act", + list( + itertools.product( + [128, 256, 512, 1024, 3584], + [16, 32, 64, 128, 448], + [True, False], + ) + ), +) +def test_awq_dequantize_jit_vs_torch( + qweight_row: int, qweight_col: int, is_bf16_act: bool +): + device = torch.device("cuda") + qweight = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (qweight_row, qweight_col), + dtype=torch.int32, + device=device, + ) + group_size = qweight_row + scales_row = qweight_row // group_size + scales_col = qweight_col * 8 + + if is_bf16_act: + scales = torch.rand(scales_row, scales_col, dtype=torch.bfloat16, device=device) + else: + scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device) + + qzeros = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (scales_row, qweight_col), + dtype=torch.int32, + device=device, + ) + + # Run both implementations + torch_out = awq_dequantize_torch(qweight, scales, qzeros, group_size) + jit_out = jit_awq_dequantize(qweight, scales, qzeros) + + # Compare results (approximate due to different computation paths) + torch.testing.assert_close( + torch_out.to(torch.float32), jit_out.to(torch.float32), rtol=1e-3, atol=1e-5 + ) + + +@pytest.mark.parametrize( + "qweight_row,qweight_col,is_bf16_act", + list( + itertools.product( + [128, 256, 512, 1024, 3584], + [16, 32, 64, 128, 448], + [True, False], + ) + ), +) +def test_awq_dequantize_jit_vs_aot( + qweight_row: int, qweight_col: int, is_bf16_act: bool +): + if not AOT_AVAILABLE: + pytest.skip("sgl_kernel AOT not available") + + device = torch.device("cuda") + qweight = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (qweight_row, qweight_col), + dtype=torch.int32, + device=device, + ) + group_size = qweight_row + scales_row = qweight_row // group_size + scales_col = qweight_col * 8 + + if is_bf16_act: + scales = torch.rand(scales_row, scales_col, dtype=torch.bfloat16, device=device) + else: + scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device) + + qzeros = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (scales_row, qweight_col), + dtype=torch.int32, + device=device, + ) + + # Run both implementations + aot_out = aot_awq_dequantize(qweight, scales, qzeros) + jit_out = jit_awq_dequantize(qweight, scales, qzeros) + + # Bitwise equality + torch.testing.assert_close(jit_out, aot_out, rtol=0, atol=0) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_awq_marlin_moe_repack.py b/sglang/python/sglang/jit_kernel/tests/test_awq_marlin_moe_repack.py new file mode 100644 index 0000000000000000000000000000000000000000..e4741b3739aa983307ff8fbe26bf165285596c6e --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_awq_marlin_moe_repack.py @@ -0,0 +1,115 @@ +import numpy as np +import pytest +import torch +from sgl_kernel.scalar_type import scalar_types + +from sglang.jit_kernel.awq_marlin_repack import ( + awq_marlin_moe_repack as jit_awq_marlin_moe_repack, +) +from sglang.srt.layers.quantization.utils import pack_cols, quantize_weights + +try: + from sgl_kernel import awq_marlin_moe_repack as aot_awq_marlin_moe_repack + + AOT_AVAILABLE = True +except ImportError: + AOT_AVAILABLE = False + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) + + +@pytest.mark.parametrize("num_bits", [4]) +@pytest.mark.parametrize("num_experts", [2, 4, 8]) +@pytest.mark.parametrize("k_tiles,n_tiles", [(1, 1), (2, 2), (4, 4)]) +@pytest.mark.parametrize("group_size", [16, 32]) +def test_awq_marlin_moe_repack_jit_vs_aot( + num_bits, num_experts, k_tiles, n_tiles, group_size +): + if not AOT_AVAILABLE: + pytest.skip("sgl_kernel AOT not available") + + tile_k, tile_n = 16, 64 + size_k = k_tiles * tile_k + size_n = n_tiles * tile_n + pack_factor = 32 // num_bits + + # Create per-expert AWQ-packed weights + b_q_weight = torch.empty( + (num_experts, size_k, size_n // pack_factor), + dtype=torch.int32, + device="cuda", + ) + for e in range(num_experts): + b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda") + w_ref, q_w, s, zp = quantize_weights( + b_weight, scalar_types.uint4, group_size, zero_points=True + ) + b_q_weight[e] = awq_pack(q_w, num_bits, size_k, size_n) + + perm = torch.empty((num_experts, 0), dtype=torch.int32, device="cuda") + + out_jit = jit_awq_marlin_moe_repack(b_q_weight, perm, size_k, size_n, num_bits) + out_aot = aot_awq_marlin_moe_repack(b_q_weight, perm, size_k, size_n, num_bits) + + torch.cuda.synchronize() + + # Bitwise equality + torch.testing.assert_close(out_jit, out_aot, rtol=0, atol=0) + + +@pytest.mark.parametrize("num_bits", [4]) +@pytest.mark.parametrize("num_experts", [2, 4]) +@pytest.mark.parametrize("k_tiles,n_tiles", [(1, 1), (2, 2)]) +@pytest.mark.parametrize("group_size", [16, 32]) +def test_awq_marlin_moe_repack_shape( + num_bits, num_experts, k_tiles, n_tiles, group_size +): + tile_k, tile_n = 16, 64 + size_k = k_tiles * tile_k + size_n = n_tiles * tile_n + pack_factor = 32 // num_bits + + # Create per-expert AWQ-packed weights + b_q_weight = torch.empty( + (num_experts, size_k, size_n // pack_factor), + dtype=torch.int32, + device="cuda", + ) + for e in range(num_experts): + b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda") + w_ref, q_w, s, zp = quantize_weights( + b_weight, scalar_types.uint4, group_size, zero_points=True + ) + b_q_weight[e] = awq_pack(q_w, num_bits, size_k, size_n) + + perm = torch.empty((num_experts, 0), dtype=torch.int32, device="cuda") + + out = jit_awq_marlin_moe_repack(b_q_weight, perm, size_k, size_n, num_bits) + torch.cuda.synchronize() + + assert out.is_cuda and out.dtype == torch.int32 + expected_shape = (num_experts, size_k // 16, size_n * (num_bits // 2)) + assert list(out.shape) == list(expected_shape) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_awq_marlin_repack.py b/sglang/python/sglang/jit_kernel/tests/test_awq_marlin_repack.py new file mode 100644 index 0000000000000000000000000000000000000000..ba959ccacd0f339ed112148f2e61c56cb4a5feca --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_awq_marlin_repack.py @@ -0,0 +1,101 @@ +import numpy as np +import pytest +import torch +from sgl_kernel.scalar_type import scalar_types + +from sglang.jit_kernel.awq_marlin_repack import ( + awq_marlin_repack as jit_awq_marlin_repack, +) +from sglang.srt.layers.quantization.utils import pack_cols, quantize_weights +from sglang.test.test_marlin_utils import get_weight_perm, marlin_weights + +try: + from sgl_kernel import awq_marlin_repack as aot_awq_marlin_repack + + AOT_AVAILABLE = True +except ImportError: + AOT_AVAILABLE = False + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) + + +@pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize("k_tiles,n_tiles", [(1, 1), (2, 2), (4, 4)]) +@pytest.mark.parametrize("group_size", [16, 32]) +def test_awq_marlin_repack_jit_vs_aot(num_bits, k_tiles, n_tiles, group_size): + if not AOT_AVAILABLE: + pytest.skip("sgl_kernel AOT not available") + + tile_k, tile_n = 16, 64 + size_k = k_tiles * tile_k + size_n = n_tiles * tile_n + + b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda") + + w_ref, q_w, s, zp = quantize_weights( + b_weight, scalar_types.uint4, group_size, zero_points=True + ) + + q_w_awq = awq_pack(q_w, num_bits, size_k, size_n) + + out_jit = jit_awq_marlin_repack(q_w_awq, size_k, size_n, num_bits) + out_aot = aot_awq_marlin_repack(q_w_awq, size_k, size_n, num_bits) + + torch.cuda.synchronize() + + # Bitwise equality + torch.testing.assert_close(out_jit, out_aot, rtol=0, atol=0) + + +@pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize("k_tiles,n_tiles", [(1, 1), (2, 2)]) +@pytest.mark.parametrize("group_size", [16, 32]) +def test_awq_marlin_repack_correct(num_bits, k_tiles, n_tiles, group_size): + tile_k, tile_n = 16, 64 + size_k = k_tiles * tile_k + size_n = n_tiles * tile_n + pack_factor = 32 // num_bits + + b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda") + + w_ref, q_w, s, zp = quantize_weights( + b_weight, scalar_types.uint4, group_size, zero_points=True + ) + + q_w_awq = awq_pack(q_w, num_bits, size_k, size_n) + + weight_perm = get_weight_perm(num_bits) + q_w_marlin = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + + out_gpu = jit_awq_marlin_repack(q_w_awq, size_k, size_n, num_bits) + assert out_gpu.is_cuda and out_gpu.dtype == torch.int32 + + expected_cols = size_n * tile_k // pack_factor + assert list(out_gpu.shape) == [size_k // tile_k, expected_cols] + + torch.cuda.synchronize() + + torch.testing.assert_close(out_gpu, q_w_marlin) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_concat_mla.py b/sglang/python/sglang/jit_kernel/tests/test_concat_mla.py new file mode 100644 index 0000000000000000000000000000000000000000..cf1013c5edf3a55c76f4f4c8e6d47bf0920886b3 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_concat_mla.py @@ -0,0 +1,169 @@ +import itertools + +import pytest +import torch +import triton + + +def torch_concat_mla_k( + k: torch.Tensor, k_nope: torch.Tensor, k_rope: torch.Tensor +) -> None: + """Reference PyTorch implementation for concat_mla_k.""" + # k_nope: [num_tokens, num_heads, nope_head_dim] + # k_rope: [num_tokens, 1, rope_head_dim] + # k: [num_tokens, num_heads, nope_head_dim + rope_head_dim] + nope_head_dim = k_nope.shape[-1] + k[:, :, :nope_head_dim] = k_nope + # Broadcast k_rope across all heads + k[:, :, nope_head_dim:] = k_rope.expand(-1, k.shape[1], -1) + + +def torch_concat_mla_absorb_q( + a: torch.Tensor, b: torch.Tensor, out: torch.Tensor +) -> None: + """Reference PyTorch implementation for concat_mla_absorb_q.""" + # a: [dim_0, dim_1, a_last_dim] + # b: [dim_0, dim_1, b_last_dim] + # out: [dim_0, dim_1, a_last_dim + b_last_dim] + a_last_dim = a.shape[-1] + out[:, :, :a_last_dim] = a + out[:, :, a_last_dim:] = b + + +def sgl_kernel_concat_mla_k( + k: torch.Tensor, k_nope: torch.Tensor, k_rope: torch.Tensor +) -> None: + """AOT compiled sgl_kernel implementation.""" + from sgl_kernel import concat_mla_k + + concat_mla_k(k, k_nope, k_rope) + + +def sgl_kernel_concat_mla_absorb_q( + a: torch.Tensor, b: torch.Tensor, out: torch.Tensor +) -> None: + """AOT compiled sgl_kernel implementation.""" + from sgl_kernel import concat_mla_absorb_q + + result = concat_mla_absorb_q(a, b) # AOT returns output + out.copy_(result) # Copy to provided tensor for comparison + + +def jit_concat_mla_k( + k: torch.Tensor, k_nope: torch.Tensor, k_rope: torch.Tensor +) -> None: + """JIT compiled implementation.""" + from sglang.jit_kernel.concat_mla import concat_mla_k + + concat_mla_k(k, k_nope, k_rope) + + +def jit_concat_mla_absorb_q( + a: torch.Tensor, b: torch.Tensor, out: torch.Tensor +) -> None: + """JIT compiled implementation - wrapper for test compatibility.""" + from sglang.jit_kernel.concat_mla import concat_mla_absorb_q + + result = concat_mla_absorb_q(a, b) + out.copy_(result) + + +# Constants matching the kernel +NUM_LOCAL_HEADS = 128 +QK_NOPE_HEAD_DIM = 128 +QK_ROPE_HEAD_DIM = 64 +K_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM + +A_LAST_DIM = 512 +B_LAST_DIM = 64 +OUT_LAST_DIM = A_LAST_DIM + B_LAST_DIM + +DEVICE = "cuda" +DTYPE = torch.bfloat16 + +# Test configurations +NUM_TOKENS_LIST = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS_LIST) +def test_concat_mla_k_jit_vs_torch(num_tokens: int) -> None: + """Test JIT kernel against PyTorch reference.""" + k_jit = torch.empty( + num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE + ) + k_torch = torch.empty( + num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE + ) + + k_nope = torch.randn( + num_tokens, NUM_LOCAL_HEADS, QK_NOPE_HEAD_DIM, device=DEVICE, dtype=DTYPE + ) + k_rope = torch.randn(num_tokens, 1, QK_ROPE_HEAD_DIM, device=DEVICE, dtype=DTYPE) + + torch_concat_mla_k(k_torch, k_nope, k_rope) + jit_concat_mla_k(k_jit, k_nope, k_rope) + + triton.testing.assert_close(k_jit, k_torch, atol=0, rtol=0) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS_LIST) +def test_concat_mla_k_jit_vs_aot(num_tokens: int) -> None: + """Test JIT kernel against AOT kernel for bitwise equivalence.""" + k_jit = torch.empty( + num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE + ) + k_aot = torch.empty( + num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, device=DEVICE, dtype=DTYPE + ) + + k_nope = torch.randn( + num_tokens, NUM_LOCAL_HEADS, QK_NOPE_HEAD_DIM, device=DEVICE, dtype=DTYPE + ) + k_rope = torch.randn(num_tokens, 1, QK_ROPE_HEAD_DIM, device=DEVICE, dtype=DTYPE) + + sgl_kernel_concat_mla_k(k_aot, k_nope, k_rope) + jit_concat_mla_k(k_jit, k_nope, k_rope) + + triton.testing.assert_close(k_jit, k_aot, atol=0, rtol=0) + + +DIM_0_LIST = [1, 2, 4, 8, 16, 32] +DIM_1_LIST = [1, 2, 4, 8, 16, 128] + + +@pytest.mark.parametrize( + "dim_0,dim_1", + list(itertools.product(DIM_0_LIST, DIM_1_LIST)), +) +def test_concat_mla_absorb_q_jit_vs_torch(dim_0: int, dim_1: int) -> None: + """Test JIT kernel against PyTorch reference.""" + a = torch.randn(dim_0, dim_1, A_LAST_DIM, device=DEVICE, dtype=DTYPE) + b = torch.randn(dim_0, dim_1, B_LAST_DIM, device=DEVICE, dtype=DTYPE) + out_jit = torch.empty(dim_0, dim_1, OUT_LAST_DIM, device=DEVICE, dtype=DTYPE) + out_torch = torch.empty(dim_0, dim_1, OUT_LAST_DIM, device=DEVICE, dtype=DTYPE) + + torch_concat_mla_absorb_q(a, b, out_torch) + jit_concat_mla_absorb_q(a, b, out_jit) + + triton.testing.assert_close(out_jit, out_torch, atol=0, rtol=0) + + +@pytest.mark.parametrize( + "dim_0,dim_1", + list(itertools.product(DIM_0_LIST, DIM_1_LIST)), +) +def test_concat_mla_absorb_q_jit_vs_aot(dim_0: int, dim_1: int) -> None: + """Test JIT kernel against AOT kernel for bitwise equivalence.""" + a = torch.randn(dim_0, dim_1, A_LAST_DIM, device=DEVICE, dtype=DTYPE) + b = torch.randn(dim_0, dim_1, B_LAST_DIM, device=DEVICE, dtype=DTYPE) + out_jit = torch.empty(dim_0, dim_1, OUT_LAST_DIM, device=DEVICE, dtype=DTYPE) + out_aot = torch.empty(dim_0, dim_1, OUT_LAST_DIM, device=DEVICE, dtype=DTYPE) + + sgl_kernel_concat_mla_absorb_q(a, b, out_aot) + jit_concat_mla_absorb_q(a, b, out_jit) + + triton.testing.assert_close(out_jit, out_aot, atol=0, rtol=0) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_cutedsl_gdn.py b/sglang/python/sglang/jit_kernel/tests/test_cutedsl_gdn.py new file mode 100644 index 0000000000000000000000000000000000000000..4139f24ae94c532debe9f49af2c96c84ae4bb74f --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_cutedsl_gdn.py @@ -0,0 +1,299 @@ +"""Tests for CuTe DSL fused sigmoid gating delta rule kernel (GDN).""" + +import numpy as np +import pytest +import torch + +try: + import cuda.bindings.driver as cuda_driver + import cutlass # noqa: F401 + from cutlass.cute.runtime import from_dlpack + + from sglang.jit_kernel import cutedsl_gdn + + CUTEDSL_AVAILABLE = True +except ImportError: + CUTEDSL_AVAILABLE = False + cutedsl_gdn = None + +try: + from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import ( + fused_sigmoid_gating_delta_rule_update, + ) + + TRITON_AVAILABLE = True +except ImportError: + TRITON_AVAILABLE = False + + +def run_triton_kernel(A_log, dt_bias, q, k, v, a, b, initial_state, indices, scale): + return fused_sigmoid_gating_delta_rule_update( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=1.0, + softplus_threshold=20.0, + q=q, + k=k, + v=v, + b=b, + initial_state_source=initial_state, + initial_state_indices=indices, + scale=scale, + use_qk_l2norm_in_kernel=True, + cu_seqlens=None, + ) + + +@pytest.mark.skipif(not CUTEDSL_AVAILABLE, reason="CuTe DSL not available") +@pytest.mark.skipif(not TRITON_AVAILABLE, reason="Triton kernel not available") +@pytest.mark.parametrize("B", [16, 128]) +def test_cutedsl_gdn_precision(B: int): + """Test precision of CuTe DSL GDN kernel against Triton reference.""" + torch.manual_seed(2025) + T, H, K, V, HV = 1, 16, 128, 128, 32 + scale = K**-0.5 + + A_log = torch.randn(HV, dtype=torch.float32, device="cuda") + dt_bias = torch.randn(HV, dtype=torch.bfloat16, device="cuda") + a = torch.randn(B, T, HV, dtype=torch.bfloat16, device="cuda") + b = torch.randn(B, T, HV, dtype=torch.bfloat16, device="cuda") + q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, T, HV, V, dtype=torch.bfloat16, device="cuda") + indices = torch.arange(B, dtype=torch.int32, device="cuda") + state_cutedsl = torch.randn(B, HV, K, V, dtype=torch.float32, device="cuda") + state_triton = state_cutedsl.clone().reshape(-1).contiguous() + + # Warmup compilation + _ = cutedsl_gdn.cutedsl_fused_sigmoid_gating_delta_rule_update( + A_log, dt_bias, q, k, v, a, b, state_cutedsl.clone(), indices, scale=scale + ) + torch.cuda.synchronize() + + # Fresh state for actual test + state_cutedsl = torch.randn(B, HV, K, V, dtype=torch.float32, device="cuda") + state_triton = state_cutedsl.clone().reshape(-1).contiguous() + + out_cutedsl = cutedsl_gdn.cutedsl_fused_sigmoid_gating_delta_rule_update( + A_log, dt_bias, q, k, v, a, b, state_cutedsl, indices, scale=scale + ) + out_triton = run_triton_kernel( + A_log, dt_bias, q, k, v, a, b, state_triton, indices, scale + ) + + # Check precision: diff > 0.1 must be < 1% of elements + abs_diff = (out_triton.float() - out_cutedsl.float()).abs() + max_diff = abs_diff.max().item() + mean_diff = abs_diff.mean().item() + fail_rate = (abs_diff > 0.1).float().mean().item() * 100 + has_nan = torch.isnan(out_cutedsl).any() or torch.isinf(out_cutedsl).any() + + kernel_type = "SmallBatch" if B < 32 else "LargeBatch" + print( + f"\n B={B} ({kernel_type}): max_diff={max_diff:.2e}, mean_diff={mean_diff:.2e}, fail_rate={fail_rate:.2f}%" + ) + + assert not has_nan, "Output contains NaN/Inf" + assert fail_rate < 1.0, f"Fail rate {fail_rate:.2f}% >= 1%" + + +@pytest.mark.skipif( + True, + reason="Skip the performance test because the speedup ratio is highly unstable in the CI environment. ", +) +@pytest.mark.skipif(not CUTEDSL_AVAILABLE, reason="CuTe DSL not available") +@pytest.mark.skipif(not TRITON_AVAILABLE, reason="Triton kernel not available") +@pytest.mark.parametrize("B", [1, 128]) +def test_cutedsl_gdn_performance(B: int): + """Benchmark CuTe DSL GDN kernel against Triton reference.""" + torch.manual_seed(2025) + T, H, K, V, HV = 1, 16, 128, 128, 32 + N = B + scale = K**-0.5 + is_varlen = True + warmup, bench_iters, run_iters = 10, 100, 10 + + A_log = torch.randn(HV, dtype=torch.float32, device="cuda") + dt_bias = torch.randn(HV, dtype=torch.bfloat16, device="cuda") + indices = torch.arange(N, dtype=torch.int32, device="cuda") + state_cutedsl = torch.randn(N, HV, K, V, dtype=torch.float32, device="cuda") + state_triton = state_cutedsl.reshape(-1).contiguous() + cu_seqlens = torch.zeros(N + 1, dtype=torch.int32, device="cuda") + o_cutedsl = torch.zeros(1, N, HV, V, dtype=torch.bfloat16, device="cuda") + + # Prepare tensors for multiple runs + q_list, k_list, v_list, a_list, b_list = [], [], [], [], [] + q_tensor_list, k_tensor_list, v_tensor_list, a_tensor_list, b_tensor_list = ( + [], + [], + [], + [], + [], + ) + q_triton, k_triton, v_triton, a_triton, b_triton = [], [], [], [], [] + + for ri in range(run_iters): + torch.manual_seed(2025 + ri) + q_i = torch.randn(1, N, H, K, dtype=torch.bfloat16, device="cuda") + k_i = torch.randn(1, N, H, K, dtype=torch.bfloat16, device="cuda") + v_i = torch.randn(1, N, HV, V, dtype=torch.bfloat16, device="cuda") + a_i = torch.randn(N, HV, dtype=torch.bfloat16, device="cuda") + b_i = torch.randn(N, HV, dtype=torch.bfloat16, device="cuda") + + q_list.append(q_i) + k_list.append(k_i) + v_list.append(v_i) + a_list.append(a_i) + b_list.append(b_i) + q_tensor_list.append(from_dlpack(q_i, assumed_align=16)) + k_tensor_list.append(from_dlpack(k_i, assumed_align=16)) + v_tensor_list.append(from_dlpack(v_i, assumed_align=16)) + a_tensor_list.append(from_dlpack(a_i, assumed_align=16)) + b_tensor_list.append(from_dlpack(b_i, assumed_align=16)) + q_triton.append(q_i.transpose(0, 1).contiguous()) + k_triton.append(k_i.transpose(0, 1).contiguous()) + v_triton.append(v_i.transpose(0, 1).contiguous()) + a_triton.append(a_i.unsqueeze(1).contiguous()) + b_triton.append(b_i.unsqueeze(1).contiguous()) + + A_log_t = from_dlpack(A_log, assumed_align=16) + dt_bias_t = from_dlpack(dt_bias, assumed_align=16) + h0_t = from_dlpack(state_cutedsl, assumed_align=16) + idx_t = from_dlpack(indices, assumed_align=16) + o_t = from_dlpack(o_cutedsl, assumed_align=16) + cu_t = from_dlpack(cu_seqlens, assumed_align=16) + + torch_stream = torch.cuda.Stream() + stream = cuda_driver.CUstream(torch_stream.cuda_stream) + + # Compile kernels + compiled = cutedsl_gdn._get_compiled_kernel(N, H, HV, K, V, N, N < 32, is_varlen) + torch.cuda.synchronize() + + for ri in range(run_iters): + _ = run_triton_kernel( + A_log, + dt_bias, + q_triton[ri], + k_triton[ri], + v_triton[ri], + a_triton[ri], + b_triton[ri], + state_triton, + indices, + scale, + ) + torch.cuda.synchronize() + + def run_cutedsl(): + for ri in range(run_iters): + compiled( + cu_t, + q_tensor_list[ri], + k_tensor_list[ri], + v_tensor_list[ri], + a_tensor_list[ri], + b_tensor_list[ri], + A_log_t, + dt_bias_t, + h0_t, + idx_t, + o_t, + stream, + ) + + def run_triton(): + for ri in range(run_iters): + _ = run_triton_kernel( + A_log, + dt_bias, + q_triton[ri], + k_triton[ri], + v_triton[ri], + a_triton[ri], + b_triton[ri], + state_triton, + indices, + scale, + ) + + # Warmup + with torch.cuda.stream(torch_stream): + run_cutedsl() + torch.cuda.synchronize() + run_triton() + torch.cuda.synchronize() + + # Capture CUDA graphs + graph_triton = torch.cuda.CUDAGraph() + graph_cutedsl = torch.cuda.CUDAGraph() + try: + with torch.cuda.graph(graph_triton): + run_triton() + with torch.cuda.graph(graph_cutedsl, stream=torch_stream): + run_cutedsl() + torch.cuda.synchronize() + except Exception: + graph_triton = graph_cutedsl = None + + # Warmup with graphs + for _ in range(warmup): + if graph_cutedsl: + graph_cutedsl.replay() + else: + with torch.cuda.stream(torch_stream): + run_cutedsl() + torch.cuda.synchronize() + + if graph_triton: + graph_triton.replay() + else: + run_triton() + torch.cuda.synchronize() + + # Benchmark + triton_times, cutedsl_times = [], [] + for _ in range(bench_iters): + start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event( + enable_timing=True + ) + start.record() + if graph_triton: + graph_triton.replay() + else: + run_triton() + end.record() + torch.cuda.synchronize() + triton_times.append(start.elapsed_time(end)) + + start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event( + enable_timing=True + ) + with torch.cuda.stream(torch_stream): + start.record() + if graph_cutedsl: + graph_cutedsl.replay() + else: + run_cutedsl() + end.record() + torch.cuda.synchronize() + cutedsl_times.append(start.elapsed_time(end)) + + triton_mean = np.mean(triton_times) / run_iters * 1000 + triton_std = np.std(triton_times) / run_iters * 1000 + cutedsl_mean = np.mean(cutedsl_times) / run_iters * 1000 + cutedsl_std = np.std(cutedsl_times) / run_iters * 1000 + speedup = triton_mean / cutedsl_mean + + kernel_type = "SmallBatch" if B < 32 else "LargeBatch" + print( + f"\n B={B} ({kernel_type}): Triton={triton_mean:.2f}±{triton_std:.2f}μs, CuTeDSL={cutedsl_mean:.2f}±{cutedsl_std:.2f}μs, speedup={speedup:.2f}x" + ) + + min_speedup = 1.0 if B < 32 else 1.15 + assert speedup >= min_speedup, f"Speedup {speedup:.2f}x < {min_speedup}x for B={B}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_flash_attention_4.py b/sglang/python/sglang/jit_kernel/tests/test_flash_attention_4.py new file mode 100644 index 0000000000000000000000000000000000000000..1540d4601b1faf1bb45668a8c38a0c8747613a79 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_flash_attention_4.py @@ -0,0 +1,1504 @@ +# Adapted from https://github.com/Dao-AILab/flash-attention/blob/8ecf128f683266735ba68e3c106ff67a2611886e/tests/cute/test_flash_attn.py + +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + +import itertools +import math + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +from sglang.jit_kernel.flash_attention_v4 import flash_attn_varlen_func + +# Skip this test on Hopper machine +skip_condition = torch.cuda.get_device_capability() < (10, 0) + + +def apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: torch.Tensor | int | None = 0, + interleaved: bool = False, +) -> torch.Tensor: + rotary_dim = cos.shape[-1] * 2 + x_rot = x[..., :rotary_dim] + x_pass = x[..., rotary_dim:] + + cos = cos.to(dtype=x.dtype) + sin = sin.to(dtype=x.dtype) + + if x_rot.dim() < 2: + raise ValueError(f"apply_rotary_emb expects x.dim() >= 2, got {x_rot.dim()}") + + b = x_rot.shape[0] + s = x_rot.shape[1] + + if seqlen_offsets is None: + seqlen_offsets = 0 + + if isinstance(seqlen_offsets, int): + positions = ( + torch.arange(s, device=x_rot.device, dtype=torch.long) + seqlen_offsets + ) + cos_s = cos.index_select(0, positions) + sin_s = sin.index_select(0, positions) + cos_s = cos_s.unsqueeze(0).expand(b, -1, -1) + sin_s = sin_s.unsqueeze(0).expand(b, -1, -1) + else: + if seqlen_offsets.dim() != 1 or seqlen_offsets.shape[0] != b: + raise ValueError( + "apply_rotary_emb expects seqlen_offsets to be int or shape [batch]" + ) + positions = torch.arange(s, device=x_rot.device, dtype=torch.long).unsqueeze( + 0 + ) + seqlen_offsets.to(dtype=torch.long).unsqueeze(1) + cos_s = cos.index_select(0, positions.reshape(-1)).view(b, s, -1) + sin_s = sin.index_select(0, positions.reshape(-1)).view(b, s, -1) + + x_rot = x_rot.reshape(b, s, -1, rotary_dim) + cos_s = cos_s.unsqueeze(2) + sin_s = sin_s.unsqueeze(2) + + if interleaved: + x1 = x_rot[..., ::2] + x2 = x_rot[..., 1::2] + o1 = x1 * cos_s - x2 * sin_s + o2 = x2 * cos_s + x1 * sin_s + x_rot = torch.stack((o1, o2), dim=-1).flatten(-2) + else: + x1, x2 = torch.chunk(x_rot, 2, dim=-1) + o1 = x1 * cos_s - x2 * sin_s + o2 = x2 * cos_s + x1 * sin_s + x_rot = torch.cat((o1, o2), dim=-1) + + x_rot = x_rot.reshape_as(x[..., :rotary_dim]) + return torch.cat((x_rot, x_pass), dim=-1) + + +def unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = ( + (attention_mask + unused_mask) if unused_mask is not None else attention_mask + ) + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. + return ( + rearrange(hidden_states, "b s ... -> (b s) ...")[indices], + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = torch.zeros( + (batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype + ) + output[indices] = hidden_states + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +def generate_random_padding_mask( + max_seqlen, batch_size, device, mode="random", zero_lengths=False +): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full( + (batch_size, 1), max_seqlen, device=device, dtype=torch.int32 + ) + elif mode == "random": + lengths = torch.randint( + max(0 if zero_lengths else 1, max_seqlen - 20), + max_seqlen + 1, + (batch_size, 1), + device=device, + ) + elif mode == "third": + lengths = torch.randint( + max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device + ) + else: + # This should never happen due to the assertion above, but for linter + lengths = torch.full( + (batch_size, 1), max_seqlen, device=device, dtype=torch.int32 + ) + + if zero_lengths: + # Generate zero-lengths every 5 batches and the last batch. + for i in range(batch_size): + if i % 5 == 0: + lengths[i] = 0 + lengths[-1] = 0 + padding_mask = ( + repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) + < lengths + ) + return padding_mask + + +def generate_qkv( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + qv=None, + kvpacked=False, + qkvpacked=False, + query_unused_mask=None, + key_unused_mask=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d_v) + query_padding_mask: (batch_size, seqlen), bool + key_padding_mask: (batch_size, seqlen), bool + """ + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + d_v = v.shape[-1] + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d_v) + if query_unused_mask is not None or key_unused_mask is not None: + assert not kvpacked + assert not qkvpacked + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( + q, query_padding_mask, query_unused_mask + ) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = ( + rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None + ) + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * seqlen_q, + step=seqlen_q, + dtype=torch.int32, + device=q_unpad.device, + ) + seqused_q = None + max_seqlen_q = seqlen_q + output_pad_fn = lambda output_unpad: rearrange( + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input( + k, key_padding_mask, key_unused_mask + ) + v_unpad, *rest = unpad_input(v, key_padding_mask, key_unused_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, + (batch_size + 1) * seqlen_k, + step=seqlen_k, + dtype=torch.int32, + device=k_unpad.device, + ) + seqused_k = None + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + dqkv_pad_fn = lambda dqkv_unpad: pad_input( + dqkv_unpad, indices_q, batch_size, seqlen_q + ) + else: + dqkv_pad_fn = lambda dqkv_unpad: rearrange( + dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dkv_pad_fn = lambda dkv_unpad: pad_input( + dkv_unpad, indices_k, batch_size, seqlen_k + ) + else: + dkv_pad_fn = lambda dkv_unpad: rearrange( + dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dk_pad_fn = lambda dk_unpad: pad_input( + dk_unpad, indices_k, batch_size, seqlen_k + ) + else: + dk_pad_fn = lambda dk_unpad: rearrange( + dk_unpad, "(b s) h d -> b s h d", b=batch_size + ) + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + qv_unpad.detach() if qv is not None else None, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + qv.detach() if qv is not None else None, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) + + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(None, None), + sink_token_length=0, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, +): + row_idx = rearrange( + torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1" + ) + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] is None: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + torch.logical_and( + col_idx < row_idx + sk - sq - window_size[0], + col_idx >= sink_token_length, + ), + ) + + +def construct_chunk_mask( + seqlen_q, + seqlen_k, + attention_chunk, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, +): + row_idx = rearrange( + torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1" + ) + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + # Subtract remainder instead of divide and then multiply to take care of negative values + col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk + return torch.logical_or( + col_idx < col_limit_left_chunk, + col_idx >= col_limit_left_chunk + attention_chunk, + ) + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=(None, None), + attention_chunk=0, + sink_token_length=0, + learnable_sink=None, + softcap=0.0, + upcast=True, + reorder_ops=False, + intermediate_dtype=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads, head_dim) + v: (batch_size, seqlen_k, nheads, head_dim_v) + qv: (batch_size, seqlen_q, nheads, head_dim_v) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim_v) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + qv = qv.float() if qv is not None else None + if q_descale is not None: + q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) + q = (q.float() * q_descale).to(q.dtype) + qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None + if k_descale is not None: + k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) + if v_descale is not None: + v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype) + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + dv = v.shape[-1] + softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if qv is not None: + scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v) + if softcap > 0: + scores = torch.tanh(scores / softcap) * softcap + if key_padding_mask is not None: + scores.masked_fill_( + rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf") + ) + local_mask = None + if window_size[0] is not None or window_size[1] is not None: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + sink_token_length, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + if attention_chunk > 0: + chunk_mask = construct_chunk_mask( + seqlen_q, + seqlen_k, + attention_chunk, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + local_mask = ( + torch.logical_or(local_mask, chunk_mask) + if local_mask is not None + else chunk_mask + ) + if local_mask is not None: + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + if learnable_sink is None: + attention = torch.softmax(scores, dim=-1).to(v.dtype) + else: + scores_fp32 = scores.to(torch.float32) + logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True) + learnable_sink = rearrange(learnable_sink, "h -> h 1 1") + logits_or_sinks_max = torch.maximum(learnable_sink, logits_max) + unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp( + learnable_sink - logits_or_sinks_max + ) + attention = (unnormalized_scores / normalizer).to(v.dtype) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill( + rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0 + ) + # Without this we might get NaN in dv + if key_padding_mask is not None: + attention = attention.masked_fill( + rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0 + ) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if local_mask is not None: + attention = attention.masked_fill( + torch.all(local_mask, dim=-1, keepdim=True), 0.0 + ) + dropout_scaling = 1.0 / (1 - dropout_p) + # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling + # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + if intermediate_dtype is not None: + attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +@pytest.mark.skipif( + skip_condition, reason="FA4 Requires compute capability of 10 or above." +) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mqa"]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +# @pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("softcap", [0.0, 15.0]) +@pytest.mark.parametrize("softcap", [0.0]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("add_unused_qkv", [False, True]) +@pytest.mark.parametrize("add_unused_qkv", [False]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128]) +@pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + # (1, 1), + # (1, 3), + # (2, 1), + (511, 1), + (3, 513), + (64, 128), + (128, 128), + (256, 256), + # (113, 203), + # (128, 217), + # (113, 211), + # (108, 256), + # (256, 512), + (307, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +def test_flash_attn_varlen_output( + seqlen_q, + seqlen_k, + d, + add_unused_qkv, + causal, + local, + softcap, + deterministic, + has_qv, + has_learnable_sink, + mha_type, + dtype, +): + if ( + causal or local + ): # Right now we only support causal attention with seqlen_k == seqlen_q + seqlen_k = seqlen_q + device = "cuda" + # set seed + torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) + batch_size = 49 if seqlen_q <= 1024 else 7 + nheads = 6 + # batch_size = 1 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + if has_qv: + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random", zero_lengths=False + ) + # TODO: test zero_lengths + key_padding_mask = generate_random_padding_mask( + # seqlen_k, batch_size, device, mode="random", zero_lengths=True + seqlen_k, + batch_size, + device, + mode="random", + zero_lengths=False, + ) + + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + # query_padding_mask[:] = True + # query_unused_mask = None + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) + + if causal or local: + key_padding_mask = query_padding_mask + + result = generate_qkv( + q, + k, + v, + query_padding_mask, + key_padding_mask, + qv=qv, + kvpacked=False, + query_unused_mask=query_unused_mask, + key_unused_mask=key_unused_mask, + ) + ( + q_unpad, # 0 + k_unpad, # 1 + v_unpad, # 2 + qv_unpad, # 3 + cu_seqlens_q, # 4 + cu_seqlens_k, # 5 + seqused_q, # 6 + seqused_k, # 7 + max_seqlen_q, # 8 + max_seqlen_k, # 9 + q, # 10 + k, # 11 + v, # 12 + qv, # 13 + output_pad_fn, # 14 + dq_pad_fn, # 15 + dk_pad_fn, # 16 + ) = result + q_unpad, k_unpad, v_unpad = [ + x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) + ] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + pack_gqa_vals = [False, True, None] + # num_splits_vals = [1, 3] + num_splits_vals = [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out_unpad, lse = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + # max_seqlen_q and max_seqlen_k not needed for FA4 + seqused_q=seqused_q, + seqused_k=seqused_k, + causal=causal, + window_size=window_size, + softcap=softcap, + sinks=learnable_sink, # FA4 uses learnable_sink, not sinks + pack_gqa=pack_gqa, + return_softmax_lse=True, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + + if ( + dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + and dv == d + and not has_learnable_sink + and False + ): + g_unpad = torch.randn_like(out_unpad) + do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + # import flash_attn_3_cuda + # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( + # g_unpad, + # q_unpad, + # k_unpad, + # v_unpad, + # out_unpad, + # lse, + # None, + # None, + # None, + # cu_seqlens_q, + # cu_seqlens_k, + # None, None, + # max_seqlen_q, + # max_seqlen_k, + # d ** (-0.5), + # causal, + # window_size[0], window_size[1], + # softcap, + # deterministic, + # 0, # sm_margin + # ) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad( + out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad + ) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + g = output_pad_fn(g_unpad) + + # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() + # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad( + out_ref, (q_ref, k_ref, v_ref), g + ) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol + + +@pytest.mark.skipif( + skip_condition, reason="FA4 Requires compute capability of 10 or above." +) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +# @pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("new_kv", [False, True]) +@pytest.mark.parametrize("new_kv", [False]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) +@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [False]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False]) +# @pytest.mark.parametrize("rotary_interleaved", [False, True]) +@pytest.mark.parametrize("rotary_interleaved", [True]) +# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) +@pytest.mark.parametrize("rotary_fraction", [0.0]) +# @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128])) +# @pytest.mark.parametrize("page_size", [None, 128]) +@pytest.mark.parametrize("page_size", [128]) +# @pytest.mark.parametrize("has_leftpad", [False, True]) +@pytest.mark.parametrize("has_leftpad", [False]) +# @pytest.mark.parametrize("has_batch_idx", [False, True]) +@pytest.mark.parametrize("has_batch_idx", [False]) +# @pytest.mark.parametrize("varlen_q", [False, True]) +@pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("d", [64]) +# @pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + # # (1, 128 * 1024), + # # (16, 128 * 1024), + # (128, 128), + # (256, 512), # To test appending KV with more than 1 block + # (2048, 3577), # Enough tile to test persistent scheduler + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_kvcache( + seqlen_q, + seqlen_k, + d, + varlen_q, + has_batch_idx, + has_leftpad, + page_size, + rotary_fraction, + rotary_interleaved, + has_rotary_seqlens, + seqlen_new_eq_seqlen_q, + causal, + local, + new_kv, + has_learnable_sink, + mha_type, + dtype, +): + if page_size is not None and seqlen_k % page_size != 0: + pytest.skip() + if seqlen_q > seqlen_k and new_kv: + pytest.skip() + if not new_kv and rotary_fraction > 0.0: + pytest.skip() + if rotary_fraction == 0.0 and has_rotary_seqlens: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 5 + # batch_size = 1 + batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 + nheads = 6 + # nheads = 1 + # rotary_dim must be a multiple of 16, and must be <= d + rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [d] + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) else [0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + # has_qv = d == 64 and dv >= 256 + has_qv = False + q = ( + torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + if has_qv: + qv = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv = None + if varlen_q: + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random" + ) + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input( + q, query_padding_mask + ) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = ( + rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None + ) + else: + query_padding_mask = None + q_unpad = q + qv_unpad = qv + cu_seqlens_q, max_seqlen_q = None, None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + + seqlen_new = ( + seqlen_q + if seqlen_new_eq_seqlen_q + else torch.randint(1, seqlen_q + 1, (1,)).item() + ) + cu_seqlens_k_new = None + key_new_padding_mask = None + if new_kv: + k = ( + torch.randn( + batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + v = ( + torch.randn( + batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + if varlen_q: # k & v are also varlen + key_new_padding_mask = generate_random_padding_mask( + seqlen_new, batch_size, device, mode="random" + ) + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input( + k, key_new_padding_mask + ) + v_unpad, *rest = unpad_input(v, key_new_padding_mask) + else: + k_unpad, v_unpad = k, v + else: + k, v, k_unpad, v_unpad = None, None, None, None + if page_size is None: + k_cache = ( + torch.randn( + batch_size_cache, + seqlen_k, + nheads_k, + d, + device=device, + dtype=dtype_ref, + ) + .to(dtype) + .to(dtype_ref) + ) + v_cache = ( + torch.randn( + batch_size_cache, + seqlen_k, + nheads_k, + dv, + device=device, + dtype=dtype_ref, + ) + .to(dtype) + .to(dtype_ref) + ) + page_table = None + num_blocks = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, + page_size, + batch_size_cache, + nheads_k, + d, + dv, + device, + dtype, + dtype_ref, + ) + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + ( + ( + seqlen_k + - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + + 1 + ) + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, + ) + if has_leftpad: + cache_leftpad = torch.cat( + [ + ( + torch.randint( + 0, + cache_seqlens[i].item(), + (1,), + dtype=torch.int32, + device=device, + ) + if cache_seqlens[i].item() > 0 + else torch.zeros(1, dtype=torch.int32, device=device) + ) + for i in range(batch_size) + ] + ) + else: + cache_leftpad = None + if has_batch_idx: + cache_batch_idx = torch.randperm( + batch_size_cache, dtype=torch.int32, device=device + )[:batch_size] + else: + cache_batch_idx = None + arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + if not new_kv: + key_padding_mask = arange < cache_seqlens_expanded + else: + k_new_seqlens = ( + key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + ) + key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, + arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k), + ) + # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 + if rotary_dim > 0: + angle = ( + torch.rand( + seqlen_k if page_size is None else num_blocks * page_size, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + if causal or local: + q_ro = apply_rotary_emb( + q, + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ) + else: + cos, sin = None, None + q_ro, k_ro = q, k + # k_cache[:, 64:] = -1 + k_cache_ref = ( + k_cache if not has_batch_idx else k_cache[cache_batch_idx] + ).clone() + v_cache_ref = ( + v_cache if not has_batch_idx else v_cache[cache_batch_idx] + ).clone() + if new_kv: + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, + arange < cache_seqlens_expanded + k_new_seqlens, + ) + k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") + v_to_update = rearrange(v, "b s ... -> (b s) ...") + if varlen_q: + k_to_update = k_to_update[indices_k] + v_to_update = v_to_update[indices_k] + k_cache_ref[update_mask] = k_to_update + v_cache_ref[update_mask] = v_to_update + k_cache_rep = repeat( + k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k + ) + v_cache_rep = repeat( + v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k + ) + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + learnable_sink=learnable_sink, + attention_chunk=attention_chunk, + key_leftpad=cache_leftpad, + ) + out_pt, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + learnable_sink=learnable_sink, + attention_chunk=attention_chunk, + upcast=False, + reorder_ops=True, + key_leftpad=cache_leftpad, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + q = q.to(dtype) + q_unpad = q_unpad.to(dtype) if varlen_q else None + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None + v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None + k = k.to(dtype) if k is not None else None + v = v.to(dtype) if v is not None else None + k_unpad = k_unpad.to(dtype) if k_unpad is not None else None + v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + qv = qv.to(dtype) if qv is not None else None + qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None + cos = cos.to(dtype) if cos is not None else None + sin = sin.to(dtype) if sin is not None else None + k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() + v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() + # num_splits_vals = [1, 0] + num_splits_vals = [1] + # precompute_metadata_vals = [False, True] + precompute_metadata_vals = [False] + for num_splits, precompute_metadata in itertools.product( + num_splits_vals, precompute_metadata_vals + ): + # if precompute_metadata: + # scheduler_metadata = get_scheduler_metadata( + # batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, + # cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, + # cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, + # max_seqlen_k_new=seqlen_new, page_size=page_size, + # causal=causal, window_size=window_size, attention_chunk=attention_chunk, + # num_splits=num_splits + # ) + # else: + # scheduler_metadata = None + scheduler_metadata = None + # Repeat to test metadata reuse + for _ in range(1 if not precompute_metadata else 2): + if page_size is None: + k_cache.copy_(k_cache_saved) + v_cache.copy_(v_cache_saved) + else: + k_cache_paged.copy_(k_cache_saved) + v_cache_paged.copy_(v_cache_saved) + # For FA4, use flash_attn_varlen_func directly instead of flash_attn_with_kvcache + # This matches the pattern from the original FA4 test + out, lse = flash_attn_varlen_func( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=None, # FA4 doesn't use cu_seqlens_k for KV cache + # max_seqlen_q and max_seqlen_k not needed for FA4 + seqused_k=cache_seqlens, # Use cache_seqlens as seqused_k + page_table=page_table, + causal=causal, + window_size=window_size, + sinks=learnable_sink, # FA4 uses learnable_sink, not sinks + softcap=0.0, + pack_gqa=None, + return_softmax_lse=True, + ) + if varlen_q: + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) + if not has_batch_idx + else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) + if not has_batch_idx + else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) + if dtype is not torch.float8_e4m3fn: + assert torch.equal(v_cache_select, v_cache_ref) + else: + assert torch.allclose( + v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3 + ) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3 + ) + else: + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1 + ) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * ( + out_pt - out_ref + ).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * ( + out_pt - out_ref + ).abs().mean().item() + + +def _generate_block_kvcache( + seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref +): + num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 + k_cache_paged = ( + torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + v_cache_paged = ( + torch.randn(num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + page_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + k_cache = rearrange( + k_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache = rearrange( + v_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_fused_add_rmsnorm.py b/sglang/python/sglang/jit_kernel/tests/test_fused_add_rmsnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..52c2dc6128abbb9c9c4c8556e9233875a4ccf88c --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_fused_add_rmsnorm.py @@ -0,0 +1,55 @@ +import itertools + +import pytest +import torch + + +def sglang_jit_fused_add_rmsnorm( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float +) -> None: + from sglang.jit_kernel.norm import fused_add_rmsnorm + + fused_add_rmsnorm(input, residual, weight, eps) + + +def flashinfer_fused_add_rmsnorm( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float +) -> None: + from flashinfer.norm import fused_add_rmsnorm + + fused_add_rmsnorm(input, residual, weight, eps=eps) + + +BS_LIST = [2**n for n in range(0, 14)] +BS_LIST += [x + 1 + i for i, x in enumerate(BS_LIST)] +HIDDEN_SIZE_LIST = [512, 1024, 1536, 2048, 3072, 4096, 5120, 6144, 7168, 8192] +DEVICE = "cuda" +DTYPE = torch.bfloat16 + + +@pytest.mark.parametrize( + "batch_size,hidden_size", list(itertools.product(BS_LIST, HIDDEN_SIZE_LIST)) +) +def test_fused_add_rmsnorm(batch_size: int, hidden_size: int) -> None: + input = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=DTYPE) + residual = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=DTYPE) + weight = torch.randn(hidden_size, device=DEVICE, dtype=DTYPE) + + input_sglang = input.clone() + residual_sglang = residual.clone() + input_flashinfer = input.clone() + residual_flashinfer = residual.clone() + sglang_jit_fused_add_rmsnorm( + input_sglang, residual_sglang, weight, torch.finfo(torch.bfloat16).eps + ) + flashinfer_fused_add_rmsnorm( + input_flashinfer, residual_flashinfer, weight, torch.finfo(torch.bfloat16).eps + ) + torch.testing.assert_close(input_sglang, input_flashinfer, atol=1e-2, rtol=1e-2) + torch.testing.assert_close( + residual_sglang, residual_flashinfer, atol=1e-2, rtol=1e-2 + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_fused_metadata_copy.py b/sglang/python/sglang/jit_kernel/tests/test_fused_metadata_copy.py new file mode 100644 index 0000000000000000000000000000000000000000..9cffd8d88108f803351e1cc8417c14a80191fe99 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_fused_metadata_copy.py @@ -0,0 +1,1067 @@ +""" +Comprehensive tests for JIT-compiled fused metadata copy kernels. + +This test suite verifies: +1. Single-backend fused kernel (fused_metadata_copy_cuda) - all forward modes +2. Multi-backend fused kernel (fused_metadata_copy_multi_cuda) - 3 backends at once +3. Correctness against reference implementations +4. Performance benchmarks and speedup measurements +""" + +import time + +import pytest +import torch + +# ============================================================================= +# Helper Functions +# ============================================================================= + + +def create_test_metadata( + bs: int, + max_len: int, + max_seqlen_k: int, + seqlens_expanded_size: int, + has_real_page_table: bool = False, + has_flashmla: bool = False, + device: str = "cuda", +): + """Create test metadata tensors matching NSA backend structure.""" + # Basic tensors (always present) + cache_seqlens_src = torch.randint( + 1, max_len, (bs,), dtype=torch.int32, device=device + ) + cu_seqlens_k_src = torch.zeros(bs + 1, dtype=torch.int32, device=device) + cu_seqlens_k_src[1:] = torch.cumsum(cache_seqlens_src, dim=0) + + page_indices_src = torch.randint( + 0, 1000, (bs, max_len), dtype=torch.int32, device=device + ) + nsa_cache_seqlens_src = torch.randint( + 1, max_len, (seqlens_expanded_size,), dtype=torch.int32, device=device + ) + seqlens_expanded_src = torch.randint( + 1, max_seqlen_k, (seqlens_expanded_size,), dtype=torch.int32, device=device + ) + nsa_cu_seqlens_k_src = torch.zeros( + seqlens_expanded_size + 1, dtype=torch.int32, device=device + ) + nsa_cu_seqlens_k_src[1:] = torch.cumsum(nsa_cache_seqlens_src, dim=0) + + # Destination tensors + cache_seqlens_dst = torch.zeros(bs, dtype=torch.int32, device=device) + cu_seqlens_k_dst = torch.zeros(bs + 1, dtype=torch.int32, device=device) + page_table_1_dst = torch.zeros((bs, max_len + 16), dtype=torch.int32, device=device) + nsa_cache_seqlens_dst = torch.zeros( + seqlens_expanded_size, dtype=torch.int32, device=device + ) + nsa_seqlens_expanded_dst = torch.zeros( + seqlens_expanded_size, dtype=torch.int32, device=device + ) + nsa_cu_seqlens_k_dst = torch.zeros( + seqlens_expanded_size + 1, dtype=torch.int32, device=device + ) + + # Optional tensors + real_page_table_src = None + real_page_table_dst = None + if has_real_page_table: + real_page_table_cols = max_len // 2 + real_page_table_src = torch.randint( + 0, 1000, (bs, real_page_table_cols), dtype=torch.int32, device=device + ) + real_page_table_dst = torch.zeros( + (bs, real_page_table_cols + 8), dtype=torch.int32, device=device + ) + + flashmla_num_splits_src = None + flashmla_num_splits_dst = None + flashmla_metadata_src = None + flashmla_metadata_dst = None + if has_flashmla: + flashmla_num_splits_src = torch.randint( + 1, 10, (seqlens_expanded_size + 1,), dtype=torch.int32, device=device + ) + flashmla_num_splits_dst = torch.zeros( + seqlens_expanded_size + 1, dtype=torch.int32, device=device + ) + # FlashMLA metadata is typically (num_sm_parts, TileSchedulerMetaDataSize) + # For testing, we use a simplified size + flashmla_metadata_size = 128 + flashmla_metadata_src = torch.randint( + 0, 100, (flashmla_metadata_size,), dtype=torch.int32, device=device + ) + flashmla_metadata_dst = torch.zeros( + flashmla_metadata_size, dtype=torch.int32, device=device + ) + + return { + "src": { + "cache_seqlens": cache_seqlens_src, + "cu_seqlens_k": cu_seqlens_k_src, + "page_indices": page_indices_src, + "nsa_cache_seqlens": nsa_cache_seqlens_src, + "seqlens_expanded": seqlens_expanded_src, + "nsa_cu_seqlens_k": nsa_cu_seqlens_k_src, + "real_page_table": real_page_table_src, + "flashmla_num_splits": flashmla_num_splits_src, + "flashmla_metadata": flashmla_metadata_src, + }, + "dst": { + "cache_seqlens": cache_seqlens_dst, + "cu_seqlens_k": cu_seqlens_k_dst, + "page_table_1": page_table_1_dst, + "nsa_cache_seqlens": nsa_cache_seqlens_dst, + "nsa_seqlens_expanded": nsa_seqlens_expanded_dst, + "nsa_cu_seqlens_k": nsa_cu_seqlens_k_dst, + "real_page_table": real_page_table_dst, + "flashmla_num_splits": flashmla_num_splits_dst, + "flashmla_metadata": flashmla_metadata_dst, + }, + } + + +def reference_copy_decode(src, dst, max_len): + """Reference implementation: individual .copy_() for DECODE mode.""" + bs = src["cache_seqlens"].shape[0] + dst["cache_seqlens"].copy_(src["cache_seqlens"]) + dst["cu_seqlens_k"][1:].copy_(src["cu_seqlens_k"][1:]) + dst["page_table_1"][:, :max_len].copy_(src["page_indices"]) + dst["nsa_cache_seqlens"].copy_(src["nsa_cache_seqlens"]) + dst["nsa_cu_seqlens_k"][1 : bs + 1].copy_(src["nsa_cu_seqlens_k"][1 : bs + 1]) + + if src["real_page_table"] is not None: + rows, cols = src["real_page_table"].shape + dst["real_page_table"][:rows, :cols].copy_(src["real_page_table"]) + + if src["flashmla_num_splits"] is not None: + flashmla_size = bs + 1 + dst["flashmla_num_splits"][:flashmla_size].copy_( + src["flashmla_num_splits"][:flashmla_size] + ) + + if src["flashmla_metadata"] is not None: + dst["flashmla_metadata"].copy_(src["flashmla_metadata"]) + + +def reference_copy_target_verify(src, dst, max_seqlen_k, seqlens_expanded_size): + """Reference implementation: individual .copy_() for TARGET_VERIFY mode.""" + bs = src["cache_seqlens"].shape[0] + dst["cache_seqlens"].copy_(src["cache_seqlens"]) + dst["cu_seqlens_k"][1:].copy_(src["cu_seqlens_k"][1:]) + + rows, cols = src["page_indices"].shape + dst["page_table_1"][:rows, :cols].copy_(src["page_indices"]) + dst["nsa_seqlens_expanded"][:seqlens_expanded_size].copy_(src["seqlens_expanded"]) + dst["nsa_cache_seqlens"][:seqlens_expanded_size].copy_(src["nsa_cache_seqlens"]) + dst["nsa_cu_seqlens_k"][1 : seqlens_expanded_size + 1].copy_( + src["nsa_cu_seqlens_k"][1 : seqlens_expanded_size + 1] + ) + + if src["real_page_table"] is not None: + rows, cols = src["real_page_table"].shape + dst["real_page_table"][:rows, :cols].copy_(src["real_page_table"]) + + if src["flashmla_num_splits"] is not None: + flashmla_size = seqlens_expanded_size + 1 + dst["flashmla_num_splits"][:flashmla_size].copy_( + src["flashmla_num_splits"][:flashmla_size] + ) + + if src["flashmla_metadata"] is not None: + dst["flashmla_metadata"].copy_(src["flashmla_metadata"]) + + +def reference_copy_draft_extend(src, dst, max_seqlen_k, seqlens_expanded_size): + """Reference implementation: individual .copy_() for DRAFT_EXTEND mode.""" + bs = src["cache_seqlens"].shape[0] + dst["cache_seqlens"].copy_(src["cache_seqlens"]) + dst["cu_seqlens_k"][1:].copy_(src["cu_seqlens_k"][1:]) + + rows, cols = src["page_indices"].shape + dst["page_table_1"][:rows, :cols].copy_(src["page_indices"]) + dst["nsa_seqlens_expanded"][:seqlens_expanded_size].copy_(src["seqlens_expanded"]) + dst["nsa_cache_seqlens"][:seqlens_expanded_size].copy_(src["nsa_cache_seqlens"]) + dst["nsa_cu_seqlens_k"][1 : seqlens_expanded_size + 1].copy_( + src["nsa_cu_seqlens_k"][1 : seqlens_expanded_size + 1] + ) + + if src["real_page_table"] is not None: + rows, cols = src["real_page_table"].shape + dst["real_page_table"][:rows, :cols].copy_(src["real_page_table"]) + + if src["flashmla_num_splits"] is not None: + flashmla_size = seqlens_expanded_size + 1 + dst["flashmla_num_splits"][:flashmla_size].copy_( + src["flashmla_num_splits"][:flashmla_size] + ) + + if src["flashmla_metadata"] is not None: + dst["flashmla_metadata"].copy_(src["flashmla_metadata"]) + + +# ============================================================================= +# Single-Backend Kernel Tests +# ============================================================================= + + +def test_fused_metadata_copy_dtype_validation(): + """Test that dtype validation rejects non-int32 tensors.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + from sglang.jit_kernel.fused_metadata_copy import fused_metadata_copy_cuda + + bs = 2 + max_len = 128 + max_seqlen_k = 256 + seqlens_expanded_size = bs + device = "cuda" + + # Create tensors with WRONG dtype (int64 instead of int32) + cache_seqlens_src_wrong = torch.randint( + 1, max_len, (bs,), dtype=torch.int64, device=device + ) + cu_seqlens_k_src = torch.zeros(bs + 1, dtype=torch.int32, device=device) + page_indices_src = torch.randint( + 0, 1000, (bs, max_len), dtype=torch.int32, device=device + ) + nsa_cache_seqlens_src = torch.randint( + 1, max_len, (seqlens_expanded_size,), dtype=torch.int32, device=device + ) + seqlens_expanded_src = torch.randint( + 1, max_seqlen_k, (seqlens_expanded_size,), dtype=torch.int32, device=device + ) + nsa_cu_seqlens_k_src = torch.zeros( + seqlens_expanded_size + 1, dtype=torch.int32, device=device + ) + + # Destination tensors (correct dtype) + cache_seqlens_dst = torch.zeros(bs, dtype=torch.int32, device=device) + cu_seqlens_k_dst = torch.zeros(bs + 1, dtype=torch.int32, device=device) + page_table_1_dst = torch.zeros((bs, max_len + 16), dtype=torch.int32, device=device) + nsa_cache_seqlens_dst = torch.zeros( + seqlens_expanded_size, dtype=torch.int32, device=device + ) + nsa_seqlens_expanded_dst = torch.zeros( + seqlens_expanded_size, dtype=torch.int32, device=device + ) + nsa_cu_seqlens_k_dst = torch.zeros( + seqlens_expanded_size + 1, dtype=torch.int32, device=device + ) + + # Test 1: Wrong dtype for source tensor should raise RuntimeError + with pytest.raises(RuntimeError, match="must have dtype int32"): + fused_metadata_copy_cuda( + cache_seqlens_src_wrong, # Wrong dtype: int64 + cu_seqlens_k_src, + page_indices_src, + nsa_cache_seqlens_src, + seqlens_expanded_src, + nsa_cu_seqlens_k_src, + None, # real_page_table_src + None, # flashmla_num_splits_src + None, # flashmla_metadata_src + cache_seqlens_dst, + cu_seqlens_k_dst, + page_table_1_dst, + nsa_cache_seqlens_dst, + nsa_seqlens_expanded_dst, + nsa_cu_seqlens_k_dst, + None, # real_page_table_dst + None, # flashmla_num_splits_dst + None, # flashmla_metadata_dst + 0, # forward_mode + bs, + max_len, + max_seqlen_k, + seqlens_expanded_size, + ) + + # Test 2: Wrong dtype for destination tensor should also raise RuntimeError + cache_seqlens_src = torch.randint( + 1, max_len, (bs,), dtype=torch.int32, device=device + ) + cache_seqlens_dst_wrong = torch.zeros(bs, dtype=torch.int64, device=device) + + with pytest.raises(RuntimeError, match="must have dtype int32"): + fused_metadata_copy_cuda( + cache_seqlens_src, + cu_seqlens_k_src, + page_indices_src, + nsa_cache_seqlens_src, + seqlens_expanded_src, + nsa_cu_seqlens_k_src, + None, + None, + None, + cache_seqlens_dst_wrong, # Wrong dtype: int64 + cu_seqlens_k_dst, + page_table_1_dst, + nsa_cache_seqlens_dst, + nsa_seqlens_expanded_dst, + nsa_cu_seqlens_k_dst, + None, + None, + None, + 0, + bs, + max_len, + max_seqlen_k, + seqlens_expanded_size, + ) + + +@pytest.mark.parametrize("bs", [1, 2, 4, 8]) +@pytest.mark.parametrize( + "forward_mode", [0] +) # DECODE mode only (other modes not fully tested yet) +@pytest.mark.parametrize("has_real_page_table", [False, True]) +@pytest.mark.parametrize("has_flashmla", [False, True]) +def test_fused_metadata_copy(bs, forward_mode, has_real_page_table, has_flashmla): + """Test fused metadata copy kernel against reference implementation.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + from sglang.jit_kernel.fused_metadata_copy import fused_metadata_copy_cuda + + max_len = 128 + max_seqlen_k = 256 + seqlens_expanded_size = bs if forward_mode == 0 else bs * 2 + + # Create test data + data = create_test_metadata( + bs=bs, + max_len=max_len, + max_seqlen_k=max_seqlen_k, + seqlens_expanded_size=seqlens_expanded_size, + has_real_page_table=has_real_page_table, + has_flashmla=has_flashmla, + ) + + # Create separate destination tensors for reference and fused kernel + dst_ref = {k: v.clone() if v is not None else None for k, v in data["dst"].items()} + dst_fused = { + k: v.clone() if v is not None else None for k, v in data["dst"].items() + } + + # Run reference implementation + if forward_mode == 0: # DECODE + reference_copy_decode(data["src"], dst_ref, max_len) + elif forward_mode == 1: # TARGET_VERIFY + reference_copy_target_verify( + data["src"], dst_ref, max_seqlen_k, seqlens_expanded_size + ) + else: # DRAFT_EXTEND + reference_copy_draft_extend( + data["src"], dst_ref, max_seqlen_k, seqlens_expanded_size + ) + + # Run fused kernel + fused_metadata_copy_cuda( + data["src"]["cache_seqlens"], + data["src"]["cu_seqlens_k"], + data["src"]["page_indices"], + data["src"]["nsa_cache_seqlens"], + data["src"]["seqlens_expanded"], + data["src"]["nsa_cu_seqlens_k"], + data["src"]["real_page_table"], + data["src"]["flashmla_num_splits"], + data["src"]["flashmla_metadata"], + dst_fused["cache_seqlens"], + dst_fused["cu_seqlens_k"], + dst_fused["page_table_1"], + dst_fused["nsa_cache_seqlens"], + dst_fused["nsa_seqlens_expanded"], + dst_fused["nsa_cu_seqlens_k"], + dst_fused["real_page_table"], + dst_fused["flashmla_num_splits"], + dst_fused["flashmla_metadata"], + forward_mode, + bs, + max_len, + max_seqlen_k, + seqlens_expanded_size, + ) + + # Compare results + assert torch.equal( + dst_ref["cache_seqlens"], dst_fused["cache_seqlens"] + ), "cache_seqlens mismatch" + assert torch.equal( + dst_ref["cu_seqlens_k"], dst_fused["cu_seqlens_k"] + ), "cu_seqlens_k mismatch" + assert torch.equal( + dst_ref["page_table_1"], dst_fused["page_table_1"] + ), "page_table_1 mismatch" + assert torch.equal( + dst_ref["nsa_cache_seqlens"], dst_fused["nsa_cache_seqlens"] + ), "nsa_cache_seqlens mismatch" + assert torch.equal( + dst_ref["nsa_seqlens_expanded"], dst_fused["nsa_seqlens_expanded"] + ), "nsa_seqlens_expanded mismatch" + assert torch.equal( + dst_ref["nsa_cu_seqlens_k"], dst_fused["nsa_cu_seqlens_k"] + ), "nsa_cu_seqlens_k mismatch" + + if has_real_page_table: + assert torch.equal( + dst_ref["real_page_table"], dst_fused["real_page_table"] + ), "real_page_table mismatch" + + if has_flashmla: + assert torch.equal( + dst_ref["flashmla_num_splits"], dst_fused["flashmla_num_splits"] + ), "flashmla_num_splits mismatch" + assert torch.equal( + dst_ref["flashmla_metadata"], dst_fused["flashmla_metadata"] + ), "flashmla_metadata mismatch" + + +@pytest.mark.parametrize("bs", [16, 32]) +def test_fused_metadata_copy_large_batch(bs): + """Test with larger batch sizes.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + from sglang.jit_kernel.fused_metadata_copy import fused_metadata_copy_cuda + + forward_mode = 0 # DECODE + max_len = 128 + max_seqlen_k = 256 + seqlens_expanded_size = bs + + data = create_test_metadata( + bs=bs, + max_len=max_len, + max_seqlen_k=max_seqlen_k, + seqlens_expanded_size=seqlens_expanded_size, + has_real_page_table=True, + has_flashmla=True, + ) + + dst_ref = {k: v.clone() if v is not None else None for k, v in data["dst"].items()} + dst_fused = { + k: v.clone() if v is not None else None for k, v in data["dst"].items() + } + + reference_copy_decode(data["src"], dst_ref, max_len) + + fused_metadata_copy_cuda( + data["src"]["cache_seqlens"], + data["src"]["cu_seqlens_k"], + data["src"]["page_indices"], + data["src"]["nsa_cache_seqlens"], + data["src"]["seqlens_expanded"], + data["src"]["nsa_cu_seqlens_k"], + data["src"]["real_page_table"], + data["src"]["flashmla_num_splits"], + data["src"]["flashmla_metadata"], + dst_fused["cache_seqlens"], + dst_fused["cu_seqlens_k"], + dst_fused["page_table_1"], + dst_fused["nsa_cache_seqlens"], + dst_fused["nsa_seqlens_expanded"], + dst_fused["nsa_cu_seqlens_k"], + dst_fused["real_page_table"], + dst_fused["flashmla_num_splits"], + dst_fused["flashmla_metadata"], + forward_mode, + bs, + max_len, + max_seqlen_k, + seqlens_expanded_size, + ) + + # Verify all tensors match + for key in dst_ref: + if dst_ref[key] is not None: + assert torch.equal(dst_ref[key], dst_fused[key]), f"{key} mismatch" + + +# ============================================================================= +# Multi-Backend Kernel Tests +# ============================================================================= + + +def create_test_metadata_multi( + bs: int, + max_len: int, + seqlens_expanded_size: int, + has_real_page_table: bool = False, + has_flashmla: bool = False, + device: str = "cuda", +): + """Create test metadata tensors for multi-backend testing.""" + # Source tensors (precomputed metadata) + cache_seqlens_src = torch.randint( + 1, max_len, (bs,), dtype=torch.int32, device=device + ) + cu_seqlens_k_src = torch.zeros(bs + 1, dtype=torch.int32, device=device) + cu_seqlens_k_src[1:] = torch.cumsum(cache_seqlens_src, dim=0) + + page_indices_src = torch.randint( + 0, 1000, (bs, max_len), dtype=torch.int32, device=device + ) + nsa_cache_seqlens_src = torch.randint( + 1, max_len, (seqlens_expanded_size,), dtype=torch.int32, device=device + ) + nsa_cu_seqlens_k_src = torch.zeros( + seqlens_expanded_size + 1, dtype=torch.int32, device=device + ) + nsa_cu_seqlens_k_src[1:] = torch.cumsum(nsa_cache_seqlens_src, dim=0) + + # Optional tensors + real_page_table_src = None + if has_real_page_table: + real_page_table_cols = max_len // 2 + real_page_table_src = torch.randint( + 0, 1000, (bs, real_page_table_cols), dtype=torch.int32, device=device + ) + + flashmla_num_splits_src = None + flashmla_metadata_src = None + if has_flashmla: + flashmla_num_splits_src = torch.randint( + 1, 10, (seqlens_expanded_size + 1,), dtype=torch.int32, device=device + ) + flashmla_metadata_size = 128 + flashmla_metadata_src = torch.randint( + 0, 100, (flashmla_metadata_size,), dtype=torch.int32, device=device + ) + + # Create destination tensors for 3 backends + def create_dst_tensors(): + cache_seqlens_dst = torch.zeros(bs, dtype=torch.int32, device=device) + cu_seqlens_k_dst = torch.zeros(bs + 1, dtype=torch.int32, device=device) + page_table_1_dst = torch.zeros( + (bs, max_len + 16), dtype=torch.int32, device=device + ) + nsa_cache_seqlens_dst = torch.zeros( + seqlens_expanded_size, dtype=torch.int32, device=device + ) + nsa_cu_seqlens_k_dst = torch.zeros( + seqlens_expanded_size + 1, dtype=torch.int32, device=device + ) + + real_page_table_dst = None + if has_real_page_table: + real_page_table_cols = max_len // 2 + real_page_table_dst = torch.zeros( + (bs, real_page_table_cols + 8), dtype=torch.int32, device=device + ) + + flashmla_num_splits_dst = None + flashmla_metadata_dst = None + if has_flashmla: + flashmla_num_splits_dst = torch.zeros( + seqlens_expanded_size + 1, dtype=torch.int32, device=device + ) + flashmla_metadata_size = 128 + flashmla_metadata_dst = torch.zeros( + flashmla_metadata_size, dtype=torch.int32, device=device + ) + + return { + "cache_seqlens_int32": cache_seqlens_dst, + "cu_seqlens_k": cu_seqlens_k_dst, + "page_table_1": page_table_1_dst, + "nsa_cache_seqlens_int32": nsa_cache_seqlens_dst, + "nsa_cu_seqlens_k": nsa_cu_seqlens_k_dst, + "real_page_table": real_page_table_dst, + "flashmla_num_splits": flashmla_num_splits_dst, + "flashmla_metadata": flashmla_metadata_dst, + } + + return { + "src": { + "cache_seqlens": cache_seqlens_src, + "cu_seqlens_k": cu_seqlens_k_src, + "page_indices": page_indices_src, + "nsa_cache_seqlens": nsa_cache_seqlens_src, + "nsa_cu_seqlens_k": nsa_cu_seqlens_k_src, + "real_page_table": real_page_table_src, + "flashmla_num_splits": flashmla_num_splits_src, + "flashmla_metadata": flashmla_metadata_src, + }, + "dst0": create_dst_tensors(), + "dst1": create_dst_tensors(), + "dst2": create_dst_tensors(), + } + + +def reference_copy_for_loop(src, dst_list, bs, max_len): + """Reference implementation: for-loop calling copy for each backend.""" + for dst in dst_list: + # Simulate what init_forward_metadata_replay_cuda_graph_from_precomputed does + dst["cache_seqlens_int32"].copy_(src["cache_seqlens"]) + dst["cu_seqlens_k"][1:].copy_(src["cu_seqlens_k"][1:]) + dst["page_table_1"][:, :max_len].copy_(src["page_indices"]) + dst["nsa_cache_seqlens_int32"].copy_(src["nsa_cache_seqlens"]) + dst["nsa_cu_seqlens_k"][1 : bs + 1].copy_(src["nsa_cu_seqlens_k"][1 : bs + 1]) + + if src["real_page_table"] is not None: + rows, cols = src["real_page_table"].shape + dst["real_page_table"][:rows, :cols].copy_(src["real_page_table"]) + + if src["flashmla_num_splits"] is not None: + flashmla_size = bs + 1 + dst["flashmla_num_splits"][:flashmla_size].copy_( + src["flashmla_num_splits"][:flashmla_size] + ) + + if src["flashmla_metadata"] is not None: + dst["flashmla_metadata"].copy_(src["flashmla_metadata"]) + + +def test_fused_metadata_copy_multi_dtype_validation(): + """Test that dtype validation rejects non-int32 tensors for multi-backend kernel.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + from sglang.jit_kernel.fused_metadata_copy import fused_metadata_copy_multi_cuda + + bs = 2 + max_len = 128 + seqlens_expanded_size = bs + device = "cuda" + + # Create source tensors - one with WRONG dtype + cache_seqlens_src_wrong = torch.randint( + 1, max_len, (bs,), dtype=torch.int64, device=device # Wrong dtype! + ) + cu_seqlens_k_src = torch.zeros(bs + 1, dtype=torch.int32, device=device) + page_indices_src = torch.randint( + 0, 1000, (bs, max_len), dtype=torch.int32, device=device + ) + nsa_cache_seqlens_src = torch.randint( + 1, max_len, (seqlens_expanded_size,), dtype=torch.int32, device=device + ) + nsa_cu_seqlens_k_src = torch.zeros( + seqlens_expanded_size + 1, dtype=torch.int32, device=device + ) + + # Create destination tensors for 3 backends (all correct dtype) + def create_dst(): + return { + "cache_seqlens": torch.zeros(bs, dtype=torch.int32, device=device), + "cu_seqlens_k": torch.zeros(bs + 1, dtype=torch.int32, device=device), + "page_table_1": torch.zeros( + (bs, max_len + 16), dtype=torch.int32, device=device + ), + "nsa_cache_seqlens": torch.zeros( + seqlens_expanded_size, dtype=torch.int32, device=device + ), + "nsa_cu_seqlens_k": torch.zeros( + seqlens_expanded_size + 1, dtype=torch.int32, device=device + ), + } + + dst0 = create_dst() + dst1 = create_dst() + dst2 = create_dst() + + # Test: Wrong dtype for source tensor should raise RuntimeError + with pytest.raises(RuntimeError, match="must have dtype int32"): + fused_metadata_copy_multi_cuda( + cache_seqlens_src_wrong, # Wrong dtype: int64 + cu_seqlens_k_src, + page_indices_src, + nsa_cache_seqlens_src, + nsa_cu_seqlens_k_src, + None, # real_page_table_src + None, # flashmla_num_splits_src + None, # flashmla_metadata_src + # Backend 0 + dst0["cache_seqlens"], + dst0["cu_seqlens_k"], + dst0["page_table_1"], + dst0["nsa_cache_seqlens"], + dst0["nsa_cu_seqlens_k"], + None, + None, + None, + # Backend 1 + dst1["cache_seqlens"], + dst1["cu_seqlens_k"], + dst1["page_table_1"], + dst1["nsa_cache_seqlens"], + dst1["nsa_cu_seqlens_k"], + None, + None, + None, + # Backend 2 + dst2["cache_seqlens"], + dst2["cu_seqlens_k"], + dst2["page_table_1"], + dst2["nsa_cache_seqlens"], + dst2["nsa_cu_seqlens_k"], + None, + None, + None, + # Parameters + bs, + max_len, + seqlens_expanded_size, + ) + + +@pytest.mark.parametrize("bs", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("has_real_page_table", [False, True]) +@pytest.mark.parametrize("has_flashmla", [False, True]) +def test_fused_metadata_copy_multi(bs, has_real_page_table, has_flashmla): + """Test fused multi-backend metadata copy kernel against for-loop version.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + from sglang.jit_kernel.fused_metadata_copy import fused_metadata_copy_multi_cuda + + max_len = 128 + seqlens_expanded_size = bs + + # Create test data + data = create_test_metadata_multi( + bs=bs, + max_len=max_len, + seqlens_expanded_size=seqlens_expanded_size, + has_real_page_table=has_real_page_table, + has_flashmla=has_flashmla, + ) + + # Create separate destination tensors for reference (for-loop) and fused kernel + dst_ref_0 = { + k: v.clone() if v is not None else None for k, v in data["dst0"].items() + } + dst_ref_1 = { + k: v.clone() if v is not None else None for k, v in data["dst1"].items() + } + dst_ref_2 = { + k: v.clone() if v is not None else None for k, v in data["dst2"].items() + } + + dst_fused_0 = { + k: v.clone() if v is not None else None for k, v in data["dst0"].items() + } + dst_fused_1 = { + k: v.clone() if v is not None else None for k, v in data["dst1"].items() + } + dst_fused_2 = { + k: v.clone() if v is not None else None for k, v in data["dst2"].items() + } + + # Run reference implementation (for-loop) + torch.cuda.synchronize() + loop_start = time.perf_counter() + reference_copy_for_loop(data["src"], [dst_ref_0, dst_ref_1, dst_ref_2], bs, max_len) + torch.cuda.synchronize() + loop_end = time.perf_counter() + loop_time = loop_end - loop_start + + # Run fused kernel + torch.cuda.synchronize() + fused_start = time.perf_counter() + fused_metadata_copy_multi_cuda( + # Source tensors + data["src"]["cache_seqlens"], + data["src"]["cu_seqlens_k"], + data["src"]["page_indices"], + data["src"]["nsa_cache_seqlens"], + data["src"]["nsa_cu_seqlens_k"], + data["src"]["real_page_table"], + data["src"]["flashmla_num_splits"], + data["src"]["flashmla_metadata"], + # Destination tensors for backend 0 + dst_fused_0["cache_seqlens_int32"], + dst_fused_0["cu_seqlens_k"], + dst_fused_0["page_table_1"], + dst_fused_0["nsa_cache_seqlens_int32"], + dst_fused_0["nsa_cu_seqlens_k"], + dst_fused_0["real_page_table"], + dst_fused_0["flashmla_num_splits"], + dst_fused_0["flashmla_metadata"], + # Destination tensors for backend 1 + dst_fused_1["cache_seqlens_int32"], + dst_fused_1["cu_seqlens_k"], + dst_fused_1["page_table_1"], + dst_fused_1["nsa_cache_seqlens_int32"], + dst_fused_1["nsa_cu_seqlens_k"], + dst_fused_1["real_page_table"], + dst_fused_1["flashmla_num_splits"], + dst_fused_1["flashmla_metadata"], + # Destination tensors for backend 2 + dst_fused_2["cache_seqlens_int32"], + dst_fused_2["cu_seqlens_k"], + dst_fused_2["page_table_1"], + dst_fused_2["nsa_cache_seqlens_int32"], + dst_fused_2["nsa_cu_seqlens_k"], + dst_fused_2["real_page_table"], + dst_fused_2["flashmla_num_splits"], + dst_fused_2["flashmla_metadata"], + # Parameters + bs, + max_len, + seqlens_expanded_size, + ) + torch.cuda.synchronize() + fused_end = time.perf_counter() + fused_time = fused_end - fused_start + + # Compare results for all 3 backends + speedup = loop_time / fused_time if fused_time > 0 else 0 + print( + f"\n[VERIFY] bs={bs}, real_page_table={has_real_page_table}, flashmla={has_flashmla}" + ) + print( + f"[VERIFY] Fused time: {fused_time*1000:.3f}ms, Loop time: {loop_time*1000:.3f}ms, Speedup: {speedup:.2f}x" + ) + + max_diff = 0.0 + all_match = True + + for backend_idx, (dst_ref, dst_fused) in enumerate( + [ + (dst_ref_0, dst_fused_0), + (dst_ref_1, dst_fused_1), + (dst_ref_2, dst_fused_2), + ] + ): + for key in [ + "cache_seqlens_int32", + "cu_seqlens_k", + "page_table_1", + "nsa_cache_seqlens_int32", + "nsa_cu_seqlens_k", + ]: + if not torch.equal(dst_ref[key], dst_fused[key]): + diff = ( + (dst_ref[key].float() - dst_fused[key].float()).abs().max().item() + ) + max_diff = max(max_diff, diff) + all_match = False + print( + f"[ERROR] Backend {backend_idx} {key}: MISMATCH! Max diff: {diff}" + ) + + if has_real_page_table and dst_ref["real_page_table"] is not None: + if not torch.equal( + dst_ref["real_page_table"], dst_fused["real_page_table"] + ): + diff = ( + ( + dst_ref["real_page_table"].float() + - dst_fused["real_page_table"].float() + ) + .abs() + .max() + .item() + ) + max_diff = max(max_diff, diff) + all_match = False + print( + f"[ERROR] Backend {backend_idx} real_page_table: MISMATCH! Max diff: {diff}" + ) + + if has_flashmla: + if dst_ref["flashmla_num_splits"] is not None and not torch.equal( + dst_ref["flashmla_num_splits"], dst_fused["flashmla_num_splits"] + ): + diff = ( + ( + dst_ref["flashmla_num_splits"].float() + - dst_fused["flashmla_num_splits"].float() + ) + .abs() + .max() + .item() + ) + max_diff = max(max_diff, diff) + all_match = False + print( + f"[ERROR] Backend {backend_idx} flashmla_num_splits: MISMATCH! Max diff: {diff}" + ) + + if dst_ref["flashmla_metadata"] is not None and not torch.equal( + dst_ref["flashmla_metadata"], dst_fused["flashmla_metadata"] + ): + diff = ( + ( + dst_ref["flashmla_metadata"].float() + - dst_fused["flashmla_metadata"].float() + ) + .abs() + .max() + .item() + ) + max_diff = max(max_diff, diff) + all_match = False + print( + f"[ERROR] Backend {backend_idx} flashmla_metadata: MISMATCH! Max diff: {diff}" + ) + + if not all_match: + error_msg = ( + f"Fused metadata copy verification FAILED! " + f"Maximum difference: {max_diff}. " + f"The fused kernel produces different results than the for-loop version." + ) + print(f"[ERROR] {error_msg}") + raise AssertionError(error_msg) + + print(f"[VERIFY] Verification PASSED - all tensors match!") + + +@pytest.mark.parametrize("bs", [32, 64]) +def test_fused_metadata_copy_multi_large_batch(bs): + """Test with larger batch sizes and timing comparison.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + from sglang.jit_kernel.fused_metadata_copy import fused_metadata_copy_multi_cuda + + max_len = 128 + seqlens_expanded_size = bs + + data = create_test_metadata_multi( + bs=bs, + max_len=max_len, + seqlens_expanded_size=seqlens_expanded_size, + has_real_page_table=True, + has_flashmla=True, + ) + + dst_ref_0 = { + k: v.clone() if v is not None else None for k, v in data["dst0"].items() + } + dst_ref_1 = { + k: v.clone() if v is not None else None for k, v in data["dst1"].items() + } + dst_ref_2 = { + k: v.clone() if v is not None else None for k, v in data["dst2"].items() + } + + dst_fused_0 = { + k: v.clone() if v is not None else None for k, v in data["dst0"].items() + } + dst_fused_1 = { + k: v.clone() if v is not None else None for k, v in data["dst1"].items() + } + dst_fused_2 = { + k: v.clone() if v is not None else None for k, v in data["dst2"].items() + } + + # Warmup + for _ in range(5): + reference_copy_for_loop( + data["src"], [dst_ref_0, dst_ref_1, dst_ref_2], bs, max_len + ) + fused_metadata_copy_multi_cuda( + data["src"]["cache_seqlens"], + data["src"]["cu_seqlens_k"], + data["src"]["page_indices"], + data["src"]["nsa_cache_seqlens"], + data["src"]["nsa_cu_seqlens_k"], + data["src"]["real_page_table"], + data["src"]["flashmla_num_splits"], + data["src"]["flashmla_metadata"], + dst_fused_0["cache_seqlens_int32"], + dst_fused_0["cu_seqlens_k"], + dst_fused_0["page_table_1"], + dst_fused_0["nsa_cache_seqlens_int32"], + dst_fused_0["nsa_cu_seqlens_k"], + dst_fused_0["real_page_table"], + dst_fused_0["flashmla_num_splits"], + dst_fused_0["flashmla_metadata"], + dst_fused_1["cache_seqlens_int32"], + dst_fused_1["cu_seqlens_k"], + dst_fused_1["page_table_1"], + dst_fused_1["nsa_cache_seqlens_int32"], + dst_fused_1["nsa_cu_seqlens_k"], + dst_fused_1["real_page_table"], + dst_fused_1["flashmla_num_splits"], + dst_fused_1["flashmla_metadata"], + dst_fused_2["cache_seqlens_int32"], + dst_fused_2["cu_seqlens_k"], + dst_fused_2["page_table_1"], + dst_fused_2["nsa_cache_seqlens_int32"], + dst_fused_2["nsa_cu_seqlens_k"], + dst_fused_2["real_page_table"], + dst_fused_2["flashmla_num_splits"], + dst_fused_2["flashmla_metadata"], + bs, + max_len, + seqlens_expanded_size, + ) + torch.cuda.synchronize() + + # Actual timing + torch.cuda.synchronize() + loop_start = time.perf_counter() + reference_copy_for_loop(data["src"], [dst_ref_0, dst_ref_1, dst_ref_2], bs, max_len) + torch.cuda.synchronize() + loop_time = time.perf_counter() - loop_start + + torch.cuda.synchronize() + fused_start = time.perf_counter() + fused_metadata_copy_multi_cuda( + data["src"]["cache_seqlens"], + data["src"]["cu_seqlens_k"], + data["src"]["page_indices"], + data["src"]["nsa_cache_seqlens"], + data["src"]["nsa_cu_seqlens_k"], + data["src"]["real_page_table"], + data["src"]["flashmla_num_splits"], + data["src"]["flashmla_metadata"], + dst_fused_0["cache_seqlens_int32"], + dst_fused_0["cu_seqlens_k"], + dst_fused_0["page_table_1"], + dst_fused_0["nsa_cache_seqlens_int32"], + dst_fused_0["nsa_cu_seqlens_k"], + dst_fused_0["real_page_table"], + dst_fused_0["flashmla_num_splits"], + dst_fused_0["flashmla_metadata"], + dst_fused_1["cache_seqlens_int32"], + dst_fused_1["cu_seqlens_k"], + dst_fused_1["page_table_1"], + dst_fused_1["nsa_cache_seqlens_int32"], + dst_fused_1["nsa_cu_seqlens_k"], + dst_fused_1["real_page_table"], + dst_fused_1["flashmla_num_splits"], + dst_fused_1["flashmla_metadata"], + dst_fused_2["cache_seqlens_int32"], + dst_fused_2["cu_seqlens_k"], + dst_fused_2["page_table_1"], + dst_fused_2["nsa_cache_seqlens_int32"], + dst_fused_2["nsa_cu_seqlens_k"], + dst_fused_2["real_page_table"], + dst_fused_2["flashmla_num_splits"], + dst_fused_2["flashmla_metadata"], + bs, + max_len, + seqlens_expanded_size, + ) + torch.cuda.synchronize() + fused_time = time.perf_counter() - fused_start + + speedup = loop_time / fused_time if fused_time > 0 else 0 + print( + f"\n[PERF] Large batch (bs={bs}): Fused={fused_time*1000:.3f}ms, Loop={loop_time*1000:.3f}ms, Speedup={speedup:.2f}x" + ) + + # Verify correctness + for backend_idx, (dst_ref, dst_fused) in enumerate( + [ + (dst_ref_0, dst_fused_0), + (dst_ref_1, dst_fused_1), + (dst_ref_2, dst_fused_2), + ] + ): + for key in dst_ref: + if dst_ref[key] is not None and dst_fused[key] is not None: + assert torch.equal( + dst_ref[key], dst_fused[key] + ), f"Backend {backend_idx} {key} mismatch" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_fused_norm_scale_shift.py b/sglang/python/sglang/jit_kernel/tests/test_fused_norm_scale_shift.py new file mode 100644 index 0000000000000000000000000000000000000000..592103b12fdf95c83a99a218523840f5ad20b45b --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_fused_norm_scale_shift.py @@ -0,0 +1,236 @@ +from typing import Optional, Tuple + +import pytest +import torch +from einops import rearrange +from torch import Tensor + +from sglang.jit_kernel.diffusion.cutedsl.scale_residual_norm_scale_shift import ( + fused_norm_scale_shift, + fused_scale_residual_norm_scale_shift, +) + +DEVICE = "cuda" +SHAPE_MAP = { + "1": lambda B, S, F, D: (1,), + "D": lambda B, S, F, D: (D,), + "1D": lambda B, S, F, D: (1, D), + "BD": lambda B, S, F, D: (B, D), + "11D": lambda B, S, F, D: (1, 1, D), + "B1D": lambda B, S, F, D: (B, 1, D), + "1SD": lambda B, S, F, D: (1, S, D), + "BSD": lambda B, S, F, D: (B, S, D), + "BF1D": lambda B, S, F, D: (B, F, 1, D), +} +SHAPES = [ + # (B, S, F, D) + (1, 115200, 1, 3072), # Hunyuan + (1, 32760, 1, 1536), # Wan + (1, 6, 1, 3072), # Qwen + (1, 1024, 8, 3072), + (4, 512, 16, 3072), +] +DTYPES = [torch.float16, torch.bfloat16, torch.float32] +NORM_TYPES = ["layer", "rms"] +AFFINE_MODES = ["D", "NAT"] +INDEX_MODES = ["BSD", "1", "1SD", "BD", "B1D", "D", "1D", "11D", "BF1D"] + + +def _tol(dtype: torch.dtype): + return 1e-5 if dtype == torch.float32 else 5e-2 + + +@pytest.fixture(autouse=True) +def cuda_setup(): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + torch.cuda.manual_seed(0) + + +def _apply_scale_shift(y: Tensor, scale: Tensor, shift: Tensor) -> Tensor: + if scale.ndim == 4: + num_frame = scale.shape[1] + return rearrange( + rearrange(y, "b (f l) d -> b f l d", f=num_frame) * (1 + scale) + shift, + "b f l d -> b (f l) d", + ) + else: + scale = rearrange(scale, "b d -> b 1 d") if scale.ndim == 2 else scale + shift = rearrange(shift, "b d -> b 1 d") if shift.ndim == 2 else shift + return y * (1 + scale) + shift + + +def fused_norm_scale_shift_ref( + x: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + scale: Tensor, + shift: Tensor, + norm_type: str, + eps: float, +) -> Tensor: + original_dtype = x.dtype + x, weight, bias, scale, shift = ( + v.float() if v is not None else v for v in [x, weight, bias, scale, shift] + ) + if norm_type == "layer": + norm = torch.layer_norm(x, x.shape[-1:], eps=eps, weight=weight, bias=bias) + else: + norm = torch.rms_norm(x, x.shape[-1:], eps=eps, weight=weight) + return _apply_scale_shift(norm, scale, shift).to(original_dtype) + + +def fused_scale_residual_norm_scale_shift_ref( + residual: Tensor, + x: Tensor, + gate: Optional[Tensor] | int, + weight: Optional[Tensor], + bias: Optional[Tensor], + scale: Tensor, + shift: Tensor, + norm_type: str, + eps: float, +): + original_dtype = x.dtype + residual, x, gate, weight, bias, scale, shift = ( + v.float() if isinstance(v, Tensor) else v + for v in [residual, x, gate, weight, bias, scale, shift] + ) + if isinstance(gate, int): + x = residual + gate * x + else: + if gate.ndim == 4: + num_frame = gate.shape[1] + x_fld = rearrange(x, "b (f l) d -> b f l d", f=num_frame) + x = residual + rearrange(x_fld * gate, "b f l d -> b (f l) d") + else: + gate = rearrange(gate, "b d -> b 1 d") if gate.ndim == 2 else gate + x = residual + gate * x + if norm_type == "layer": + norm = torch.layer_norm(x, x.shape[-1:], eps=eps, weight=weight, bias=bias) + else: + norm = torch.rms_norm(x, x.shape[-1:], eps=eps, weight=weight) + y_ref = _apply_scale_shift(norm, scale, shift) + return y_ref.to(original_dtype), x.to(original_dtype) + + +def _make_tensor(index_mode: str, shape: Tuple, dtype: torch.dtype): + if index_mode == "NAT": + return None + return torch.randn(*SHAPE_MAP[index_mode](*shape), device=DEVICE, dtype=dtype) + + +@torch.no_grad() +def run_norm_scale_shift( + shape=SHAPES[0], + dtype=DTYPES[0], + affine_dtype=DTYPES[0], + scale_dtype=DTYPES[0], + shift_dtype=DTYPES[0], + norm_type=NORM_TYPES[0], + affine_mode=AFFINE_MODES[0], + scale_mode="BSD", + shift_mode="BSD", + eps=1e-5, +): + x = _make_tensor("BSD", shape, dtype) + weight = _make_tensor(affine_mode, shape, affine_dtype) + bias = _make_tensor(affine_mode, shape, affine_dtype) + scale = _make_tensor(scale_mode, shape, scale_dtype) + shift = _make_tensor(shift_mode, shape, shift_dtype) + y_dev = fused_norm_scale_shift(x, weight, bias, scale, shift, norm_type, eps) + y_ref = fused_norm_scale_shift_ref(x, weight, bias, scale, shift, norm_type, eps) + torch.testing.assert_close(y_dev, y_ref, atol=_tol(dtype), rtol=_tol(dtype)) + + +@torch.no_grad() +def run_scale_resi_norm_scale_shift( + shape=SHAPES[0], + dtype=DTYPES[0], + affine_dtype=DTYPES[0], + scale_dtype=DTYPES[0], + shift_dtype=DTYPES[0], + norm_type=NORM_TYPES[0], + affine_mode=AFFINE_MODES[0], + gate_mode="B1D", + scale_mode="BSD", + shift_mode="BSD", + eps=1e-5, +): + residual = _make_tensor("BSD", shape, dtype) + x = _make_tensor("BSD", shape, dtype) + gate = _make_tensor(gate_mode, shape, dtype) + weight = _make_tensor(affine_mode, shape, affine_dtype) + bias = _make_tensor(affine_mode, shape, affine_dtype) + scale = _make_tensor(scale_mode, shape, scale_dtype) + shift = _make_tensor(shift_mode, shape, shift_dtype) + y_dev, res_dev = fused_scale_residual_norm_scale_shift( + residual, x, gate, weight, bias, scale, shift, norm_type, eps + ) + y_ref, res_ref = fused_scale_residual_norm_scale_shift_ref( + residual, x, gate, weight, bias, scale, shift, norm_type, eps + ) + torch.testing.assert_close(y_dev, y_ref, atol=_tol(dtype), rtol=_tol(dtype)) + torch.testing.assert_close(res_dev, res_ref, atol=_tol(dtype), rtol=_tol(dtype)) + + +@pytest.mark.parametrize("norm_type", NORM_TYPES) +class TestFusedNormScaleShift: + @pytest.mark.parametrize("shape", SHAPES) + @pytest.mark.parametrize("dtype", DTYPES) + def test_shape_dtype(self, shape, dtype, norm_type): + run_norm_scale_shift(shape=shape, dtype=dtype, norm_type=norm_type) + + @pytest.mark.parametrize("dtype", DTYPES) + def test_dtype_0(self, dtype, norm_type): + run_norm_scale_shift(affine_dtype=dtype, norm_type=norm_type) + + @pytest.mark.parametrize("dtype", DTYPES) + def test_dtype_1(self, dtype, norm_type): + run_norm_scale_shift(scale_dtype=dtype, shift_dtype=dtype, norm_type=norm_type) + + @pytest.mark.parametrize("affine_mode", AFFINE_MODES) + def test_normtype_affine(self, affine_mode, norm_type): + run_norm_scale_shift(affine_mode=affine_mode, norm_type=norm_type) + + @pytest.mark.parametrize("index_mode", INDEX_MODES) + def test_index_mode(self, index_mode, norm_type): + run_norm_scale_shift( + scale_mode=index_mode, shift_mode=index_mode, norm_type=norm_type + ) + + +@pytest.mark.parametrize("norm_type", NORM_TYPES) +class TestFusedScaleResidualNormScaleShift: + @pytest.mark.parametrize("shape", SHAPES) + @pytest.mark.parametrize("dtype", DTYPES) + def test_shape_dtype(self, shape, dtype, norm_type): + run_scale_resi_norm_scale_shift(shape=shape, dtype=dtype, norm_type=norm_type) + + @pytest.mark.parametrize("dtype", DTYPES) + def test_dtype_0(self, dtype, norm_type): + run_scale_resi_norm_scale_shift(affine_dtype=dtype, norm_type=norm_type) + + @pytest.mark.parametrize("dtype", DTYPES) + def test_dtype_1(self, dtype, norm_type): + run_scale_resi_norm_scale_shift( + scale_dtype=dtype, shift_dtype=dtype, norm_type=norm_type + ) + + @pytest.mark.parametrize("affine_mode", AFFINE_MODES) + def test_normtype_affine(self, affine_mode, norm_type): + run_scale_resi_norm_scale_shift(affine_mode=affine_mode, norm_type=norm_type) + + @pytest.mark.parametrize("index_mode", INDEX_MODES) + def test_scale_shift_index_mode(self, index_mode, norm_type): + run_scale_resi_norm_scale_shift( + scale_mode=index_mode, shift_mode=index_mode, norm_type=norm_type + ) + + @pytest.mark.parametrize("index_mode", INDEX_MODES) + def test_gate_index_mode(self, index_mode, norm_type): + run_scale_resi_norm_scale_shift(gate_mode=index_mode, norm_type=norm_type) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_fused_store_index_cache.py b/sglang/python/sglang/jit_kernel/tests/test_fused_store_index_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..00edb5819cea5329376be9a3f7c019e3588e9d6f --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_fused_store_index_cache.py @@ -0,0 +1,453 @@ +""" +Test for fused_store_index_k_cache kernel. + +Design Notes: + 1. torch.cuda.synchronize() needed after TVM FFI kernel call. + 2. _split_buffer used buf[:, :vb].reshape(-1) which COPIES data for + non-contiguous slices → reference buffer stayed all-zeros. + Fix: use flat byte-offset indexing. + 3. act_quant may use a different quantization scheme → generous tolerance. + 4. FP8 E4M3 1-ULP rounding differences between CUDA hardware cast + (__nv_fp8_e4m3) and PyTorch .to(float8_e4m3fn) at tie-break points. + Adjacent FP8 representable values at the high end differ by up to 32 + in float space (e.g. 288, 320, 352, ..., 448). + Need to compare dequantized values with FP8-appropriate tolerance. +""" + +from __future__ import annotations + +from typing import Optional, Tuple + +import pytest +import torch + +try: + from sglang.jit_kernel.fused_store_index_cache import ( + can_use_nsa_fused_store, + fused_store_index_k_cache, + ) + + HAS_FUSED = True +except ImportError: + HAS_FUSED = False + +try: + from sglang.srt.utils import is_hip + + _is_hip = is_hip() +except ImportError: + _is_hip = False + +try: + from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + + _is_fp8_fnuz = is_fp8_fnuz() +except ImportError: + _is_fp8_fnuz = False + +PAGE_SIZE = 64 +HEAD_DIM = 128 +FP8_E4M3_MAX = 448.0 +FP8_DTYPE = torch.float8_e4m3fn +BYTES_PER_TOKEN = 128 + 4 # 128 fp8 bytes + 4 scale bytes +BYTES_PER_PAGE = PAGE_SIZE * BYTES_PER_TOKEN + + +def _skip_if_unavailable(page_size: int = PAGE_SIZE): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + if _is_hip: + pytest.skip("Fused store kernel is CUDA-specific") + if _is_fp8_fnuz: + pytest.skip("Fused store path disabled for FP8 FNUZ") + if not hasattr(torch, "float8_e4m3fn"): + pytest.skip("torch.float8_e4m3fn not available") + if not HAS_FUSED: + pytest.skip("fused_store_index_cache not importable") + if not can_use_nsa_fused_store(torch.bfloat16, torch.int64, page_size): + pytest.skip("JIT kernel unavailable / failed to compile") + + +def _num_pages(loc: torch.Tensor, page_size: int, extra: int = 1) -> int: + return int(loc.max().item()) // page_size + 1 + extra + + +def _make_buffer(num_pages: int, page_size: int = PAGE_SIZE) -> torch.Tensor: + return torch.zeros( + (num_pages, page_size * BYTES_PER_TOKEN), + dtype=torch.uint8, + device="cuda", + ) + + +def _read_token_from_buffer( + buf: torch.Tensor, + token_idx: int, + page_size: int = PAGE_SIZE, +) -> Tuple[torch.Tensor, float]: + """ + Read a single token's fp8 values and scale from the paged buffer + using flat byte offsets. + """ + page = token_idx // page_size + offset = token_idx % page_size + page_bytes = page_size * BYTES_PER_TOKEN + + buf_flat = buf.reshape(-1) + + val_start = page * page_bytes + offset * 128 + fp8_bytes = buf_flat[val_start : val_start + 128] + fp8_vals = fp8_bytes.view(FP8_DTYPE).float() + + scale_start = page * page_bytes + 128 * page_size + offset * 4 + scale_bytes = buf_flat[scale_start : scale_start + 4] + scale = scale_bytes.view(torch.float32).item() + + return fp8_vals, scale + + +def _write_token_to_buffer( + buf: torch.Tensor, + token_idx: int, + fp8_data: torch.Tensor, + scale: float, + page_size: int = PAGE_SIZE, +) -> None: + """ + Write a single token's fp8 values and scale into the paged buffer + using flat byte offsets on buf.reshape(-1) (which is a true view + since buf is contiguous). + """ + page = token_idx // page_size + offset = token_idx % page_size + page_bytes = page_size * BYTES_PER_TOKEN + + buf_flat = buf.reshape(-1) + + val_start = page * page_bytes + offset * 128 + buf_flat[val_start : val_start + 128] = fp8_data.view(torch.uint8) + + scale_start = page * page_bytes + 128 * page_size + offset * 4 + scale_t = torch.tensor([scale], dtype=torch.float32, device=buf.device) + buf_flat[scale_start : scale_start + 4] = scale_t.view(torch.uint8) + + +def _gather_tokens( + buf: torch.Tensor, + loc: torch.Tensor, + page_size: int = PAGE_SIZE, +) -> Tuple[torch.Tensor, torch.Tensor]: + N = loc.shape[0] + fp8_f32 = torch.empty((N, HEAD_DIM), dtype=torch.float32, device=buf.device) + scales = torch.empty((N,), dtype=torch.float32, device=buf.device) + for i in range(N): + idx = int(loc[i].item()) + vals, s = _read_token_from_buffer(buf, idx, page_size) + fp8_f32[i] = vals + scales[i] = s + return fp8_f32, scales + + +# Reference kernel +def _reference_quantize_and_store( + key_bf16: torch.Tensor, + loc: torch.Tensor, + num_pages: int, + page_size: int = PAGE_SIZE, +) -> torch.Tensor: + """ + Reference kernel of the fused kernel's quantization: + abs_max = max(|row|) + scale = max(1e-4, abs_max) / 448 + fp8_val = clip(val / scale, -448, 448) -> cast to fp8 + """ + N = key_bf16.shape[0] + key_f32 = key_bf16.float() + buf = _make_buffer(num_pages, page_size) + + for i in range(N): + row = key_f32[i] + abs_max = row.abs().max().item() + scale = max(1e-4, abs_max) / FP8_E4M3_MAX + inv_scale = 1.0 / scale + quantized = (row * inv_scale).clamp(-FP8_E4M3_MAX, FP8_E4M3_MAX) + quantized_fp8 = quantized.to(FP8_DTYPE) + + idx = int(loc[i].item()) + _write_token_to_buffer(buf, idx, quantized_fp8, scale, page_size) + + return buf + + +def _import_act_quant(): + try: + from sglang.srt.layers.attention.nsa.triton_kernel import act_quant + + return act_quant + except Exception: + return None + + +def _ref_store_via_act_quant( + key_bf16: torch.Tensor, + loc: torch.Tensor, + num_pages: int, + page_size: int = PAGE_SIZE, + block_size: int = 128, + scale_fmt: Optional[str] = None, +) -> Optional[torch.Tensor]: + act_quant = _import_act_quant() + if act_quant is None: + return None + + try: + k_fp8, k_scale = act_quant(key_bf16, block_size, scale_fmt) + except TypeError: + k_fp8, k_scale = act_quant(key_bf16, block_size) + + if k_fp8.dim() == 3 and k_fp8.shape[1] == 1: + k_fp8 = k_fp8.squeeze(1) + if k_scale is not None and k_scale.dim() == 3 and k_scale.shape[1] == 1: + k_scale = k_scale.squeeze(1) + k_scale = k_scale.view(-1).float() + + buf = _make_buffer(num_pages, page_size) + N = key_bf16.shape[0] + for i in range(N): + idx = int(loc[i].item()) + _write_token_to_buffer( + buf, idx, k_fp8[i].to(FP8_DTYPE), k_scale[i].item(), page_size + ) + return buf + + +# TEST 1: Fused kernel vs. its own algorithm (pure-Python reference) +# +# NOTE on FP8 rounding: +# CUDA hardware fp8 cast (__nv_fp8_e4m3) and PyTorch .to(float8_e4m3fn) +# may round differently at tie-break points. This causes up to 1-ULP +# differences in the FP8 codes. In FP8 E4M3, adjacent representable +# values at the high end differ by up to 32 in float space (e.g. +# 288 vs 320). After dequantization (fp8_float * scale), the error +# from 1-ULP is: scale * ulp ≈ (abs_max/448) * 32 ≈ 0.07 * abs_max. +# For randn inputs (abs_max ≈ 3-4), this is about 0.2-0.3. +# +# We therefore compare dequantized values with tolerances that +# accommodate 1-ULP FP8 rounding, NOT byte-exact fp8 codes. +@pytest.mark.parametrize( + "num_tokens,base_index", + [(1, 0), (32, 0), (64, 0), (128, 64), (257, 65), (512, 0)], +) +def test_fused_kernel_matches_own_algorithm(num_tokens: int, base_index: int): + """Compare fused CUDA kernel against a pure-Python implementation + of the *same* quantization formula.""" + _skip_if_unavailable() + device = torch.device("cuda") + + key = torch.randn((num_tokens, HEAD_DIM), device=device, dtype=torch.bfloat16) + loc = ( + base_index + torch.randperm(num_tokens, device=device, dtype=torch.int64) + ).contiguous() + num_pages = _num_pages(loc, PAGE_SIZE) + + # Reference kernel + ref_buf = _reference_quantize_and_store(key, loc, num_pages) + + # Fused kernel + out_buf = _make_buffer(num_pages) + fused_store_index_k_cache(key, out_buf, loc, page_size=PAGE_SIZE) + torch.cuda.synchronize() + + out_f, out_s = _gather_tokens(out_buf, loc) + ref_f, ref_s = _gather_tokens(ref_buf, loc) + + # 1) Scales must match tightly (same f32 formula, no rounding ambiguity) + torch.testing.assert_close(out_s, ref_s, rtol=1e-5, atol=1e-7) + + # 2) Most FP8 codes should match; allow rare 1-ULP differences. + # 1-ULP at FP8 E4M3 high end = 32 in float space. + mismatch = out_f != ref_f + mismatch_frac = mismatch.float().mean().item() + assert mismatch_frac < 0.01, ( + f"Too many FP8 code mismatches: {mismatch_frac:.2%} " + f"(expected < 1% from rounding tie-breaks)" + ) + + # 3) Where codes differ, the difference should be exactly 1 ULP. + # In FP8 E4M3: if the float-cast value is V, the adjacent value + # differs by ~V * 0.1 (relative) at most. + if mismatch.any(): + diff = (out_f[mismatch] - ref_f[mismatch]).abs() + rel_diff = diff / ref_f[mismatch].abs().clamp(min=1e-6) + # 1-ULP relative difference for E4M3 is at most ~12.5% (2^-3) + assert rel_diff.max().item() <= 0.15, ( + f"FP8 code difference exceeds 1-ULP: max relative diff = " + f"{rel_diff.max().item():.4f}" + ) + + # 4) Dequantized values should be close. + # Max error from 1-ULP: scale * fp8_ulp ≈ (abs_max/448) * 32 + # For randn abs_max ≈ 3-4: max_err ≈ 0.21 - 0.29 + out_deq = out_f * out_s.unsqueeze(-1) + ref_deq = ref_f * ref_s.unsqueeze(-1) + torch.testing.assert_close(out_deq, ref_deq, rtol=0.15, atol=0.5) + + +# TEST 2: Cross-check against act_quant +@pytest.mark.parametrize("scale_fmt", [None, "fp32"]) +def test_fused_kernel_vs_act_quant_semantic(scale_fmt: Optional[str]): + """Both fused kernel and act_quant should approximately reconstruct + the original bf16 values.""" + _skip_if_unavailable() + device = torch.device("cuda") + + num_tokens = 257 + base_index = 65 + key = torch.randn((num_tokens, HEAD_DIM), device=device, dtype=torch.bfloat16) + loc = ( + base_index + torch.randperm(num_tokens, device=device, dtype=torch.int64) + ).contiguous() + num_pages = _num_pages(loc, PAGE_SIZE) + + ref_buf = _ref_store_via_act_quant(key, loc, num_pages, scale_fmt=scale_fmt) + if ref_buf is None: + pytest.skip("act_quant not available") + + out_buf = _make_buffer(num_pages) + fused_store_index_k_cache(key, out_buf, loc, page_size=PAGE_SIZE) + torch.cuda.synchronize() + + out_f, out_s = _gather_tokens(out_buf, loc) + ref_f, ref_s = _gather_tokens(ref_buf, loc) + + out_deq = out_f * out_s.unsqueeze(-1) + ref_deq = ref_f * ref_s.unsqueeze(-1) + orig_f32 = key.float() + + # Fused kernel should reconstruct original within FP8 precision + torch.testing.assert_close( + out_deq, + orig_f32, + rtol=0.15, + atol=5e-2, + msg="Fused kernel dequantized values don't approximate original", + ) + + # act_quant may use a very different scale policy. + try: + torch.testing.assert_close( + ref_deq, + orig_f32, + rtol=0.25, + atol=0.5, + msg="act_quant dequantized values don't approximate original", + ) + except AssertionError: + nonzero_frac = (ref_deq.abs() > 1e-6).float().mean().item() + if nonzero_frac < 0.5: + pytest.fail( + f"act_quant output looks mostly zero ({nonzero_frac:.1%} nonzero)." + ) + else: + pytest.skip( + f"act_quant uses a very different quantization scheme " + f"(scale_fmt={scale_fmt}). Fused kernel validated independently." + ) + + torch.testing.assert_close( + out_deq, + ref_deq, + rtol=0.3, + atol=0.5, + msg="Fused and act_quant dequantized values diverge too much", + ) + + +# TEST 3: Roundtrip reconstruction +@pytest.mark.parametrize("num_tokens", [1, 64, 257]) +def test_roundtrip_reconstruction(num_tokens: int): + _skip_if_unavailable() + device = torch.device("cuda") + + key = torch.randn((num_tokens, HEAD_DIM), device=device, dtype=torch.bfloat16) + loc = torch.arange(num_tokens, device=device, dtype=torch.int64) + num_pages = _num_pages(loc, PAGE_SIZE) + + buf = _make_buffer(num_pages) + fused_store_index_k_cache(key, buf, loc, page_size=PAGE_SIZE) + torch.cuda.synchronize() + + fp8_f32, scales = _gather_tokens(buf, loc) + reconstructed = fp8_f32 * scales.unsqueeze(-1) + original = key.float() + + torch.testing.assert_close(reconstructed, original, rtol=0.15, atol=5e-2) + + per_row_energy = reconstructed.abs().sum(dim=-1) + orig_energy = original.abs().sum(dim=-1) + mask = orig_energy > 0.1 + assert ( + per_row_energy[mask] > 0.01 + ).all(), "Some tokens have zero reconstruction — kernel may not be writing output" + + +# TEST 4: Boundary conditions +def test_single_token(): + _skip_if_unavailable() + device = torch.device("cuda") + + key = torch.randn((1, HEAD_DIM), device=device, dtype=torch.bfloat16) + loc = torch.tensor([0], device=device, dtype=torch.int64) + + buf = _make_buffer(1) + fused_store_index_k_cache(key, buf, loc, page_size=PAGE_SIZE) + torch.cuda.synchronize() + + fp8_f32, scales = _gather_tokens(buf, loc) + reconstructed = fp8_f32 * scales.unsqueeze(-1) + torch.testing.assert_close(reconstructed, key.float(), rtol=0.15, atol=5e-2) + + +# TEST 5: Zero input conditions +def test_zero_input(): + _skip_if_unavailable() + device = torch.device("cuda") + + key = torch.zeros((4, HEAD_DIM), device=device, dtype=torch.bfloat16) + loc = torch.arange(4, device=device, dtype=torch.int64) + + buf = _make_buffer(1) + fused_store_index_k_cache(key, buf, loc, page_size=PAGE_SIZE) + torch.cuda.synchronize() + + fp8_f32, scales = _gather_tokens(buf, loc) + + expected_scale = 1e-4 / FP8_E4M3_MAX + torch.testing.assert_close( + scales, + torch.full_like(scales, expected_scale), + rtol=1e-5, + atol=1e-10, + ) + assert (fp8_f32 == 0).all() + + +# TEST 6: Sanity check — verify reference itself writes non-zero data +def test_reference_writes_nonzero(): + _skip_if_unavailable() + device = torch.device("cuda") + + key = torch.randn((8, HEAD_DIM), device=device, dtype=torch.bfloat16) + loc = torch.arange(8, device=device, dtype=torch.int64) + + buf = _reference_quantize_and_store(key, loc, num_pages=1) + + fp8_f32, scales = _gather_tokens(buf, loc) + deq = fp8_f32 * scales.unsqueeze(-1) + + assert deq.abs().sum().item() > 0, "Reference buffer is all zeros — error!" + torch.testing.assert_close(deq, key.float(), rtol=0.15, atol=5e-2) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_fused_verify_triton_gdn.py b/sglang/python/sglang/jit_kernel/tests/test_fused_verify_triton_gdn.py new file mode 100644 index 0000000000000000000000000000000000000000..a6e048a40b5c09ceecd1f0898587da53c770116f --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_fused_verify_triton_gdn.py @@ -0,0 +1,231 @@ +"""Tests for fused sigmoid gating delta rule MTP kernel (GDN target_verify). + +Compares the fused kernel `fused_sigmoid_gating_delta_rule_update` against +the reference two-step implementation: + 1. g, beta = fused_gdn_gating(A_log, a, b, dt_bias) + 2. o = fused_recurrent_gated_delta_rule_update(q, k, v, g, beta, ...) +""" + +import pytest +import torch + +try: + from sglang.srt.layers.attention.fla.fused_gdn_gating import fused_gdn_gating + from sglang.srt.layers.attention.fla.fused_recurrent import ( + fused_recurrent_gated_delta_rule_update, + ) + from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import ( + fused_sigmoid_gating_delta_rule_update, + ) + + KERNELS_AVAILABLE = True +except ImportError: + KERNELS_AVAILABLE = False + + +def _make_tensors(N, T, H, HV, K, V, device="cuda", seed=2025): + """Create input tensors for GDN target_verify.""" + torch.manual_seed(seed) + A_log = torch.randn(HV, dtype=torch.float32, device=device) + dt_bias = torch.randn(HV, dtype=torch.bfloat16, device=device) + a = torch.randn(1, N * T, HV, dtype=torch.bfloat16, device=device) + b = torch.randn(1, N * T, HV, dtype=torch.bfloat16, device=device) + q = torch.randn(1, N * T, H, K, dtype=torch.bfloat16, device=device) + k = torch.randn(1, N * T, H, K, dtype=torch.bfloat16, device=device) + v = torch.randn(1, N * T, HV, V, dtype=torch.bfloat16, device=device) + indices = torch.arange(N, dtype=torch.int32, device=device) + initial_state = torch.randn(N, HV, K, V, dtype=torch.float, device=device) + cu_seqlens = torch.arange(0, N * T + 1, T, dtype=torch.int32, device=device) + return A_log, dt_bias, a, b, q, k, v, initial_state, indices, cu_seqlens + + +def run_reference( + A_log, + dt_bias, + q, + k, + v, + a, + b, + initial_state_source, + initial_state_indices, + cu_seqlens, + disable_state_update=True, + intermediate_states_buffer=None, + intermediate_state_indices=None, + cache_steps=None, + retrieve_parent_token=None, +): + """Reference: fused_gdn_gating + fused_recurrent_gated_delta_rule_update.""" + # fused_gdn_gating expects 2D [seq_len, HV] + a_2d = a.view(-1, a.shape[-1]) + b_2d = b.view(-1, b.shape[-1]) + g, beta = fused_gdn_gating(A_log, a_2d, b_2d, dt_bias) + # fused_recurrent expects 3D [B, T, HV] + g = g.view(a.shape) + beta = beta.view(b.shape) + + # fused_recurrent requires intermediate_state_indices when cu_seqlens is used + if cu_seqlens is not None and intermediate_state_indices is None: + N = len(cu_seqlens) - 1 + intermediate_state_indices = torch.arange(N, dtype=torch.int32, device=q.device) + + return fused_recurrent_gated_delta_rule_update( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state_source=initial_state_source, + initial_state_indices=initial_state_indices, + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True, + disable_state_update=disable_state_update, + intermediate_states_buffer=intermediate_states_buffer, + intermediate_state_indices=intermediate_state_indices, + cache_steps=cache_steps, + retrieve_parent_token=retrieve_parent_token, + ) + + +def run_fused_mtp( + A_log, + dt_bias, + q, + k, + v, + a, + b, + initial_state_source, + initial_state_indices, + cu_seqlens, + disable_state_update=True, + intermediate_states_buffer=None, + intermediate_state_indices=None, + cache_steps=None, + retrieve_parent_token=None, +): + """Fused: fused_sigmoid_gating_delta_rule_update.""" + return fused_sigmoid_gating_delta_rule_update( + A_log=A_log, + dt_bias=dt_bias, + q=q, + k=k, + v=v, + a=a, + b=b, + initial_state_source=initial_state_source, + initial_state_indices=initial_state_indices, + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True, + softplus_beta=1.0, + softplus_threshold=20.0, + is_kda=False, + disable_state_update=disable_state_update, + intermediate_states_buffer=intermediate_states_buffer, + intermediate_state_indices=intermediate_state_indices, + cache_steps=cache_steps, + retrieve_parent_token=retrieve_parent_token, + ) + + +@pytest.mark.skipif(not KERNELS_AVAILABLE, reason="Kernel not available") +@pytest.mark.parametrize("N", [1, 8, 16]) +@pytest.mark.parametrize("T", [1, 4, 8]) +def test_fused_gdn_mtp_precision(N: int, T: int): + """Compare fused MTP output against reference.""" + H, HV, K, V = 16, 32, 128, 128 + + A_log, dt_bias, a, b, q, k, v, state, indices, cu_seqlens = _make_tensors( + N, T, H, HV, K, V + ) + + state_ref = state.clone() + state_fused = state.clone() + + out_ref = run_reference( + A_log, + dt_bias, + q, + k, + v, + a, + b, + state_ref, + indices, + cu_seqlens, + disable_state_update=True, + ) + out_fused = run_fused_mtp( + A_log, + dt_bias, + q, + k, + v, + a, + b, + state_fused, + indices, + cu_seqlens, + disable_state_update=True, + ) + + torch.testing.assert_close(out_ref, out_fused, rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif(not KERNELS_AVAILABLE, reason="Kernels not available") +@pytest.mark.parametrize("N", [1, 16, 128]) +def test_mtp_single_step_decode(N: int): + """Verify MTP kernel matches reference for T=1 (decode scenario).""" + T = 1 + H, HV, K, V = 16, 32, 128, 128 + + A_log, dt_bias, a, b, q, k, v, state, indices, cu_seqlens = _make_tensors( + N, T, H, HV, K, V + ) + + state_ref = state.clone() + state_fused = state.clone() + + out_ref = run_reference( + A_log, + dt_bias, + q, + k, + v, + a, + b, + state_ref, + indices, + cu_seqlens, + disable_state_update=False, + ) + out_fused = run_fused_mtp( + A_log, + dt_bias, + q, + k, + v, + a, + b, + state_fused, + indices, + cu_seqlens, + disable_state_update=False, + ) + + torch.testing.assert_close(out_ref, out_fused, rtol=1e-2, atol=1e-2) + + # Also verify states match after update + state_diff = (state_ref.float() - state_fused.float()).abs() + state_max_diff = state_diff.max().item() + state_fail_rate = (state_diff > 0.1).float().mean().item() * 100 + print( + f" single_step state N={N}: max_diff={state_max_diff:.2e}, " + f"fail_rate={state_fail_rate:.2f}%" + ) + assert state_fail_rate < 0.01, f"State mismatch: fail_rate={state_fail_rate:.2f}%" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_gptq_marlin.py b/sglang/python/sglang/jit_kernel/tests/test_gptq_marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..c7cdb1e6ce98253a91bb415712ea305bc9ade1e5 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_gptq_marlin.py @@ -0,0 +1,99 @@ +import pytest +import torch +from sgl_kernel.scalar_type import scalar_types + +from sglang.jit_kernel.gptq_marlin import gptq_marlin_gemm +from sglang.srt.layers.quantization.marlin_utils import marlin_make_workspace +from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize + +MNK_FACTORS = [ + (1, 1, 1), + (1, 4, 8), + (13, 17, 67), + (257, 13, 11), +] + + +@pytest.mark.parametrize("k_chunk", [128]) +@pytest.mark.parametrize("n_chunk", [64, 256]) +@pytest.mark.parametrize("quant_type", [scalar_types.uint4, scalar_types.uint4b8]) +@pytest.mark.parametrize("group_size", [-1, 128]) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +@pytest.mark.parametrize("act_order", [False, True]) +def test_gptq_marlin_gemm( + k_chunk, + n_chunk, + quant_type, + group_size, + mnk_factors, + act_order, +): + m_factor, n_factor, k_factor = mnk_factors + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + if act_order: + if group_size == -1: + return + if group_size == size_k: + return + if has_zp: + return + + if size_k % group_size != 0: + return + + a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda") + b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda") + + if has_zp: + w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( + b_weight, quant_type, group_size + ) + g_idx = None + sort_indices = None + marlin_s2 = None + else: + w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + b_weight, quant_type, group_size, act_order + ) + marlin_zp = None + marlin_s2 = None + + workspace = marlin_make_workspace(w_ref.device) + + output = gptq_marlin_gemm( + a_input, + None, + marlin_q_w, + marlin_s, + marlin_s2, + marlin_zp, + g_idx, + sort_indices, + workspace, + quant_type, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + is_k_full=True, + use_atomic_add=False, + use_fp32_reduce=False, + is_zp_float=False, + ) + + output_ref = torch.matmul(a_input, w_ref) + torch.cuda.synchronize() + + # JIT kernel should produce approximately correct results vs torch.matmul + max_diff = torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref) + ) + assert max_diff < 0.04 + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_gptq_marlin_repack.py b/sglang/python/sglang/jit_kernel/tests/test_gptq_marlin_repack.py new file mode 100644 index 0000000000000000000000000000000000000000..0c571dbff292d230a955f0b925e305b86eb69a8c --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_gptq_marlin_repack.py @@ -0,0 +1,90 @@ +import pytest +import torch +from sgl_kernel.scalar_type import scalar_types + +from sglang.jit_kernel.gptq_marlin_repack import gptq_marlin_repack +from sglang.srt.layers.quantization.utils import ( + gptq_quantize_weights, + pack_rows, + sort_weights, +) +from sglang.test.test_marlin_utils import get_weight_perm, marlin_weights + +MARLIN_K_CHUNKS = [128] +MARLIN_N_CHUNKS = [64, 256] + +MNK_FACTORS = [ + (1, 1, 1), + (1, 4, 8), + (1, 7, 5), + (13, 17, 67), + (26, 37, 13), + (67, 13, 11), + (257, 13, 11), + (658, 13, 11), +] + + +@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) +@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) +@pytest.mark.parametrize("quant_type", [scalar_types.uint4b8]) +@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.parametrize("act_order", [False, True]) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +def test_gptq_marlin_repack( + k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors +): + m_factor, n_factor, k_factor = mnk_factors + + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + # Filter act_order + if act_order: + if group_size == -1: + return + if group_size == size_k: + return + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + if size_k % group_size != 0: + pytest.skip("size_k must be divisible by group_size") + + # Create input + b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda") + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + b_weight, quant_type, group_size, act_order + ) + + q_w_gptq = pack_rows(q_w, quant_type.size_bits, size_k, size_n) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + marlin_layout_perm = get_weight_perm(quant_type.size_bits) + q_w_marlin_ref = marlin_weights( + q_w, size_k, size_n, quant_type.size_bits, marlin_layout_perm + ) + + # Run JIT repack kernel + jit_output = gptq_marlin_repack( + q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits + ) + + torch.cuda.synchronize() + + # JIT should match the reference (computed from CPU marlin_weights) + torch.testing.assert_close(jit_output, q_w_marlin_ref) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_moe_wna16_marlin.py b/sglang/python/sglang/jit_kernel/tests/test_moe_wna16_marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..e40f82461719515ebb7bba6e3c370ae5aa8229c5 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_moe_wna16_marlin.py @@ -0,0 +1,327 @@ +import itertools + +import pytest +import torch +from sgl_kernel import moe_wna16_marlin_gemm as aot_moe_wna16_marlin_gemm +from sgl_kernel.scalar_type import scalar_types + +from sglang.jit_kernel.moe_wna16_marlin import moe_wna16_marlin_gemm +from sglang.srt.layers.moe.fused_moe_triton import moe_align_block_size +from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize + + +def stack_and_dev(tensors: list[torch.Tensor]): + dev = tensors[0].device + return torch.stack(tensors, dim=0).to(dev) + + +def _get_scalar_type(num_bits: int, has_zp: bool): + if has_zp: + assert num_bits == 4 + return scalar_types.uint4 + else: + return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 + + +def _setup_moe_weights(e, n, k, quant_type, group_size, act_order, dtype): + """Set up quantized MoE weights for a single gate (e experts, output n, input k).""" + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] + + w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 20 + + w_ref_l = [] + qweight_l = [] + scales_l = [] + zeros_l = [] + g_idx_l = [] + sort_indices_l = [] + + for i in range(e): + if has_zp: + w_ref, qweight, scales, zeros = awq_marlin_quantize( + w[i].transpose(1, 0), quant_type, group_size + ) + w_ref_l.append(w_ref.T) + qweight_l.append(qweight) + scales_l.append(scales) + zeros_l.append(zeros) + else: + test_perm = torch.randperm(k) + w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( + w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + ) + w_ref_l.append(w_ref.T) + qweight_l.append(qweight) + scales_l.append(scales) + g_idx_l.append(g_idx) + sort_indices_l.append(sort_indices) + + w_ref = stack_and_dev(w_ref_l) + qweight = stack_and_dev(qweight_l).contiguous() + scales = stack_and_dev(scales_l) + g_idx = stack_and_dev(g_idx_l) if g_idx_l else None + sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None + zeros = stack_and_dev(zeros_l) if zeros_l else None + + return w_ref, qweight, scales, zeros, g_idx, sort_indices + + +def _run_single_gemm( + fn, + a, + c, + qweight, + scales, + zeros, + g_idx, + sort_indices, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + quant_type, + block_size_m, + topk, + size_m, + size_n, + size_k, + mul_topk_weights, + is_k_full, + use_atomic_add, +): + return fn( + a, + c, + qweight, + None, # b_bias + scales, + None, # global_scale + zeros, + g_idx, + sort_indices, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + moe_block_size=block_size_m, + top_k=topk, + mul_topk_weights=mul_topk_weights, + is_ep=False, + b_q_type=quant_type, + size_m=size_m, + size_n=size_n, + size_k=size_k, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=True, + is_zp_float=False, + ) + + +def _run_single_gemm_aot( + a, + c, + qweight, + scales, + zeros, + g_idx, + sort_indices, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + quant_type, + block_size_m, + topk, + size_m, + size_n, + size_k, + mul_topk_weights, + is_k_full, + use_atomic_add, +): + return aot_moe_wna16_marlin_gemm( + a, + c, + qweight, + None, # b_bias + scales, + None, # global_scale + zeros, + g_idx, + sort_indices, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + moe_block_size=block_size_m, + top_k=topk, + mul_topk_weights=mul_topk_weights, + is_ep=False, + b_q_type_id=quant_type.id, + size_m=size_m, + size_n=size_n, + size_k=size_k, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=True, + is_zp_float=False, + ) + + +def generate_test_cases(): + m_list = [1, 123] + n_list = [128, 1024] + k_list = [256] + e_list = [4] + topk_list = [2] + dtype_list = [torch.float16, torch.bfloat16] + group_size_list = [128] + act_order_list = [False, True] + quant_type_list = [scalar_types.uint4, scalar_types.uint4b8] + + all_combinations = itertools.product( + m_list, + n_list, + k_list, + e_list, + topk_list, + dtype_list, + group_size_list, + act_order_list, + quant_type_list, + ) + + def is_valid(m, n, k, e, topk, dtype, group_size, act_order, quant_type): + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] + if act_order: + if group_size == -1 or group_size == k: + return False + if has_zp: + return False + if group_size > 0 and k % group_size != 0: + return False + return True + + return [case for case in all_combinations if is_valid(*case)] + + +TEST_CASES = generate_test_cases() + + +@pytest.mark.parametrize( + "m,n,k,e,topk,dtype,group_size,act_order,quant_type", + TEST_CASES, + ids=[ + f"m{c[0]}_n{c[1]}_k{c[2]}_e{c[3]}_t{c[4]}_{c[5].__name__ if hasattr(c[5], '__name__') else str(c[5]).split('.')[-1]}_g{c[6]}_act{c[7]}_{c[8]}" + for c in TEST_CASES + ], +) +def test_moe_wna16_marlin_gemm( + m, n, k, e, topk, dtype, group_size, act_order, quant_type +): + torch.manual_seed(0) + + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + + # Set up quantized weights for first gemm (gate_up: output 2*n, input k) + w_ref1, qweight1, scales1, zeros1, g_idx1, sort_indices1 = _setup_moe_weights( + e, 2 * n, k, quant_type, group_size, act_order, dtype + ) + + # Compute block_size_m + for block_size_m in [8, 16, 32, 48, 64]: + if m * topk / e / block_size_m < 0.9: + break + + # Align tokens + score = torch.randn((m, e), device="cuda", dtype=dtype) + score_softmax = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weights, topk_ids = torch.topk(score_softmax, topk) + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, block_size_m, e + ) + + # Workspace + sms = torch.cuda.get_device_properties("cuda").multi_processor_count + max_workspace_size = (max(2 * n, k) // 64) * ( + sorted_token_ids.size(0) // block_size_m + ) + max_workspace_size = min(max_workspace_size, sms * 4) + workspace = torch.zeros( + max_workspace_size, dtype=torch.int, device="cuda", requires_grad=False + ) + + use_atomic_add = ( + dtype == torch.half or torch.cuda.get_device_capability("cuda")[0] >= 9 + ) + + scalar_type = _get_scalar_type(4, has_zp) + + # --- Run JIT kernel --- + c_jit = torch.empty((m * topk, 2 * n), dtype=dtype, device="cuda") + c_jit = _run_single_gemm( + moe_wna16_marlin_gemm, + a, + c_jit, + qweight1, + scales1, + zeros1, + g_idx1, + sort_indices1, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + scalar_type, + block_size_m, + topk, + m, + 2 * n, + k, + False, + True, + use_atomic_add, + ) + + torch.cuda.synchronize() + + # --- Check bitwise equality with AOT kernel --- + c_aot = torch.empty((m * topk, 2 * n), dtype=dtype, device="cuda") + c_aot = _run_single_gemm_aot( + a, + c_aot, + qweight1, + scales1, + zeros1, + g_idx1, + sort_indices1, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + scalar_type, + block_size_m, + topk, + m, + 2 * n, + k, + False, + True, + use_atomic_add, + ) + torch.cuda.synchronize() + torch.testing.assert_close(c_jit, c_aot, rtol=0, atol=0) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_nvfp4_blockwise_moe.py b/sglang/python/sglang/jit_kernel/tests/test_nvfp4_blockwise_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..6eb809d3385f6551adeca869afc9ee49f4b777ff --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_nvfp4_blockwise_moe.py @@ -0,0 +1,127 @@ +import pytest +import torch + +from sglang.jit_kernel.nvfp4 import ( + cutlass_fp4_group_mm, + scaled_fp4_experts_quant, + scaled_fp4_quant, +) + +FLOAT4_E2M1_MAX = 6.0 +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + + +def _nvfp4_supported() -> bool: + return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0) + + +def _round_up(x: int, y: int) -> int: + return ((x + y - 1) // y) * y + + +def _build_expert_offsets( + m_per_expert: list[int], device: torch.device +) -> torch.Tensor: + offsets = [0] + for m in m_per_expert: + offsets.append(offsets[-1] + m) + return torch.tensor(offsets, dtype=torch.int32, device=device) + + +def _build_blockscale_offsets( + m_per_expert: list[int], device: torch.device +) -> torch.Tensor: + offsets = [0] + for m in m_per_expert: + offsets.append(offsets[-1] + _round_up(m, 128)) + return torch.tensor(offsets, dtype=torch.int32, device=device) + + +@pytest.mark.skipif( + not _nvfp4_supported(), reason="NVFP4 requires compute capability >= 10.0" +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_nvfp4_blockwise_moe_grouped_mm(dtype: torch.dtype) -> None: + torch.manual_seed(0) + device = torch.device("cuda") + + num_experts = 4 + m_per_expert = [33, 17, 48, 29] + n = 256 + k = 128 + + expert_offsets_full = _build_expert_offsets(m_per_expert, device) + blockscale_offsets_full = _build_blockscale_offsets(m_per_expert, device) + + total_m = int(expert_offsets_full[-1].item()) + a = torch.randn((total_m, k), device=device, dtype=dtype) * 0.1 + b = torch.randn((num_experts, n, k), device=device, dtype=dtype) * 0.1 + + a_global_scale = torch.empty((num_experts,), device=device, dtype=torch.float32) + for i in range(num_experts): + start = int(expert_offsets_full[i].item()) + end = int(expert_offsets_full[i + 1].item()) + amax = a[start:end].abs().max().to(torch.float32) + a_global_scale[i] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / amax + + b_global_scale = torch.empty((num_experts,), device=device, dtype=torch.float32) + for i in range(num_experts): + bmax = b[i].abs().max().to(torch.float32) + b_global_scale[i] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / bmax + + a_fp4, a_blockscale = scaled_fp4_experts_quant( + a, + a_global_scale, + expert_offsets_full, + blockscale_offsets_full, + topk=1, + ) + + b_fp4 = torch.empty((num_experts, n, k // 2), device=device, dtype=torch.uint8) + b_blockscale = torch.empty( + (num_experts, _round_up(n, 128), _round_up(k // 16, 4)), + device=device, + dtype=torch.float8_e4m3fn, + ) + for i in range(num_experts): + b_fp4_i, b_scale_i = scaled_fp4_quant(b[i], b_global_scale[i]) + b_fp4[i].copy_(b_fp4_i) + b_blockscale[i].copy_(b_scale_i) + + alphas = (1.0 / (a_global_scale * b_global_scale)).to(torch.float32) + + params = { + "ab_strides": torch.full((num_experts,), k, dtype=torch.int64, device=device), + "c_strides": torch.full((num_experts,), n, dtype=torch.int64, device=device), + "problem_sizes": torch.tensor( + [[m, n, k] for m in m_per_expert], dtype=torch.int32, device=device + ), + "expert_offsets": expert_offsets_full[:-1].contiguous(), + "blockscale_offsets": blockscale_offsets_full[:-1].contiguous(), + "a_ptrs": torch.empty((num_experts,), dtype=torch.int64, device=device), + "b_ptrs": torch.empty((num_experts,), dtype=torch.int64, device=device), + "out_ptrs": torch.empty((num_experts,), dtype=torch.int64, device=device), + "a_scales_ptrs": torch.empty((num_experts,), dtype=torch.int64, device=device), + "b_scales_ptrs": torch.empty((num_experts,), dtype=torch.int64, device=device), + "alpha_ptrs": torch.empty((num_experts,), dtype=torch.int64, device=device), + "layout_sfa": torch.empty((num_experts, 5), dtype=torch.int64, device=device), + "layout_sfb": torch.empty((num_experts, 5), dtype=torch.int64, device=device), + } + + out = cutlass_fp4_group_mm( + a_fp4, + b_fp4, + a_blockscale, + b_blockscale, + alphas, + dtype, + params, + ) + + ref = torch.empty((total_m, n), device=device, dtype=dtype) + for i in range(num_experts): + start = int(expert_offsets_full[i].item()) + end = int(expert_offsets_full[i + 1].item()) + ref[start:end] = torch.matmul(a[start:end], b[i].t()) + + torch.testing.assert_close(out, ref, atol=1e-1, rtol=1e-1) diff --git a/sglang/python/sglang/jit_kernel/tests/test_nvfp4_gemm.py b/sglang/python/sglang/jit_kernel/tests/test_nvfp4_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..9a76cd8009d76f8cd933152f3beca3998e1383b2 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_nvfp4_gemm.py @@ -0,0 +1,142 @@ +import pytest +import torch + +from sglang.jit_kernel.nvfp4 import cutlass_scaled_fp4_mm, scaled_fp4_quant + + +def _nvfp4_supported() -> bool: + return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0) + + +DTYPES = [torch.float16, torch.bfloat16] +SHAPES = [ + (128, 128, 64), + (128, 128, 128), + (256, 128, 64), + (128, 256, 128), + (150, 128, 64), +] + +FLOAT4_E2M1_MAX = 6.0 +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +K_E2M1_TO_FLOAT = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, +] + + +def e2m1_to_fp32(int4_value: int) -> float: + sign_bit = int4_value & 0x8 + int4_abs_value = int4_value & 0x7 + float_result = K_E2M1_TO_FLOAT[int4_abs_value] + return -float_result if sign_bit else float_result + + +def break_fp4_bytes(a: torch.Tensor) -> torch.Tensor: + assert a.dtype == torch.uint8 + m, n = a.shape + a = a.flatten() + high_half_byte = (a & 0xF0) >> 4 + low_half_byte = a & 0x0F + f_h = torch.tensor([e2m1_to_fp32(x) for x in high_half_byte], device=a.device) + f_l = torch.tensor([e2m1_to_fp32(x) for x in low_half_byte], device=a.device) + return torch.stack((f_l, f_h), dim=-1).reshape(m, n * 2) + + +def convert_swizzled_to_linear( + a_sf_swizzled: torch.Tensor, m: int, k: int, block_size: int +) -> torch.Tensor: + sf_m, sf_k = a_sf_swizzled.shape + del sf_m, sf_k + m_tiles = (m + 128 - 1) // 128 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + return out[0:m, 0 : k // block_size] + + +def dequantize_to_dtype( + tensor_fp4: torch.Tensor, + tensor_sf: torch.Tensor, + global_scale: torch.Tensor, + block_size: int = 16, +) -> torch.Tensor: + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + return (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + + +def get_ref_results( + a_fp4: torch.Tensor, + b_fp4: torch.Tensor, + a_sf: torch.Tensor, + b_sf: torch.Tensor, + a_global_scale: torch.Tensor, + b_global_scale: torch.Tensor, + block_size: int, +) -> torch.Tensor: + a_in_dtype = dequantize_to_dtype(a_fp4, a_sf, a_global_scale, block_size=block_size) + b_in_dtype = dequantize_to_dtype(b_fp4, b_sf, b_global_scale, block_size=block_size) + return torch.matmul(a_in_dtype, b_in_dtype.t()) + + +@pytest.mark.skipif( + not _nvfp4_supported(), reason="NVFP4 requires compute capability >= 10.0" +) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +def test_nvfp4_gemm(dtype: torch.dtype, shape: tuple[int, int, int]) -> None: + m, n, packed_k = shape + k = packed_k * 2 + block_size = 16 + + a_dtype = torch.randn((m, k), dtype=dtype, device="cuda") + b_dtype = torch.randn((n, k), dtype=dtype, device="cuda") + + a_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1) + ).to(torch.float32) + b_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1) + ).to(torch.float32) + + alpha = 1.0 / (a_global_scale * b_global_scale) + + a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale) + b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale) + + expected_out = get_ref_results( + a_fp4, + b_fp4, + a_scale_interleaved, + b_scale_interleaved, + a_global_scale, + b_global_scale, + block_size, + ) + + out = cutlass_scaled_fp4_mm( + a_fp4, + b_fp4, + a_scale_interleaved, + b_scale_interleaved, + alpha, + dtype, + ) + + torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1) diff --git a/sglang/python/sglang/jit_kernel/tests/test_nvfp4_quant.py b/sglang/python/sglang/jit_kernel/tests/test_nvfp4_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..eb7816408b60ce04bf147ce16583b4ea9cdfc979 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_nvfp4_quant.py @@ -0,0 +1,214 @@ +import pytest +import torch + +from sglang.jit_kernel.nvfp4 import ( + scaled_fp4_grouped_quant, + scaled_fp4_quant, + silu_and_mul_scaled_fp4_grouped_quant, +) + +try: + from sgl_kernel import silu_and_mul as _sgl_silu_and_mul +except Exception: + _sgl_silu_and_mul = None + + +def _nvfp4_supported() -> bool: + return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0) + + +def _silu_and_mul_reference(x: torch.Tensor) -> torch.Tensor: + if _sgl_silu_and_mul is not None: + return _sgl_silu_and_mul(x) + k = x.shape[-1] // 2 + return torch.nn.functional.silu(x[:, :, :k]) * x[:, :, k:] + + +DTYPES = [torch.float16, torch.bfloat16] +SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)] +PAD_SHAPES = [ + (90, 64), + (150, 64), + (128, 48), + (128, 80), +] + +FLOAT4_E2M1_MAX = 6.0 +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max +BLOCK_SIZE = 16 + +E2M1_TO_FLOAT32 = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + 0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + + +def cast_from_fp4(x: torch.Tensor, m: int, n: int) -> torch.Tensor: + v_2nd = (x & 0xF).to(torch.long) + v_1st = ((x >> 4) & 0xF).to(torch.long) + c = torch.stack((v_2nd, v_1st), dim=-1).flatten() + lut = torch.tensor(E2M1_TO_FLOAT32, device=x.device, dtype=torch.float32) + return lut[c].reshape(m, n) + + +def cast_to_fp4(x: torch.Tensor) -> torch.Tensor: + sign = torch.sign(x) + x = torch.abs(x) + x[(x >= 0.0) & (x <= 0.25)] = 0.0 + x[(x > 0.25) & (x < 0.75)] = 0.5 + x[(x >= 0.75) & (x <= 1.25)] = 1.0 + x[(x > 1.25) & (x < 1.75)] = 1.5 + x[(x >= 1.75) & (x <= 2.5)] = 2.0 + x[(x > 2.5) & (x < 3.5)] = 3.0 + x[(x >= 3.5) & (x <= 5.0)] = 4.0 + x[x > 5.0] = 6.0 + return x * sign + + +def get_reciprocal(x): + if isinstance(x, torch.Tensor): + return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x) + return 0.0 if x == 0 else 1.0 / x + + +def ref_nvfp4_quant(x: torch.Tensor, global_scale: torch.Tensor): + assert global_scale.dtype == torch.float32 + assert x.ndim == 2 + m, n = x.shape + x = torch.reshape(x, (m, n // BLOCK_SIZE, BLOCK_SIZE)) + vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32) + scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) + scale = scale.to(torch.float8_e4m3fn).to(torch.float32) + output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) + + scaled_x = x.to(torch.float32) * output_scale + clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n) + return cast_to_fp4(clipped_x), scale.squeeze(-1) + + +def recover_swizzled_scales(scale: torch.Tensor, m: int, n: int) -> torch.Tensor: + rounded_m = ((m + 128 - 1) // 128) * 128 + scale_n = n // BLOCK_SIZE + rounded_n = ((scale_n + 4 - 1) // 4) * 4 + tmp = torch.reshape(scale, (1, rounded_m // 128, rounded_n // 4, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + result = torch.reshape(tmp, (rounded_m, rounded_n)).to(torch.float32) + return result[:m, :scale_n] + + +@pytest.mark.skipif( + not _nvfp4_supported(), reason="NVFP4 requires compute capability >= 10.0" +) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +def test_quantize_to_fp4(dtype: torch.dtype, shape: tuple[int, int]) -> None: + torch.manual_seed(42) + m, n = shape + + x = torch.randn((m, n), dtype=dtype, device="cuda") + tensor_amax = torch.abs(x).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + out_ref, scale_ref = ref_nvfp4_quant(x, global_scale) + + out, out_scale = scaled_fp4_quant(x, global_scale) + scale_ans = recover_swizzled_scales(out_scale, m, n) + out_ans = cast_from_fp4(out, m, n) + + torch.testing.assert_close(out_ans, out_ref) + torch.testing.assert_close(scale_ans, scale_ref) + + +@pytest.mark.skipif( + not _nvfp4_supported(), reason="NVFP4 requires compute capability >= 10.0" +) +@pytest.mark.parametrize("shape", PAD_SHAPES) +def test_quantize_to_fp4_padded(shape: tuple[int, int]) -> None: + torch.manual_seed(42) + m, n = shape + x = torch.randn((m, n), dtype=torch.float16, device="cuda") + + tensor_amax = torch.abs(x).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + out_ref, scale_ref = ref_nvfp4_quant(x, global_scale) + + out, out_scale = scaled_fp4_quant(x, global_scale) + scale_ans = recover_swizzled_scales(out_scale, m, n) + out_ans = cast_from_fp4(out, m, n) + + torch.testing.assert_close(out_ans, out_ref) + torch.testing.assert_close(scale_ans, scale_ref) + + +@pytest.mark.skipif( + not _nvfp4_supported(), reason="NVFP4 requires compute capability >= 10.0" +) +@pytest.mark.parametrize("shape", [(2, 128, 512), (2, 100, 128)]) +def test_quantize_to_fp4_grouped(shape: tuple[int, int, int]) -> None: + torch.manual_seed(42) + l, m, k = shape + + x = torch.randn((l, m, k), dtype=torch.bfloat16, device="cuda") + mask = torch.randint(1, max(2, m // 2), (l,), dtype=torch.int32, device="cuda") + tensor_amax = x.abs().amax(dim=(1, 2)).to(torch.float32) + x_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + + output, output_scales = scaled_fp4_grouped_quant(x, x_sf_global, mask) + output = output.permute(2, 0, 1) + padded_m = ((m + 128 - 1) // 128) * 128 + output_scales = output_scales.permute(5, 2, 4, 0, 1, 3).view(l, padded_m, -1) + + for i in range(l): + a_fp4, a_scale_interleaved = scaled_fp4_quant(x[i], x_sf_global[i]) + torch.testing.assert_close(a_fp4[: mask[i]], output[i][: mask[i]]) + scale_ref = recover_swizzled_scales(a_scale_interleaved, m, k) + scale_ans = recover_swizzled_scales(output_scales[i], m, k) + torch.testing.assert_close(scale_ref[: mask[i]], scale_ans[: mask[i]]) + + +@pytest.mark.skipif( + not _nvfp4_supported(), reason="NVFP4 requires compute capability >= 10.0" +) +@pytest.mark.parametrize("shape", [(4, 96, 256), (8, 128, 512)]) +def test_silu_and_mul_quantize_to_fp4_grouped(shape: tuple[int, int, int]) -> None: + torch.manual_seed(42) + l, m, k = shape + + x = torch.randn((l, m, k * 2), dtype=torch.bfloat16, device="cuda") + mask = torch.randint(1, max(2, m // 2), (l,), dtype=torch.int32, device="cuda") + + ref_y = _silu_and_mul_reference(x) + + tensor_amax = ref_y.abs().amax(dim=(1, 2)).to(torch.float32) + y_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + + ref_output, ref_output_scales = scaled_fp4_grouped_quant(ref_y, y_sf_global, mask) + output, output_scales = silu_and_mul_scaled_fp4_grouped_quant(x, y_sf_global, mask) + + output = output.permute(2, 0, 1) + ref_output = ref_output.permute(2, 0, 1) + + padded_m = ((m + 128 - 1) // 128) * 128 + output_scales = output_scales.permute(5, 2, 4, 0, 1, 3).view(l, padded_m, -1) + ref_output_scales = ref_output_scales.permute(5, 2, 4, 0, 1, 3).view( + l, padded_m, -1 + ) + + for i in range(l): + torch.testing.assert_close(ref_output[i, : mask[i]], output[i, : mask[i]]) + scale_ref = recover_swizzled_scales(ref_output_scales[i], m, k) + scale_ans = recover_swizzled_scales(output_scales[i], m, k) + torch.testing.assert_close(scale_ref[: mask[i]], scale_ans[: mask[i]]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_per_tensor_quant_fp8.py b/sglang/python/sglang/jit_kernel/tests/test_per_tensor_quant_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..b560127ae0c77c7e63e7b7e7f5bd176d49059fd6 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_per_tensor_quant_fp8.py @@ -0,0 +1,86 @@ +import itertools +from typing import Optional, Tuple + +import pytest +import torch + +from sglang.jit_kernel.per_tensor_quant_fp8 import per_tensor_quant_fp8 + +try: + from sglang.srt.utils import is_hip + + _is_hip = is_hip() +except ImportError: + _is_hip = False + +fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn + + +def sglang_scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + fp8_type_: torch.dtype = torch.float8_e4m3fn + output = torch.empty_like(input, device=input.device, dtype=fp8_type_) + is_static = True + if scale is None: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + is_static = False + per_tensor_quant_fp8(input, output, scale, is_static) + + return output, scale + + +def torch_scaled_fp8_quant(tensor, inv_scale): + finfo = torch.finfo(torch.float8_e4m3fn) + scale = inv_scale.reciprocal() + qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max) + qweight = qweight.to(torch.float8_e4m3fn) + return qweight + + +@pytest.mark.parametrize( + "num_tokens,hidden_dim", + list(itertools.product([128, 256, 512], [512, 2048, 4096])), +) +def test_jit_per_tensor_quant_compare_implementations( + num_tokens: int, + hidden_dim: int, +): + device = torch.device("cuda") + x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device) + + sglang_out, sglang_scale = sglang_scaled_fp8_quant(x) + torch_out = torch_scaled_fp8_quant(x, sglang_scale) + + torch.testing.assert_close( + sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3 + ) + + +@pytest.mark.parametrize("shape", [(4, 8, 64), (2, 16, 128), (19260817, 1, 1)]) +def test_jit_per_tensor_quant_supports_3d(shape): + device = torch.device("cuda") + x = torch.rand(shape, dtype=torch.bfloat16, device=device) + out = torch.empty_like(x, device=x.device, dtype=fp8_type_) + scale = torch.zeros(1, device=x.device, dtype=torch.float32) + + per_tensor_quant_fp8(x, out, scale, is_static=False) + + x_2d = x.flatten(0, -2) + out_ref_2d = torch_scaled_fp8_quant(x_2d, scale) + out_ref = out_ref_2d.reshape(shape) + + torch.testing.assert_close(out.float(), out_ref.float(), rtol=1e-3, atol=1e-3) + + scale = torch.rand(1, dtype=torch.float32, device=device) + sglang_out, _ = sglang_scaled_fp8_quant(x, scale) + torch_out = torch_scaled_fp8_quant(x, scale) + + torch.testing.assert_close( + sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3 + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_per_token_group_quant_8bit.py b/sglang/python/sglang/jit_kernel/tests/test_per_token_group_quant_8bit.py new file mode 100644 index 0000000000000000000000000000000000000000..eebd495277bd7606cb9a1a63b257058e1630c47e --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_per_token_group_quant_8bit.py @@ -0,0 +1,205 @@ +import itertools + +import pytest +import torch + +from sglang.srt.utils import is_hip + +_is_hip = is_hip() +fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn + +from sgl_kernel.test_utils import ( + assert_all_close_or_tiny_diff, + create_per_token_group_quant_test_data, +) + +from sglang.jit_kernel.per_token_group_quant_8bit import ( + per_token_group_quant_8bit as sglang_per_token_group_quant_8bit, +) +from sglang.srt.layers.quantization.fp8_kernel import ( + create_per_token_group_quant_fp8_output_scale, +) +from sglang.srt.layers.quantization.fp8_kernel import ( + per_token_group_quant_8bit as triton_per_token_group_quant_8bit, +) + +configs = list( + itertools.product( + [1, 4, 16, 64, 127, 128, 512, 1024, 4096, 8192], # num_tokens + [128, 256, 384, 512, 1024, 1536, 1664, 2048, 4096, 7168, 16384], # hidden_dim + [16, 32, 64, 128], # group_size + [None], # num_ranks + [fp8_type_], # dtype + [ + dict( + column_major_scales=False, + scale_tma_aligned=False, + scale_ue8m0=False, + fuse_silu_and_mul=False, + masked_layout_mode=None, + ), + dict( + column_major_scales=True, + scale_tma_aligned=False, + scale_ue8m0=False, + fuse_silu_and_mul=False, + masked_layout_mode=None, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=False, + fuse_silu_and_mul=False, + masked_layout_mode=None, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=False, + masked_layout_mode=None, + ), + ], + ) +) + list( + itertools.product( + [1, 4, 1 * 8, 4 * 8, 64 * 8, 256 * 8, 768 * 8], + [2048], + [128], + [8, 16, 32, 48], + [fp8_type_], + [ + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode=None, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode="balanced", + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode="imbalanced", + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + fuse_silu_and_mul=True, + masked_layout_mode="extreme", + ), + ], + ) +) + + +@pytest.mark.parametrize( + "num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags", configs +) +def test_per_token_group_quant_with_column_major( + num_tokens, + hidden_dim, + group_size, + num_ranks, + dst_dtype, + flags, +): + arch_major, _ = torch.cuda.get_device_capability(torch.cuda.current_device()) + if flags["scale_ue8m0"] and (arch_major <= 9): + pytest.skip("Only Blackwell need ue8m0 fusion") + return + + if (flags["scale_ue8m0"] and (group_size != 128)) or ( + (dst_dtype == torch.int8) and flags["column_major_scales"] + ): + pytest.skip() + return + + x, masked_m = create_per_token_group_quant_test_data( + num_tokens=num_tokens, hidden_dim=hidden_dim, num_ranks=num_ranks, flags=flags + ) + + execute_kwargs = dict( + x=x, + masked_m=masked_m, + group_size=group_size, + eps=1e-10, + dst_dtype=dst_dtype, + **{k: v for k, v in flags.items() if k not in ["masked_layout_mode"]}, + ) + + def _postprocess(x_q, x_s): + if masked_m is not None: + print(f"Mask tokens after {masked_m} to be zero") + for i in range(len(masked_m)): + x_q[i, masked_m[i] :, :] = 0 + x_s[i, masked_m[i] :, :] = 0 + return x_q, x_s + + x_q_triton, x_s_triton = _postprocess( + *triton_per_token_group_quant_8bit(**execute_kwargs) + ) + + fuse_silu_and_mul = False + out_shape = (*x.shape[:-1], x.shape[-1] // (2 if fuse_silu_and_mul else 1)) + + fp8_dtype = torch.float8_e4m3fn + fp8_max = torch.finfo(fp8_dtype).max + fp8_min = -fp8_max + x_q = torch.empty(out_shape, device=x.device, dtype=fp8_dtype) + x_s = create_per_token_group_quant_fp8_output_scale( + x_shape=out_shape, + device=x.device, + group_size=group_size, + column_major_scales=False, + scale_tma_aligned=False, + scale_ue8m0=False, + ) + + execute_kwargs = dict( + input=x, + output_q=x_q, + output_s=x_s, + group_size=group_size, + eps=1e-10, + fp8_max=fp8_max, + fp8_min=fp8_min, + ) + x_q_sglang, x_s_sglang = _postprocess( + *sglang_per_token_group_quant_8bit(**execute_kwargs) + ) + + try: + assert_all_close_or_tiny_diff(x_q_triton, x_q_sglang) + torch.testing.assert_close( + x_s_triton.contiguous(), + x_s_sglang.contiguous(), + rtol=1e-3, + atol=1e-5, + msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}", + ) + except AssertionError: + print( + f"{x.shape=} {x_q_triton.shape=} {x_s_triton.shape=} {x_q_sglang.shape=} {x_s_sglang.shape=}" + ) + print(f"{x=}") + print(f"{masked_m=}") + print(f"{x_q_triton=}") + print(f"{x_s_triton=}") + print(f"{x_q_sglang=}") + print(f"{x_s_sglang=}") + + raise + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_pos_enc.py b/sglang/python/sglang/jit_kernel/tests/test_pos_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..4b809b002c8c0e5425b3425ae8100b3942c7fae7 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_pos_enc.py @@ -0,0 +1,489 @@ +import time +from typing import Optional, Tuple, Union + +import pytest +import torch +import triton +import triton.language as tl + +from sglang.jit_kernel.pos_enc import rotary_embedding + + +@triton.jit +def burn_kernel(out_ptr, iters: tl.constexpr): + pid = tl.program_id(0) + x = tl.full((), pid + 1, dtype=tl.uint32) + + a = tl.full((), 1664525, dtype=tl.uint32) + c = tl.full((), 1013904223, dtype=tl.uint32) + sh = tl.full((), 13, dtype=tl.uint32) + + for _ in range(iters): + x = x * a + c + x = x ^ (x >> sh) + + if pid == 0: + tl.store(out_ptr, x) + + +def triton_burn(ms: float, grid=(256,)): + iters = int(ms * 20000) + out = torch.empty((), device="cuda", dtype=torch.uint32) + burn_kernel[grid](out, iters=iters) + return out + + +def create_test_inputs( + head_size, batch_size, seq_len, device, dtype, num_q_heads, num_kv_heads +): + """Create test inputs.""" + total_tokens = batch_size * seq_len + + query = torch.randn( + batch_size, seq_len, num_q_heads, head_size, dtype=dtype, device=device + ) + key = torch.randn( + batch_size, seq_len, num_kv_heads, head_size, dtype=dtype, device=device + ) + + pos_ids = torch.randint( + 0, min(seq_len * 2, 100), (total_tokens,), dtype=torch.long, device=device + ) + + query = query.view(total_tokens, num_q_heads, head_size) + key = key.view(total_tokens, num_kv_heads, head_size) + + return query, key, pos_ids + + +def create_cos_sin_cache(rotary_dim, max_position_embeddings, base, dtype, device): + """Create cos/sin cache for rotary embedding.""" + max_pos = max_position_embeddings + extended_max_pos = max(max_pos, 100) + cos_sin_cache = torch.zeros( + extended_max_pos, rotary_dim, dtype=dtype, device=device + ) + + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, rotary_dim, 2, dtype=torch.float32, device=device) + / rotary_dim + ) + ) + t = torch.arange(extended_max_pos, dtype=torch.float32, device=device) + freqs = torch.outer(t, inv_freq) + cos_cache = torch.cos(freqs).to(dtype) + sin_cache = torch.sin(freqs).to(dtype) + + cos_sin_cache[:, : rotary_dim // 2] = cos_cache + cos_sin_cache[:, rotary_dim // 2 :] = sin_cache + + return cos_sin_cache + + +# vLLM torch native +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +class RotaryEmbedding(torch.nn.Module): + # Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + + if offsets is not None: + positions = positions + offsets + + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + # Modification: convert to the correct dtype + query = query.to(self.dtype) + + if key is not None: + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + + key = key.to(self.dtype) + + return query, key + + +def get_torch_rotary_embedding( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device +): + """Initialize Torch Native RotaryEmbedding based on vLLM implementation.""" + return RotaryEmbedding( + head_size=head_size, + rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, + base=base, + is_neox_style=is_neox_style, + dtype=dtype, + ).to(device) + + +def get_sgl_rotary_embedding( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device +): + """Initialize SglKernelRotaryEmbedding.""" + try: + from sgl_kernel.testing.rotary_embedding import SglKernelRotaryEmbedding + except ImportError: + pytest.skip( + "SglKernelRotaryEmbedding is not available. Test case can be removed." + ) + + return SglKernelRotaryEmbedding( + head_size=head_size, + rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, + base=base, + is_neox_style=is_neox_style, + dtype=dtype, + ).to(device) + + +def compare_results(jit_out, sgl_out, dtype): + """Compare results between JIT and SGL implementations.""" + if jit_out is None: + assert sgl_out is None + return + + assert sgl_out is not None + + # Check for NaN values + assert not torch.isnan(jit_out).any(), "NaN in JIT results" + assert not torch.isnan(sgl_out).any(), "NaN in SGL results" + + # Compare results + atol = 1e-2 if dtype != torch.float32 else 1e-5 + rtol = 1e-2 if dtype != torch.float32 else 1e-5 + + torch.testing.assert_close(jit_out, sgl_out, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize( + "head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads", + [ + # GPT-OSS cases + *[ + (64, 64, 4096, 8000, True, torch.bfloat16, "cuda", bs, sl, 8, 8) + for bs, sl in [(1, 1), (32, 1), (128, 1), (512, 1), (2, 512), (4, 4096)] + ], + # Other cases + (64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1), + (256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2), + (512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2), + (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8), + (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4), + (512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2), + (64, 64, 32, 8000, True, torch.float32, "cuda", 32, 32, 1, 1), + (256, 128, 4096, 10000, True, torch.float32, "cuda", 2, 512, 4, 2), + (512, 128, 311, 10000, True, torch.float32, "cuda", 3, 39, 4, 2), + (128, 128, 2048, 10000, False, torch.float32, "cuda", 2, 512, 32, 8), + (128, 128, 2048, 10000, False, torch.float32, "cuda", 2, 512, 16, 4), + (512, 128, 311, 10000, False, torch.float32, "cuda", 3, 39, 4, 2), + # Additional test cases for different head sizes and dtypes + (64, 32, 1024, 10000, True, torch.float16, "cuda", 16, 64, 8, 4), + (128, 64, 2048, 10000, True, torch.float16, "cuda", 8, 128, 16, 8), + (256, 128, 4096, 10000, True, torch.float16, "cuda", 4, 256, 8, 4), + ], +) +@pytest.mark.parametrize( + "key_is_none", + [True, False], +) +def test_correctness( + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + device, + batch_size, + seq_len, + num_q_heads, + num_kv_heads, + key_is_none, +): + """Test correctness of JIT rotary embedding implementation.""" + # Create inputs and caches + query, key, pos_ids = create_test_inputs( + head_size, batch_size, seq_len, device, dtype, num_q_heads, num_kv_heads + ) + cos_sin_cache = create_cos_sin_cache( + rotary_dim, max_position_embeddings, base, dtype, device + ) + + # Initialize torch kernel + torch_rotary_emb = get_torch_rotary_embedding( + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + device, + ) + torch_rotary_emb.cos_sin_cache = cos_sin_cache + r = torch.randn_like(query) + + # Apply rotary embeddings + query_jit, key_jit = query.clone(), key.clone() + query_torch, key_torch = query.clone(), key.clone() + stream_jit = torch.get_device_module("cuda").Stream() + stream_kernel = torch.get_device_module("cuda").Stream() + + if key_is_none: + key_jit = None + key_torch = None + triton_burn(100.0, grid=(1024,)) + + r_jit, r_torch = r.clone(), r.clone() + torch.cuda.synchronize() + + with torch.cuda.stream(stream_jit): + # Test if rotary_embedding runs on stream_jit + triton_burn(100.0, grid=(1024,)) + query_jit = query_jit + r_jit + query_jit_out, key_jit_out = rotary_embedding( + positions=pos_ids, + query=query_jit, + key=key_jit, + head_size=head_size, + cos_sin_cache=cos_sin_cache, + is_neox=is_neox_style, + ) + + with torch.cuda.stream(stream_kernel): + triton_burn(100.0, grid=(1024,)) + query_torch = query_torch + r_torch + query_torch_out, key_torch_out = torch_rotary_emb.forward_native( + positions=pos_ids, query=query_torch, key=key_torch + ) + + torch.cuda.synchronize() + compare_results(query_jit_out, query_torch_out, dtype) + compare_results(key_jit_out, key_torch_out, dtype) + + +@pytest.mark.parametrize( + "head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads", + [ + # Small scale + (64, 64, 4096, 8000, True, torch.bfloat16, "cuda", 1, 1, 8, 8), + (64, 64, 4096, 8000, True, torch.bfloat16, "cuda", 4, 16, 8, 8), + # Medium scale + (64, 64, 4096, 8000, True, torch.bfloat16, "cuda", 8, 64, 8, 8), + (64, 64, 4096, 8000, True, torch.bfloat16, "cuda", 16, 128, 8, 8), + # Large scale + (64, 64, 4096, 8000, True, torch.bfloat16, "cuda", 32, 512, 8, 8), + (64, 64, 4096, 8000, True, torch.bfloat16, "cuda", 64, 1024, 8, 8), + ], +) +def test_performance( + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style, + dtype, + device, + batch_size, + seq_len, + num_q_heads, + num_kv_heads, +): + """Performance test comparing JIT and SGL implementations with accuracy validation.""" + # Create inputs and caches + query, key, pos_ids = create_test_inputs( + head_size, batch_size, seq_len, device, dtype, num_q_heads, num_kv_heads + ) + cos_sin_cache = create_cos_sin_cache( + rotary_dim, max_position_embeddings, base, dtype, device + ) + + # Initialize SGL kernel + sgl_rotary_emb = get_sgl_rotary_embedding( + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + device, + ) + sgl_rotary_emb.cos_sin_cache = cos_sin_cache + + warmup = 3 + + # Warmup runs + for _ in range(warmup): + query_warm, key_warm = query.clone(), key.clone() + rotary_embedding( + positions=pos_ids, + query=query_warm, + key=key_warm, + head_size=head_size, + cos_sin_cache=cos_sin_cache, + is_neox=is_neox_style, + ) + + query_sgl_warm, key_sgl_warm = query.clone(), key.clone() + sgl_rotary_emb.forward_cuda( + positions=pos_ids, query=query_sgl_warm, key=key_sgl_warm + ) + + iteration = 100 + + # Time JIT implementation + torch.cuda.synchronize() + start_time = time.time() + for _ in range(iteration): + query_jit, key_jit = query.clone(), key.clone() + rotary_embedding( + positions=pos_ids, + query=query_jit, + key=key_jit, + head_size=head_size, + cos_sin_cache=cos_sin_cache, + is_neox=is_neox_style, + ) + torch.cuda.synchronize() + jit_time = (time.time() - start_time) / iteration + + # Time SGL implementation + torch.cuda.synchronize() + start_time = time.time() + for _ in range(iteration): + query_sgl, key_sgl = query.clone(), key.clone() + sgl_rotary_emb.forward_cuda(positions=pos_ids, query=query_sgl, key=key_sgl) + torch.cuda.synchronize() + sgl_time = (time.time() - start_time) / iteration + + # Accuracy validation during performance test + # Run one more time to get outputs for comparison + query_jit_final, key_jit_final = query.clone(), key.clone() + query_sgl_final, key_sgl_final = query.clone(), key.clone() + + query_jit_out, key_jit_out = rotary_embedding( + positions=pos_ids, + query=query_jit_final, + key=key_jit_final, + head_size=head_size, + cos_sin_cache=cos_sin_cache, + is_neox=is_neox_style, + ) + + query_sgl_out, key_sgl_out = sgl_rotary_emb.forward_cuda( + positions=pos_ids, query=query_sgl_final, key=key_sgl_final + ) + + # Validate accuracy + compare_results(query_jit_out, query_sgl_out, dtype) + compare_results(key_jit_out, key_sgl_out, dtype) + + # Print results + total_tokens = batch_size * seq_len + print( + f"\nPerformance Test - Batch={batch_size}, SeqLen={seq_len}, Tokens={total_tokens}" + ) + print(f"JIT: {jit_time*1000:.9f}ms, SGL: {sgl_time*1000:.9f}ms") + if sgl_time > 0: + speedup = sgl_time / jit_time if jit_time > 0 else float("inf") + print(f"Speedup (SGL/JIT): {speedup:.2f}x") + + assert jit_time >= 0 and sgl_time >= 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_qknorm.py b/sglang/python/sglang/jit_kernel/tests/test_qknorm.py new file mode 100644 index 0000000000000000000000000000000000000000..ee72e9ec6bd4f1443d5b45f37754c04507a26be4 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_qknorm.py @@ -0,0 +1,93 @@ +import itertools + +import pytest +import torch +import triton + + +def sglang_aot_qknorm( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, +) -> None: + from sgl_kernel import rmsnorm + + head_dim = q.shape[-1] + q = q.view(-1, head_dim) + k = k.view(-1, head_dim) + rmsnorm(q, q_weight, out=q) + rmsnorm(k, k_weight, out=k) + + +def sglang_jit_qknorm( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, +) -> None: + from sglang.jit_kernel.norm import fused_inplace_qknorm + + fused_inplace_qknorm(q, k, q_weight, k_weight) + + +def flashinfer_qknorm( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, +) -> None: + from flashinfer.norm import rmsnorm + + rmsnorm(q, q_weight, out=q) + rmsnorm(k, k_weight, out=k) + + +@torch.compile() +def torch_impl_qknorm( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + q_mean = q.float().pow(2).mean(dim=-1, keepdim=True) + k_mean = k.float().pow(2).mean(dim=-1, keepdim=True) + q_norm = (q_mean + eps).rsqrt() + k_norm = (k_mean + eps).rsqrt() + q.copy_(q.float() * q_norm * q_weight.float()) + k.copy_(k.float() * k_norm * k_weight.float()) + + +BS_LIST = [2**n for n in range(0, 14)] +BS_LIST += [x + 1 + i for i, x in enumerate(BS_LIST)] +N_K_LIST = [2, 4] +N_Q_LIST = [8, 16] +HEAD_DIM_LIST = [64, 128, 256, 512, 1024] +DEVICE = "cuda" +DTYPE = torch.bfloat16 + +# NOTE(dark): sgl_kernel use flashinfer template, which is bitwise identical to flashinfer impl. +# However, sgl-jit-kernel, flashinfer, torch_impl, may have small numerical differences. +# so we allow a small rel/abs tolerance in correctness test. + + +@pytest.mark.parametrize( + "batch_size,n_k,n_q,head_dim", + list(itertools.product(BS_LIST, N_K_LIST, N_Q_LIST, HEAD_DIM_LIST)), +) +def test_qknorm(batch_size: int, n_k: int, n_q: int, head_dim: int) -> None: + q = torch.randn(batch_size, n_q, head_dim, device=DEVICE, dtype=DTYPE) + k = torch.randn(batch_size, n_k, head_dim, device=DEVICE, dtype=DTYPE) + q_weight = torch.randn(head_dim, device=DEVICE, dtype=DTYPE) + k_weight = torch.randn(head_dim, device=DEVICE, dtype=DTYPE) + q_k_aot = (q.clone(), k.clone()) + q_k_jit = (q.clone(), k.clone()) + sglang_aot_qknorm(q_k_aot[0], q_k_aot[1], q_weight, k_weight) + sglang_jit_qknorm(q_k_jit[0], q_k_jit[1], q_weight, k_weight) + triton.testing.assert_close(q_k_aot[0], q_k_jit[0], atol=1e-2, rtol=1e-2) + triton.testing.assert_close(q_k_aot[1], q_k_jit[1], atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_qknorm_across_heads.py b/sglang/python/sglang/jit_kernel/tests/test_qknorm_across_heads.py new file mode 100644 index 0000000000000000000000000000000000000000..f090d82ada0b68a8ea67ab6d8450430f0ee81438 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_qknorm_across_heads.py @@ -0,0 +1,75 @@ +import itertools + +import pytest +import torch +import triton + + +def sglang_jit_qknorm_across_heads( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, +) -> None: + from sglang.jit_kernel.norm import fused_inplace_qknorm_across_heads + + fused_inplace_qknorm_across_heads(q, k, q_weight, k_weight) + + +def sglang_aot_qknorm_across_heads( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, +) -> None: + from sgl_kernel import rmsnorm + + rmsnorm(q, q_weight, out=q) + rmsnorm(k, k_weight, out=k) + + +@torch.compile() +def torch_impl_qknorm_across_heads( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + q_mean = q.float().pow(2).mean(dim=-1, keepdim=True) + k_mean = k.float().pow(2).mean(dim=-1, keepdim=True) + q_norm = (q_mean + eps).rsqrt() + k_norm = (k_mean + eps).rsqrt() + q.copy_(q.float() * q_norm * q_weight.float()) + k.copy_(k.float() * k_norm * k_weight.float()) + + +BS_LIST = [2**n for n in range(0, 14)] +BS_LIST += [x + 1 + i for i, x in enumerate(BS_LIST)] +HIDDEN_DIM_LIST = [512, 1024, 2048, 4096] +DEVICE = "cuda" +DTYPE = torch.bfloat16 + + +@pytest.mark.parametrize( + "batch_size,hidden_dim", + list(itertools.product(BS_LIST, HIDDEN_DIM_LIST)), +) +def test_qknorm_across_heads(batch_size: int, hidden_dim: int) -> None: + q = torch.randn(batch_size, hidden_dim, device=DEVICE, dtype=DTYPE) + k = torch.randn(batch_size, hidden_dim, device=DEVICE, dtype=DTYPE) + q_weight = torch.randn(hidden_dim, device=DEVICE, dtype=DTYPE) + k_weight = torch.randn(hidden_dim, device=DEVICE, dtype=DTYPE) + + q_k_jit = (q.clone(), k.clone()) + q_k_aot = (q.clone(), k.clone()) + + sglang_jit_qknorm_across_heads(q_k_jit[0], q_k_jit[1], q_weight, k_weight) + sglang_aot_qknorm_across_heads(q_k_aot[0], q_k_aot[1], q_weight, k_weight) + + triton.testing.assert_close(q_k_jit[0], q_k_aot[0], atol=1e-2, rtol=1e-2) + triton.testing.assert_close(q_k_jit[1], q_k_aot[1], atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_renorm.py b/sglang/python/sglang/jit_kernel/tests/test_renorm.py new file mode 100644 index 0000000000000000000000000000000000000000..a25a8b1f6fbc11a85460c117138dc86483cf97ac --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_renorm.py @@ -0,0 +1,118 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/main/tests/test_sampling.py +# and /sgl-workspace/sglang/sgl-kernel/tests/test_sampling.py + +import pytest +import sgl_kernel +import torch + + +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) +@pytest.mark.parametrize("k", [10, 100, 500]) +def test_top_k_renorm_probs(batch_size, vocab_size, k): + """Test top_k_renorm_probs kernel for correctness. + + This test validates that the kernel correctly: + 1. Identifies the top-k probabilities + 2. Masks out non-top-k values + 3. Renormalizes the remaining probabilities to sum to 1 + """ + if k > vocab_size: + pytest.skip("k should be less than vocab_size") + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, _ = torch.sort(normalized_prob, descending=True) + pivot = sorted_prob[:, k - 1] + mask = (normalized_prob >= pivot.unsqueeze(-1)).int() + renorm_prob_ground_truth = normalized_prob.clone() + renorm_prob_ground_truth[mask == 0] = 0 + renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( + dim=-1, keepdim=True + ) + + renorm_prob = sgl_kernel.top_k_renorm_prob(normalized_prob, k) + for i in range(batch_size): + torch.testing.assert_close( + renorm_prob_ground_truth[i], + renorm_prob[i], + rtol=1e-3, + atol=1e-3, + ) + + +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) +@pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) +def test_top_p_renorm_probs(batch_size, vocab_size, p): + """Test top_p_renorm_probs kernel for correctness. + + This test validates that the kernel correctly: + 1. Computes the cumulative probability distribution + 2. Identifies tokens in the top-p threshold + 3. Masks out tokens outside the threshold + 4. Renormalizes the remaining probabilities to sum to 1 + """ + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, indices = torch.sort(normalized_prob, descending=False) + cdf = torch.cumsum(sorted_prob, dim=-1) + mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0") + mask.scatter_add_(1, indices, (cdf >= (1 - p)).int()) + renorm_prob_ground_truth = normalized_prob.clone() + renorm_prob_ground_truth[mask == 0] = 0 + renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( + dim=-1, keepdim=True + ) + + renorm_prob = sgl_kernel.top_p_renorm_prob(normalized_prob, p) + torch.testing.assert_close( + renorm_prob_ground_truth, + renorm_prob, + rtol=1e-3, + atol=1e-3, + ) + + +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) +@pytest.mark.parametrize("k", [10, 100, 500]) +@pytest.mark.parametrize("neginf_input", [False, True]) +def test_top_k_mask_logits(batch_size, vocab_size, k, neginf_input): + """Test top_k_mask_logits kernel for correctness. + + This test validates that the kernel correctly: + 1. Identifies the top-k logits + 2. Masks non-top-k values to -inf + 3. Preserves the top-k values + 4. Handles negative infinity inputs gracefully + + The test verifies correctness by comparing softmax(top_k_mask_logits(logits)) + with top_k_renorm_prob(probs), which should be equivalent. + """ + if k > vocab_size: + pytest.skip("k should be less than vocab_size") + torch.manual_seed(42) + logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5 + if neginf_input: + # Randomly assign some logits to -inf to test edge cases + num_neginf = torch.randint(1, vocab_size * batch_size, (1,)).item() + idxs = torch.randperm(batch_size * vocab_size, device="cuda:0")[:num_neginf] + logits[idxs // vocab_size, idxs % vocab_size] = -float("inf") + + probs = torch.softmax(logits, dim=-1) + masked_logits = sgl_kernel.top_k_mask_logits(logits, k) + renormed_probs = torch.softmax(masked_logits, dim=-1) + renormed_probs_ref = sgl_kernel.top_k_renorm_prob(probs, k) + + torch.testing.assert_close( + renormed_probs, + renormed_probs_ref, + rtol=1e-3, + atol=1e-3, + ) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_rmsnorm.py b/sglang/python/sglang/jit_kernel/tests/test_rmsnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..501124daf72906d08bde7caa38a2fb8a434aa453 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_rmsnorm.py @@ -0,0 +1,41 @@ +import itertools + +import pytest +import torch +import triton + + +def sglang_jit_rmsnorm(input: torch.Tensor, weight: torch.Tensor) -> None: + from sglang.jit_kernel.norm import rmsnorm + + rmsnorm(input, weight, output=input) + + +def flashinfer_rmsnorm(input: torch.Tensor, weight: torch.Tensor) -> None: + from flashinfer.norm import rmsnorm + + rmsnorm(input, weight, out=input) + + +BS_LIST = [2**n for n in range(0, 14)] +BS_LIST += [x + 1 + i for i, x in enumerate(BS_LIST)] +HIDDEN_SIZE_LIST = [512, 1024, 1536, 2048, 3072, 4096, 5120, 6144, 7168, 8192] +DEVICE = "cuda" +DTYPE = torch.bfloat16 + + +@pytest.mark.parametrize( + "batch_size,hidden_size", list(itertools.product(BS_LIST, HIDDEN_SIZE_LIST)) +) +def test_rmsnorm(batch_size: int, hidden_size: int) -> None: + input = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=DTYPE) + weight = torch.randn(hidden_size, device=DEVICE, dtype=DTYPE) + input_sglang = input.clone() + input_flashinfer = input.clone() + sglang_jit_rmsnorm(input_sglang, weight) + flashinfer_rmsnorm(input_flashinfer, weight) + triton.testing.assert_close(input_sglang, input_flashinfer, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_rope.py b/sglang/python/sglang/jit_kernel/tests/test_rope.py new file mode 100644 index 0000000000000000000000000000000000000000..0aba5cf4cefacd57bdf814927ce70f8f30fa7ede --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_rope.py @@ -0,0 +1,244 @@ +import pytest +import torch +import triton + +DEVICE = "cuda" +DTYPE = torch.bfloat16 +MAX_SEQ_LEN = 131072 # common seq length +ROPE_BASE = 10000.0 +CACHE_SIZE = 1024 * 128 + + +def create_cos_sin_cache( + rotary_dim: int, + max_position: int = MAX_SEQ_LEN, + base: float = ROPE_BASE, +) -> torch.Tensor: + """Create cos/sin cache compatible with SGLang layout: [max_pos, rotary_dim].""" + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, rotary_dim, 2, dtype=torch.float32, device=DEVICE) + / rotary_dim + ) + ) + t = torch.arange(max_position, dtype=torch.float32, device=DEVICE) + freqs = torch.einsum("i,j->ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) # [max_pos, rotary_dim] + return cache + + +# --------------------------------------------------------------------------- +# Implementation wrappers +# --------------------------------------------------------------------------- + + +def sglang_jit_rope( + q: torch.Tensor, + k: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, + is_neox: bool, +) -> None: + from sglang.jit_kernel.rope import apply_rope_inplace + + apply_rope_inplace(q, k, cos_sin_cache, positions, is_neox=is_neox) + + +def flashinfer_rope( + q: torch.Tensor, + k: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, + is_neox: bool, +) -> None: + from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace + + head_size = q.shape[-1] + # flashinfer expects [nnz, num_heads * head_size] + q_2d = q.view(q.shape[0], -1) + k_2d = k.view(k.shape[0], -1) + apply_rope_with_cos_sin_cache_inplace( + positions=positions, + query=q_2d, + key=k_2d, + head_size=head_size, + cos_sin_cache=cos_sin_cache, + is_neox=is_neox, + ) + + +def torch_impl_rope( + q: torch.Tensor, + k: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, + is_neox: bool, +) -> None: + # TODO: implement a pure-PyTorch reference for extra coverage + pass + + +# --------------------------------------------------------------------------- +# Test parameters +# --------------------------------------------------------------------------- + +BS_LIST = [2**x for x in range(12)] +BS_LIST += [x + 1 for x in BS_LIST] # odd sizes to stress non-aligned paths +NUM_KV_HEADS_LIST = [1, 2, 8] +GQA_RATIO = [1, 4, 8] +ROPE_DIM_LIST = [64, 128, 256, 512] +IS_NEOX_LIST = [False, True] +DTYPE_LIST = [torch.bfloat16, torch.float16] + + +@pytest.mark.parametrize("batch_size", BS_LIST) +@pytest.mark.parametrize("gqa_ratio", GQA_RATIO) +@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS_LIST) +@pytest.mark.parametrize("rope_dim", ROPE_DIM_LIST) +@pytest.mark.parametrize("is_neox", IS_NEOX_LIST) +@pytest.mark.parametrize("dtype", DTYPE_LIST) +def test_rope( + batch_size: int, + gqa_ratio: int, + num_kv_heads: int, + rope_dim: int, + is_neox: bool, + dtype: torch.dtype, +) -> None: + num_qo_heads = num_kv_heads * gqa_ratio + q = torch.randn(batch_size, num_qo_heads, rope_dim, device=DEVICE, dtype=dtype) + k = torch.randn(batch_size, num_kv_heads, rope_dim, device=DEVICE, dtype=dtype) + positions = torch.randint( + 0, MAX_SEQ_LEN, (batch_size,), device=DEVICE, dtype=torch.int64 + ) + cos_sin_cache = create_cos_sin_cache(rope_dim) + + q_fi, k_fi = q.clone(), k.clone() + q_jit, k_jit = q.clone(), k.clone() + + flashinfer_rope(q_fi, k_fi, cos_sin_cache, positions, is_neox) + sglang_jit_rope(q_jit, k_jit, cos_sin_cache, positions, is_neox) + + atol = rtol = 1e-2 + triton.testing.assert_close(q_fi, q_jit, atol=atol, rtol=rtol) + triton.testing.assert_close(k_fi, k_jit, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("dtype", [torch.int32, torch.int64]) +def test_rope_position_dtypes(dtype: torch.dtype) -> None: + """Ensure both int32 and int64 position tensors work correctly.""" + batch_size, num_qo_heads, num_kv_heads, rope_dim = 16384, 16, 2, 128 + is_neox = True + + q = torch.randn(batch_size, num_qo_heads, rope_dim, device=DEVICE, dtype=DTYPE) + k = torch.randn(batch_size, num_kv_heads, rope_dim, device=DEVICE, dtype=DTYPE) + positions = torch.randint(0, MAX_SEQ_LEN, (batch_size,), device=DEVICE, dtype=dtype) + cos_sin_cache = create_cos_sin_cache(rope_dim) + + q_fi, k_fi = q.clone(), k.clone() + q_jit, k_jit = q.clone(), k.clone() + + flashinfer_rope(q_fi, k_fi, cos_sin_cache, positions.long(), is_neox) + sglang_jit_rope(q_jit, k_jit, cos_sin_cache, positions, is_neox) + + atol = rtol = 1e-2 + triton.testing.assert_close(q_fi, q_jit, atol=atol, rtol=rtol) + triton.testing.assert_close(k_fi, k_jit, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("batch_size", BS_LIST) +@pytest.mark.parametrize("is_neox", IS_NEOX_LIST) +@pytest.mark.parametrize("rope_dim", [64, 80, 96, 128]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +def test_partial_rope(batch_size: int, is_neox: bool, rope_dim: int, head_dim: int): + if head_dim < rope_dim: + pytest.skip("Invalid config: head_dim must be >= rope_dim.") + num_qo_heads, num_kv_heads = 8, 2 + + q = torch.randn(batch_size, num_qo_heads, head_dim, device=DEVICE, dtype=DTYPE) + k = torch.randn(batch_size, num_kv_heads, head_dim, device=DEVICE, dtype=DTYPE) + positions = torch.randint(0, MAX_SEQ_LEN, (batch_size,), device=DEVICE) + cos_sin_cache = create_cos_sin_cache(rope_dim) + + q_fi, k_fi = q.clone(), k.clone() + q_jit, k_jit = q.clone(), k.clone() + rope = ..., slice(rope_dim) # NOTE: flashinfer by default apply to first rope_dim + + flashinfer_rope(q_fi, k_fi, cos_sin_cache, positions.long(), is_neox) + sglang_jit_rope(q_jit[rope], k_jit[rope], cos_sin_cache, positions, is_neox) + + atol = rtol = 1e-2 + triton.testing.assert_close(q_fi, q_jit, atol=atol, rtol=rtol) + triton.testing.assert_close(k_fi, k_jit, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("batch_size", BS_LIST) +@pytest.mark.parametrize("gqa_ratio", GQA_RATIO) +@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS_LIST) +@pytest.mark.parametrize("rope_dim", ROPE_DIM_LIST) +@pytest.mark.parametrize("is_neox", IS_NEOX_LIST) +def test_fused_rope_store( + batch_size: int, + gqa_ratio: int, + num_kv_heads: int, + rope_dim: int, + is_neox: bool, +) -> None: + """Test fused RoPE + KV cache store against separate RoPE + manual store.""" + from sglang.jit_kernel.rope import apply_rope_inplace_with_kvcache + + num_qo_heads = num_kv_heads * gqa_ratio + dtype = DTYPE + + q = torch.randn(batch_size, num_qo_heads, rope_dim, device=DEVICE, dtype=dtype) + k = torch.randn(batch_size, num_kv_heads, rope_dim, device=DEVICE, dtype=dtype) + v = torch.randn(batch_size, num_kv_heads, rope_dim, device=DEVICE, dtype=dtype) + positions = torch.randint( + 0, MAX_SEQ_LEN, (batch_size,), device=DEVICE, dtype=torch.int64 + ) + out_loc = torch.randperm(CACHE_SIZE, device=DEVICE, dtype=torch.int64)[:batch_size] + cos_sin_cache = create_cos_sin_cache(rope_dim) + + row_size = num_kv_heads * rope_dim + k_cache_ref = torch.zeros(CACHE_SIZE, row_size, device=DEVICE, dtype=dtype) + v_cache_ref = torch.zeros(CACHE_SIZE, row_size, device=DEVICE, dtype=dtype) + k_cache_fused = torch.zeros(CACHE_SIZE, row_size, device=DEVICE, dtype=dtype) + v_cache_fused = torch.zeros(CACHE_SIZE, row_size, device=DEVICE, dtype=dtype) + + # --- reference: separate RoPE then manual scatter --- + q_ref, k_ref = q.clone(), k.clone() + flashinfer_rope(q_ref, k_ref, cos_sin_cache, positions, is_neox) + k_cache_ref[out_loc] = k_ref.view(batch_size, -1) + v_cache_ref[out_loc] = v.view(batch_size, -1) + + # --- fused kernel --- + q_fused, k_fused = q.clone(), k.clone() + v_fused = v.clone() + apply_rope_inplace_with_kvcache( + q_fused, + k_fused, + v_fused, + k_cache_fused, + v_cache_fused, + cos_sin_cache, + positions, + out_loc, + is_neox=is_neox, + ) + + atol = rtol = 1e-2 + # q should match RoPE-only result + triton.testing.assert_close(q_ref, q_fused, atol=atol, rtol=rtol) + # k_cache should contain the rotated k + triton.testing.assert_close( + k_cache_ref[out_loc], k_cache_fused[out_loc], atol=atol, rtol=rtol + ) + # v_cache should be an exact copy + assert torch.all(v_cache_ref[out_loc] == v_cache_fused[out_loc]), "v_cache mismatch" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_store_cache.py b/sglang/python/sglang/jit_kernel/tests/test_store_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..ea168d0b22ee2c8c342220bd1c09c3ca955e673a --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_store_cache.py @@ -0,0 +1,123 @@ +import itertools + +import pytest +import torch + +from sglang.jit_kernel.kvcache import can_use_store_cache, store_cache + +BS_LIST = [2**n for n in range(0, 15)] +BS_LIST += [x + 1 + i for i, x in enumerate(BS_LIST)] +HIDDEN_DIMS = [64, 128, 256, 512, 1024, 96, 98, 100] +CACHE_SIZE = 1024 * 1024 +DTYPE = torch.bfloat16 +DEVICE = "cuda" + + +@pytest.mark.parametrize( + "batch_size,element_dim", + list(itertools.product(BS_LIST, HIDDEN_DIMS)), +) +def test_store_cache(batch_size: int, element_dim: int) -> None: + k = torch.randn((batch_size, element_dim), dtype=DTYPE, device=DEVICE) + v = torch.randn((batch_size, element_dim), dtype=DTYPE, device=DEVICE) + k_cache = torch.randn((CACHE_SIZE, element_dim), dtype=DTYPE, device=DEVICE) + v_cache = torch.randn((CACHE_SIZE, element_dim), dtype=DTYPE, device=DEVICE) + indices = torch.randperm(CACHE_SIZE, device=DEVICE)[:batch_size] + + # AOT store cache + store_cache(k, v, k_cache, v_cache, indices) + + assert torch.all(k_cache[indices] == k) + assert torch.all(v_cache[indices] == v) + + +# Smaller subset for targeted tests below +REPR_BS = [1, 7, 128] +REPR_DIMS = [64, 128, 512, 1024, 96] +SMALL_CACHE = 4096 + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize( + "batch_size,element_dim", + list(itertools.product(REPR_BS, REPR_DIMS)), +) +def test_store_cache_dtypes( + batch_size: int, element_dim: int, dtype: torch.dtype +) -> None: + k = torch.randn((batch_size, element_dim), dtype=dtype, device=DEVICE) + v = torch.randn((batch_size, element_dim), dtype=dtype, device=DEVICE) + k_cache = torch.randn((SMALL_CACHE, element_dim), dtype=dtype, device=DEVICE) + v_cache = torch.randn((SMALL_CACHE, element_dim), dtype=dtype, device=DEVICE) + indices = torch.randperm(SMALL_CACHE, device=DEVICE)[:batch_size] + + store_cache(k, v, k_cache, v_cache, indices) + + assert torch.all(k_cache[indices] == k) + assert torch.all(v_cache[indices] == v) + + +@pytest.mark.parametrize( + "batch_size,element_dim", + list(itertools.product(REPR_BS, REPR_DIMS)), +) +def test_store_cache_int32_indices(batch_size: int, element_dim: int) -> None: + k = torch.randn((batch_size, element_dim), dtype=DTYPE, device=DEVICE) + v = torch.randn((batch_size, element_dim), dtype=DTYPE, device=DEVICE) + k_cache = torch.randn((SMALL_CACHE, element_dim), dtype=DTYPE, device=DEVICE) + v_cache = torch.randn((SMALL_CACHE, element_dim), dtype=DTYPE, device=DEVICE) + # int32 indices exercise a different CUDA template instantiation than default int64 + indices = torch.randperm(SMALL_CACHE, device=DEVICE)[:batch_size].to(torch.int32) + + store_cache(k, v, k_cache, v_cache, indices) + + assert torch.all(k_cache[indices.long()] == k) + assert torch.all(v_cache[indices.long()] == v) + + +def _valid_num_splits(element_dim: int, dtype: torch.dtype) -> list: + """Return the list of valid num_split values for a given element_dim/dtype.""" + row_bytes = element_dim * dtype.itemsize + splits = [1] + if row_bytes % (2 * 128) == 0: + splits.append(2) + if row_bytes % (4 * 128) == 0: + splits.append(4) + return splits + + +_NUM_SPLIT_CASES = [ + (_dim, _ns, _dtype) + for _dtype in [torch.float16, torch.bfloat16, torch.float32] + for _dim in REPR_DIMS + for _ns in _valid_num_splits(_dim, _dtype) +] + + +@pytest.mark.parametrize("element_dim,num_split,dtype", _NUM_SPLIT_CASES) +def test_store_cache_num_split( + element_dim: int, num_split: int, dtype: torch.dtype +) -> None: + batch_size = 128 + k = torch.randn((batch_size, element_dim), dtype=dtype, device=DEVICE) + v = torch.randn((batch_size, element_dim), dtype=dtype, device=DEVICE) + k_cache = torch.randn((SMALL_CACHE, element_dim), dtype=dtype, device=DEVICE) + v_cache = torch.randn((SMALL_CACHE, element_dim), dtype=dtype, device=DEVICE) + indices = torch.randperm(SMALL_CACHE, device=DEVICE)[:batch_size] + + # Verify each num_split kernel path (1, 2, 4) produces correct results + store_cache(k, v, k_cache, v_cache, indices, num_split=num_split) + + assert torch.all(k_cache[indices] == k) + assert torch.all(v_cache[indices] == v) + + +def test_can_use_store_cache() -> None: + assert can_use_store_cache(128) + assert can_use_store_cache(256) + assert can_use_store_cache(1024) + assert can_use_store_cache(2048) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/tests/test_timestep_embedding.py b/sglang/python/sglang/jit_kernel/tests/test_timestep_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..068363774a91f5947c65cd699e4adb9bb2c70f5b --- /dev/null +++ b/sglang/python/sglang/jit_kernel/tests/test_timestep_embedding.py @@ -0,0 +1,160 @@ +import os + +import numpy as np +import pytest +import torch + +try: + import tabulate +except Exception: + tabulate = None + +from sglang.jit_kernel.timestep_embedding import ( + timestep_embedding as timestep_embedding_cuda, +) + + +def get_timestep_embedding_reference( + timesteps: torch.Tensor, + dim: int, + *, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + timesteps = timesteps.to(torch.float32) + half_dim = dim // 2 + exponent = -torch.log( + torch.tensor(max_period, dtype=torch.float32, device=timesteps.device) + ) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + emb = scale * emb + + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + if dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +@pytest.mark.parametrize( + "batch_size", [1, 2, 8, 128, 256, 512, 1536, 2048, 4096, 11008, 16384] +) +@pytest.mark.parametrize("dim", [32, 128, 256, 512, 1536, 2048, 4096, 8192]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_timestep_embedding_correctness_with_sgld(batch_size, dim, dtype): + device = "cuda" + t = torch.randint(low=0, high=1000, size=(batch_size,), device=device).to(dtype) + torch_output = get_timestep_embedding_reference( + t, dim, flip_sin_to_cos=True, downscale_freq_shift=0 + ) + cuda_output = timestep_embedding_cuda( + t, dim, flip_sin_to_cos=True, downscale_freq_shift=0 + ) + torch.testing.assert_close(torch_output, cuda_output, atol=1e-3, rtol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 2, 8, 128, 256, 512, 1536, 2048, 16384]) +@pytest.mark.parametrize("dim", [32, 256, 512, 1536, 8192]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("flip_sin_to_cos", [False, True]) +@pytest.mark.parametrize("downscale_freq_shift", [0, 1]) +@pytest.mark.parametrize("scale", [1, 0.01]) +def test_timestep_embedding_correctness_with_diffusers( + batch_size, dim, flip_sin_to_cos, downscale_freq_shift, scale, dtype +): + device = "cuda" + t = torch.randint(low=0, high=1000, size=(batch_size,), device=device).to(dtype) + torch_output = get_timestep_embedding_reference( + t, + dim, + flip_sin_to_cos=flip_sin_to_cos, + downscale_freq_shift=downscale_freq_shift, + scale=scale, + max_period=10000, + ) + cuda_output = timestep_embedding_cuda( + t, + dim, + flip_sin_to_cos=flip_sin_to_cos, + downscale_freq_shift=downscale_freq_shift, + scale=scale, + max_period=10000, + ) + torch.testing.assert_close(torch_output, cuda_output, atol=1e-3, rtol=1e-3) + + +def test_timestep_embedding_perf(): + if os.environ.get("SGLANG_RUN_JIT_KERNEL_PERF_TESTS") != "1": + pytest.skip("Perf test disabled by default") + if tabulate is None: + pytest.skip("Optional dependency 'tabulate' is not installed") + + NUM_BATCH = [1, 2, 8, 63, 256, 512, 613, 1024, 1536] + NUM_DIM = [32, 64, 128, 256, 512, 1024, 2048, 4096] + + def perf_kernel_fn(kernel_fn: callable, *args, **kwargs): + warmup_times = 4 + repeat_times = 20 + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + for _ in range(warmup_times): + output_fn = kernel_fn(*args, **kwargs) + torch.cuda.synchronize() + + start.record() + for _ in range(repeat_times): + output_fn = kernel_fn(*args, **kwargs) + end.record() + end.synchronize() + return start.elapsed_time(end) / repeat_times + + device = "cuda" + results = [] + + cuda_speedups = [] + for B in NUM_BATCH: + for dim in NUM_DIM: + t = torch.linspace(0, max(100000, B), steps=B, device=device).to( + torch.float32 + ) + time_torch = perf_kernel_fn(get_timestep_embedding_reference, t, dim) + time_cuda = perf_kernel_fn(timestep_embedding_cuda, t, dim) + speedup_cuda = time_torch / time_cuda + + results.append( + { + "Batch Size": B, + "Dimension": dim, + "Torch Time (ms)": time_torch, + "CUDA Time (ms)": time_cuda, + "Speedup (CUDA)": speedup_cuda, + } + ) + cuda_speedups.append(speedup_cuda) + + print("=== Timestep Embedding Benchmark Results ===") + print( + tabulate.tabulate( + results, + headers="keys", + tablefmt="fancy_grid", + floatfmt=(".0f", ".0f", ".6f", ".6f", ".5f"), + ) + ) + print(f"Average Speedup(cuda): {np.mean(cuda_speedups):.4f}") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/jit_kernel/timestep_embedding.py b/sglang/python/sglang/jit_kernel/timestep_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..4f6cc2b422921778feea146c45610a5a225b38f5 --- /dev/null +++ b/sglang/python/sglang/jit_kernel/timestep_embedding.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_timestep_embedding_module(dtype: torch.dtype) -> Module: + args = make_cpp_args(dtype) + return load_jit( + "timestep_embedding", + *args, + cuda_files=["diffusion/timestep_embedding.cuh"], + cuda_wrappers=[("timestep_embedding", f"timestep_embedding<{args}>")], + ) + + +def timestep_embedding( + t: torch.Tensor, + dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 0.0, + scale: float = 1, + max_period: int = 10000, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + if t.dtype not in (torch.float16, torch.bfloat16, torch.float32): + t = t.to(dtype) + output = torch.empty((t.shape[0], dim), dtype=torch.float32, device=t.device) + module = _jit_timestep_embedding_module(t.dtype) + module.timestep_embedding( + t, + output, + dim, + flip_sin_to_cos, + float(downscale_freq_shift), + float(scale), + int(max_period), + ) + return output diff --git a/sglang/python/sglang/jit_kernel/utils.py b/sglang/python/sglang/jit_kernel/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e5de7df01d59fba6de6beee45165ba05df2177eb --- /dev/null +++ b/sglang/python/sglang/jit_kernel/utils.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +import functools +import os +import pathlib +from typing import TYPE_CHECKING, Any, Callable, List, Tuple, TypeAlias, TypeVar, Union + +import torch + +if TYPE_CHECKING: + from tvm_ffi import Module + +F = TypeVar("F", bound=Callable[..., Any]) + + +def cache_once(fn: F) -> F: + """ + NOTE: `functools.lru_cache` is not compatible with `torch.compile` + So we manually implement a simple cache_once decorator to replace it. + """ + result_map = {} + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + key = (args, tuple(sorted(kwargs.items(), key=lambda x: x[0]))) + if key not in result_map: + result_map[key] = fn(*args, **kwargs) + return result_map[key] + + return wrapper # type: ignore + + +def _make_wrapper(tup: Tuple[str, str]) -> str: + export_name, kernel_name = tup + return f"TVM_FFI_DLL_EXPORT_TYPED_FUNC({export_name}, ({kernel_name}));" + + +@cache_once +def _resolve_kernel_path() -> pathlib.Path: + cur_dir = pathlib.Path(__file__).parent.resolve() + + # first, try this directory structure + def _environment_install(): + candidate = cur_dir.resolve() + if (candidate / "include").exists() and (candidate / "csrc").exists(): + return candidate + return None + + def _package_install(): + # TODO: support find path by package + return None + + path = _environment_install() or _package_install() + if path is None: + raise RuntimeError("Cannot find sgl-kernel/jit path") + return path + + +KERNEL_PATH = _resolve_kernel_path() +DEFAULT_INCLUDE = [str(KERNEL_PATH / "include")] +DEFAULT_CFLAGS = ["-std=c++20", "-O3"] +DEFAULT_CUDA_CFLAGS = ["-std=c++20", "-O3", "--expt-relaxed-constexpr"] +DEFAULT_HIP_CFLAGS = [ + flag for flag in DEFAULT_CUDA_CFLAGS if flag != "--expt-relaxed-constexpr" +] +DEFAULT_LDFLAGS = [] +CPP_TEMPLATE_TYPE: TypeAlias = Union[int, float, bool, torch.dtype] + + +class CPPArgList(list[str]): + def __str__(self) -> str: + return ", ".join(self) + + +CPP_DTYPE_MAP = { + torch.float: "fp32_t", + torch.float16: "fp16_t", + torch.float8_e4m3fn: "fp8_e4m3_t", + torch.bfloat16: "bf16_t", + torch.int8: "int8_t", + torch.int64: "int64_t", +} + + +# AMD/ROCm note: +@cache_once +def is_hip_runtime() -> bool: + return bool(torch.version.hip) + + +def make_cpp_args(*args: CPP_TEMPLATE_TYPE) -> CPPArgList: + def _convert(arg: CPP_TEMPLATE_TYPE) -> str: + if isinstance(arg, bool): + return "true" if arg else "false" + if isinstance(arg, (int, float)): + return str(arg) + if isinstance(arg, torch.dtype): + return CPP_DTYPE_MAP[arg] + raise TypeError(f"Unsupported argument type for cpp template: {type(arg)}") + + return CPPArgList(_convert(arg) for arg in args) + + +def load_jit( + *args: str, + cpp_files: List[str] | None = None, + cuda_files: List[str] | None = None, + cpp_wrappers: List[Tuple[str, str]] | None = None, + cuda_wrappers: List[Tuple[str, str]] | None = None, + extra_cflags: List[str] | None = None, + extra_cuda_cflags: List[str] | None = None, + extra_ldflags: List[str] | None = None, + extra_include_paths: List[str] | None = None, + build_directory: str | None = None, +) -> Module: + """ + Loading a JIT module from C++/CUDA source files. + We define a wrapper as a tuple of (export_name, kernel_name), + where `export_name` is the name used to called from Python, + and `kernel_name` is the name of the kernel class in C++/CUDA source. + + :param args: Unique marker of the JIT module. Must be distinct for different kernels. + :type args: str + :param cpp_files: A list of C++ source files. + :type cpp_files: List[str] | None + :param cuda_files: A list of CUDA source files. + :type cuda_files: List[str] | None + :param cpp_wrappers: A list of C++ wrappers, defining the export name and kernel name. + :type cpp_wrappers: List[Tuple[str, str]] | None + :param cuda_wrappers: A list of CUDA wrappers, defining the export name and kernel name. + :type cuda_wrappers: List[Tuple[str, str]] | None + :param extra_cflags: Extra C++ compiler flags. + :type extra_cflags: List[str] | None + :param extra_cuda_cflags: Extra CUDA compiler flags. + :type extra_cuda_cflags: List[str] | None + :param extra_ldflags: Extra linker flags. + :type extra_ldflags: List[str] | None + :param extra_include_paths: Extra include paths. + :type extra_include_paths: List[str] | None + :param build_directory: The build directory for JIT compilation. + :type build_directory: str | None + :return: A just-in-time(JIT) compiled module. + :rtype: Module + """ + + from tvm_ffi.cpp import load_inline + + cpp_files = cpp_files or [] + cuda_files = cuda_files or [] + cpp_wrappers = cpp_wrappers or [] + cuda_wrappers = cuda_wrappers or [] + extra_cflags = extra_cflags or [] + extra_cuda_cflags = extra_cuda_cflags or [] + extra_ldflags = extra_ldflags or [] + extra_include_paths = extra_include_paths or [] + + # include cpp files + cpp_paths = [(KERNEL_PATH / "csrc" / f).resolve() for f in cpp_files] + cpp_sources = [f'#include "{path}"' for path in cpp_paths] + cpp_sources += [_make_wrapper(tup) for tup in cpp_wrappers] + + # include cuda files + cuda_paths = [(KERNEL_PATH / "csrc" / f).resolve() for f in cuda_files] + cuda_sources = [f'#include "{path}"' for path in cuda_paths] + cuda_sources += [_make_wrapper(tup) for tup in cuda_wrappers] + + # Override TVM_FFI_CUDA_ARCH_LIST if it does not exist. + env_key = "TVM_FFI_CUDA_ARCH_LIST" + env_existed = env_key in os.environ + selected_cuda_cflags = DEFAULT_CUDA_CFLAGS + if is_hip_runtime(): + selected_cuda_cflags = DEFAULT_HIP_CFLAGS + extra_cuda_cflags = ["-DUSE_ROCM"] + extra_cuda_cflags + if not env_existed: + os.environ[env_key] = _get_cuda_arch_list() + try: + return load_inline( + "sgl_kernel_jit_" + "_".join(str(arg) for arg in args), + cpp_sources=cpp_sources, + cuda_sources=cuda_sources, + extra_cflags=DEFAULT_CFLAGS + extra_cflags, + extra_cuda_cflags=selected_cuda_cflags + extra_cuda_cflags, + extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags, + extra_include_paths=DEFAULT_INCLUDE + extra_include_paths, + build_directory=build_directory, + ) + finally: + # Reset TVM_FFI_CUDA_ARCH_LIST to original state (not exist) + if not env_existed: + del os.environ[env_key] + + +@cache_once +def is_arch_support_pdl() -> bool: + import torch + + device = torch.cuda.current_device() + return torch.cuda.get_device_capability(device)[0] >= 9 + + +@cache_once +def _get_cuda_arch_list() -> str: + """Get the correct CUDA architecture string for TVM_FFI_CUDA_ARCH_LIST.""" + device = torch.cuda.current_device() + major, minor = torch.cuda.get_device_capability(device) + return f"{major}.{minor}" diff --git a/sglang/python/sglang/lang/__pycache__/api.cpython-311.pyc b/sglang/python/sglang/lang/__pycache__/api.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59c35855e2697e681bbaa99bdcfdf5e875ca53d6 Binary files /dev/null and b/sglang/python/sglang/lang/__pycache__/api.cpython-311.pyc differ diff --git a/sglang/python/sglang/lang/__pycache__/chat_template.cpython-311.pyc b/sglang/python/sglang/lang/__pycache__/chat_template.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f380165826e3918cb55c8897c943b797d91ebca Binary files /dev/null and b/sglang/python/sglang/lang/__pycache__/chat_template.cpython-311.pyc differ diff --git a/sglang/python/sglang/lang/__pycache__/choices.cpython-311.pyc b/sglang/python/sglang/lang/__pycache__/choices.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f94811b9c474f940703a0712ee8163763d886235 Binary files /dev/null and b/sglang/python/sglang/lang/__pycache__/choices.cpython-311.pyc differ diff --git a/sglang/python/sglang/lang/__pycache__/interpreter.cpython-311.pyc b/sglang/python/sglang/lang/__pycache__/interpreter.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e8d67124c1fd09a0dbdee758d49c2967f7424ee Binary files /dev/null and b/sglang/python/sglang/lang/__pycache__/interpreter.cpython-311.pyc differ diff --git a/sglang/python/sglang/lang/__pycache__/ir.cpython-311.pyc b/sglang/python/sglang/lang/__pycache__/ir.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..838e6bd56f1cca7cb29c189bb5354323d4f4ee1c Binary files /dev/null and b/sglang/python/sglang/lang/__pycache__/ir.cpython-311.pyc differ diff --git a/sglang/python/sglang/lang/__pycache__/tracer.cpython-311.pyc b/sglang/python/sglang/lang/__pycache__/tracer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db4b64dc86aefd36f90cb772249f538689cc22b0 Binary files /dev/null and b/sglang/python/sglang/lang/__pycache__/tracer.cpython-311.pyc differ diff --git a/sglang/python/sglang/lang/api.py b/sglang/python/sglang/lang/api.py new file mode 100644 index 0000000000000000000000000000000000000000..745c656ee12fe7e68a1e869367f9e21969fe0056 --- /dev/null +++ b/sglang/python/sglang/lang/api.py @@ -0,0 +1,292 @@ +"""Public APIs of the language.""" + +import re +from typing import Callable, List, Optional, Union + +from sglang.global_config import global_config +from sglang.lang.backend.base_backend import BaseBackend +from sglang.lang.choices import ChoicesSamplingMethod, token_length_normalized +from sglang.lang.ir import ( + SglExpr, + SglExprList, + SglFunction, + SglGen, + SglImage, + SglRoleBegin, + SglRoleEnd, + SglSelect, + SglSeparateReasoning, + SglVideo, +) + + +def function( + func: Optional[Callable] = None, num_api_spec_tokens: Optional[int] = None +): + if func: + return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens) + + def decorator(func): + return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens) + + return decorator + + +def Runtime(*args, **kwargs): + # Avoid importing unnecessary dependency + from sglang.lang.backend.runtime_endpoint import Runtime + + return Runtime(*args, **kwargs) + + +def Engine(*args, **kwargs): + # Avoid importing unnecessary dependency + from sglang.srt.entrypoints.engine import Engine + + return Engine(*args, **kwargs) + + +def set_default_backend(backend: BaseBackend): + global_config.default_backend = backend + + +def flush_cache(backend: Optional[BaseBackend] = None): + backend = backend or global_config.default_backend + if backend is None: + return False + + # If backend is Runtime + if hasattr(backend, "endpoint"): + backend = backend.endpoint + return backend.flush_cache() + + +def get_server_info(backend: Optional[BaseBackend] = None): + backend = backend or global_config.default_backend + if backend is None: + return None + + # If backend is Runtime + if hasattr(backend, "endpoint"): + backend = backend.endpoint + return backend.get_server_info() + + +def gen( + name: Optional[str] = None, + max_tokens: Optional[int] = None, + min_tokens: Optional[int] = None, + n: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, + stop_regex: Optional[Union[str, List[str]]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + min_p: Optional[float] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + ignore_eos: Optional[bool] = None, + return_logprob: Optional[bool] = None, + logprob_start_len: Optional[int] = None, + top_logprobs_num: Optional[int] = None, + return_text_in_logprobs: Optional[bool] = None, + dtype: Optional[Union[type, str]] = None, + choices: Optional[List[str]] = None, + choices_method: Optional[ChoicesSamplingMethod] = None, + regex: Optional[str] = None, + json_schema: Optional[str] = None, +): + """Call the model to generate. See the meaning of the arguments in docs/backend/sampling_params.md""" + + if choices: + return SglSelect( + name, + choices, + 0.0 if temperature is None else temperature, + token_length_normalized if choices_method is None else choices_method, + ) + + # check regex is valid + if regex is not None: + try: + re.compile(regex) + except re.error as e: + raise e + + return SglGen( + name, + max_tokens, + min_tokens, + n, + stop, + stop_token_ids, + stop_regex, + temperature, + top_p, + top_k, + min_p, + frequency_penalty, + presence_penalty, + ignore_eos, + return_logprob, + logprob_start_len, + top_logprobs_num, + return_text_in_logprobs, + dtype, + regex, + json_schema, + ) + + +def gen_int( + name: Optional[str] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, + stop_regex: Optional[Union[str, List[str]]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + min_p: Optional[float] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + ignore_eos: Optional[bool] = None, + return_logprob: Optional[bool] = None, + logprob_start_len: Optional[int] = None, + top_logprobs_num: Optional[int] = None, + return_text_in_logprobs: Optional[bool] = None, +): + return SglGen( + name, + max_tokens, + None, + n, + stop, + stop_token_ids, + stop_regex, + temperature, + top_p, + top_k, + min_p, + frequency_penalty, + presence_penalty, + ignore_eos, + return_logprob, + logprob_start_len, + top_logprobs_num, + return_text_in_logprobs, + int, + None, + ) + + +def gen_string( + name: Optional[str] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, + stop_regex: Optional[Union[str, List[str]]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + min_p: Optional[float] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + ignore_eos: Optional[bool] = None, + return_logprob: Optional[bool] = None, + logprob_start_len: Optional[int] = None, + top_logprobs_num: Optional[int] = None, + return_text_in_logprobs: Optional[bool] = None, +): + return SglGen( + name, + max_tokens, + None, + n, + stop, + stop_token_ids, + stop_regex, + temperature, + top_p, + top_k, + min_p, + frequency_penalty, + presence_penalty, + ignore_eos, + return_logprob, + logprob_start_len, + top_logprobs_num, + return_text_in_logprobs, + str, + None, + ) + + +def image(expr: SglExpr): + return SglImage(expr) + + +def video(path: str, num_frames: int): + return SglVideo(path, num_frames) + + +def select( + name: Optional[str] = None, + choices: Optional[List[str]] = None, + temperature: float = 0.0, + choices_method: ChoicesSamplingMethod = token_length_normalized, +): + assert choices is not None + return SglSelect(name, choices, temperature, choices_method) + + +def _role_common(name: str, expr: Optional[SglExpr] = None): + if expr is None: + return SglExprList([SglRoleBegin(name), SglRoleEnd(name)]) + else: + return SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)]) + + +def system(expr: Optional[SglExpr] = None): + return _role_common("system", expr) + + +def user(expr: Optional[SglExpr] = None): + return _role_common("user", expr) + + +def assistant(expr: Optional[SglExpr] = None): + return _role_common("assistant", expr) + + +def system_begin(): + return SglRoleBegin("system") + + +def system_end(): + return SglRoleEnd("system") + + +def user_begin(): + return SglRoleBegin("user") + + +def user_end(): + return SglRoleEnd("user") + + +def assistant_begin(): + return SglRoleBegin("assistant") + + +def assistant_end(): + return SglRoleEnd("assistant") + + +def separate_reasoning( + expr: Optional[SglExpr] = None, model_type: Optional[str] = None +): + return SglExprList([expr, SglSeparateReasoning(model_type, expr=expr)]) diff --git a/sglang/python/sglang/lang/backend/__pycache__/base_backend.cpython-311.pyc b/sglang/python/sglang/lang/backend/__pycache__/base_backend.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a737fdecde1b046903b1395972a7b3ac11c27bef Binary files /dev/null and b/sglang/python/sglang/lang/backend/__pycache__/base_backend.cpython-311.pyc differ diff --git a/sglang/python/sglang/lang/backend/__pycache__/openai.cpython-311.pyc b/sglang/python/sglang/lang/backend/__pycache__/openai.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa74c6c9d2c62b243e7a29b4ffa17d1edec0d0d8 Binary files /dev/null and b/sglang/python/sglang/lang/backend/__pycache__/openai.cpython-311.pyc differ diff --git a/sglang/python/sglang/lang/backend/__pycache__/runtime_endpoint.cpython-311.pyc b/sglang/python/sglang/lang/backend/__pycache__/runtime_endpoint.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a2e746d32f908e1d42027cc9868cb61b6b60e95 Binary files /dev/null and b/sglang/python/sglang/lang/backend/__pycache__/runtime_endpoint.cpython-311.pyc differ diff --git a/sglang/python/sglang/lang/backend/anthropic.py b/sglang/python/sglang/lang/backend/anthropic.py new file mode 100644 index 0000000000000000000000000000000000000000..4918a17032b88db14ede678d831345e0af6c9436 --- /dev/null +++ b/sglang/python/sglang/lang/backend/anthropic.py @@ -0,0 +1,73 @@ +from sglang.lang.backend.base_backend import BaseBackend +from sglang.lang.chat_template import get_chat_template +from sglang.lang.interpreter import StreamExecutor +from sglang.lang.ir import SglSamplingParams + +try: + import anthropic +except ImportError as e: + anthropic = e + + +class Anthropic(BaseBackend): + def __init__(self, model_name, *args, **kwargs): + super().__init__() + + if isinstance(anthropic, Exception): + raise anthropic + + self.model_name = model_name + self.chat_template = get_chat_template("claude") + self.client = anthropic.Anthropic(*args, **kwargs) + + def get_chat_template(self): + return self.chat_template + + def generate( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + ): + if s.messages_: + messages = s.messages_ + else: + messages = [{"role": "user", "content": s.text_}] + + if messages and messages[0]["role"] == "system": + system = messages.pop(0)["content"] + else: + system = "" + + ret = self.client.messages.create( + model=self.model_name, + system=system, + messages=messages, + **sampling_params.to_anthropic_kwargs(), + ) + comp = ret.content[0].text + + return comp, {} + + def generate_stream( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + ): + if s.messages_: + messages = s.messages_ + else: + messages = [{"role": "user", "content": s.text_}] + + if messages and messages[0]["role"] == "system": + system = messages.pop(0)["content"] + else: + system = "" + + with self.client.messages.stream( + model=self.model_name, + system=system, + messages=messages, + **sampling_params.to_anthropic_kwargs(), + ) as stream: + for text in stream.text_stream: + yield text, {} diff --git a/sglang/python/sglang/lang/backend/base_backend.py b/sglang/python/sglang/lang/backend/base_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..62dd50416354e78ce606489ae284bb5443a0b452 --- /dev/null +++ b/sglang/python/sglang/lang/backend/base_backend.py @@ -0,0 +1,82 @@ +from typing import List, Optional, Union + +from sglang.lang.chat_template import get_chat_template +from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod +from sglang.lang.interpreter import StreamExecutor +from sglang.lang.ir import SglSamplingParams + + +class BaseBackend: + def __init__(self) -> None: + self.support_concate_and_append = False + self.chat_template = get_chat_template("default") + + def get_model_name(self): + raise NotImplementedError() + + def get_chat_template(self): + return self.chat_template + + def cache_prefix(self, prefix_str: str): + pass + + def uncache_prefix(self, rid: str): + pass + + def end_request(self, rid: Union[str, List[str]]): + pass + + def begin_program(self, s: StreamExecutor): + pass + + def end_program(self, s: Union[StreamExecutor, List[StreamExecutor]]): + pass + + def commit_lazy_operations(self, s: StreamExecutor): + pass + + def fork_program( + self, + src: StreamExecutor, + dst: List[StreamExecutor], + position_ids_offset: Optional[List[int]] = None, + ): + pass + + def fill_image(self, s: StreamExecutor): + pass + + def generate( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + ): + raise NotImplementedError() + + def generate_stream( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + ): + raise NotImplementedError() + + def select( + self, + s: StreamExecutor, + choices: List[str], + temperature: float, + choices_method: Optional[ChoicesSamplingMethod] = None, + ) -> ChoicesDecision: + raise NotImplementedError() + + def concatenate_and_append(self, src_rids: List[str], dst_rid: str): + raise NotImplementedError() + + def shutdown(self): + pass + + def flush_cache(self): + pass + + def get_server_info(self): + pass diff --git a/sglang/python/sglang/lang/backend/litellm.py b/sglang/python/sglang/lang/backend/litellm.py new file mode 100644 index 0000000000000000000000000000000000000000..5803b5431e2db92bf57b7dd68b04ebb14e387532 --- /dev/null +++ b/sglang/python/sglang/lang/backend/litellm.py @@ -0,0 +1,90 @@ +from typing import Mapping, Optional + +from sglang.lang.backend.base_backend import BaseBackend +from sglang.lang.chat_template import get_chat_template_by_model_path +from sglang.lang.interpreter import StreamExecutor +from sglang.lang.ir import SglSamplingParams + +try: + import litellm +except ImportError as e: + litellm = e + litellm.num_retries = 1 + + +class LiteLLM(BaseBackend): + def __init__( + self, + model_name, + chat_template=None, + api_key=None, + organization: Optional[str] = None, + base_url: Optional[str] = None, + timeout: Optional[float] = 600, + max_retries: Optional[int] = litellm.num_retries, + default_headers: Optional[Mapping[str, str]] = None, + ): + super().__init__() + + if isinstance(litellm, Exception): + raise litellm + + self.model_name = model_name + + self.chat_template = chat_template or get_chat_template_by_model_path( + model_name + ) + + self.client_params = { + "api_key": api_key, + "organization": organization, + "base_url": base_url, + "timeout": timeout, + "max_retries": max_retries, + "default_headers": default_headers, + } + + def get_chat_template(self): + return self.chat_template + + def generate( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + ): + if s.messages_: + messages = s.messages_ + else: + messages = [{"role": "user", "content": s.text_}] + + ret = litellm.completion( + model=self.model_name, + messages=messages, + **self.client_params, + **sampling_params.to_litellm_kwargs(), + ) + comp = ret.choices[0].message.content + + return comp, {} + + def generate_stream( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + ): + if s.messages_: + messages = s.messages_ + else: + messages = [{"role": "user", "content": s.text_}] + + ret = litellm.completion( + model=self.model_name, + messages=messages, + stream=True, + **self.client_params, + **sampling_params.to_litellm_kwargs(), + ) + for chunk in ret: + text = chunk.choices[0].delta.content + if text is not None: + yield text, {} diff --git a/sglang/python/sglang/lang/backend/openai.py b/sglang/python/sglang/lang/backend/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..a2d006bb7c241f341219614e72e527fdfbb2ca7f --- /dev/null +++ b/sglang/python/sglang/lang/backend/openai.py @@ -0,0 +1,475 @@ +import dataclasses +import logging +import time +import warnings +from typing import List, Optional, Union + +import numpy as np + +from sglang.lang.backend.base_backend import BaseBackend +from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path +from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod +from sglang.lang.interpreter import StreamExecutor +from sglang.lang.ir import SglSamplingParams + +try: + import openai + import tiktoken +except ImportError as e: + openai = tiktoken = e + + +logger = logging.getLogger(__name__) + + +def create_logit_bias_int(tokenizer): + """Get logit bias for integer numbers.""" + int_token_ids = [] + + tokens = tokenizer._mergeable_ranks + for token, token_id in tokens.items(): + s = tokenizer.decode([token_id]) + if all([c.isdigit() for c in s]) or s in [" "]: + int_token_ids.append(token_id) + if len(int_token_ids) >= 300: # OpenAI API limit + break + special_tokens = tokenizer._special_tokens + mask = {t: 100 for t in int_token_ids[:299]} + mask[special_tokens["<|endoftext|>"]] = 100 + return mask + + +INSTRUCT_MODEL_NAMES = [ + "gpt-3.5-turbo-instruct", +] + + +@dataclasses.dataclass +class TokenUsage: + prompt_tokens: int + completion_tokens: int + + def reset(self): + self.prompt_tokens = self.completion_tokens = 0 + + +class OpenAI(BaseBackend): + def __init__( + self, + model_name: str, + is_chat_model: Optional[bool] = None, + chat_template: Optional[ChatTemplate] = None, + is_azure: bool = False, + *args, + **kwargs, + ): + super().__init__() + + if isinstance(openai, Exception): + raise openai + + if is_azure: + self.client = openai.AzureOpenAI(*args, **kwargs) + else: + self.client = openai.OpenAI(*args, **kwargs) + + self.model_name = model_name + try: + self.tokenizer = tiktoken.encoding_for_model(model_name) + except KeyError: + self.tokenizer = tiktoken.get_encoding("cl100k_base") + self.logit_bias_int = create_logit_bias_int(self.tokenizer) + + self.chat_template = chat_template or get_chat_template_by_model_path( + model_name + ) + + if is_chat_model is not None: + self.is_chat_model = is_chat_model + else: + if model_name in INSTRUCT_MODEL_NAMES: + self.is_chat_model = False + else: + self.is_chat_model = True + + self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0] + + # Usage + self.token_usage = TokenUsage(0, 0) + + # API speculative execution + # TODO(ying): This does not support multi-threading (run_batch) + self.spec_kwargs = {} + self.spec_format = [] + self.spec_max_num_tries = 3 + + def get_chat_template(self): + return self.chat_template + + def _prepare_spec_execution( + self, + sampling_params: SglSamplingParams, + num_api_spec_tokens: int, + spec_var_name: str, + ): + if "max_tokens" not in self.spec_kwargs: + self.spec_kwargs["max_tokens"] = num_api_spec_tokens + else: + assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens + + params = sampling_params.to_openai_kwargs() + for key, value in params.items(): + if key in ["stop"]: + continue + if key in ["max_tokens"]: + warnings.warn( + "The parameter max_tokens will be overwritten by speculated number of tokens." + ) + continue + if key not in self.spec_kwargs: + self.spec_kwargs[key] = value + else: + assert ( + value == self.spec_kwargs[key] + ), "sampling parameters should be consistent if turn on api speculative execution." + self.spec_format.append( + {"text": "", "stop": params["stop"], "name": spec_var_name} + ) + return "", {} + + def generate( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + spec_var_name: str = None, + ): + if sampling_params.dtype is None: + if self.is_chat_model: + if s.num_api_spec_tokens is None: + if not s.text_.endswith(self.chat_prefix): + raise RuntimeError( + "This use case is not supported if api speculative execution is off. " + "For OpenAI chat models, sgl.gen must be right after sgl.assistant. " + "Example of adding api speculative execution: @function(num_api_spec_tokens=128)." + ) + prompt = s.messages_ + else: + return self._prepare_spec_execution( + sampling_params, s.num_api_spec_tokens, spec_var_name + ) + else: + prompt = s.text_ + + kwargs = sampling_params.to_openai_kwargs() + if ( + self.model_name.startswith("o1") + or self.model_name.startswith("o3") + or "o1" in self.model_name + ): + kwargs.pop("max_tokens", None) + else: + kwargs.pop("max_completion_tokens", None) + + comp = openai_completion( + client=self.client, + token_usage=self.token_usage, + is_chat=self.is_chat_model, + model=self.model_name, + prompt=prompt, + **kwargs, + ) + # Keep the returned list (or string) as is. + elif sampling_params.dtype in [str, "str", "string"]: + assert ( + not self.is_chat_model + ), "constrained type not supported on chat model" + kwargs = sampling_params.to_openai_kwargs() + kwargs.pop("stop") + comp = openai_completion( + client=self.client, + token_usage=self.token_usage, + is_chat=self.is_chat_model, + model=self.model_name, + prompt=s.text_ + '"', + stop='"', + **kwargs, + ) + # Wrap each element in quotes if we have a list. + if isinstance(comp, list): + comp = ['"' + x + '"' for x in comp] + else: + comp = '"' + comp + '"' + elif sampling_params.dtype in [int, "int"]: + assert ( + not self.is_chat_model + ), "constrained type not supported on chat model" + kwargs = sampling_params.to_openai_kwargs() + kwargs.pop("stop") + comp = openai_completion( + client=self.client, + token_usage=self.token_usage, + is_chat=self.is_chat_model, + model=self.model_name, + prompt=s.text_, + logit_bias=self.logit_bias_int, + stop=[" "], + **kwargs, + ) + # Leave as a list if that's what is returned. + else: + raise ValueError(f"Unknown dtype: {sampling_params.dtype}") + + return comp, {} + + def spec_fill(self, value: str): + assert self.is_chat_model + self.spec_format.append({"text": value, "stop": None, "name": None}) + + def spec_pattern_match(self, comp): + for i, term in enumerate(self.spec_format): + text = term["text"] + if text != "": + if comp.startswith(text): + comp = comp[len(text) :] + else: + return False + else: + pos = comp.find(term["stop"]) + if pos != -1: + term["text"] = comp[:pos] + comp = comp[pos:] + else: + if i == len(self.spec_format) - 1: + term["text"] = comp + else: + return False + return True + + def role_end_generate( + self, + s: StreamExecutor, + ): + if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix): + return + + comp = "" + if not all(x["name"] is None for x in self.spec_format): + # TODO(ying): throw errors or warnings + for i in range(self.spec_max_num_tries): + comp = openai_completion( + client=self.client, + token_usage=self.token_usage, + is_chat=self.is_chat_model, + model=self.model_name, + prompt=s.messages_, + **self.spec_kwargs, + ) + # Use a string for pattern matching. + comp_for_match = comp[0] if isinstance(comp, list) else comp + if self.spec_pattern_match(comp_for_match): + break + + for term in self.spec_format: + s.text_ += term["text"] + name = term["name"] + if name is not None: + s.variables[name] = term["text"] + s.meta_info[name] = {} + s.variable_event[name].set() + + self.spec_kwargs = {} + self.spec_format = [] + + def generate_stream( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + ): + if sampling_params.dtype is None: + if self.is_chat_model: + if not s.text_.endswith(self.chat_prefix): + raise RuntimeError( + "This use case is not supported. " + "For OpenAI chat models, sgl.gen must be right after sgl.assistant" + ) + prompt = s.messages_ + else: + prompt = s.text_ + + kwargs = sampling_params.to_openai_kwargs() + generator = openai_completion_stream( + client=self.client, + token_usage=self.token_usage, + is_chat=self.is_chat_model, + model=self.model_name, + prompt=prompt, + **kwargs, + ) + return generator + else: + raise ValueError(f"Unknown dtype: {sampling_params.dtype}") + + def select( + self, + s: StreamExecutor, + choices: List[str], + temperature: float, + choices_method: ChoicesSamplingMethod, + ) -> ChoicesDecision: + """Note: `choices_method` is not used by the OpenAI backend.""" + if self.is_chat_model: + raise NotImplementedError( + "select/choices is not supported for chat models. " + "Please try to use a non-chat model such as gpt-3.5-turbo-instruct" + ) + + n_choices = len(choices) + token_ids = [self.tokenizer.encode(x) for x in choices] + scores = [0] * n_choices + valid = [len(x) > 0 for x in token_ids] + prompt_tokens = self.tokenizer.encode(s.text_) + + max_len = max([len(x) for x in token_ids]) + for step in range(max_len): + # Build logit bias + logit_bias = {} + for i in range(n_choices): + if valid[i]: + logit_bias[token_ids[i][step]] = 100 + + # Call API + ret = self.client.completions.create( + model=self.model_name, + prompt=prompt_tokens, + logit_bias=logit_bias, + max_tokens=1, + temperature=temperature, + ) + ret_str = ret.choices[0].text + ret_token = self.tokenizer.encode(ret_str)[0] + self.token_usage.prompt_tokens += ret.usage.prompt_tokens + self.token_usage.completion_tokens = ret.usage.completion_tokens + + # TODO: + # 1. return logits as the scores + # 2. compute logits of the full choice + # 3. consider chunk-based decoding + + # Update valid + hit = False + for i in range(n_choices): + if valid[i]: + if step == len(token_ids[i]) - 1: + valid[i] = False + + if ret_token == token_ids[i][step]: + scores[i] += 1 + hit = True + else: + valid[i] = False + assert hit + + if np.sum(valid) <= 1: + break + + prompt_tokens.append(ret_token) + + return ChoicesDecision( + decision=choices[np.argmax(scores)], + meta_info={"scores": scores}, + ) + + +def openai_completion( + client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs +) -> Union[str, List[str]]: + # if "ebnf" is in kwargs, warn and remove + if "ebnf" in kwargs: + warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.") + del kwargs["ebnf"] + + for attempt in range(retries): + try: + if is_chat: + if "stop" in kwargs and kwargs["stop"] is None: + kwargs.pop("stop") + ret = client.chat.completions.create(messages=prompt, **kwargs) + if len(ret.choices) == 1: + comp = ret.choices[0].message.content + else: + comp = [c.message.content for c in ret.choices] + else: + ret = client.completions.create(prompt=prompt, **kwargs) + if isinstance(prompt, (list, tuple)): + comp = [c.text for c in ret.choices] + else: + comp = ret.choices[0].text + if len(ret.choices) > 1: + comp = [c.text for c in ret.choices] + + token_usage.prompt_tokens += ret.usage.prompt_tokens + token_usage.completion_tokens += ret.usage.completion_tokens + break + except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e: + logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...") + time.sleep(5) + if attempt == retries - 1: + raise e + except Exception as e: + logger.error(f"RuntimeError {e}.") + raise e + + return comp + + +def openai_completion_stream( + client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs +): + # if "ebnf" is in kwargs, warn and remove + if "ebnf" in kwargs: + warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.") + del kwargs["ebnf"] + + for attempt in range(retries): + try: + if is_chat: + if "stop" in kwargs and kwargs["stop"] is None: + kwargs.pop("stop") + generator = client.chat.completions.create( + messages=prompt, + stream=True, + stream_options={"include_usage": True}, + **kwargs, + ) + for ret in generator: + if len(ret.choices) == 0: + continue + try: + content = ret.choices[0].delta.content + except IndexError: + content = None + yield content or "", {} + else: + generator = client.completions.create( + prompt=prompt, + stream=True, + stream_options={"include_usage": True}, + **kwargs, + ) + for ret in generator: + if len(ret.choices) == 0: + continue + content = ret.choices[0].text + yield content or "", {} + + token_usage.prompt_tokens += ret.usage.prompt_tokens + token_usage.completion_tokens += ret.usage.completion_tokens + break + except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e: + logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...") + time.sleep(5) + if attempt == retries - 1: + raise e + except Exception as e: + logger.error(f"RuntimeError {e}.") + raise e diff --git a/sglang/python/sglang/lang/backend/runtime_endpoint.py b/sglang/python/sglang/lang/backend/runtime_endpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..09a0116c5ed8315eaf0678bdd10997f2311b49a9 --- /dev/null +++ b/sglang/python/sglang/lang/backend/runtime_endpoint.py @@ -0,0 +1,544 @@ +import atexit +import json +import multiprocessing +import time +import warnings +from typing import Dict, List, Optional, Union + +import aiohttp +import requests + +from sglang.global_config import global_config +from sglang.lang.backend.base_backend import BaseBackend +from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path +from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod +from sglang.lang.interpreter import StreamExecutor +from sglang.lang.ir import ( + REGEX_BOOL, + REGEX_FLOAT, + REGEX_INT, + REGEX_STR, + SglSamplingParams, +) +from sglang.utils import http_request + + +class RuntimeEndpoint(BaseBackend): + def __init__( + self, + base_url: str, + api_key: Optional[str] = None, + verify: Optional[str] = None, + chat_template_name: Optional[str] = None, + ): + super().__init__() + self.support_concate_and_append = True + + self.base_url = base_url + self.api_key = api_key + self.verify = verify + + res = http_request( + self.base_url + "/get_model_info", + api_key=self.api_key, + verify=self.verify, + ) + self._assert_success(res) + self.model_info = res.json() + + if chat_template_name: + self.chat_template = get_chat_template(chat_template_name) + else: + self.chat_template = get_chat_template_by_model_path( + self.model_info["model_path"] + ) + + def get_model_name(self): + return self.model_info["model_path"] + + def flush_cache(self): + res = http_request( + self.base_url + "/flush_cache", + api_key=self.api_key, + verify=self.verify, + method="POST", + ) + self._assert_success(res) + + def get_server_info(self): + res = http_request( + self.base_url + "/get_server_info", + api_key=self.api_key, + verify=self.verify, + ) + self._assert_success(res) + return res.json() + + def get_chat_template(self): + return self.chat_template + + def cache_prefix(self, prefix_str: str): + res = http_request( + self.base_url + "/generate", + json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}}, + api_key=self.api_key, + verify=self.verify, + ) + self._assert_success(res) + + def start_profile(self): + res = http_request( + self.base_url + "/start_profile", + api_key=self.api_key, + verify=self.verify, + ) + self._assert_success(res) + + def stop_profile(self): + res = http_request( + self.base_url + "/stop_profile", + api_key=self.api_key, + verify=self.verify, + ) + self._assert_success(res) + + def commit_lazy_operations(self, s: StreamExecutor): + data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} + self._add_images(s, data) + res = http_request( + self.base_url + "/generate", + json=data, + api_key=self.api_key, + verify=self.verify, + ) + self._assert_success(res) + + def fill_image(self, s: StreamExecutor): + data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} + self._add_images(s, data) + res = http_request( + self.base_url + "/generate", + json=data, + api_key=self.api_key, + verify=self.verify, + ) + self._assert_success(res) + + def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams): + if sampling_params.dtype is None: + return + + if sampling_params.stop == (): + sampling_params.stop = [] + + dtype_regex = None + if sampling_params.dtype in ["int", int]: + + dtype_regex = REGEX_INT + sampling_params.stop.extend([" ", "\n"]) + elif sampling_params.dtype in ["float", float]: + + dtype_regex = REGEX_FLOAT + sampling_params.stop.extend([" ", "\n"]) + elif sampling_params.dtype in ["str", str]: + + dtype_regex = REGEX_STR + elif sampling_params.dtype in ["bool", bool]: + + dtype_regex = REGEX_BOOL + else: + raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}") + + if dtype_regex is not None and sampling_params.regex is not None: + warnings.warn( + f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}" + ) + + sampling_params.regex = dtype_regex + + def generate( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + ): + self._handle_dtype_to_regex(sampling_params) + data = { + "text": s.text_, + "sampling_params": { + "skip_special_tokens": global_config.skip_special_tokens_in_output, + "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, + **sampling_params.to_srt_kwargs(), + }, + } + + for item in [ + "return_logprob", + "logprob_start_len", + "top_logprobs_num", + "return_text_in_logprobs", + ]: + value = getattr(sampling_params, item, None) + if value is not None: + data[item] = value + + self._add_images(s, data) + + res = http_request( + self.base_url + "/generate", + json=data, + api_key=self.api_key, + verify=self.verify, + ) + self._assert_success(res) + + obj = res.json() + comp = obj["text"] + return comp, obj["meta_info"] + + def generate_stream( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + ): + self._handle_dtype_to_regex(sampling_params) + + data = { + "text": s.text_, + "sampling_params": { + "skip_special_tokens": global_config.skip_special_tokens_in_output, + "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, + **sampling_params.to_srt_kwargs(), + }, + } + + for item in [ + "return_logprob", + "logprob_start_len", + "top_logprobs_num", + "return_text_in_logprobs", + ]: + value = getattr(sampling_params, item, None) + if value is not None: + data[item] = value + + data["stream"] = True + self._add_images(s, data) + + res = http_request( + self.base_url + "/generate", + json=data, + stream=True, + api_key=self.api_key, + verify=self.verify, + ) + self._assert_success(res) + pos = 0 + + for chunk in res.iter_lines(decode_unicode=False): + chunk = chunk.decode("utf-8") + if chunk and chunk.startswith("data:"): + if chunk == "data: [DONE]": + break + data = json.loads(chunk[5:].strip("\n")) + chunk_text = data["text"][pos:] + meta_info = data["meta_info"] + pos += len(chunk_text) + yield chunk_text, meta_info + + def select( + self, + s: StreamExecutor, + choices: List[str], + temperature: float, + choices_method: ChoicesSamplingMethod, + ) -> ChoicesDecision: + assert temperature <= 1e-5 + + # Cache common prefix + data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} + obj = self._generate_http_request(s, data) + prompt_len = obj["meta_info"]["prompt_tokens"] + logprob_start_len = max(prompt_len - 2, 0) # For token healing + + # Compute logprob + data = { + "text": [s.text_ + c for c in choices], + "sampling_params": { + "max_new_tokens": 0, + "temperature": 0, + }, + "return_logprob": True, + "return_text_in_logprobs": True, + "logprob_start_len": logprob_start_len, + } + obj = self._generate_http_request(s, data) + + input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj] + output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj] + normalized_prompt_logprobs = [ + compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"]) + for r in obj + ] + + # Remove extra token if no token healing occurred + for i in range(len(input_token_logprobs)): + healed_token_str = input_token_logprobs[i][0][-1] + if s.text_.endswith(healed_token_str): + healed_token_logprob = input_token_logprobs[i][0][0] + normalized_prompt_logprobs[i] = ( + normalized_prompt_logprobs[i] * len(input_token_logprobs[i]) + - healed_token_logprob + ) / (len(input_token_logprobs[i]) - 1) + input_token_logprobs[i] = input_token_logprobs[i][1:] + + # Compute unconditional logprobs if required + if choices_method.requires_unconditional_logprobs: + input_ids = [[el[1] for el in subl] for subl in input_token_logprobs] + data = { + "input_ids": input_ids, + "sampling_params": {"max_new_tokens": 0}, + "return_logprob": True, + } + obj = self._generate_http_request(s, data) + unconditional_token_logprobs = [ + r["meta_info"]["input_token_logprobs"] for r in obj + ] + else: + unconditional_token_logprobs = None + + return choices_method( + choices=choices, + normalized_prompt_logprobs=normalized_prompt_logprobs, + input_token_logprobs=input_token_logprobs, + output_token_logprobs=output_token_logprobs, + unconditional_token_logprobs=unconditional_token_logprobs, + ) + + def concatenate_and_append(self, src_rids: List[str], dst_rid: str): + res = http_request( + self.base_url + "/concate_and_append_request", + json={"src_rids": src_rids, "dst_rid": dst_rid}, + api_key=self.api_key, + verify=self.verify, + ) + self._assert_success(res) + + def _generate_http_request(self, s: StreamExecutor, data): + self._add_images(s, data) + res = http_request( + self.base_url + "/generate", + json=data, + api_key=self.api_key, + verify=self.verify, + ) + self._assert_success(res) + return res.json() + + def _add_images(self, s: StreamExecutor, data): + if s.images_: + assert len(s.images_) == 1, "Only support one image." + data["image_data"] = s.images_[0][1] + + def _assert_success(self, res): + if res.status_code != 200: + try: + content = res.json() + except json.JSONDecodeError: + content = res.text + raise RuntimeError(content) + + +def compute_normalized_prompt_logprobs(input_logprobs): + values = [x[0] for x in input_logprobs if x[0]] + return sum(values) / len(values) + + +class Runtime: + """ + A wrapper for the HTTP server. + This is used for launching the server in a python program without + using the command line interface. + + It is mainly used for the frontend language. + You should use the Engine class if you want to do normal offline processing without the frontend language. + """ + + def __init__( + self, + log_level: str = "error", + launch_timeout: float = 300.0, + *args, + **kwargs, + ): + """See the arguments in server_args.py::ServerArgs + + Args: + log_level: Log level for the server. + timeout: Timeout in seconds for waiting for the server to start. + *args: Additional arguments passed to ServerArgs. + **kwargs: Additional keyword arguments passed to ServerArgs. + """ + # We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run + # client code without installing SRT server and its dependency if they want. + from sglang.srt.entrypoints.http_server import launch_server + from sglang.srt.server_args import ServerArgs + from sglang.srt.utils import is_port_available + + self.server_args = ServerArgs(*args, log_level=log_level, **kwargs) + + # Pre-allocate ports + for port in range(self.server_args.port, 40000): + if is_port_available(port): + break + self.server_args.port = port + + self.url = self.server_args.url() + self.generate_url = self.url + "/generate" + + # NOTE: We store pid instead of proc to fix some issues during __delete__ + self.pid = None + + ctx = multiprocessing.get_context("spawn") + proc = ctx.Process( + target=launch_server, + args=(self.server_args,), + ) + proc.start() + self.pid = proc.pid + + # Before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() + atexit.register(self.shutdown) + + # Wait for server to be ready by polling /health_generate + start_time = time.time() + with requests.Session() as session: + while time.time() - start_time < launch_timeout: + try: + response = session.get(f"{self.url}/health_generate") + if response.status_code == 200: + break + except requests.RequestException: + pass + + if not proc.is_alive(): + self.shutdown() + raise RuntimeError( + "Initialization failed. Please see the error messages above." + ) + + time.sleep(2) + else: + self.shutdown() + raise TimeoutError("Server failed to start within the timeout period.") + + self.endpoint = RuntimeEndpoint(self.url) + + def shutdown(self): + from sglang.srt.utils import kill_process_tree + + if self.pid is not None: + kill_process_tree(self.pid) + self.pid = None + + def start_profile(self): + self.endpoint.start_profile() + + def stop_profile(self): + self.endpoint.stop_profile() + + def cache_prefix(self, prefix: str): + self.endpoint.cache_prefix(prefix) + + def get_tokenizer(self): + from sglang.srt.utils.hf_transformers_utils import get_tokenizer + + return get_tokenizer( + self.server_args.tokenizer_path, + tokenizer_mode=self.server_args.tokenizer_mode, + trust_remote_code=self.server_args.trust_remote_code, + revision=self.server_args.revision, + ) + + async def async_generate( + self, + prompt: str, + sampling_params: Optional[Dict] = None, + ): + if self.server_args.skip_tokenizer_init: + json_data = { + "input_ids": prompt, + "sampling_params": sampling_params, + "stream": True, + } + else: + json_data = { + "text": prompt, + "sampling_params": sampling_params, + "stream": True, + } + pos = 0 + + timeout = aiohttp.ClientTimeout(total=3 * 3600) + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: + async with session.post(self.generate_url, json=json_data) as response: + async for chunk, _ in response.content.iter_chunks(): + chunk = chunk.decode("utf-8") + if chunk and chunk.startswith("data:"): + if chunk == "data: [DONE]\n\n": + break + data = json.loads(chunk[5:].strip("\n")) + if "text" in data: + cur = data["text"][pos:] + if cur: + yield cur + pos += len(cur) + else: + yield data + + add_request = async_generate + + def generate( + self, + prompt: Union[str, List[str]], + sampling_params: Optional[Dict] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, + ): + json_data = { + "text": prompt, + "sampling_params": sampling_params, + "return_logprob": return_logprob, + "logprob_start_len": logprob_start_len, + "top_logprobs_num": top_logprobs_num, + "lora_path": lora_path, + } + assert not isinstance(lora_path, list) or len(lora_path) == len(prompt) + response = requests.post( + self.url + "/generate", + json=json_data, + ) + return json.dumps(response.json()) + + def encode( + self, + prompt: Union[str, List[str], List[Dict], List[List[Dict]]], + ): + json_data = {"text": prompt} + response = requests.post(self.url + "/encode", json=json_data) + return json.dumps(response.json()) + + async def get_server_info(self): + async with aiohttp.ClientSession() as session: + async with session.get(f"{self.url}/get_server_info") as response: + if response.status == 200: + return await response.json() + else: + error_data = await response.json() + raise RuntimeError( + f"Failed to get server info. {error_data['error']['message']}" + ) + + def __del__(self): + self.shutdown() diff --git a/sglang/python/sglang/lang/backend/vertexai.py b/sglang/python/sglang/lang/backend/vertexai.py new file mode 100644 index 0000000000000000000000000000000000000000..3d51fb13744f84b7724c70501617301e0fe5d04b --- /dev/null +++ b/sglang/python/sglang/lang/backend/vertexai.py @@ -0,0 +1,148 @@ +import os +import warnings + +from sglang.lang.backend.base_backend import BaseBackend +from sglang.lang.chat_template import get_chat_template +from sglang.lang.interpreter import StreamExecutor +from sglang.lang.ir import SglSamplingParams + +try: + import vertexai + from vertexai.preview.generative_models import ( + GenerationConfig, + GenerativeModel, + Image, + ) +except ImportError as e: + GenerativeModel = e + + +class VertexAI(BaseBackend): + def __init__(self, model_name, safety_settings=None): + super().__init__() + + if isinstance(GenerativeModel, Exception): + raise GenerativeModel + + project_id = os.environ["GCP_PROJECT_ID"] + location = os.environ.get("GCP_LOCATION") + vertexai.init(project=project_id, location=location) + + self.model_name = model_name + self.chat_template = get_chat_template("default") + self.safety_settings = safety_settings + + def get_chat_template(self): + return self.chat_template + + def generate( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + ): + if s.messages_: + prompt = self.messages_to_vertexai_input(s.messages_) + else: + # single-turn + prompt = ( + self.text_to_vertexai_input(s.text_, s.cur_images) + if s.cur_images + else s.text_ + ) + ret = GenerativeModel(self.model_name).generate_content( + prompt, + generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()), + safety_settings=self.safety_settings, + ) + + comp = ret.text + + return comp, {} + + def generate_stream( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + ): + if s.messages_: + prompt = self.messages_to_vertexai_input(s.messages_) + else: + # single-turn + prompt = ( + self.text_to_vertexai_input(s.text_, s.cur_images) + if s.cur_images + else s.text_ + ) + generator = GenerativeModel(self.model_name).generate_content( + prompt, + stream=True, + generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()), + safety_settings=self.safety_settings, + ) + for ret in generator: + yield ret.text, {} + + def text_to_vertexai_input(self, text, images): + input = [] + # split with image token + text_segs = text.split(self.chat_template.image_token) + for image_path, image_base64_data in images: + text_seg = text_segs.pop(0) + if text_seg != "": + input.append(text_seg) + input.append(Image.from_bytes(image_base64_data)) + text_seg = text_segs.pop(0) + if text_seg != "": + input.append(text_seg) + return input + + def messages_to_vertexai_input(self, messages): + vertexai_message = [] + # from openai message format to vertexai message format + for msg in messages: + if isinstance(msg["content"], str): + text = msg["content"] + else: + text = msg["content"][0]["text"] + + if msg["role"] == "system": + warnings.warn("Warning: system prompt is not supported in VertexAI.") + vertexai_message.append( + { + "role": "user", + "parts": [{"text": "System prompt: " + text}], + } + ) + vertexai_message.append( + { + "role": "model", + "parts": [{"text": "Understood."}], + } + ) + continue + if msg["role"] == "user": + vertexai_msg = { + "role": "user", + "parts": [{"text": text}], + } + elif msg["role"] == "assistant": + vertexai_msg = { + "role": "model", + "parts": [{"text": text}], + } + + # images + if isinstance(msg["content"], list) and len(msg["content"]) > 1: + for image in msg["content"][1:]: + assert image["type"] == "image_url" + vertexai_msg["parts"].append( + { + "inline_data": { + "data": image["image_url"]["url"].split(",")[1], + "mime_type": "image/jpeg", + } + } + ) + + vertexai_message.append(vertexai_msg) + return vertexai_message diff --git a/sglang/python/sglang/lang/chat_template.py b/sglang/python/sglang/lang/chat_template.py new file mode 100644 index 0000000000000000000000000000000000000000..212d07e0bebd2754a0b18fcca0fd0dc09032e028 --- /dev/null +++ b/sglang/python/sglang/lang/chat_template.py @@ -0,0 +1,668 @@ +import re +from dataclasses import dataclass +from enum import Enum, auto +from typing import Callable, Dict, List, Tuple + + +class ChatTemplateStyle(Enum): + PLAIN = auto() + LLAMA2 = auto() + + +@dataclass +class ChatTemplate: + name: str + default_system_prompt: str + role_prefix_and_suffix: Dict[str, Tuple[str, str]] + stop_str: List[str] = () + image_token: str = "" + audio_token: str = ""), + }, + image_token=" \n", + ) +) + +register_chat_template( + ChatTemplate( + name="llama-2-chat", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ("<>\n", "\n<>\n\n"), + "user": ("[INST] ", " [/INST]"), + "assistant": ("", " "), + }, + style=ChatTemplateStyle.LLAMA2, + ) +) + +# Reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json +register_chat_template( + ChatTemplate( + name="mistral", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ("[SYSTEM_PROMPT] ", " [/SYSTEM_PROMPT]"), + "user": ("[INST] ", " [/INST]"), + "assistant": ("", " "), + }, + stop_str=("",), + image_token="[IMG]", + ) +) + +register_chat_template( + ChatTemplate( + name="llama-3-instruct", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ( + "<|start_header_id|>system<|end_header_id|>\n\n", + "<|eot_id|>", + ), + "user": ( + "<|start_header_id|>user<|end_header_id|>\n\n", + "<|eot_id|>", + ), + "assistant": ( + "<|start_header_id|>assistant<|end_header_id|>\n\n", + "<|eot_id|>", + ), + }, + stop_str=("<|eot_id|>",), + image_token="<|image|>", + ) +) + +# https://huggingface.co/openbmb/MiniCPM-V-2_6 +register_chat_template( + ChatTemplate( + name="minicpmv", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ("", " "), + "user": ("user:", " "), + "assistant": ("assistant:", ""), + }, + stop_str=("<|im_end|>", "<|endoftext|>"), + image_token="(./)", + ) +) + +register_chat_template( + ChatTemplate( + name="janus-pro", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ( + "", + "", + ), + "User": ( + "<|User|>", + "", + ), + "assistant": ( + "<|Assistant|>", + "<|end▁of▁sentence|>", + ), + }, + stop_str=("<|end▁of▁sentence|>",), + image_token="\n", + ) +) + +# https://huggingface.co/openbmb/MiniCPM-o-2_6 +register_chat_template( + ChatTemplate( + name="minicpmo", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ("", " "), + "user": ("user:", " "), + "assistant": ("assistant:", ""), + }, + stop_str=("<|im_end|>", "<|endoftext|>"), + image_token="(./)", + audio_token="()", + ) +) + +register_chat_template( + ChatTemplate( + name="janus", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ( + "", + "", + ), + "user": ( + "<|User|>", + "", + ), + "assistant": ( + "<|Assistant|>", + "<|end▁of▁sentence|>", + ), + }, + stop_str=("<|end▁of▁sentence|>",), + image_token="\n", + ) +) + +# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token. +register_chat_template( + ChatTemplate( + name="llama-3-instruct-llava", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ( + "<|start_header_id|>system<|end_header_id|>\n\n", + "<|eot_id|>", + ), + "user": ( + "<|start_header_id|>user<|end_header_id|>\n\n", + "<|eot_id|>", + ), + "assistant": ( + "<|start_header_id|>assistant<|end_header_id|>\n\n", + "<|eot_id|>", + ), + }, + stop_str=("<|eot_id|>",), + image_token="\n", + ) +) + +# Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json +register_chat_template( + ChatTemplate( + name="llama-4", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ( + "<|header_start|>system<|header_end|>\n\n", + "<|eot|>", + ), + "user": ( + "<|header_start|>user<|header_end|>\n\n", + "<|eot|>", + ), + "assistant": ( + "<|header_start|>assistant<|header_end|>\n\n", + "<|eot|>", + ), + }, + stop_str=("<|eot|>",), + image_token="<|image|>", + ) +) + +# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1 +register_chat_template( + ChatTemplate( + name="yi-1.5", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ("", ""), + "user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"), + "assistant": ("", "<|im_end|>\n"), + }, + style=ChatTemplateStyle.PLAIN, + stop_str=("<|im_end|>",), + ) +) + +# Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava +register_chat_template( + ChatTemplate( + name="yi-vl", + default_system_prompt=( + "This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers." + "这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。" + ), + role_prefix_and_suffix={ + "system": ("", "\n\n"), + "user": ("### Human:", "\n"), + "assistant": ("### Assistant:", "\n"), + }, + image_token=" \n", + ) +) + +register_chat_template( + ChatTemplate( + name="gemma-it", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ("", ""), + "user": ("user\n", "\n"), + "assistant": ("model\n", "\n"), + }, + style=ChatTemplateStyle.PLAIN, + ) +) + +register_chat_template( + ChatTemplate( + name="dbrx-instruct", + default_system_prompt="You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\nYou assist with various tasks, from writing to coding (using markdown for code blocks — remember to use ``` with code, JSON, and tables).\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY.", + role_prefix_and_suffix={ + "system": ("<|im_start|>system\n", "<|im_end|>"), + "user": ("\n<|im_start|>user\n", "<|im_end|>"), + "assistant": ("\n<|im_start|>assistant\n", "<|im_end|>"), + }, + stop_str=("<|im_end|>",), + ) +) + +register_chat_template( + ChatTemplate( + name="c4ai-command-r", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ( + "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>", + "<|END_OF_TURN_TOKEN|>", + ), + "user": ("<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|END_OF_TURN_TOKEN|>"), + "assistant": ( + "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + "<|END_OF_TURN_TOKEN|>", + ), + }, + style=ChatTemplateStyle.PLAIN, + ) +) + +# Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py +register_chat_template( + ChatTemplate( + name="internvl-2-5", + default_system_prompt="你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。", + role_prefix_and_suffix={ + "system": ("<|im_start|>system\n", "<|im_end|>\n"), + "user": ("<|im_start|>user\n", "<|im_end|>\n"), + "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"), + }, + stop_str=["<|im_end|>", "<|action_end|>"], + ) +) + +register_chat_template( + ChatTemplate( + name="interns1", + default_system_prompt="You are an AI assistant whose name is Intern-S1 (书生大模型).\n- Intern-S1 (书生大模型) is a vision-language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n- Intern-S1 (书生大模型) can understand and communicate fluently in the language chosen by the user such as English and 中文.\nYou are an expert reasoner with extensive experience in all areas. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within ... tags.", + role_prefix_and_suffix={ + "system": ("<|im_start|>system\n", "<|im_end|>\n"), + "user": ("<|im_start|>user\n", "<|im_end|>\n"), + "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"), + }, + stop_str=["<|im_end|>", "<|action_end|>"], + ) +) + +register_chat_template( + ChatTemplate( + name="granite-3-instruct", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ( + "<|start_of_role|>system<|end_of_role|>", + "<|end_of_text|>", + ), + "user": ( + "<|start_of_role|>user<|end_of_role|>", + "<|end_of_text|>", + ), + "assistant": ( + "<|start_of_role|>assistant<|end_of_role|>", + "<|end_of_text|>", + ), + }, + stop_str=("<|end_of_text|>",), + ) +) + +register_chat_template( + ChatTemplate( + name="deepseek-v3", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ( + "", + "", + ), + "user": ( + "<|User|>", + "", + ), + "assistant": ( + "<|Assistant|>", + "<|end▁of▁sentence|>", + ), + }, + stop_str=("<|end▁of▁sentence|>",), + ) +) + +# Reference: https://huggingface.co/docs/transformers/main/model_doc/glm4_v#usage-example +register_chat_template( + ChatTemplate( + name="glm-4v", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ("<|system|>\n", "\n"), + "user": ("<|user|>\n", "\n"), + "assistant": ("<|assistant|>\n", "\n"), + }, + style=ChatTemplateStyle.PLAIN, + stop_str=["<|user|>", "<|endoftext|>", "<|observation|>"], + image_token="<|image|>", + ) +) + + +@register_chat_template_matching_function +def match_deepseek(model_path: str): + if re.search(r"deepseek-(v3|r1)", model_path, re.IGNORECASE) and not re.search( + r"base", model_path, re.IGNORECASE + ): + return "deepseek-v3" + + +@register_chat_template_matching_function +def match_orion(model_path: str): + if "orion" in model_path.lower(): + return "claude" + + +@register_chat_template_matching_function +def match_deepseek_janus_pro(model_path: str): + if re.search(r"janus", model_path, re.IGNORECASE): + return "janus-pro" + + +@register_chat_template_matching_function +def match_dbrx(model_path: str): + if re.search(r"dbrx", model_path, re.IGNORECASE) and re.search( + r"instruct", model_path, re.IGNORECASE + ): + return "dbrx-instruct" + + +@register_chat_template_matching_function +def match_vicuna(model_path: str): + if re.search(r"vicuna|llava-v1\.5|llava-next-video-7b", model_path, re.IGNORECASE): + return "vicuna_v1.1" + + +@register_chat_template_matching_function +def match_llama2_chat(model_path: str): + if re.search( + r"llama-2.*chat|codellama.*instruct", + model_path, + re.IGNORECASE, + ): + return "llama-2-chat" + + +@register_chat_template_matching_function +def match_mistral(model_path: str): + if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE): + return "mistral" + + +@register_chat_template_matching_function +def match_llama3_instruct(model_path: str): + if re.search(r"llama-3.*instruct", model_path, re.IGNORECASE): + return "llama-3-instruct" + + +@register_chat_template_matching_function +def match_chat_ml(model_path: str): + if re.search(r"tinyllama", model_path, re.IGNORECASE): + return "chatml" + if re.search(r"qwen.*vl", model_path, re.IGNORECASE): + return "qwen2-vl" + if re.search(r"glm[-_]?4(\.\d+)?v", model_path, re.IGNORECASE): + return "glm-4v" + if re.search(r"qwen.*(chat|instruct)", model_path, re.IGNORECASE) and not re.search( + r"llava", model_path, re.IGNORECASE + ): + return "qwen" + if re.search( + r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2", + model_path, + re.IGNORECASE, + ): + return "chatml-llava" + + +@register_chat_template_matching_function +def match_chat_yi(model_path: str): + if re.search(r"yi-vl", model_path, re.IGNORECASE) and not re.search( + r"llava", model_path, re.IGNORECASE + ): + return "yi-vl" + elif re.search(r"yi-1\.5.*chat", model_path, re.IGNORECASE): + return "yi-1.5" + + +@register_chat_template_matching_function +def match_gemma_it(model_path: str): + if re.search(r"gemma.*it", model_path, re.IGNORECASE): + return "gemma-it" + + +@register_chat_template_matching_function +def match_openbmb_minicpm(model_path: str): + if re.search(r"minicpm-v", model_path, re.IGNORECASE): + return "minicpmv" + elif re.search(r"minicpm-o", model_path, re.IGNORECASE): + return "minicpmo" + + +@register_chat_template_matching_function +def match_c4ai_command_r(model_path: str): + if re.search(r"c4ai-command-r", model_path, re.IGNORECASE): + return "c4ai-command-r" + + +@register_chat_template_matching_function +def match_granite_instruct(model_path: str): + if re.search(r"granite.*instruct", model_path, re.IGNORECASE): + return "granite-3-instruct" + + +@register_chat_template_matching_function +def match_gemma3_instruct(model_path: str): + if re.search(r"gemma-3", model_path, re.IGNORECASE): + return "gemma-it" + + +@register_chat_template_matching_function +def match_internvl_chat(model_path: str): + if re.search(r"internvl2_5", model_path, re.IGNORECASE): + return "internvl-2-5" + + +@register_chat_template_matching_function +def match_interns1_chat(model_path: str): + if re.search(r"intern-s1", model_path, re.IGNORECASE): + return "interns1" + if re.search(r"interns1", model_path, re.IGNORECASE): + return "interns1" + + +if __name__ == "__main__": + messages = [ + {"role": "system", "content": None}, # None means default + # {"role": "system", "content": "You are a helpful, respectful and honest assistant."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi!"}, + {"role": "user", "content": "What can you do?"}, + {"role": "assistant", "content": "I can chat with you."}, + ] + + template = get_chat_template("llama-2-chat") + print(template.get_prompt(messages)) diff --git a/sglang/python/sglang/lang/choices.py b/sglang/python/sglang/lang/choices.py new file mode 100644 index 0000000000000000000000000000000000000000..e52c6b3621795cbcf1320d2ba042ea7ebef4dfcf --- /dev/null +++ b/sglang/python/sglang/lang/choices.py @@ -0,0 +1,164 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import numpy as np + + +@dataclass +class ChoicesDecision: + decision: str + meta_info: Optional[Dict[str, Any]] = None + + +class ChoicesSamplingMethod(ABC): + + @property + def requires_unconditional_logprobs(self) -> bool: + return False + + @abstractmethod + def __call__( + self, + *, + choices: List[str], + normalized_prompt_logprobs: List[float], + input_token_logprobs: List[List[Any]], + output_token_logprobs: List[List[Any]], + unconditional_token_logprobs: Optional[List[List[Any]]] = None, + ) -> ChoicesDecision: ... + + +class TokenLengthNormalized(ChoicesSamplingMethod): + + def __call__( + self, + *, + choices: List[str], + normalized_prompt_logprobs: List[float], + input_token_logprobs: List[List[Any]], + output_token_logprobs: List[List[Any]], + unconditional_token_logprobs: Optional[List[List[Any]]] = None, + ) -> ChoicesDecision: + """Select the option with the highest token length normalized prompt logprob.""" + best_choice = choices[np.argmax(normalized_prompt_logprobs)] + meta_info = { + "normalized_prompt_logprobs": normalized_prompt_logprobs, + "input_token_logprobs": input_token_logprobs, + "output_token_logprobs": output_token_logprobs, + } + return ChoicesDecision(decision=best_choice, meta_info=meta_info) + + +token_length_normalized = TokenLengthNormalized() + + +class GreedyTokenSelection(ChoicesSamplingMethod): + + def __call__( + self, + *, + choices: List[str], + normalized_prompt_logprobs: List[float], + input_token_logprobs: List[List[Any]], + output_token_logprobs: List[List[Any]], + unconditional_token_logprobs: Optional[List[List[Any]]] = None, + ) -> ChoicesDecision: + """Select the option based on greedy logprob selection. For overlapping options + where one option is a subset of a longer option, extend the shorter option using + its average logprob for comparison against the longer option.""" + + num_options = len(choices) + max_tokens = max(len(option) for option in input_token_logprobs) + logprob_matrix = self._build_logprob_matrix( + input_token_logprobs, max_tokens, num_options + ) + remaining = self._greedy_selection(logprob_matrix, num_options, max_tokens) + + best_choice = choices[remaining[0]] + meta_info = { + "normalized_prompt_logprobs": normalized_prompt_logprobs, + "input_token_logprobs": input_token_logprobs, + "output_token_logprobs": output_token_logprobs, + "greedy_logprob_matrix": logprob_matrix.tolist(), + } + return ChoicesDecision(decision=best_choice, meta_info=meta_info) + + def _build_logprob_matrix(self, input_token_logprobs, max_tokens, num_options): + logprob_matrix = np.zeros((num_options, max_tokens)) + for i, option in enumerate(input_token_logprobs): + actual_logprobs = [token[0] for token in option] + avg_logprob = np.mean(actual_logprobs) + logprob_matrix[i, : len(option)] = actual_logprobs + if len(option) < max_tokens: + logprob_matrix[i, len(option) :] = avg_logprob + return logprob_matrix + + def _greedy_selection(self, logprob_matrix, num_options, max_tokens): + remaining = np.arange(num_options) + for j in range(max_tokens): + max_logprob = np.max(logprob_matrix[remaining, j]) + remaining = remaining[logprob_matrix[remaining, j] == max_logprob] + if len(remaining) == 1: + break + return remaining + + +greedy_token_selection = GreedyTokenSelection() + + +class UnconditionalLikelihoodNormalized(ChoicesSamplingMethod): + + @property + def requires_unconditional_logprobs(self) -> bool: + return True + + def __call__( + self, + *, + choices: List[str], + normalized_prompt_logprobs: List[float], + input_token_logprobs: List[List[Any]], + output_token_logprobs: List[List[Any]], + unconditional_token_logprobs: Optional[List[List[Any]]] = None, + ) -> ChoicesDecision: + """Select the option with the highest average token logprob once normalized by + the unconditional token logprobs. + + The first unconditional token logprob is assumed to be None. If so, it is + replaced with 0 for the purposes of normalization.""" + + if unconditional_token_logprobs is None: + raise ValueError( + "Unconditional token logprobs are required for this method." + ) + + normalized_unconditional_prompt_logprobs = self._normalize_logprobs( + input_token_logprobs, unconditional_token_logprobs + ) + + best_choice = choices[np.argmax(normalized_unconditional_prompt_logprobs)] + meta_info = { + "normalized_prompt_logprobs": normalized_prompt_logprobs, + "input_token_logprobs": input_token_logprobs, + "output_token_logprobs": output_token_logprobs, + "unconditional_token_logprobs": unconditional_token_logprobs, + "normalized_unconditional_prompt_logprobs": normalized_unconditional_prompt_logprobs, + } + return ChoicesDecision(decision=best_choice, meta_info=meta_info) + + def _normalize_logprobs(self, input_token_logprobs, unconditional_token_logprobs): + normalized_unconditional_prompt_logprobs = [] + for inputs, unconditionals in zip( + input_token_logprobs, unconditional_token_logprobs + ): + inputs_logprobs = np.array([token[0] for token in inputs]) + unconditionals_logprobs = np.array([token[0] for token in unconditionals]) + unconditionals_logprobs[0] = unconditionals_logprobs[0] or 0 + normalized_unconditional_prompt_logprobs.append( + float(np.mean(inputs_logprobs - unconditionals_logprobs)) + ) + return normalized_unconditional_prompt_logprobs + + +unconditional_likelihood_normalized = UnconditionalLikelihoodNormalized() diff --git a/sglang/python/sglang/lang/interpreter.py b/sglang/python/sglang/lang/interpreter.py new file mode 100644 index 0000000000000000000000000000000000000000..0b59e91b5ff044f261c786f67e818a7302b7c95b --- /dev/null +++ b/sglang/python/sglang/lang/interpreter.py @@ -0,0 +1,1061 @@ +"""The interpreter that executes SGL programs""" + +import asyncio +import contextvars +import copy +import multiprocessing +import queue +import threading +import uuid +import warnings +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +from typing import Any, Callable, Dict, List, Optional + +import tqdm + +from sglang.global_config import global_config +from sglang.lang.ir import ( + SglCommitLazy, + SglConcateAndAppend, + SglConstantText, + SglExpr, + SglExprList, + SglGen, + SglImage, + SglRoleBegin, + SglRoleEnd, + SglSelect, + SglSeparateReasoning, + SglVariable, + SglVarScopeBegin, + SglVarScopeEnd, + SglVideo, +) +from sglang.utils import ( + encode_image_base64, + encode_video_base64, + get_exception_traceback, +) + + +def run_internal(state, program, func_args, func_kwargs, sync): + try: + state.ret_value = program.func(state, *func_args, **func_kwargs) + except Exception as e: + raise e + finally: + state.stream_executor.end() + + if sync: + state.stream_executor.sync() + + if global_config.verbosity >= 2: + print(state.text()) + + +def run_program( + program, + backend, + func_args, + func_kwargs, + default_sampling_para, + stream, + sync=False, + use_thread=True, +): + if hasattr(backend, "endpoint"): + backend = backend.endpoint + assert backend is not None, "Please specify a backend" + func_kwargs.update(program.bind_arguments) + stream_executor = StreamExecutor( + backend, + func_kwargs, + default_sampling_para, + chat_template=None, + stream=stream, + num_api_spec_tokens=program.num_api_spec_tokens, + use_thread=use_thread, + ) + state = ProgramState(stream_executor) + + if stream: + t = threading.Thread( + target=run_internal, args=(state, program, func_args, func_kwargs, sync) + ) + t.start() + return state + else: + run_internal(state, program, func_args, func_kwargs, sync) + return state + + +def run_program_batch( + program, + backend, + batch_arguments, + default_sampling_para, + num_threads, + progress_bar, + generator_style=False, +): + if hasattr(backend, "endpoint"): + backend = backend.endpoint + + # Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program. + if global_config.enable_precache_with_tracing and len(batch_arguments) > 1: + cache_program(program, backend) + + # Run all programs + if num_threads == "auto": + num_threads = max(96, multiprocessing.cpu_count() * 16) + num_threads = min(num_threads, len(batch_arguments)) + + if generator_style: + return _run_program_batch_generator( + program, + backend, + batch_arguments, + default_sampling_para, + num_threads, + progress_bar, + ) + + # Original code path when generator_style=False + if num_threads == 1: + rets = [] + if progress_bar: + for arguments in tqdm.tqdm(batch_arguments): + rets.append( + run_program( + program, + backend, + (), + arguments, + default_sampling_para, + False, + True, + ) + ) + else: + for arguments in batch_arguments: + rets.append( + run_program( + program, + backend, + (), + arguments, + default_sampling_para, + False, + True, + ) + ) + else: + if progress_bar: + pbar = tqdm.tqdm(total=len(batch_arguments)) + + with ThreadPoolExecutor(num_threads) as executor: + futures = [] + for arguments in batch_arguments: + futures.append( + executor.submit( + run_program, + program, + backend, + (), + arguments, + default_sampling_para, + False, + True, + ) + ) + if progress_bar: + futures[-1].add_done_callback(lambda _: pbar.update()) + + rets = [f.result() for f in futures] + rets[-1].sync() + + if progress_bar: + pbar.close() + + return rets + + +def _run_program_batch_generator( + program, + backend, + batch_arguments, + default_sampling_para, + num_threads, + progress_bar, +): + """Helper function that yields results one by one using chunking to avoid overwhelming ThreadPoolExecutor.""" + if num_threads == 1: + iterator = tqdm.tqdm(batch_arguments) if progress_bar else batch_arguments + for arguments in iterator: + yield run_program( + program, + backend, + (), + arguments, + default_sampling_para, + False, + True, + ) + else: + pbar = tqdm.tqdm(total=len(batch_arguments)) if progress_bar else None + + # Process in chunks to avoid overwhelming ThreadPoolExecutor + # Otherwise, ThreadPoolExecutor.submit will block after adding certain number of tasks + # so we will never reach "yield" until all tasks are done + chunk_size = 200 + + with ThreadPoolExecutor(num_threads) as executor: + for chunk_start in range(0, len(batch_arguments), chunk_size): + chunk_end = min(chunk_start + chunk_size, len(batch_arguments)) + chunk_futures = [] + + # Submit chunk of tasks + for i in range(chunk_start, chunk_end): + future = executor.submit( + run_program, + program, + backend, + (), + batch_arguments[i], + default_sampling_para, + False, + True, + ) + if pbar: + future.add_done_callback(lambda _: pbar.update()) + chunk_futures.append(future) + + # Yield results from this chunk as they complete + for future in chunk_futures: + yield future.result() + + if pbar: + pbar.close() + + +def cache_program(program, backend): + from sglang.lang.tracer import extract_prefix_by_tracing + + prefix = extract_prefix_by_tracing(program, backend) + if prefix and len(prefix) > 64: + backend.cache_prefix(prefix) + + +class StreamExecutor: + """A stream executor that executes SGL expressions in a background thread.""" + + def __init__( + self, + backend, + arguments, + default_sampling_para, + chat_template, + stream, + num_api_spec_tokens=None, + use_thread=True, + ): + from sglang.lang.backend.base_backend import BaseBackend + + self.sid = uuid.uuid4().hex + self.backend: BaseBackend = backend + self.arguments: Dict[str, Any] = arguments + self.default_sampling_para = default_sampling_para + self.stream = stream + + self.variables = {} # Dict[name: str -> value: str] + self.variable_event = {} # Dict[name: str -> event: threading.Event] + self.meta_info = {} # Dict[name: str -> info: str] + self.is_finished = False + self.error_ = None + + # For completion + self.text_ = "" # The full text + + # For chat + self.messages_ = [] # The messages in the OpenAI API format + self.chat_template = chat_template or self.backend.get_chat_template() + self.cur_role = None + self.cur_role_begin_pos = None + + # For vision + self.images_ = [] + self.cur_images = [] + + # For fork/join + self.fork_start_text_pos = None + + # For speculative execution + self.num_api_spec_tokens = num_api_spec_tokens + self.speculated_text = "" + + # Worker thread + self.use_thread = use_thread + if self.use_thread: + self.queue = queue.Queue() + + def _run_worker_in_context(): + self._thread_worker_func() + + self.worker = threading.Thread( + target=contextvars.copy_context().run, args=(_run_worker_in_context,) + ) + self.worker.start() + + # For streaming + if stream: + self.stream_text_event = threading.Event() + self.stream_var_event = {} + else: + self.stream_text_event = None + self.stream_var_event = None + + def submit(self, expr: SglExpr): + self._init_var_event(expr) + + if self.use_thread: + self.queue.put(expr) + else: + self._execute(expr) + + def sync(self): + if self.use_thread: + self.queue.join() + + def get_var(self, name): + if name in self.variable_event: + self.variable_event[name].wait() + return self.variables[name] + + def set_var(self, name, value): + self.variables[name] = value + + def get_meta_info(self, name, timeout=None): + if name in self.variable_event: + got = self.variable_event[name].wait(timeout) + if not got: + raise TimeoutError(f"Timeout while waiting for event '{name}'") + ret = self.meta_info.get(name, None) + return ret + + def fork( + self, + size: int = 1, + position_ids_offset: Optional[List[int]] = None, + ): + if size > 1 and str(self.text_): + self.submit(SglCommitLazy()) + + self.sync() + size = int(size) + + exes = [ + StreamExecutor( + self.backend, + self.arguments, + self.default_sampling_para, + self.chat_template, + self.stream, + ) + for _ in range(size) + ] + for i in range(size): + exes[i].variables = dict(self.variables) + exes[i].text_ = str(self.text_) + exes[i].messages_ = list(self.messages_) + exes[i].cur_role = self.cur_role + exes[i].cur_role_begin_pos = self.cur_role_begin_pos + exes[i].fork_start_text_pos = len(self.text_) + exes[i].images_ = list(self.images_) + + # TODO(ying): handle API speculative execution + + return exes + + def text(self): + self.sync() + return self.text_ + + def messages(self): + self.sync() + return self.messages_ + + def error(self): + self.sync() + return self.error_ + + def end(self): + if self.use_thread: + if self.worker.is_alive(): + self.queue.put(None) + self.backend.end_program(self) + + def _thread_worker_func(self): + error = None + + while True: + expr = self.queue.get() + if expr is None: + self.queue.task_done() + break + + try: + self._execute(expr) + except Exception as e: + warnings.warn(f"Error in stream_executor: {get_exception_traceback()}") + error = e + break + self.queue.task_done() + if self.stream_text_event: + self.stream_text_event.set() + + # Clean the queue and events + if error is not None: + try: + while True: + self.queue.task_done() + self.queue.get_nowait() + except queue.Empty: + pass + for name in self.variable_event: + self.variable_event[name].set() + if self.stream_var_event: + for name in self.stream_var_event: + self.stream_var_event[name].set() + self.error_ = error + + if self.stream_text_event: + self.stream_text_event.set() + + self.is_finished = True + + def _execute(self, other): + if isinstance(other, str): + other = SglConstantText(other) + + assert isinstance(other, SglExpr), f"{other}" + + if isinstance(other, SglConstantText): + self._execute_fill(other.value) + elif isinstance(other, SglGen): + self._execute_gen(other) + elif isinstance(other, SglSelect): + self._execute_select(other) + elif isinstance(other, SglExprList): + for x in other.expr_list: + self._execute(x) + elif isinstance(other, SglRoleBegin): + self._execute_role_begin(other) + elif isinstance(other, SglRoleEnd): + self._execute_role_end(other) + elif isinstance(other, SglImage): + self._execute_image(other) + elif isinstance(other, SglVideo): + self._execute_video(other) + elif isinstance(other, SglVariable): + self._execute_variable(other) + elif isinstance(other, SglVarScopeBegin): + self._execute_var_scope_begin(other) + elif isinstance(other, SglVarScopeEnd): + self._execute_var_scope_end(other) + elif isinstance(other, SglCommitLazy): + self._execute_commit_lazy_operations(other) + elif isinstance(other, SglConcateAndAppend): + if ( + global_config.enable_parallel_encoding + and self.backend.support_concate_and_append + ): + self._execute_concatenate_and_append_kv_cache(other) + else: + self._execute_concatenate_and_append_text(other) + elif isinstance(other, SglSeparateReasoning): + self._execute_separate_reasoning(other) + else: + raise ValueError(f"Unknown type: {type(other)}") + + def _execute_fill(self, value: str, prefix=False): + value = str(value) + + if ( + self.cur_role == "assistant" + and self.num_api_spec_tokens is not None + and self.backend.is_chat_model + and not prefix + ): + self.backend.spec_fill(value) + return + + if self.speculated_text.startswith(value): + self.speculated_text = self.speculated_text[len(value) :] + else: + self.speculated_text = "" + + self.text_ += value + + def _execute_image(self, expr: SglImage): + path = expr.path + + base64_data = encode_image_base64(path) + + self.images_.append((path, base64_data)) + self.cur_images.append((path, base64_data)) + self.text_ += self.chat_template.image_token + + def _execute_video(self, expr: SglVideo): + path = expr.path + num_frames = expr.num_frames + + base64_data = encode_video_base64(path, num_frames) + + self.images_.append((path, base64_data)) + self.cur_images.append((path, base64_data)) + self.text_ += self.chat_template.image_token + + def _spec_gen(self, sampling_params): + stop = sampling_params.stop + max_new_tokens = sampling_params.max_new_tokens + meta_info = {} + + def regen(): + nonlocal meta_info + + sampling_params.max_new_tokens = max( + sampling_params.max_new_tokens, self.num_api_spec_tokens + ) + sampling_params.stop = None + self.speculated_text, meta_info = self.backend.generate( + self, sampling_params=sampling_params + ) + + def find_stop(): + if isinstance(stop, str): + return self.speculated_text.find(stop) + elif isinstance(stop, (tuple, list)): + pos = -1 + for stop_str in stop: + stop_pos = self.speculated_text.find(stop_str) + if stop_pos != -1 and (pos == -1 or stop_pos < pos): + pos = stop_pos + return pos + else: + raise Exception("Wrong type of stop in sampling parameters.") + + if stop is None: + if len(self.speculated_text) < max_new_tokens: + regen() + comp = self.speculated_text[:max_new_tokens] + self.speculated_text = self.speculated_text[max_new_tokens:] + elif isinstance(stop, (str, list, tuple)): + if self.speculated_text == "": + regen() + stop_pos = find_stop() + if stop_pos == -1: + stop_pos = min( + sampling_params.max_new_tokens, + len(self.speculated_text), + ) + comp = self.speculated_text[:stop_pos] + self.speculated_text = self.speculated_text[stop_pos:] + else: + raise ValueError("Wrong type of stop in sampling parameters.") + + return comp, meta_info + + def _execute_gen(self, expr: SglGen): + sampling_params = self._resolve_sampling_params(expr.sampling_params) + name = expr.name + if not self.stream: + if self.num_api_spec_tokens is None: + comp, meta_info = self.backend.generate( + self, + sampling_params=sampling_params, + ) + + else: + if self.backend.is_chat_model: + # Speculative execution on models with only chat interface. + # Store the calls into a temporary list. + # They will be lazily executed later. + comp, meta_info = self.backend.generate( + self, + sampling_params=sampling_params, + spec_var_name=name, + ) + return + + else: # Speculative execution on models with completion interface + comp, meta_info = self._spec_gen(sampling_params) + if isinstance(comp, list): + self.text_ += comp[0] + else: + assert isinstance(comp, str) + self.text_ += comp + + self.variables[name] = comp + self.meta_info[name] = meta_info + self.variable_event[name].set() + else: + assert ( + self.num_api_spec_tokens is None + ), "stream is not supported with api speculative execution" + generator = self.backend.generate_stream( + self, sampling_params=sampling_params + ) + + self.variables[name] = "" + self.stream_var_event[name].set() + + for comp, meta_info in generator: + self.text_ += comp + self.variables[name] += comp + self.meta_info[name] = meta_info + self.stream_var_event[name].set() + self.stream_text_event.set() + + self.variable_event[name].set() + self.stream_var_event[name].set() + + def _execute_select(self, expr: SglSelect): + choices_decision = self.backend.select( + self, expr.choices, expr.temperature, expr.choices_method + ) + if expr.name is not None: + name = expr.name + self.variables[name] = choices_decision.decision + self.meta_info[name] = choices_decision.meta_info + self.variable_event[name].set() + if self.stream_var_event: + self.stream_var_event[name].set() + self.text_ += choices_decision.decision + + def _execute_variable(self, expr: SglVariable): + src_executor = expr.source_stream_executor + value = src_executor.get_var(expr.name) + self._execute_fill(value) + + def _execute_role_begin(self, expr: SglRoleBegin): + assert self.cur_role is None, "Nested roles are not allowed." + + if len(self.messages_) == 0 and expr.role != "system": + # Insert the default system message + default_system = self.chat_template.default_system_prompt + if default_system: + self._execute_role_begin(SglRoleBegin("system")) + self._execute_fill(default_system) + self._execute_role_end(SglRoleEnd("system")) + + self.cur_role = expr.role + + prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_) + + self._execute_fill(prefix, prefix=True) + self.cur_role_begin_pos = len(self.text_) + + def _execute_role_end(self, expr: SglRoleEnd): + if ( + self.cur_role == "assistant" + and self.num_api_spec_tokens is not None + and self.backend.is_chat_model + ): + # Execute the stored lazy generation calls + self.backend.role_end_generate(self) + self.cur_role = None + + new_text = self.text_[self.cur_role_begin_pos :].lstrip() + + _, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_) + self._execute_fill(suffix) + + if self.cur_images: + # OpenAI vision API format + last_msg = { + "role": expr.role, + "content": [{"type": "text", "text": new_text}], + } + for image_path, image_base64_data in self.cur_images: + last_msg["content"].append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64_data}" + }, + } + ) + self.messages_.append(last_msg) + self.cur_images = [] + else: + # OpenAI chat API format + self.messages_.append({"role": expr.role, "content": new_text}) + + def _execute_var_scope_begin(self, expr: SglVarScopeBegin): + self.variables[expr.name] = int(len(self.text_)) + + def _execute_var_scope_end(self, expr: SglVarScopeEnd): + self.variables[expr.name] = self.text_[self.variables[expr.name] :] + self.variable_event[expr.name].set() + + def _execute_commit_lazy_operations(self, expr: SglCommitLazy): + self.backend.commit_lazy_operations(self) + + def _execute_concatenate_and_append_text(self, expr: SglConcateAndAppend): + new_text = "" + for s in expr.states: + exe = s.stream_executor + exe.sync() + new_text += exe.text_[exe.fork_start_text_pos :] + + self._execute_fill(new_text) + + def _execute_concatenate_and_append_kv_cache(self, expr: SglConcateAndAppend): + self_len = len(self.text_) + + for i, s in enumerate(expr.states): + exe = s.stream_executor + exe.submit(SglCommitLazy()) + + for i, s in enumerate(expr.states): + exe = s.stream_executor + exe.sync() + assert exe.fork_start_text_pos == self_len + self.text_ += exe.text_[exe.fork_start_text_pos :] + + src_rids = [state.stream_executor.sid for state in expr.states] + self.backend.concatenate_and_append(src_rids, self.sid) + + def _execute_separate_reasoning(self, expr: SglSeparateReasoning): + if self.stream: + # separate reasoning for stream is not supported + return + + if ( + self.cur_role == "assistant" + and self.num_api_spec_tokens is not None + and self.backend.is_chat_model + ): + # Execute the stored lazy generation calls + self.backend.role_end_generate(self) + + from sglang.srt.parser.reasoning_parser import ReasoningParser + + reasoning_parser = ReasoningParser(expr.model_type) + other = expr.expr + if not other: + return + elif isinstance(other, SglGen) or isinstance(other, SglSelect): + cur_text = self.get_var(other.name) + reasoning, normal_text = reasoning_parser.parse_non_stream(cur_text) + reasoning_name = expr.process_name_for_reasoning(other.name) + self.set_var(other.name, normal_text) + self.set_var(reasoning_name, reasoning) + # the variable is ready to be used + self.variable_event[reasoning_name].set() + self.text_ = self.text_[: self.cur_role_begin_pos] + normal_text + elif isinstance(other, SglExprList): + for x in other.expr_list: + self._execute_separate_reasoning( + SglSeparateReasoning(expr.model_type, x) + ) + + def _init_var_event(self, expr): + if isinstance( + expr, (SglGen, SglSelect, SglVarScopeBegin, SglSeparateReasoning) + ): + self.variable_event[expr.name] = threading.Event() + if self.stream: + self.stream_var_event[expr.name] = threading.Event() + elif isinstance(expr, SglExprList): + for e in expr.expr_list: + self._init_var_event(e) + + def _resolve_sampling_params(self, sampling_params): + """ + Construct sampling param based on default + override values + + The default values of sampling are populated in `default_sampling_para` via sgl.function.run(...sampling_args) + , and `sampling_params` contains the override values from sgl.gen(). + + Here we use default_sampling_para as the base and override the values if they exist in `sampling_params`. + It also extends the stop tokens based on the chat template. + """ + + # deepcopy is required because the dict has lists inside + clone = copy.deepcopy(self.default_sampling_para) + + for item in [ + "max_new_tokens", + "min_new_tokens", + "n", + "stop", + "stop_token_ids", + "stop_regex", + "temperature", + "top_p", + "top_k", + "min_p", + "frequency_penalty", + "presence_penalty", + "ignore_eos", + "return_logprob", + "logprob_start_len", + "top_logprobs_num", + "return_text_in_logprobs", + "dtype", + "regex", + "json_schema", + ]: + value = getattr(sampling_params, item, None) + if value is not None: + setattr(clone, item, value) + + if self.chat_template.stop_str: + if clone.stop == (): + clone.stop = [] + elif isinstance(clone.stop, str): + clone.stop = [clone.stop] + clone.stop += self.chat_template.stop_str + + return clone + + def __del__(self): + self.end() + + +class ProgramState: + """The state of an SGL program.""" + + def __init__(self, stream_executor: StreamExecutor): + self.stream_executor = stream_executor + + def _role_common(self, name: str, expr: Optional[SglExpr] = None): + if expr is not None: + role_expr = SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)]) + self.stream_executor.submit(role_expr) + return role_expr + else: + + @contextmanager + def role_scope(): + self.stream_executor.submit(SglRoleBegin(name)) + yield + self.stream_executor.submit(SglRoleEnd(name)) + + return role_scope() + + def system(self, expr: Optional[SglExpr] = None): + return self._role_common("system", expr) + + def user(self, expr: Optional[SglExpr] = None): + return self._role_common("user", expr) + + def assistant(self, expr: Optional[SglExpr] = None): + return self._role_common("assistant", expr) + + @contextmanager + def var_scope(self, name: str): + self.stream_executor.submit(SglVarScopeBegin(name)) + yield + self.stream_executor.submit(SglVarScopeEnd(name)) + + def fork( + self, + size: int = 1, + position_ids_offset: Optional[List[int]] = None, + ): + stream_executors = self.stream_executor.fork(size, position_ids_offset) + states = [ProgramState(x) for x in stream_executors] + state_group = ProgramStateGroup(states, self) + return state_group + + @contextmanager + def copy(self, position_ids_offset: Optional[List[int]] = None): + state_group = self.fork(1, position_ids_offset) + try: + yield state_group[0] + finally: + state_group.join() + + def text(self): + return self.stream_executor.text() + + def messages(self): + return self.stream_executor.messages() + + def sync(self): + return self.stream_executor.sync() + + def error(self): + return self.stream_executor.error() + + def text_iter(self, var_name: Optional[str] = None): + if self.stream_executor.stream: + prev = 0 + if var_name is None: + event = self.stream_executor.stream_text_event + while True: + event.wait() + event.clear() + out = str(self.stream_executor.text_[prev:]) + prev += len(out) + if out: + yield out + if self.stream_executor.is_finished: + break + else: + event = None + while not event: + if var_name in self.stream_executor.stream_var_event: + event = self.stream_executor.stream_var_event[var_name] + if self.stream_executor.is_finished: + yield "" + return + + while True: + event.wait() + event.clear() + out = str(self.stream_executor.variables[var_name][prev:]) + prev += len(out) + if out: + yield out + if self.stream_executor.variable_event[var_name].is_set(): + break + else: + if var_name is None: + yield self.text() + else: + yield self.get_var(var_name) + + async def text_async_iter( + self, var_name: Optional[str] = None, return_meta_data: bool = False + ): + loop = asyncio.get_running_loop() + + if self.stream_executor.stream: + prev = 0 + if var_name is None: + event = self.stream_executor.stream_text_event + while True: + await loop.run_in_executor(None, event.wait) + event.clear() + out = str(self.stream_executor.text_[prev:]) + prev += len(out) + if out: + yield out + if self.stream_executor.is_finished: + break + else: + event = None + while not event: + if var_name in self.stream_executor.stream_var_event: + event = self.stream_executor.stream_var_event[var_name] + if self.stream_executor.is_finished: + yield "" + return + + while True: + await loop.run_in_executor(None, event.wait) + event.clear() + out = str(self.stream_executor.variables[var_name][prev:]) + prev += len(out) + if out: + if return_meta_data: + yield out, self.stream_executor.meta_info[var_name] + else: + yield out + if self.stream_executor.variable_event[var_name].is_set(): + break + else: + if var_name is None: + yield self.text() + else: + yield self.get_var(var_name) + + def get_var(self, name): + return self.stream_executor.get_var(name) + + def set_var(self, name, value): + return self.stream_executor.set_var(name, value) + + def get_meta_info(self, name): + return self.stream_executor.get_meta_info(name) + + def __iadd__(self, other): + if other is None: + raise ValueError("Tried to append None to state.") + self.stream_executor.submit(other) + return self + + def __getitem__(self, name): + return self.get_var(name) + + def __setitem__(self, name, value): + self.set_var(name, value) + + def __contains__(self, name): + return name in self.stream_executor.variables + + def __del__(self): + self.stream_executor.end() + + def __repr__(self) -> str: + return f"ProgramState({self.text()})" + + +class ProgramStateGroup: + def __init__( + self, states: List[ProgramState], src_state: Optional[ProgramState] = None + ): + self.states = states + self.src_state = src_state + + def join(self, mode: str = "gather_variable"): + if mode == "gather_variable": + # Copy variables back + src_vars = self.src_state.stream_executor.variables + src_var_set = set(src_vars.keys()) + for child_state in self.states: + child_state.stream_executor.sync() + child_vars = child_state.stream_executor.variables + new_vars = set(child_vars.keys()) - src_var_set + + for k in new_vars: + if k in src_vars: + src_vars[k].append(child_vars[k]) + else: + src_vars[k] = [child_vars[k]] + elif mode == "concate_and_append": + # Concatenate and append KV cache + self.src_state += SglConcateAndAppend(self.states) + # Need a sync here. Otherwise, `states` can be deleted. + self.src_state.stream_executor.sync() + else: + raise ValueError(f"Invalid join mode: {mode}") + + for s in self.states: + s.stream_executor.end() + + def __getitem__(self, i: int): + return self.states[i] + + def __setitem__(self, i: int, value): + assert self.states[i] == value + + def __iadd__(self, other): + if isinstance(other, Callable): + # lambda function + for i in range(len(self.states)): + self.states[i] += other(i) + elif isinstance(other, SglExpr): + for i in range(len(self.states)): + self.states[i] += other + elif isinstance(other, (list, tuple)): + for i in range(len(self.states)): + self.states[i] += other[i] + else: + raise ValueError(f"Invalid value: {other}") + + return self diff --git a/sglang/python/sglang/lang/ir.py b/sglang/python/sglang/lang/ir.py new file mode 100644 index 0000000000000000000000000000000000000000..43da723b8ec9ce2f6360c38982a30a7a43f70f72 --- /dev/null +++ b/sglang/python/sglang/lang/ir.py @@ -0,0 +1,643 @@ +"""The intermediate representation.""" + +import dataclasses +import inspect +import warnings +from typing import List, Optional, Union + +from sglang.global_config import global_config +from sglang.lang.choices import ChoicesSamplingMethod + +REGEX_INT = r"[-+]?[0-9]+[ \n]*" +REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+[ \n]*" +REGEX_BOOL = r"(True|False)" +REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg + + +@dataclasses.dataclass +class SglSamplingParams: + max_new_tokens: int = 128 + min_new_tokens: int = 0 + n: int = 1 + stop: Union[str, List[str]] = () + stop_token_ids: Optional[List[int]] = () + stop_regex: Optional[Union[str, List[str]]] = () + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = -1 # -1 means disable + min_p: float = 0.0 + frequency_penalty: float = 0.0 + presence_penalty: float = 0.0 + ignore_eos: bool = False + return_logprob: Optional[bool] = None + logprob_start_len: Optional[int] = (None,) + top_logprobs_num: Optional[int] = (None,) + return_text_in_logprobs: Optional[bool] = (None,) + json_schema: Optional[str] = None + + # for constrained generation, not included in to_xxx_kwargs + dtype: Optional[str] = None + regex: Optional[str] = None + + def clone(self): + return SglSamplingParams( + self.max_new_tokens, + self.min_new_tokens, + self.n, + self.stop, + self.stop_token_ids, + self.stop_regex, + self.temperature, + self.top_p, + self.top_k, + self.min_p, + self.frequency_penalty, + self.presence_penalty, + self.ignore_eos, + self.return_logprob, + self.logprob_start_len, + self.top_logprobs_num, + self.return_text_in_logprobs, + self.json_schema, + ) + + def to_openai_kwargs(self): + # OpenAI does not support top_k, so we drop it here + if self.regex is not None: + warnings.warn("Regular expression is not supported in the OpenAI backend.") + return { + "max_tokens": self.max_new_tokens, + "max_completion_tokens": self.max_new_tokens, + "n": self.n, + "stop": self.stop or None, + "temperature": self.temperature, + "top_p": self.top_p, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + } + + def to_vertexai_kwargs(self): + if self.regex is not None: + warnings.warn( + "Regular expression is not supported in the VertexAI backend." + ) + return { + "candidate_count": 1, + "max_output_tokens": self.max_new_tokens, + "stop_sequences": self.stop, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k if self.top_k > 0 else None, + } + + def to_anthropic_kwargs(self): + # Anthropic does not support frequency_penalty or presence_penalty, so we drop it here + if self.regex is not None: + warnings.warn( + "Regular expression is not supported in the Anthropic backend." + ) + return { + "max_tokens": self.max_new_tokens, + "stop_sequences": ( + self.stop if isinstance(self.stop, (list, tuple)) else [self.stop] + ), + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + } + + def to_litellm_kwargs(self): + if self.regex is not None: + warnings.warn("Regular expression is not supported in the LiteLLM backend.") + return { + "max_tokens": self.max_new_tokens, + "stop": self.stop or None, + "temperature": self.temperature, + "top_p": self.top_p, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + } + + def to_srt_kwargs(self): + return { + "max_new_tokens": self.max_new_tokens, + "min_new_tokens": self.min_new_tokens, + "n": self.n, + "stop": self.stop, + "stop_token_ids": self.stop_token_ids, + "stop_regex": self.stop_regex, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "min_p": self.min_p, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + "ignore_eos": self.ignore_eos, + "regex": self.regex, + "json_schema": self.json_schema, + } + + +class SglFunction: + def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None): + self.func = func + self.num_api_spec_tokens = num_api_spec_tokens + self.bind_arguments = bind_arguments or {} + self.pin_prefix_rid = None + + # Parse arguments + argspec = inspect.getfullargspec(func) + assert argspec.args[0] == "s", 'The first argument must be "s"' + self.arg_names = argspec.args[1:] + self.arg_defaults = argspec.defaults if argspec.defaults is not None else [] + + def bind(self, **kwargs): + assert all(key in self.arg_names for key in kwargs) + + new_bind_dict = {**self.bind_arguments, **kwargs} + return SglFunction(self.func, bind_arguments=new_bind_dict) + + def run( + self, + *args, + max_new_tokens: int = 128, + n: int = 1, + stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, + stop_regex: Optional[Union[str, List[str]]] = None, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + min_p: float = 0.0, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + ignore_eos: bool = False, + return_logprob: Optional[bool] = None, + logprob_start_len: Optional[int] = None, + top_logprobs_num: Optional[int] = None, + return_text_in_logprobs: Optional[bool] = None, + stream: bool = False, + backend=None, + use_thread: bool = True, + **kwargs, + ): + from sglang.lang.interpreter import run_program + + # avoid using [] as the default arg: https://nikos7am.com/posts/mutable-default-arguments/ + if stop is None: + stop = [] + if stop_token_ids is None: + stop_token_ids = [] + if stop_regex is None: + stop_regex = [] + + default_sampling_para = SglSamplingParams( + max_new_tokens=max_new_tokens, + n=n, + stop=stop, + stop_token_ids=stop_token_ids, + stop_regex=stop_regex, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + ignore_eos=ignore_eos, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + return_text_in_logprobs=return_text_in_logprobs, + ) + backend = backend or global_config.default_backend + return run_program( + self, + backend, + args, + kwargs, + default_sampling_para, + stream, + use_thread=use_thread, + ) + + def run_batch( + self, + batch_kwargs, + *, + max_new_tokens: int = 128, + n: int = 1, + stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, + stop_regex: Optional[Union[str, List[str]]] = None, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + min_p: float = 0.0, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + ignore_eos: bool = False, + return_logprob: Optional[bool] = None, + logprob_start_len: Optional[int] = None, + top_logprobs_num: Optional[int] = None, + return_text_in_logprobs: Optional[bool] = None, + backend=None, + num_threads: Union[str, int] = "auto", + progress_bar: bool = False, + generator_style: bool = False, + ): + from sglang.lang.interpreter import run_program_batch + + if stop is None: + stop = [] + if stop_token_ids is None: + stop_token_ids = [] + if stop_regex is None: + stop_regex = [] + + assert isinstance(batch_kwargs, (list, tuple)) + if len(batch_kwargs) == 0: + return [] + if not isinstance(batch_kwargs[0], dict): + num_programs = len(batch_kwargs) + # change the list of argument values to dict of arg_name -> arg_value + batch_kwargs = [ + {self.arg_names[i]: v for i, v in enumerate(arg_values)} + for arg_values in batch_kwargs + if isinstance(arg_values, (list, tuple)) + and len(self.arg_names) - len(self.arg_defaults) + <= len(arg_values) + <= len(self.arg_names) + ] + # Ensure to raise an exception if the number of arguments mismatch + if len(batch_kwargs) != num_programs: + raise Exception("Given arguments mismatch the SGL function signature") + + default_sampling_para = SglSamplingParams( + max_new_tokens=max_new_tokens, + n=n, + stop=stop, + stop_token_ids=stop_token_ids, + stop_regex=stop_regex, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + ignore_eos=ignore_eos, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + return_text_in_logprobs=return_text_in_logprobs, + ) + backend = backend or global_config.default_backend + return run_program_batch( + self, + backend, + batch_kwargs, + default_sampling_para, + num_threads, + progress_bar, + generator_style=generator_style, + ) + + def trace(self, *, backend=None, **kwargs): + from sglang.lang.tracer import trace_program + + backend = backend or global_config.default_backend + return trace_program(self, kwargs, backend) + + def cache(self, backend=None): + from sglang.lang.interpreter import cache_program + + backend = backend or global_config.default_backend + return cache_program(self, backend) + + def __call__(self, *args, **kwargs): + from sglang.lang.tracer import TracingScope + + tracing_scope = TracingScope.get_current_scope() + if tracing_scope is None: + return self.run(*args, **kwargs) + else: + kwargs["backend"] = tracing_scope.tracer_state.backend + return self.trace(*args, **kwargs) + + +class SglExpr: + node_ct = 0 + + def __init__(self): + self.node_id = SglExpr.node_ct + self.prev_node = None + self.pid = None + SglExpr.node_ct += 1 + + def __add__(self, other): + if isinstance(other, str): + other = SglConstantText(other) + assert isinstance(other, SglExpr) + + return self.concatenate_ir(self, other) + + def __radd__(self, other): + if isinstance(other, str): + other = SglConstantText(other) + assert isinstance(other, SglExpr), f"{other}" + + return self.concatenate_ir(other, self) + + def concatenate_ir(self, a, b): + if isinstance(a, SglExprList): + if isinstance(b, SglExprList): + return SglExprList(a.expr_list + b.expr_list) + else: + return SglExprList(a.expr_list + [b]) + elif isinstance(b, SglExprList): + return SglExprList([a] + b.expr_list) + + return SglExprList([a, b]) + + def print_graph_dfs(self): + ret = [""] + visited = set() + + def dfs_print(x): + if x is None or x in visited: + return + visited.add(x) + + # Print dependency + if x.prev_node is not None: + dfs_print(x.prev_node) + + if isinstance(x, SglExprList): + for y in x.expr_list: + dfs_print(y) + # elif isinstance(x, SglRole): + # dfs_print(x.expr) + elif isinstance(x, SglVariable): + dfs_print(x.source) + + # Print the node itself + if isinstance(x, (SglFork, SglGetForkItem)): + ret[0] += f"%{x.node_id} = {x}\n" + else: + if x.prev_node is not None: + ret[0] += ( + f"%{x.node_id} = %{x.prev_node.node_id} + " + str(x) + "\n" + ) + else: + ret[0] += f"%{x.node_id} = " + str(x) + "\n" + + dfs_print(self) + return ret[0] + + +class SglExprList(SglExpr): + def __init__(self, expr_list: List[SglExpr]): + super().__init__() + self.expr_list = expr_list + + def __repr__(self): + return f"ExprList({self.expr_list})" + + +class SglArgument(SglExpr): + def __init__(self, name: str, value: str): + super().__init__() + self.name = name + self.value = value + + def __repr__(self): + return f"Argument(name={self.name}, value={repr(self.value)})" + + def __len__(self): + return len(self.value) + + def __getitem__(self, i): + return self.value[i] + + def __int__(self): + return self.value + + def __bool__(self): + return self.value + + def __format__(self, *args): + raise TypeError( + "Cannot put argument inside a f-string. " + "This is not compatible with the tracer. " + ) + + +class SglImage(SglExpr): + def __init__(self, path: str): + self.path = path + + def __repr__(self) -> str: + return f"SglImage({self.path})" + + +class SglVideo(SglExpr): + def __init__(self, path: str, num_frames: int): + self.path = path + self.num_frames = num_frames + + def __repr__(self) -> str: + return f"SglVideo({self.path}, {self.num_frames})" + + +class SglGen(SglExpr): + def __init__( + self, + name: Optional[str] = None, + max_new_tokens: Optional[int] = None, + min_new_tokens: Optional[int] = None, + n: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, + stop_regex: Optional[Union[str, List[str]]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + min_p: Optional[float] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + ignore_eos: Optional[bool] = None, + return_logprob: Optional[bool] = None, + logprob_start_len: Optional[int] = None, + top_logprobs_num: Optional[int] = None, + return_text_in_logprobs: Optional[bool] = None, + dtype: Optional[type] = None, + regex: Optional[str] = None, + json_schema: Optional[str] = None, + ): + """Call the model to generate. See the meaning of the arguments in docs/backend/sampling_params.md""" + super().__init__() + self.name = name + self.sampling_params = SglSamplingParams( + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + n=n, + stop=stop, + stop_regex=stop_regex, + stop_token_ids=stop_token_ids, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + ignore_eos=ignore_eos, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + return_text_in_logprobs=return_text_in_logprobs, + dtype=dtype, + regex=regex, + json_schema=json_schema, + ) + + def __repr__(self): + return f"Gen('{self.name}')" + + +class SglConstantText(SglExpr): + def __init__(self, value: str): + super().__init__() + self.value = value + + def __repr__(self): + return f"Constant({repr(self.value)})" + + +class SglRoleBegin(SglExpr): + def __init__(self, role: str): + super().__init__() + self.role = role + + def __repr__(self): + return f"RoleBegin({self.role})" + + +class SglRoleEnd(SglExpr): + def __init__(self, role: str): + super().__init__() + self.role = role + + def __repr__(self): + return f"RoleEnd({self.role})" + + +class SglSelect(SglExpr): + + def __init__( + self, + name: str, + choices: List[str], + temperature: float, + choices_method: ChoicesSamplingMethod, + ): + super().__init__() + self.name = name + self.choices = choices + self.temperature = temperature + self.choices_method = choices_method + + def __repr__(self): + return f"Select({self.name}, choices={self.choices}, choices_method={self.choices_method})" + + +class SglFork(SglExpr): + def __init__(self, number: int, position_ids_offset=None): + super().__init__() + self.number = number + self.position_ids_offset = position_ids_offset + + def __repr__(self): + return ( + f"Fork(%{self.prev_node.node_id}, number={self.number}, " + f"position_ids_offset={self.position_ids_offset})" + ) + + +class SglGetForkItem(SglExpr): + def __init__(self, index: int): + super().__init__() + self.index = index + + def __repr__(self): + return f"GetForkItem(%{self.prev_node.node_id}, index={self.index})" + + +class SglVariable(SglExpr): + def __init__(self, name: str, source): + super().__init__() + self.name = name + self.source = source + + def __repr__(self): + return f"Variable('{self.name}', source=%{self.source.node_id})" + + +class SglVarScopeBegin(SglExpr): + def __init__(self, name: str): + super().__init__() + self.name = name + + def __repr__(self): + return f"VarScopeBegin('{self.name}')" + + +class SglVarScopeEnd(SglExpr): + def __init__(self, name: str): + super().__init__() + self.name = name + + def __repr__(self): + return f"VarScopeEnd('{self.name}')" + + +class SglConcateAndAppend(SglExpr): + def __init__(self, states): + super().__init__() + self.states = states + + def __repr__(self): + return f"ConcatenateAndAppend('{self.states}')" + + +class SglCommitLazy(SglExpr): + def __init__(self): + super().__init__() + + def __repr__(self): + return "CommitLazy()" + + +class SglSeparateReasoning(SglExpr): + def __init__(self, model_type: str, expr: SglExpr): + super().__init__() + self.model_type = model_type + + self.expr = expr + self.name = None + self._process_expr(expr) + + def process_name_for_reasoning(self, name): + if not name: + raise ValueError("name must be provided") + return f"{name}_reasoning_content" + + def _process_expr(self, expr): + if isinstance(expr, SglGen): + self.name = self.process_name_for_reasoning(expr.name) + elif isinstance(expr, SglSelect): + self.name = self.process_name_for_reasoning(expr.name) + elif isinstance(expr, SglExprList): + for x in expr.expr_list: + self._process_expr(x) + + def __repr__(self): + return f"SeparateReasoning(model_type={self.model_type}, name={self.name})" diff --git a/sglang/python/sglang/lang/tracer.py b/sglang/python/sglang/lang/tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a2a744f92996ecbc50b285474627026aae160ba --- /dev/null +++ b/sglang/python/sglang/lang/tracer.py @@ -0,0 +1,279 @@ +"""Tracing a program.""" + +import uuid +from typing import Any, Dict, List, Optional + +from sglang.lang.backend.base_backend import BaseBackend +from sglang.lang.interpreter import ProgramState, ProgramStateGroup +from sglang.lang.ir import ( + SglArgument, + SglConstantText, + SglExpr, + SglExprList, + SglFork, + SglGen, + SglGetForkItem, + SglRoleBegin, + SglRoleEnd, + SglSelect, + SglVariable, + SglVarScopeBegin, + SglVarScopeEnd, +) + + +class StopTracing(Exception): + pass + + +def extract_prefix_by_tracing(program, backend): + # Create dummy arguments + dummy_arguments = {name: SglArgument(name, None) for name in program.arg_names} + arguments = dummy_arguments + arguments.update(program.bind_arguments) + + # Trace + tracer = TracerProgramState(backend, arguments, only_trace_prefix=True) + try: + with TracingScope(tracer): + tracer.ret_value = program.func(tracer, **arguments) + except (StopTracing, TypeError, AttributeError): + # Some exceptions may not be caught + pass + + # Run and cache prefix + prefix = "" + for expr in tracer.flatten_nodes(): + if isinstance(expr, SglConstantText): + prefix += expr.value + else: + break + return prefix + + +def trace_program(program, arguments, backend): + # Create dummy backend + if backend is None: + backend = BaseBackend() + + # Create dummy arguments + dummy_arguments = { + name: SglArgument(name, None) + for name in program.arg_names + if name not in arguments + } + arguments.update(dummy_arguments) + arguments.update(program.bind_arguments) + + # Trace + tracer = TracerProgramState(backend, arguments, only_trace_prefix=False) + with TracingScope(tracer): + tracer.ret_value = program.func(tracer, **arguments) + return tracer + + +class TracerProgramState(ProgramState): + def __init__(self, backend, arguments, only_trace_prefix): + self.pid = uuid.uuid4().hex + self.backend = backend + self.arguments: Dict[str, Any] = arguments + self.only_trace_prefix = only_trace_prefix + + if hasattr(backend, "endpoint"): + self.backend = backend.endpoint + + self.nodes = [] + self.last_node = None + self.variables = {} + self.ret_value = None + + # For completion + + # For chat + self.messages_ = [] + self.cur_role = None + self.chat_template = self.backend.get_chat_template() + + # For multi states + self.child_states = [] + + cur_scope = TracingScope.get_current_scope() + if cur_scope is not None: + cur_scope.add_child_state(self) + + ################################## + ########### Public API ########### + ################################## + + def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None): + assert size >= 1 + + if self.only_trace_prefix: + raise StopTracing() + + fork_node = SglFork(size) + fork_node.prev_node = self.last_node + + states = [ + TracerProgramState(self.backend, self.arguments, self.only_trace_prefix) + for _ in range(size) + ] + + for i in range(size): + node = SglGetForkItem(i) + node.prev_node = fork_node + states[i].last_node = node + states[i].variables = dict(self.variables) + states[i].messages_ = list(self.messages_) + states[i].cur_role = self.cur_role + states[i].chat_template = self.chat_template + + state_group = ProgramStateGroup(states, self) + + return state_group + + ################################## + ########## Internal API ########## + ################################## + + def _append_node(self, other: SglExpr): + self.nodes.append(other) + other.prev_node = self.last_node + self.last_node = other + + def _execute(self, other: SglExpr): + if isinstance(other, str): + other = SglConstantText(other) + + other.pid = self.pid + + if isinstance(other, SglConstantText): + self._execute_fill(other) + elif isinstance(other, SglGen): + self._execute_gen(other) + elif isinstance(other, SglSelect): + self._execute_select(other) + elif isinstance(other, SglExprList): + for x in other.expr_list: + self._execute(x) + elif isinstance(other, SglRoleBegin): + self._execute_role_begin(other) + elif isinstance(other, SglRoleEnd): + self._execute_role_end(other) + elif isinstance(other, SglVarScopeBegin): + self._execute_var_scope_begin(other) + elif isinstance(other, SglVarScopeEnd): + self._execute_var_scope_end(other) + else: + if self.only_trace_prefix: + raise StopTracing() + else: + self._append_node(other) + + return self + + def __iadd__(self, other): + self._execute(other) + return self + + def _execute_fill(self, expr: SglConstantText): + if isinstance(expr, str): + expr = SglConstantText(expr) + self._append_node(expr) + + def _execute_gen(self, expr: SglGen): + name = expr.name if expr.name is not None else "gen_" + str(len(self.variables)) + new_node = SglVariable(name, source=expr) + self.variables[name] = new_node + self._append_node(expr) + + def _execute_select(self, expr: SglSelect): + name = ( + expr.name if expr.name is not None else "select_" + str(len(self.variables)) + ) + new_node = SglVariable(name, source=expr) + self.variables[name] = new_node + self._append_node(expr) + + def _execute_role_begin(self, expr: SglRoleBegin): + assert self.cur_role is None, "Nested roles are not allowed." + + if len(self.messages_) == 0 and expr.role != "system": + # Insert default system message + default_system = self.chat_template.default_system_prompt + if default_system: + self._execute_role_begin(SglRoleBegin("system")) + self._execute_fill(default_system) + self._execute_role_end(SglRoleEnd("system")) + + self.cur_role = expr.role + + prefix, suffix = self.chat_template.get_prefix_and_suffix( + expr.role, self.messages_ + ) + + self._execute_fill(prefix) + + def _execute_role_end(self, expr: SglRoleEnd): + prefix, suffix = self.chat_template.get_prefix_and_suffix( + expr.role, self.messages_ + ) + + self._execute_fill(suffix) + + self.messages_.append({"role": expr.role, "content": ""}) + + self.cur_role = None + + def _execute_var_scope_end(self, expr: SglVarScopeEnd): + new_node = SglVariable(expr.name, source=self.last_node) + self.variables[expr.name] = new_node + + def get_var(self, name): + ret = self.arguments.get(name, None) + if ret is not None: + return ret + + v = self.variables[name] + return SglVariable(v.name, v.source) + + def flatten_nodes(self): + def traverse(cur): + if isinstance(cur, SglExprList): + for child in cur.expr_list: + traverse(child) + else: + ret.append(cur) + + ret = [] + for x in self.nodes: + traverse(x) + return ret + + def __del__(self): + pass + + +class TracingScope: + cur_scope = None + + def __init__(self, tracer_state: TracerProgramState): + self.tracer_state = tracer_state + self.last_scope = TracingScope.cur_scope + + def __enter__(self): + TracingScope.cur_scope = self + return self + + def __exit__(self, exc_type, exc_value, traceback): + TracingScope.cur_scope = self.last_scope + + @staticmethod + def get_current_scope(): + return TracingScope.cur_scope + + def add_child_state(self, state: TracerProgramState): + cur_scope = self + while cur_scope is not None: + cur_scope.tracer_state.child_states.append(state) + cur_scope = cur_scope.last_scope diff --git a/sglang/python/sglang/launch_server.py b/sglang/python/sglang/launch_server.py new file mode 100644 index 0000000000000000000000000000000000000000..fe06a9289ac8935314ab3fcd46fdb210b1a9bcdd --- /dev/null +++ b/sglang/python/sglang/launch_server.py @@ -0,0 +1,64 @@ +"""Launch the inference server.""" + +import asyncio +import os +import sys + +from sglang.srt.server_args import prepare_server_args +from sglang.srt.utils import kill_process_tree +from sglang.srt.utils.common import suppress_noisy_warnings + +suppress_noisy_warnings() + + +def run_server(server_args): + """Run the server based on server_args.grpc_mode and server_args.encoder_only.""" + if server_args.encoder_only: + if server_args.grpc_mode: + from sglang.srt.disaggregation.encode_grpc_server import ( + serve_grpc_encoder, + ) + + asyncio.run(serve_grpc_encoder(server_args)) + else: + from sglang.srt.disaggregation.encode_server import launch_server + + launch_server(server_args) + elif server_args.grpc_mode: + from sglang.srt.entrypoints.grpc_server import serve_grpc + + asyncio.run(serve_grpc(server_args)) + elif server_args.use_ray: + try: + from sglang.srt.ray.http_server import launch_server + except ImportError: + raise ImportError( + "Ray is required for --use-ray mode. " + "Install it with: pip install 'sglang[ray]'" + ) + + launch_server(server_args) + else: + # Default mode: HTTP mode. + from sglang.srt.entrypoints.http_server import launch_server + + launch_server(server_args) + + +if __name__ == "__main__": + import warnings + + warnings.warn( + "'python -m sglang.launch_server' is still supported, but " + "'sglang serve' is the recommended entrypoint.\n" + " Example: sglang serve --model-path [options]", + UserWarning, + stacklevel=1, + ) + + server_args = prepare_server_args(sys.argv[1:]) + + try: + run_server(server_args) + finally: + kill_process_tree(os.getpid(), include_parent=False) diff --git a/sglang/python/sglang/multimodal_gen/.claude/CLAUDE.md b/sglang/python/sglang/multimodal_gen/.claude/CLAUDE.md new file mode 100644 index 0000000000000000000000000000000000000000..26cb86ae6bbc3fb0879b54c5b84379aa45613ebb --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/.claude/CLAUDE.md @@ -0,0 +1,108 @@ +# CLAUDE.md — sglang-diffusion (multimodal_gen) + +## What is this? + +SGLang's diffusion/multimodal generation subsystem. Separate from the LLM runtime (`srt`). Supports 20+ image/video diffusion models (Wan, FLUX, HunyuanVideo, LTX, Qwen-Image, etc.) with distributed inference, LoRA, and multiple attention backends. + +## Quick Start + +```bash +# One-shot generation +sglang generate --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers --prompt "A curious raccoon" --save-output + +# Start server +sglang serve --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers --num-gpus 4 + +# Python API +from sglang import DiffGenerator +gen = DiffGenerator.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers") +result = gen.generate(sampling_params_kwargs={"prompt": "A curious raccoon"}) +``` + +## Architecture + +``` +CLI / Python API / HTTP Server (FastAPI + OpenAI-compatible) + ↓ ZMQ +Scheduler (request queue, batching, dispatch) + ↓ multiprocessing pipes +GPU Worker(s) → ComposedPipeline (stages: TextEncode → Denoise → Decode) +``` + +### Key Directories + +``` +runtime/ +├── entrypoints/ # CLI (generate/serve), HTTP server, Python API (DiffGenerator) +├── managers/ # scheduler.py, gpu_worker.py +├── pipelines_core/ # ComposedPipelineBase, stages/, schedule_batch.py (Req/OutputBatch) +├── pipelines/ # Model-specific pipelines (wan, flux, hunyuan, ltx, qwen_image, ...) +├── models/ # encoders/, dits/, vaes/, schedulers/ +├── layers/ # attention/, lora/, quantization/ +├── loader/ # Model loading, weight utils +├── server_args.py # ServerArgs (all CLI/config params) +└── distributed/ # Multi-GPU (TP, SP: ulysses/ring) +configs/ +├── pipeline_configs/ # Per-model pipeline configs +├── sample/ # SamplingParams +└── models/ # DiT, VAE, Encoder configs +``` + +### Key Classes + +| Class | Location | Purpose | +|-------|----------|---------| +| `DiffGenerator` | `runtime/entrypoints/diffusion_generator.py` | Python API entry point | +| `ComposedPipelineBase` | `runtime/pipelines_core/composed_pipeline_base.py` | Pipeline orchestrator (stages) | +| `Scheduler` | `runtime/managers/scheduler.py` | ZMQ event loop, request dispatch | +| `GPUWorker` | `runtime/managers/gpu_worker.py` | GPU inference worker | +| `Req` / `OutputBatch` | `runtime/pipelines_core/schedule_batch.py` | Request/output containers | +| `ServerArgs` | `runtime/server_args.py` | All config params | +| `SamplingParams` | `configs/sample/sampling_params.py` | Generation params | +| `PipelineConfig` | `configs/pipeline_configs/base.py` | Model structure config | + +### Key Functions + +| Function | Module | Purpose | +|----------|--------|---------| +| `build_pipeline()` | `runtime/pipelines_core/__init__.py` | Instantiate pipeline from model_path | +| `get_model_info()` | `registry.py` | Resolve pipeline + config classes | +| `launch_server()` | `runtime/launch_server.py` | Start multi-process server | + +## Adding a New Model + +1. Create pipeline in `runtime/pipelines/` extending `ComposedPipelineBase` +2. Define stages via `create_pipeline_stages()` (TextEncoding → Denoising → Decoding) +3. Add config in `configs/pipeline_configs/` +4. Register in `registry.py` via `register_configs()` + +## Multi-GPU + +```bash +# Sequence parallelism (video frames across GPUs) +sglang serve --model-path ... --num-gpus 4 --ulysses-degree 2 --ring-degree 2 + +# Tensor parallelism (model layers across GPUs) +sglang serve --model-path ... --num-gpus 2 --tp-size 2 +``` + +## Testing + +```bash +# Tests live in test/ subdirectory +python -m pytest python/sglang/multimodal_gen/test/ + +# No need to pre-download models — auto-downloaded at runtime +# Dependencies assumed already installed via `pip install -e "python[diffusion]"` +``` + +## Perf Measurement + +Look for `Pixel data generated successfully in xxxx seconds` in console output. With warmup enabled, use the line containing `warmup excluded` for accurate timing. + +## Env Vars + +Defined in `envs.py` (300+ vars). Key ones: +- `SGLANG_DIFFUSION_ATTENTION_BACKEND` — attention backend override +- `SGLANG_CACHE_DIT_ENABLED` — enable Cache-DiT acceleration +- `SGLANG_CLOUD_STORAGE_TYPE` — cloud output storage (s3, etc.) diff --git a/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/SKILL.md b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..2cdd24b2c159628b9611b02b75f3ad72abcb1abb --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/SKILL.md @@ -0,0 +1,68 @@ +--- +name: diffusion-kernel +description: Index for SGLang Diffusion kernel development skills. +--- + +# Diffusion Kernel Skills + +## Rule: Follow User Kernel Language Preference + +If the user explicitly states a preference for **Triton** or **CUDA**, follow that preference when implementing and optimizing kernels (even if the other option could work). Do not “pick for convenience”. + +## Directory Layout + +``` +python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/ +├── SKILL.md +├── add-triton-kernel.md +├── add-cuda-kernel.md +├── diffusion-benchmark-and-profile.md +├── nsight-profiler.md +├── use-efficient-diffusion-kernels.md +├── references/ +│ ├── kernel-templates.md # Copy-paste CUDA kernel templates (sglang JIT style) +│ ├── troubleshooting.md # Build/perf/integration issues & fixes +│ ├── h100-optimization-guide.md # H100 (sm_90) deep dive +│ ├── a100-optimization-guide.md # A100 (sm_80) deep dive +│ └── t4-optimization-guide.md # T4 (sm_75, FP16 only) deep dive +└── scripts/ + ├── bench_diffusion_rmsnorm.py # RMSNorm micro-benchmark vs PyTorch + └── bench_diffusion_denoise.py # End-to-end denoise benchmark (sglang generate) +``` + +## Index + +- [add-triton-kernel.md](./add-triton-kernel.md) + + Step-by-step guide for adding a new Triton kernel to SGLang Diffusion's `jit_kernel/diffusion/triton/` module, including authoring, autotune, `torch.compile` compatibility, integration, and tests. Use for fused elementwise ops, norm variants, RoPE variants, or when NPU/CPU fallback is needed. + +- [add-cuda-kernel.md](./add-cuda-kernel.md) + + Step-by-step guide for adding a JIT CUDA kernel. CUDA source goes in `jit_kernel/csrc/diffusion/.cuh`; Python wrapper at `jit_kernel/diffusion/.py`. Uses SGLang's JIT compilation system (`load_jit`, `cache_once`) and internal abstractions (`TensorMatcher`, `device::AlignedVector`, `host::LaunchKernel`, `device::warp::reduce_sum`). Use for bandwidth-bound reductions (RMSNorm, LayerNorm) or ops needing fine-grained vectorization and shared memory control. Adapted from [HuggingFace kernels cuda-kernels skill](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels). + +- [use-efficient-diffusion-kernels.md](./use-efficient-diffusion-kernels.md) + + Practical guidance for using SGLang Diffusion fused kernels and fast CUDA paths, including constraints, fallbacks, and where the fused ops are wired into the runtime. + +- [diffusion-benchmark-and-profile.md](./diffusion-benchmark-and-profile.md) + + Denoise-stage benchmark and profiling guide for SGLang Diffusion models. Three profiling levels: Level 1 (torch.profiler — kernel time ranking), Level 2 (nsys — category breakdown), Level 3 (ncu — per-kernel bandwidth/occupancy/roofline analysis). **ncu is critical for kernel optimization** — always use it when writing or tuning custom kernels to verify hardware saturation. + +- [nsight-profiler.md](./nsight-profiler.md) + + Advanced profiling skill for NVIDIA Nsight Systems / Nsight Compute: collecting traces, reading reports, and interpreting kernel-level performance metrics. + +## References (GPU optimization guides, templates, troubleshooting) + +Loaded by `add-cuda-kernel.md`. Adapted from [HuggingFace kernels cuda-kernels skill](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels). + +- [references/kernel-templates.md](references/kernel-templates.md) — copy-paste ready sglang JIT CUDA templates: element-wise (SiLU), row-reduction (RMSNorm), fused AdaLN, Python wrapper, test, benchmark +- [references/troubleshooting.md](references/troubleshooting.md) — build errors, performance issues, torch.compile compatibility, kernel injection pitfalls +- [references/h100-optimization-guide.md](references/h100-optimization-guide.md) — H100 (sm_90): AlignedVector benchmarks, warp reductions, occupancy, TMA, PDL +- [references/a100-optimization-guide.md](references/a100-optimization-guide.md) — A100 (sm_80): cp.async, TF32, 2:4 sparsity, H100→A100 migration checklist +- [references/t4-optimization-guide.md](references/t4-optimization-guide.md) — T4 (sm_75): FP16 only, 320 GB/s bandwidth, 64 KB shared mem, 16 GB memory management + +## Scripts (runnable benchmarks) + +- [scripts/bench_diffusion_rmsnorm.py](scripts/bench_diffusion_rmsnorm.py) — RMSNorm micro-benchmark: JIT CUDA vs PyTorch, correctness check, bandwidth efficiency analysis +- [scripts/bench_diffusion_denoise.py](scripts/bench_diffusion_denoise.py) — end-to-end denoise benchmark via `sglang generate`, baseline vs custom kernels comparison table diff --git a/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/add-cuda-kernel.md b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/add-cuda-kernel.md new file mode 100644 index 0000000000000000000000000000000000000000..0b275766df935c87a42b8e61d687b8275875e6f9 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/add-cuda-kernel.md @@ -0,0 +1,542 @@ +--- +name: add-cuda-kernel +description: Step-by-step guide for adding a new JIT CUDA kernel to SGLang Diffusion. CUDA source files go in jit_kernel/csrc/diffusion/.cuh; Python wrapper at jit_kernel/diffusion/.py. Use when implementing optimized CUDA kernels for diffusion model operators (RMSNorm, RoPE, AdaLN, GEGLU, etc.) on NVIDIA GPUs (H100, A100). Covers kernel authoring with sglang abstractions, JIT compilation, Python wrapper, integration into the denoise stage, and benchmarking. Adapted from https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels. +--- + +# Adding a CUDA Kernel to SGLang Diffusion (JIT Style) + +> **Origin**: This skill is adapted from the [HuggingFace kernels cuda-kernels skill](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels), rewritten to follow SGLang's JIT compilation system and internal abstractions. +> +> **Extended references** (in this directory's `references/` and `scripts/`): +> - [references/kernel-templates.md](references/kernel-templates.md) — copy-paste ready templates for element-wise, row-reduction (RMSNorm), fused AdaLN +> - [references/troubleshooting.md](references/troubleshooting.md) — build errors, perf issues, integration pitfalls +> - [references/h100-optimization-guide.md](references/h100-optimization-guide.md) — H100 (sm_90) deep dive +> - [references/a100-optimization-guide.md](references/a100-optimization-guide.md) — A100 (sm_80) deep dive +> - [references/t4-optimization-guide.md](references/t4-optimization-guide.md) — T4 (sm_75, FP16 only) deep dive +> - [scripts/bench_diffusion_rmsnorm.py](scripts/bench_diffusion_rmsnorm.py) — RMSNorm micro-benchmark vs PyTorch +> - [scripts/bench_diffusion_denoise.py](scripts/bench_diffusion_denoise.py) — end-to-end denoise benchmark with/without kernels + +## When to Use CUDA vs Triton + +| Scenario | Use | +|----------|-----| +| Fused elementwise / norm variants / RoPE | **Triton** (`add-triton-kernel.md`) — faster iteration | +| Bandwidth-bound reduction (RMSNorm, LayerNorm) requiring max vectorization | **CUDA** — full control over `__nv_bfloat162` / `float4` vectorization | +| Attention pattern or tile-based ops needing shared memory tuning | **CUDA** — warp-level primitives, shared memory layout | +| Prototype or NPU/CPU fallback needed | **Triton** — portable across backends | + +For most diffusion-model elementwise ops, **start with Triton**. Switch to CUDA when profiling shows Triton can't reach hardware bandwidth limits. + +## Directory Layout + +``` +python/sglang/jit_kernel/ +├── csrc/ +│ ├── diffusion/ # JIT CUDA source files for diffusion kernels (this skill) +│ │ ├── timestep_embedding.cuh # existing example +│ │ ├── rmsnorm.cuh # NEW: add here +│ │ └── adaln.cuh # NEW: add here +│ └── elementwise/ # shared JIT CUDA csrc (non-diffusion) +├── diffusion/ +│ ├── triton/ # Triton kernels (scale_shift, norm, rope, ...) +│ ├── cutedsl/ # CuTe DSL kernels +│ └── rmsnorm.py # NEW: CUDA JIT Python wrapper (add here) +├── timestep_embedding.py # existing CUDA diffusion kernel Python wrapper (legacy) +``` + +New diffusion CUDA kernel source files go into `python/sglang/jit_kernel/csrc/diffusion/.cuh`. +The Python wrapper goes at `python/sglang/jit_kernel/diffusion/.py` +(inside `diffusion/`, alongside the `triton/` and `cutedsl/` subdirectories). + +--- + +## SGLang Kernel Abstractions (Required) + +Always use these — do **not** use raw CUDA primitives directly. + +```cpp +#include // TensorMatcher, SymbolicSize, SymbolicDevice +#include // fp16_t, bf16_t, fp32_t, dtype_trait, packed_t +#include // RuntimeCheck, div_ceil +#include // LaunchKernel, SGL_DEVICE, type aliases +#include // AlignedVector — 128-bit vector loads +#include // warp::reduce_sum, warp::reduce_max +#include // device::math::rsqrt, sqrt, ... +#include // tile::Memory (strided access pattern) +``` + +Key types: `fp16_t` = `__half`, `bf16_t` = `__nv_bfloat16`, `fp32_t` = `float`. +Packed variants: `fp16x2_t`, `bf16x2_t`. Use `packed_t` for the 2-element alias. + +--- + +## Step 1: Write the CUDA Kernel + +Create `python/sglang/jit_kernel/csrc/diffusion/rmsnorm.cuh` (RMSNorm as example). + +### 1a. Vectorized RMSNorm Kernel + +```cpp +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace { + +// --------------------------------------------------------------- +// RMSNorm kernel: y = x / rms(x) * weight +// T = fp16_t | bf16_t | fp32_t +// kVecN = vectorized elements per load (8 for fp16/bf16, 4 for fp32) +// --------------------------------------------------------------- +template +__global__ void rmsnorm_kernel( + T* __restrict__ dst, + const T* __restrict__ src, + const T* __restrict__ weight, // may be nullptr if no affine weight + uint32_t hidden_size, + uint32_t n_vecs, // hidden_size / kVecN + float eps) +{ + using vec_t = device::AlignedVector; + + const uint32_t row = blockIdx.x; + const T* row_src = src + row * hidden_size; + T* row_dst = dst + row * hidden_size; + + // --- Pass 1: accumulate sum of squares (vectorized) --- + float sum_sq = 0.f; + for (uint32_t vi = threadIdx.x; vi < n_vecs; vi += blockDim.x) { + vec_t v; + v.load(row_src, vi); + #pragma unroll + for (int i = 0; i < kVecN; ++i) { + float val = static_cast(v[i]); + sum_sq += val * val; + } + } + + // --- Warp reduction --- + sum_sq = device::warp::reduce_sum(sum_sq); + + // --- Block reduction via shared memory --- + __shared__ float smem[32]; + if (threadIdx.x % 32 == 0) { + smem[threadIdx.x / 32] = sum_sq; + } + __syncthreads(); + if (threadIdx.x < 32) { + sum_sq = (threadIdx.x < blockDim.x / 32) ? smem[threadIdx.x] : 0.f; + sum_sq = device::warp::reduce_sum(sum_sq); + } + __syncthreads(); + + const float rms_inv = device::math::rsqrt(sum_sq / static_cast(hidden_size) + eps); + + // --- Pass 2: normalize + apply weight (vectorized) --- + for (uint32_t vi = threadIdx.x; vi < n_vecs; vi += blockDim.x) { + vec_t v_in, v_w, v_out; + v_in.load(row_src, vi); + if (weight != nullptr) { + v_w.load(weight, vi); + } + #pragma unroll + for (int i = 0; i < kVecN; ++i) { + float val = static_cast(v_in[i]) * rms_inv; + if (weight != nullptr) { + val *= static_cast(v_w[i]); + } + v_out[i] = static_cast(val); + } + v_out.store(row_dst, vi); + } +} + +// --------------------------------------------------------------- +// Launcher +// --------------------------------------------------------------- +template +void rmsnorm( + tvm::ffi::TensorView dst, + tvm::ffi::TensorView src, + tvm::ffi::TensorView weight, // pass empty / nullptr for no-weight case + float eps) +{ + using namespace host; + + // Validate + SymbolicSize B{"batch_tokens"}, H{"hidden_size"}; + SymbolicDevice device; + device.set_options(); + + TensorMatcher({B, H}) + .with_dtype() + .with_device(device) + .verify(dst) + .verify(src); + + const uint32_t num_rows = static_cast(B.unwrap()); + const uint32_t hidden = static_cast(H.unwrap()); + const DLDevice dev = device.unwrap(); + + RuntimeCheck(hidden % (16 / sizeof(T)) == 0, + "rmsnorm: hidden_size must be divisible by vector width, got ", hidden); + + constexpr int kVecN = 16 / sizeof(T); // 128-bit vector: 8×fp16/bf16, 4×fp32 + const uint32_t n_vecs = hidden / kVecN; + + // Thread count: enough warps to cover n_vecs, max 512 threads + uint32_t threads = std::min(n_vecs, 512u); + threads = (threads + 31) / 32 * 32; // round up to warp boundary + + const T* w_ptr = (weight.data_ptr() != nullptr) + ? static_cast(weight.data_ptr()) : nullptr; + + LaunchKernel(num_rows, threads, dev)( + rmsnorm_kernel, + static_cast(dst.data_ptr()), + static_cast(src.data_ptr()), + w_ptr, + hidden, + n_vecs, + eps); +} + +} // namespace +``` + +--- + +## Step 2: Python Wrapper + +Create `python/sglang/jit_kernel/diffusion/rmsnorm.py`: + +```python +from __future__ import annotations +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_rmsnorm_module(dtype: torch.dtype) -> Module: + args = make_cpp_args(dtype) + return load_jit( + "diffusion_rmsnorm", + *args, + cuda_files=["diffusion/rmsnorm.cuh"], # relative to csrc/ + cuda_wrappers=[("rmsnorm", f"rmsnorm<{args}>")], + extra_cuda_cflags=["-O3", "--use_fast_math"], + ) + + +def diffusion_rmsnorm( + src: torch.Tensor, + weight: torch.Tensor | None = None, + eps: float = 1e-6, + out: torch.Tensor | None = None, +) -> torch.Tensor: + """ + RMSNorm for diffusion DiT layers. + + y = x / rms(x) * weight (weight=None → no affine scaling) + + Supported dtypes: float16, bfloat16, float32. + hidden_size must be divisible by 8 (fp16/bf16) or 4 (fp32). + """ + assert src.is_cuda, "src must be a CUDA tensor" + assert src.dtype in (torch.float16, torch.bfloat16, torch.float32) + + if out is None: + out = torch.empty_like(src) + + # Pass a zero-sized tensor when weight is absent (launcher checks data_ptr == nullptr) + w = weight if weight is not None else torch.empty(0, dtype=src.dtype, device=src.device) + + module = _jit_rmsnorm_module(src.dtype) + module.rmsnorm(out, src, w, eps) + return out +``` + +**Key rules for the wrapper:** +- Use `cache_once` — never `functools.lru_cache` (breaks `torch.compile`) +- First arg(s) to `load_jit` form the unique build cache key +- `cuda_files` are relative to `python/sglang/jit_kernel/csrc/` +- `cuda_wrappers`: `(python_name, cpp_template_instantiation)` + +--- + +## Step 3: Integrate into Denoising Stage + +The kernel replaces a slow operator inside the DiT forward pass. Find the correct module in: + +``` +python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py +python/sglang/multimodal_gen/runtime/models/dits/.py +``` + +**Pattern — monkey-patch the DiT block's RMSNorm:** + +```python +from sglang.jit_kernel.diffusion.rmsnorm import diffusion_rmsnorm + +def _patch_rmsnorm(model: torch.nn.Module) -> None: + for name, module in model.named_modules(): + cls_name = type(module).__name__ + if cls_name in ("RMSNorm", "LlamaRMSNorm") or "RMSNorm" in cls_name: + eps = getattr(module, "eps", getattr(module, "variance_epsilon", 1e-6)) + has_weight = hasattr(module, "weight") and module.weight is not None + + if has_weight: + def _make_fwd(mod, epsilon): + def forward(x): + return diffusion_rmsnorm(x, weight=mod.weight, eps=epsilon) + return forward + module.forward = _make_fwd(module, eps) + else: + def _make_fwd_noweight(epsilon): + def forward(x): + return diffusion_rmsnorm(x, weight=None, eps=epsilon) + return forward + module.forward = _make_fwd_noweight(eps) +``` + +**Critical:** inject kernels **before** `torch.compile` and before any CPU offload is enabled. + +--- + +## Step 4: Key Kernel Patterns Reference + +### Diffusion-Specific Operators + +| Operator | Kernel Pattern | Notes | +|----------|---------------|-------| +| **RMSNorm** | 2-pass row reduction + vectorized normalize | Weight may be `None` (`elementwise_affine=False`) | +| **AdaLN modulation** | `y = norm(x) * (1 + scale) + shift` | Fuse norm + scale + shift in one pass | +| **RoPE 3D** | Read `(t, h, w)` cos/sin tables, apply to `(q, k)` | Layout: `[batch, t*h*w, heads, head_dim]` | +| **GEGLU** | Split last dim → `gate * silu(linear)` | Input `[B, L, 2*H]` → output `[B, L, H]` | +| **SiLU gate** | `out = a * sigmoid(a)` fused | Avoid separate elementwise ops | + +### Vectorized Memory Access + +```cpp +// BF16: 8 elements × 2 bytes = 16 bytes per vector load (AlignedVector) +// FP16: 8 elements × 2 bytes = 16 bytes (AlignedVector) +// FP32: 4 elements × 4 bytes = 16 bytes (AlignedVector) +constexpr int kVecN = 16 / sizeof(T); +using vec_t = device::AlignedVector; +``` + +### Warp / Block Reductions + +```cpp +// Warp reduction (within 32 threads) +float result = device::warp::reduce_sum(partial); + +// Block reduction via shared memory (see rmsnorm example above) +__shared__ float smem[32]; +// ... write warp-leaders into smem, sync, reduce again +``` + +### Thread Configuration + +```cpp +// Element-wise (RoPE, GEGLU, SiLU): simple 1D grid +constexpr uint32_t kBlock = 256; +uint32_t grid = host::div_ceil(total_elements, kBlock); +LaunchKernel(grid, kBlock, dev)(kernel, ...); + +// Row reduction (RMSNorm, LayerNorm): one block per row +uint32_t threads = std::min(hidden_size / kVecN, 512u); +threads = (threads + 31) / 32 * 32; +LaunchKernel(num_rows, threads, dev)(kernel, ...); +``` + +--- + +## Step 5: GPU Architecture Targets + +| GPU | Compute Cap | Memory BW | BF16 | Key Note | +|-----|------------|-----------|------|----------| +| H100 | sm_90 | 3.35 TB/s | Yes | Primary target; 132 SMs, 192 KB shared mem/SM | +| A100 | sm_80 | 2.0 TB/s | Yes | 108 SMs, 164 KB shared mem/SM | +| T4 | sm_75 | 320 GB/s | **No** | FP16 only; no `__nv_bfloat16` | + +If kernel requires SM90+ features (e.g., TMA, wgmma), raise a clear error: + +```python +if torch.cuda.get_device_capability()[0] < 9: + raise RuntimeError("This kernel requires SM90 (H100/Hopper) or later") +``` + +**Grid sizing for H100** (132 SMs): aim for grid multiples of 132 for good occupancy. + +--- + +## Step 6: Tests + +Create `python/sglang/jit_kernel/tests/test_diffusion_rmsnorm.py`: + +```python +import pytest +import torch +from sglang.jit_kernel.diffusion.rmsnorm import diffusion_rmsnorm + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("shape", [(1, 2048), (4, 3072), (16, 4096)]) +@pytest.mark.parametrize("has_weight", [True, False]) +def test_rmsnorm_correctness(dtype, shape, has_weight): + batch, hidden = shape + src = torch.randn(batch, hidden, dtype=dtype, device="cuda") + weight = torch.randn(hidden, dtype=dtype, device="cuda") if has_weight else None + + out_jit = diffusion_rmsnorm(src, weight=weight, eps=1e-6) + + # Reference: torch.nn.functional + ref = torch.nn.functional.rms_norm( + src.float(), (hidden,), weight.float() if weight is not None else None, eps=1e-6 + ).to(dtype) + + tol = {"rtol": 1e-2, "atol": 1e-2} if dtype != torch.float32 else {"rtol": 1e-5, "atol": 1e-6} + torch.testing.assert_close(out_jit, ref, **tol) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) +``` + +--- + +## Step 7: Benchmark + +Create `python/sglang/jit_kernel/benchmark/bench_diffusion_rmsnorm.py`: + +```python +import torch +import triton.testing + +from sglang.jit_kernel.benchmark.utils import DEFAULT_DEVICE, DEFAULT_DTYPE, run_benchmark +from sglang.jit_kernel.diffusion.rmsnorm import diffusion_rmsnorm + +SHAPES = [(4096, 2048), (4096, 3072), (4096, 4096)] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["hidden"], + x_vals=[s[1] for s in SHAPES], + line_arg="provider", + line_vals=["jit_cuda", "torch"], + line_names=["SGLang JIT CUDA", "PyTorch rms_norm"], + styles=[("blue", "-"), ("red", "--")], + ylabel="us", + plot_name="diffusion-rmsnorm", + args={}, + ) +) +def benchmark(hidden: int, provider: str): + src = torch.randn(4096, hidden, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE) + w = torch.ones(hidden, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE) + + if provider == "jit_cuda": + fn = lambda: diffusion_rmsnorm(src, weight=w, eps=1e-6) + else: + fn = lambda: torch.nn.functional.rms_norm(src, (hidden,), w, eps=1e-6) + + return run_benchmark(fn) + + +if __name__ == "__main__": + benchmark.run(print_data=True) +``` + +--- + +## Step 8: Profile with Nsight Compute (required) + +After correctness + benchmarking, you must collect **Nsight Compute (ncu)** data to validate: + +- Whether the kernel reaches reasonable bandwidth/throughput (avoid false positives where it is “faster” but under-utilizes hardware) +- Whether there are clear occupancy / register / shared memory limiters + +Use the canonical docs in this directory (do not duplicate CLI details across multiple skills): + +- `diffusion-benchmark-and-profile.md` → Step 3.5 (ncu workflow, including CUDA graph profiling) +- `nsight-profiler.md` (metrics interpretation: bandwidth / occupancy / roofline / stall reasons) + +--- + +## Common Pitfalls + +| Issue | Fix | +|-------|-----| +| `RMSNorm weight is None` | Use `type(module).__name__` check; pass `None` weight explicitly | +| `isinstance(m, torch.nn.RMSNorm)` misses diffusers variants | Use `"RMSNorm" in type(m).__name__` | +| Kernel patched after `torch.compile` | Inject **before** any compile call | +| Kernel patched after `enable_model_cpu_offload()` | Inject **before** CPU offload | +| `hidden_size` not divisible by `kVecN` | Add `RuntimeCheck(hidden % kVecN == 0, ...)` in launcher | +| `torch.compile` fails with custom CUDA kernel | Register as `@torch.library.custom_op` or use Triton instead | +| T4 GPU with BF16 kernel | Gate on compute capability; T4 is `sm_75`, no native BF16 | + +--- + +## Summary of Files + +``` +python/sglang/jit_kernel/csrc/diffusion/ +└── rmsnorm.cuh # NEW: JIT CUDA kernel source + +python/sglang/jit_kernel/diffusion/ +└── rmsnorm.py # NEW: Python wrapper + load_jit + +python/sglang/jit_kernel/tests/ +└── test_diffusion_rmsnorm.py # NEW: correctness tests + +python/sglang/jit_kernel/benchmark/ +└── bench_diffusion_rmsnorm.py # NEW: benchmark +``` + +--- + +## References + +### This Skill's Extended Docs (references/ and scripts/) + +| File | Contents | +|------|----------| +| [references/kernel-templates.md](references/kernel-templates.md) | Copy-paste templates: element-wise, RMSNorm, AdaLN, Python wrapper, test, benchmark | +| [references/troubleshooting.md](references/troubleshooting.md) | Build errors, perf issues, torch.compile compatibility, debugging checklist | +| [references/h100-optimization-guide.md](references/h100-optimization-guide.md) | H100 (sm_90): memory hierarchy, warp reductions, occupancy, vectorization benchmarks | +| [references/a100-optimization-guide.md](references/a100-optimization-guide.md) | A100 (sm_80): cp.async, TF32, 2:4 sparsity, H100→A100 migration checklist | +| [references/t4-optimization-guide.md](references/t4-optimization-guide.md) | T4 (sm_75): FP16 only, low bandwidth, tile size limits, memory constraints | +| [scripts/bench_diffusion_rmsnorm.py](scripts/bench_diffusion_rmsnorm.py) | Micro-benchmark: JIT CUDA RMSNorm vs PyTorch, correctness check, bandwidth analysis | +| [scripts/bench_diffusion_denoise.py](scripts/bench_diffusion_denoise.py) | End-to-end: `sglang generate` baseline vs custom kernels, comparison table | + +### SGLang Internals + +- **JIT system**: `add-jit-kernel` skill (`sglang/.claude/skills/add-jit-kernel/SKILL.md`) +- **JIT utils**: `python/sglang/jit_kernel/utils.py` — `cache_once`, `load_jit`, `make_cpp_args` +- **Abstractions**: `python/sglang/jit_kernel/include/sgl_kernel/` — `tensor.h`, `utils.cuh`, `vec.cuh`, `warp.cuh`, `math.cuh`, `tile.cuh` +- **Real csrc examples**: `python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh`, `python/sglang/jit_kernel/csrc/elementwise/qknorm.cuh` + +### Other Diffusion Kernel Skills (this directory) + +- **Triton alternative**: `add-triton-kernel.md` — prefer Triton unless bandwidth analysis shows CUDA needed +- **Existing fused kernels**: `use-efficient-diffusion-kernels.md` — check here first before writing new kernels +- **Profiling**: `diffusion-benchmark-and-profile.md` — workflow to identify bottleneck before implementing +- **Nsight Compute deep dive**: `nsight-profiler.md` — full guide: occupancy analysis, roofline model, warp efficiency, kernel comparison + +### External + +- [HuggingFace kernels cuda-kernels skill](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels) — original source adapted for this skill diff --git a/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/add-triton-kernel.md b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/add-triton-kernel.md new file mode 100644 index 0000000000000000000000000000000000000000..e6072cf769b52907a58877b5bfe121efed7f59c5 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/add-triton-kernel.md @@ -0,0 +1,512 @@ +--- +name: add-triton-kernel +description: Step-by-step guide for adding a new Triton kernel to SGLang Diffusion's jit_kernel module. Use when implementing fused elementwise ops, norm variants, RoPE variants, or any other lightweight GPU kernel for diffusion models using Triton JIT. Covers kernel authoring, autotune, torch.compile compatibility, layer integration, and tests. +--- + +# Adding a Triton Kernel to SGLang Diffusion + +This guide walks through adding a Triton kernel to `python/sglang/jit_kernel/diffusion/triton/`. +We use a fused elementwise operation as the running example: `y = x * (1 + scale) + shift` (AdaLN modulation). + +--- + +## Directory Layout + +``` +python/sglang/jit_kernel/diffusion/ +├── triton/ +│ ├── scale_shift.py # AdaLN scale/shift fused kernels +│ ├── norm.py # LayerNorm / RMSNorm fused kernels +│ ├── rmsnorm_onepass.py # One-pass RMSNorm for small hidden size +│ └── rotary.py # RoPE kernel +└── cutedsl/ + └── ... # CuTe DSL kernels (see use-efficient-diffusion-kernels.md) +``` + +New Triton kernels go into `triton/.py`. + +--- + +## Step 1: Write the Triton Kernel + +Create `python/sglang/jit_kernel/diffusion/triton/.py`. + +### 1a. Imports + +```python +import torch +import triton # type: ignore +import triton.language as tl # type: ignore +``` + +Always use `# type: ignore` on triton imports — the stubs are incomplete. + +### 1b. The `@triton.jit` Kernel Function + +Follow the naming convention `__kernel` (private, underscore prefix). + +```python +@triton.autotune( + configs=[ + triton.Config({"BLOCK_C": 64}, num_warps=2), + triton.Config({"BLOCK_C": 128}, num_warps=4), + triton.Config({"BLOCK_C": 256}, num_warps=4), + triton.Config({"BLOCK_C": 512}, num_warps=8), + ], + key=["C"], # re-tune when hidden dim changes +) +@triton.jit +def _fused_scale_shift_kernel( + # Pointers — always pass raw tensors; Triton takes .data_ptr() internally + x_ptr, + scale_ptr, + shift_ptr, + y_ptr, + # Dimensions + B, # batch size + L, # sequence length + C, # hidden / channel dim + # Strides — pass every stride separately; do NOT assume contiguous + stride_xb, stride_xl, stride_xc, + stride_sb, stride_sc, + stride_yb, stride_yl, stride_yc, + # Compile-time constants (tl.constexpr) + BLOCK_C: tl.constexpr, +): + # Grid: (cdiv(L, 1), B) — one program per (batch, token) + pid_l = tl.program_id(0) + pid_b = tl.program_id(1) + + c_offs = tl.arange(0, BLOCK_C) + mask = c_offs < C + + x_row = pid_b * stride_xb + pid_l * stride_xl + y_row = pid_b * stride_yb + pid_l * stride_yl + s_row = pid_b * stride_sb + + x = tl.load(x_ptr + x_row + c_offs * stride_xc, mask=mask, other=0.0) + scale = tl.load(scale_ptr + s_row + c_offs * stride_sc, mask=mask, other=0.0) + shift = tl.load(shift_ptr + s_row + c_offs * stride_sc, mask=mask, other=0.0) + + y = x * (1.0 + scale) + shift + tl.store(y_ptr + y_row + c_offs * stride_yc, y, mask=mask) +``` + +**Rules:** +- All pointer arguments are raw (Triton extracts `.data_ptr()` internally when called via `kernel[grid](...)`). +- Pass every stride as a separate scalar — never assume a tensor is contiguous inside the kernel. +- Use `tl.constexpr` for block sizes and boolean flags (`HAS_RESIDUAL`, `IS_RMS_NORM`, etc.). +- Use `mask=mask, other=0.0` on every `tl.load` to avoid out-of-bounds reads. +- Compute in `tl.float32` when precision matters (`x.to(tl.float32)`), then cast back to output dtype before `tl.store`. +- Use `tl.fma(a, b, c)` (`a*b + c`) for fused multiply-add — avoids rounding errors and maps to a single instruction. + +### 1c. `@triton.autotune` Guidelines + +| `key` entry | When to include | +|-------------|-----------------| +| `"C"` / `"hidden_dim"` | Always — block tile size depends on C | +| `"IS_RMS_NORM"` | When the kernel has a `constexpr` boolean flag that changes code paths | +| `"HAS_RESIDUAL"` | Same — constexpr path branching | +| Shape / batch / seq | Usually NOT — autotune cost outweighs benefit | + +Keep configs in ascending `BLOCK_C` order with matching `num_warps` (warp × 32 threads ≤ 1024). + +### 1d. `torch.compile` Compatibility + +When the kernel is called inside a `torch.compile`-d region, wrap the launch with `torch.library.wrap_triton`: + +```python +with torch.get_device_module().device(x.device): + torch.library.wrap_triton(_fused_scale_shift_kernel)[grid]( + x, scale, shift, y, + B, L, C, + x.stride(0), x.stride(1), x.stride(2), + scale.stride(0), scale.stride(1), + y.stride(0), y.stride(1), y.stride(2), + ) +``` + +Use `wrap_triton` when the kernel is called from a layer that runs under `torch.compile`. +Skip it for utility kernels called only at Python graph boundaries. + +--- + +## Step 2: Write the Python Launcher + +The launcher is a regular Python function (public, no underscore) in the same file. + +```python +def fused_scale_shift( + x: torch.Tensor, + scale: torch.Tensor, + shift: torch.Tensor, +) -> torch.Tensor: + """ + Fused AdaLN modulation: y = x * (1 + scale) + shift. + + Args: + x: [B, L, C], CUDA, contiguous + scale: [B, C], CUDA + shift: [B, C], CUDA (same shape as scale) + + Returns: + y: same shape and dtype as x + """ + # --- Precondition checks --- + assert x.is_cuda, "x must be on CUDA" + assert x.is_contiguous(), "x must be contiguous" + assert scale.is_cuda and shift.is_cuda + assert x.ndim == 3, f"x must be 3D [B, L, C], got {x.shape}" + assert scale.shape == shift.shape + B, L, C = x.shape + + # Allocate output + y = torch.empty_like(x) + + # Grid: one program per token + grid = (L, B) + + _fused_scale_shift_kernel[grid]( + x, scale, shift, y, + B, L, C, + x.stride(0), x.stride(1), x.stride(2), + scale.stride(0), scale.stride(1), + y.stride(0), y.stride(1), y.stride(2), + ) + return y +``` + +**Rules:** +- Validate CUDA placement and shape/dtype **before** launching — use `assert` with a helpful message. +- Call `.contiguous()` on inputs that the kernel requires contiguous **before** the launch, not inside it. +- Allocate the output with `torch.empty_like(x)` — never reuse input buffers unless the op is explicitly in-place. +- The `grid` is a tuple or a lambda `(META)` when block sizes are auto-tuned: + +```python +# Static grid (block size fixed) +grid = (triton.cdiv(L, BLOCK_L), triton.cdiv(C, BLOCK_C), B) + +# Dynamic grid (block size comes from autotune) +grid = lambda META: (triton.cdiv(L, META["BLOCK_C"]), B) +``` + +### Handling Non-Contiguous Inputs + +Never call `.contiguous()` silently — it copies data. Instead, pass strides to the kernel and let it handle arbitrary layouts. Only call `.contiguous()` when the kernel genuinely requires it (e.g., after a reshape): + +```python +# OK: reshape + contiguous needed for 2D view trick +x_2d = x.view(B * L, C) # view only works on contiguous +if not x.is_contiguous(): + x = x.contiguous() + x_2d = x.view(B * L, C) +``` + +--- + +## Step 3: Integrate into the Layer + +Call the new kernel from the appropriate layer file in +`python/sglang/multimodal_gen/runtime/layers/` (typically `layernorm.py` or `elementwise.py`). + +```python +# In layernorm.py or elementwise.py +import torch + +def apply_scale_shift(x, scale, shift): + if x.is_cuda: + from sglang.jit_kernel.diffusion.triton.my_op import fused_scale_shift + return fused_scale_shift(x, scale, shift) + # Pure-PyTorch fallback for non-CUDA execution + return x * (1.0 + scale) + shift +``` + +**Rules:** +- Gate on `x.is_cuda` — the Triton kernel only runs on CUDA; the fallback handles everything else. +- The launcher raises `AssertionError` on invalid inputs (wrong shape, CPU tensor, etc.) — do **not** silently catch these. Let them propagate so bugs are visible during development. +- Add `logger.warning_once(...)` only when falling back due to a **known hardware limitation** (e.g., unsupported SM compute capability), not for wrong-input errors. + +--- + +## Step 4: Write Tests + +Create `python/sglang/jit_kernel/tests/test_.py`. + +```python +import pytest +import torch + +from sglang.jit_kernel.diffusion.triton.my_op import fused_scale_shift + + +def _ref_fused_scale_shift(x, scale, shift): + """PyTorch reference implementation.""" + # Broadcast scale/shift from [B, C] to [B, L, C] + return x * (1.0 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +@pytest.fixture(autouse=True) +def require_cuda(): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + + +@pytest.mark.parametrize("B,L,C", [ + (1, 6, 3072), # Qwen (small batch) + (1, 1024, 1536), # Wan + (2, 512, 3072), # typical training shape + (1, 1, 256), # edge: L=1 +]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_fused_scale_shift_correctness(B, L, C, dtype): + torch.manual_seed(0) + x = torch.randn(B, L, C, dtype=dtype, device="cuda") + scale = torch.randn(B, C, dtype=dtype, device="cuda") * 0.1 + shift = torch.randn(B, C, dtype=dtype, device="cuda") * 0.1 + + out = fused_scale_shift(x, scale, shift) + ref = _ref_fused_scale_shift(x.float(), scale.float(), shift.float()).to(dtype) + + atol = 1e-5 if dtype == torch.float32 else 1e-2 + torch.testing.assert_close(out, ref, atol=atol, rtol=atol, + msg=f"Mismatch at B={B} L={L} C={C} dtype={dtype}") + + +def test_fused_scale_shift_non_cuda_raises(): + x = torch.randn(1, 4, 64) + scale = torch.randn(1, 64) + shift = torch.randn(1, 64) + with pytest.raises(AssertionError, match="CUDA"): + fused_scale_shift(x, scale, shift) + + +def test_fused_scale_shift_output_dtype_preserved(): + x = torch.randn(1, 8, 128, dtype=torch.bfloat16, device="cuda") + scale = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda") + shift = torch.zeros(1, 128, dtype=torch.bfloat16, device="cuda") + out = fused_scale_shift(x, scale, shift) + assert out.dtype == torch.bfloat16 + assert out.shape == x.shape + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) +``` + +Run: + +```bash +pytest python/sglang/jit_kernel/tests/test_.py -v +``` + +**Test coverage requirements:** +1. Reference comparison against pure-PyTorch for all supported dtypes (fp16, bf16, fp32). +2. Edge shapes: `L=1`, `C` not a multiple of the largest BLOCK_C, large `B`. +3. Error cases: CPU tensor, wrong shape. +4. Output dtype and shape preservation. + +--- + +## Step 5: Add a Benchmark (required) + +Create `python/sglang/jit_kernel/benchmark/bench_.py`. + +```python +import torch +import triton.testing + +from sglang.jit_kernel.diffusion.triton.my_op import fused_scale_shift + + +SHAPES = [ + # (B, L, C) — representative diffusion shapes + (1, 6, 3072), # Qwen image + (1, 1024, 1536), # Wan video + (1, 4096, 3072), # FLUX double-stream +] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["B", "L", "C"], + x_vals=SHAPES, + line_arg="provider", + line_vals=["triton", "torch"], + line_names=["Triton Fused", "PyTorch"], + styles=[("blue", "-"), ("red", "--")], + ylabel="µs (median)", + plot_name="fused-scale-shift", + args={}, + ) +) +def benchmark(B, L, C, provider): + dtype = torch.bfloat16 + x = torch.randn(B, L, C, dtype=dtype, device="cuda") + scale = torch.randn(B, C, dtype=dtype, device="cuda") + shift = torch.randn(B, C, dtype=dtype, device="cuda") + + if provider == "triton": + fn = lambda: fused_scale_shift(x, scale, shift) + else: + fn = lambda: x * (1.0 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + ms, *_ = triton.testing.do_bench_cudagraph(fn, quantiles=[0.5, 0.2, 0.8]) + return ms * 1000 # µs + + +if __name__ == "__main__": + benchmark.run(print_data=True) +``` + +Run: + +```bash +python python/sglang/jit_kernel/benchmark/bench_.py +``` + +--- + +## Step 6: Profile with Nsight Compute (required for optimization work) + +After correctness tests, you must use **ncu (Nsight Compute)** to validate hardware efficiency (bandwidth/throughput/occupancy/bottleneck type). + +To avoid duplicating ncu CLI details across multiple skills, this skill does not repeat command flags. Follow the canonical docs: + +- `diffusion-benchmark-and-profile.md` → Step 3.5 (ncu workflow, including CUDA graph profiling) +- `nsight-profiler.md` (metrics interpretation: bandwidth / occupancy / roofline / warp stalls) + +--- + +## Common Patterns Reference + +### Pattern 1: Autotune over a 2D tile (L × C) + +Used in `scale_shift.py` (`fuse_scale_shift_kernel_blc_opt`): + +```python +@triton.jit +def _kernel(..., BLOCK_L: tl.constexpr, BLOCK_C: tl.constexpr): + pid_l = tl.program_id(0) + pid_c = tl.program_id(1) + pid_b = tl.program_id(2) + l_offs = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) + c_offs = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + mask = (l_offs[:, None] < L) & (c_offs[None, :] < C) + ... + +# Launch: +grid = (triton.cdiv(L, BLOCK_L), triton.cdiv(C, BLOCK_C), B) +_kernel[grid](..., BLOCK_L=block_l, BLOCK_C=block_c, num_warps=4, num_stages=2) +``` + +### Pattern 2: One-pass RMSNorm for small hidden size + +Used in `rmsnorm_onepass.py`: + +```python +@triton.jit +def _rms_norm_tiled_onepass(y_ptr, x_ptr, w_ptr, + SEQ: tl.constexpr, DIM: tl.constexpr, EPS: tl.constexpr, + BLOCK_SIZE_SEQ: tl.constexpr, BLOCK_SIZE_DIM: tl.constexpr): + seq_blk_id = tl.program_id(0) + seq_id = seq_blk_id * BLOCK_SIZE_SEQ + seq_offset = seq_id + tl.arange(0, BLOCK_SIZE_SEQ)[:, None] + d_offset = tl.arange(0, BLOCK_SIZE_DIM)[None, :] + ... + x = tl.load(x_ptr + seq_offset * DIM + d_offset, mask=..., other=0.0).to(tl.float32) + mean_sq = tl.sum(x * x, axis=1, keep_dims=True) / DIM + rstd = tl.math.rsqrt(mean_sq + EPS) + tl.store(y_ptr + ..., x * rstd * w, mask=...) + +# Launch with wrap_triton for torch.compile compat: +with torch.get_device_module().device(x.device): + torch.library.wrap_triton(_rms_norm_tiled_onepass)[grid]( + y_view, x_view, w, + S, D, eps, + BLOCK_SIZE_DIM=triton.next_power_of_2(D), + BLOCK_SIZE_SEQ=BLOCK_SIZE_SEQ, + ) +``` + +### Pattern 3: `tl.constexpr` boolean flags for conditional paths + +Used in `norm.py` and `scale_shift.py`: + +```python +@triton.jit +def _kernel(..., + IS_RMS_NORM: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + SCALE_IS_SCALAR: tl.constexpr): + ... + if IS_RMS_NORM: + var = tl.sum(x * x, axis=0) / N + else: + mean = tl.sum(x, axis=0) / N + var = tl.sum((x - mean) ** 2, axis=0) / N + + if HAS_RESIDUAL: + x = x + tl.load(residual_ptr + ...) + + if SCALE_IS_SCALAR: + scale_val = tl.load(scale_ptr) + scale = tl.full([BLOCK_N], scale_val, dtype=scale_val.dtype) + else: + scale = tl.load(scale_ptr + col_offsets, mask=mask, other=0.0) +``` + +Autotune key must include these booleans so the compiler generates separate specializations. + +### Pattern 4: Computing in fp32, storing in original dtype + +Always up-cast to `tl.float32` for reductions and math, then down-cast before storing: + +```python +x_f32 = x.to(tl.float32) +scale_f32 = scale.to(tl.float32) +y_f32 = x_f32 * (1.0 + scale_f32) + shift_f32 +tl.store(y_ptr + offsets, y_f32.to(x.dtype), mask=mask) +``` + +--- + +## Checklist Before Submitting + +### Prerequisites +- [ ] `ncu --version` prints a valid Nsight Compute version (required for Step 7 profiling) + +### Implementation +- [ ] Kernel file at `python/sglang/jit_kernel/diffusion/triton/.py` +- [ ] All pointer arguments passed with separate stride scalars +- [ ] Every `tl.load` uses `mask=` and `other=` +- [ ] Autotune `key` includes all `constexpr` flags that change code paths +- [ ] `torch.library.wrap_triton` used if kernel runs inside `torch.compile` region +- [ ] PyTorch fallback path in the layer integration (see Step 4) + +### Validation +- [ ] Tests pass: `pytest python/sglang/jit_kernel/tests/test_.py -v` +- [ ] Benchmark runs: `python python/sglang/jit_kernel/benchmark/bench_.py` +- [ ] **Correctness verified**: Triton output matches PyTorch reference within tolerance +- [ ] Nsight Compute profile collected (`ncu --set full`); achieved occupancy ≥ 50% and memory throughput ≥ 70% of peak (or bottleneck documented) + +--- + +## Summary of Files Created/Modified + +``` +python/sglang/jit_kernel/diffusion/triton/.py # NEW: Triton kernel + launcher +python/sglang/jit_kernel/tests/test_.py # NEW: correctness tests +python/sglang/jit_kernel/benchmark/bench_.py # NEW: performance benchmark +python/sglang/multimodal_gen/runtime/layers/layernorm.py # MODIFIED: integrate into layer + (or elementwise.py, depending on op type) +``` + +## References + +- `python/sglang/jit_kernel/diffusion/triton/scale_shift.py` — 2D tile pattern, scalar broadcast, 4D shape handling +- `python/sglang/jit_kernel/diffusion/triton/rmsnorm_onepass.py` — `wrap_triton`, tiled one-pass reduction +- `python/sglang/jit_kernel/diffusion/triton/norm.py` — complex autotune with many `constexpr` flags +- `python/sglang/jit_kernel/diffusion/triton/rotary.py` — per-head grid, interleaved RoPE +- `nsight-profiler.md` — full Nsight Compute guide: occupancy analysis, roofline model, warp efficiency, kernel comparison +- `diffusion-benchmark-and-profile.md` — how to verify the kernel's impact on denoise latency +- `use-efficient-diffusion-kernels.md` — overview of existing fused kernel entry points diff --git a/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/diffusion-benchmark-and-profile.md b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/diffusion-benchmark-and-profile.md new file mode 100644 index 0000000000000000000000000000000000000000..653688017db6daa629e574b1133e56d99c2dea00 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/diffusion-benchmark-and-profile.md @@ -0,0 +1,478 @@ +--- +name: diffusion-benchmark-and-profile +description: Denoise-stage benchmark and per-layer kernel profiling guide for SGLang Diffusion models. Use when measuring denoising latency, profiling DiT kernel breakdown with torch.profiler or nsys+gputrc2graph.py, investigating performance bottlenecks, or optimizing with custom Triton/CUDA kernels. Always verify output correctness before and after any optimization. +--- + +# SGLang Diffusion Benchmark and Profile Guide + +**Primary Metric: Denoise Latency** +The denoising loop latency — total DiT forward pass time across all inference steps — is the dominant cost (>80% of end-to-end) and the **sole optimization target** for kernel work. End-to-end latency is recorded as a secondary check only. + +> **Correctness First**: Faster but incorrect output is not an improvement. Always compare generated images/videos against a reference baseline before and after any change. + +--- + +## Prerequisites + +```bash +#!/usr/bin/env bash +# Quick dependency check +check() { "$@" &>/dev/null && echo "[OK] $1" || echo "[MISS] $1"; } +check "sglang" python3 -c "import sglang" +check "torch+CUDA" python3 -c "import torch; assert torch.cuda.is_available()" +check "torch.profiler" python3 -c "import torch.profiler" +check "nsys (Level 2)" which nsys +check "pandas" python3 -c "import pandas" +check "plotly" python3 -c "import plotly" +``` + +**Minimum for benchmarking**: `sglang`, `torch` with CUDA. +**Level 1 profiling**: `torch.profiler` (bundled with torch). +**Level 2 profiling**: `nsys`, `pandas`, `plotly` + `gputrc2graph.py` from the sglang repo. + +Download input images required by some models: +```bash +mkdir -p /workspace/gen_benchmark/figs +wget -O /workspace/gen_benchmark/figs/cat.png \ + https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png +wget -O /workspace/gen_benchmark/figs/astronaut.jpg \ + https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg +``` + +--- + +## Benchmark Commands + +All commands include `--warmup` and `--enable-torch-compile` for real production performance. Add `--perf-dump-path .json` for machine-readable output. + +### Perf dump & before/after compare + +For every benchmark run, always write a perf dump JSON: + +```bash +sglang generate ... --warmup --perf-dump-path .json +``` + +Before/after comparison (outputs a Markdown table suitable for PR descriptions): + +```bash +# Baseline (on main branch or before changes) +sglang generate ... --warmup --perf-dump-path baseline.json + +# New (after changes) +sglang generate ... --warmup --perf-dump-path new.json + +python python/sglang/multimodal_gen/benchmarks/compare_perf.py baseline.json new.json +``` + +### Qwen-Image-2512 (1024×1024, 50 steps) +```bash +sglang generate \ + --model-path=Qwen/Qwen-Image-2512 \ + --prompt="A futuristic cyberpunk city at night, neon lights reflecting on wet streets, highly detailed, 8k" \ + '--negative-prompt= ' \ + --width=1024 --height=1024 --num-inference-steps=50 --guidance-scale=4.0 \ + --seed=42 --save-output --enable-torch-compile --warmup \ + --dit-cpu-offload false --text-encoder-cpu-offload false +``` + +### Qwen-Image-Edit-2511 (image editing, 1024×1024, 50 steps) +```bash +sglang generate \ + --model-path=Qwen/Qwen-Image-Edit-2511 \ + '--prompt=Transform into anime style' '--negative-prompt= ' \ + --image-path=/workspace/gen_benchmark/figs/cat.png \ + --width=1024 --height=1024 --num-inference-steps=50 --guidance-scale=4.0 \ + --seed=42 --save-output --enable-torch-compile --warmup \ + --dit-cpu-offload false --text-encoder-cpu-offload false +``` + +### FLUX.1-dev (1024×1024, 50 steps) +```bash +sglang generate \ + --model-path=black-forest-labs/FLUX.1-dev \ + --prompt="A futuristic cyberpunk city at night, neon lights reflecting on wet streets, highly detailed, 8k" \ + --width=1024 --height=1024 --num-inference-steps=50 --guidance-scale=4.0 \ + --seed=42 --save-output --enable-torch-compile --warmup +``` + +### FLUX.2-dev (1024×1024) +```bash +sglang generate \ + --model-path black-forest-labs/FLUX.2-dev \ + --prompt "A Logo With Bold Large Text: SGL Diffusion" \ + --width=1024 --height=1024 \ + --dit-layerwise-offload false --enable-torch-compile --warmup \ + --dit-cpu-offload false --text-encoder-cpu-offload true --vae-cpu-offload false +``` + +### Z-Image-Turbo (1024×1024, 9 steps) +```bash +sglang generate \ + --model-path=Tongyi-MAI/Z-Image-Turbo \ + --prompt='A fantasy landscape with mountains and a river, detailed, vibrant colors' \ + --width=1024 --height=1024 --num-inference-steps=9 --guidance-scale=0.0 \ + --seed=42 --save-output --enable-torch-compile --warmup \ + --dit-cpu-offload false --text-encoder-cpu-offload false +``` + +### Wan2.2-T2V-A14B 720P (8 GPUs, 81 frames, 40 steps) +```bash +sglang generate \ + --model-path=Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --prompt="A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon." \ + --negative-prompt=" " --720p --num-inference-steps=40 --num-frames=81 \ + --guidance-scale=5.0 --seed=42 --save-output \ + --num-gpus=8 --enable-cfg-parallel --ulysses-degree=4 \ + --dit-layerwise-offload true --dit-cpu-offload false \ + --vae-cpu-offload false --text-encoder-cpu-offload true \ + --warmup --enable-torch-compile +``` + +### Wan2.2-TI2V-5B 720P (single GPU, 81 frames, 50 steps) +```bash +sglang generate \ + --model-path Wan-AI/Wan2.2-TI2V-5B-Diffusers \ + --prompt "An astronaut hatching from an egg, on the surface of the moon..." \ + --negative-prompt "Bright tones, overexposed, static, blurred details..." \ + --image-path=/workspace/gen_benchmark/figs/astronaut.jpg \ + --num-frames 81 --720p --num-inference-steps 50 --guidance-scale 5.0 \ + --seed 42 --save-output \ + --dit-layerwise-offload false --dit-cpu-offload false \ + --vae-cpu-offload false --text-encoder-cpu-offload false \ + --enable-torch-compile --warmup +``` + +**Key metrics** (all models): denoise latency ★, end-to-end latency, peak GPU memory. + +--- + +## Performance Bottleneck Workflow + +### Step 1: Identify the Slow DiT Operation + +Add `--log-level=info` and observe: +- **Denoise loop latency** ★ — primary target +- Per-step DiT latency — denoise ÷ steps + +### Step 2: Profile with torch.profiler (Level 1) + +```bash +SGLANG_TORCH_PROFILER_DIR=/workspace/profiles \ +sglang generate \ + --model-path=black-forest-labs/FLUX.1-dev \ + --prompt="A futuristic cyberpunk city at night" \ + --width=1024 --height=1024 --num-inference-steps=50 \ + --seed=42 --enable-torch-compile --warmup \ + --profile --num-profiled-timesteps 3 +``` + +Parse the trace without a browser: +```python +import gzip, json, collections, glob, os + +log_dir = os.environ.get("SGLANG_TORCH_PROFILER_DIR", "./logs") +trace_path = sorted(glob.glob(f"{log_dir}/*.trace.json.gz"), key=os.path.getmtime, reverse=True)[0] +with gzip.open(trace_path, "rb") as f: + data = json.loads(f.read()) + +cuda_ops = collections.defaultdict(lambda: {"total_us": 0, "count": 0}) +for e in data.get("traceEvents", []): + if e.get("cat") in ("kernel", "gpu_memcpy") and "dur" in e: + cuda_ops[e.get("name","unknown")]["total_us"] += e["dur"] + cuda_ops[e.get("name","unknown")]["count"] += 1 + +print(f"{'Kernel':<80} {'Total(ms)':>10} {'Count':>6}") +for name, s in sorted(cuda_ops.items(), key=lambda x: -x[1]["total_us"])[:30]: + print(f"{name:<80} {s['total_us']/1000:>10.3f} {s['count']:>6}") +``` + +Add `record_function` scopes in the DiT block for per-layer attribution: +```python +with torch.profiler.record_function(f"dit_block_{idx}.attn"): + x = self.attn(x) +with torch.profiler.record_function(f"dit_block_{idx}.norm"): + x = self.norm(x) +``` + +**Expected dominant kernels per DiT sub-component:** + +| Sub-component | Expected kernel | +|--------------|-----------------| +| QKV / output / MLP projections | `cutlass_gemm` / `ampere_*_gemm` | +| Attention | `flash_attn_fwd` / `fmha_*` (FA3/FA4) | +| AdaLN modulation | `fuse_scale_shift_kernel` | +| RMSNorm / LayerNorm | `sgl_kernel_rmsnorm` / Triton norm | +| SiLU gate | `vectorized_elementwise_kernel` | +| RoPE | `apply_rotary_embedding` (Triton) | +| QK Norm | `fused_inplace_qknorm` (JIT) | + +### Step 3: Deep CUDA Kernel Breakdown (Level 2 — nsys) + +```bash +# Pass A — collect nsys trace (skip warmup with --delay) +nsys profile -t cuda -o /workspace/profiles/flux_dev -f true \ + --trace-fork-before-exec=true --delay 120 --duration 60 \ + sglang generate \ + --model-path=black-forest-labs/FLUX.1-dev \ + --prompt="A futuristic cyberpunk city at night" \ + --width=1024 --height=1024 --num-inference-steps=50 \ + --seed=42 --enable-torch-compile --warmup + +# Pass B — measure wall-clock time without profiling +time sglang generate --model-path=black-forest-labs/FLUX.1-dev \ + --width=1024 --height=1024 --num-inference-steps=50 --seed=42 \ + --enable-torch-compile --warmup +# Record ELAPSED_SEC from Pass B +``` + +Create classification JSON at `examples/profiler/nsys_profile_tools/sglang_diffusion_engine_model.json`: +```json +{ + "sglang": { + "diffusion": { + "gemm|nvjet|cutlass": "gemm", + "flash|fmha|fwd_flash": "attn", + "fuse_scale_shift|scale_shift_gate": "adaln_modulation", + "_norm_|Norm|rmsnorm|fused_add_rmsnorm": "norm", + "rotary|rope": "rope", + "act_and_mul|silu|gelu": "activation", + "ncclDevKernel|all_gather|all_reduce": "nccl_comm", + "triton": "triton_kernel", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + } + } +} +``` + +Run analysis: +```bash +cd examples/profiler/nsys_profile_tools +python3 gputrc2graph.py \ + --in_file /workspace/profiles/flux_dev.nsys-rep,sglang,diffusion,ELAPSED_SEC \ + --out_dir /workspace/profiles/analysis \ + --title "FLUX.1-dev denoise kernel breakdown" + +# Read results +python3 - << 'EOF' +import pandas as pd +df = pd.read_csv("/workspace/profiles/analysis/result.csv") +summary = df.groupby("Category")["Elapsed Time (sec)"].sum().sort_values(ascending=False) +total = summary.sum() +for cat, sec in summary.items(): + print(f"{cat:<30} {sec:>8.3f}s ({sec/total*100:>5.1f}%)") +EOF +``` + +**What the category breakdown tells you:** + +| Category high | Investigation | +|--------------|---------------| +| `gemm` dominant | Check tensor parallelism; QKV/MLP bottleneck | +| `attn` dominant | Verify FA3/FA4 is active | +| `adaln_modulation` high | Verify fused `fuse_scale_shift_kernel` is used | +| `norm` high | Verify `sgl_kernel_rmsnorm` / CuTe DSL path; check D alignment | +| `nccl_comm` high | Multi-GPU: tune Ulysses degree | +| `triton_kernel` high | Identify which Triton kernel; consider CUDA replacement | +| `non-gpu-H_D_memops` high | Accidental CPU offload or `.cpu()` calls mid-denoising | +| `CPU(non-GPU)` high | Python dispatch overhead / torch.compile graph breaks | + +### Step 3.5: Per-Kernel Deep Analysis (Level 3 — ncu) + +**CRITICAL**: `ncu` (Nsight Compute) is the essential tool for kernel-level optimization. While nsys and torch.profiler tell you **which** kernels are slow, only ncu tells you **why** — memory bandwidth utilization, compute throughput, occupancy limiters, warp stall reasons, and roofline position. **Always use ncu when optimizing or writing custom kernels.** + +#### When to use ncu + +- After writing a new Triton or CUDA kernel — verify it saturates hardware bandwidth +- When a kernel shows up as a top bottleneck in Level 1/2 profiling +- When comparing your fused kernel vs PyTorch baseline or torch.compile output +- When tuning Triton autotune configs (block sizes, num_warps) + +#### Basic ncu workflow + +```bash +# 1. Profile a specific kernel by name (skip warmup launches, collect 3 invocations) +ncu --kernel-name "_fused_gated_residual_add_kernel" \ + --launch-skip 10 --launch-count 3 \ + --set full \ + -o /workspace/ncu_reports/gated_residual \ + sglang generate \ + --model-path=black-forest-labs/FLUX.1-dev \ + --prompt="test" --width=1024 --height=1024 \ + --num-inference-steps=5 --seed=42 + +# 2. Profile all kernels in a short run (use few steps to limit time) +ncu --launch-skip 50 --launch-count 200 \ + --set full \ + -o /workspace/ncu_reports/all_kernels \ + sglang generate \ + --model-path=black-forest-labs/FLUX.1-dev \ + --prompt="test" --width=1024 --height=1024 \ + --num-inference-steps=3 --seed=42 + +# 3. For CUDA graph mode, use --graph-profiling=node to profile inside the graph +ncu --graph-profiling node \ + --kernel-name "_fused_gated_residual_add_kernel" \ + --launch-skip 5 --launch-count 3 \ + --set full \ + -o /workspace/ncu_reports/gated_residual_cudagraph \ + sglang generate \ + --model-path=black-forest-labs/FLUX.1-dev \ + --prompt="test" --width=1024 --height=1024 \ + --num-inference-steps=5 --seed=42 \ + --enable-piecewise-cuda-graph +``` + +#### Reading ncu results (CLI, no GUI needed) + +```bash +# Summary of all profiled kernels +ncu --import /workspace/ncu_reports/gated_residual.ncu-rep --page raw --csv 2>/dev/null | head -50 + +# Key metrics to extract: +ncu --import /workspace/ncu_reports/gated_residual.ncu-rep \ + --page details --csv 2>/dev/null | python3 -c " +import csv, sys +reader = csv.DictReader(sys.stdin) +key_metrics = [ + 'gpu__time_duration.avg', # Kernel duration + 'sm__throughput.avg.pct_of_peak_sustained_elapsed', # SM utilization + 'dram__throughput.avg.pct_of_peak_sustained_elapsed', # DRAM bandwidth util + 'l1tex__throughput.avg.pct_of_peak_sustained_elapsed', # L1 throughput + 'sm__warps_active.avg.pct_of_peak_sustained_active', # Achieved occupancy + 'launch__occupancy_limit_registers', # Occupancy limiter + 'launch__occupancy_limit_shared_mem', +] +for row in reader: + name = row.get('Metric Name', '') + if any(m in name for m in key_metrics): + print(f'{name:<60} {row.get(\"Metric Value\",\"\")}') +" +``` + +#### Interpreting ncu results for kernel optimization + +| Metric | Good | Action if bad | +|--------|------|--------------| +| DRAM throughput > 80% peak | Memory-bound, near optimal | Already saturating HBM — fuse with adjacent ops to reduce total memory traffic | +| DRAM throughput < 50% peak | Not saturating memory bandwidth | Check coalescing, increase vector width, tune BLOCK sizes | +| SM throughput > 60% peak | Compute-bound, near optimal | Reduce arithmetic, use faster instructions (e.g., FMA) | +| SM throughput < 30% peak | Underutilized compute | Increase occupancy, reduce warp stalls, check instruction mix | +| Achieved occupancy > 50% | Acceptable for most kernels | — | +| Achieved occupancy < 25% | Too few active warps | Reduce register pressure or shared memory; increase block size | + +#### Comparing before/after with ncu + +```bash +# Profile baseline kernel +ncu --kernel-name "vectorized_elementwise_kernel" \ + --launch-skip 10 --launch-count 3 --set full \ + -o /workspace/ncu_reports/baseline ./program + +# Profile optimized kernel +ncu --kernel-name "_fused_gated_residual_add_kernel" \ + --launch-skip 10 --launch-count 3 --set full \ + -o /workspace/ncu_reports/optimized ./program + +# Compare key metrics +for report in baseline optimized; do + echo "=== $report ===" + ncu --import /workspace/ncu_reports/${report}.ncu-rep \ + --page details --csv 2>/dev/null | grep -E "time_duration|throughput.*pct|occupancy" +done +``` + +**Decision rule after ncu analysis:** +- Kernel already at >80% DRAM bandwidth → fuse with neighbors to reduce total traffic +- Kernel at <50% DRAM bandwidth → tune block sizes, fix coalescing, increase vectorization +- Kernel compute-bound (SM util high, DRAM low) → reduce FLOPs or switch to a faster algorithm +- Low occupancy → reduce registers (simplify kernel) or increase block size in autotune configs + +### Step 4: Apply Kernel Optimization + +After pinpointing the slow op, choose the right tool: + +| Scenario | Skill to use | +|----------|-------------| +| New fused elementwise, norm variant, RoPE variant | **`add-triton-kernel.md`** — Triton JIT, faster iteration, NPU fallback | +| Bandwidth-bound reduction (RMSNorm) needing max vectorization | **`add-cuda-kernel.md`** — CUDA JIT with `AlignedVector`, warp reductions | +| Attention or tile-based op needing shared memory tuning | **`add-cuda-kernel.md`** — full control over CUDA primitives | +| Slow op already covered by existing fused kernel | **`use-efficient-diffusion-kernels.md`** — check constraints & enable | + +**Quick decision rule**: start with Triton. Switch to CUDA JIT only when profiling shows Triton can't saturate hardware bandwidth. + +Both kernel types use SGLang's JIT compilation: +- **Triton**: `python/sglang/jit_kernel/diffusion/triton/.py` +- **CUDA JIT**: `python/sglang/jit_kernel/csrc/diffusion/.cuh` + wrapper `python/sglang/jit_kernel/diffusion/.py` + +### Step 5: torch.compile Coverage + +```bash +TORCH_COMPILE_DEBUG=1 sglang generate ... +``` +- Dynamic shape changes trigger recompilation → fix resolution and frame count when benchmarking +- `tensor.item()` in conditional branches causes graph breaks → rewrite as tensor ops + +### Step 6: Multi-GPU Efficiency (Wan2.2-T2V-A14B) + +- Verify `--ulysses-degree` evenly divides `--num-gpus` +- Confirm `--enable-cfg-parallel` is active (requires `guidance_scale > 1`) +- `--dit-layerwise-offload true` introduces CPU↔GPU transfer overhead; disable when memory permits + +--- + +## Optimization Workflow Summary + +``` +0. BASELINE + sglang generate --seed=42 --save-output → save reference images/videos + ↓ +1. BENCHMARK + Run benchmark commands above → record denoise latency baseline + ↓ +2. LEVEL 1 PROFILE (torch.profiler) + --profile --num-profiled-timesteps 3 + → parse .trace.json.gz → rank ops by CUDA time + → identify slow DiT layer (norm / attn / mlp / rope / adaln) + ↓ +3. LEVEL 2 PROFILE (nsys + gputrc2graph.py) + → result.csv category breakdown (gemm / attn / adaln / norm / triton / cpu) + → confirm where GPU time is concentrated + ↓ +4. LEVEL 3 PROFILE (ncu — per-kernel deep analysis) ★ CRITICAL + → ncu --set full on target kernel(s) + → extract DRAM bandwidth util, SM throughput, achieved occupancy + → determine if kernel is memory-bound, compute-bound, or latency-bound + → for CUDA graph: use --graph-profiling node + ↓ +5. KERNEL OPTIMIZATION + Existing fused kernel? → use-efficient-diffusion-kernels.md + New Triton kernel? → add-triton-kernel.md + New CUDA JIT kernel? → add-cuda-kernel.md + After writing kernel → ncu again to verify bandwidth/occupancy ★ + ↓ +6. VERIFY CORRECTNESS + sglang generate --seed=42 --save-output → diff against reference + If output differs beyond tolerance → reject optimization + ↓ +7. RE-BENCHMARK + Verify denoise latency improvement; no regression on other models +``` + +--- + +## Checklist Before Merging + +### Correctness (must pass first) +- [ ] Reference outputs collected with `--seed=42 --save-output` **before** any change +- [ ] After change: regenerate with identical args and compare +- [ ] No visible quality degradation in generated images / videos +- [ ] Correctness verified on all benchmark models + +### Performance (only after correctness passes) +- [ ] All benchmark models executed; denoise latency ★, end-to-end, peak memory recorded +- [ ] No regression in denoise latency vs. previous baseline (±2% tolerance) +- [ ] New kernel shows measurable improvement on at least 2 models +- [ ] No new torch.compile graph breaks introduced +- [ ] Results reproducible with all offloads disabled and fixed `--seed=42` diff --git a/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/nsight-profiler.md b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/nsight-profiler.md new file mode 100644 index 0000000000000000000000000000000000000000..0c0b9e8578b63dcde1658728b29ebd7857912df1 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/nsight-profiler.md @@ -0,0 +1,277 @@ +--- +name: nsight-profiler +description: Expert skill for NVIDIA Nsight Systems and Nsight Compute profiling tools. Configure profiling sessions, analyze kernel reports, interpret occupancy metrics, roofline model data, memory bandwidth bottlenecks, and warp execution efficiency. +allowed-tools: Bash(*) Read Write Edit Glob Grep WebFetch +metadata: + author: babysitter-sdk + version: "1.0.0" + category: performance-profiling + backlog-id: SK-002 + source: "Adapted from https://github.com/lobehub/lobehub (.agents/skills/nsight-profiler)" +--- + +> **Source**: This skill is adapted from the [lobehub/lobehub](https://github.com/lobehub/lobehub) open-source repository (`.agents/skills/nsight-profiler`). Original author: `babysitter-sdk`. + +# nsight-profiler + +You are **nsight-profiler** - a specialized skill for NVIDIA Nsight Systems and Nsight Compute profiling tools. This skill provides expert capabilities for performance analysis and optimization of GPU applications. + +## Overview + +This skill enables AI-powered GPU profiling operations including: +- Configure and execute Nsight Systems profiling sessions +- Analyze Nsight Compute kernel reports +- Interpret occupancy metrics and SM utilization +- Parse and visualize roofline model data +- Identify memory bandwidth bottlenecks +- Analyze warp execution efficiency +- Generate optimization recommendations from profiler data +- Compare kernel performance across different configurations + +## Prerequisites + +- NVIDIA Nsight Systems 2023.1+ +- NVIDIA Nsight Compute 2023.1+ +- CUDA Toolkit 11.0+ +- GPU with compute capability 7.0+ (for full profiling features) + +## Capabilities + +### 1. Nsight Systems Profiling + +System-wide performance analysis: + +```bash +# Basic system profile +nsys profile -o report ./cuda_program + +# Profile with CUDA API tracing +nsys profile -t cuda,nvtx,osrt -o report ./cuda_program + +# Capture GPU metrics +nsys profile --gpu-metrics-device=all -o report ./cuda_program + +# Profile specific duration +nsys profile -d 10 -o report ./cuda_program + +# Export to multiple formats (one type per command) +nsys export -t sqlite report.nsys-rep +nsys export -t json report.nsys-rep + +# Generate summary statistics +nsys stats report.nsys-rep +``` + +### 2. Nsight Compute Profiling + +Detailed kernel analysis: + +```bash +# Profile all kernels +ncu -o profile ./cuda_program + +# Profile specific kernel +ncu --kernel-name myKernel -o profile ./cuda_program + +# Full metric collection +ncu --set full -o profile ./cuda_program + +# Roofline analysis +ncu --set roofline -o profile ./cuda_program + +# Memory analysis +ncu --section MemoryWorkloadAnalysis -o profile ./cuda_program + +# Compare two runs +ncu --import baseline.ncu-rep --diff ./cuda_program +``` + +### 3. Occupancy Analysis + +Analyze and optimize occupancy: + +```bash +# Collect occupancy metrics +ncu --section Occupancy -o occupancy ./cuda_program + +# Key metrics to analyze: +# - Achieved Occupancy +# - Theoretical Occupancy +# - Block Limit (registers, shared memory, warps) +# - Occupancy Limiter +``` + +```cuda +// Query occupancy in code +int numBlocks; +int blockSize = 256; +cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &numBlocks, myKernel, blockSize, sharedMemSize); + +float occupancy = (numBlocks * blockSize) / + (float)deviceProp.maxThreadsPerMultiProcessor; +printf("Theoretical Occupancy: %.2f%%\n", occupancy * 100); +``` + +### 4. Roofline Model Analysis + +Performance bound analysis: + +```bash +# Generate roofline data +ncu --set roofline -o roofline ./cuda_program + +# Key metrics: +# - Achieved FLOP/s +# - Achieved Memory Bandwidth +# - Arithmetic Intensity (FLOP/byte) +# - Ridge Point +``` + +Interpretation guide: +- Below memory roofline: Memory bound +- Below compute roofline: Compute bound +- At peak: Optimal utilization + +### 5. Memory Bandwidth Analysis + +Identify memory bottlenecks: + +```bash +# Memory analysis sections +ncu --section MemoryWorkloadAnalysis \ + --section MemoryWorkloadAnalysis_Chart \ + --section MemoryWorkloadAnalysis_Tables \ + -o memory ./cuda_program +``` + +Key metrics: +- Global Load/Store Throughput +- L1/L2 Cache Hit Rate +- Shared Memory Bandwidth +- Memory Transactions per Request + +### 6. Warp Execution Analysis + +Analyze warp efficiency: + +```bash +# Warp state analysis +ncu --section WarpStateStatistics -o warp ./cuda_program + +# Scheduler statistics +ncu --section SchedulerStatistics -o scheduler ./cuda_program +``` + +Key metrics: +- Warp Cycles Per Issued Instruction +- Eligible Warps Per Active Cycle +- Active Warps Per Scheduler +- Stall Reasons (memory, sync, execution) + +### 7. Kernel Comparison + +Compare kernel variants: + +```bash +# Step 1: Profile baseline +ncu --set full -o baseline ./program_v1 + +# Step 2: Profile optimized version +ncu --set full -o optimized ./program_v2 + +# Step 3: Export both profiles to CSV, then compare with Python (no GUI needed) +# Note: --import can only be specified once; --page diff is not a valid page value. +ncu --import baseline.ncu-rep --page details --csv > baseline_details.csv +ncu --import optimized.ncu-rep --page details --csv > optimized_details.csv + +python3 -c " +import csv +def load(p): + return {r.get('Metric Name',''): r.get('Metric Value','') + for r in csv.DictReader(open(p))} +b = load('baseline_details.csv') +o = load('optimized_details.csv') +for k in sorted(set(b) | set(o)): + bv, ov = b.get(k,''), o.get(k,'') + if bv != ov: + print(f'{k[:55]:<55} {bv} -> {ov}') +" + +### 8. Performance Recommendations + +Automated analysis: + +```bash +# Get optimization recommendations +ncu --section SpeedOfLight \ + --section SpeedOfLight_RooflineChart \ + -o speedoflight ./cuda_program + +# Export with recommendations +ncu --import profile.ncu-rep --page details --csv > details.csv +``` + +## Common Profiling Workflows + +### Workflow 1: Initial Performance Assessment + +```bash +# Step 1: System overview +nsys profile -t cuda -o system_overview ./program +nsys stats system_overview.nsys-rep + +# Step 2: Identify hot kernels +ncu --launch-skip 10 --launch-count 5 -o hot_kernels ./program + +# Step 3: Deep dive on bottleneck kernel +ncu --kernel-name hotKernel --set full -o detailed ./program +``` + +### Workflow 2: Memory Optimization + +```bash +# Analyze memory access patterns +ncu --section SourceCounters \ + --section MemoryWorkloadAnalysis \ + --kernel-name targetKernel \ + -o memory_analysis ./program + +# Check for coalescing issues +ncu --metrics l1tex__t_sectors_pipe_lsu_mem_global_op_ld.sum,\ +l1tex__t_requests_pipe_lsu_mem_global_op_ld.sum \ + -o coalescing ./program +``` + +### Workflow 3: Occupancy Optimization + +```bash +# Profile with occupancy focus +ncu --section Occupancy \ + --section LaunchStatistics \ + -o occupancy ./program +``` + +**Interpreting occupancy limiters** (from the `Occupancy` section report): + +| Limiter shown | Fix | +|---------------|-----| +| `Registers` | Reduce register pressure: use fewer local variables, add `maxnreg` hint | +| `Shared Memory` | Decrease shared memory allocation or use 32-bit instead of 64-bit | +| `Block Size` | Increase threads per block; ensure block size is a multiple of warp size (32) | +| `Warp Limit` | Already at theoretical max for this SM; no action needed | + +> **For Triton kernels**: block sizes are controlled via `@triton.autotune` configs, not CLI flags. To test occupancy at different block sizes, add or modify the `triton.Config({"BLOCK_C": N}, num_warps=W)` entries in the autotune list and re-run. Do **not** pass `--block-size` as a CLI argument — the Triton benchmark script does not accept it. + +## Dependencies + +- Nsight Systems 2023.1+ +- Nsight Compute 2023.1+ +- CUDA Toolkit 11.0+ + +## Constraints + +- Full profiling requires root/admin privileges +- Some metrics only available on specific GPU architectures +- Profiling adds overhead; results may differ from production +- Nsight Compute profiles one kernel invocation at a time by default diff --git a/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/a100-optimization-guide.md b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/a100-optimization-guide.md new file mode 100644 index 0000000000000000000000000000000000000000..af3a24c0969da6195da95176980022f036d1f128 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/a100-optimization-guide.md @@ -0,0 +1,283 @@ +# A100 GPU Optimization Guide — SGLang Diffusion JIT Kernels + +Deep dive into A100-specific optimizations for diffusion model CUDA kernels in SGLang's JIT system. + +> **Adapted from**: [HuggingFace kernels cuda-kernels skill](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels) + +--- + +## A100 Ampere Architecture Overview + +| Component | A100 40GB | A100 80GB | Notes | +|-----------|-----------|-----------|-------| +| Compute Capability | sm_80 | sm_80 | Use `"-arch=sm_80"` in `extra_cuda_cflags` | +| SMs | 108 | 108 | Grid: aim for multiples of 108 | +| Shared Memory | 164 KB/SM | 164 KB/SM | Configurable: 48/96/164 KB | +| L2 Cache | 40 MB | 40 MB | Less than H100 (50 MB) | +| Memory Bandwidth | 1.55 TB/s | 2.0 TB/s | HBM2e | +| Max Threads/SM | 2048 | 2048 | Same as H100 | +| Tensor Cores | 3rd gen | 3rd gen | FP16, BF16, TF32, INT8, INT4 | + +### A100 vs H100 Comparison + +| Feature | A100 | H100 | Impact on JIT Kernels | +|---------|------|------|-----------------------| +| Memory BW | 2.0 TB/s | 3.35 TB/s | H100 ~67% faster for memory-bound ops | +| SMs | 108 | 132 | Adjust persistent kernel grid sizing | +| Shared Mem/SM | 164 KB | 192 KB | Reduce max tile sizes on A100 | +| L2 Cache | 40 MB | 50 MB | Attention tile reuse still works well | +| TMA | No | Yes | Can't use `cp.async.bulk` on A100 | +| FP8 | No | Yes | Use FP16/BF16 only on A100 | + +--- + +## Memory Access Optimization + +Same coalescing and vectorization rules as H100; lower bandwidth makes them even more critical. + +### `AlignedVector` Vectorization (same pattern as H100) + +```cpp +#include + +constexpr int kVecN = 16 / sizeof(T); // 8 for bf16/fp16, 4 for fp32 +using vec_t = device::AlignedVector; + +vec_t v; +v.load(src, vi); +// ... process elements ... +v.store(dst, vi); +``` + +**Expected A100 performance (BF16 RMSNorm):** + +| Implementation | A100 (ms) | H100 (ms) | A100 Speedup | +|:---|:---:|:---:|:---:| +| Scalar loads | ~0.10 | 0.065 | 1.00x | +| `AlignedVector` | ~0.03 | 0.019 | ~3x | + +**Target bandwidth**: 30–40% of A100's 2.0 TB/s = 600–800 GB/s. + +### Shared Memory Configuration + +```cpp +// A100 max: 164 KB/SM +cudaFuncSetAttribute( + your_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + 164 * 1024 // 164 KB max on A100 +); +``` + +Attention tile sizes for A100: + +``` +BLOCK_SIZE_M = 128 (Q block) +BLOCK_SIZE_N = 64 (K,V block) +Tile = 128×64×2 = 16 KB (FP16) — fits in 164 KB shared mem +``` + +--- + +## Occupancy Tuning + +**Grid sizing for A100 (108 SMs):** + +```cpp +#include + +// Cap blocks to SM × occupancy (same pattern as H100) +static const uint32_t max_occ = host::runtime::get_blocks_per_sm(kernel, kBlockSize); +static const uint32_t num_sm = host::runtime::get_sm_count(device.unwrap().device_id); +const uint32_t num_blocks = std::min(num_sm * max_occ, host::div_ceil(n, kBlockSize)); +``` + +**Recommended block sizes (same as H100):** + +| Kernel Type | Threads/Block | Notes | +|-------------|---------------|-------| +| Element-wise | 256 | High occupancy | +| Row reduction | 512 | Full reduction per row | +| Tiled/attention | 256 | Balance shared mem | + +--- + +## A100-Specific Features + +### Async Memory Copy (sm_80) + +A100 introduced `cp.async` for overlapping compute and memory. Use this in custom kernels for prefetching: + +```cuda +#if __CUDA_ARCH__ >= 800 +// Async copy from global to shared (A100+) +__pipeline_memcpy_async(smem_ptr, global_ptr, bytes); +__pipeline_commit(); +__pipeline_wait_prior(0); +#endif +``` + +### TF32 Mode (A100 specific) + +Enables FP32-range with FP16-like throughput for GEMM. Enable in Python: + +```python +# Enable TF32 for matmuls (A100+) +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +``` + +TF32 is automatic for FP32 GEMMs via cuBLAS — no kernel changes needed. + +### Structural Sparsity (2:4) + +A100 tensor cores support 50% structured sparsity: + +```python +from torch.sparse import to_sparse_semi_structured +sparse_weight = to_sparse_semi_structured(dense_weight) +# ~2x GEMM speedup for matmul with sparse weight +``` + +--- + +## JIT Compilation for A100 + +```python +return load_jit( + "my_kernel", + *args, + cuda_files=["diffusion/my_kernel.cuh"], + cuda_wrappers=[("my_kernel", f"my_kernel<{args}>")], + extra_cuda_cflags=[ + "-O3", + "--use_fast_math", + "-arch=sm_80", # A100 only; omit for multi-arch + ], +) +``` + +**Multi-arch (A100 + H100):** + +```python +extra_cuda_cflags=[ + "-O3", "--use_fast_math", + "-gencode=arch=compute_80,code=sm_80", # A100 + "-gencode=arch=compute_90,code=sm_90", # H100 +] +``` + +Runtime arch guard (in Python wrapper): + +```python +cap = torch.cuda.get_device_capability() +if cap < (8, 0): + raise RuntimeError(f"This kernel requires sm_80 (A100) or later, got sm_{cap[0]}{cap[1]}") +``` + +--- + +## H100 → A100 Migration Checklist + +When porting an H100-optimized kernel to A100: + +| Item | H100 | A100 | Change Required | +|------|------|------|-----------------| +| Shared memory | 192 KB | 164 KB | Reduce `cudaFuncSetAttribute` size | +| Grid sizing | ×132 SMs | ×108 SMs | `get_sm_count()` handles automatically | +| TMA bulk copy | Available | **Not available** | Remove `cp.async.bulk`; use standard `__pipeline_memcpy_async` | +| FP8 | Available | **Not available** | Fall back to FP16/BF16 | +| PDL | Supported | Supported | `.enable_pdl(true)` works on sm_80 | +| Warp shuffles | Same | Same | No changes | +| `AlignedVector` | Same | Same | No changes | + +**Conditional compilation:** + +```cuda +#if __CUDA_ARCH__ >= 900 + // H100-only: TMA, FP8, thread block clusters + #define USE_TMA 1 +#elif __CUDA_ARCH__ >= 800 + // A100: cp.async, TF32, 2:4 sparsity + #define USE_ASYNC_COPY 1 +#endif +``` + +--- + +## Precision Notes + +| Type | Available on A100 | Notes | +|------|-------------------|-------| +| FP16 | Yes | Good, watch overflow in attention | +| BF16 | Yes | Preferred for training and inference | +| TF32 | Yes (A100 specific) | Auto for FP32 GEMMs | +| FP8 | **No** | H100 only | + +--- + +## Performance Profiling + +### NVIDIA Nsight Systems (nsys) + +```bash +nsys profile -o a100_profile python scripts/bench_diffusion_rmsnorm.py + +# Key metrics to watch: +# - Kernel duration +# - Memory transfer time +# - GPU idle time +# - Stream utilization +``` + +### NVIDIA Nsight Compute (ncu) + +```bash +# Full metrics +ncu --set full -o a100_metrics.ncu-rep \ + python scripts/bench_diffusion_rmsnorm.py + +# Specific metrics for bandwidth / occupancy checks +ncu --metrics sm__throughput.avg.pct_of_peak_sustained_elapsed,\ +dram__throughput.avg.pct_of_peak_sustained_elapsed \ + python scripts/bench_diffusion_rmsnorm.py + +# Key metrics for A100 diffusion kernels: +# - Achieved occupancy (sm__warps_active.avg.pct_of_peak_sustained_active) +# - Memory throughput (dram__throughput.avg.pct_of_peak_sustained_elapsed) +# → Target: 30–40% of 2.0 TB/s (600–800 GB/s) for vectorized kernels +# - Compute throughput (sm__throughput.avg.pct_of_peak_sustained_elapsed) +# - Warp stall reasons (smsp__warp_issue_stalled_*.avg.pct_of_peak_sustained_active) +# - Kernel time (gpu__time_duration.avg) +``` + +### Common A100 Performance Issues + +1. **Memory bound below target**: `dram__throughput` < 30% + - Fix: Use `AlignedVector` (128-bit vector loads) + +2. **Low occupancy**: Grid too small for 108 SMs + - Fix: Use `runtime::get_sm_count()` persistent kernel pattern + +3. **No TF32 for FP32 GEMMs**: torch.backends.cuda.matmul.allow_tf32 not set + - Fix: `torch.backends.cuda.matmul.allow_tf32 = True` + +--- + +## Best Practices Summary (A100) + +1. **Bandwidth**: Even more critical than H100 — profile with `ncu` first +2. **Vectorization**: `AlignedVector` gives ~3x over scalar +3. **TF32**: Enable for any FP32 matmul workload +4. **Shared memory**: Cap at 164 KB; use `cudaFuncSetAttribute` +5. **Grid sizing**: Multiples of 108 SMs via `runtime::get_sm_count` +6. **cp.async**: Use for prefetching in tiled kernels +7. **Multi-arch**: Build for both `sm_80` and `sm_90` to support both GPUs +8. **Same abstractions**: `AlignedVector`, `TensorMatcher`, `LaunchKernel` work identically + +## Reference Benchmark Results (A100 80GB, BF16) + +| Kernel | Shape | A100 (ms) | H100 (ms) | H100 Speedup | +|--------|-------|-----------|-----------|--------------| +| RMSNorm | [2, 1024, 2048] | ~0.08 | 0.054 | 1.5x | +| GEGLU | [2, 1024, 4096] | ~0.05 | 0.030 | 1.7x | diff --git a/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/h100-optimization-guide.md b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/h100-optimization-guide.md new file mode 100644 index 0000000000000000000000000000000000000000..c7e6975758821e1273255a0ea9d77c9379535089 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/h100-optimization-guide.md @@ -0,0 +1,364 @@ +# H100 GPU Optimization Guide — SGLang Diffusion JIT Kernels + +Deep dive into H100-specific optimizations for diffusion model CUDA kernels, written for SGLang's JIT kernel system. + +> **Adapted from**: [HuggingFace kernels cuda-kernels skill](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels) + +--- + +## H100 Hopper Architecture Overview + +| Component | Specification | Optimization Implication | +|-----------|---------------|--------------------------| +| Compute Capability | sm_90 | Use `extra_cuda_cflags=["-arch=sm_90"]` in `load_jit` | +| SMs | 132 | Grid: aim for multiples of 132 | +| Shared Memory | 192 KB/SM | Configurable: 96/144/192 KB | +| L2 Cache | 50 MB | Tile K,V of attention to fit in L2 | +| Memory Bandwidth | 3.35 TB/s | BF16 vectorized: achieves ~38% (~1.27 TB/s) | +| Max Threads/SM | 2048 | Max 16 blocks of 128 threads per SM | +| Warp Size | 32 | All reductions use `warp::reduce_sum` | +| Registers | 64K 32-bit/SM | 255 per thread max | + +### New Hopper Features (sm_90+) + +1. **Thread Block Clusters** — groups cooperating via Distributed Shared Memory +2. **TMA (Tensor Memory Accelerator)** — hardware-accelerated bulk copies +3. **FP8 support** — native 8-bit floating point in tensor cores +4. **PDL (Programmatic Dependent Launch)** — enable with `.enable_pdl(true)` in `LaunchKernel` + +Gate sm_90+ features with a runtime check before calling `load_jit`: + +```python +if torch.cuda.get_device_capability()[0] < 9: + raise RuntimeError("This kernel requires H100 (sm_90+)") +``` + +--- + +## Memory Hierarchy Optimization + +### Coalesced Global Memory Access + +```cpp +// GOOD: threads read consecutive addresses → 128-byte transaction per warp +uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; +fp16_t val = src[idx]; + +// BAD: strided access → multiple transactions, lower effective bandwidth +uint32_t idx = threadIdx.x * stride; // avoid stride > 1 +``` + +**Transaction sizes**: 32 bytes minimum, 128 bytes optimal (full warp, FP32). + +### Vectorized Memory Access with `AlignedVector` + +SGLang's `AlignedVector` provides 128-bit (16-byte) vector loads. Always use this instead of raw pointer reinterprets. + +```cpp +#include + +// 16 bytes per load: 8×bf16_t, 8×fp16_t, or 4×fp32_t +constexpr int kVecN = 16 / sizeof(T); +using vec_t = device::AlignedVector; + +// Load +vec_t v; +v.load(src, vi); // loads src[vi * kVecN .. vi * kVecN + kVecN - 1] + +// Process +#pragma unroll +for (int i = 0; i < kVecN; ++i) { + float val = static_cast(v[i]); + // ... compute ... + v[i] = static_cast(result); +} + +// Store +v.store(dst, vi); +``` + +**RMSNorm benchmark (H100 80GB, BF16):** + +| Implementation | Time (ms) | Speedup | +|:---|:---:|:---:| +| Scalar loads | 0.065 | 1.00x | +| `AlignedVector` | 0.019 | **3.37x** | + +Bandwidth achieved: **~38% of 3.35 TB/s** = 1.27 TB/s. + +### L2 Cache Utilization (50 MB) + +For attention, tile K and V so they stay in L2 while Q iterates: + +``` +BLOCK_SIZE_M = 128 (Q block) +BLOCK_SIZE_N = 64 (K,V block) +With head_dim=64: tile = 128×64×2 = 16 KB (FP16), multiple tiles fit in L2 +``` + +### Shared Memory Configuration + +Request max shared memory for attention kernels: + +```cpp +// In launcher (after selecting kernel function pointer): +cudaFuncSetAttribute( + your_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + 192 * 1024 // 192 KB max on H100 +); +``` + +Shared memory has 32 banks (4 bytes/bank). Avoid conflicts with padding: + +```cpp +__shared__ float data[32][33]; // 33 instead of 32 → no bank conflict +``` + +--- + +## Warp & CTA Reductions (SGLang Abstractions) + +Use `sgl_kernel/warp.cuh` and `sgl_kernel/cta.cuh` — never raw `__shfl_xor_sync`. + +```cpp +#include +#include + +// Warp-level sum (uses __shfl_xor_sync internally) +float result = device::warp::reduce_sum(partial); + +// Warp-level max +float mx = device::warp::reduce_max(val); + +// CTA-wide max via shared memory +__shared__ float smem[32]; +device::cta::reduce_max(val, smem, -1e38f); +// smem[0] holds the result after __syncthreads() +``` + +**Block reduction pattern for RMSNorm:** + +```cpp +// 1. Warp reduction +sum_sq = device::warp::reduce_sum(sum_sq); + +// 2. Write warp leaders to smem +__shared__ float smem_r[32]; +if (threadIdx.x % 32 == 0) smem_r[threadIdx.x / 32] = sum_sq; +__syncthreads(); + +// 3. Final warp reduction over warp leaders +if (threadIdx.x < 32) { + sum_sq = (threadIdx.x < blockDim.x / 32) ? smem_r[threadIdx.x] : 0.f; + sum_sq = device::warp::reduce_sum(sum_sq); +} +__syncthreads(); +``` + +--- + +## Occupancy Tuning + +``` +Occupancy = Active Warps per SM / Max Warps per SM (64) + +Limiting factors on H100: + 1. Registers: 65536 / (threads_per_block × regs_per_thread) + 2. Shared Memory: 192 KB / smem_per_block + 3. Threads: 2048 / threads_per_block +``` + +**Recommended block sizes:** + +| Kernel Type | Threads/Block | Warps | Reasoning | +|-------------|---------------|-------|-----------| +| Element-wise (RoPE, GEGLU) | 256 | 8 | High occupancy, simple | +| Row reduction (RMSNorm, LayerNorm) | 256–512 | 8–16 | Enough threads for full reduction | +| Tiled (attention) | 256 | 8 | Balance shared mem and registers | + +**Persistent kernel pattern** (cap grid to SM × occupancy): + +```cpp +#include + +static const uint32_t max_occ = host::runtime::get_blocks_per_sm(kernel, kBlockSize); +static const uint32_t num_sm = host::runtime::get_sm_count(device.unwrap().device_id); +const uint32_t num_blocks = std::min(num_sm * max_occ, host::div_ceil(n, kBlockSize)); +host::LaunchKernel(num_blocks, kBlockSize, device.unwrap())(kernel, params); +``` + +--- + +## Precision and Numerical Stability + +| Type | Exponent Bits | Mantissa Bits | Range | Use Case | +|------|--------------|---------------|-------|----------| +| FP16 | 5 | 10 | ±65504 | Inference; attention score overflow risk | +| BF16 | 8 | 7 | ±3.39×10³⁸ | Training/inference preferred; safer for attn | +| FP32 | 8 | 23 | ±3.39×10³⁸ | Accumulation only | + +**Mixed precision pattern** (always accumulate in FP32): + +```cpp +// Input via AlignedVector +vec_t v; +v.load(src, vi); +float acc = 0.f; +#pragma unroll +for (int i = 0; i < kVecN; ++i) { + float val = static_cast(v[i]); // promote to FP32 + acc += val * val; +} +// Output +v[i] = static_cast(fp32_result); // demote back +``` + +--- + +## Diffusion-Specific Patterns + +### DiT Block Operators + +| Operator | Pattern | Key Constraint | +|----------|---------|----------------| +| **RMSNorm** | 2-pass row reduction | weight may be `None` | +| **AdaLN** | `norm(x) * (1 + scale) + shift` | fuse norm+scale+shift | +| **RoPE 3D** | `[B, t*h*w, heads, head_dim]` | layout: `seq = t*h*w` | +| **GEGLU** | `gelu(gate) * value`, input `[B,L,2H]` | don't use for LTX-Video (uses GELU) | +| **SiLU gate** | `x * sigmoid(x)` | fuse with MLP linear | + +### Online Softmax (for custom attention) + +```cuda +// Numerically stable without materializing full [seq×seq] score matrix +float row_max = -INFINITY, row_sum = 0.f; +for each K block: + compute local_scores + new_max = max(row_max, max(local_scores)) + rescale = exp(row_max - new_max) + row_sum = row_sum * rescale + sum(exp(local_scores - new_max)) + out_acc = out_acc * rescale + softmax(local_scores) @ V_block + row_max = new_max +``` + +--- + +## Profiling and Debugging + +### NVIDIA Nsight Systems (nsys) + +System-wide profiling to see kernel durations, memory transfers, and GPU idle time: + +```bash +nsys profile -o profile_report python scripts/bench_diffusion_rmsnorm.py + +# Key metrics to watch: +# - Kernel duration +# - Memory transfer time +# - GPU idle time +# - Stream utilization +``` + +For end-to-end denoise profiling via `sglang generate`, see `diffusion-benchmark-and-profile.md` (Level 2: nsys + gputrc2graph.py). + +### NVIDIA Nsight Compute (ncu) + +Detailed per-kernel analysis for tuning individual JIT CUDA kernels: + +```bash +# Full metrics — use when you need everything (slow) +ncu --set full -o metrics.ncu-rep \ + python scripts/bench_diffusion_rmsnorm.py + +# Specific metrics — use for targeted bandwidth / occupancy checks +ncu --metrics sm__throughput.avg.pct_of_peak_sustained_elapsed,\ +dram__throughput.avg.pct_of_peak_sustained_elapsed \ + python scripts/bench_diffusion_rmsnorm.py + +# Key metrics for diffusion JIT kernels: +# - Achieved occupancy (sm__warps_active.avg.pct_of_peak_sustained_active) +# - Memory throughput (dram__throughput.avg.pct_of_peak_sustained_elapsed) +# - Compute throughput (sm__throughput.avg.pct_of_peak_sustained_elapsed) +# - Warp stall reasons (smsp__warp_issue_stalled_*.avg.pct_of_peak_sustained_active) +# - L1 cache hit rate (l1tex__t_requests_pipe_lsu_mem_global_op_ld.sum) +``` + +### Common Performance Issues + +1. **Low occupancy**: Too many registers or shared memory per block + - Check: `--ptxas-options=-v` in `extra_cuda_cflags` to see register count + - Fix: Reduce `--maxrregcount=N`; use smaller block size + +2. **Memory bound, low bandwidth**: Achieved < 30% of 3.35 TB/s + - Check: `dram__throughput.avg.pct_of_peak_sustained_elapsed` + - Fix: Switch to `AlignedVector` for 128-bit vector loads + +3. **Shared memory bank conflicts**: `l1tex__data_bank_conflicts_pipe_lmem_op_st.sum` is high + - Fix: Add padding — `__shared__ float data[32][33]` + +4. **Warp divergence**: Conditional branches splitting warps + - Check: `smsp__warp_issue_stalled_branch.avg.pct_of_peak_sustained_active` + - Fix: Restructure so elements with identical branches are in the same warp + +5. **Too many small kernels**: High kernel launch overhead + - Fix: Fuse operations (e.g., norm + scale + shift → AdaLN in one kernel) + +--- + +## JIT Compilation Notes + +SGLang's JIT compiles kernels on first use via `load_jit`. For H100-specific flags: + +```python +return load_jit( + "my_kernel", + *args, + cuda_files=["diffusion/my_kernel.cuh"], + cuda_wrappers=[("my_kernel", f"my_kernel<{args}>")], + extra_cuda_cflags=[ + "-O3", + "--use_fast_math", + "-arch=sm_90", # H100 only; omit for multi-arch + "--ptxas-options=-v", # Remove after tuning + ], +) +``` + +For multi-arch (H100 + A100): + +```python +extra_cuda_cflags=[ + "-O3", + "--use_fast_math", + "-gencode=arch=compute_80,code=sm_80", # A100 + "-gencode=arch=compute_90,code=sm_90", # H100 +] +``` + +--- + +## Best Practices Summary + +1. **Memory access**: Coalesce writes, align to 128-byte boundaries +2. **Vectorization**: Use `AlignedVector` for all element-wise loads/stores +3. **Reductions**: Use `warp::reduce_sum/max`, then shared memory pattern above +4. **Precision**: BF16 for I/O, FP32 for accumulation; use `static_cast` +5. **Block size**: 256 threads default; 512 for reductions; tune with `runtime::get_blocks_per_sm` +6. **Grid sizing**: Multiples of 132 SMs; use persistent kernel pattern for small N +7. **Shared memory**: Add padding (`[32][33]`) to avoid bank conflicts +8. **Profile**: Run `ncu` before claiming a speedup; check dram throughput % +9. **Fuse**: Combine norm + scale + shift into a single pass to reduce memory traffic +10. **Abstractions**: Always use `TensorMatcher`, `AlignedVector`, `LaunchKernel` — never raw CUDA + +## Reference Benchmark Results (H100 80GB, BF16) + +| Kernel | Shape | Time (ms) | +|--------|-------|-----------| +| RMSNorm | [2, 1024, 2048] | 0.054 | +| GEGLU | [2, 1024, 4096] → [2, 1024, 2048] | 0.030 | +| RoPE 3D | [2, 480, 8, 64] | 1.670 | +| RMSNorm vectorized | [1, 1024, 2048] | 0.019 | +| RMSNorm vectorized | [4, 4096, 3072] | 0.157 | + +> See `kernel-templates.md` for copy-paste ready sglang JIT kernel implementations. diff --git a/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/kernel-templates.md b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/kernel-templates.md new file mode 100644 index 0000000000000000000000000000000000000000..899a6347cf63e887e6273aaa16c93ade11f76d8f --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/kernel-templates.md @@ -0,0 +1,569 @@ +# CUDA Kernel Templates — SGLang Diffusion JIT Style + +Copy-paste ready templates for JIT CUDA kernels in `python/sglang/jit_kernel/csrc/diffusion/`. +All templates use SGLang's internal abstractions; no raw CUDA headers needed. + +> **Adapted from**: [HuggingFace kernels cuda-kernels skill](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels) + +--- + +## Prerequisite: Standard Includes + +Every kernel file in `csrc/diffusion/` starts with: + +```cpp +#include // TensorMatcher, SymbolicSize, SymbolicDevice +#include // fp16_t, bf16_t, fp32_t, dtype_trait, packed_t +#include // RuntimeCheck, Panic, div_ceil +#include // LaunchKernel, SGL_DEVICE, type aliases +#include // AlignedVector +#include // warp::reduce_sum, warp::reduce_max +#include // device::math::rsqrt, sqrt, ... +#include // tile::Memory (strided access pattern) + +#include +#include +``` + +**Key type aliases** (from `utils.cuh`): +- `fp16_t` = `__half`, `fp16x2_t` = `__half2` +- `bf16_t` = `__nv_bfloat16`, `bf16x2_t` = `__nv_bfloat162` +- `fp32_t` = `float`, `fp32x2_t` = `float2` +- `SGL_DEVICE` = `__forceinline__ __device__` + +--- + +## Template 1: Element-wise Operation + +Use for ops that process elements independently: RoPE, SiLU, GEGLU, scale+bias. + +### `.cuh` file: `csrc/diffusion/silu_gate.cuh` + +```cpp +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace { + +// SiLU gate: out[i] = x[i] * sigmoid(x[i]) +// Input layout: [B, L, hidden] +template +__global__ void silu_gate_kernel( + T* __restrict__ dst, + const T* __restrict__ src, + uint32_t n_vecs, + uint32_t n_remainder, + uint32_t n_total) +{ + using vec_t = device::AlignedVector; + + const uint32_t stride = blockDim.x * gridDim.x; + + // --- vectorized body --- + for (uint32_t vi = blockIdx.x * blockDim.x + threadIdx.x; vi < n_vecs; vi += stride) { + vec_t v; + v.load(src, vi); + #pragma unroll + for (int i = 0; i < kVecN; ++i) { + float val = static_cast(v[i]); + float sig = 1.f / (1.f + device::math::exp(-val)); + v[i] = static_cast(val * sig); + } + v.store(dst, vi); + } + + // --- scalar tail (for sizes not divisible by kVecN) --- + const uint32_t base = n_vecs * kVecN; + for (uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; i < n_remainder; i += stride) { + float val = static_cast(src[base + i]); + float sig = 1.f / (1.f + device::math::exp(-val)); + dst[base + i] = static_cast(val * sig); + } +} + +template +void silu_gate(tvm::ffi::TensorView dst, tvm::ffi::TensorView src) { + using namespace host; + + SymbolicSize N{"num_elements"}; + SymbolicDevice device; + device.set_options(); + + TensorMatcher({N}) + .with_dtype() + .with_device(device) + .verify(dst) + .verify(src); + + const uint32_t n = static_cast(N.unwrap()); + const DLDevice dev = device.unwrap(); + RuntimeCheck(n > 0, "silu_gate: num_elements must be > 0"); + + constexpr int kVecN = 16 / sizeof(T); // 128-bit vector load + const uint32_t n_vecs = n / kVecN; + const uint32_t n_rem = n % kVecN; + + constexpr uint32_t kBlock = 256; + const uint32_t grid = div_ceil(std::max(n_vecs, n_rem), kBlock); + + LaunchKernel(grid, kBlock, dev)( + silu_gate_kernel, + static_cast(dst.data_ptr()), + static_cast(src.data_ptr()), + n_vecs, n_rem, n); +} + +} // namespace +``` + +### Python wrapper: `diffusion/silu_gate.py` + +```python +from __future__ import annotations +import torch +from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args + +@cache_once +def _jit_silu_gate_module(dtype: torch.dtype): + args = make_cpp_args(dtype) + return load_jit( + "diffusion_silu_gate", + *args, + cuda_files=["diffusion/silu_gate.cuh"], + cuda_wrappers=[("silu_gate", f"silu_gate<{args}>")], + extra_cuda_cflags=["-O3", "--use_fast_math"], + ) + +def diffusion_silu_gate(src: torch.Tensor, out: torch.Tensor | None = None) -> torch.Tensor: + assert src.is_cuda and src.dtype in (torch.float16, torch.bfloat16, torch.float32) + if out is None: + out = torch.empty_like(src) + module = _jit_silu_gate_module(src.dtype) + module.silu_gate(out, src) + return out +``` + +--- + +## Template 2: Row-wise Reduction (RMSNorm / LayerNorm) + +Use for ops that reduce across the last dimension of each row. + +### `.cuh` file: `csrc/diffusion/rmsnorm.cuh` + +```cpp +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace { + +// RMSNorm: y = x / rms(x) * weight +// One block per row; vectorized loads/stores; warp + shared-mem reduction +template +__global__ void rmsnorm_kernel( + T* __restrict__ dst, + const T* __restrict__ src, + const T* __restrict__ weight, // nullptr if no affine weight + uint32_t hidden, + uint32_t n_vecs, + float eps) +{ + using vec_t = device::AlignedVector; + + const uint32_t row = blockIdx.x; + const T* row_src = src + row * hidden; + T* row_dst = dst + row * hidden; + + // Pass 1: sum of squares + float sum_sq = 0.f; + for (uint32_t vi = threadIdx.x; vi < n_vecs; vi += blockDim.x) { + vec_t v; + v.load(row_src, vi); + #pragma unroll + for (int i = 0; i < kVecN; ++i) { + float val = static_cast(v[i]); + sum_sq += val * val; + } + } + + // Warp + block reduction + sum_sq = device::warp::reduce_sum(sum_sq); + __shared__ float smem[32]; + if (threadIdx.x % 32 == 0) smem[threadIdx.x / 32] = sum_sq; + __syncthreads(); + if (threadIdx.x < 32) { + sum_sq = (threadIdx.x < blockDim.x / 32) ? smem[threadIdx.x] : 0.f; + sum_sq = device::warp::reduce_sum(sum_sq); + } + __syncthreads(); + + const float rms_inv = device::math::rsqrt(sum_sq / static_cast(hidden) + eps); + + // Pass 2: normalize + optional weight + for (uint32_t vi = threadIdx.x; vi < n_vecs; vi += blockDim.x) { + vec_t v_in, v_out; + v_in.load(row_src, vi); + if (weight != nullptr) { + vec_t v_w; + v_w.load(weight, vi); + #pragma unroll + for (int i = 0; i < kVecN; ++i) + v_out[i] = static_cast(static_cast(v_in[i]) * rms_inv + * static_cast(v_w[i])); + } else { + #pragma unroll + for (int i = 0; i < kVecN; ++i) + v_out[i] = static_cast(static_cast(v_in[i]) * rms_inv); + } + v_out.store(row_dst, vi); + } +} + +template +void rmsnorm( + tvm::ffi::TensorView dst, + tvm::ffi::TensorView src, + tvm::ffi::TensorView weight, // data_ptr == nullptr → no weight + float eps) +{ + using namespace host; + + SymbolicSize B{"batch_tokens"}, H{"hidden_size"}; + SymbolicDevice device; + device.set_options(); + + TensorMatcher({B, H}) + .with_dtype() + .with_device(device) + .verify(dst) + .verify(src); + + const uint32_t num_rows = static_cast(B.unwrap()); + const uint32_t hidden = static_cast(H.unwrap()); + const DLDevice dev = device.unwrap(); + + constexpr int kVecN = 16 / sizeof(T); + RuntimeCheck(hidden % kVecN == 0, + "rmsnorm: hidden_size (", hidden, ") must be divisible by ", kVecN); + const uint32_t n_vecs = hidden / kVecN; + + uint32_t threads = std::min(n_vecs, 512u); + threads = (threads + 31) / 32 * 32; + + const T* w_ptr = (weight.data_ptr() != nullptr) + ? static_cast(weight.data_ptr()) : nullptr; + + LaunchKernel(num_rows, threads, dev)( + rmsnorm_kernel, + static_cast(dst.data_ptr()), + static_cast(src.data_ptr()), + w_ptr, hidden, n_vecs, eps); +} + +} // namespace +``` + +--- + +## Template 3: Fused Row-Reduction + Element-wise (AdaLN) + +Combines RMSNorm + AdaLN modulation into one pass: `y = norm(x) * (1 + scale) + shift`. + +### `.cuh` file: `csrc/diffusion/adaln.cuh` + +```cpp +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace { + +// AdaLN: y = norm(x) * (1 + scale) + shift +// scale, shift: [batch, hidden] (one per row) +template +__global__ void adaln_kernel( + T* __restrict__ dst, + const T* __restrict__ src, + const T* __restrict__ weight, + const T* __restrict__ scale, + const T* __restrict__ shift, + uint32_t hidden, + uint32_t n_vecs, + float eps) +{ + using vec_t = device::AlignedVector; + + const uint32_t row = blockIdx.x; + const T* row_src = src + row * hidden; + const T* row_scale = scale + row * hidden; + const T* row_shift = shift + row * hidden; + T* row_dst = dst + row * hidden; + + // Pass 1: compute RMS + float sum_sq = 0.f; + for (uint32_t vi = threadIdx.x; vi < n_vecs; vi += blockDim.x) { + vec_t v; + v.load(row_src, vi); + #pragma unroll + for (int i = 0; i < kVecN; ++i) { + float val = static_cast(v[i]); + sum_sq += val * val; + } + } + sum_sq = device::warp::reduce_sum(sum_sq); + __shared__ float smem[32]; + if (threadIdx.x % 32 == 0) smem[threadIdx.x / 32] = sum_sq; + __syncthreads(); + if (threadIdx.x < 32) { + sum_sq = (threadIdx.x < blockDim.x / 32) ? smem[threadIdx.x] : 0.f; + sum_sq = device::warp::reduce_sum(sum_sq); + } + __syncthreads(); + const float rms_inv = device::math::rsqrt(sum_sq / static_cast(hidden) + eps); + + // Pass 2: normalize + modulate + for (uint32_t vi = threadIdx.x; vi < n_vecs; vi += blockDim.x) { + vec_t v_in, v_w, v_sc, v_sh, v_out; + v_in.load(row_src, vi); + v_w.load(weight, vi); + v_sc.load(row_scale, vi); + v_sh.load(row_shift, vi); + #pragma unroll + for (int i = 0; i < kVecN; ++i) { + float x = static_cast(v_in[i]) * rms_inv * static_cast(v_w[i]); + float sc = static_cast(v_sc[i]); + float sh = static_cast(v_sh[i]); + v_out[i] = static_cast(x * (1.f + sc) + sh); + } + v_out.store(row_dst, vi); + } +} + +template +void adaln( + tvm::ffi::TensorView dst, + tvm::ffi::TensorView src, + tvm::ffi::TensorView weight, + tvm::ffi::TensorView scale, + tvm::ffi::TensorView shift, + float eps) +{ + using namespace host; + + SymbolicSize B{"batch_tokens"}, H{"hidden_size"}; + SymbolicDevice device; + device.set_options(); + + TensorMatcher({B, H}) + .with_dtype() + .with_device(device) + .verify(dst).verify(src).verify(weight).verify(scale).verify(shift); + + const uint32_t num_rows = static_cast(B.unwrap()); + const uint32_t hidden = static_cast(H.unwrap()); + const DLDevice dev = device.unwrap(); + + constexpr int kVecN = 16 / sizeof(T); + RuntimeCheck(hidden % kVecN == 0, "adaln: hidden_size must be divisible by ", kVecN); + const uint32_t n_vecs = hidden / kVecN; + + uint32_t threads = std::min(n_vecs, 512u); + threads = (threads + 31) / 32 * 32; + + LaunchKernel(num_rows, threads, dev)( + adaln_kernel, + static_cast(dst.data_ptr()), + static_cast(src.data_ptr()), + static_cast(weight.data_ptr()), + static_cast(scale.data_ptr()), + static_cast(shift.data_ptr()), + hidden, n_vecs, eps); +} + +} // namespace +``` + +--- + +## Template 4: Python Wrapper (generic pattern) + +File location: `python/sglang/jit_kernel/diffusion/.py` + +```python +from __future__ import annotations +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_module(dtype: torch.dtype) -> Module: + """Cache key: dtype (and any other template params you need).""" + args = make_cpp_args(dtype) + return load_jit( + "diffusion_your_op", # unique build cache key + *args, + cuda_files=["diffusion/your_op.cuh"], # relative to csrc/ + cuda_wrappers=[("your_op", f"your_op<{args}>")], + extra_cuda_cflags=["-O3", "--use_fast_math"], + ) + + +def diffusion_your_op( + src: torch.Tensor, + out: torch.Tensor | None = None, +) -> torch.Tensor: + """ + Your op description. + + Supported dtypes: float16, bfloat16, float32. + """ + assert src.is_cuda, "src must be a CUDA tensor" + assert src.dtype in (torch.float16, torch.bfloat16, torch.float32), ( + f"Unsupported dtype {src.dtype}" + ) + if out is None: + out = torch.empty_like(src) + + module = _jit_module(src.dtype) + module.your_op(out, src) + return out +``` + +**`make_cpp_args` conversion table:** + +| `torch.dtype` | C++ type | +|---------------|----------| +| `torch.float16` | `fp16_t` | +| `torch.bfloat16` | `bf16_t` | +| `torch.float32` | `fp32_t` | + +--- + +## Template 5: Correctness Test + +```python +# python/sglang/jit_kernel/tests/test_diffusion_.py +import pytest +import torch +from sglang.jit_kernel.diffusion. import diffusion_ + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("shape", [(1, 2048), (4, 3072), (16, 4096)]) +def test__correctness(dtype, shape): + src = torch.randn(*shape, dtype=dtype, device="cuda") + + out_jit = diffusion_(src) + ref = reference_(src.float()).to(dtype) # reference in fp32 + + tol = {"rtol": 1e-2, "atol": 1e-2} if dtype != torch.float32 else {"rtol": 1e-5, "atol": 1e-6} + torch.testing.assert_close(out_jit, ref, **tol) + + +def test__out_param(): + src = torch.randn(1024, 2048, dtype=torch.bfloat16, device="cuda") + out = torch.empty_like(src) + result = diffusion_(src, out=out) + assert result is out + + +def test__cpu_error(): + src = torch.randn(128, dtype=torch.float16) # CPU tensor + with pytest.raises(AssertionError): + diffusion_(src) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) +``` + +--- + +## Template 6: Benchmark + +```python +# python/sglang/jit_kernel/benchmark/bench_diffusion_.py +import torch +import triton.testing + +from sglang.jit_kernel.benchmark.utils import DEFAULT_DEVICE, DEFAULT_DTYPE, run_benchmark +from sglang.jit_kernel.diffusion. import diffusion_ + +SHAPES = [(4096, 2048), (4096, 3072), (4096, 4096)] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["hidden"], + x_vals=[s[1] for s in SHAPES], + line_arg="provider", + line_vals=["jit_cuda", "torch"], + line_names=["SGLang JIT CUDA", "PyTorch"], + styles=[("blue", "-"), ("red", "--")], + ylabel="us", + plot_name="diffusion-", + args={}, + ) +) +def benchmark(hidden: int, provider: str): + src = torch.randn(4096, hidden, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE) + + if provider == "jit_cuda": + fn = lambda: diffusion_(src) + else: + fn = lambda: reference_(src) # torch baseline + + return run_benchmark(fn) + + +if __name__ == "__main__": + benchmark.run(print_data=True) +``` + +--- + +## Summary of New Files per Kernel + +``` +python/sglang/jit_kernel/csrc/diffusion/ +└── .cuh # CUDA kernel + launcher + +python/sglang/jit_kernel/diffusion/ +└── .py # Python wrapper (load_jit + cache_once) + +python/sglang/jit_kernel/tests/ +└── test_diffusion_.py # correctness tests + +python/sglang/jit_kernel/benchmark/ +└── bench_diffusion_.py # triton.testing benchmark +``` + +> See `scripts/bench_diffusion_rmsnorm.py` and `scripts/bench_diffusion_denoise.py` for full runnable examples. diff --git a/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/t4-optimization-guide.md b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/t4-optimization-guide.md new file mode 100644 index 0000000000000000000000000000000000000000..50e298cc8ddeccdc2bfd34f3edb51970e5d6a147 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/t4-optimization-guide.md @@ -0,0 +1,335 @@ +# T4 GPU Optimization Guide — SGLang Diffusion JIT Kernels + +T4 is a Turing architecture GPU (GCP n1+T4, AWS g4dn) commonly used for cloud inference. +Its key constraint for diffusion kernels: **no BF16 support** — FP16 only. + +> **Adapted from**: [HuggingFace kernels cuda-kernels skill](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels) + +--- + +## T4 Turing Architecture Overview + +| Component | T4 | A100 | H100 | +|-----------|-----|------|------| +| Compute Capability | sm_75 | sm_80 | sm_90 | +| SMs | 40 | 108 | 132 | +| Shared Memory/SM | **64 KB** | 164 KB | 192 KB | +| L2 Cache | 4 MB | 40 MB | 50 MB | +| Memory Bandwidth | **320 GB/s** | 2.0 TB/s | 3.35 TB/s | +| Memory | 16 GB GDDR6 | 40–80 GB HBM2e | 80 GB HBM3 | +| Max Threads/SM | **1024** | 2048 | 2048 | +| BF16 Support | **No** | Yes | Yes | + +### Critical T4 Constraints + +1. **No BFloat16** — must use FP16 everywhere +2. **320 GB/s bandwidth** — ~10x lower than H100; vectorization is critical +3. **16 GB memory** — limits model size; use offloading +4. **64 KB shared memory/SM** — smaller attention tiles +5. **Max 1024 threads/SM** — half of A100/H100; affects occupancy calculations + +--- + +## No BF16: Always Use FP16 + +This is the most impactful constraint. **Never use `bf16_t` or `__nv_bfloat16` on T4.** + +**Python wrapper guard:** + +```python +import torch +from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args + +@cache_once +def _jit_rmsnorm_module(dtype: torch.dtype): + # T4 (sm_75) does not support BF16 + cap = torch.cuda.get_device_capability() + if cap < (8, 0) and dtype == torch.bfloat16: + raise RuntimeError( + f"T4 (sm_75) does not support BF16. Use torch.float16 instead. " + f"Got dtype={dtype}" + ) + args = make_cpp_args(dtype) + return load_jit( + "diffusion_rmsnorm", + *args, + cuda_files=["diffusion/rmsnorm.cuh"], + cuda_wrappers=[("rmsnorm", f"rmsnorm<{args}>")], + ) +``` + +**Conditional type in kernel:** + +```cuda +#if __CUDA_ARCH__ >= 800 + // A100/H100: BF16 available + using DefaultHalf = bf16_t; +#else + // T4/Turing: FP16 only + using DefaultHalf = fp16_t; +#endif +``` + +**Runtime detection helper:** + +```python +def get_diffusion_dtype() -> torch.dtype: + """Return the appropriate half-precision dtype for the current GPU.""" + cap = torch.cuda.get_device_capability() + if cap >= (8, 0): + return torch.bfloat16 # A100/H100: prefer BF16 + else: + return torch.float16 # T4/older: FP16 only +``` + +--- + +## Memory Access Optimization + +With only 320 GB/s, **vectorization is more critical on T4 than on A100/H100**. + +### `AlignedVector` (same abstraction, FP16 only) + +```cpp +#include + +// On T4, T must be fp16_t or fp32_t (NOT bf16_t) +constexpr int kVecN = 16 / sizeof(T); // 8 for fp16, 4 for fp32 +using vec_t = device::AlignedVector; +``` + +**Target bandwidth**: 40–50% of T4's 320 GB/s = 128–160 GB/s. + +### Increase Arithmetic Intensity + +With low bandwidth, fusing ops saves more on T4 than on H100: + +```cpp +// BAD on T4: separate passes → 2× memory traffic +output1[i] = input[i] * scale; // pass 1 +output2[i] = output1[i] + bias; // pass 2 + +// GOOD: fuse → single memory read, single write +float val = static_cast(v[i]); +val = val * scale + bias; +val = device::math::max(val, 0.f); // ReLU +v[i] = static_cast(val); +``` + +### Expected T4 Performance + +| Kernel | T4 (ms) | A100 (ms) | H100 (ms) | T4 vs H100 | +|--------|---------|-----------|-----------|------------| +| RMSNorm [2, 1024, 2048] | ~0.5 | ~0.08 | 0.054 | ~9x slower | +| GEGLU [2, 1024, 4096] | ~0.3 | ~0.05 | 0.030 | ~10x slower | + +--- + +## Shared Memory Configuration + +T4 max: **64 KB/SM**. Use smaller tiles vs A100/H100. + +```cpp +// T4: request max shared memory (64 KB) +cudaFuncSetAttribute( + your_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + 64 * 1024 +); +``` + +**Attention tile sizes for T4** (halved vs H100): + +``` +H100/A100: BLOCK_SIZE_M = 128, BLOCK_SIZE_N = 64 +T4: BLOCK_SIZE_M = 64, BLOCK_SIZE_N = 32 ← reduced for 64 KB limit +``` + +--- + +## Occupancy Tuning + +T4 max: **1024 threads/SM** (vs 2048 on A100/H100). This halves max occupancy for a given block size. + +**Block sizes for T4:** + +| Kernel Type | Threads/Block | Notes | +|-------------|---------------|-------| +| Element-wise | 256 | Same as H100 | +| Row reduction | 256–512 | Avoid > 512 to fit multiple blocks/SM | +| Tiled/attention | 128–256 | Small tiles due to 64 KB shared mem | + +**Grid sizing for T4 (40 SMs)** — `runtime::get_sm_count` handles this automatically: + +```cpp +// get_sm_count() returns 40 on T4, 108 on A100, 132 on H100 +const uint32_t num_sm = host::runtime::get_sm_count(device.unwrap().device_id); +``` + +--- + +## Numerical Stability with FP16 + +FP16 has a smaller dynamic range (±65504) vs BF16 (±3.39×10³⁸). Watch for overflow in attention: + +```cuda +// Scale attention scores to prevent FP16 overflow +float scale_factor = 1.0f / sqrtf(static_cast(head_dim)); +// For very long sequences on T4, may need additional scaling: +// if (score * scale_factor > 65000.f) { /* clamp */ } +``` + +Always accumulate in FP32: + +```cpp +float acc = 0.f; // FP32 accumulation +for (uint32_t vi = threadIdx.x; vi < n_vecs; vi += blockDim.x) { + vec_t v; + v.load(src, vi); + #pragma unroll + for (int i = 0; i < kVecN; ++i) { + float val = static_cast(v[i]); // fp16 → fp32 + acc += val * val; + } +} +``` + +--- + +## Memory Management for 16 GB + +T4's 16 GB requires careful planning for large diffusion models. + +**sglang generate flags for T4:** + +```bash +# Enable CPU offloading to fit within 16 GB +sglang generate \ + --model-path=black-forest-labs/FLUX.1-dev \ + --dit-cpu-offload true \ # DiT weights to CPU + --text-encoder-cpu-offload true \ + --vae-cpu-offload true \ + --width=512 --height=512 \ # Reduce resolution + --num-inference-steps=20 \ # Fewer steps + --seed=42 +``` + +**Resolution recommendations for T4:** + +| Model | H100/A100 | T4 | +|-------|-----------|-----| +| FLUX.1-dev | 1024×1024 | 512×512 | +| Wan2.2-TI2V-5B | 720P | 480P | +| FLUX.2-dev | 1024×1024 | 512×512 | + +--- + +## JIT Compilation for T4 + +```python +return load_jit( + "my_kernel", + *args, + cuda_files=["diffusion/my_kernel.cuh"], + cuda_wrappers=[("my_kernel", f"my_kernel<{args}>")], + extra_cuda_cflags=[ + "-O3", + "--use_fast_math", + "-arch=sm_75", # T4 only; omit for multi-arch + ], +) +``` + +**Multi-arch (T4 + A100 + H100):** + +```python +extra_cuda_cflags=[ + "-O3", "--use_fast_math", + "-gencode=arch=compute_75,code=sm_75", # T4 + "-gencode=arch=compute_80,code=sm_80", # A100 + "-gencode=arch=compute_90,code=sm_90", # H100 +] +``` + +--- + +## H100/A100 → T4 Migration Checklist + +| Item | H100/A100 | T4 | Action | +|------|-----------|-----|--------| +| BF16 | Available | **Not available** | Replace `bf16_t` with `fp16_t`; guard in Python wrapper | +| Shared memory | 164–192 KB | **64 KB** | Halve tile sizes | +| Grid sizing | ×108/132 SMs | ×40 SMs | `get_sm_count()` auto-handles | +| Max threads/SM | 2048 | **1024** | Don't exceed 512 threads/block | +| Memory | 40–80 GB | **16 GB** | Enable CPU offloading | +| cp.async | Available | No (Turing has limited async) | Remove async copy patterns | +| `AlignedVector` | Same | Same | No changes | +| `warp::reduce_sum` | Same | Same | No changes | + +--- + +## Performance Profiling + +### NVIDIA Nsight Systems (nsys) + +```bash +nsys profile -o t4_profile python scripts/bench_diffusion_rmsnorm.py + +# Key metrics to watch: +# - Kernel duration +# - Memory transfer time +# - GPU idle time +# - Stream utilization +``` + +### NVIDIA Nsight Compute (ncu) + +```bash +# Full metrics +ncu --set full -o t4_metrics.ncu-rep \ + python scripts/bench_diffusion_rmsnorm.py + +# Specific metrics — T4 is memory-bound; focus on dram throughput +ncu --metrics sm__throughput.avg.pct_of_peak_sustained_elapsed,\ +dram__throughput.avg.pct_of_peak_sustained_elapsed \ + python scripts/bench_diffusion_rmsnorm.py + +# Key metrics for T4 diffusion kernels: +# - Memory throughput (dram__throughput.avg.pct_of_peak_sustained_elapsed) +# → Target: 40–50% of 320 GB/s (128–160 GB/s) for vectorized kernels +# - SM utilization (sm__throughput.avg.pct_of_peak_sustained_elapsed) +# → Target high with only 40 SMs +# - Achieved occupancy (sm__warps_active.avg.pct_of_peak_sustained_active) +# → Max 1024 threads/SM on T4 — block size ≤ 512 for decent occupancy +# - Warp stall reasons (smsp__warp_issue_stalled_*.avg.pct_of_peak_sustained_active) +``` + +### Common T4 Bottlenecks + +1. **Memory Bandwidth** — 320 GB/s is the primary limit; if `dram__throughput` < 40% → use `AlignedVector` +2. **Limited Memory** — 16 GB; enable `--dit-cpu-offload`/`--vae-cpu-offload` as needed +3. **No BF16** — guard in Python wrapper; FP16 overflow risk in long-sequence attention +4. **Smaller tiles** — 64 KB shared memory; reduce `BLOCK_SIZE_M/N` vs H100 + +--- + +## Best Practices Summary (T4) + +1. **No BF16**: Guard in Python wrapper, raise clear error +2. **Vectorization**: Even more critical at 320 GB/s — always use `AlignedVector` +3. **Tile sizes**: 64 KB shared memory limit → halve BLOCK_SIZE vs H100 +4. **Block size**: Max 512 threads/block for decent occupancy (max 1024 threads/SM) +5. **Grid sizing**: 40 SMs — `runtime::get_sm_count()` auto-handles +6. **FP32 accumulation**: Always accumulate in FP32 to avoid FP16 overflow +7. **Memory**: Plan for 16 GB; use `--dit-cpu-offload`/`--vae-cpu-offload` as needed +8. **Fuse more**: Low bandwidth makes kernel fusion more impactful than on H100 +9. **Multi-arch build**: Always build for `sm_75,sm_80,sm_90` together + +## T4 Cloud Instance Quick Reference + +| Provider | Instance | Notes | +|----------|----------|-------| +| GCP | n1-standard-4 + T4 | Most common inference setup | +| AWS | g4dn.xlarge | 1× T4, 16 GB | +| AWS | g4dn.12xlarge | 4× T4, 64 GB total | +| Azure | NC4as T4 v3 | 1× T4 | diff --git a/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/troubleshooting.md b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/troubleshooting.md new file mode 100644 index 0000000000000000000000000000000000000000..5dc4637c8d1141b7096137f4ae6538db0fdb538a --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/references/troubleshooting.md @@ -0,0 +1,328 @@ +# Troubleshooting Guide — SGLang Diffusion JIT CUDA Kernels + +Common issues and solutions when writing and integrating JIT CUDA kernels for SGLang Diffusion. + +> **Adapted from**: [HuggingFace kernels cuda-kernels skill](https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels) + +--- + +## Build / Compile Issues + +### 1. JIT compilation fails: "No such file or directory" + +**Problem:** `load_jit` cannot find your `.cuh` file. + +``` +FileNotFoundError: .../jit_kernel/csrc/diffusion/your_op.cuh not found +``` + +**Fix:** Ensure the file is under `python/sglang/jit_kernel/csrc/diffusion/`. The path passed to `cuda_files` is relative to `csrc/`: + +```python +# CORRECT — file lives at csrc/diffusion/your_op.cuh +load_jit(..., cuda_files=["diffusion/your_op.cuh"]) +# resolves to: python/sglang/jit_kernel/csrc/diffusion/your_op.cuh + +# ALSO CORRECT — absolute path (pathlib replaces the csrc/ prefix) +load_jit(..., cuda_files=["/full/absolute/path/to/your_op.cuh"]) +``` + +### 2. Type conversion errors (FP16/BF16) + +**Problem:** Implicit FP16/BF16 conversion fails because PyTorch compiles with `-D__CUDA_NO_HALF_OPERATORS__`: + +``` +error: no suitable conversion function from "__half" to "float" exists +``` + +**Fix:** SGLang's `static_cast` works because `fp16_t` and `bf16_t` are typedef'd with proper conversion operators. Always use explicit casts: + +```cpp +// CORRECT — explicit cast +float val = static_cast(v[i]); // fp16_t / bf16_t → float +v[i] = static_cast(fp32_result); // float → T + +// WRONG — implicit conversion (disabled by PyTorch build flags) +float val = v[i]; // compile error +v[i] = fp32_result; // compile error +``` + +If you need the raw intrinsics for packed types: +```cpp +// bf16x2_t → two floats +bf16x2_t packed = ...; +float v0 = __bfloat162float(packed.x); +float v1 = __bfloat162float(packed.y); +``` + +### 3. Template instantiation explodes / slow first compile + +**Problem:** Many template combinations makes the first JIT compile very slow. + +**Fix:** Reduce template argument combinations. Move compile-time constants to runtime if they don't affect performance critically: + +```cpp +// Fewer template args = fewer instantiations +template // only dtype varies +void my_op(tvm::ffi::TensorView dst, tvm::ffi::TensorView src, int block_size); +``` + +### 4. SM check: kernel requires sm_90 but device is sm_80 + +**Problem:** Kernel uses H100-only features on A100. + +**Fix:** Add a Python guard before calling `load_jit`: + +```python +cap = torch.cuda.get_device_capability() +if cap[0] < 9: + raise RuntimeError( + f"This kernel requires H100 (sm_90+). " + f"Got compute capability {cap[0]}.{cap[1]}. " + f"Use the Triton fallback instead: diffusion_triton_()" + ) +``` + +--- + +## Performance Issues + +### 5. Kernel is slower than Triton / PyTorch baseline + +**Steps to diagnose:** + +1. Check dtype: are you using `bf16_t` on T4? (T4 has no BF16 — silently falls back to slow emulation) +2. Check vectorization: is `hidden_size` divisible by `kVecN = 16/sizeof(T)` (8 for bf16, 4 for fp32)? +3. Profile with `ncu`: + ```bash + ncu --set full --csv -o metrics.csv \ + python -c "from sglang.jit_kernel.diffusion.rmsnorm import diffusion_rmsnorm; ..." + ``` + Look at `dram__throughput.avg.pct_of_peak_sustained_elapsed` — if < 30%, check coalescing. + +4. Check occupancy: run with `--ptxas-options=-v` in `extra_cuda_cflags` to see register usage. + +### 6. Shared memory bank conflicts + +**Problem:** `ncu` reports high `l1tex__data_bank_conflicts_pipe_lmem_op_st.sum`. + +**Fix:** Add padding to shared memory arrays: + +```cpp +// Conflict (all threads hit same bank when stride=32) +__shared__ float data[32][32]; + +// Fixed with padding +__shared__ float data[32][33]; // 33 instead of 32 +``` + +### 7. Low occupancy from too many registers + +**Problem:** `nvcc --ptxas-options=-v` shows high register count; occupancy < 25%. + +**Fix:** Add `--maxrregcount=N` to limit registers: + +```python +extra_cuda_cflags=["-O3", "--use_fast_math", "--maxrregcount=64"] +``` + +Reduces registers per thread at the cost of possible register spilling to local memory. + +--- + +## Integration Issues + +### 8. RMSNorm weight is None (`elementwise_affine=False`) + +**Problem:** +``` +AttributeError: 'NoneType' object has no attribute 'data_ptr' +``` + +**Root Cause:** DiT transformer blocks often use `RMSNorm(dim, elementwise_affine=False)` — no learnable weight. + +**Fix in Python wrapper:** pass an empty tensor when weight is absent; the kernel launcher checks `data_ptr == nullptr`: + +```python +w = weight if weight is not None else torch.empty(0, dtype=src.dtype, device=src.device) +module.rmsnorm(out, src, w, eps) +``` + +**Fix in `.cuh` launcher:** + +```cpp +const T* w_ptr = (weight.data_ptr() != nullptr) + ? static_cast(weight.data_ptr()) : nullptr; +// ... pass w_ptr to kernel ... +``` + +**Fix in module patching:** + +```python +has_weight = hasattr(module, "weight") and module.weight is not None +if has_weight: + def _fwd(mod, eps): + def forward(x): return diffusion_rmsnorm(x, weight=mod.weight, eps=eps) + return forward + module.forward = _fwd(module, module.eps) +else: + def _fwd_noweight(eps): + def forward(x): return diffusion_rmsnorm(x, weight=None, eps=eps) + return forward + module.forward = _fwd_noweight(module.eps) +``` + +### 9. `isinstance(module, torch.nn.RMSNorm)` misses diffusion variants + +**Problem:** Patching doesn't apply because diffusers / sglang diffusion models define their own `RMSNorm` class that is **not** a subclass of `torch.nn.RMSNorm`. + +**Fix:** Match by class name string: + +```python +# WRONG — misses diffusers/sglang RMSNorm +if isinstance(module, torch.nn.RMSNorm): + +# CORRECT — catches all variants +if type(module).__name__ == "RMSNorm": +# or for broader matching: +if "RMSNorm" in type(module).__name__: +``` + +### 10. Kernel patching doesn't persist after CPU offloading + +**Problem:** After calling `pipe.enable_model_cpu_offload()`, patched modules revert. + +**Fix:** Always inject **after** moving to CUDA, **before** enabling any offloading: + +```python +pipe = load_pipeline(...) +pipe.to("cuda") # 1. Move to CUDA +inject_optimized_kernels(pipe) # 2. Patch modules +pipe.enable_model_cpu_offload() # 3. Now safe to enable offloading +``` + +### 11. Kernel patched after `torch.compile` + +**Problem:** Module is already compiled; patching its `forward` after compilation has no effect. + +**Fix:** Apply patches **before** any `torch.compile` call: + +```python +inject_optimized_kernels(pipe) # FIRST: patch +pipe.transformer = torch.compile(...) # SECOND: compile +``` + +--- + +## `torch.compile` Compatibility + +### 12. Custom CUDA kernel causes graph break + +**Problem:** +``` +torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped +``` +or: +``` +torch._dynamo.exc.TorchRuntimeError: Cannot access data pointer of Tensor (FakeTensor) +``` + +**Root Cause:** `torch.compile` traces with "fake tensors" that have no real data. Any kernel that calls `.data_ptr()` during tracing fails. + +**Options:** + +**Option A (simplest):** Don't use `torch.compile` with CUDA JIT kernels — use Triton instead: +```python +# Triton kernels are torch.compile compatible +from sglang.jit_kernel.diffusion.triton.norm import fused_rmsnorm +``` + +**Option B:** Register as a `@torch.library.custom_op` (advanced): +```python +import torch + +@torch.library.custom_op("diffusion_jit::rmsnorm", mutates_args={"out"}) +def _rmsnorm_op(out: torch.Tensor, src: torch.Tensor, + weight: torch.Tensor, eps: float) -> None: + module = _jit_rmsnorm_module(src.dtype) + module.rmsnorm(out, src, weight, eps) + +@_rmsnorm_op.register_fake +def _(out, src, weight, eps): + pass # no shape changes; output already allocated in 'out' +``` + +**Performance trade-off:** + +| Approach | Speedup (denoise) | torch.compile | Notes | +|----------|-------------------|---------------|-------| +| CUDA JIT kernel | best | Yes (via `torch.library.custom_op`) | Performance-optimal regardless of whether `torch.compile` is enabled; use `custom_op` + `register_fake` for compile compatibility | +| Triton kernel | good | Yes | Use when you need faster iteration/portability, or when you do not have a well-tuned CUDA kernel yet | +| Triton + compile | good | Yes | Use for end-to-end `torch.compile` integration convenience; typically slower than a well-tuned CUDA kernel | + +### 13. Unstable benchmark results from JIT timing + +**Problem:** First few runs are slow due to JIT compilation; timing is noisy. + +**Fix:** Use `triton.testing.do_bench` / `run_benchmark` which use CUDA-graph-based timing automatically. Always do a warmup run first: + +```python +# Pre-compile by running once before timing +diffusion_rmsnorm(dummy_src, weight=dummy_w, eps=1e-6) +torch.cuda.synchronize() +# Now time +result = run_benchmark(lambda: diffusion_rmsnorm(src, weight=w, eps=1e-6)) +``` + +--- + +## Debugging Checklist + +```bash +# 1. Verify CUDA device and compute capability +python -c "import torch; print(torch.cuda.get_device_name(), torch.cuda.get_device_capability())" + +# 2. Force synchronous CUDA execution to get real error location +CUDA_LAUNCH_BLOCKING=1 python scripts/bench_diffusion_rmsnorm.py + +# 3. Run memory sanitizer to catch illegal accesses +compute-sanitizer --tool memcheck python scripts/bench_diffusion_rmsnorm.py + +# 4. Check register and shared memory usage +# Add to extra_cuda_cflags: "--ptxas-options=-v" + +# 5a. Kernel-level profiling — full metrics +ncu --set full -o metrics.ncu-rep \ + python scripts/bench_diffusion_rmsnorm.py + +# 5b. Kernel-level profiling — targeted bandwidth + occupancy check +ncu --metrics sm__throughput.avg.pct_of_peak_sustained_elapsed,\ +dram__throughput.avg.pct_of_peak_sustained_elapsed \ + python scripts/bench_diffusion_rmsnorm.py + +# Key metrics to interpret: +# - sm__throughput : compute utilization % of peak +# - dram__throughput: memory bandwidth % of peak (target ≥ 30% on H100/A100) +# - smsp__warp_issue_stalled_*: warp stall breakdown (memory_dependency / math_pipe) + +# 6. System-level profiling (per-op breakdown inside sglang generate) +nsys profile -o denoise_profile \ + sglang generate --model-path=black-forest-labs/FLUX.1-dev \ + --width=1024 --height=1024 --num-inference-steps=50 \ + --seed=42 --enable-torch-compile --warmup + +# 7. Verify a patched module produces correct output +python - << 'EOF' +import torch +from sglang.jit_kernel.diffusion.rmsnorm import diffusion_rmsnorm + +x = torch.randn(4, 2048, dtype=torch.bfloat16, device="cuda") +w = torch.ones(2048, dtype=torch.bfloat16, device="cuda") + +out_jit = diffusion_rmsnorm(x, weight=w, eps=1e-6) +out_ref = torch.nn.functional.rms_norm(x.float(), (2048,), w.float(), eps=1e-6).to(torch.bfloat16) + +max_diff = (out_jit - out_ref).abs().max().item() +print(f"Max diff: {max_diff:.2e} ({'PASS' if max_diff < 0.02 else 'FAIL'})") +EOF +``` diff --git a/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/scripts/bench_diffusion_denoise.py b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/scripts/bench_diffusion_denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..2455147ae5ab20673d4f65eb50e5a51fe1f7204a --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/scripts/bench_diffusion_denoise.py @@ -0,0 +1,475 @@ +""" +End-to-end denoise-stage benchmark for SGLang Diffusion with/without custom JIT CUDA kernels. + +Measures denoise latency (primary metric ★) and peak GPU memory. +All model configs are kept in exact sync with diffusion-benchmark-and-profile.md. + +Adapted from: https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels + +Usage: + # Baseline — single model + python scripts/bench_diffusion_denoise.py --model flux + + # With custom JIT CUDA kernels + python scripts/bench_diffusion_denoise.py --model flux --custom-kernels + + # Side-by-side comparison + python scripts/bench_diffusion_denoise.py --model flux --compare + + # All 7 models, comparison + python scripts/bench_diffusion_denoise.py --all --compare + +Input images required for image-guided models: + mkdir -p /workspace/gen_benchmark/figs + wget -O /workspace/gen_benchmark/figs/cat.png \ + https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png + wget -O /workspace/gen_benchmark/figs/astronaut.jpg \ + https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg +""" + +import argparse +import json +import os +import subprocess +import time +from pathlib import Path +from typing import Optional + +# --------------------------------------------------------------------------- +# Model configs — kept in exact sync with diffusion-benchmark-and-profile.md +# Each entry produces the same `sglang generate` command as shown in that doc. +# --------------------------------------------------------------------------- +MODELS = { + # 1. Qwen/Qwen-Image-2512 — Text-to-Image, 1024×1024, 50 steps + "qwen": { + "path": "Qwen/Qwen-Image-2512", + "prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets, highly detailed, 8k", + "negative_prompt": " ", + "extra_args": [ + "--width=1024", + "--height=1024", + "--num-inference-steps=50", + "--guidance-scale=4.0", + "--dit-cpu-offload", + "false", + "--text-encoder-cpu-offload", + "false", + ], + }, + # 2. Qwen/Qwen-Image-Edit-2511 — Image Editing, 1024×1024, 50 steps + # Requires: /workspace/gen_benchmark/figs/cat.png + "qwen-edit": { + "path": "Qwen/Qwen-Image-Edit-2511", + "prompt": "Transform into anime style", + "negative_prompt": " ", + "image_path": "/workspace/gen_benchmark/figs/cat.png", + "extra_args": [ + "--width=1024", + "--height=1024", + "--num-inference-steps=50", + "--guidance-scale=4.0", + "--dit-cpu-offload", + "false", + "--text-encoder-cpu-offload", + "false", + ], + }, + # 3. black-forest-labs/FLUX.1-dev — Text-to-Image, 1024×1024, 50 steps + "flux": { + "path": "black-forest-labs/FLUX.1-dev", + "prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets, highly detailed, 8k", + "extra_args": [ + "--width=1024", + "--height=1024", + "--num-inference-steps=50", + "--guidance-scale=4.0", + ], + }, + # 4. black-forest-labs/FLUX.2-dev — Text-to-Image, 1024×1024 + "flux2": { + "path": "black-forest-labs/FLUX.2-dev", + "prompt": "A Logo With Bold Large Text: SGL Diffusion", + "extra_args": [ + "--width=1024", + "--height=1024", + "--dit-layerwise-offload", + "false", + "--dit-cpu-offload", + "false", + "--text-encoder-cpu-offload", + "true", + "--vae-cpu-offload", + "false", + ], + }, + # 5. Tongyi-MAI/Z-Image-Turbo — Turbo Text-to-Image, 1024×1024, 9 steps + "zimage": { + "path": "Tongyi-MAI/Z-Image-Turbo", + "prompt": "A fantasy landscape with mountains and a river, detailed, vibrant colors", + "extra_args": [ + "--width=1024", + "--height=1024", + "--num-inference-steps=9", + "--guidance-scale=0.0", + "--dit-cpu-offload", + "false", + "--text-encoder-cpu-offload", + "false", + ], + }, + # 6. Wan-AI/Wan2.2-T2V-A14B-Diffusers — Text-to-Video, 720P, 8 GPUs, 81 frames, 40 steps + "wan-t2v": { + "path": "Wan-AI/Wan2.2-T2V-A14B-Diffusers", + "prompt": "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon.", + "negative_prompt": " ", + "extra_args": [ + "--720p", + "--num-inference-steps=40", + "--num-frames=81", + "--guidance-scale=5.0", + "--num-gpus=8", + "--enable-cfg-parallel", + "--ulysses-degree=4", + "--dit-layerwise-offload", + "true", + "--dit-cpu-offload", + "false", + "--vae-cpu-offload", + "false", + "--text-encoder-cpu-offload", + "true", + ], + }, + # 7. Wan-AI/Wan2.2-TI2V-5B-Diffusers — Text-Image-to-Video, 720P, 1 GPU, 81 frames, 50 steps + # Requires: /workspace/gen_benchmark/figs/astronaut.jpg + "wan-ti2v": { + "path": "Wan-AI/Wan2.2-TI2V-5B-Diffusers", + "prompt": "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot.", + "negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", + "image_path": "/workspace/gen_benchmark/figs/astronaut.jpg", + "extra_args": [ + "--num-frames", + "81", + "--720p", + "--num-inference-steps", + "50", + "--guidance-scale", + "5.0", + "--dit-layerwise-offload", + "false", + "--dit-cpu-offload", + "false", + "--vae-cpu-offload", + "false", + "--text-encoder-cpu-offload", + "false", + ], + }, +} + + +def build_sglang_cmd( + model_key: str, + use_custom_kernels: bool, + perf_dump_path: Optional[str] = None, + warmup: bool = True, + torch_compile: bool = True, + seed: int = 42, + save_output: bool = True, +) -> list[str]: + """ + Build the `sglang generate` command for the given model. + Matches the commands in diffusion-benchmark-and-profile.md exactly. + """ + cfg = MODELS[model_key] + + cmd = [ + "sglang", + "generate", + f"--model-path={cfg['path']}", + f"--prompt={cfg['prompt']}", + f"--seed={seed}", + "--log-level=info", + ] + + if "negative_prompt" in cfg: + cmd.append(f"--negative-prompt={cfg['negative_prompt']}") + + if "image_path" in cfg: + cmd.append(f"--image-path={cfg['image_path']}") + + cmd.extend(cfg["extra_args"]) + + if save_output: + cmd.append("--save-output") + if warmup: + cmd.append("--warmup") + if torch_compile: + cmd.append("--enable-torch-compile") + if perf_dump_path: + cmd.extend(["--perf-dump-path", perf_dump_path]) + + return cmd + + +def run_benchmark_once( + model_key: str, + use_custom_kernels: bool, + output_dir: Path, + warmup: bool = True, +) -> dict: + """Run a single benchmark pass and return results dict.""" + label = "custom" if use_custom_kernels else "baseline" + perf_path = output_dir / f"{model_key}_{label}.json" + + cmd = build_sglang_cmd( + model_key, + use_custom_kernels=use_custom_kernels, + perf_dump_path=str(perf_path), + warmup=warmup, + ) + + env = os.environ.copy() + if use_custom_kernels: + # NOTE: This env var is a convention for user-implemented kernel injection + # logic. SGLang runtime does not read it by default — you must add handling + # in your denoising stage or model code to check this var and apply patches. + env["SGLANG_DIFFUSION_CUSTOM_CUDA_KERNELS"] = "1" + + print(f"\n{'=' * 64}") + print(f"[{label.upper()}] {model_key}") + print(" " + " \\\n ".join(cmd)) + print() + + t0 = time.time() + result = subprocess.run(cmd, env=env, text=True) + elapsed = time.time() - t0 + + if result.returncode != 0: + print(f" ERROR: exit code {result.returncode}") + return {"model": model_key, "label": label, "error": True, "elapsed_s": elapsed} + + metrics = {"model": model_key, "label": label, "elapsed_s": elapsed, "error": False} + if perf_path.exists(): + try: + with open(perf_path) as f: + perf = json.load(f) + + # e2e latency: total_duration_ms (set by PerformanceLogger.dump_benchmark_report) + total_ms = perf.get("total_duration_ms") + metrics["e2e_latency_s"] = ( + float(total_ms) / 1000.0 if total_ms is not None else None + ) + + # denoise latency: look in "steps" list for the "DenoisingStage" entry + # steps = [{"name": "DenoisingStage", "duration_ms": 1234.5}, ...] + denoise_latency_s = None + for step in perf.get("steps", []): + if ( + step.get("name") == "DenoisingStage" + and step.get("duration_ms") is not None + ): + denoise_latency_s = float(step["duration_ms"]) / 1000.0 + break + + # fallback: sum all per-step durations from denoise_steps_ms + # denoise_steps_ms = [{"step": 0, "duration_ms": 100.5}, ...] + if denoise_latency_s is None: + denoise_steps = perf.get("denoise_steps_ms", []) + if denoise_steps: + denoise_latency_s = ( + sum(s.get("duration_ms", 0.0) for s in denoise_steps) / 1000.0 + ) + metrics["denoise_latency_s"] = denoise_latency_s + + # peak memory: max peak_reserved_mb across all memory checkpoints (→ GB) + # memory_checkpoints = {"after_DenoisingStage": {"peak_reserved_mb": 12288.0, ...}} + peak_memory_gb = None + for snapshot in perf.get("memory_checkpoints", {}).values(): + peak_mb = snapshot.get("peak_reserved_mb") + if peak_mb is not None: + candidate = float(peak_mb) / 1024.0 + if peak_memory_gb is None or candidate > peak_memory_gb: + peak_memory_gb = candidate + metrics["peak_memory_gb"] = peak_memory_gb + + except Exception as e: + print(f" Warning: could not parse perf dump: {e}") + + return metrics + + +def print_results_table(results: list[dict]): + """Print baseline vs custom kernel comparison table.""" + print() + print("=" * 80) + print("BENCHMARK RESULTS — Denoise Latency (primary metric ★)") + print("(Models and params match diffusion-benchmark-and-profile.md)") + print("=" * 80) + + by_model: dict[str, dict] = {} + for r in results: + by_model.setdefault(r["model"], {})[r["label"]] = r + + print( + f"{'Model':<16} {'Baseline(s)':>12} {'Custom(s)':>10} {'Speedup':>9} {'Peak Mem(GB)':>14}" + ) + print("-" * 64) + + for model_key in MODELS: # preserve order + if model_key not in by_model: + continue + runs = by_model[model_key] + base = runs.get("baseline", {}) + custom = runs.get("custom", {}) + + base_lat = base.get("denoise_latency_s") + custom_lat = custom.get("denoise_latency_s") + peak_mem = base.get("peak_memory_gb") or custom.get("peak_memory_gb") + + speedup = f"{base_lat / custom_lat:.2f}x" if base_lat and custom_lat else "n/a" + base_s = f"{base_lat:.2f}" if base_lat else "n/a" + custom_s = f"{custom_lat:.2f}" if custom_lat else "n/a" + mem_s = f"{peak_mem:.1f}" if isinstance(peak_mem, float) else "n/a" + + print(f"{model_key:<16} {base_s:>12} {custom_s:>10} {speedup:>9} {mem_s:>14}") + + print("-" * 64) + print() + print("★ Denoise latency = total DiT forward pass time across all inference steps.") + print( + " See diffusion-benchmark-and-profile.md for full Level 1/2 profiling workflow." + ) + + +def inject_kernels_example(): + """ + Show the kernel injection pattern used when SGLANG_DIFFUSION_CUSTOM_CUDA_KERNELS=1. + After implementing add-cuda-kernel.md, this logic lives in denoising.py or + the model's transformer.py — NOT in this script. + + Call patch_rmsnorm(dit_model) BEFORE torch.compile and BEFORE any CPU offloading. + """ + import torch.nn as nn + + try: + from sglang.jit_kernel.diffusion.rmsnorm import diffusion_rmsnorm + except ImportError: + print( + "diffusion.rmsnorm JIT kernel not available. " + "Implement add-cuda-kernel.md first." + ) + return + + def patch_rmsnorm(model: nn.Module, verbose: bool = False) -> int: + """Monkey-patch all RMSNorm variants to use the JIT CUDA kernel.""" + patched = 0 + for name, module in model.named_modules(): + if "RMSNorm" not in type(module).__name__: + continue + eps = getattr(module, "eps", getattr(module, "variance_epsilon", 1e-6)) + has_weight = hasattr(module, "weight") and module.weight is not None + + if has_weight: + + def _make(mod, ep): + def fwd(x): + return diffusion_rmsnorm(x, weight=mod.weight, eps=ep) + + return fwd + + module.forward = _make(module, eps) + else: + + def _make_no_w(ep): + def fwd(x): + return diffusion_rmsnorm(x, weight=None, eps=ep) + + return fwd + + module.forward = _make_no_w(eps) + + patched += 1 + if verbose: + print(f" Patched: {name} (weight={has_weight})") + return patched + + return patch_rmsnorm + + +def main(): + parser = argparse.ArgumentParser( + description="SGLang Diffusion denoise benchmark — baseline vs JIT CUDA kernels" + ) + parser.add_argument( + "--model", + choices=list(MODELS.keys()), + help="Model to benchmark (default: flux)", + ) + parser.add_argument("--all", action="store_true", help="Benchmark all 7 models") + parser.add_argument( + "--custom-kernels", + action="store_true", + help="Run with custom JIT CUDA kernels (SGLANG_DIFFUSION_CUSTOM_CUDA_KERNELS=1)", + ) + parser.add_argument( + "--no-custom-kernels", + action="store_true", + help="Run baseline (no custom kernels)", + ) + parser.add_argument( + "--compare", + action="store_true", + help="Run both baseline and custom, print comparison table", + ) + parser.add_argument( + "--output-dir", + type=str, + default="/workspace/gen_benchmark/bench_results", + help="Directory for perf dump JSON files", + ) + parser.add_argument("--no-warmup", action="store_true", help="Skip warmup") + parser.add_argument( + "--show-injection-example", + action="store_true", + help="Print kernel injection pattern and exit", + ) + + args = parser.parse_args() + + if args.show_injection_example: + patch_fn = inject_kernels_example() + if patch_fn: + print( + "patch_rmsnorm function defined. " + "Call it on the DiT model before torch.compile and CPU offloading." + ) + return + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + warmup = not args.no_warmup + + models_to_run = list(MODELS.keys()) if args.all else [args.model or "flux"] + results = [] + + for model_key in models_to_run: + if args.compare: + results.append(run_benchmark_once(model_key, False, output_dir, warmup)) + results.append(run_benchmark_once(model_key, True, output_dir, warmup)) + elif args.custom_kernels: + results.append(run_benchmark_once(model_key, True, output_dir, warmup)) + else: + results.append(run_benchmark_once(model_key, False, output_dir, warmup)) + + if results: + print_results_table(results) + + print(f"Perf dump JSONs → {output_dir}") + print( + "Compare across runs: follow diffusion-benchmark-and-profile.md → Perf dump & before/after compare." + ) + + +if __name__ == "__main__": + main() diff --git a/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/scripts/bench_diffusion_rmsnorm.py b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/scripts/bench_diffusion_rmsnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..d02788c73a26f60ffd754943dacb276bbdf51ca4 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/scripts/bench_diffusion_rmsnorm.py @@ -0,0 +1,193 @@ +""" +Micro-benchmark for the SGLang Diffusion JIT CUDA RMSNorm kernel. + +Compares: + 1. SGLang JIT CUDA kernel (diffusion_rmsnorm) + 2. PyTorch baseline (torch.nn.functional.rms_norm) + +Adapted from: https://github.com/huggingface/kernels/tree/main/skills/cuda-kernels + +Usage: + python scripts/bench_diffusion_rmsnorm.py + +Requirements: + pip install triton # for triton.testing timing utilities + # SGLang must be installed and CUDA available +""" + +import time +from typing import Tuple + +import torch + +# --------------------------------------------------------------------------- +# Import the JIT CUDA kernel. +# When you implement add-cuda-kernel.md, the file will be at: +# python/sglang/jit_kernel/diffusion/rmsnorm.py +# --------------------------------------------------------------------------- +try: + from sglang.jit_kernel.diffusion.rmsnorm import diffusion_rmsnorm + + JIT_AVAILABLE = True +except ImportError: + JIT_AVAILABLE = False + print( + "WARNING: diffusion.rmsnorm JIT kernel not available. " + "Run after implementing add-cuda-kernel.md." + ) + + +def pytorch_rmsnorm( + x: torch.Tensor, + weight: torch.Tensor | None = None, + eps: float = 1e-6, +) -> torch.Tensor: + """Reference PyTorch implementation of RMSNorm.""" + hidden = x.shape[-1] + return torch.nn.functional.rms_norm( + x.float(), (hidden,), weight.float() if weight is not None else None, eps=eps + ).to(x.dtype) + + +def benchmark_kernel( + func, + args, + warmup: int = 20, + iterations: int = 100, +) -> Tuple[float, float]: + """Benchmark a kernel function. Returns (avg_ms, min_ms).""" + for _ in range(warmup): + func(*args) + torch.cuda.synchronize() + + times = [] + for _ in range(iterations): + torch.cuda.synchronize() + t0 = time.perf_counter() + func(*args) + torch.cuda.synchronize() + times.append((time.perf_counter() - t0) * 1000) + + return sum(times) / len(times), min(times) + + +def run_benchmark(): + print("=" * 72) + print("SGLang Diffusion RMSNorm Micro-Benchmark: JIT CUDA vs PyTorch") + print("=" * 72) + print(f"Device: {torch.cuda.get_device_name(0)}") + cap = torch.cuda.get_device_capability() + print(f"Compute Capability: sm_{cap[0]}{cap[1]}") + print() + + if not JIT_AVAILABLE: + print("Skipping JIT kernel benchmark (kernel not available).") + return + + # Determine dtype: T4 (sm_75) has no BF16 + dtype = torch.bfloat16 if cap >= (8, 0) else torch.float16 + print(f"Dtype: {dtype}") + print() + + # Typical DiT hidden sizes for sglang diffusion models: + # FLUX.1-dev: hidden=3072 + # Qwen-Image: hidden=2048 + # Wan2.2: hidden=4096 + configs = [ + # (batch_tokens, hidden_size, has_weight) + (1024, 2048, True), # Qwen-Image: 1 sample × 1024 tokens + (4096, 2048, True), # Qwen-Image: larger batch + (1024, 3072, True), # FLUX: 1 sample × 1024 tokens + (4096, 3072, True), # FLUX: larger + (4096, 4096, True), # Wan2.2 + (4096, 2048, False), # no-weight (elementwise_affine=False) + (16384, 3072, True), # long sequence + ] + + print( + f"{'Config':<32} {'JIT(ms)':>10} {'PyTorch(ms)':>12} {'Speedup':>9} {'Weight'}" + ) + print("-" * 72) + + total_speedup = 0 + n = 0 + + for batch_tokens, hidden, has_weight in configs: + x = torch.randn(batch_tokens, hidden, dtype=dtype, device="cuda") + weight = torch.ones(hidden, dtype=dtype, device="cuda") if has_weight else None + + jit_avg, _ = benchmark_kernel( + diffusion_rmsnorm, (x, weight, 1e-6), warmup=20, iterations=100 + ) + pt_avg, _ = benchmark_kernel( + pytorch_rmsnorm, (x, weight, 1e-6), warmup=20, iterations=100 + ) + + speedup = pt_avg / jit_avg + total_speedup += speedup + n += 1 + + w_str = "yes" if has_weight else "no " + cfg = f"[{batch_tokens}×{hidden}]" + print(f"{cfg:<32} {jit_avg:>10.3f} {pt_avg:>12.3f} {speedup:>8.2f}x {w_str}") + + print("-" * 72) + print(f"{'Average Speedup':>56} {total_speedup / n:.2f}x") + print() + + # ----------------------------------------------------------------------- + # Correctness check + # ----------------------------------------------------------------------- + print("Correctness Check (BF16 tolerance 0.02):") + x = torch.randn(4096, 3072, dtype=dtype, device="cuda") + weight = torch.ones(3072, dtype=dtype, device="cuda") + + out_jit = diffusion_rmsnorm(x, weight=weight, eps=1e-6) + out_ref = pytorch_rmsnorm(x, weight=weight, eps=1e-6) + + max_diff = (out_jit - out_ref).abs().max().item() + rel_diff = ((out_jit - out_ref).abs() / (out_ref.abs() + 1e-8)).max().item() + passed = max_diff < 0.02 + + print(f" Max absolute diff: {max_diff:.2e}") + print(f" Max relative diff: {rel_diff:.2e}") + print(f" Correctness: {'PASS ✓' if passed else 'FAIL ✗'}") + print() + + # ----------------------------------------------------------------------- + # Memory bandwidth analysis + # ----------------------------------------------------------------------- + print("Memory Bandwidth Analysis:") + bt, hid = 4096, 3072 + x = torch.randn(bt, hid, dtype=dtype, device="cuda") + weight = torch.ones(hid, dtype=dtype, device="cuda") + + bytes_per_elem = dtype.itemsize + total_bytes = ( + bt * hid + hid + bt * hid + ) * bytes_per_elem # read x + read w + write out + jit_avg, _ = benchmark_kernel(diffusion_rmsnorm, (x, weight, 1e-6)) + + bandwidth_gbps = (total_bytes / 1e9) / (jit_avg / 1000) + theoretical_bw = { + (9, 0): 3350, # H100: 3.35 TB/s + (8, 0): 2000, # A100 80GB + }.get( + cap, 320 + ) # T4: 320 GB/s + efficiency = bandwidth_gbps / theoretical_bw * 100 + + print(f" Shape: [{bt} × {hid}] dtype: {dtype}") + print(f" Total data: {total_bytes / 1e6:.1f} MB") + print(f" Achieved: {bandwidth_gbps:.1f} GB/s") + print(f" Theoretical ({torch.cuda.get_device_name(0)}): {theoretical_bw} GB/s") + print(f" Bandwidth efficiency: {efficiency:.1f}%") + print() + print("Target: ≥ 30% efficiency (H100/A100), ≥ 40% (T4)") + + +if __name__ == "__main__": + if not torch.cuda.is_available(): + print("CUDA not available.") + else: + run_benchmark() diff --git a/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/use-efficient-diffusion-kernels.md b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/use-efficient-diffusion-kernels.md new file mode 100644 index 0000000000000000000000000000000000000000..c77216ce4077fbac83d1c1ec33b81fcdb9d52347 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-kernel/use-efficient-diffusion-kernels.md @@ -0,0 +1,112 @@ +--- +name: use-efficient-diffusion-kernels +description: Guidance for using SGLang Diffusion fused kernels and fast CUDA paths. Use when mapping fusion patterns in diffusion inference, choosing fused ops or attention backends, handling RoPE/QK norm performance pitfalls, or integrating new diffusion models with kernel-aware constraints. +--- + +# Use Efficient Diffusion Kernels + +**Overview** +This skill focuses on SGLang Diffusion (`sglang.multimodal_gen`) kernel fusion patterns and fast CUDA paths. Prefer existing fused ops (Triton, CuTe DSL, sgl-kernel). Make constraints and fallbacks explicit. + +**Key Files** +- `python/sglang/multimodal_gen/runtime/layers/layernorm.py` +- `python/sglang/multimodal_gen/runtime/layers/elementwise.py` +- `python/sglang/multimodal_gen/runtime/layers/rotary_embedding/utils.py` +- `python/sglang/jit_kernel/diffusion/triton/scale_shift.py` +- `python/sglang/jit_kernel/diffusion/triton/norm.py` +- `python/sglang/jit_kernel/diffusion/triton/rmsnorm_onepass.py` +- `python/sglang/jit_kernel/diffusion/triton/rotary.py` +- `python/sglang/jit_kernel/diffusion/cutedsl/scale_residual_norm_scale_shift.py` +- `python/sglang/jit_kernel/norm.py` +- `python/sglang/multimodal_gen/runtime/platforms/cuda.py` +- `python/sglang/multimodal_gen/runtime/layers/attention/selector.py` +- `docs/diffusion/performance/attention_backends.md` (repo root) + +**Core Fusion Patterns** + +1. Scale/Shift elementwise fusion (AdaLN modulation) +- Kernels: `fuse_scale_shift_kernel`, `fuse_scale_shift_gate_select01_kernel` +- Locations: `elementwise.py`, `layernorm.py`, `qwen_image.py`, `triton/scale_shift.py` +- Use cases: `x * (1 + scale) + shift` and `a * (k + b) + c` +- Constraints: `x` must be CUDA and contiguous. `scale/shift` support 0D/1D/2D/3D/4D broadcast. 4D `[B, F, 1, C]` requires `L % F == 0`. +- NPU fallback: `scale_shift.py` swaps to `npu_fallback` native path. + +2. Norm + Scale/Shift fusion (CuTe DSL) +- Kernels: `fused_norm_scale_shift`, `fused_scale_residual_norm_scale_shift` +- Locations: `layernorm.py`, `cutedsl/scale_residual_norm_scale_shift.py` +- Use cases: + - `y = norm(x) * (1 + scale) + shift` + - `y = norm(residual + gate * x) * (1 + scale) + shift` +- Constraints: `D % 256 == 0` and `D <= 8192`. `x/residual/gate/scale/shift` must pass shape and stride validation. Dtypes limited to fp16/bf16/fp32. +- Behavior: CuTe DSL compilation cached by `(dtype, ndim, D, norm_type)`. `None` tensors replaced by scalar placeholders. If constraints fail, `layernorm.py` warns and falls back to native PyTorch. + +3. Triton LayerNorm/RMSNorm fusion +- Kernels: `rms_norm_fn`, `layer_norm_fn`, `norm_infer` +- Locations: `triton/norm.py`, `layernorm.py` +- Use cases: fp32 RMSNorm with residual/dropout/rowscale/x1 branches, and inference-friendly `norm_infer`. +- Constraints: last dim must be contiguous, and `N * element_size < 64KB`. + +4. Triton one-pass RMSNorm (small hidden size fast path) +- Kernel: `triton_one_pass_rms_norm` +- Locations: `triton/rmsnorm_onepass.py`, `layernorm.py` +- Use case: `hidden_size <= 128` in `RMSNorm.forward_cuda`. + +5. Triton RoPE fusion +- Kernel: `apply_rotary_embedding` +- Locations: `triton/rotary.py`, `rotary_embedding/utils.py` +- Use case: GPT-J style RoPE when not Neox. +- Constraints: `head_size` must be even. +- NPU fallback: `npu_fallback.apply_rotary_embedding_native`. + +**Faster CUDA Kernel Usage Points** + +1. sgl-kernel RMSNorm and fused add RMSNorm +- Location: `layernorm.py` +- Behavior: CUDA uses `sgl_kernel.fused_add_rmsnorm` and `sgl_kernel.rmsnorm`. `hidden_size <= 128` uses Triton one-pass. ROCm falls back to native. + +2. Attention backend selection (FlashAttention, Sage, SDPA) +- Locations: `platforms/cuda.py`, `attention/selector.py`, `docs/diffusion/performance/attention_backends.md` +- Behavior: CUDA prefers FlashAttention (FA3/FA4) when supported, otherwise Torch SDPA. Force via `--attention-backend` or `global_force_attn_backend`. + +3. FlashInfer RoPE (Q/K inplace) +- Location: `rotary_embedding/utils.py` +- Behavior: `flashinfer.rope.apply_rope_with_cos_sin_cache_inplace` when available, otherwise Triton RoPE fallback. + +**QK Norm Optimization** + +- Entry point: `apply_qk_norm` in `layernorm.py`. +- Fast path: JIT fused inplace QK norm from `python/sglang/jit_kernel/norm.py` via `fused_inplace_qknorm`. +- Preconditions for fused path: + - CUDA only. + - `allow_inplace=True` and `q_eps == k_eps`. + - `can_use_fused_inplace_qknorm(head_dim, dtype)` returns true. + - Supported head dims: `64, 128, 256, 512, 1024`. +- Behavior: Fused path operates on `q` and `k` in place after reshaping to `[B, -1, head_dim]`. If preconditions fail, fall back to per-tensor RMSNorm. + +**Common Entry Points in Diffusion Models** +- AdaLN modulation: `LayerNormScaleShift`, `RMSNormScaleShift`, `ScaleResidual*` in `layernorm.py`. +- Qwen-Image gating: `fuse_scale_shift_gate_select01_kernel` in `qwen_image.py`. +- QK norm: `apply_qk_norm` used in `flux.py`, `flux_2.py`, `qwen_image.py`, `zimage.py`, `wanvideo.py`, `ltx_2.py`, `hunyuanvideo.py`. +- RoPE: `_apply_rotary_emb` prefers Triton; Q/K RoPE prefers FlashInfer when present. + +**Constraints and Fallbacks** +- `scale_shift` Triton requires CUDA + contiguous `x`. NPU swaps to native. +- CuTe DSL fused norms require `D % 256 == 0` and `D <= 8192`. +- Triton norm kernels error on feature size >= 64KB. +- FlashAttention requires fp16/bf16 and SM80+; otherwise SDPA. + +**Integration Checklist for New Models** + +1. Reuse `LayerNormScaleShift` or `ScaleResidual*` modules instead of re-implementing fusion logic. +2. Keep tensors contiguous and satisfy D alignment (`% 256`) and size (`<= 8192`) for CuTe fused paths. +3. Use `fuse_scale_shift_kernel` for AdaLN modulation and keep a PyTorch fallback. +4. Use `apply_qk_norm` and ensure head_dim is in the supported list for fused QK norm. +5. If using FlashInfer RoPE, avoid `pack qkv` and ensure Q/K are contiguous. +6. For attention, follow `selector.py` priority; override with CLI only if needed. + +**When Extending or Modifying Kernels** +- Add `torch.library.custom_op` and `register_fake` for compile and meta support. +- Keep CuTe compile cache keys aligned to `(dtype, ndim, D)`. +- Avoid implicit broadcasts that force hidden `contiguous()` copies. +- Preserve NPU and ROCm fallback paths. +- **Always verify with ncu** (`ncu --set full`) that the kernel achieves adequate memory bandwidth utilization (>70% of peak for bandwidth-bound ops) and occupancy (>50%). See `diffusion-benchmark-and-profile.md` Step 3.5 for the ncu workflow. diff --git a/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-perf/SKILL.md b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-perf/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..25b2b53829e5a09807d355c73e16f9b5101a75cb --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/.claude/skills/diffusion-perf/SKILL.md @@ -0,0 +1,18 @@ +--- +name: diffusion-perf +description: Deprecated alias (merged into diffusion-kernel). +user-invocable: false +allowed-tools: Bash, Read +argument-hint: [--prompt "..."] [--baseline baseline.json] +--- + +# Diffusion Performance Measurement + +This skill has been merged into the canonical docs under `diffusion-kernel`: + +- `../diffusion-kernel/diffusion-benchmark-and-profile.md` → **Perf dump & before/after compare** + +Follow that document as the single source of truth: + +- Always run `sglang generate ... --warmup --perf-dump-path .json` +- Use `python python/sglang/multimodal_gen/benchmarks/compare_perf.py ` to generate a PR-ready comparison table diff --git a/sglang/python/sglang/multimodal_gen/README.md b/sglang/python/sglang/multimodal_gen/README.md new file mode 100644 index 0000000000000000000000000000000000000000..313567fdecb1ba0b9c245e786c8d24946c4e3b54 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/README.md @@ -0,0 +1,95 @@ +
+ +
+ +**SGLang diffusion is an inference framework for accelerated image/video generation.** + +SGLang diffusion features an end-to-end unified pipeline for accelerating diffusion models. It is designed to be modular and extensible, allowing users to easily add new models and optimizations. + +## Key Features + +SGLang Diffusion has the following features: + - Broad model support: Wan series, FastWan series, Hunyuan, Qwen-Image, Qwen-Image-Edit, Flux, Z-Image, GLM-Image + - Fast inference speed: enpowered by highly optimized kernel from sgl-kernel and efficient scheduler loop + - Ease of use: OpenAI-compatible api, CLI, and python sdk support + - Multi-platform support: NVIDIA GPUs (H100, H200, A100, B200, 4090) and AMD GPUs (MI300X, MI325X) + +### AMD/ROCm Support + +SGLang Diffusion supports AMD Instinct GPUs through ROCm. On AMD platforms, we use the Triton attention backend and leverage AITER kernels for optimized layernorm and other operations. See the [installation guide](https://github.com/sgl-project/sglang/tree/main/docs/diffusion/installation.md) for setup instructions. + +### Moore Threads/MUSA Support + +SGLang Diffusion supports Moore Threads GPUs (MTGPU) through the MUSA software stack. On MUSA platforms, we use the Torch SDPA backend for attention. See the [installation guide](https://github.com/sgl-project/sglang/tree/main/docs/diffusion/installation.md) for setup instructions. + +## Getting Started + +```bash +uv pip install 'sglang[diffusion]' --prerelease=allow +``` + +For more installation methods (e.g. pypi, uv, docker, ROCm/AMD, MUSA/Moore Threads), check [install.md](https://github.com/sgl-project/sglang/tree/main/docs/diffusion/installation.md). + +## Inference + +Here's a minimal example to generate a video using the default settings: + +```python +from sglang.multimodal_gen import DiffGenerator + +def main(): + # Create a diff generator from a pre-trained model + generator = DiffGenerator.from_pretrained( + model_path="Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + num_gpus=1, # Adjust based on your hardware + ) + + # Generate the video + video = generator.generate( + sampling_params_kwargs=dict( + prompt="A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest.", + return_frames=True, # Also return frames from this call (defaults to False) + output_path="my_videos/", # Controls where videos are saved + save_output=True + ) + ) + +if __name__ == '__main__': + main() +``` + +Or, more simply, with the CLI: + +```bash +sglang generate --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ + --text-encoder-cpu-offload --pin-cpu-memory \ + --prompt "A curious raccoon" \ + --save-output +``` + +### LoRA support + +Apply LoRA adapters via `--lora-path`: + +```bash +sglang generate \ + --model-path Qwen/Qwen-Image-Edit-2511 \ + --lora-path prithivMLmods/Qwen-Image-Edit-2511-Anime \ + --prompt "Transform into anime." \ + --image-path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png" \ + --save-output +``` + +For more usage examples (e.g. OpenAI compatible API, server mode), check [cli.md](https://github.com/sgl-project/sglang/tree/main/docs/diffusion/cli.md). + +## Contributing + +All contributions are welcome. The contribution guide is available [here](https://github.com/sgl-project/sglang/tree/main/docs/diffusion/contributing.md). + +## Acknowledgement + +We learnt and reused code from the following projects: + +- [FastVideo](https://github.com/hao-ai-lab/FastVideo.git). The major components of this repo are based on a fork of FastVideo on Sept. 24, 2025. +- [xDiT](https://github.com/xdit-project/xDiT). We used the parallelism library from it. +- [diffusers](https://github.com/huggingface/diffusers) We used the pipeline design from it. diff --git a/sglang/python/sglang/multimodal_gen/__init__.py b/sglang/python/sglang/multimodal_gen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d7545560e435feb6082cf2dec89b945101efe77 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/__init__.py @@ -0,0 +1,8 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo +from sglang.multimodal_gen.configs.pipeline_configs import PipelineConfig +from sglang.multimodal_gen.configs.sample import SamplingParams +from sglang.multimodal_gen.runtime.entrypoints.diffusion_generator import DiffGenerator + +__all__ = ["DiffGenerator", "PipelineConfig", "SamplingParams"] + +# Trigger multimodal CI tests diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/README.md b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3d1146121e03989de57c349b39ac835f1fdb2252 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/README.md @@ -0,0 +1,59 @@ +# ComfyUI SGLDiffusion Plugin + +A ComfyUI plugin for integrating with SGLang Diffusion server, supporting image and video generation capabilities. + +## Installation + +1. **Install SGLang**: Follow the [Installation Guide](../../docs/install.md) to install `sglang[diffusion]`. +2. **Install Plugin**: Copy this entire directory (`ComfyUI_SGLDiffusion`) to your ComfyUI `custom_nodes/` folder. +3. **Restart ComfyUI**: Restart ComfyUI to load the plugin. + +## Usage + +The plugin supports two modes of operation: **Server Mode** (via HTTP API) and **Integrated Mode** (tight integration with ComfyUI). + +### Supported Models +- **Z-Image**: High-speed image generation models (e.g., `Z-Image-Turbo`) +- **FLUX**: State-of-the-art text-to-image models (e.g., `FLUX.1-dev`) +- **Qwen-Image**: Multi-modal image generation models (e.g., `Qwen-Image`,`Qwen-Image-2512`). *Note: Image editing support is currently experimental and may have some issues.* + +### Mode 1: Server Mode (HTTP API) +Connect to a standalone SGLang Diffusion server. + +1. **Start SGLang Diffusion Server**: Ensure the server is running and accessible. +2. **Connect to Server**: Use the `SGLDiffusion Server Model` node to connect (default: `http://localhost:3000/v1`). +3. **Generate Content**: + - `SGLDiffusion Generate Image`: For text-to-image and image editing. + - `SGLDiffusion Generate Video`: For text-to-video and image-to-video. +4. **LoRA Support**: Use `SGLDiffusion Server Set LoRA` and `SGLDiffusion Server Unset LoRA`. + +### Mode 2: Integrated Mode (Tight Integration) +Leverage SGLang's high-performance sampling directly within ComfyUI while using ComfyUI's front-end nodes (CLIP, VAE, etc.). + +1. **Load Model**: Use the `SGLDiffusion UNET Loader` node to load your diffusion model. +2. **Configure Options**: Use the `SGLDiffusion Options` node to set runtime parameters like `num_gpus`, `tp_size`, `model_type`, or `enable_torch_compile`. +3. **Sample**: Connect the loaded model to standard ComfyUI samplers. SGLang will handle the sampling process efficiently. +4. **LoRA Support**: Use the `SGLDiffusion LoRA Loader` for native LoRA integration. + +## Example Workflows + +Reference workflow files are provided in the `workflows/` directory: + +- **`flux_sgld_sp.json`**: Multi-GPU (Sequence Parallelism) workflow for FLUX models. High-performance inference across multiple cards. +- **`qwen_image_sgld.json`**: Qwen-Image generation with LoRA support. Optimized for multi-modal image tasks. +- **`z-image_sgld.json`**: High-speed image generation using Z-Image. +- **`sgld_text2img.json`**: Server-mode text-to-image generation with LoRA support. +- **`sgld_image2video.json`**: Server-mode image-to-video generation. + +For other workflows supporting the models, you can easily use SGLang by replacing the official `UNET Loader` node with the `SGLDUNETLoader` node. Similarly, for LoRA support, replace the official LoRA loader with the `SGLDiffusion LoRA Loader`. + +To use these workflows: +1. Open ComfyUI. +2. Load the workflow JSON file from the `workflows/` directory. +3. Adjust the parameters and model paths as needed. +4. Run the workflow. + + +## Current Implementation + +This plugin provides a high-performance backend for diffusion models in ComfyUI. By leveraging SGLang's optimized kernels and parallelization techniques (Tensor Parallelism, TeaCache, etc.), it significantly accelerates the sampling process, especially for large models like FLUX. diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/__init__.py b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a34ef84f975988416b2267f478b95e495c6abe3e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/__init__.py @@ -0,0 +1,13 @@ +""" +ComfyUI SGLang Diffusion nodes package. +""" + +try: + from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + + __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] +except ImportError: + # ComfyUI dependencies not available (e.g., in test environment) + NODE_CLASS_MAPPINGS = {} + NODE_DISPLAY_NAME_MAPPINGS = {} + __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/__init__.py b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a7d01106b3cab6a27f86718d2229700741d75a02 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/__init__.py @@ -0,0 +1,14 @@ +""" +Core components for SGLang Diffusion ComfyUI integration. +Provides generator, model patcher, and server API client. +""" + +from .generator import SGLDiffusionGenerator +from .model_patcher import SGLDModelPatcher +from .server_api import SGLDiffusionServerAPI + +__all__ = [ + "SGLDiffusionGenerator", + "SGLDModelPatcher", + "SGLDiffusionServerAPI", +] diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/generator.py b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..a0ddb94523676d7dd121314decadec3ad58e8ec2 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/generator.py @@ -0,0 +1,232 @@ +""" +Generator for SGLang Diffusion ComfyUI integration. +""" + +import logging +import os + +import psutil +from comfy import model_detection, model_management +from comfy.utils import ( + calculate_parameters, + load_torch_file, + state_dict_prefix_replace, + unet_to_diffusers, +) + +logger = logging.getLogger(__name__) + +try: + from sglang.multimodal_gen import DiffGenerator +except ImportError: + logger.error( + "Error: sglang.multimodal_gen is not installed. Please install it using 'pip install sglang[diffusion]'" + ) + +from ..executors import ( + FluxExecutor, + QwenImageEditExecutor, + QwenImageExecutor, + ZImageExecutor, +) +from .model_patcher import SGLDModelPatcher + + +class SGLDiffusionGenerator: + """Generator for SGLang Diffusion models in ComfyUI.""" + + def __init__(self): + self.model_path = None + self.generator = None + self.executor = None + self.last_options = None + + self.pipeline_class_dict = { + "flux": "ComfyUIFluxPipeline", + "lumina2": "ComfyUIZImagePipeline", # zimage + "qwen_image": "ComfyUIQwenImagePipeline", + "qwen_image_edit": "ComfyUIQwenImageEditPipeline", + } + self.executor_class_dict = { + "flux": FluxExecutor, + "lumina2": ZImageExecutor, + "qwen_image": QwenImageExecutor, + "qwen_image_edit": QwenImageEditExecutor, + } + + def __del__(self): + self.close_generator() + + def init_generator( + self, model_path: str, pipeline_class_name: str, kwargs: dict = None + ): + """Initialize the diffusion generator.""" + if self.generator is not None: + return self.generator + if kwargs is None: + kwargs = {} + # Set comfyui_mode for ComfyUI integration + kwargs["comfyui_mode"] = True + self.generator = DiffGenerator.from_pretrained( + model_path=model_path, + pipeline_class_name=pipeline_class_name, + **kwargs, + ) + return self.generator + + def kill_generator(self): + """Kill worker processes manually because generator shutdown cannot terminate them.""" + current_pid = os.getpid() + worker_processes = [] + for proc in psutil.process_iter(["pid", "name", "cmdline"]): + try: + # Look for sglang-diffusionWorker processes + if proc.info["cmdline"]: + cmdline = " ".join(proc.info["cmdline"]) + if "sgl_diffusion::" in cmdline: + if proc.info["pid"] != current_pid: + worker_processes.append(proc) + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + + if worker_processes: + logger.info( + f"Found {len(worker_processes)} worker processes to terminate..." + ) + for proc in worker_processes: + try: + logger.info( + f"Terminating worker process {proc.info['pid']}: {proc.info['name']}" + ) + proc.terminate() + proc.wait(timeout=5) + except psutil.TimeoutExpired: + logger.warning( + f"Process {proc.info['pid']} did not terminate, forcing kill..." + ) + try: + proc.kill() + proc.wait(timeout=2) + except (psutil.NoSuchProcess, psutil.TimeoutExpired): + pass + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + + def close_generator(self): + """Close and cleanup the generator and all associated resources.""" + if self.generator is not None: + self.generator.shutdown() + self.kill_generator() + # Clear other references + self.last_options = None + self.model_path = None + self.generator = None + self.executor = None + + def get_comfyui_model(self, model_path: str, model_options: dict = None): + """Get ComfyUI model from model path.""" + if model_options is None: + model_options = {} + dtype = model_options.get("dtype", None) + # Allow loading unets from checkpoint files + sd = load_torch_file(model_path) + diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) + temp_sd = state_dict_prefix_replace( + sd, {diffusion_model_prefix: ""}, filter_keys=True + ) + if len(temp_sd) > 0: + sd = temp_sd + + parameters = calculate_parameters(sd) + load_device = model_management.get_torch_device() + + model_detect_config = model_detection.detect_unet_config(sd, "") + model_type = model_detect_config.get("image_model", None) + if model_type is None or model_type not in self.pipeline_class_dict: + raise ValueError(f"Unsupported model type: {model_type}") + model_config = model_detection.model_config_from_unet(sd, "") + + if model_config is not None: + new_sd = sd + else: + new_sd = model_detection.convert_diffusers_mmdit(sd, "") + if new_sd is not None: # diffusers mmdit + model_config = model_detection.model_config_from_unet(new_sd, "") + if model_config is None: + return None + else: # diffusers unet + model_config = model_detection.model_config_from_diffusers_unet(sd) + if model_config is None: + return None + + diffusers_keys = unet_to_diffusers(model_config.unet_config) + new_sd = {} + for k in diffusers_keys: + if k in sd: + new_sd[diffusers_keys[k]] = sd.pop(k) + offload_device = model_management.unet_offload_device() + if dtype is None: + unet_dtype = model_management.unet_dtype( + model_params=parameters, + supported_dtypes=model_config.supported_inference_dtypes, + ) + else: + unet_dtype = dtype + + manual_cast_dtype = model_management.unet_manual_cast( + unet_dtype, load_device, model_config.supported_inference_dtypes + ) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) + model_config.custom_operations = model_options.get("custom_operations", None) + model_config.unet_config["disable_unet_model_creation"] = True + comfyui_model = model_config.get_model({}) + return comfyui_model, model_config, model_type + + def load_model( + self, model_path: str, model_options: dict = None, sgld_options: dict = None + ): + """Load model and return model patcher.""" + gather_options = { + "model_path": model_path, + "model_options": model_options, + "sgld_options": sgld_options, + } + if ( + self.last_options is not None + and self.last_options == gather_options + and self.generator is not None + ): + return self.generator + else: + self.close_generator() + + self.last_options = gather_options + self.model_path = model_path + + comfyui_model, model_config, model_type = self.get_comfyui_model( + model_path, model_options + ) + if model_type is None or model_type not in self.pipeline_class_dict: + raise ValueError(f"Unsupported model type: {model_type}") + + set_model_type = sgld_options.pop("model_type", None) if sgld_options else None + if set_model_type is not None and set_model_type in self.pipeline_class_dict: + model_type = set_model_type + + pipeline_class_name = self.pipeline_class_dict[model_type] + self.generator = self.init_generator( + model_path, pipeline_class_name, sgld_options + ) + + executor_class = self.executor_class_dict[model_type] + self.executor = executor_class( + self.generator, model_path, comfyui_model, model_config + ) + comfyui_model.diffusion_model = self.executor + + load_device = model_management.get_torch_device() + offload_device = model_management.unet_offload_device() + + return SGLDModelPatcher( + comfyui_model, load_device, offload_device, model_type=model_type + ) diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/model_patcher.py b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/model_patcher.py new file mode 100644 index 0000000000000000000000000000000000000000..c88efd29f0eb2a555a76411ed9c35538a0778f61 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/model_patcher.py @@ -0,0 +1,82 @@ +""" +Model patcher for SGLang Diffusion ComfyUI integration. +""" + +import copy + +from comfy.model_patcher import ModelPatcher + + +class SGLDModelPatcher(ModelPatcher): + """Model patcher for SGLang Diffusion models in ComfyUI.""" + + def __init__( + self, + model, + load_device, + offload_device, + size=0, + weight_inplace_update=False, + model_type=None, + ): + super().__init__( + model, load_device, offload_device, size, weight_inplace_update + ) + self.lora_cache = {} + self.model_type = model_type + self.model_size_dict = { + "flux": 27 * 1024 * 1024 * 1024, + "lumina2": 8 * 1024 * 1024 * 1024, + } + + def clone(self): + """Clone the model patcher.""" + n = SGLDModelPatcher( + self.model, + self.load_device, + self.offload_device, + self.size, + weight_inplace_update=self.weight_inplace_update, + ) + n.patches = {} + for k in self.patches: + n.patches[k] = self.patches[k][:] + n.patches_uuid = self.patches_uuid + + n.object_patches = self.object_patches.copy() + n.model_options = copy.deepcopy(self.model_options) + n.backup = self.backup + n.object_patches_backup = self.object_patches_backup + n.lora_cache = copy.copy(self.lora_cache) + return n + + def model_size(self): + """Get the model size in bytes.""" + if self.model_type in self.model_size_dict: + return self.model_size_dict[self.model_type] + else: + return 0 + + def load( + self, + device_to=None, + lowvram_model_memory=0, + force_patch_weights=False, + full_load=False, + ): + """Load model (no-op for SGLang Diffusion).""" + pass + + def patch_model( + self, + device_to=None, + lowvram_model_memory=0, + load_weights=True, + force_patch_weights=False, + ): + """Patch model (no-op for SGLang Diffusion).""" + pass + + def unpatch_model(self, device_to=None, unpatch_weights=True): + """Unpatch model (no-op for SGLang Diffusion).""" + pass diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/server_api.py b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/server_api.py new file mode 100644 index 0000000000000000000000000000000000000000..2f86efe79f65524b8ac3653ea723e5706365523b --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/core/server_api.py @@ -0,0 +1,539 @@ +""" +SGLang Diffusion Server API client. +Provides a low-level interface for interacting with SGLang Diffusion HTTP server. +""" + +import base64 +import io +import os +import time +from typing import Any, Dict, Optional + +import requests +from PIL import Image + + +class SGLDiffusionServerAPI: + """Client for SGLang Diffusion HTTP server API.""" + + def __init__(self, base_url: str, api_key: str = "sk-proj-1234567890"): + """ + Initialize the API client. + + Args: + base_url: Base URL of the SGLang Diffusion server (e.g., "http://localhost:30010/v1") + api_key: API key for authentication (default: "sk-proj-1234567890") + """ + # Ensure base_url doesn't end with /v1 if it's already there + if base_url.endswith("/v1"): + self.base_url = base_url + elif base_url.endswith("/v1/"): + self.base_url = base_url.rstrip("/") + else: + self.base_url = f"{base_url.rstrip('/')}/v1" + + self.api_key = api_key + self.headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + } + + def get_model_info(self) -> Dict[str, Any]: + """ + Get information about the model served by this server. + + Returns: + Dictionary containing model information including: + - model_path: Path to the model + - task_type: Type of task (e.g., "T2V", "I2I") + - pipeline_name: Name of the pipeline + - num_gpus: Number of GPUs + - dit_precision: DiT model precision + - vae_precision: VAE model precision + """ + try: + # Remove /v1 from base_url for /models endpoint + models_url = self.base_url.removesuffix("/v1") + "/models" + response = requests.get(models_url, headers=self.headers, timeout=30) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Failed to get model info: {str(e)}") + + def generate_image( + self, + prompt: str, + image_path: Optional[str] = None, + mask_path: Optional[str] = None, + size: Optional[str] = None, + width: Optional[int] = None, + height: Optional[int] = None, + n: int = 1, + negative_prompt: Optional[str] = None, + guidance_scale: Optional[float] = None, + num_inference_steps: Optional[int] = None, + seed: Optional[int] = None, + enable_teacache: bool = False, + response_format: str = "b64_json", + quality: Optional[str] = "auto", + style: Optional[str] = "vivid", + background: Optional[str] = "auto", + output_format: Optional[str] = None, + generator_device: Optional[str] = "cuda", + ) -> Dict[str, Any]: + """ + Generate or edit an image using SGLang Diffusion API. + If image_path is provided, calls the edit endpoint; otherwise calls the generation endpoint. + + Args: + prompt: Text prompt for image generation/editing + image_path: Optional path to input image file for editing. If provided, uses edit API. + mask_path: Optional path to mask image file (only used when image_path is provided) + size: Image size in format "WIDTHxHEIGHT" (e.g., "1024x1024") + width: Image width (used if size is not provided) + height: Image height (used if size is not provided) + n: Number of images to generate (1-10) + negative_prompt: Negative prompt to avoid certain elements + guidance_scale: Classifier-free guidance scale + num_inference_steps: Number of denoising steps + seed: Random seed for reproducible generation + enable_teacache: Enable TEA cache acceleration + response_format: Response format ("b64_json" or "url") + quality: Image quality ("auto", "standard", "hd") - only for generation + style: Image style ("vivid" or "natural") - only for generation + background: Background type ("auto", "transparent", "opaque") + output_format: Output format ("png", "jpeg", "webp") + generator_device: Device for random generator ("cuda" or "cpu") + + Returns: + Dictionary containing the API response with generated/edited image data + """ + if not prompt: + raise ValueError("Prompt cannot be empty") + + # Determine size + if size is None: + if width is not None and height is not None: + size = f"{width}x{height}" + else: + size = "1024x1024" + + # Build common parameters + common_params = self._build_image_common_params( + prompt=prompt, + size=size, + n=n, + response_format=response_format, + negative_prompt=negative_prompt, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + seed=seed, + enable_teacache=enable_teacache, + background=background, + output_format=output_format, + generator_device=generator_device, + ) + + # If image_path is provided, use edit endpoint + if image_path: + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image file not found: {image_path}") + + # Prepare multipart form data for edit + files: Dict[str, Any] = {} + data = common_params.copy() + + # Add image file + files["image"] = ( + os.path.basename(image_path), + open(image_path, "rb"), + self._get_content_type(image_path), + ) + + # Add mask file if provided + if mask_path: + if not os.path.exists(mask_path): + raise FileNotFoundError(f"Mask file not found: {mask_path}") + files["mask"] = ( + os.path.basename(mask_path), + open(mask_path, "rb"), + self._get_content_type(mask_path), + ) + + # Prepare headers for multipart form data + headers = { + "Authorization": f"Bearer {self.api_key}", + } + + try: + response = requests.post( + f"{self.base_url}/images/edits", + files=files, + data=data, + headers=headers, + timeout=300, # 5 minutes timeout for generation + ) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Failed to edit image: {str(e)}") + finally: + # Close file handles + for file_tuple in files.values(): + if isinstance(file_tuple, tuple) and len(file_tuple) > 1: + file_tuple[1].close() + else: + # Use generation endpoint - add generation-specific parameters + payload = common_params.copy() + if quality: + payload["quality"] = quality + if style: + payload["style"] = style + + try: + response = requests.post( + f"{self.base_url}/images/generations", + json=payload, + headers=self.headers, + timeout=300, # 5 minutes timeout for generation + ) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Failed to generate image: {str(e)}") + + def generate_video( + self, + prompt: str, + size: Optional[str] = None, + width: Optional[int] = None, + height: Optional[int] = None, + seconds: Optional[int] = 4, + fps: Optional[int] = None, + num_frames: Optional[int] = None, + negative_prompt: Optional[str] = None, + guidance_scale: Optional[float] = None, + num_inference_steps: Optional[int] = None, + seed: Optional[int] = None, + enable_teacache: bool = False, + generator_device: Optional[str] = "cuda", + input_reference: Optional[str] = None, + output_path: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Generate a video using SGLang Diffusion API and wait for completion. + + Args: + prompt: Text prompt for video generation + size: Video size in format "WIDTHxHEIGHT" (e.g., "1280x720") + width: Video width (used if size is not provided) + height: Video height (used if size is not provided) + seconds: Duration of the video in seconds + fps: Frames per second + num_frames: Number of frames (overrides seconds * fps if provided) + negative_prompt: Negative prompt to avoid certain elements + guidance_scale: Classifier-free guidance scale + num_inference_steps: Number of denoising steps + seed: Random seed for reproducible generation + enable_teacache: Enable TEA cache acceleration + generator_device: Device for random generator ("cuda" or "cpu") + input_reference: Path to input reference image for image-to-video + + Returns: + Dictionary containing completed video job information with file_path + """ + if not prompt: + raise ValueError("Prompt cannot be empty") + + # Determine size + if size is None: + if width is not None and height is not None: + size = f"{width}x{height}" + else: + size = "720x1280" + + # Prepare request payload + payload: Dict[str, Any] = { + "prompt": prompt, + "size": size, + } + + # Add optional parameters + if seconds is not None: + payload["seconds"] = seconds + if fps is not None: + payload["fps"] = fps + if num_frames is not None: + payload["num_frames"] = num_frames + if negative_prompt: + payload["negative_prompt"] = negative_prompt + if guidance_scale is not None: + payload["guidance_scale"] = guidance_scale + if num_inference_steps is not None: + payload["num_inference_steps"] = num_inference_steps + if seed is not None and seed >= 0: + payload["seed"] = seed + if enable_teacache: + payload["enable_teacache"] = True + if generator_device: + payload["generator_device"] = generator_device + if input_reference: + payload["input_reference"] = input_reference + if output_path: + payload["output_path"] = output_path + + try: + # Create video generation job + response = requests.post( + f"{self.base_url}/videos", + json=payload, + headers=self.headers, + timeout=30, + ) + response.raise_for_status() + video_job = response.json() + video_id = video_job.get("id") + + # Wait for completion with fixed polling + poll_interval = 5 # 5 seconds + max_wait_time = 3600 # 1 hour + max_consecutive_errors = 5 + consecutive_errors = 0 + start_time = time.time() + + while time.time() - start_time < max_wait_time: + try: + status_response = requests.get( + f"{self.base_url}/videos/{video_id}", + headers=self.headers, + timeout=30, + ) + status_response.raise_for_status() + status = status_response.json() + + # Reset error counter on successful request + consecutive_errors = 0 + + if status.get("status") == "completed": + return status + elif status.get("status") == "failed": + error = status.get("error", {}) + error_msg = ( + error.get("message", "Unknown error") + if error + else "Unknown error" + ) + raise RuntimeError(f"Video generation failed: {error_msg}") + except requests.exceptions.ConnectionError as e: + # Connection errors - likely server is down + consecutive_errors += 1 + if consecutive_errors >= max_consecutive_errors: + raise RuntimeError( + f"Lost connection to server after {consecutive_errors} consecutive errors. " + f"Server may be unavailable: {str(e)}" + ) + except requests.exceptions.RequestException as e: + # Other network errors - continue polling but track errors + consecutive_errors += 1 + if consecutive_errors >= max_consecutive_errors: + raise RuntimeError( + f"Network error after {consecutive_errors} consecutive failures: {str(e)}" + ) + + time.sleep(poll_interval) + + raise TimeoutError( + f"Video generation timed out after {max_wait_time} seconds" + ) + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Failed to generate video: {str(e)}") + + def _build_image_common_params( + self, + prompt: str, + size: str, + n: int, + response_format: str, + negative_prompt: Optional[str] = None, + guidance_scale: Optional[float] = None, + num_inference_steps: Optional[int] = None, + seed: Optional[int] = None, + enable_teacache: bool = False, + background: Optional[str] = None, + output_format: Optional[str] = None, + generator_device: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Build common parameters for both image generation and editing. + + Returns: + Dictionary containing common parameters + """ + params: Dict[str, Any] = { + "prompt": prompt, + "size": size, + "n": max(1, min(n, 10)), + "response_format": response_format, + } + + # Add optional parameters + if negative_prompt: + params["negative_prompt"] = negative_prompt + if guidance_scale is not None: + params["guidance_scale"] = guidance_scale + if num_inference_steps is not None: + params["num_inference_steps"] = num_inference_steps + if seed is not None and seed >= 0: + params["seed"] = seed + if enable_teacache: + params["enable_teacache"] = True + if background: + params["background"] = background + if output_format: + params["output_format"] = output_format + if generator_device: + params["generator_device"] = generator_device + + return params + + def _get_content_type(self, file_path: str) -> str: + """Get content type based on file extension.""" + ext = os.path.splitext(file_path)[1].lower() + content_types = { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".webp": "image/webp", + } + return content_types.get(ext, "image/png") + + def decode_image_from_response( + self, response_data: Dict[str, Any], index: int = 0 + ) -> Image.Image: + """ + Decode base64 image from API response. + + Args: + response_data: API response dictionary + index: Index of the image in the response (default: 0) + + Returns: + PIL Image object + """ + if "data" not in response_data or not response_data["data"]: + raise ValueError("No image data in response") + + if index >= len(response_data["data"]): + raise IndexError(f"Image index {index} out of range") + + image_data = response_data["data"][index] + if "b64_json" not in image_data or not image_data["b64_json"]: + raise ValueError("No base64 image data found") + + image_bytes = base64.b64decode(image_data["b64_json"]) + image = Image.open(io.BytesIO(image_bytes)) + + # Convert to RGB if needed + if image.mode != "RGB": + image = image.convert("RGB") + + return image + + def set_lora( + self, + lora_nickname: str, + lora_path: Optional[str] = None, + target: str = "all", + ) -> Dict[str, Any]: + """ + Set a LoRA adapter for the specified transformer(s). + + Args: + lora_nickname: The nickname of the adapter (required). + lora_path: Path to the LoRA adapter (local path or HF repo id). + Required for the first load; optional if re-activating a cached nickname. + target: Which transformer(s) to apply the LoRA to. One of: + - "all": Apply to all transformers (default) + - "transformer": Apply only to the primary transformer (high noise for Wan2.2) + - "transformer_2": Apply only to transformer_2 (low noise for Wan2.2) + - "critic": Apply only to the critic model + + Returns: + Dictionary containing the API response with status and message + """ + if not lora_nickname: + raise ValueError("lora_nickname cannot be empty") + + # Prepare request payload + payload: Dict[str, Any] = { + "lora_nickname": lora_nickname, + "target": target, + } + + # Add optional lora_path if provided + if lora_path: + payload["lora_path"] = lora_path + + try: + response = requests.post( + f"{self.base_url}/set_lora", + json=payload, + headers=self.headers, + timeout=30, + ) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Failed to set LoRA adapter: {str(e)}") + + def unset_lora( + self, + target: str = "all", + ) -> Dict[str, Any]: + """ + Unset (unmerge) LoRA weights from the base model. + + Args: + target: same as set_lora + + Returns: + Dictionary containing the API response with status and message + """ + # Prepare request payload + payload: Dict[str, Any] = { + "target": target, + } + + try: + response = requests.post( + f"{self.base_url}/unmerge_lora_weights", + json=payload, + headers=self.headers, + timeout=30, + ) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Failed to unset LoRA adapter: {str(e)}") + + +if __name__ == "__main__": + api = SGLDiffusionServerAPI( + base_url="http://localhost:30010/v1", api_key="sk-proj-1234567890" + ) + model_info = api.get_model_info() + print(api.get_model_info()) + if model_info.get("task_type") == "T2V" or model_info.get("task_type") == "I2V": + print( + api.generate_video( + prompt="A calico cat playing a piano on stage", + num_inference_steps=1, + size="480x480", + ) + ) + else: + print( + api.generate_image( + prompt="A calico cat playing a piano on stage", size="1024x1024" + ) + ) diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/__init__.py b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..84afba8425ee1bfcb5f31dea944b0a34cd63ecce --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/__init__.py @@ -0,0 +1,17 @@ +""" +ComfyUI SGLang Diffusion executors package. +Provides executor classes for different model types. +""" + +from .base import SGLDiffusionExecutor +from .flux import FluxExecutor +from .qwen_image import QwenImageEditExecutor, QwenImageExecutor +from .zimage import ZImageExecutor + +__all__ = [ + "SGLDiffusionExecutor", + "FluxExecutor", + "ZImageExecutor", + "QwenImageExecutor", + "QwenImageEditExecutor", +] diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/base.py b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e14f8226b8b44f9165b203f6984c2423b5816764 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/base.py @@ -0,0 +1,56 @@ +""" +Base executor class for SGLang Diffusion ComfyUI integration. +""" + +import torch + + +class SGLDiffusionExecutor(torch.nn.Module): + """Base executor class for SGLang Diffusion models in ComfyUI.""" + + def __init__(self, generator, model_path, model, config): + super(SGLDiffusionExecutor, self).__init__() + self.generator = generator + self.model_path = model_path + self.model = model + self.dtype = config.unet_config["dtype"] + self.config = config + self.loras = [] + + @staticmethod + def should_suppress_logs(timestep): + """Determine if logs should be suppressed based on timestep value.""" + if torch.is_tensor(timestep): + return bool((timestep < 1.0).item()) + return bool(timestep < 1.0) + + def set_lora(self, lora_nickname=None, lora_path=None, strength=None, target=None): + """Set LoRA adapter using SGLang Diffusion API.""" + if len(lora_nickname) > 0: + self.generator.set_lora( + lora_nickname=lora_nickname, + lora_path=lora_path, + strength=strength, + target=target, + ) + + def _unpack_latents(self, latents, height, width, channels): + """Unpack latents from packed format to standard format.""" + batch_size = latents.shape[0] + latents = latents.view(batch_size, height // 2, width // 2, channels, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + latents = latents.reshape(batch_size, channels, height, width) + + return latents + + def _pack_latents(self, latents): + """Pack latents from standard format to packed format.""" + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view( + batch_size, num_channels_latents, height // 2, 2, width // 2, 2 + ) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape( + batch_size, (height // 2) * (width // 2), num_channels_latents * 4 + ) + return latents diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/flux.py b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/flux.py new file mode 100644 index 0000000000000000000000000000000000000000..489d3383f50958b1ec73b9d1c5e3758b48b3b6a6 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/flux.py @@ -0,0 +1,69 @@ +""" +Flux executor for SGLang Diffusion ComfyUI integration. +""" + +import torch + +try: + from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams + from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request +except ImportError: + print( + "Error: sglang.multimodal_gen is not installed. Please install it using 'pip install sglang[diffusion]'" + ) + +from .base import SGLDiffusionExecutor + + +class FluxExecutor(SGLDiffusionExecutor): + """Executor for Flux models in ComfyUI.""" + + def __init__(self, generator, model_path, model, config): + super().__init__(generator, model_path, model, config) + + def forward(self, x, timestep, context, y=None, guidance=None, **kwargs): + """Forward pass for Flux model.""" + hidden_states = self._pack_latents(x) + timesteps = timestep * 1000.0 + encoder_hidden_states = context + pooled_projections = y + guidance = guidance * 1000.0 + + B, C, H, W = x.shape + height = H * 8 + width = W * 8 + # Create SamplingParams + sampling_params = SamplingParams.from_user_sampling_params_args( + self.model_path, + server_args=self.generator.server_args, + prompt=" ", + guidance_scale=3.5, # Flux typically uses embedded_cfg_scale=3.5 + height=height, + width=width, + num_frames=1, + num_inference_steps=1, + save_output=False, + suppress_logs=self.should_suppress_logs(timestep), + ) + + # Prepare request (converts SamplingParams to Req) + req = prepare_request( + server_args=self.generator.server_args, + sampling_params=sampling_params, + ) + req.latents = hidden_states # Set as [B, S, D] format directly + req.timesteps = timesteps # ComfyUI's timesteps parameter + req.prompt_embeds = [pooled_projections, encoder_hidden_states] # [CLIP, T5] + req.raw_latent_shape = torch.tensor(hidden_states.shape, dtype=torch.long) + + # Set pooled_projections (required by Flux) + req.pooled_embeds = [pooled_projections] # List format as per Req definition + req.do_classifier_free_guidance = False + req.generator = [ + torch.Generator("cuda") for _ in range(req.num_outputs_per_prompt) + ] + + # Send request to scheduler + output_batch = self.generator._send_to_scheduler_and_wait_for_response([req]) + noise_pred = output_batch.noise_pred + return self._unpack_latents(noise_pred, H, W, C).to(x.device) diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/qwen_image.py b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/qwen_image.py new file mode 100644 index 0000000000000000000000000000000000000000..b89186801b4719c16b9a7c3f25de605061900608 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/qwen_image.py @@ -0,0 +1,174 @@ +""" +QwenImage executor for SGLang Diffusion ComfyUI integration. +""" + +import torch + +try: + from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams + from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request +except ImportError: + print( + "Error: sglang.multimodal_gen is not installed. Please install it using 'pip install sglang[diffusion]'" + ) + +import comfy.ldm.common_dit + +from .base import SGLDiffusionExecutor + + +class QwenImageExecutor(SGLDiffusionExecutor): + """Executor for QwenImage models in ComfyUI.""" + + def __init__(self, generator, model_path, model, config): + super().__init__(generator, model_path, model, config) + self.patch_size = 2 + + def _pack_latents(self, x): + """Process hidden states for QwenImage model.""" + bs, c, t, h, w = x.shape + patch_size = self.patch_size + latents = comfy.ldm.common_dit.pad_to_patch_size( + x, (1, self.patch_size, self.patch_size) + ) + orig_shape = latents.shape + latents = latents.view( + orig_shape[0], + orig_shape[1], + orig_shape[-3], + orig_shape[-2] // 2, + 2, + orig_shape[-1] // 2, + 2, + ) + latents = latents.permute(0, 2, 3, 5, 1, 4, 6) + latents = latents.reshape( + orig_shape[0], + orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), + orig_shape[1] * 4, + ) + return latents, orig_shape + + def _unpack_latents(self, latents, num_embeds, orig_shape, x): + """Unpack hidden states from packed format to standard format.""" + latents = latents[:, :num_embeds].view( + orig_shape[0], + orig_shape[-3], + orig_shape[-2] // 2, + orig_shape[-1] // 2, + orig_shape[1], + 2, + 2, + ) + latents = latents.permute(0, 4, 1, 2, 5, 3, 6) + latents = latents.reshape(orig_shape)[:, :, :, : x.shape[-2], : x.shape[-1]] + return latents + + def forward(self, x, timestep, context, **kwargs): + """Forward pass for QwenImage model.""" + latents, orig_shape = self._pack_latents(x) + num_embeds = latents.shape[1] + height = orig_shape[-2] * 8 + width = orig_shape[-1] * 8 + + sampling_params = SamplingParams.from_user_sampling_params_args( + self.model_path, + server_args=self.generator.server_args, + prompt=" ", + guidance_scale=1.0, + height=height, + width=width, + num_frames=1, + num_inference_steps=1, + save_output=False, + suppress_logs=self.should_suppress_logs(timestep), + ) + + # Prepare request (converts SamplingParams to Req) + req = prepare_request( + server_args=self.generator.server_args, + sampling_params=sampling_params, + ) + # Set ComfyUI-specific inputs directly on the Req object + req.latents = latents + req.timesteps = timestep * 1000.0 + req.prompt_embeds = [context] + req.raw_latent_shape = torch.tensor(latents.shape, dtype=torch.long) + req.do_classifier_free_guidance = False + req.generator = [ + torch.Generator("cuda") for _ in range(req.num_outputs_per_prompt) + ] + + output_batch = self.generator._send_to_scheduler_and_wait_for_response([req]) + noise_pred = output_batch.noise_pred + + return self._unpack_latents(noise_pred, num_embeds, orig_shape, x) + + +class QwenImageEditExecutor(QwenImageExecutor): + """Executor for QwenImageEdit models in ComfyUI.""" + + def __init__(self, generator, model_path, model, config): + super().__init__(generator, model_path, model, config) + + def forward( + self, + x, + timestep, + context, + attention_mask=None, + ref_latents=None, + additional_t_cond=None, + transformer_options={}, + **kwargs + ): + """Forward pass for QwenImageEdit model.""" + latents, orig_shape = self._pack_latents(x) + num_embeds = latents.shape[1] + height = orig_shape[-2] * 8 + width = orig_shape[-1] * 8 + + # Prepare vae_image_sizes for the condition image (ref_latents) + vae_image_sizes = [] + pack_ref_latents = None + + # TODO: sgld now don't support multiple condition images, so we only support one condition image for now. + if ref_latents is not None and len(ref_latents) > 0: + pack_ref_latents, orig_ref_shape = self._pack_latents(ref_latents[0]) + vae_image_sizes = [(orig_ref_shape[-1], orig_ref_shape[-2])] + + sampling_params = SamplingParams.from_user_sampling_params_args( + self.model_path, + server_args=self.generator.server_args, + prompt=" ", + guidance_scale=1.0, + image_path="", + height=height, + width=width, + num_frames=1, + num_inference_steps=1, + save_output=False, + suppress_logs=self.should_suppress_logs(timestep), + ) + + # Prepare request (converts SamplingParams to Req) + req = prepare_request( + server_args=self.generator.server_args, + sampling_params=sampling_params, + ) + # Set ComfyUI-specific inputs directly on the Req object + req.latents = latents + req.image_latent = pack_ref_latents + req.timesteps = timestep * 1000.0 + req.vae_image_sizes = vae_image_sizes + req.prompt_embeds = [context] + req.raw_latent_shape = torch.tensor(latents.shape, dtype=torch.long) + req.do_classifier_free_guidance = False + req.generator = [ + torch.Generator("cuda") for _ in range(req.num_outputs_per_prompt) + ] + + output_batch = self.generator._send_to_scheduler_and_wait_for_response([req]) + noise_pred = output_batch.noise_pred + + return self._unpack_latents(noise_pred, num_embeds, orig_shape, x) diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/zimage.py b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/zimage.py new file mode 100644 index 0000000000000000000000000000000000000000..d817c4b1a26a22011ce18cc6d07b9d16e23351d0 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/zimage.py @@ -0,0 +1,64 @@ +""" +ZImage executor for SGLang Diffusion ComfyUI integration. +""" + +import torch + +try: + from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams + from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request +except ImportError: + print( + "Error: sglang.multimodal_gen is not installed. Please install it using 'pip install sglang[diffusion]'" + ) + +from .base import SGLDiffusionExecutor + + +class ZImageExecutor(SGLDiffusionExecutor): + """Executor for ZImage models in ComfyUI.""" + + def __init__(self, generator, model_path, model, config): + super().__init__(generator, model_path, model, config) + + def forward(self, x, timesteps, context, **kwargs): + """Forward pass for ZImage model.""" + B, C, H, W = x.shape + height = H * 8 + width = W * 8 + sampling_params = SamplingParams.from_user_sampling_params_args( + self.model_path, + server_args=self.generator.server_args, + prompt=" ", + guidance_scale=1.0, + height=height, + width=width, + num_frames=1, # For images + num_inference_steps=1, # Single step for ComfyUI + save_output=False, + suppress_logs=self.should_suppress_logs(timesteps), + ) + + # Prepare request (converts SamplingParams to Req) + req = prepare_request( + server_args=self.generator.server_args, + sampling_params=sampling_params, + ) + latents = x.unsqueeze(2) + context = context.squeeze(0) + # Set ComfyUI-specific inputs directly on the Req object + req.latents = latents # ComfyUI's x parameter + req.timesteps = timesteps * 1000.0 # ComfyUI's timesteps parameter + req.prompt_embeds = [ + context + ] # ComfyUI's context parameter (must be List[Tensor]) + req.raw_latent_shape = torch.tensor(latents.shape, dtype=torch.long) + req.do_classifier_free_guidance = False + req.generator = [ + torch.Generator("cuda") for _ in range(req.num_outputs_per_prompt) + ] + + output_batch = self.generator._send_to_scheduler_and_wait_for_response([req]) + noise_pred = output_batch.noise_pred + + return noise_pred.permute(1, 0, 2, 3).to(x.device) diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/nodes.py b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..c28d7d0e059c1ca0353873ebaa8fcb1322e1da59 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/nodes.py @@ -0,0 +1,715 @@ +""" +ComfyUI nodes for SGLang Diffusion integration. +Provides nodes for connecting to SGLang Diffusion server and generating images/videos. +""" + +import os +import uuid + +import folder_paths +import torch + +from .core import SGLDiffusionGenerator, SGLDiffusionServerAPI +from .utils import ( + convert_b64_to_tensor_image, + convert_video_to_comfy_video, + get_image_path, + is_empty_image, +) + + +class SGLDOptions: + @classmethod + def INPUT_TYPES(cls): + return { + "required": {}, + "optional": { + "model_type": ( + ["auto-detect", "qwen_image", "qwen_image_edit", "flux", "lumina2"], + {"default": "auto-detect"}, + ), + "enable_torch_compile": ( + "BOOLEAN", + {"default": False}, + ), + "num_gpus": ("INT", {"default": 1, "min": 1, "step": 1}), + "tp_size": ("INT", {"default": -1, "min": -1, "step": 1}), + "sp_degree": ("INT", {"default": -1, "min": -1, "step": 1}), + "ulysses_degree": ( + "INT", + { + "default": -1, + "min": -1, + "step": 1, + }, + ), + "ring_degree": ( + "INT", + { + "default": -1, + "min": -1, + "step": 1, + }, + ), + "dp_size": ("INT", {"default": 1, "min": 1, "step": 1}), + "dp_degree": ("INT", {"default": 1, "min": 1, "step": 1}), + "enable_cfg_parallel": ( + "BOOLEAN", + {"default": False}, + ), + "attention_backend": ( + "STRING", + {"default": ""}, + ), + }, + } + + RETURN_TYPES = ("SGLD_OPTIONS",) + RETURN_NAMES = ("sgld_options",) + FUNCTION = "create_options" + CATEGORY = "SGLDiffusion" + + def create_options( + self, + model_type: str = "auto-detect", + enable_torch_compile: bool = False, + num_gpus: int = 1, + tp_size: int = -1, + sp_degree: int = -1, + ulysses_degree: int = -1, + ring_degree: int = -1, + dp_size: int = 1, + dp_degree: int = 1, + enable_cfg_parallel: bool = False, + attention_backend: str = "", + ): + """ + Build a dictionary of SGLang Diffusion runtime options. + """ + # Convert -1 to None for optional parameters (matching ServerArgs defaults) + ulysses_degree = None if ulysses_degree == -1 else ulysses_degree + ring_degree = None if ring_degree == -1 else ring_degree + attention_backend = None if attention_backend == "" else attention_backend + + options = { + "model_type": model_type, + "enable_torch_compile": enable_torch_compile, + "num_gpus": num_gpus, + "tp_size": tp_size, + "sp_degree": sp_degree, + "ulysses_degree": ulysses_degree, + "ring_degree": ring_degree, + "dp_size": dp_size, + "dp_degree": dp_degree, + "enable_cfg_parallel": enable_cfg_parallel, + "attention_backend": attention_backend, + } + + # Strip None to keep payload clean + options = {k: v for k, v in options.items() if v is not None} + return (options,) + + +class SGLDLoraLoader: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("MODEL",), + "lora_name": (folder_paths.get_filename_list("loras"),), + "strength_model": ( + "FLOAT", + {"default": 1.0, "min": 0, "max": 10, "step": 0.01}, + ), + "nickname": ("STRING", {"default": ""}), + "target": ( + ["all", "transformer", "transformer_2", "critic"], + {"default": "all"}, + ), + }, + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_lora" + + CATEGORY = "SGLDiffusion" + + def load_lora( + self, model, lora_name, strength_model=1.0, nickname="", target="all" + ): + """Load LoRA adapter using SGLang Diffusion API.""" + lora_path = folder_paths.get_full_path("loras", lora_name) + assert model is not None + bi = model.clone() + nickname = nickname if nickname != "" else str("lora" + str(uuid.uuid4())) + # set lora in the model + bi.patches[nickname] = (lora_path, strength_model, target) + + # prepare input for the SGLang Diffusion API + lora_input = { + "lora_nickname": [], + "lora_path": [], + "strength": [], + "target": [], + } + for nickname, lora_info in bi.patches.items(): + lora_input["lora_nickname"].append(nickname) + lora_input["lora_path"].append(lora_info[0]) + lora_input["strength"].append(lora_info[1]) + lora_input["target"].append(lora_info[2]) + + # call the SGLang Diffusion API + model.model.diffusion_model.set_lora(**lora_input) + return (model,) + + +class SGLDUNETLoader: + def __init__(self): + self.generator = SGLDiffusionGenerator() + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "unet_name": (folder_paths.get_filename_list("diffusion_models"),), + "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],), + }, + "optional": { + "sgld_options": ("SGLD_OPTIONS",), + }, + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_unet" + + CATEGORY = "SGLDiffusion" + + def load_unet(self, unet_name, weight_dtype, sgld_options: dict = None): + model_options = {} + if weight_dtype == "fp8_e4m3fn": + model_options["dtype"] = torch.float8_e4m3fn + elif weight_dtype == "fp8_e5m2": + model_options["dtype"] = torch.float8_e5m2 + + unet_path = folder_paths.get_full_path("diffusion_models", unet_name) + + model = self.generator.load_model( + unet_path, model_options=model_options, sgld_options=sgld_options + ) + return (model,) + + +class SGLDiffusionServerModel: + """Node to load and manage SGLang Diffusion server connection.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "base_url": ( + "STRING", + { + "default": "http://localhost:3000/v1", + "multiline": False, + }, + ), + "api_key": ( + "STRING", + { + "default": "sk-proj-1234567890", + "multiline": False, + }, + ), + } + } + + RETURN_TYPES = ("SGLD_CLIENT", "STRING") + RETURN_NAMES = ("sgld_client", "model_info") + FUNCTION = "load_server" + CATEGORY = "SGLDiffusion" + + def load_server(self, base_url: str, api_key: str): + """Initialize OpenAI client for SGLang Diffusion server.""" + client = SGLDiffusionServerAPI(base_url=base_url, api_key=api_key) + try: + model_info = client.get_model_info() + # Format model_info as a readable string + info_lines = ["=== SGLDiffusion Model Info ==="] + for key, value in model_info.items(): + info_lines.append(f"{key}: {value}") + model_info_str = "\n".join(info_lines) + except Exception as e: + model_info_str = f"Failed to get model info: {str(e)}" + return (client, model_info_str) + + +class SGLDiffusionGenerateImage: + """Node to generate images using SGLang Diffusion.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "sgld_client": ("SGLD_CLIENT",), + "positive_prompt": ( + "STRING", + { + "default": "", + "tooltip": "Text prompt for image generation", + }, + ), + }, + "optional": { + "negative_prompt": ( + "STRING", + { + "default": "", + "tooltip": "Negative prompt to avoid certain elements", + }, + ), + "image": ( + "IMAGE", + { + "default": None, + "tooltip": "input image to use for editing", + }, + ), + "seed": ( + "INT", + { + "default": 1024, + "min": -1, + "max": 2**32 - 1, + }, + ), + "steps": ( + "INT", + { + "default": 6, + "min": 1, + "max": 100, + "step": 1, + }, + ), + "cfg": ( + "FLOAT", + { + "default": 7.0, + "min": 1.0, + "max": 20.0, + "step": 0.1, + }, + ), + "width": ( + "INT", + { + "default": 1024, + "min": 256, + "max": 4096, + "step": 64, + }, + ), + "height": ( + "INT", + { + "default": 1024, + "min": 256, + "max": 4096, + "step": 64, + }, + ), + "enable_teacache": ( + "BOOLEAN", + { + "default": False, + }, + ), + }, + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("image",) + FUNCTION = "generate_image" + CATEGORY = "SGLDiffusion" + OUTPUT_NODE = False + + def generate_image( + self, + sgld_client: SGLDiffusionServerAPI, + positive_prompt: str, + negative_prompt: str = "", + image: torch.Tensor = None, + seed: int = 1024, + steps: int = 6, + cfg: float = 7.0, + width: int = 1024, + height: int = 1024, + enable_teacache: bool = False, + ): + """Generate image using SGLang Diffusion API.""" + if not positive_prompt: + raise ValueError("Prompt cannot be empty") + + size = f"{width}x{height}" + + # Prepare request parameters + request_params = { + "prompt": positive_prompt, + "size": size, + "response_format": "b64_json", + } + + # Add optional parameters if provided + if negative_prompt: + request_params["negative_prompt"] = negative_prompt + if cfg is not None: + request_params["guidance_scale"] = cfg + if steps is not None: + request_params["num_inference_steps"] = steps + if seed is not None and seed >= 0: + request_params["seed"] = seed + if enable_teacache: + request_params["enable_teacache"] = True + if image is not None: + # If the image is empty, use the size of the image to generate the image + if is_empty_image(image): + width, height = image.shape[2], image.shape[1] + size = f"{width}x{height}" + request_params["size"] = size + else: + request_params["image_path"] = get_image_path(image) + + # Call API + try: + response = sgld_client.generate_image(**request_params) + except Exception as e: + raise RuntimeError(f"Failed to generate image: {str(e)}") + + # Decode base64 image + if not response["data"] or not response["data"][0]["b64_json"]: + raise RuntimeError("No image data in response") + image_data = response["data"][0]["b64_json"] + image = convert_b64_to_tensor_image(image_data) + + return (image,) + + +class SGLDiffusionGenerateVideo: + """Node to generate videos using SGLang Diffusion.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "sgld_client": ("SGLD_CLIENT",), + "positive_prompt": ( + "STRING", + { + "default": "", + "tooltip": "Text prompt for video generation", + }, + ), + }, + "optional": { + "negative_prompt": ( + "STRING", + { + "default": "", + "tooltip": "Negative prompt to avoid certain elements", + }, + ), + "image": ( + "IMAGE", + { + "default": None, + "tooltip": "input image to use for image-to-video", + }, + ), + "seed": ( + "INT", + { + "default": 1024, + "min": -1, + "max": 2**32 - 1, + }, + ), + "steps": ( + "INT", + { + "default": 6, + "min": 1, + "max": 100, + "step": 1, + }, + ), + "cfg": ( + "FLOAT", + { + "default": 7.0, + "min": 1.0, + "max": 20.0, + "step": 0.1, + }, + ), + "width": ( + "INT", + { + "default": 1280, + "min": 256, + "max": 4096, + "step": 1, + }, + ), + "height": ( + "INT", + { + "default": 720, + "min": 256, + "max": 4096, + "step": 1, + }, + ), + "num_frames": ( + "INT", + { + "default": 120, + "min": 1, + "max": 1000, + "step": 1, + }, + ), + "fps": ( + "INT", + { + "default": 24, + "min": 1, + "max": 60, + "step": 1, + }, + ), + "seconds": ( + "INT", + { + "default": 5, + "min": 1, + "max": 60, + "step": 1, + }, + ), + "enable_teacache": ( + "BOOLEAN", + { + "default": False, + }, + ), + }, + } + + RETURN_TYPES = ("VIDEO", "STRING") + RETURN_NAMES = ("video", "video_path") + FUNCTION = "generate_video" + CATEGORY = "SGLDiffusion" + OUTPUT_NODE = False + + def generate_video( + self, + sgld_client: SGLDiffusionServerAPI, + positive_prompt: str, + negative_prompt: str = "", + image: torch.Tensor = None, + seed: int = 1024, + steps: int = 6, + cfg: float = 7.0, + width: int = 1280, + height: int = 720, + num_frames: int = 120, + fps: int = 24, + seconds: int = 5, + enable_teacache: bool = False, + ): + """Generate video using SGLang Diffusion API.""" + if not positive_prompt: + raise ValueError("Prompt cannot be empty") + + size = f"{width}x{height}" + output_dir = folder_paths.get_temp_directory() + + # Prepare request parameters + request_params = { + "prompt": positive_prompt, + "size": size, + "seconds": seconds, + "fps": fps, + "output_path": output_dir, + } + + # Add optional parameters if provided + if negative_prompt: + request_params["negative_prompt"] = negative_prompt + if cfg is not None: + request_params["guidance_scale"] = cfg + if steps is not None: + request_params["num_inference_steps"] = steps + if seed is not None and seed >= 0: + request_params["seed"] = seed + if enable_teacache: + request_params["enable_teacache"] = True + if num_frames is not None: + request_params["num_frames"] = num_frames + if image is not None: + # If the image is empty, use the size of the image to generate the video + if is_empty_image(image): + width, height = image.shape[2], image.shape[1] + size = f"{width}x{height}" + request_params["size"] = size + else: + request_params["input_reference"] = get_image_path(image) + + # Call API + try: + response = sgld_client.generate_video(**request_params) + video_path = response.get("file_path", "") + video = convert_video_to_comfy_video(video_path, height, width) + except Exception as e: + raise RuntimeError(f"Failed to generate video: {str(e)}") + + return (video, video_path) + + +class SGLDiffusionServerSetLora: + """Node to set LoRA adapter for SGLang Diffusion server.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "sgld_client": ("SGLD_CLIENT",), + "lora_name": ( + "STRING", + { + "default": "", + "tooltip": "The name of the LoRA adapter", + }, + ), + }, + "optional": { + "lora_nickname": ( + "STRING", + { + "default": "", + "tooltip": "The nickname of the LoRA adapter", + }, + ), + "target": ( + [ + "all", + "transformer", + "transformer_2", + "critic", + ], + { + "default": "all", + "tooltip": "Which transformer(s) to apply the LoRA to", + }, + ), + }, + } + + RETURN_TYPES = ("SGLD_CLIENT",) + RETURN_NAMES = ("sgld_client",) + FUNCTION = "set_lora" + CATEGORY = "SGLDiffusion" + OUTPUT_NODE = False + + def set_lora( + self, + sgld_client: SGLDiffusionServerAPI, + lora_name: str = "", + lora_nickname: str = "", + target: str = "all", + ): + """Set LoRA adapter using SGLang Diffusion API.""" + if lora_nickname == "": + lora_nickname = os.path.splitext(lora_name)[0] + + # Prepare request parameters + request_params = { + "lora_nickname": lora_nickname, + "lora_path": lora_name, + "target": target, + } + + # Call API + try: + response = sgld_client.set_lora(**request_params) + return (sgld_client,) + except Exception as e: + raise RuntimeError(f"Failed to set LoRA adapter: {str(e)}") + + +class SGLDiffusionServerUnsetLora: + """Node to unset LoRA adapter for SGLang Diffusion server.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "sgld_client": ("SGLD_CLIENT",), + }, + "optional": { + "target": ( + [ + "all", + "transformer", + "transformer_2", + "critic", + ], + { + "default": "all", + "tooltip": "Which transformer(s) to unset the LoRA from", + }, + ), + }, + } + + RETURN_TYPES = ("SGLD_CLIENT",) + RETURN_NAMES = ("sgld_client",) + FUNCTION = "unset_lora" + CATEGORY = "SGLDiffusion" + OUTPUT_NODE = False + + def unset_lora( + self, + sgld_client: SGLDiffusionServerAPI, + target: str = "all", + ): + """Unset LoRA adapter using SGLang Diffusion API.""" + try: + response = sgld_client.unset_lora(target=target) + return (sgld_client,) + except Exception as e: + raise RuntimeError(f"Failed to unset LoRA adapter: {str(e)}") + + +# Register nodes +NODE_CLASS_MAPPINGS = { + "SGLDiffusionServerModel": SGLDiffusionServerModel, + "SGLDiffusionGenerateImage": SGLDiffusionGenerateImage, + "SGLDiffusionGenerateVideo": SGLDiffusionGenerateVideo, + "SGLDiffusionServerSetLora": SGLDiffusionServerSetLora, + "SGLDiffusionServerUnsetLora": SGLDiffusionServerUnsetLora, + "SGLDUNETLoader": SGLDUNETLoader, + "SGLDOptions": SGLDOptions, + "SGLDLoraLoader": SGLDLoraLoader, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "SGLDiffusionServerModel": "SGLDiffusion Server Model", + "SGLDiffusionGenerateImage": "SGLDiffusion Generate Image", + "SGLDiffusionGenerateVideo": "SGLDiffusion Generate Video", + "SGLDiffusionServerSetLora": "SGLDiffusion Server Set LoRA", + "SGLDiffusionServerUnsetLora": "SGLDiffusion Server Unset LoRA", + "SGLDUNETLoader": "SGLDiffusion UNET Loader", + "SGLDOptions": "SGLDiffusion Options", + "SGLDLoraLoader": "SGLDiffusion LoRA Loader", +} diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/README.md b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5246d29231f376eeb5d289ed309b049579437e49 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/README.md @@ -0,0 +1,66 @@ +# ComfyUI SGLDiffusion Pipeline Tests + +This directory contains tests for each ComfyUI pipeline integration. + +## Test Files + +- `test_zimage_pipeline.py` - Tests for ComfyUIZImagePipeline +- `test_flux_pipeline.py` - Tests for ComfyUIFluxPipeline +- `test_qwen_image_pipeline.py` - Tests for ComfyUIQwenImagePipeline +- `test_qwen_image_edit_pipeline.py` - Tests for ComfyUIQwenImageEditPipeline (I2I/edit mode) + +## Running Tests + +### Run all tests + +```bash +pytest python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/ -v -s +``` + +### Run a specific test file + +```bash +pytest python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_zimage_pipeline.py -v -s +``` + +## Environment Variables + +You can configure model paths via environment variables. Model paths support two formats: +- **Safetensors file**: Path to a single `.safetensors` file (e.g., `/path/to/model.safetensors`) +- **Diffusers format**: HuggingFace model ID or local diffusers directory (e.g., `Tongyi-MAI/Z-Image-Turbo`) + +Environment variables: +- `SGLANG_TEST_ZIMAGE_MODEL_PATH` - Path to ZImage model (default: `Tongyi-MAI/Z-Image-Turbo`) +- `SGLANG_TEST_FLUX_MODEL_PATH` - Path to Flux model (default: `black-forest-labs/FLUX.1-dev`) +- `SGLANG_TEST_QWEN_IMAGE_MODEL_PATH` - Path to QwenImage model (default: `Qwen/Qwen-Image`) +- `SGLANG_TEST_QWEN_IMAGE_EDIT_MODEL_PATH` - Path to QwenImageEdit model (default: `Qwen/Qwen-Image-Edit-2511`) + +Examples: + +```bash +# Using HuggingFace model ID (diffusers format) +export SGLANG_TEST_ZIMAGE_MODEL_PATH="Tongyi-MAI/Z-Image-Turbo" +pytest python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_zimage_pipeline.py -v -s + +# Using safetensors file +export SGLANG_TEST_ZIMAGE_MODEL_PATH="/path/to/z_image_turbo_bf16.safetensors" +pytest python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_zimage_pipeline.py -v -s +``` + +## Test Structure + +Each test file follows a similar structure: + +1. **Setup**: Creates a `DiffGenerator` with the appropriate pipeline class +2. **Input Preparation**: Creates dummy tensors for latents, timesteps, and embeddings +3. **Request Preparation**: Uses `prepare_request` to convert `SamplingParams` to `Req` +4. **ComfyUI Inputs**: Sets ComfyUI-specific inputs directly on the `Req` object +5. **Execution**: Sends request to scheduler and waits for response +6. **Validation**: Checks that `noise_pred` is retrieved from `OutputBatch` + +## Notes + +- These tests use `comfyui_mode=True` to enable ComfyUI-specific behavior +- Tests use pre-processed inputs (latents, timesteps, embeddings) as ComfyUI would provide +- The tests verify that `noise_pred` can be retrieved from the `OutputBatch` after processing +- All tests use dummy/ones tensors for simplicity - in production, these would be actual model outputs diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/__init__.py b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d62576041f67d7647740a2f98256e8a5488ccf0 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/__init__.py @@ -0,0 +1,9 @@ +""" +Test suite for ComfyUI SGLDiffusion pipelines. + +This package contains tests for each ComfyUI pipeline integration: +- ZImagePipeline +- FluxPipeline +- QwenImagePipeline +- QwenImageEditPipeline +""" diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_flux_pipeline.py b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_flux_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..37d67ecdfb14de0bd40bec18d2bae74c162d7898 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_flux_pipeline.py @@ -0,0 +1,162 @@ +"""Test for ComfyUIFluxPipeline with pass-through scheduler.""" + +import os + +import pytest +import torch + +from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams +from sglang.multimodal_gen.runtime.entrypoints.diffusion_generator import DiffGenerator +from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request + + +def test_comfyui_flux_pipeline_direct() -> None: + """Test ComfyUIFluxPipeline with custom inputs.""" + model_path = os.environ.get( + "SGLANG_TEST_FLUX_MODEL_PATH", + "black-forest-labs/FLUX.1-dev", # Supports both safetensors file and diffusers format + ) + + generator = DiffGenerator.from_pretrained( + model_path=model_path, + pipeline_class_name="ComfyUIFluxPipeline", + num_gpus=2, + comfyui_mode=True, + ) + + batch_size = 1 + hidden_states_seq_len = 3600 + hidden_states_dim = 64 + height = 1280 + width = 720 + + encoder_seq_len = 512 + encoder_dim = 4096 + pooled_dim = 768 + + hidden_states = torch.ones( + batch_size, + hidden_states_seq_len, + hidden_states_dim, + device="cuda", + dtype=torch.bfloat16, + ) + + encoder_hidden_states = torch.ones( + batch_size, + encoder_seq_len, + encoder_dim, + device="cuda", + dtype=torch.bfloat16, + ) + + pooled_projections = torch.ones( + batch_size, + pooled_dim, + device="cuda", + dtype=torch.bfloat16, + ) + + timesteps = torch.tensor([1000], dtype=torch.long, device="cuda") + + sampling_params = SamplingParams.from_user_sampling_params_args( + generator.server_args.model_path, + server_args=generator.server_args, + prompt="a beautiful girl", + height=height, + width=width, + num_frames=1, + num_inference_steps=1, + save_output=True, + return_trajectory_latents=True, + ) + + req = prepare_request( + server_args=generator.server_args, + sampling_params=sampling_params, + ) + + req.latents = hidden_states + req.timesteps = timesteps + req.raw_latent_shape = torch.tensor(hidden_states.shape, dtype=torch.long) + + clip_dim = 768 + dummy_clip_embedding = torch.zeros( + batch_size, + 77, + clip_dim, + device="cuda", + dtype=torch.bfloat16, + ) + req.prompt_embeds = [pooled_projections, encoder_hidden_states] + + if req.guidance_scale > 1.0: + dummy_neg_clip_embedding = torch.zeros( + batch_size, + 77, + clip_dim, + device="cuda", + dtype=torch.bfloat16, + ) + negative_encoder_hidden_states = torch.ones( + batch_size, + encoder_seq_len, + encoder_dim, + device="cuda", + dtype=torch.bfloat16, + ) + req.negative_prompt_embeds = [ + dummy_neg_clip_embedding, + negative_encoder_hidden_states, + ] + else: + req.negative_prompt_embeds = None + + req.pooled_embeds = [pooled_projections] + req.neg_pooled_embeds = [] + + if ( + req.guidance_scale > 1.0 + and req.negative_prompt_embeds is not None + and len(req.negative_prompt_embeds) > 0 + ): + req.do_classifier_free_guidance = True + else: + req.do_classifier_free_guidance = False + + if req.seed is not None: + generator_device = req.generator_device + device_str = "cuda" if generator_device == "cuda" else "cpu" + req.generator = [ + torch.Generator(device_str).manual_seed(req.seed + i) + for i in range(req.num_outputs_per_prompt) + ] + else: + req.generator = [ + torch.Generator("cuda") for _ in range(req.num_outputs_per_prompt) + ] + + output_batch = generator._send_to_scheduler_and_wait_for_response([req]) + noise_pred = output_batch.noise_pred + + assert noise_pred is not None, "noise_pred should not be None in OutputBatch" + assert isinstance(noise_pred, torch.Tensor), "noise_pred should be a torch.Tensor" + assert ( + noise_pred.device.type == "cuda" + ), f"noise_pred should be on cuda, got {noise_pred.device}" + assert ( + noise_pred.dtype == torch.bfloat16 + ), f"noise_pred should be bfloat16, got {noise_pred.dtype}" + + print(f"✓ Successfully retrieved noise_pred from OutputBatch!") + print(f" noise_pred shape: {noise_pred.shape}") + print(f" noise_pred dtype: {noise_pred.dtype}") + print(f" noise_pred device: {noise_pred.device}") + + latents = output_batch.output if output_batch.output is not None else req.latents + assert latents is not None, "latents should not be None" + print(f"latents.shape: {latents.shape}") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_qwen_image_edit_pipeline.py b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_qwen_image_edit_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..e36883c2a5889ed043c03cf49dfa0c8eda8b8ab8 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_qwen_image_edit_pipeline.py @@ -0,0 +1,135 @@ +"""Test for ComfyUIQwenImageEditPipeline with pass-through scheduler (I2I/edit mode).""" + +import os + +import pytest +import torch + +from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams +from sglang.multimodal_gen.runtime.entrypoints.diffusion_generator import DiffGenerator +from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request + + +def test_comfyui_qwen_image_edit_pipeline_direct() -> None: + """Test ComfyUIQwenImageEditPipeline with edit mode (I2I) and custom inputs.""" + model_path = os.environ.get( + "SGLANG_TEST_QWEN_IMAGE_EDIT_MODEL_PATH", + "Qwen/Qwen-Image-Edit-2511", # Supports both safetensors file and diffusers format + ) + + generator = DiffGenerator.from_pretrained( + model_path=model_path, + pipeline_class_name="ComfyUIQwenImageEditPipeline", + num_gpus=1, + comfyui_mode=True, + dit_layerwise_offload=False, + ) + + batch_size = 1 + noisy_image_seq_len = 3600 + hidden_states_dim = 64 + condition_image_seq_len = 6889 + condition_image_dim = 64 + encoder_seq_len = 45 + encoder_dim = 3584 + height = 720 + width = 1280 + + vae_scale_factor = 8 + condition_height_latent = 1328 // vae_scale_factor + condition_width_latent = 1328 // vae_scale_factor + + noisy_image_latents = torch.ones( + batch_size, + noisy_image_seq_len, + hidden_states_dim, + device="cuda", + dtype=torch.bfloat16, + ) + + condition_image_latents = torch.ones( + batch_size, + condition_image_seq_len, + condition_image_dim, + device="cuda", + dtype=torch.bfloat16, + ) + + encoder_hidden_states = torch.ones( + batch_size, + encoder_seq_len, + encoder_dim, + device="cuda", + dtype=torch.bfloat16, + ) + + timesteps = torch.tensor([1000], dtype=torch.long, device="cuda") + + sampling_params = SamplingParams.from_user_sampling_params_args( + generator.server_args.model_path, + server_args=generator.server_args, + prompt=" ", + guidance_scale=1.0, + height=height, + width=width, + image_path="", + num_frames=1, + num_inference_steps=1, + seed=42, + save_output=False, + return_frames=False, + ) + + req = prepare_request( + server_args=generator.server_args, + sampling_params=sampling_params, + ) + + req.latents = noisy_image_latents + req.image_latent = condition_image_latents + req.timesteps = timesteps + req.prompt_embeds = [encoder_hidden_states] + req.negative_prompt_embeds = None + req.vae_image_sizes = [(condition_width_latent, condition_height_latent)] + req.raw_latent_shape = torch.tensor(noisy_image_latents.shape, dtype=torch.long) + + if req.guidance_scale > 1.0 and req.negative_prompt_embeds is not None: + req.do_classifier_free_guidance = True + else: + req.do_classifier_free_guidance = False + + if req.seed is not None: + generator_device = req.generator_device + device_str = "cpu" if generator_device == "cpu" else "cuda" + req.generator = [ + torch.Generator(device_str).manual_seed(req.seed + i) + for i in range(req.num_outputs_per_prompt) + ] + else: + req.generator = [ + torch.Generator("cuda") for _ in range(req.num_outputs_per_prompt) + ] + + output_batch = generator._send_to_scheduler_and_wait_for_response([req]) + noise_pred = output_batch.noise_pred + + assert noise_pred is not None, "noise_pred should not be None in OutputBatch" + assert isinstance(noise_pred, torch.Tensor), "noise_pred should be a torch.Tensor" + assert ( + noise_pred.device.type == "cuda" + ), f"noise_pred should be on cuda, got {noise_pred.device}" + assert ( + noise_pred.dtype == torch.bfloat16 + ), f"noise_pred should be bfloat16, got {noise_pred.dtype}" + + print(f"✓ Successfully retrieved noise_pred from OutputBatch (Edit Mode)!") + print(f" noise_pred shape: {noise_pred.shape}") + print(f" noise_pred dtype: {noise_pred.dtype}") + print(f" noise_pred device: {noise_pred.device}") + + latents = output_batch.output if output_batch.output is not None else req.latents + assert latents is not None, "latents should not be None" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_qwen_image_pipeline.py b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_qwen_image_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..43613fa0ae0b124f5cec94c07eeda569ee670cd6 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_qwen_image_pipeline.py @@ -0,0 +1,119 @@ +"""Test for ComfyUIQwenImagePipeline with pass-through scheduler.""" + +import os + +import pytest +import torch + +from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams +from sglang.multimodal_gen.runtime.entrypoints.diffusion_generator import DiffGenerator +from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request + + +def test_comfyui_qwen_image_pipeline_direct() -> None: + """Test ComfyUIQwenImagePipeline with custom inputs.""" + model_path = os.environ.get( + "SGLANG_TEST_QWEN_IMAGE_MODEL_PATH", + "Qwen/Qwen-Image", # Supports both safetensors file and diffusers format + ) + + generator = DiffGenerator.from_pretrained( + model_path=model_path, + pipeline_class_name="ComfyUIQwenImagePipeline", + num_gpus=2, + comfyui_mode=True, + dit_layerwise_offload=False, + ) + + batch_size = 1 + hidden_states_seq_len = 6889 + hidden_states_dim = 64 + encoder_seq_len = 45 + encoder_dim = 3584 + height = 1328 + width = 1328 + dtype = torch.bfloat16 + + hidden_states = torch.ones( + batch_size, + hidden_states_seq_len, + hidden_states_dim, + device="cuda", + dtype=dtype, + ) + + encoder_hidden_states = torch.ones( + batch_size, + encoder_seq_len, + encoder_dim, + device="cuda", + dtype=torch.bfloat16, + ) + + timesteps = torch.tensor([1000], dtype=torch.long, device="cuda") + + sampling_params = SamplingParams.from_user_sampling_params_args( + generator.server_args.model_path, + server_args=generator.server_args, + prompt=" ", + guidance_scale=3.0, + height=height, + width=width, + num_frames=1, + num_inference_steps=1, + seed=42, + save_output=False, + return_frames=False, + ) + + req = prepare_request( + server_args=generator.server_args, + sampling_params=sampling_params, + ) + + req.latents = hidden_states + req.timesteps = timesteps + req.prompt_embeds = [encoder_hidden_states] + req.negative_prompt_embeds = [encoder_hidden_states] + req.raw_latent_shape = torch.tensor(hidden_states.shape, dtype=torch.long) + + if req.guidance_scale > 1.0 and req.negative_prompt_embeds is not None: + req.do_classifier_free_guidance = True + else: + req.do_classifier_free_guidance = False + + if req.seed is not None: + generator_device = req.generator_device + device_str = "cpu" if generator_device == "cpu" else "cuda" + req.generator = [ + torch.Generator(device_str).manual_seed(req.seed + i) + for i in range(req.num_outputs_per_prompt) + ] + else: + req.generator = [ + torch.Generator("cuda") for _ in range(req.num_outputs_per_prompt) + ] + + output_batch = generator._send_to_scheduler_and_wait_for_response([req]) + noise_pred = output_batch.noise_pred + + assert noise_pred is not None, "noise_pred should not be None in OutputBatch" + assert isinstance(noise_pred, torch.Tensor), "noise_pred should be a torch.Tensor" + assert ( + noise_pred.device.type == "cuda" + ), f"noise_pred should be on cuda, got {noise_pred.device}" + assert ( + noise_pred.dtype == torch.bfloat16 + ), f"noise_pred should be bfloat16, got {noise_pred.dtype}" + + print(f"✓ Successfully retrieved noise_pred from OutputBatch!") + print(f" noise_pred shape: {noise_pred.shape}") + print(f" noise_pred dtype: {noise_pred.dtype}") + print(f" noise_pred device: {noise_pred.device}") + + latents = output_batch.output if output_batch.output is not None else req.latents + assert latents is not None, "latents should not be None" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_zimage_pipeline.py b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_zimage_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..8ed1308f32aeceac56a43aaa449a418f21eb9c52 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/test/test_zimage_pipeline.py @@ -0,0 +1,121 @@ +"""Test for ComfyUIZImagePipeline with pass-through scheduler.""" + +import os + +import pytest +import torch + +from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams +from sglang.multimodal_gen.runtime.entrypoints.diffusion_generator import DiffGenerator +from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request + + +def test_comfyui_zimage_pipeline_direct() -> None: + """Test ComfyUIZImagePipeline with custom inputs.""" + model_path = os.environ.get( + "SGLANG_TEST_ZIMAGE_MODEL_PATH", + "Tongyi-MAI/Z-Image-Turbo", # Supports both safetensors file and diffusers format + ) + + generator = DiffGenerator.from_pretrained( + model_path=model_path, + pipeline_class_name="ComfyUIZImagePipeline", + num_gpus=1, + sp_degree=1, + comfyui_mode=True, + ) + + batch_size = 1 + num_channels = 16 + num_frames = 1 + height = 720 + width = 1280 + latent_height = height // 8 + latent_width = width // 8 + + latents = torch.ones( + batch_size, + num_channels, + num_frames, + latent_height, + latent_width, + device="cuda", + dtype=torch.bfloat16, + ) + + timesteps = torch.tensor([1000], dtype=torch.long, device="cuda") + + context_seq_len = 19 + context_dim = 2560 + context = torch.ones( + context_seq_len, + context_dim, + device="cuda", + dtype=torch.bfloat16, + ) + + sampling_params = SamplingParams.from_user_sampling_params_args( + generator.server_args.model_path, + server_args=generator.server_args, + prompt="a beautiful girl", + guidance_scale=1.0, + height=height, + width=width, + num_frames=1, + num_inference_steps=1, + seed=42, + save_output=False, + return_frames=False, + ) + + req = prepare_request( + server_args=generator.server_args, + sampling_params=sampling_params, + ) + + req.latents = latents + req.timesteps = timesteps + req.prompt_embeds = [context] + req.negative_prompt_embeds = None + req.raw_latent_shape = torch.tensor(latents.shape, dtype=torch.long) + + if req.guidance_scale > 1.0 and req.negative_prompt_embeds is not None: + req.do_classifier_free_guidance = True + else: + req.do_classifier_free_guidance = False + + if req.seed is not None: + generator_device = req.generator_device + device_str = "cpu" if generator_device == "cpu" else "cuda" + req.generator = [ + torch.Generator(device_str).manual_seed(req.seed + i) + for i in range(req.num_outputs_per_prompt) + ] + else: + req.generator = [ + torch.Generator("cuda") for _ in range(req.num_outputs_per_prompt) + ] + + output_batch = generator._send_to_scheduler_and_wait_for_response([req]) + noise_pred = output_batch.noise_pred + + assert noise_pred is not None, "noise_pred should not be None in OutputBatch" + assert isinstance(noise_pred, torch.Tensor), "noise_pred should be a torch.Tensor" + assert ( + noise_pred.device.type == "cuda" + ), f"noise_pred should be on cuda, got {noise_pred.device}" + assert ( + noise_pred.dtype == torch.bfloat16 + ), f"noise_pred should be bfloat16, got {noise_pred.dtype}" + + print(f"✓ Successfully retrieved noise_pred from OutputBatch!") + print(f" noise_pred shape: {noise_pred.shape}") + print(f" noise_pred dtype: {noise_pred.dtype}") + print(f" noise_pred device: {noise_pred.device}") + + latents = output_batch.output if output_batch.output is not None else req.latents + assert latents is not None, "latents should not be None" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/utils.py b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..093b901e4705aa7d33929917227e7f55e573d76d --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/utils.py @@ -0,0 +1,176 @@ +import base64 +import io +import os +import shutil +import time +import uuid + +import folder_paths +import numpy as np +import torch +from comfy_api.input import VideoInput +from PIL import Image + + +def _ensure_dir(path: str) -> None: + os.makedirs(path, exist_ok=True) + + +def _to_numpy_image(image: torch.Tensor) -> np.ndarray: + """Convert ComfyUI image tensor to uint8 numpy array (H, W, C).""" + if image.dim() == 4: + image = image[0] + if image.dim() == 3 and image.shape[0] in (1, 3, 4): + image = image.permute(1, 2, 0) + elif image.dim() == 2: + image = image.unsqueeze(-1) + np_img = image.detach().cpu().numpy() + np_img = np.clip(np_img, 0.0, 1.0) + np_img = (np_img * 255).astype(np.uint8) + if np_img.shape[-1] == 1: + np_img = np.repeat(np_img, 3, axis=-1) + return np_img + + +def _to_hwc_tensor(image: torch.Tensor) -> torch.Tensor: + """Convert ComfyUI image tensor to HWC format (normalized [0, 1]).""" + img = image.clone() + if img.dim() == 4: + img = img[0] + if img.dim() == 3 and img.shape[0] in (1, 3, 4): + img = img.permute(1, 2, 0) + elif img.dim() == 2: + img = img.unsqueeze(-1) + + img = torch.clamp(img, 0.0, 1.0) + if img.shape[-1] == 1: + img = img.repeat(1, 1, 3) + + return img + + +def is_empty_image(image: torch.Tensor, tolerance: float = 1e-6) -> bool: + """ + Check if the input image is an empty/solid color image (like ComfyUI's empty image). + Args: + image: Input tensor image in ComfyUI format (BCHW, CHW, HWC, etc.) + tolerance: Tolerance for floating point comparison (default: 1e-6) + + Returns: + True if the image is empty (all pixels have same color), False otherwise + """ + if image is None: + return True + + # Convert to HWC format + img_hwc = _to_hwc_tensor(image) + + # Get the first pixel's RGB values + first_pixel = img_hwc[0, 0, :] + + h, w, c = img_hwc.shape + pixels = img_hwc.reshape(-1, c) + + diff = torch.abs(pixels - first_pixel) + max_diff = torch.max(diff) + + return max_diff.item() <= tolerance + + +def get_image_path(image: torch.Tensor) -> str: + """ + Save tensor image to ComfyUI temp directory as PNG and return the path. + """ + temp_dir = folder_paths.get_temp_directory() + + # Build file name + ts = time.strftime("%Y%m%d-%H%M%S") + unique = uuid.uuid4().hex[:8] + file_name = f"sgl_output_{ts}_{unique}.png" + file_path = os.path.join(temp_dir, file_name) + + # Save image + np_img = _to_numpy_image(image) + img = Image.fromarray(np_img) + img.save(file_path, format="PNG") + + return file_path + + +def convert_b64_to_tensor_image(b64_image: str) -> torch.Tensor: + """ + Convert base64 encoded image to ComfyUI IMAGE format (torch.Tensor). + + Args: + b64_image: Base64 encoded image string + + Returns: + torch.Tensor with shape [batch_size, height, width, channels] (BHWC format), + values normalized to [0, 1] range, RGB format (3 channels) + """ + # Decode base64 + image_bytes = base64.b64decode(b64_image) + + # Open image and convert to RGB + pil_image = Image.open(io.BytesIO(image_bytes)) + if pil_image.mode != "RGB": + pil_image = pil_image.convert("RGB") + + # Convert to numpy array and normalize to [0, 1] + image_array = np.array(pil_image).astype(np.float32) / 255.0 + + # Add batch dimension: [height, width, channels] -> [1, height, width, channels] + image_array = image_array[np.newaxis, ...] + + # Convert to torch.Tensor + tensor_image = torch.from_numpy(image_array) + + return tensor_image + + +class SGLDVideoInput(VideoInput): + def __init__(self, video_path: str, height: int, width: int): + super().__init__() + + self.video_path = video_path + self.height = height + self.width = width + + def get_dimensions(self) -> tuple[int, int]: + """ + Returns the dimensions of the video input. + + Returns: + Tuple of (width, height) + """ + return self.width, self.height + + def get_components(self): + """ + Returns the components of the video input. + This is required by the VideoInput abstract base class. + """ + return [self.video_path] + + def save_to(self, path: str, format=None, codec=None, metadata=None): + """ + Abstract method to save the video input to a file. + """ + save_path = path + # Copy video file from video_path to save_path + if os.path.exists(self.video_path): + # Ensure destination directory exists + save_dir = os.path.dirname(save_path) + if save_dir: + os.makedirs(save_dir, exist_ok=True) + shutil.copy2(self.video_path, save_path) + + +def convert_video_to_comfy_video( + video_path: str, height: int, width: int +) -> VideoInput: + """ + Convert video to ComfyUI VIDEO format (VideoInput). + """ + video_input = SGLDVideoInput(video_path, height, width) + return video_input diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/flux_sgld_sp.json b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/flux_sgld_sp.json new file mode 100644 index 0000000000000000000000000000000000000000..d158a07084480cd5ad0296afc9b8710bfafbb8d1 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/flux_sgld_sp.json @@ -0,0 +1,222 @@ +{ + "8": { + "inputs": { + "samples": [ + "40", + 0 + ], + "vae": [ + "10", + 0 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "10": { + "inputs": { + "vae_name": "ae.safetensors" + }, + "class_type": "VAELoader", + "_meta": { + "title": "Load VAE" + } + }, + "11": { + "inputs": { + "clip_name1": "t5xxl_fp16.safetensors", + "clip_name2": "clip_l.safetensors", + "type": "flux", + "device": "default" + }, + "class_type": "DualCLIPLoader", + "_meta": { + "title": "DualCLIPLoader" + } + }, + "17": { + "inputs": { + "scheduler": "normal", + "steps": 25, + "denoise": 1, + "model": [ + "46", + 0 + ] + }, + "class_type": "BasicScheduler", + "_meta": { + "title": "BasicScheduler" + } + }, + "38": { + "inputs": { + "model": [ + "46", + 0 + ], + "conditioning": [ + "42", + 0 + ] + }, + "class_type": "BasicGuider", + "_meta": { + "title": "BasicGuider" + } + }, + "39": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + }, + "40": { + "inputs": { + "noise": [ + "45", + 0 + ], + "guider": [ + "38", + 0 + ], + "sampler": [ + "47", + 0 + ], + "sigmas": [ + "17", + 0 + ], + "latent_image": [ + "44", + 0 + ] + }, + "class_type": "SamplerCustomAdvanced", + "_meta": { + "title": "SamplerCustomAdvanced" + } + }, + "42": { + "inputs": { + "guidance": 3.5, + "conditioning": [ + "43", + 0 + ] + }, + "class_type": "FluxGuidance", + "_meta": { + "title": "FluxGuidance" + } + }, + "43": { + "inputs": { + "text": "beautiful photography of a gonger haired artist with Lots of Colorful coloursplashes in face and pn her hands, she is natural, having her hair in a casual bun, looking happily into camera, cinematic,", + "speak_and_recognation": { + "__value__": [ + false, + true + ] + }, + "clip": [ + "11", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "44": { + "inputs": { + "width": 1024, + "height": 1024, + "batch_size": 1 + }, + "class_type": "EmptySD3LatentImage", + "_meta": { + "title": "EmptySD3LatentImage" + } + }, + "45": { + "inputs": { + "noise_seed": 747172083610812 + }, + "class_type": "RandomNoise", + "_meta": { + "title": "RandomNoise" + } + }, + "46": { + "inputs": { + "max_shift": 1.15, + "base_shift": 0.5, + "width": 1024, + "height": 1024, + "model": [ + "51", + 0 + ] + }, + "class_type": "ModelSamplingFlux", + "_meta": { + "title": "ModelSamplingFlux" + } + }, + "47": { + "inputs": { + "sampler_name": "euler" + }, + "class_type": "KSamplerSelect", + "_meta": { + "title": "KSamplerSelect" + } + }, + "51": { + "inputs": { + "unet_name": "flux1-dev.safetensors", + "weight_dtype": "default", + "sgld_options": [ + "52", + 0 + ] + }, + "class_type": "SGLDUNETLoader", + "_meta": { + "title": "SGLDiffusion UNET Loader" + } + }, + "52": { + "inputs": { + "model_type": "auto-detect", + "enable_torch_compile": false, + "num_gpus": 2, + "tp_size": -1, + "sp_degree": -1, + "ulysses_degree": -1, + "ring_degree": -1, + "dp_size": 1, + "dp_degree": 1, + "enable_cfg_parallel": false, + "attention_backend": "", + "cache_strategy": "none" + }, + "class_type": "SGLDOptions", + "_meta": { + "title": "SGLDiffusion Options" + } + } + } diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/qwen_image_sgld.json b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/qwen_image_sgld.json new file mode 100644 index 0000000000000000000000000000000000000000..f7f31c5d2bef7522ea528f03594d591d04247d4c --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/qwen_image_sgld.json @@ -0,0 +1,165 @@ +{ + "3": { + "inputs": { + "seed": 808633539418610, + "steps": 4, + "cfg": 1, + "sampler_name": "euler", + "scheduler": "simple", + "denoise": 1, + "model": [ + "66", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "58", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "6": { + "inputs": { + "text": "\"A vibrant, warm neon-lit street scene in Hong Kong at the afternoon, with a mix of colorful Chinese and English signs glowing brightly. The atmosphere is lively, cinematic, and rain-washed with reflections on the pavement. The colors are vivid, full of pink, blue, red, and green hues. Crowded buildings with overlapping neon signs. 1980s Hong Kong style. Signs include:\n\"龍鳳冰室\" \"金華燒臘\" \"HAPPY HAIR\" \"鴻運茶餐廳\" \"EASY BAR\" \"永發魚蛋粉\" \"添記粥麵\" \"SUNSHINE MOTEL\" \"美都餐室\" \"富記糖水\" \"太平館\" \"雅芳髮型屋\" \"STAR KTV\" \"銀河娛樂城\" \"百樂門舞廳\" \"BUBBLE CAFE\" \"萬豪麻雀館\" \"CITY LIGHTS BAR\" \"瑞祥香燭莊\" \"文記文具\" \"GOLDEN JADE HOTEL\" \"LOVELY BEAUTY\" \"合興百貨\" \"興旺電器\" And the background is warm yellow street and with all stores' lights on.", + "speak_and_recognation": { + "__value__": [ + false, + true + ] + }, + "clip": [ + "38", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Positive Prompt)" + } + }, + "7": { + "inputs": { + "text": "", + "speak_and_recognation": { + "__value__": [ + false, + true + ] + }, + "clip": [ + "38", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Negative Prompt)" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "39", + 0 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "38": { + "inputs": { + "clip_name": "qwen_2.5_vl_7b_fp8_scaled.safetensors", + "type": "qwen_image", + "device": "default" + }, + "class_type": "CLIPLoader", + "_meta": { + "title": "Load CLIP" + } + }, + "39": { + "inputs": { + "vae_name": "qwen_image_vae.safetensors" + }, + "class_type": "VAELoader", + "_meta": { + "title": "Load VAE" + } + }, + "58": { + "inputs": { + "width": 1328, + "height": 1328, + "batch_size": 1 + }, + "class_type": "EmptySD3LatentImage", + "_meta": { + "title": "EmptySD3LatentImage" + } + }, + "60": { + "inputs": { + "filename_prefix": "ComfyUI" + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + }, + "66": { + "inputs": { + "shift": 3.1000000000000005, + "model": [ + "78", + 0 + ] + }, + "class_type": "ModelSamplingAuraFlow", + "_meta": { + "title": "ModelSamplingAuraFlow" + } + }, + "77": { + "inputs": { + "unet_name": "qwen_image_2512_bf16.safetensors", + "weight_dtype": "default" + }, + "class_type": "SGLDUNETLoader", + "_meta": { + "title": "SGLDiffusion UNET Loader" + } + }, + "78": { + "inputs": { + "lora_name": "Qwen-Image-2512-Lightning-4steps-V1.0-bf16.safetensors", + "strength_model": 1, + "nickname": "", + "target": "all", + "model": [ + "77", + 0 + ] + }, + "class_type": "SGLDLoraLoader", + "_meta": { + "title": "SGLDiffusion LoRA Loader" + } + } + } diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/sgld_image2video.json b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/sgld_image2video.json new file mode 100644 index 0000000000000000000000000000000000000000..e3d5456ff6db174440c259d9d130339057aa4f07 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/sgld_image2video.json @@ -0,0 +1,97 @@ +{ + "1": { + "inputs": { + "base_url": "http://localhost:3000/v1", + "api_key": "sk-proj-1234567890" + }, + "class_type": "SGLDiffusionServerModel", + "_meta": { + "title": "SGLDiffusion Server Model" + } + }, + "3": { + "inputs": { + "prompt": "The girl turn the body and spin around in place.", + "main": "none", + "lighting": "none", + "speak_and_recognation": { + "__value__": [ + false, + true + ] + } + }, + "class_type": "easy prompt", + "_meta": { + "title": "Prompt" + } + }, + "4": { + "inputs": { + "text": "", + "anything": [ + "1", + 1 + ] + }, + "class_type": "easy showAnything", + "_meta": { + "title": "Show Any" + } + }, + "15": { + "inputs": { + "positive_prompt": [ + "3", + 0 + ], + "negative_prompt": "", + "seed": 2435791308, + "steps": 50, + "cfg": 4, + "width": 704, + "height": 1280, + "num_frames": 16, + "fps": 16, + "seconds": 1, + "enable_teacache": false, + "sgld_client": [ + "1", + 0 + ], + "image": [ + "17", + 0 + ] + }, + "class_type": "SGLDiffusionGenerateVideo", + "_meta": { + "title": "SGLDiffusion Generate Video" + } + }, + "16": { + "inputs": { + "filename_prefix": "video/ComfyUI", + "format": "auto", + "codec": "auto", + "video-preview": "", + "video": [ + "15", + 0 + ] + }, + "class_type": "SaveVideo", + "_meta": { + "title": "save video" + } + }, + "17": { + "inputs": { + "image": "tmpe_w0bd_0.jpg" + }, + "class_type": "LoadImage", + "_meta": { + "title": "load image" + } + } + } diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/sgld_text2img.json b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/sgld_text2img.json new file mode 100644 index 0000000000000000000000000000000000000000..65e74c7b3276c0d39c14c69d905a726980ff684b --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/sgld_text2img.json @@ -0,0 +1,109 @@ +{ + "1": { + "inputs": { + "base_url": "http://localhost:3000/v1", + "api_key": "sk-proj-1234567890" + }, + "class_type": "SGLDiffusionServerModel", + "_meta": { + "title": "SGLDiffusion Server Model" + } + }, + "3": { + "inputs": { + "prompt": "a bicycle, illustration in the style of SMPL, thick black lines on a white background", + "main": "none", + "lighting": "none", + "speak_and_recognation": { + "__value__": [ + false, + true + ] + } + }, + "class_type": "easy prompt", + "_meta": { + "title": "Prompt" + } + }, + "4": { + "inputs": { + "text": "", + "anything": [ + "1", + 1 + ] + }, + "class_type": "easy showAnything", + "_meta": { + "title": "Show Any" + } + }, + "5": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "6", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "save image" + } + }, + "6": { + "inputs": { + "positive_prompt": [ + "3", + 0 + ], + "negative_prompt": "", + "seed": 4215918563, + "steps": 50, + "cfg": 4, + "width": 512, + "height": 512, + "enable_teacache": false, + "sgld_client": [ + "11", + 0 + ], + "image": [ + "14", + 0 + ] + }, + "class_type": "SGLDiffusionGenerateImage", + "_meta": { + "title": "SGLDiffusion Generate Image" + } + }, + "11": { + "inputs": { + "lora_name": "dvyio/flux-lora-simple-illustration", + "lora_nickname": "", + "target": "all", + "sgld_client": [ + "1", + 0 + ] + }, + "class_type": "SGLDiffusionSetLora", + "_meta": { + "title": "SGLDiffusion Set LoRA" + } + }, + "14": { + "inputs": { + "width": 512, + "height": 512, + "batch_size": 1, + "color": 0 + }, + "class_type": "EmptyImage", + "_meta": { + "title": "empty image" + } + } + } diff --git a/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/z-image_sgld.json b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/z-image_sgld.json new file mode 100644 index 0000000000000000000000000000000000000000..8c56c29f46fa68bdfd1519c9640aa0c90f6ae81f --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows/z-image_sgld.json @@ -0,0 +1,140 @@ +{ + "3": { + "inputs": { + "seed": 3338398, + "steps": 9, + "cfg": 1, + "sampler_name": "euler", + "scheduler": "simple", + "denoise": 1, + "model": [ + "28", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "13", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "6": { + "inputs": { + "text": "cute anime style girl with massive fluffy fennec ears and a big fluffy tail blonde messy long hair blue eyes wearing a maid outfit with a long black gold leaf pattern dress and a white apron, it is a postcard held by a hand in front of a beautiful realistic city at sunset and there is cursive writing that says \"ZImage, Now in ComfyUI\"", + "speak_and_recognation": { + "__value__": [ + false, + true + ] + }, + "clip": [ + "18", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Positive Prompt)" + } + }, + "7": { + "inputs": { + "text": "blurry ugly bad", + "speak_and_recognation": { + "__value__": [ + false, + true + ] + }, + "clip": [ + "18", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Negative Prompt)" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "17", + 0 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "9": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + }, + "13": { + "inputs": { + "width": 1024, + "height": 1024, + "batch_size": 1 + }, + "class_type": "EmptySD3LatentImage", + "_meta": { + "title": "EmptySD3LatentImage" + } + }, + "17": { + "inputs": { + "vae_name": "ae.safetensors" + }, + "class_type": "VAELoader", + "_meta": { + "title": "VAE Loader" + } + }, + "18": { + "inputs": { + "clip_name": "qwen_3_4b.safetensors", + "type": "lumina2", + "device": "default" + }, + "class_type": "CLIPLoader", + "_meta": { + "title": "CLIP Loader" + } + }, + "28": { + "inputs": { + "unet_name": "z_image_turbo_bf16.safetensors", + "weight_dtype": "default" + }, + "class_type": "SGLDUNETLoader", + "_meta": { + "title": "SGLDiffusion UNET Loader" + } + } +} diff --git a/sglang/python/sglang/multimodal_gen/apps/webui/README.md b/sglang/python/sglang/multimodal_gen/apps/webui/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f0046e6f78fa7a31b78c99e45c1cfe1ca06a45b8 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/webui/README.md @@ -0,0 +1,58 @@ +# SGLang Diffusion WebUI User Guide + +SGLang Diffusion WebUI provides an intuitive Gradio-based interface for image and video generation, supporting parameter +tuning and real-time previews. + +## Prerequisites + +The WebUI runs on Gradio. To get started, install Gradio first: + +```bash +pip install gradio==6.1.0 +``` + +## Launch WebUI Service + +SGLang Diffusion now includes an integrated WebUI. Simply add the `--webui` parameter when starting the service. + +### Launch Text-to-Image Service + +```bash +sglang serve --model-path black-forest-labs/FLUX.1-dev --num-gpus 1 --webui --webui-port 2333 +``` + +### Launch Text-to-Video Service + +```bash +sglang serve --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers --num-gpus 1 --webui --webui-port 2333 +``` + +### Launch Image-to-Image Service +```bash +sglang serve --model-path Qwen/Qwen-Image-Edit-2511 --num-gpus 1 --webui --webui-port 2333 +``` + +### Launch Image-to-Video Service +```bash +sglang serve --model-path Wan-AI/Wan2.2-TI2V-5B-Diffusers --num-gpus 1 --webui --webui-port 2333 +``` + +## Port Forwarding + +Once the WebUI service is running, you need to use **SSH port forwarding** to securely access the remote service from +your local machine. + +In most cases: Your IDE (like VS Code, Cursor, etc.) can handle this automatically. Check your IDE's remote development +or port forwarding features. Otherwise, execute this command manually. + +```bash +ssh -L ${WEBUI_PORT}:localhost:${WEBUI_PORT} user_name@machine_name +``` + +Learn more about port forwarding: [Port Forwarding](https://en.wikipedia.org/wiki/Port_forwarding). + +## Interface Instructions + +You can view your model path and task name directly in the UI. We'd appreciate any feedback you'd like to share. + +Once launched, access the interface at `http://localhost:${WEBUI_PORT}` in your browser. diff --git a/sglang/python/sglang/multimodal_gen/apps/webui/__init__.py b/sglang/python/sglang/multimodal_gen/apps/webui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa9add784c73201aa518fa2836f1b3aaba52eb60 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/webui/__init__.py @@ -0,0 +1,3 @@ +from .main import run_sgl_diffusion_webui + +__all__ = ["run_sgl_diffusion_webui"] diff --git a/sglang/python/sglang/multimodal_gen/apps/webui/main.py b/sglang/python/sglang/multimodal_gen/apps/webui/main.py new file mode 100644 index 0000000000000000000000000000000000000000..61cac9ab701ab01b4fed59b36d971a1ae222d0c6 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/apps/webui/main.py @@ -0,0 +1,253 @@ +import argparse +import os + +from sglang.multimodal_gen.configs.sample.sampling_params import ( + DataType, + SamplingParams, +) +from sglang.multimodal_gen.runtime.entrypoints.utils import ( + post_process_sample, + prepare_request, +) +from sglang.multimodal_gen.runtime.scheduler_client import sync_scheduler_client +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.srt.environ import envs + +logger = init_logger(__name__) + + +def add_webui_args(parser: argparse.ArgumentParser): + """Add the arguments for the generate command.""" + parser = ServerArgs.add_cli_args(parser) + parser = SamplingParams.add_cli_args(parser) + return parser + + +def run_sgl_diffusion_webui(server_args: ServerArgs): + # import gradio in function to avoid CI crash + + import gradio as gr + + def resolve_model_repo_id(model_path: str) -> str: + from pathlib import Path + + from huggingface_hub.utils import HFValidationError, validate_repo_id + + try: + validate_repo_id(model_path) + return model_path + except HFValidationError: + pass + + p = Path(model_path).expanduser() + parts = p.parts + + if len(parts) < 2: + raise ValueError(f"Invalid model_path: {model_path}") + + candidate = f"{parts[-2]}/{parts[-1]}" + validate_repo_id(candidate) # let it raise if invalid + return candidate + + repo_id = resolve_model_repo_id(server_args.model_path) + if envs.SGLANG_USE_MODELSCOPE.get(): + from modelscope.hub.api import HubApi + + api = HubApi() + model_info_obj = api.model_info(repo_id) + task_name = model_info_obj.tasks[0]["Name"].replace("-synthesis", "") + else: + from huggingface_hub import model_info + + task_name = model_info(repo_id).pipeline_tag + + # init client + sync_scheduler_client.initialize(server_args) + + if task_name in ("text-to-video", "image-to-video", "video-to-video"): + task_type = "video" + elif task_name in ["text-to-image", "image-to-image"]: + task_type = "image" + else: + raise ValueError( + f"The task name {task_name} of model {server_args.model_path} is not a valid task name. Please check the model path." + ) + video_visible_only = task_type == "video" + image_visible_only = task_type == "image" + + # server_args will be reused in gradio_generate function + def gradio_generate( + prompt, + negative_prompt, + reference_image_paths_str, + seed, + num_frames, + frames_per_second, + width, + height, + num_inference_steps, + guidance_scale, + enable_teacache, + ): + """ + NOTE: The input and output of function which is called by gradio button must be gradio components + So we use global variable sampling_params_kwargs to avoid pass this param, because gradio does not support this. + return [ np.ndarray, None ] | [None, np.ndarray] + """ + if reference_image_paths_str: + if "," in reference_image_paths_str: + logger.warning( + f"Warning: please use English comma to separate the reference image paths, and the reference image paths is: {reference_image_paths_str}" + ) + reference_image_paths_str = reference_image_paths_str.replace(",", ",") + image_path = [path.strip() for path in reference_image_paths_str.split(",")] + else: + image_path = None + + sampling_params_kwargs = dict( + prompt=prompt, + negative_prompt=negative_prompt, + image_path=image_path, + seed=seed, + num_frames=num_frames, + fps=frames_per_second, + width=width, + height=height, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + enable_teacache=enable_teacache, + return_file_paths_only=False, + ) + sampling_params = SamplingParams.from_user_sampling_params_args( + server_args.model_path, + server_args=server_args, + **sampling_params_kwargs, + ) + batch = prepare_request( + server_args=server_args, + sampling_params=sampling_params, + ) + result = sync_scheduler_client.forward([batch]) + save_file_path = str(os.path.join(batch.output_path, batch.output_file_name)) + if result.output is None: + sampling_params_str = "\n".join( + [f"{key}: {value}" for key, value in sampling_params_kwargs.items()] + ) + no_output_msg = f"No output is generated by client, and their sampling params is: {sampling_params_str}" + + if batch.data_type == DataType.VIDEO: + if os.path.exists(save_file_path): + logger.warning(no_output_msg) + return None, save_file_path + else: + no_output_msg += f"\nAnd the expected output file was not found at: {save_file_path}" + raise ValueError(no_output_msg) + else: + raise ValueError(no_output_msg) + + frames = post_process_sample( + result.output[0], + batch.data_type, + batch.fps, + batch.save_output, + save_file_path, + ) + if batch.data_type == DataType.VIDEO: + # gradio video need video path to show video + return None, save_file_path + else: + return frames[0], None + + with gr.Blocks() as demo: + gr.Markdown("# 🚀 SGLang Diffusion Application") + with gr.Row(): + launched_model_box = gr.Textbox(label="Model", value=server_args.model_path) + task_name_box = gr.Textbox(label="Task name", value=task_name) + + with gr.Row(): + with gr.Column(scale=4): + prompt = gr.Textbox(label="Prompt", value="A curious raccoon") + negative_prompt = gr.Textbox( + label="Negative_prompt", + value="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", + ) + with gr.Column(scale=1): + seed = gr.Number(label="seed", precision=0, value=1234) + run_btn = gr.Button("Generate", variant="primary", size="lg") + + with gr.Row(): + with gr.Column(): + width = gr.Number(label="width", precision=0, value=720) + height = gr.Number(label="height", precision=0, value=480) + num_inference_steps = gr.Slider( + minimum=0, maximum=50, value=20, step=1, label="num_inference_steps" + ) + guidance_scale = gr.Slider( + minimum=0.0, maximum=10, value=5, step=0.01, label="guidance_scale" + ) + num_frames = gr.Slider( + minimum=1, + maximum=181, + value=81, + step=1, + label="num_frames", + visible=video_visible_only, + ) + frames_per_second = gr.Slider( + minimum=4, + maximum=60, + value=16, + step=1, + label="frames_per_second", + visible=video_visible_only, + ) + reference_image_paths_str = gr.Textbox( + label="reference images", + placeholder="Examples: 'image1.png, image2.png' or 'https://example.com/image1.png, https://example.com/image2.png'", + ) + enable_teacache = gr.Checkbox(label="enable_teacache", value=False) + + with gr.Column(): + image_out = gr.Image( + label="Generated Image", visible=image_visible_only, format="png" + ) + video_out = gr.Video( + label="Generated Video", visible=video_visible_only + ) + + run_btn.click( + fn=gradio_generate, + inputs=[ + prompt, + negative_prompt, + reference_image_paths_str, + seed, + num_frames, + frames_per_second, + width, + height, + num_inference_steps, + guidance_scale, + enable_teacache, + ], + outputs=[image_out, video_out], + ) + + _, local_url, _ = demo.launch( + server_port=server_args.webui_port, + quiet=True, + prevent_thread_lock=True, + show_error=True, + ) + + # print banner + delimiter = "=" * 80 + url = local_url or f"http://localhost:{server_args.webui_port}" + print(f""" +{delimiter} +\033[1mSGLang Diffusion WebUI available at:\033[0m \033[1;4;92m{url}\033[0m +{delimiter} +""") + + demo.block_thread() diff --git a/sglang/python/sglang/multimodal_gen/benchmarks/bench_offline_throughput.py b/sglang/python/sglang/multimodal_gen/benchmarks/bench_offline_throughput.py new file mode 100644 index 0000000000000000000000000000000000000000..5f0a5996c051906d608af25f68080fcca2a8b874 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/benchmarks/bench_offline_throughput.py @@ -0,0 +1,443 @@ +""" +Benchmark offline throughput for multimodal generation models (Image/Video Generation). + +This script benchmarks generation throughput without running a server, using low-level APIs. +It provides detailed metrics on throughput, latency, and resource utilization. + +# Usage Examples + +## Text-to-Video with VBench dataset +python -m sglang.multimodal_gen.benchmarks.bench_offline_throughput \\ + --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \\ + --dataset vbench \\ + --num-prompts 20 \\ + --batch-size 1 \\ + --width 512 --height 512 --num-frames 16 + +## Random dataset for stress testing +python -m sglang.multimodal_gen.benchmarks.bench_offline_throughput \\ + --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \\ + --dataset random \\ + --num-prompts 100 \\ + --batch-size 1 \\ + --num-inference-steps 20 \\ + --output-file results.json +""" + +import argparse +import dataclasses +import json +import time +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +import torch +from tqdm import tqdm + +from sglang.multimodal_gen.benchmarks.datasets import RandomDataset, VBenchDataset +from sglang.multimodal_gen.runtime.entrypoints.diffusion_generator import DiffGenerator +from sglang.multimodal_gen.runtime.server_args import ServerArgs, set_global_server_args +from sglang.multimodal_gen.runtime.utils.logging_utils import ( + configure_logger, + init_logger, +) +from sglang.multimodal_gen.test.test_utils import print_divider, print_value_formatted + +logger = init_logger(__name__) + + +@dataclass +class BatchOutput: + """Container for batch generation results.""" + + latency: float = 0.0 + latency_per_sample: float = 0.0 + num_samples: int = 0 + total_frames: int = 0 + peak_memory_mb: float = 0.0 + success: bool = False + error: str = "" + + +@dataclass +class BenchArgs: + """Benchmark configuration for multimodal generation.""" + + # Diffusion Model Configuration + num_inference_steps: int = 20 + guidance_scale: float = 7.5 + seed: int = 42 + disable_safety_checker: bool = False + + # Output Configuration + width: int = 32 + height: int = 32 + num_frames: int = 1 + fps: int = 24 + + # Dataset & Benchmark + dataset: str = "random" + dataset_path: str = "" + task_name: str = "unknown" + num_prompts: int = 10 + batch_size: int = 1 + + # Benchmark Execution + skip_warmup: bool = False + output_file: str = "" + disable_tqdm: bool = False + + @staticmethod + def add_cli_args(parser: argparse.ArgumentParser): + """Add benchmark-specific CLI arguments.""" + # Diffusion Model Configuration + parser.add_argument( + "--num-inference-steps", + type=int, + default=20, + help="Number of denoising steps", + ) + parser.add_argument( + "--guidance-scale", + type=float, + default=7.5, + help="Classifier-free guidance scale", + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument( + "--disable-safety-checker", + action="store_true", + help="Disable NSFW detection", + ) + + # Output Configuration + parser.add_argument("--width", type=int, default=32, help="Image/video width") + parser.add_argument("--height", type=int, default=32, help="Image/video height") + parser.add_argument( + "--num-frames", type=int, default=1, help="Number of frames for video" + ) + parser.add_argument("--fps", type=int, default=24, help="FPS for video") + + # Dataset & Benchmark + parser.add_argument( + "--dataset", + type=str, + default="random", + choices=["vbench", "random"], + help="Dataset to use", + ) + parser.add_argument( + "--dataset-path", + type=str, + default="", + help="Path to dataset (prompts file or image directory)", + ) + parser.add_argument( + "--task-name", + type=str, + default="unknown", + help="Task name for benchmark identification", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=10, + help="Total number of prompts to benchmark", + ) + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="Batch size per generation call (currently only bs=1 is supported)", + ) + + # Benchmark Execution + parser.add_argument( + "--skip-warmup", action="store_true", help="Skip warmup batch" + ) + parser.add_argument( + "--output-file", + type=str, + default="", + help="Output JSON file for results (append mode)", + ) + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Disable progress bar", + ) + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + """Create BenchArgs from parsed CLI arguments.""" + attrs = [attr.name for attr in dataclasses.fields(cls)] + return cls(**{attr: getattr(args, attr) for attr in attrs}) + + +def initialize_engine(server_args: ServerArgs) -> DiffGenerator: + """Initialize diffusion pipeline engine.""" + logger.info("Initializing engine...") + engine = DiffGenerator.from_server_args(server_args, local_mode=True) + logger.info("Engine initialized successfully") + return engine + + +def generate_batch( + engine: DiffGenerator, + bench_args: BenchArgs, + prompts: List[str], + user_sampling_params: Dict[str, Any], +) -> BatchOutput: + """Generate batch of images/videos synchronously.""" + output = BatchOutput() + start_time = time.perf_counter() + + torch.cuda.reset_peak_memory_stats() + + for prompt in prompts: + try: + sampling_params_kwargs = dict(user_sampling_params) + sampling_params_kwargs["prompt"] = prompt + result = engine.generate(sampling_params_kwargs=sampling_params_kwargs) + + if result is not None: + if isinstance(result, list): + output.total_frames += len(result) + else: + output.total_frames += 1 + output.num_samples += 1 + except Exception as e: + logger.error(f"Generation failed for prompt '{prompt[:50]}...': {e}") + output.error = str(e) + + output.latency = time.perf_counter() - start_time + output.latency_per_sample = output.latency / len(prompts) if prompts else 0.0 + output.success = output.num_samples > 0 + output.peak_memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024) + + logger.debug( + f"Batch generated: {output.num_samples}/{len(prompts)} samples in {output.latency:.2f}s" + ) + + return output + + +def calculate_metrics( + outputs: List[BatchOutput], + total_duration: float, + resolution: Tuple[int, int, int], + num_requests: int, +) -> Dict[str, Any]: + """Calculate generation-specific throughput metrics.""" + successful = [o for o in outputs if o.success] + num_success = sum(o.num_samples for o in successful) + total_frames = sum(o.total_frames for o in successful) + peak_memory = max((o.peak_memory_mb for o in outputs), default=0) + + width, height, frames = resolution + pixels_per_sample = width * height * frames + total_pixels = num_success * pixels_per_sample + + metrics = { + "num_requests": num_requests, + "successful_requests": num_success, + "failed_requests": num_requests - num_success, + "total_duration_seconds": total_duration, + "total_frames_generated": total_frames, + "total_pixels_generated": total_pixels, + "images_per_second": num_success / total_duration if total_duration > 0 else 0, + "frames_per_second": total_frames / total_duration if total_duration > 0 else 0, + "megapixels_per_second": ( + total_pixels / (total_duration * 1e6) if total_duration > 0 else 0 + ), + "requests_per_second": ( + num_success / total_duration if total_duration > 0 else 0 + ), + "latency_per_request_seconds": ( + total_duration / num_success if num_success > 0 else 0 + ), + "peak_memory_mb": peak_memory, + } + + return metrics + + +def throughput_test( + server_args: ServerArgs, + bench_args: BenchArgs, +) -> Dict[str, Any]: + """Main throughput benchmark function.""" + configure_logger(server_args=server_args) + logger.info("Starting offline throughput benchmark...") + + engine = initialize_engine(server_args) + + logger.info(f"Loading {bench_args.dataset} dataset...") + if bench_args.dataset == "vbench": + bench_args.task_name = engine.server_args.pipeline_config.task_type + dataset = VBenchDataset(bench_args) + elif bench_args.dataset == "random": + dataset = RandomDataset(bench_args) + else: + raise ValueError(f"Unknown dataset: {bench_args.dataset}") + + sampling_params = { + "guidance_scale": bench_args.guidance_scale, + "num_inference_steps": bench_args.num_inference_steps, + "height": bench_args.height, + "width": bench_args.width, + "num_frames": bench_args.num_frames, + "seed": bench_args.seed, + } + if bench_args.disable_safety_checker: + sampling_params["safety_checker"] = None + + if not bench_args.skip_warmup: + logger.info("Running warmup batch...") + warmup_count = min(bench_args.batch_size, len(dataset)) + warmup_prompts = [dataset[i].prompt for i in range(warmup_count)] + generate_batch(engine, bench_args, warmup_prompts, sampling_params) + + logger.info(f"Running benchmark with {bench_args.num_prompts} prompts...") + outputs: List[BatchOutput] = [] + total_count = min(bench_args.num_prompts, len(dataset)) + all_prompts = [dataset[i].prompt for i in range(total_count)] + + start_time = time.perf_counter() + + num_batches = (total_count + bench_args.batch_size - 1) // bench_args.batch_size + pbar = tqdm( + total=num_batches, + disable=bench_args.disable_tqdm, + desc="Benchmark", + ) + + for batch_start in range(0, total_count, bench_args.batch_size): + batch_end = min(batch_start + bench_args.batch_size, total_count) + batch_prompts = all_prompts[batch_start:batch_end] + + batch_output = generate_batch( + engine, bench_args, batch_prompts, sampling_params + ) + outputs.append(batch_output) + + pbar.update(1) + + pbar.close() + total_duration = time.perf_counter() - start_time + + resolution = (bench_args.width, bench_args.height, bench_args.num_frames) + metrics = calculate_metrics( + outputs, + total_duration, + resolution=resolution, + num_requests=total_count, + ) + + display_results( + metrics, + bench_args, + model_path=server_args.model_path, + ) + + if bench_args.output_file: + save_results(metrics, bench_args, server_args) + + return metrics + + +def display_results( + metrics: Dict[str, Any], + bench_args: BenchArgs, + model_path: str, +): + """Display benchmark results in console.""" + print( + "\n{s:{c}^{n}}".format(s=" Offline Throughput Benchmark Result ", n=110, c="=") + ) + print_value_formatted("Model:", model_path) + print_value_formatted("Dataset:", bench_args.dataset) + print_value_formatted( + "Resolution:", + f"{bench_args.width}x{bench_args.height}x{bench_args.num_frames}", + ) + print_value_formatted("Num Inference Steps:", bench_args.num_inference_steps) + print_divider(75) + print_value_formatted("Total Requests:", metrics["num_requests"]) + print_value_formatted("Successful Requests:", metrics["successful_requests"]) + print_value_formatted("Failed Requests:", metrics["failed_requests"]) + print_value_formatted( + "Total Duration (seconds):", metrics["total_duration_seconds"] + ) + print_divider(75) + print_value_formatted("Frames Generated:", metrics["total_frames_generated"]) + print_value_formatted( + "Megapixels Generated:", metrics["total_pixels_generated"] / 1e6 + ) + print_divider(75) + print_value_formatted( + "Frame Throughput (frames/sec):", metrics["frames_per_second"] + ) + print_value_formatted("MP Throughput (MP/sec):", metrics["megapixels_per_second"]) + print_value_formatted("Requests Per Second:", metrics["requests_per_second"]) + print_value_formatted( + "Latency Per Request (sec):", metrics["latency_per_request_seconds"] + ) + print_value_formatted("Peak Memory (MB):", metrics["peak_memory_mb"]) + print_divider(110, "=") + + +def save_results( + metrics: Dict[str, Any], + bench_args: BenchArgs, + server_args: ServerArgs, +): + """Save benchmark results to JSON file.""" + result = { + "metadata": { + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), + "model_path": server_args.model_path, + "task_type": bench_args.task_name, + "backend": "engine", + }, + "configuration": { + "num_inference_steps": bench_args.num_inference_steps, + "guidance_scale": bench_args.guidance_scale, + "seed": bench_args.seed, + "batch_size": bench_args.batch_size, + "num_prompts": bench_args.num_prompts, + "resolution": f"{bench_args.width}x{bench_args.height}x{bench_args.num_frames}", + "dataset": bench_args.dataset, + }, + "results": metrics, + } + + with open(bench_args.output_file, "a") as f: + f.write(json.dumps(result) + "\n") + + logger.info(f"Results saved to {bench_args.output_file}") + + +def main(): + """Main entry point.""" + parser = argparse.ArgumentParser( + description="Offline throughput benchmark for multimodal generation models" + ) + + ServerArgs.add_cli_args(parser) + BenchArgs.add_cli_args(parser) + + args = parser.parse_args() + + server_args = ServerArgs.from_cli_args(args) + bench_args = BenchArgs.from_cli_args(args) + + set_global_server_args(server_args) + + result = throughput_test(server_args, bench_args) + + return result + + +if __name__ == "__main__": + main() diff --git a/sglang/python/sglang/multimodal_gen/benchmarks/bench_serving.py b/sglang/python/sglang/multimodal_gen/benchmarks/bench_serving.py new file mode 100644 index 0000000000000000000000000000000000000000..1ec9721fa3ed06e05c5b1bddeb868c6897205394 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/benchmarks/bench_serving.py @@ -0,0 +1,581 @@ +""" +Benchmark online serving for diffusion models (Image/Video Generation). + + +Usage: + # launch a server and benchmark on it + + # T2V or T2I or any other multimodal generation model + sglang serve Wan-AI/Wan2.2-T2V-A14B-Diffusers --num-gpus 1 --port 1231 + + # benchmark it and make sure the port is the same as the server's port + python3 -m sglang.multimodal_gen.benchmarks.bench_serving --dataset vbench --num-prompts 20 --port 1231 +""" + +import argparse +import asyncio +import json +import os +import time +from typing import Any, Dict, List, Optional + +import aiohttp +import numpy as np +import requests +from tqdm.asyncio import tqdm + +from sglang.multimodal_gen.benchmarks.datasets import ( + RandomDataset, + RequestFuncInput, + RequestFuncOutput, + VBenchDataset, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import ( + configure_logger, + init_logger, +) +from sglang.multimodal_gen.test.test_utils import print_divider, print_value_formatted + +logger = init_logger(__name__) + + +async def async_request_image_sglang( + input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + output = RequestFuncOutput() + output.start_time = time.perf_counter() + + # Check if we need to use multipart (for image edits with input images) + if input.image_paths and len(input.image_paths) > 0: + # Use multipart/form-data for image edits + data = aiohttp.FormData() + data.add_field("model", input.model) + data.add_field("prompt", input.prompt) + data.add_field("response_format", "b64_json") + + if input.width and input.height: + data.add_field("size", f"{input.width}x{input.height}") + + # Merge extra parameters + for key, value in input.extra_body.items(): + data.add_field(key, str(value)) + + # Add image file(s) + for idx, img_path in enumerate(input.image_paths): + if os.path.exists(img_path): + data.add_field( + "image", + open(img_path, "rb"), + filename=os.path.basename(img_path), + content_type="application/octet-stream", + ) + else: + output.error = f"Image file not found: {img_path}" + output.success = False + if pbar: + pbar.update(1) + return output + + try: + async with session.post(input.api_url, data=data) as response: + if response.status == 200: + resp_json = await response.json() + output.response_body = resp_json + output.success = True + if "peak_memory_mb" in resp_json: + output.peak_memory_mb = resp_json["peak_memory_mb"] + else: + output.error = f"HTTP {response.status}: {await response.text()}" + output.success = False + except Exception as e: + output.error = str(e) + output.success = False + else: + # Use JSON for text-to-image generation + payload = { + "model": input.model, + "prompt": input.prompt, + "n": 1, + "response_format": "b64_json", + } + + if input.width and input.height: + payload["size"] = f"{input.width}x{input.height}" + + # Merge extra parameters + payload.update(input.extra_body) + + try: + async with session.post(input.api_url, json=payload) as response: + if response.status == 200: + resp_json = await response.json() + output.response_body = resp_json + output.success = True + if "peak_memory_mb" in resp_json: + output.peak_memory_mb = resp_json["peak_memory_mb"] + else: + output.error = f"HTTP {response.status}: {await response.text()}" + output.success = False + except Exception as e: + output.error = str(e) + output.success = False + + output.latency = time.perf_counter() - output.start_time + + if pbar: + pbar.update(1) + return output + + +async def async_request_video_sglang( + input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + output = RequestFuncOutput() + output.start_time = time.perf_counter() + + # 1. Submit Job + job_id = None + # Check if we need to upload images (Multipart) or just send JSON + if input.image_paths and len(input.image_paths) > 0: + # Use multipart/form-data + data = aiohttp.FormData() + data.add_field("model", input.model) + data.add_field("prompt", input.prompt) + + if input.width and input.height: + data.add_field("size", f"{input.width}x{input.height}") + + # Add extra body fields to form data if possible, or assume simple key-values + # Note: Nested dicts in extra_body might need JSON serialization if API expects it stringified + if input.extra_body: + data.add_field("extra_body", json.dumps(input.extra_body)) + + # Explicitly add fps/num_frames if they are not in extra_body (bench_serving logic overrides) + if input.num_frames: + data.add_field("num_frames", str(input.num_frames)) + if input.fps: + data.add_field("fps", str(input.fps)) + + # Add image file + # Currently only support single image upload as 'input_reference' per API spec + img_path = input.image_paths[0] + if os.path.exists(img_path): + data.add_field( + "input_reference", + open(img_path, "rb"), + filename=os.path.basename(img_path), + content_type="application/octet-stream", + ) + else: + output.error = f"Image file not found: {img_path}" + output.success = False + if pbar: + pbar.update(1) + return output + + try: + async with session.post(input.api_url, data=data) as response: + if response.status == 200: + resp_json = await response.json() + job_id = resp_json.get("id") + else: + output.error = ( + f"Submit failed HTTP {response.status}: {await response.text()}" + ) + output.success = False + if pbar: + pbar.update(1) + return output + except Exception as e: + output.error = f"Submit exception: {str(e)}" + output.success = False + if pbar: + pbar.update(1) + return output + + else: + # Use JSON + payload: Dict[str, Any] = { + "model": input.model, + "prompt": input.prompt, + } + if input.width and input.height: + payload["size"] = f"{input.width}x{input.height}" + if input.num_frames: + payload["num_frames"] = input.num_frames + if input.fps: + payload["fps"] = input.fps + + payload.update(input.extra_body) + + try: + async with session.post(input.api_url, json=payload) as response: + if response.status == 200: + resp_json = await response.json() + job_id = resp_json.get("id") + else: + output.error = ( + f"Submit failed HTTP {response.status}: {await response.text()}" + ) + output.success = False + if pbar: + pbar.update(1) + return output + except Exception as e: + output.error = f"Submit exception: {str(e)}" + output.success = False + if pbar: + pbar.update(1) + return output + + if not job_id: + output.error = "No job_id returned" + output.success = False + if pbar: + pbar.update(1) + return output + + # 2. Poll for completion + # Assuming the API returns a 'status' field. + # We construct the check URL. Assuming api_url is like .../v1/videos + # The check url should be .../v1/videos/{id} + check_url = f"{input.api_url}/{job_id}" + + while True: + try: + async with session.get(check_url) as response: + if response.status == 200: + status_data = await response.json() + status = status_data.get("status") + if status == "completed": + output.success = True + output.response_body = status_data + if "peak_memory_mb" in status_data: + output.peak_memory_mb = status_data["peak_memory_mb"] + break + elif status == "failed": + output.success = False + output.error = f"Job failed: {status_data.get('error')}" + break + else: + # queued or processing + await asyncio.sleep(1.0) + else: + output.success = False + output.error = ( + f"Poll failed HTTP {response.status}: {await response.text()}" + ) + break + except Exception as e: + output.success = False + output.error = f"Poll exception: {str(e)}" + break + + output.latency = time.perf_counter() - output.start_time + + if pbar: + pbar.update(1) + return output + + +def calculate_metrics(outputs: List[RequestFuncOutput], total_duration: float): + success_outputs = [o for o in outputs if o.success] + error_outputs = [o for o in outputs if not o.success] + + num_success = len(success_outputs) + latencies = [o.latency for o in success_outputs] + peak_memories = [o.peak_memory_mb for o in success_outputs if o.peak_memory_mb > 0] + + metrics = { + "duration": total_duration, + "completed_requests": num_success, + "failed_requests": len(error_outputs), + "throughput_qps": num_success / total_duration if total_duration > 0 else 0, + "latency_mean": np.mean(latencies) if latencies else 0, + "latency_median": np.median(latencies) if latencies else 0, + "latency_p99": np.percentile(latencies, 99) if latencies else 0, + "latency_p50": np.percentile(latencies, 50) if latencies else 0, + "peak_memory_mb_max": max(peak_memories) if peak_memories else 0, + "peak_memory_mb_mean": np.mean(peak_memories) if peak_memories else 0, + "peak_memory_mb_median": np.median(peak_memories) if peak_memories else 0, + } + + return metrics + + +def wait_for_service(base_url: str, timeout: int = 1200) -> None: + logger.info(f"Waiting for service at {base_url}...") + start_time = time.time() + while True: + try: + # Try /health endpoint first + resp = requests.get(f"{base_url}/health", timeout=1) + if resp.status_code == 200: + logger.info("Service is ready.") + break + except requests.exceptions.RequestException: + pass + + if time.time() - start_time > timeout: + raise TimeoutError( + f"Service at {base_url} did not start within {timeout} seconds." + ) + + time.sleep(1) + + +async def benchmark(args): + from huggingface_hub import model_info + + # Construct base_url if not provided + if args.base_url is None: + args.base_url = f"http://{args.host}:{args.port}" + + # Wait for service + wait_for_service(args.base_url) + + # Fetch model info + try: + resp = requests.get(f"{args.base_url}/v1/model_info", timeout=5) + if resp.status_code == 200: + info = resp.json() + if "model_path" in info and info["model_path"]: + args.model = info["model_path"] + logger.info(f"Updated model name from server: {args.model}") + except Exception as e: + logger.info(f"Failed to fetch model info: {e}. Using default: {args.model}") + + valid_tasks = ( + "text-to-video", + "image-to-video", + "video-to-video", + "text-to-image", + "image-to-image", + ) + + # Resolve task_name with priority: args.task > local config > HF pipeline_tag + if args.task: + task_name = args.task + logger.info(f"Using task from --task: {task_name}") + elif os.path.exists(args.model): + config_path = os.path.join(args.model, "config.json") + if os.path.exists(config_path): + with open(config_path, "r") as f: + config = json.load(f) + task_name = config.get("pipeline_tag", "text-to-image") + logger.info(f"Inferred task from local config.json: {task_name}") + else: + task_name = "text-to-image" + logger.info(f"No config.json found, defaulting task to: {task_name}") + else: + task_name = model_info(args.model).pipeline_tag + logger.info(f"Inferred task from HuggingFace pipeline_tag: {task_name}") + + if task_name not in valid_tasks: + raise ValueError( + f"Task '{task_name}' is not a valid multimodal generation task. " + f"Use --task to specify one of: {', '.join(valid_tasks)}" + ) + + if task_name in ("text-to-video", "image-to-video", "video-to-video"): + api_url = f"{args.base_url}/v1/videos" + request_func = async_request_video_sglang + else: # text-to-image or image-to-image + api_url = ( + f"{args.base_url}/v1/images/edits" + if task_name == "image-to-image" + else f"{args.base_url}/v1/images/generations" + ) + request_func = async_request_image_sglang + + setattr(args, "task_name", task_name) + + if args.dataset == "vbench": + dataset = VBenchDataset(args, api_url, args.model) + elif args.dataset == "random": + dataset = RandomDataset(args, api_url, args.model) + else: + raise ValueError(f"Unknown dataset: {args.dataset}") + + logger.info(f"Loading requests...") + requests_list = dataset.get_requests() + logger.info(f"Prepared {len(requests_list)} requests from {args.dataset} dataset.") + + # Limit concurrency + if args.max_concurrency is not None: + semaphore = asyncio.Semaphore(args.max_concurrency) + else: + semaphore = None + + async def limited_request_func(req, session, pbar): + if semaphore: + async with semaphore: + return await request_func(req, session, pbar) + else: + return await request_func(req, session, pbar) + + # Run benchmark + pbar = tqdm(total=len(requests_list), disable=args.disable_tqdm) + + async with aiohttp.ClientSession() as session: + start_time = time.perf_counter() + tasks = [] + for req in requests_list: + if args.request_rate != float("inf"): + # Poisson process: inter-arrival times follow exponential distribution + interval = np.random.exponential(1.0 / args.request_rate) + await asyncio.sleep(interval) + + task = asyncio.create_task(limited_request_func(req, session, pbar)) + tasks.append(task) + + outputs = await asyncio.gather(*tasks) + total_duration = time.perf_counter() - start_time + + pbar.close() + + # Calculate metrics + metrics = calculate_metrics(outputs, total_duration) + + print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=60, c="=")) + + # Section 1: Configuration + print_value_formatted("Task:", task_name) + print_value_formatted("Model:", args.model) + print_value_formatted("Dataset:", args.dataset) + + # Section 2: Execution & Traffic + print_divider(50) + print_value_formatted("Benchmark duration (s):", metrics["duration"]) + print_value_formatted("Request rate:", str(args.request_rate)) + print_value_formatted( + "Max request concurrency:", + str(args.max_concurrency) if args.max_concurrency else "not set", + ) + print_value_formatted( + "Successful requests:", + f"{metrics['completed_requests']}/{len(requests_list)}", + ) + + # Section 3: Performance Metrics + print_divider(50) + + print_value_formatted("Request throughput (req/s):", metrics["throughput_qps"]) + + print_value_formatted("Latency Mean (s):", metrics["latency_mean"]) + print_value_formatted("Latency Median (s):", metrics["latency_median"]) + print_value_formatted("Latency P99 (s):", metrics["latency_p99"]) + + if metrics["peak_memory_mb_max"] > 0: + print_divider(50) + print_value_formatted("Peak Memory Max (MB):", metrics["peak_memory_mb_max"]) + print_value_formatted("Peak Memory Mean (MB):", metrics["peak_memory_mb_mean"]) + print_value_formatted( + "Peak Memory Median (MB):", metrics["peak_memory_mb_median"] + ) + + print_divider(60) + + if args.output_file: + with open(args.output_file, "w") as f: + json.dump(metrics, f, indent=2) + print(f"Metrics saved to {args.output_file}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Benchmark serving for diffusion models." + ) + parser.add_argument( + "--backend", + type=str, + default=None, + help="DEPRECATED: --task is deprecated and will be ignored. The task will be inferred from --model.", + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Base URL of the server (e.g., http://localhost:30000). Overrides host/port.", + ) + parser.add_argument("--host", type=str, default="localhost", help="Server host.") + parser.add_argument("--port", type=int, default=30000, help="Server port.") + parser.add_argument("--model", type=str, default="default", help="Model name.") + parser.add_argument( + "--dataset", + type=str, + default="vbench", + choices=["vbench", "random"], + help="Dataset to use.", + ) + parser.add_argument( + "--task", + type=str, + choices=[ + "text-to-video", + "image-to-video", + "text-to-image", + "image-to-image", + "video-to-video", + ], + default=None, + help="The task will be inferred from huggingface pipeline_tag. When huggingface pipeline_tag is not provided, --task will be used.", + ) + parser.add_argument( + "--dataset-path", + type=str, + default=None, + help="Path to local dataset file (optional).", + ) + parser.add_argument( + "--num-prompts", type=int, default=10, help="Number of prompts to benchmark." + ) + parser.add_argument( + "--max-concurrency", + type=int, + default=1, + help="Maximum number of concurrent requests, default to `1`. This can be used " + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.", + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.", + ) + parser.add_argument("--width", type=int, default=None, help="Image/Video width.") + parser.add_argument("--height", type=int, default=None, help="Image/Video height.") + parser.add_argument( + "--num-frames", type=int, default=None, help="Number of frames (for video)." + ) + parser.add_argument("--fps", type=int, default=None, help="FPS (for video).") + parser.add_argument( + "--output-file", type=str, default=None, help="Output JSON file for metrics." + ) + parser.add_argument( + "--disable-tqdm", action="store_true", help="Disable progress bar." + ) + parser.add_argument( + "--log-level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Log level.", + ) + + args = parser.parse_args() + + configure_logger(args) + + asyncio.run(benchmark(args)) diff --git a/sglang/python/sglang/multimodal_gen/benchmarks/compare_perf.py b/sglang/python/sglang/multimodal_gen/benchmarks/compare_perf.py new file mode 100644 index 0000000000000000000000000000000000000000..d600d06d2b1075ec47bfd35e6af9a4847427f939 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/benchmarks/compare_perf.py @@ -0,0 +1,301 @@ +import argparse +import json +import os +import re +from datetime import datetime +from typing import Any, Dict, List, Tuple + + +def calculate_diff(base: float, new: float) -> Tuple[float, float]: + """Returns (diff, diff_percent).""" + diff = new - base + if base == 0: + percent = 0.0 + else: + percent = (diff / base) * 100 + return diff, percent + + +def calculate_upper_bound(baseline: float, rel_tol: float, min_abs_tol: float) -> float: + """Calculates the upper bound for performance regression check.""" + rel_limit = baseline * (1 + rel_tol) + abs_limit = baseline + min_abs_tol + return max(rel_limit, abs_limit) + + +def calculate_lower_bound(baseline: float, rel_tol: float, min_abs_tol: float) -> float: + """Calculates the lower bound for performance improvement check.""" + rel_lower = baseline * (1 - rel_tol) + abs_lower = baseline - min_abs_tol + return min(rel_lower, abs_lower) + + +def get_perf_status_emoji( + baseline: float, + new: float, + rel_tol: float = 0.1, + min_abs_tol: float = 120.0, +) -> str: + """ + Determines the status emoji based on performance difference. + + Logic: + Upper bound (Slower): max(baseline * (1 + rel_tol), baseline + min_abs_tol) + Lower bound (Faster): min(baseline * (1 - rel_tol), baseline - min_abs_tol) + """ + upper_bound = calculate_upper_bound(baseline, rel_tol, min_abs_tol) + lower_bound = calculate_lower_bound(baseline, rel_tol, min_abs_tol) + + if new > upper_bound: + return "🔴" + elif new < lower_bound: + return "🟢" + else: + return "⚪️" + + +def consolidate_steps( + steps_list: List[Dict[str, Any]], +) -> Tuple[Dict[str, float], List[str], Dict[str, int]]: + """ + Aggregates specific repeating steps (like denoising_step_*) into groups. + Returns: + - aggregated_durations: {name: duration_ms} + - ordered_names: list of names in execution order + - counts: {name: count_of_steps_aggregated} + """ + durations = {} + counts = {} + ordered_names = [] + seen_names = set() + + # Regex for steps to group + # Group "denoising_step_0", "denoising_step_1" -> "Denoising Loop" + denoise_pattern = re.compile(r"^denoising_step_(\d+)$") + denoising_group_name = "Denoising Loop" + + for step in steps_list: + name = step.get("name", "unknown") + dur = step.get("duration_ms", 0.0) + + match = denoise_pattern.match(name) + if match: + key = denoising_group_name + if key not in durations: + durations[key] = 0.0 + counts[key] = 0 + if key not in seen_names: + ordered_names.append(key) + seen_names.add(key) + durations[key] += dur + counts[key] += 1 + else: + # Standard stage (preserve order) + if name not in durations: + durations[name] = 0.0 + counts[name] = 0 + if name not in seen_names: + ordered_names.append(name) + seen_names.add(name) + durations[name] += dur + counts[name] += 1 + + return durations, ordered_names, counts + + +def _load_benchmark_file(file_path: str) -> Dict[str, Any]: + """Loads a benchmark JSON file.""" + with open(file_path, "r", encoding="utf-8") as f: + return json.load(f) + + +def _get_status_emoji_from_diff_percent(diff_pct): + if diff_pct < -2.0: + return "✅" + elif diff_pct > 2.0: + return "❌" + else: + return "⚪️" + + +def _print_single_comparison_report( + others_data, base_e2e, combined_order, base_durations, others_processed, base_counts +): + new_data = others_data[0] + new_e2e = new_data.get("total_duration_ms", 0) + diff_ms, diff_pct = calculate_diff(base_e2e, new_e2e) + status = _get_status_emoji_from_diff_percent(diff_pct) + + print("#### 1. High-level Summary") + print("| Metric | Baseline | New | Diff | Status |") + print("| :--- | :--- | :--- | :--- | :--- |") + print( + f"| **E2E Latency** | {base_e2e:.2f} ms | {new_e2e:.2f} ms | **{diff_ms:+.2f} ms ({diff_pct:+.1f}%)** | {status} |" + ) + print( + f"| **Throughput** | {1000 / base_e2e if base_e2e else 0:.2f} req/s | {1000 / new_e2e if new_e2e else 0:.2f} req/s | - | - |" + ) + print("\n") + + print("#### 2. Stage Breakdown") + print("| Stage Name | Baseline (ms) | New (ms) | Diff (ms) | Diff (%) | Status |") + print("| :--- | :--- | :--- | :--- | :--- | :--- |") + + new_durations, _, new_counts = others_processed[0] + + for stage in combined_order: + b_val = base_durations.get(stage, 0.0) + n_val = new_durations.get(stage, 0.0) + b_count = base_counts.get(stage, 1) + n_count = new_counts.get(stage, 1) + + s_diff, s_pct = calculate_diff(b_val, n_val) + + count_str = "" + if stage == "Denoising Loop": + count_str = ( + f" ({n_count} steps)" + if n_count == b_count + else f" ({b_count}->{n_count} steps)" + ) + + status_emoji = get_perf_status_emoji(b_val, n_val) + print( + f"| {stage}{count_str} | {b_val:.2f} | {n_val:.2f} | {s_diff:+.2f} | {s_pct:+.1f}% | {status_emoji} |" + ) + + +def _print_multi_comparison_report( + base_e2e, + others_data, + other_labels, + combined_order, + base_durations, + others_processed, +): + print("#### 1. High-level Summary") + header = "| Metric | Baseline | " + " | ".join(other_labels) + " |" + sep = "| :--- | :--- | " + " | ".join([":---"] * len(other_labels)) + " |" + print(header) + print(sep) + + # E2E Row + row_e2e = f"| **E2E Latency** | {base_e2e:.2f} ms |" + for i, d in enumerate(others_data): + val = d.get("total_duration_ms", 0) + diff_ms, diff_pct = calculate_diff(base_e2e, val) + + status = _get_status_emoji_from_diff_percent(diff_pct) + + row_e2e += f" {val:.2f} ms ({diff_pct:+.1f}%) {status} |" + print(row_e2e) + print("\n") + + print("#### 2. Stage Breakdown") + # Header: Stage | Baseline | Label1 | Label2 ... + header = "| Stage Name | Baseline | " + " | ".join(other_labels) + " |" + sep = "| :--- | :--- | " + " | ".join([":---"] * len(other_labels)) + " |" + print(header) + print(sep) + + for stage in combined_order: + b_val = base_durations.get(stage, 0.0) + row_str = f"| {stage} | {b_val:.2f} |" + + for i, (n_durations, _, n_counts) in enumerate(others_processed): + n_val = n_durations.get(stage, 0.0) + _, s_pct = calculate_diff(b_val, n_val) + status_emoji = get_perf_status_emoji(b_val, n_val) + + row_str += f" {n_val:.2f} ({s_pct:+.1f}%) {status_emoji} |" + print(row_str) + + +def compare_benchmarks(file_paths: List[str], output_format: str = "markdown"): + """ + Compares benchmark JSON files and prints a report. + First file is baseline, others will be compared against it. + """ + if len(file_paths) < 2: + print("Error: Need at least 2 files to compare.") + return + + try: + data_list = [_load_benchmark_file(f) for f in file_paths] + except Exception as e: + print(f"Error loading benchmark files: {e}") + return + + base_data = data_list[0] + others_data = data_list[1:] + + # Use filenames as labels if multiple comparisons, else just "New" + other_labels = [os.path.basename(p) for p in file_paths[1:]] + + base_e2e = base_data.get("total_duration_ms", 0) + + base_durations, base_order, base_counts = consolidate_steps( + base_data.get("steps", []) + ) + + others_processed = [] + for d in others_data: + dur, order, counts = consolidate_steps(d.get("steps", [])) + others_processed.append((dur, order, counts)) + + combined_order = [] + # Collect all unique stages maintaining order from newest to baseline + for _, order, _ in reversed(others_processed): + for name in order: + if name not in combined_order: + combined_order.append(name) + for name in base_order: + if name not in combined_order: + combined_order.append(name) + + if output_format == "markdown": + print("### Performance Comparison Report\n") + + if len(others_data) == 1: + _print_single_comparison_report( + others_data, + base_e2e, + combined_order, + base_durations, + others_processed, + base_counts, + ) + else: + _print_multi_comparison_report( + base_e2e, + others_data, + other_labels, + combined_order, + base_durations, + others_processed, + ) + + print("\n") + # Metadata + print("
") + print("Metadata\n") + print(f"- Baseline Commit: `{base_data.get('commit_hash', 'N/A')}`") + for i, d in enumerate(others_data): + label = "New" if len(others_data) == 1 else other_labels[i] + print(f"- {label} Commit: `{d.get('commit_hash', 'N/A')}`") + print(f"- Timestamp: {datetime.now().isoformat()}") + print("
") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Compare sglang-diffusion performance JSON files." + ) + parser.add_argument( + "files", + nargs="+", + help="List of JSON files. First is baseline, others are compared against it.", + ) + args = parser.parse_args() + + compare_benchmarks(args.files) diff --git a/sglang/python/sglang/multimodal_gen/benchmarks/datasets.py b/sglang/python/sglang/multimodal_gen/benchmarks/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..d3eb1d10b3137187f0e9f263e4240b932bb6fa06 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/benchmarks/datasets.py @@ -0,0 +1,298 @@ +import glob +import json +import os +import re +import subprocess +import uuid +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +import requests +from PIL import Image + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +@dataclass +class RequestFuncInput: + prompt: str + api_url: str = "" + model: str = "" + width: Optional[int] = None + height: Optional[int] = None + num_frames: Optional[int] = None + fps: Optional[int] = None + extra_body: Dict[str, Any] = field(default_factory=dict) + image_paths: Optional[List[str]] = None + request_id: str = field(default_factory=lambda: str(uuid.uuid4())) + + +@dataclass +class RequestFuncOutput: + success: bool = False + latency: float = 0.0 + error: str = "" + start_time: float = 0.0 + response_body: Dict[str, Any] = field(default_factory=dict) + peak_memory_mb: float = 0.0 + + +def is_dir_not_empty(path: str) -> bool: + return os.path.isdir(path) and bool(os.listdir(path)) + + +class BaseDataset(ABC): + def __init__(self, args, api_url: str = "", model: str = ""): + self.args = args + self.api_url = api_url + self.model = model + self.items: List[Dict[str, Any]] = [] + + @abstractmethod + def __len__(self) -> int: + pass + + @abstractmethod + def __getitem__(self, idx: int) -> RequestFuncInput: + pass + + def get_requests(self) -> List[RequestFuncInput]: + return [self[i] for i in range(len(self))] + + +class VBenchDataset(BaseDataset): + """ + Dataset loader for VBench prompts. + Supports t2v, i2v. + """ + + T2V_PROMPT_URL = "https://raw.githubusercontent.com/Vchitect/VBench/master/prompts/prompts_per_dimension/subject_consistency.txt" + I2V_DOWNLOAD_SCRIPT_URL = "https://raw.githubusercontent.com/Vchitect/VBench/master/vbench2_beta_i2v/download_data.sh" + + def __init__(self, args, api_url: str = "", model: str = ""): + super().__init__(args, api_url, model) + self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "sglang") + self.items = self._load_data() + + def _load_data(self) -> List[Dict[str, Any]]: + if self.args.task_name in ("text-to-video", "text-to-image", "video-to-video"): + return self._load_t2v_prompts() + elif self.args.task_name in ("image-to-video", "image-to-image"): + return self._load_i2v_data() + else: + raise ValueError( + f"Illegal task name is found in VBenchDataset {self.args.task_name}" + ) + + def _download_file(self, url: str, dest_path: str) -> None: + """Download a file from URL to destination path.""" + os.makedirs(os.path.dirname(dest_path), exist_ok=True) + resp = requests.get(url) + resp.raise_for_status() + with open(dest_path, "w") as f: + f.write(resp.text) + + def _load_t2v_prompts(self) -> List[Dict[str, Any]]: + path = self.args.dataset_path + + if not path: + path = os.path.join(self.cache_dir, "vbench_subject_consistency.txt") + if not os.path.exists(path): + logger.info(f"Downloading VBench T2V prompts to {path}...") + try: + self._download_file(self.T2V_PROMPT_URL, path) + except Exception as e: + logger.info(f"Failed to download VBench prompts: {e}") + return [{"prompt": "A cat sitting on a bench"}] * 50 + + prompts = [] + with open(path, "r") as f: + for line in f: + line = line.strip() + if line: + prompts.append({"prompt": line}) + + return self._resize_data(prompts) + + def _auto_download_i2v_dataset(self) -> Optional[str]: + """Auto-download VBench I2V dataset and return the dataset directory.""" + vbench_i2v_dir = os.path.join(self.cache_dir, "vbench_i2v", "vbench2_beta_i2v") + info_json_path = os.path.join(vbench_i2v_dir, "data", "i2v-bench-info.json") + crop_dir = os.path.join(vbench_i2v_dir, "data", "crop") + origin_dir = os.path.join(vbench_i2v_dir, "data", "origin") + + if ( + os.path.exists(info_json_path) + and is_dir_not_empty(crop_dir) + and is_dir_not_empty(origin_dir) + ): + return vbench_i2v_dir + + logger.info(f"Downloading VBench I2V dataset to {vbench_i2v_dir}...") + try: + cache_root = os.path.join(self.cache_dir, "vbench_i2v") + script_path = os.path.join(cache_root, "download_data.sh") + + self._download_file(self.I2V_DOWNLOAD_SCRIPT_URL, script_path) + os.chmod(script_path, 0o755) + + logger.info("Executing download_data.sh (this may take a while)...") + + result = subprocess.run( + ["bash", script_path], + cwd=cache_root, + capture_output=True, + text=True, + ) + if result.returncode != 0: + raise RuntimeError(f"Download script failed: {result.stderr}") + missing_packages = re.findall(r"(\S+): command not found", result.stderr) + if missing_packages: + missing_packages = list(set(missing_packages)) + package_list = ", ".join(f"'{cmd}'" for cmd in missing_packages) + raise RuntimeError( + f"Download script failed because the following commands are not installed: {package_list}.\n" + "Please install them (e.g., on Ubuntu: `sudo apt install ...`) and try again." + ) + logger.info( + f"Successfully downloaded VBench I2V dataset to {vbench_i2v_dir}" + ) + except Exception as e: + logger.info(f"Failed to download VBench I2V dataset: {e}") + logger.info("Please manually download following instructions at:") + logger.info( + "https://github.com/Vchitect/VBench/tree/master/vbench2_beta_i2v#22-download" + ) + return None + + return vbench_i2v_dir if os.path.exists(info_json_path) else None + + def _load_from_i2v_json(self, json_path: str) -> List[Dict[str, Any]]: + """Load I2V data from i2v-bench-info.json format.""" + with open(json_path, "r") as f: + items = json.load(f) + + base_dir = os.path.dirname( + os.path.dirname(json_path) + ) # Go up to vbench2_beta_i2v + origin_dir = os.path.join(base_dir, "data", "origin") + + data = [] + for item in items: + img_path = os.path.join(origin_dir, item.get("file_name", "")) + if os.path.exists(img_path): + data.append({"prompt": item.get("caption", ""), "image_path": img_path}) + else: + logger.warning(f"Image not found: {img_path}") + + logger.info(f"Loaded {len(data)} I2V samples from VBench I2V dataset") + return data + + def _scan_directory_for_images(self, path: str) -> List[Dict[str, Any]]: + """Scan directory for image files.""" + exts = ["*.jpg", "*.jpeg", "*.png", "*.webp"] + files = [] + + for ext in exts: + files.extend(glob.glob(os.path.join(path, ext))) + files.extend(glob.glob(os.path.join(path, ext.upper()))) + + origin_dir = os.path.join(path, "data", "origin") + if os.path.exists(origin_dir): + files.extend(glob.glob(os.path.join(origin_dir, ext))) + files.extend(glob.glob(os.path.join(origin_dir, ext.upper()))) + + return [ + {"prompt": os.path.splitext(os.path.basename(f))[0], "image_path": f} + for f in files + ] + + def _create_dummy_data(self) -> List[Dict[str, Any]]: + """Create dummy data with a placeholder image in cache directory.""" + logger.info("No I2V data found. Using dummy placeholders.") + + dummy_image = os.path.join(self.cache_dir, "dummy_image.jpg") + if not os.path.exists(dummy_image): + os.makedirs(self.cache_dir, exist_ok=True) + img = Image.new("RGB", (100, 100), color="red") + img.save(dummy_image) + logger.info(f"Created dummy image at {dummy_image}") + + return [{"prompt": "A moving cat", "image_path": dummy_image}] * 10 + + def _load_i2v_data(self) -> List[Dict[str, Any]]: + """Load I2V data from VBench I2V dataset or user-provided path.""" + path = self.args.dataset_path + if not path: + path = self._auto_download_i2v_dataset() + if not path: + return self._resize_data(self._create_dummy_data()) + + info_json_candidates = [ + os.path.join(path, "data", "i2v-bench-info.json"), + path if path.endswith(".json") else None, + ] + + for json_path in info_json_candidates: + if json_path and os.path.exists(json_path): + try: + return self._resize_data(self._load_from_i2v_json(json_path)) + except Exception as e: + logger.info(f"Failed to load {json_path}: {e}") + + if os.path.isdir(path): + data = self._scan_directory_for_images(path) + if data: + return self._resize_data(data) + + return self._resize_data(self._create_dummy_data()) + + def _resize_data(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Resize data to match num_prompts.""" + if not self.args.num_prompts: + return data + + if len(data) < self.args.num_prompts: + factor = (self.args.num_prompts // len(data)) + 1 + data = data * factor + + return data[: self.args.num_prompts] + + def __len__(self) -> int: + return len(self.items) + + def __getitem__(self, idx: int) -> RequestFuncInput: + item = self.items[idx] + return RequestFuncInput( + prompt=item.get("prompt", ""), + api_url=self.api_url, + model=self.model, + width=self.args.width, + height=self.args.height, + num_frames=self.args.num_frames, + fps=self.args.fps, + image_paths=[item["image_path"]] if "image_path" in item else None, + ) + + +class RandomDataset(BaseDataset): + def __init__(self, args, api_url: str = "", model: str = ""): + super().__init__(args, api_url, model) + self.num_prompts = args.num_prompts or 100 + + def __len__(self) -> int: + return self.num_prompts + + def __getitem__(self, idx: int) -> RequestFuncInput: + return RequestFuncInput( + prompt=f"Random prompt {idx} for benchmarking diffusion models", + api_url=self.api_url, + model=self.model, + width=self.args.width, + height=self.args.height, + num_frames=self.args.num_frames, + fps=self.args.fps, + ) diff --git a/sglang/python/sglang/multimodal_gen/configs/__init__.py b/sglang/python/sglang/multimodal_gen/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dfff5f2c4e4b6ed28bdc8a06400012d504abbf0b --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/__init__.py @@ -0,0 +1,3 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# Configs for pipelines, and pipeline modules (in models folder) diff --git a/sglang/python/sglang/multimodal_gen/configs/backend/vmoba/wan_1.3B_77_448_832.json b/sglang/python/sglang/multimodal_gen/configs/backend/vmoba/wan_1.3B_77_448_832.json new file mode 100644 index 0000000000000000000000000000000000000000..1e55b5f2e3d0dc0619eb33a1471bebb78d845662 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/backend/vmoba/wan_1.3B_77_448_832.json @@ -0,0 +1,16 @@ +{ + "temporal_chunk_size": 2, + "temporal_topk": 2, + "spatial_chunk_size": [4, 13], + "spatial_topk": 6, + "st_chunk_size": [4, 4, 13], + "st_topk": 18, + "moba_select_mode": "topk", + "moba_threshold": 0.25, + "moba_threshold_type": "query_head", + "first_full_layer": 0, + "first_full_step": 12, + "temporal_layer": 1, + "spatial_layer": 1, + "st_layer": 1 +} diff --git a/sglang/python/sglang/multimodal_gen/configs/backend/vmoba/wan_1.3B_77_480_832.json b/sglang/python/sglang/multimodal_gen/configs/backend/vmoba/wan_1.3B_77_480_832.json new file mode 100644 index 0000000000000000000000000000000000000000..ddf66f48e554a1b0cc2e9d26ca0a90c34d370f5a --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/backend/vmoba/wan_1.3B_77_480_832.json @@ -0,0 +1,16 @@ +{ + "temporal_chunk_size": 2, + "temporal_topk": 3, + "spatial_chunk_size": [3, 4], + "spatial_topk": 20, + "st_chunk_size": [4, 6, 4], + "st_topk": 15, + "moba_select_mode": "threshold", + "moba_threshold": 0.25, + "moba_threshold_type": "query_head", + "first_full_layer": 0, + "first_full_step": 12, + "temporal_layer": 1, + "spatial_layer": 1, + "st_layer": 1 +} diff --git a/sglang/python/sglang/multimodal_gen/configs/models/__init__.py b/sglang/python/sglang/multimodal_gen/configs/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..62c0aadfd7cda9ff1584b00681acbb51624432c7 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/__init__.py @@ -0,0 +1,8 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from sglang.multimodal_gen.configs.models.base import ModelConfig +from sglang.multimodal_gen.configs.models.dits.base import DiTConfig +from sglang.multimodal_gen.configs.models.encoders.base import EncoderConfig +from sglang.multimodal_gen.configs.models.vaes.base import VAEConfig + +__all__ = ["ModelConfig", "VAEConfig", "DiTConfig", "EncoderConfig"] diff --git a/sglang/python/sglang/multimodal_gen/configs/models/adapter/base.py b/sglang/python/sglang/multimodal_gen/configs/models/adapter/base.py new file mode 100644 index 0000000000000000000000000000000000000000..abd22f11554b0c8f9acd322fadc3a9659cc2116d --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/adapter/base.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +from typing import Any + +from sglang.multimodal_gen.configs.models.base import ArchConfig, ModelConfig +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum + + +@dataclass +class AdapterArchConfig(ArchConfig): + _fsdp_shard_conditions: list = field(default_factory=list) + _compile_conditions: list = field(default_factory=list) + + # convert weights name from HF-format to SGLang-dit-format + param_names_mapping: dict = field(default_factory=dict) + + # Reverse mapping for saving checkpoints: custom -> hf + reverse_param_names_mapping: dict = field(default_factory=dict) + _supported_attention_backends: set[AttentionBackendEnum] = field( + default_factory=lambda: { + AttentionBackendEnum.SLIDING_TILE_ATTN, + AttentionBackendEnum.SAGE_ATTN, + AttentionBackendEnum.FA, + AttentionBackendEnum.AITER, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.VIDEO_SPARSE_ATTN, + AttentionBackendEnum.VMOBA_ATTN, + AttentionBackendEnum.SAGE_ATTN_3, + } + ) + + hidden_size: int = 0 + num_attention_heads: int = 0 + num_channels_latents: int = 0 + exclude_lora_layers: list[str] = field(default_factory=list) + boundary_ratio: float | None = None + + def __post_init__(self) -> None: + if not self._compile_conditions: + self._compile_conditions = self._fsdp_shard_conditions.copy() + + +@dataclass +class AdapterConfig(ModelConfig): + arch_config: AdapterArchConfig = field(default_factory=AdapterArchConfig) + + # sglang-diffusion Adapter-specific parameters + prefix: str = "" + + @staticmethod + def add_cli_args(parser: Any, prefix: str = "dit-config") -> Any: + """Add CLI arguments for AdapterConfig fields""" + parser.add_argument( + f"--{prefix}.prefix", + type=str, + dest=f"{prefix.replace('-', '_')}.prefix", + default=AdapterConfig.prefix, + help="Prefix for the Adapter", + ) + + return parser diff --git a/sglang/python/sglang/multimodal_gen/configs/models/adapter/ltx_2_connector.py b/sglang/python/sglang/multimodal_gen/configs/models/adapter/ltx_2_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..d8f6c43ef09210eee5c81ee5aaf5e83aa3ded17c --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/adapter/ltx_2_connector.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.adapter.base import ( + AdapterArchConfig, + AdapterConfig, +) + + +@dataclass +class LTX2ConnectorArchConfig(AdapterArchConfig): + audio_connector_attention_head_dim: int = 128 + audio_connector_num_attention_heads: int = 30 + audio_connector_num_layers: int = 2 + audio_connector_num_learnable_registers: int = 128 + caption_channels: int = 3840 + causal_temporal_positioning: bool = False + connector_rope_base_seq_len: int = 4096 + rope_double_precision: bool = True + rope_theta: float = 10000.0 + rope_type: str = "split" + text_proj_in_factor: int = 49 + video_connector_attention_head_dim: int = 128 + video_connector_num_attention_heads: int = 30 + video_connector_num_layers: int = 2 + video_connector_num_learnable_registers: int = 128 + + +@dataclass +class LTX2ConnectorConfig(AdapterConfig): + + arch_config: AdapterArchConfig = field(default_factory=LTX2ConnectorArchConfig) + + prefix: str = "LTX2" diff --git a/sglang/python/sglang/multimodal_gen/configs/models/base.py b/sglang/python/sglang/multimodal_gen/configs/models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..6619117be610f26a2e9237bc589cc332f6e8837e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/base.py @@ -0,0 +1,100 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field, fields +from typing import Any, Dict + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +# 1. ArchConfig contains all fields from diffuser's/transformer's config.json (i.e. all fields related to the architecture of the model) +# 2. ArchConfig should be inherited & overridden by each model arch_config +# 3. Any field in ArchConfig is fixed upon initialization, and should be hidden away from users +@dataclass +class ArchConfig: + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=list + ) # mapping from huggingface weight names to custom names + extra_attrs: Dict[str, Any] = field(default_factory=dict) + + def __getattr__(self, name: str): + d = object.__getattribute__(self, "__dict__") + extras = d.get("extra_attrs") + if extras is not None and name in extras: + return extras[name] + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) + + def __setattr__(self, key, value): + if key in type(self).__dataclass_fields__: + object.__setattr__(self, key, value) + else: + d = object.__getattribute__(self, "__dict__") + extras = d.get("extra_attrs") + if extras is None: + extras = {} + d["extra_attrs"] = extras + extras[key] = value + + +@dataclass +class ModelConfig: + # Every model config parameter can be categorized into either ArchConfig or everything else + # Diffuser/Transformer parameters + arch_config: ArchConfig = field(default_factory=ArchConfig) + + # sglang-diffusion-specific parameters here + # i.e. STA, quantization, teacache + + def __getattr__(self, name): + # Only called if 'name' is not found in ModelConfig directly + if hasattr(self.arch_config, name): + return getattr(self.arch_config, name) + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + def __getstate__(self): + # Return a dictionary of attributes to pickle + # Convert to dict and exclude any problematic attributes + state = self.__dict__.copy() + return state + + def __setstate__(self, state): + # Restore instance attributes from the unpickled state + self.__dict__.update(state) + + # This should be used only when loading from transformers/diffusers + def update_model_arch(self, source_model_dict: dict[str, Any]) -> None: + """ + Update arch_config with source_model_dict + """ + arch_config = self.arch_config + + for key, value in source_model_dict.items(): + setattr(arch_config, key, value) + + if hasattr(arch_config, "__post_init__"): + arch_config.__post_init__() + + def update_model_config(self, source_model_dict: dict[str, Any]) -> None: + assert ( + "arch_config" not in source_model_dict + ), "Source model config shouldn't contain arch_config." + + valid_fields = {f.name for f in fields(self)} + + for key, value in source_model_dict.items(): + if key in valid_fields: + setattr(self, key, value) + else: + logger.warning( + "%s does not contain field '%s'!", type(self).__name__, key + ) + raise AttributeError(f"Invalid field: {key}") + + if hasattr(self, "__post_init__"): + self.__post_init__() diff --git a/sglang/python/sglang/multimodal_gen/configs/models/bridges/__init__.py b/sglang/python/sglang/multimodal_gen/configs/models/bridges/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8e860b8903b54d1785ad865d87edf7676cab760b --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/bridges/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 + +from sglang.multimodal_gen.configs.models.bridges.mova_dual_tower import ( + MOVADualTowerConfig, +) + +__all__ = ["MOVADualTowerConfig"] diff --git a/sglang/python/sglang/multimodal_gen/configs/models/bridges/mova_dual_tower.py b/sglang/python/sglang/multimodal_gen/configs/models/bridges/mova_dual_tower.py new file mode 100644 index 0000000000000000000000000000000000000000..4bf0b83c4309065c038cb9a5236733549d7cc858 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/bridges/mova_dual_tower.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Configuration for MOVA dual tower bridge model.""" + +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig + + +def _is_conditioner_block(name: str, module) -> bool: + """Check if module is a ConditionalCrossAttentionBlock.""" + return "ConditionalCrossAttentionBlock" in type(module).__name__ + + +@dataclass +class MOVADualTowerArchConfig(DiTArchConfig): + _fsdp_shard_conditions: list = field( + default_factory=lambda: [_is_conditioner_block] + ) + + # Model architecture parameters + visual_layers: int = 40 + audio_layers: int = 30 + visual_hidden_dim: int = 5120 + audio_hidden_dim: int = 1536 + audio_fps: float = 50.0 + head_dim: int = 128 + interaction_strategy: str = "full" + apply_cross_rope: bool = True + apply_first_frame_bias_in_rope: bool = False + trainable_condition_scale: bool = False + pooled_adaln: bool = False + eps: float = 1e-6 + + def __post_init__(self): + super().__post_init__() + self.hidden_size = self.visual_hidden_dim + self.num_attention_heads = self.visual_hidden_dim // self.head_dim + + +@dataclass +class MOVADualTowerConfig(DiTConfig): + arch_config: DiTArchConfig = field(default_factory=MOVADualTowerArchConfig) diff --git a/sglang/python/sglang/multimodal_gen/configs/models/dits/__init__.py b/sglang/python/sglang/multimodal_gen/configs/models/dits/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ba114c18fa895936506bcbc0c376308765cef859 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/dits/__init__.py @@ -0,0 +1,17 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from sglang.multimodal_gen.configs.models.dits.helios import HeliosConfig +from sglang.multimodal_gen.configs.models.dits.hunyuan3d import Hunyuan3DDiTConfig +from sglang.multimodal_gen.configs.models.dits.hunyuanvideo import HunyuanVideoConfig +from sglang.multimodal_gen.configs.models.dits.mova_audio import MOVAAudioConfig +from sglang.multimodal_gen.configs.models.dits.mova_video import MOVAVideoConfig +from sglang.multimodal_gen.configs.models.dits.wanvideo import WanVideoConfig + +__all__ = [ + "HeliosConfig", + "HunyuanVideoConfig", + "WanVideoConfig", + "Hunyuan3DDiTConfig", + "MOVAAudioConfig", + "MOVAVideoConfig", +] diff --git a/sglang/python/sglang/multimodal_gen/configs/models/dits/base.py b/sglang/python/sglang/multimodal_gen/configs/models/dits/base.py new file mode 100644 index 0000000000000000000000000000000000000000..71ad7c66397817d9ad05fd9908bc29ebea8d141f --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/dits/base.py @@ -0,0 +1,78 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +from typing import Any + +from sglang.multimodal_gen.configs.models.base import ArchConfig, ModelConfig +from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum + + +@dataclass +class DiTArchConfig(ArchConfig): + _fsdp_shard_conditions: list = field(default_factory=list) + _compile_conditions: list = field(default_factory=list) + + # convert weights name from HF-format to SGLang-dit-format + param_names_mapping: dict = field(default_factory=dict) + + # convert weights name from misc-format to HF-format + # usually applicable if the LoRA is trained with official repo implementation + lora_param_names_mapping: dict = field(default_factory=dict) + + # Reverse mapping for saving checkpoints: custom -> hf + reverse_param_names_mapping: dict = field(default_factory=dict) + _supported_attention_backends: set[AttentionBackendEnum] = field( + default_factory=lambda: { + AttentionBackendEnum.SLIDING_TILE_ATTN, + AttentionBackendEnum.SAGE_ATTN, + AttentionBackendEnum.FA, + AttentionBackendEnum.AITER, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.VIDEO_SPARSE_ATTN, + AttentionBackendEnum.SPARSE_VIDEO_GEN_2_ATTN, + AttentionBackendEnum.VMOBA_ATTN, + AttentionBackendEnum.SAGE_ATTN_3, + } + ) + + hidden_size: int = 0 + num_attention_heads: int = 0 + num_channels_latents: int = 0 + exclude_lora_layers: list[str] = field(default_factory=list) + boundary_ratio: float | None = None + + def __post_init__(self) -> None: + if not self._compile_conditions: + self._compile_conditions = self._fsdp_shard_conditions.copy() + + +@dataclass +class DiTConfig(ModelConfig): + arch_config: DiTArchConfig = field(default_factory=DiTArchConfig) + + # sglang-diffusion DiT-specific parameters + prefix: str = "" + quant_config: QuantizationConfig | None = None + + @staticmethod + def add_cli_args(parser: Any, prefix: str = "dit-config") -> Any: + """Add CLI arguments for DiTConfig fields""" + parser.add_argument( + f"--{prefix}.prefix", + type=str, + dest=f"{prefix.replace('-', '_')}.prefix", + default=DiTConfig.prefix, + help="Prefix for the DiT model", + ) + + parser.add_argument( + f"--{prefix}.quant-config", + type=str, + dest=f"{prefix.replace('-', '_')}.quant_config", + default=None, + help="Quantization configuration for the DiT model", + ) + + return parser diff --git a/sglang/python/sglang/multimodal_gen/configs/models/dits/flux.py b/sglang/python/sglang/multimodal_gen/configs/models/dits/flux.py new file mode 100644 index 0000000000000000000000000000000000000000..0e5a23090de2b0c43afbeee907620a5f9d62d4fc --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/dits/flux.py @@ -0,0 +1,71 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +from typing import Tuple + +from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig + + +@dataclass +class FluxArchConfig(DiTArchConfig): + patch_size: int = 1 + in_channels: int = 64 + out_channels: int | None = None + num_layers: int = 19 + num_single_layers: int = 38 + attention_head_dim: int = 128 + num_attention_heads: int = 24 + joint_attention_dim: int = 4096 + pooled_projection_dim: int = 768 + guidance_embeds: bool = False + axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) + + stacked_params_mapping: list[tuple[str, str, str]] = field(default_factory=list) + + # nunchaku checkpoint uses different weight names; map to sglang flux layout + param_names_mapping: dict = field( + default_factory=lambda: { + # HF diffusers format + r"^transformer\.(\w*)\.(.*)$": r"\1.\2", + # transformer_blocks nunchaku format (raw export - before internal conversion) + r"^transformer_blocks\.(\d+)\.mlp_fc1\.(.*)$": r"transformer_blocks.\1.ff.net.0.proj.\2", + r"^transformer_blocks\.(\d+)\.mlp_fc2\.(.*)$": r"transformer_blocks.\1.ff.net.2.\2", + r"^transformer_blocks\.(\d+)\.mlp_context_fc1\.(.*)$": r"transformer_blocks.\1.ff_context.net.0.proj.\2", + r"^transformer_blocks\.(\d+)\.mlp_context_fc2\.(.*)$": r"transformer_blocks.\1.ff_context.net.2.\2", + r"^transformer_blocks\.(\d+)\.qkv_proj\.(.*)$": r"transformer_blocks.\1.attn.to_qkv.\2", + r"^transformer_blocks\.(\d+)\.qkv_proj_context\.(.*)$": r"transformer_blocks.\1.attn.to_added_qkv.\2", + r"^transformer_blocks\.(\d+)\.out_proj\.(.*)$": r"transformer_blocks.\1.attn.to_out.0.\2", + r"^transformer_blocks\.(\d+)\.out_proj_context\.(.*)$": r"transformer_blocks.\1.attn.to_add_out.\2", + r"^transformer_blocks\.(\d+)\.norm_q\.(.*)$": r"transformer_blocks.\1.attn.norm_q.\2", + r"^transformer_blocks\.(\d+)\.norm_k\.(.*)$": r"transformer_blocks.\1.attn.norm_k.\2", + r"^transformer_blocks\.(\d+)\.norm_added_q\.(.*)$": r"transformer_blocks.\1.attn.norm_added_q.\2", + r"^transformer_blocks\.(\d+)\.norm_added_k\.(.*)$": r"transformer_blocks.\1.attn.norm_added_k.\2", + # transformer_blocks nunchaku format (already converted with convert_flux_state_dict) + r"^transformer_blocks\.(\d+)\.attn\.add_qkv_proj\.(.*)$": r"transformer_blocks.\1.attn.to_added_qkv.\2", + # single_transformer_blocks nunchaku format (raw export - before internal conversion) + r"^single_transformer_blocks\.(\d+)\.qkv_proj\.(.*)$": r"single_transformer_blocks.\1.attn.to_qkv.\2", + r"^single_transformer_blocks\.(\d+)\.out_proj\.(.*)$": r"single_transformer_blocks.\1.attn.to_out.0.\2", + r"^single_transformer_blocks\.(\d+)\.norm_q\.(.*)$": r"single_transformer_blocks.\1.attn.norm_q.\2", + r"^single_transformer_blocks\.(\d+)\.norm_k\.(.*)$": r"single_transformer_blocks.\1.attn.norm_k.\2", + # nunchaku quantization parameter name conversions (apply to all blocks) + r"^(.*)\.smooth_orig$": r"\1.smooth_factor_orig", + r"^(.*)\.smooth$": r"\1.smooth_factor", + r"^(.*)\.lora_down$": r"\1.proj_down", + r"^(.*)\.lora_up$": r"\1.proj_up", + } + ) + + def __post_init__(self): + super().__post_init__() + self.out_channels = self.out_channels or self.in_channels + self.hidden_size = self.num_attention_heads * self.attention_head_dim + self.num_channels_latents = self.out_channels + + +@dataclass +class FluxConfig(DiTConfig): + + arch_config: DiTArchConfig = field(default_factory=FluxArchConfig) + + prefix: str = "Flux" diff --git a/sglang/python/sglang/multimodal_gen/configs/models/dits/glmimage.py b/sglang/python/sglang/multimodal_gen/configs/models/dits/glmimage.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc2553e4ff74e737b7fedc25ddc3e8ee1ca3817 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/dits/glmimage.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig + + +@dataclass +class GlmImageArchConfig(DiTArchConfig): + patch_size: int = 2 + in_channels: int = 16 + out_channels: int | None = 16 + num_layers: int = 30 + attention_head_dim: int = 128 + num_attention_heads: int = 32 + condition_dim: int = 256 + prior_vq_quantizer_codebook_size: int = 16384 + text_embed_dim: int = 1472 + time_embed_dim: int = 512 + + stacked_params_mapping: list[tuple[str, str, str]] = field(default_factory=list) + + param_names_mapping: dict = field( + default_factory=lambda: { + # LoRA mappings + r"^(transformer_blocks\.\d+\.attn\..*\.lora_[AB])\.default$": r"\1", + } + ) + + def __post_init__(self): + super().__post_init__() + self.out_channels = self.out_channels or self.in_channels + self.hidden_size = self.num_attention_heads * self.attention_head_dim + self.num_channels_latents = self.out_channels + + +@dataclass +class GlmImageDitConfig(DiTConfig): + arch_config: DiTArchConfig = field(default_factory=GlmImageArchConfig) + + prefix: str = "glmimage" diff --git a/sglang/python/sglang/multimodal_gen/configs/models/dits/helios.py b/sglang/python/sglang/multimodal_gen/configs/models/dits/helios.py new file mode 100644 index 0000000000000000000000000000000000000000..15f73fb02719c3a9b8f2cbe905fd708d74668983 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/dits/helios.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig + + +def is_blocks(n: str, m) -> bool: + return "blocks" in n and str.isdigit(n.split(".")[-1]) + + +@dataclass +class HeliosArchConfig(DiTArchConfig): + _fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks]) + + param_names_mapping: dict = field( + default_factory=lambda: { + # Patch embeddings + r"^patch_embedding\.(.*)$": r"patch_embedding.proj.\1", + # Condition embedder: text + r"^condition_embedder\.text_embedder\.linear_1\.(.*)$": r"condition_embedder.text_embedder.fc_in.\1", + r"^condition_embedder\.text_embedder\.linear_2\.(.*)$": r"condition_embedder.text_embedder.fc_out.\1", + # Condition embedder: time + r"^condition_embedder\.time_embedder\.linear_1\.(.*)$": r"condition_embedder.time_embedder.mlp.fc_in.\1", + r"^condition_embedder\.time_embedder\.linear_2\.(.*)$": r"condition_embedder.time_embedder.mlp.fc_out.\1", + r"^condition_embedder\.time_proj\.(.*)$": r"condition_embedder.time_modulation.linear.\1", + # Blocks: self-attention (keep attn1. prefix, drop .0. from to_out) + r"^blocks\.(\d+)\.attn1\.to_out\.0\.(.*)$": r"blocks.\1.attn1.to_out.\2", + # Blocks: cross-attention output (drop .0. from to_out) + r"^blocks\.(\d+)\.attn2\.to_out\.0\.(.*)$": r"blocks.\1.attn2.to_out.\2", + # Blocks: feed-forward + r"^blocks\.(\d+)\.ffn\.net\.0\.proj\.(.*)$": r"blocks.\1.ffn.fc_in.\2", + r"^blocks\.(\d+)\.ffn\.net\.2\.(.*)$": r"blocks.\1.ffn.fc_out.\2", + # Blocks: cross-attn residual norm + r"^blocks\.(\d+)\.norm2\.(.*)$": r"blocks.\1.self_attn_residual_norm.\2", + } + ) + + reverse_param_names_mapping: dict = field(default_factory=lambda: {}) + + lora_param_names_mapping: dict = field(default_factory=lambda: {}) + + patch_size: tuple[int, int, int] = (1, 2, 2) + text_len: int = 226 + num_attention_heads: int = 40 + attention_head_dim: int = 128 + in_channels: int = 16 + out_channels: int = 16 + text_dim: int = 4096 + freq_dim: int = 256 + ffn_dim: int = 13824 + num_layers: int = 40 + cross_attn_norm: bool = True + qk_norm: str = "rms_norm_across_heads" + eps: float = 1e-6 + added_kv_proj_dim: int | None = None + rope_max_seq_len: int = 1024 + pos_embed_seq_len: int | None = None + exclude_lora_layers: list[str] = field(default_factory=lambda: ["embedder"]) + + # Helios-specific + rope_dim: tuple[int, int, int] = (44, 42, 42) + rope_theta: float = 10000.0 + guidance_cross_attn: bool = True + zero_history_timestep: bool = True + has_multi_term_memory_patch: bool = True + is_amplify_history: bool = False + history_scale_mode: str = "per_head" + + def __post_init__(self): + super().__post_init__() + self.out_channels = self.out_channels or self.in_channels + self.hidden_size = self.num_attention_heads * self.attention_head_dim + self.num_channels_latents = self.out_channels + + +@dataclass +class HeliosConfig(DiTConfig): + arch_config: DiTArchConfig = field(default_factory=HeliosArchConfig) + + prefix: str = "Helios" diff --git a/sglang/python/sglang/multimodal_gen/configs/models/dits/hunyuan3d.py b/sglang/python/sglang/multimodal_gen/configs/models/dits/hunyuan3d.py new file mode 100644 index 0000000000000000000000000000000000000000..edcac78ca19ce755ce637e41a069faf0a274ca05 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/dits/hunyuan3d.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig + + +@dataclass +class Hunyuan3DDiTArchConfig(DiTArchConfig): + """Architecture config for Hunyuan3D DiT (Flux-style for Hunyuan3D-2.0).""" + + param_names_mapping: dict = field( + default_factory=lambda: { + r"(.*)\.img_mlp\.0\.(.*)$": r"\1.img_mlp.fc_in.\2", + r"(.*)\.img_mlp\.2\.(.*)$": r"\1.img_mlp.fc_out.\2", + r"(.*)\.txt_mlp\.0\.(.*)$": r"\1.txt_mlp.fc_in.\2", + r"(.*)\.txt_mlp\.2\.(.*)$": r"\1.txt_mlp.fc_out.\2", + } + ) + + in_channels: int = 64 + hidden_size: int = 1024 + num_attention_heads: int = 16 + num_layers: int = 16 + num_single_layers: int = 32 + mlp_ratio: float = 4.0 + context_in_dim: int = 1536 + axes_dim: tuple[int, ...] = (64,) + theta: int = 10000 + qkv_bias: bool = True + guidance_embed: bool = False + time_factor: float = 1000.0 + + def __post_init__(self) -> None: + if self.num_channels_latents == 0: + self.num_channels_latents = self.in_channels + super().__post_init__() + + +@dataclass +class Hunyuan3DDiTConfig(DiTConfig): + """DiT configuration for Hunyuan3D shape generation (Flux-style).""" + + arch_config: Hunyuan3DDiTArchConfig = field(default_factory=Hunyuan3DDiTArchConfig) + subfolder: str = "hunyuan3d-dit-v2-0" diff --git a/sglang/python/sglang/multimodal_gen/configs/models/dits/hunyuanvideo.py b/sglang/python/sglang/multimodal_gen/configs/models/dits/hunyuanvideo.py new file mode 100644 index 0000000000000000000000000000000000000000..1cae921ff4fd325f6c2cfc0a237374bbde87da39 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/dits/hunyuanvideo.py @@ -0,0 +1,184 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +import torch + +from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig + + +def is_double_block(n: str, m) -> bool: + return "double" in n and str.isdigit(n.split(".")[-1]) + + +def is_single_block(n: str, m) -> bool: + return "single" in n and str.isdigit(n.split(".")[-1]) + + +def is_refiner_block(n: str, m) -> bool: + return "refiner" in n and str.isdigit(n.split(".")[-1]) + + +def is_txt_in(n: str, m) -> bool: + return n.split(".")[-1] == "txt_in" + + +@dataclass +class HunyuanVideoArchConfig(DiTArchConfig): + _fsdp_shard_conditions: list = field( + default_factory=lambda: [is_double_block, is_single_block, is_refiner_block] + ) + + _compile_conditions: list = field( + default_factory=lambda: [is_double_block, is_single_block, is_txt_in] + ) + + param_names_mapping: dict = field( + default_factory=lambda: { + # 1. context_embedder.time_text_embed submodules (specific rules, applied first): + r"^context_embedder\.time_text_embed\.timestep_embedder\.linear_1\.(.*)$": r"txt_in.t_embedder.mlp.fc_in.\1", + r"^context_embedder\.time_text_embed\.timestep_embedder\.linear_2\.(.*)$": r"txt_in.t_embedder.mlp.fc_out.\1", + r"^context_embedder\.proj_in\.(.*)$": r"txt_in.input_embedder.\1", + r"^context_embedder\.time_text_embed\.text_embedder\.linear_1\.(.*)$": r"txt_in.c_embedder.fc_in.\1", + r"^context_embedder\.time_text_embed\.text_embedder\.linear_2\.(.*)$": r"txt_in.c_embedder.fc_out.\1", + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.norm1\.(.*)$": r"txt_in.refiner_blocks.\1.norm1.\2", + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.norm2\.(.*)$": r"txt_in.refiner_blocks.\1.norm2.\2", + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.attn\.to_q\.(.*)$": ( + r"txt_in.refiner_blocks.\1.self_attn_qkv.\2", + 0, + 3, + ), + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.attn\.to_k\.(.*)$": ( + r"txt_in.refiner_blocks.\1.self_attn_qkv.\2", + 1, + 3, + ), + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.attn\.to_v\.(.*)$": ( + r"txt_in.refiner_blocks.\1.self_attn_qkv.\2", + 2, + 3, + ), + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.attn\.to_out\.0\.(.*)$": r"txt_in.refiner_blocks.\1.self_attn_proj.\2", + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.ff\.net\.0(?:\.proj)?\.(.*)$": r"txt_in.refiner_blocks.\1.mlp.fc_in.\2", + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.ff\.net\.2(?:\.proj)?\.(.*)$": r"txt_in.refiner_blocks.\1.mlp.fc_out.\2", + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.norm_out\.linear\.(.*)$": r"txt_in.refiner_blocks.\1.adaLN_modulation.linear.\2", + # 3. x_embedder mapping: + r"^x_embedder\.proj\.(.*)$": r"img_in.proj.\1", + # 4. Top-level time_text_embed mappings: + r"^time_text_embed\.timestep_embedder\.linear_1\.(.*)$": r"time_in.mlp.fc_in.\1", + r"^time_text_embed\.timestep_embedder\.linear_2\.(.*)$": r"time_in.mlp.fc_out.\1", + r"^time_text_embed\.guidance_embedder\.linear_1\.(.*)$": r"guidance_in.mlp.fc_in.\1", + r"^time_text_embed\.guidance_embedder\.linear_2\.(.*)$": r"guidance_in.mlp.fc_out.\1", + r"^time_text_embed\.text_embedder\.linear_1\.(.*)$": r"vector_in.fc_in.\1", + r"^time_text_embed\.text_embedder\.linear_2\.(.*)$": r"vector_in.fc_out.\1", + # 5. transformer_blocks mapping: + r"^transformer_blocks\.(\d+)\.norm1\.linear\.(.*)$": r"double_blocks.\1.img_mod.linear.\2", + r"^transformer_blocks\.(\d+)\.norm1_context\.linear\.(.*)$": r"double_blocks.\1.txt_mod.linear.\2", + r"^transformer_blocks\.(\d+)\.attn\.norm_q\.(.*)$": r"double_blocks.\1.img_attn_q_norm.\2", + r"^transformer_blocks\.(\d+)\.attn\.norm_k\.(.*)$": r"double_blocks.\1.img_attn_k_norm.\2", + r"^transformer_blocks\.(\d+)\.attn\.to_q\.(.*)$": ( + r"double_blocks.\1.img_attn_qkv.\2", + 0, + 3, + ), + r"^transformer_blocks\.(\d+)\.attn\.to_k\.(.*)$": ( + r"double_blocks.\1.img_attn_qkv.\2", + 1, + 3, + ), + r"^transformer_blocks\.(\d+)\.attn\.to_v\.(.*)$": ( + r"double_blocks.\1.img_attn_qkv.\2", + 2, + 3, + ), + r"^transformer_blocks\.(\d+)\.attn\.add_q_proj\.(.*)$": ( + r"double_blocks.\1.txt_attn_qkv.\2", + 0, + 3, + ), + r"^transformer_blocks\.(\d+)\.attn\.add_k_proj\.(.*)$": ( + r"double_blocks.\1.txt_attn_qkv.\2", + 1, + 3, + ), + r"^transformer_blocks\.(\d+)\.attn\.add_v_proj\.(.*)$": ( + r"double_blocks.\1.txt_attn_qkv.\2", + 2, + 3, + ), + r"^transformer_blocks\.(\d+)\.attn\.to_out\.0\.(.*)$": r"double_blocks.\1.img_attn_proj.\2", + # Corrected: merge attn.to_add_out into the main projection. + r"^transformer_blocks\.(\d+)\.attn\.to_add_out\.(.*)$": r"double_blocks.\1.txt_attn_proj.\2", + r"^transformer_blocks\.(\d+)\.attn\.norm_added_q\.(.*)$": r"double_blocks.\1.txt_attn_q_norm.\2", + r"^transformer_blocks\.(\d+)\.attn\.norm_added_k\.(.*)$": r"double_blocks.\1.txt_attn_k_norm.\2", + r"^transformer_blocks\.(\d+)\.ff\.net\.0(?:\.proj)?\.(.*)$": r"double_blocks.\1.img_mlp.fc_in.\2", + r"^transformer_blocks\.(\d+)\.ff\.net\.2(?:\.proj)?\.(.*)$": r"double_blocks.\1.img_mlp.fc_out.\2", + r"^transformer_blocks\.(\d+)\.ff_context\.net\.0(?:\.proj)?\.(.*)$": r"double_blocks.\1.txt_mlp.fc_in.\2", + r"^transformer_blocks\.(\d+)\.ff_context\.net\.2(?:\.proj)?\.(.*)$": r"double_blocks.\1.txt_mlp.fc_out.\2", + # 6. single_transformer_blocks mapping: + r"^single_transformer_blocks\.(\d+)\.attn\.norm_q\.(.*)$": r"single_blocks.\1.q_norm.\2", + r"^single_transformer_blocks\.(\d+)\.attn\.norm_k\.(.*)$": r"single_blocks.\1.k_norm.\2", + r"^single_transformer_blocks\.(\d+)\.attn\.to_q\.(.*)$": ( + r"single_blocks.\1.linear1.\2", + 0, + 4, + ), + r"^single_transformer_blocks\.(\d+)\.attn\.to_k\.(.*)$": ( + r"single_blocks.\1.linear1.\2", + 1, + 4, + ), + r"^single_transformer_blocks\.(\d+)\.attn\.to_v\.(.*)$": ( + r"single_blocks.\1.linear1.\2", + 2, + 4, + ), + r"^single_transformer_blocks\.(\d+)\.proj_mlp\.(.*)$": ( + r"single_blocks.\1.linear1.\2", + 3, + 4, + ), + # Corrected: map proj_out to modulation.linear rather than a separate proj_out branch. + r"^single_transformer_blocks\.(\d+)\.proj_out\.(.*)$": r"single_blocks.\1.linear2.\2", + r"^single_transformer_blocks\.(\d+)\.norm\.linear\.(.*)$": r"single_blocks.\1.modulation.linear.\2", + # 7. Final layers mapping: + r"^norm_out\.linear\.(.*)$": r"final_layer.adaLN_modulation.linear.\1", + r"^proj_out\.(.*)$": r"final_layer.linear.\1", + } + ) + + reverse_param_names_mapping: dict = field(default_factory=lambda: {}) + + patch_size: int = 2 + patch_size_t: int = 1 + in_channels: int = 16 + out_channels: int = 16 + num_attention_heads: int = 24 + attention_head_dim: int = 128 + mlp_ratio: float = 4.0 + num_layers: int = 20 + num_single_layers: int = 40 + num_refiner_layers: int = 2 + rope_axes_dim: tuple[int, int, int] = (16, 56, 56) + guidance_embeds: bool = False + dtype: torch.dtype | None = None + text_embed_dim: int = 4096 + pooled_projection_dim: int = 768 + rope_theta: int = 256 + qk_norm: str = "rms_norm" + exclude_lora_layers: list[str] = field( + default_factory=lambda: ["img_in", "txt_in", "time_in", "vector_in"] + ) + + def __post_init__(self): + super().__post_init__() + self.hidden_size: int = self.attention_head_dim * self.num_attention_heads + self.num_channels_latents: int = self.in_channels + + +@dataclass +class HunyuanVideoConfig(DiTConfig): + arch_config: DiTArchConfig = field(default_factory=HunyuanVideoArchConfig) + + prefix: str = "Hunyuan" diff --git a/sglang/python/sglang/multimodal_gen/configs/models/dits/ltx_2.py b/sglang/python/sglang/multimodal_gen/configs/models/dits/ltx_2.py new file mode 100644 index 0000000000000000000000000000000000000000..d2bee4ba0fa6c468caf5d605aaeaa971c7cadfd2 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/dits/ltx_2.py @@ -0,0 +1,176 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +from enum import Enum + +from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig + + +class LTXModelType(Enum): + """ + Model type enum mirroring upstream `LTXModelType`. + + Upstream reference: + - `LTX-2/packages/ltx-core/src/ltx_core/model/transformer/model.py::LTXModelType` + """ + + AudioVideo = "ltx av model" + VideoOnly = "ltx video only model" + AudioOnly = "ltx audio only model" + + def is_video_enabled(self) -> bool: + return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly) + + def is_audio_enabled(self) -> bool: + return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly) + + +class LTX2RopeType(str, Enum): + """ + Minimal RoPE type enum mirroring LTX-2 upstream `LTXRopeType`. + + Upstream reference: + - `LTX-2/packages/ltx-core/src/ltx_core/model/transformer/rope.py::LTXRopeType` + """ + + INTERLEAVED = "interleaved" + SPLIT = "split" + + +class LTX2AttentionFunction(str, Enum): + """ + Placeholder enum for upstream `AttentionFunction.DEFAULT`. + + Upstream reference: + - `LTX-2/packages/ltx-core/src/ltx_core/model/transformer/attention.py` + """ + + DEFAULT = "default" + + +def is_blocks(n: str, m) -> bool: + return "blocks" in n and str.isdigit(n.split(".")[-1]) + + +@dataclass +class LTX2ArchConfig(DiTArchConfig): + """Architecture configuration for LTX-2 Video Transformer.""" + + _fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks]) + + param_names_mapping: dict = field( + default_factory=lambda: { + # Parameter name mappings from HuggingFace checkpoint keys to SGLang module names. + # We use upstream variable names (patchify_proj, adaln_single) but HF uses different keys. + # + # HF key -> SGLang key (upstream naming) + r"^proj_in\.(.*)$": r"patchify_proj.\1", + r"^time_embed\.(.*)$": r"adaln_single.\1", + r"^audio_proj_in\.(.*)$": r"audio_patchify_proj.\1", + r"^audio_time_embed\.(.*)$": r"audio_adaln_single.\1", + # FeedForward + r"(.*)ff\.net\.0\.proj\.(.*)$": r"\1ff.proj_in.\2", + r"(.*)ff\.net\.2\.(.*)$": r"\1ff.proj_out.\2", + # Attention Norms + r"(.*)\.norm_q\.(.*)$": r"\1.q_norm.\2", + r"(.*)\.norm_k\.(.*)$": r"\1.k_norm.\2", + # Scale Shift Tables (Global) + r"^av_cross_attn_video_scale_shift\.(.*)$": r"av_ca_video_scale_shift_adaln_single.\1", + r"^av_cross_attn_audio_scale_shift\.(.*)$": r"av_ca_audio_scale_shift_adaln_single.\1", + r"^av_cross_attn_video_a2v_gate\.(.*)$": r"av_ca_a2v_gate_adaln_single.\1", + r"^av_cross_attn_audio_v2a_gate\.(.*)$": r"av_ca_v2a_gate_adaln_single.\1", + # Scale Shift Tables (Block Level) + # HF: scale_shift_table_a2v_ca_video -> SGLang: video_a2v_cross_attn_scale_shift_table + r"(.*)scale_shift_table_a2v_ca_video": r"\1video_a2v_cross_attn_scale_shift_table", + r"(.*)scale_shift_table_a2v_ca_audio": r"\1audio_a2v_cross_attn_scale_shift_table", + } + ) + + reverse_param_names_mapping: dict = field( + default_factory=lambda: { + # Reverse mapping: SGLang module names -> HF checkpoint keys (for saving). + r"^patchify_proj\.(.*)$": r"proj_in.\1", + r"^adaln_single\.(.*)$": r"time_embed.\1", + r"^audio_patchify_proj\.(.*)$": r"audio_proj_in.\1", + r"^audio_adaln_single\.(.*)$": r"audio_time_embed.\1", + # FeedForward + r"(.*)ff\.proj_in\.(.*)$": r"\1ff.net.0.proj.\2", + r"(.*)ff\.proj_out\.(.*)$": r"\1ff.net.2.\2", + # Attention Norms + r"(.*)\.q_norm\.(.*)$": r"\1.norm_q.\2", + r"(.*)\.k_norm\.(.*)$": r"\1.norm_k.\2", + # Scale Shift Tables (Global) + r"^av_ca_video_scale_shift_adaln_single\.(.*)$": r"av_cross_attn_video_scale_shift.\1", + r"^av_ca_audio_scale_shift_adaln_single\.(.*)$": r"av_cross_attn_audio_scale_shift.\1", + r"^av_ca_a2v_gate_adaln_single\.(.*)$": r"av_cross_attn_video_a2v_gate.\1", + r"^av_ca_v2a_gate_adaln_single\.(.*)$": r"av_cross_attn_audio_v2a_gate.\1", + # Scale Shift Tables (Block Level) + # SGLang: video_a2v_cross_attn_scale_shift_table -> HF: scale_shift_table_a2v_ca_video + r"(.*)video_a2v_cross_attn_scale_shift_table": r"\1scale_shift_table_a2v_ca_video", + r"(.*)audio_a2v_cross_attn_scale_shift_table": r"\1scale_shift_table_a2v_ca_audio", + } + ) + + lora_param_names_mapping: dict = field( + default_factory=lambda: { + # LoRA parameter name mappings from official repo format to HF format. + # This is applied before param_names_mapping when loading LoRA adapters. + # Will be populated if LoRA adapters use different naming conventions. + } + ) + + # Model type and attention configuration + model_type: LTXModelType = LTXModelType.AudioVideo + attention_type: LTX2AttentionFunction = LTX2AttentionFunction.DEFAULT + rope_type: LTX2RopeType = LTX2RopeType.INTERLEAVED + double_precision_rope: bool = False + + # Video parameters + num_attention_heads: int = 32 + attention_head_dim: int = 128 + in_channels: int = 128 + out_channels: int = 128 + num_layers: int = 48 + cross_attention_dim: int = 4096 + norm_eps: float = 1e-6 + caption_channels: int = 3840 + positional_embedding_theta: float = 10000.0 + positional_embedding_max_pos: list[int] | None = None + timestep_scale_multiplier: int = 1000 + use_middle_indices_grid: bool = True + + # Audio parameters + audio_num_attention_heads: int = 32 + audio_attention_head_dim: int = 64 + audio_in_channels: int = 128 + audio_out_channels: int = 128 + audio_cross_attention_dim: int = 2048 + audio_positional_embedding_max_pos: list[int] | None = None + av_ca_timestep_scale_multiplier: int = 1 + + # SGLang-specific parameters + patch_size: tuple[int, int, int] = (1, 2, 2) + text_len: int = 512 + + def __post_init__(self): + super().__post_init__() + # Video derived values + self.hidden_size = self.num_attention_heads * self.attention_head_dim + self.num_channels_latents = self.out_channels + if self.positional_embedding_max_pos is None: + self.positional_embedding_max_pos = [20, 2048, 2048] + + # Audio derived values + self.audio_hidden_size = ( + self.audio_num_attention_heads * self.audio_attention_head_dim + ) + if self.audio_positional_embedding_max_pos is None: + self.audio_positional_embedding_max_pos = [20] + + +@dataclass +class LTX2Config(DiTConfig): + """Configuration for LTX-2 Video Transformer.""" + + arch_config: LTX2ArchConfig = field(default_factory=LTX2ArchConfig) + + prefix: str = "ltx2" diff --git a/sglang/python/sglang/multimodal_gen/configs/models/dits/mova_audio.py b/sglang/python/sglang/multimodal_gen/configs/models/dits/mova_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..1240d5f44b0da82935611d7f74fd1fe36d1876a8 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/dits/mova_audio.py @@ -0,0 +1,67 @@ +# Copied and adapted from: mossVG/mova/diffusion/models/wan_audio_dit.py +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig + + +def _is_blocks(n: str, m) -> bool: + return "blocks" in n and str.isdigit(n.split(".")[-1]) + + +@dataclass +class MOVAAudioArchConfig(DiTArchConfig): + _fsdp_shard_conditions: list = field(default_factory=lambda: [_is_blocks]) + + param_names_mapping: dict = field( + default_factory=lambda: { + r"^blocks\.(\d+)\.ffn\.0\.(.*)$": r"blocks.\1.ffn.fc_in.\2", + r"^blocks\.(\d+)\.ffn\.2\.(.*)$": r"blocks.\1.ffn.fc_out.\2", + r"^blocks\.(\d+)\.norm3\.(.*)$": r"blocks.\1.self_attn_norm.\2", + r"^text_embedding\.0\.(.*)$": r"text_embedding.fc_in.\1", + r"^text_embedding\.2\.(.*)$": r"text_embedding.fc_out.\1", + r"^time_embedding\.0\.(.*)$": r"time_embedding.fc_in.\1", + r"^time_embedding\.2\.(.*)$": r"time_embedding.fc_out.\1", + r"^img_emb\.proj\.1\.(.*)$": r"img_emb.fc_in.\1", + r"^img_emb\.proj\.3\.(.*)$": r"img_emb.fc_out.\1", + } + ) + reverse_param_names_mapping: dict = field(default_factory=dict) + lora_param_names_mapping: dict = field(default_factory=dict) + + dim: int = 1536 + in_dim: int = 128 + ffn_dim: int = 6144 + out_dim: int = 128 + text_dim: int = 4096 + freq_dim: int = 256 + eps: float = 1e-6 + patch_size: tuple[int, int, int] = (1, 2, 2) + num_heads: int = 12 + num_layers: int = 30 + has_image_input: bool = False + has_image_pos_emb: bool = False + has_ref_conv: bool = False + add_control_adapter: bool = False + in_dim_control_adapter: int = 24 + separated_timestep: bool = False + require_vae_embedding: bool = False + require_clip_embedding: bool = False + fuse_vae_embedding_in_latents: bool = False + vae_type: str = "dac" + + def __post_init__(self): + super().__post_init__() + self.hidden_size = self.dim + self.num_attention_heads = self.num_heads + self.num_channels_latents = self.out_dim + assert ( + not self.has_image_input + ), "has_image_input must be False; it's a config from Diffsynth Studio, which means the model uses CLIP for image encoding (we don't)." + + +@dataclass +class MOVAAudioConfig(DiTConfig): + arch_config: DiTArchConfig = field(default_factory=MOVAAudioArchConfig) + prefix: str = "mova_audio" diff --git a/sglang/python/sglang/multimodal_gen/configs/models/dits/mova_video.py b/sglang/python/sglang/multimodal_gen/configs/models/dits/mova_video.py new file mode 100644 index 0000000000000000000000000000000000000000..66156be4fb825e203960906ec4ec0ebdd0af222d --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/dits/mova_video.py @@ -0,0 +1,66 @@ +# Copied and adapted from: mossVG/mova/diffusion/models/wan_video_dit.py +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig + + +def _is_blocks(n: str, m) -> bool: + return "blocks" in n and str.isdigit(n.split(".")[-1]) + + +@dataclass +class MOVAVideoArchConfig(DiTArchConfig): + _fsdp_shard_conditions: list = field(default_factory=lambda: [_is_blocks]) + + param_names_mapping: dict = field( + default_factory=lambda: { + r"^blocks\.(\d+)\.ffn\.0\.(.*)$": r"blocks.\1.ffn.fc_in.\2", + r"^blocks\.(\d+)\.ffn\.2\.(.*)$": r"blocks.\1.ffn.fc_out.\2", + r"^blocks\.(\d+)\.norm3\.(.*)$": r"blocks.\1.self_attn_norm.\2", + r"^text_embedding\.0\.(.*)$": r"text_embedding.fc_in.\1", + r"^text_embedding\.2\.(.*)$": r"text_embedding.fc_out.\1", + r"^time_embedding\.0\.(.*)$": r"time_embedding.fc_in.\1", + r"^time_embedding\.2\.(.*)$": r"time_embedding.fc_out.\1", + r"^img_emb\.proj\.1\.(.*)$": r"img_emb.fc_in.\1", + r"^img_emb\.proj\.3\.(.*)$": r"img_emb.fc_out.\1", + } + ) + reverse_param_names_mapping: dict = field(default_factory=dict) + lora_param_names_mapping: dict = field(default_factory=dict) + + dim: int = 5120 + in_dim: int = 16 + ffn_dim: int = 13824 + out_dim: int = 16 + text_dim: int = 4096 + freq_dim: int = 256 + eps: float = 1e-6 + patch_size: tuple[int, int, int] = (1, 2, 2) + num_heads: int = 40 + num_layers: int = 40 + has_image_input: bool = False + has_image_pos_emb: bool = False + has_ref_conv: bool = False + add_control_adapter: bool = False + in_dim_control_adapter: int = 24 + separated_timestep: bool = False + require_vae_embedding: bool = True + require_clip_embedding: bool = True + fuse_vae_embedding_in_latents: bool = False + + def __post_init__(self): + super().__post_init__() + self.hidden_size = self.dim + self.num_attention_heads = self.num_heads + self.num_channels_latents = self.out_dim + assert ( + not self.has_image_input + ), "has_image_input must be False; it's a config from Diffsynth Studio, which means the model uses CLIP for image encoding (we don't)." + + +@dataclass +class MOVAVideoConfig(DiTConfig): + arch_config: DiTArchConfig = field(default_factory=MOVAVideoArchConfig) + prefix: str = "mova_video" diff --git a/sglang/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py b/sglang/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py new file mode 100644 index 0000000000000000000000000000000000000000..be1dfe57ae71ba9691c1fbfbdaf281237b9f15a9 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py @@ -0,0 +1,63 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +from typing import Tuple + +from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig + + +@dataclass +class QwenImageArchConfig(DiTArchConfig): + patch_size: int = 1 + in_channels: int = 64 + out_channels: int | None = None + num_layers: int = 19 + num_single_layers: int = 38 + attention_head_dim: int = 128 + num_attention_heads: int = 24 + joint_attention_dim: int = 4096 + pooled_projection_dim: int = 768 + guidance_embeds: bool = False + axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) + zero_cond_t: bool = False + + stacked_params_mapping: list[tuple[str, str, str]] = field(default_factory=list) + + param_names_mapping: dict = field( + default_factory=lambda: { + # LoRA mappings + r"^(transformer_blocks\.\d+\.attn\..*\.lora_[AB])\.default$": r"\1", + # SVDquant mappings + r"(.*)\.add_qkv_proj\.(.+)$": r"\1.to_added_qkv.\2", + r"(transformer_blocks\.\d+\.(img_mlp|txt_mlp)\..*\.(smooth_factor_orig|wcscales))$": r"\1", + r".*\.wtscale$": r"", + } + ) + + def __post_init__(self): + super().__post_init__() + self.out_channels = self.out_channels or self.in_channels + self.hidden_size = self.num_attention_heads * self.attention_head_dim + self.num_channels_latents = self.out_channels + + +@dataclass +class QwenImageEditPlus_2511_ArchConfig(DiTArchConfig): + zero_cond_t: bool = True + + +@dataclass +class QwenImageDitConfig(DiTConfig): + arch_config: DiTArchConfig = field(default_factory=QwenImageArchConfig) + + prefix: str = "qwenimage" + + +@dataclass +class QwenImageEditPlus_2511_DitConfig(DiTConfig): + arch_config: DiTArchConfig = field( + default_factory=QwenImageEditPlus_2511_ArchConfig + ) + + prefix: str = "qwenimageedit" diff --git a/sglang/python/sglang/multimodal_gen/configs/models/dits/wanvideo.py b/sglang/python/sglang/multimodal_gen/configs/models/dits/wanvideo.py new file mode 100644 index 0000000000000000000000000000000000000000..3430c001f6d2d1c3c96eee16f85b9761228458bf --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/dits/wanvideo.py @@ -0,0 +1,105 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig + + +def is_blocks(n: str, m) -> bool: + return "blocks" in n and str.isdigit(n.split(".")[-1]) + + +@dataclass +class WanVideoArchConfig(DiTArchConfig): + _fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks]) + + param_names_mapping: dict = field( + default_factory=lambda: { + r"^patch_embedding\.(.*)$": r"patch_embedding.proj.\1", + r"^condition_embedder\.text_embedder\.linear_1\.(.*)$": r"condition_embedder.text_embedder.fc_in.\1", + r"^condition_embedder\.text_embedder\.linear_2\.(.*)$": r"condition_embedder.text_embedder.fc_out.\1", + r"^condition_embedder\.time_embedder\.linear_1\.(.*)$": r"condition_embedder.time_embedder.mlp.fc_in.\1", + r"^condition_embedder\.time_embedder\.linear_2\.(.*)$": r"condition_embedder.time_embedder.mlp.fc_out.\1", + r"^condition_embedder\.time_proj\.(.*)$": r"condition_embedder.time_modulation.linear.\1", + r"^condition_embedder\.image_embedder\.ff\.net\.0\.proj\.(.*)$": r"condition_embedder.image_embedder.ff.fc_in.\1", + r"^condition_embedder\.image_embedder\.ff\.net\.2\.(.*)$": r"condition_embedder.image_embedder.ff.fc_out.\1", + r"^blocks\.(\d+)\.attn1\.to_q\.(.*)$": r"blocks.\1.to_q.\2", + r"^blocks\.(\d+)\.attn1\.to_k\.(.*)$": r"blocks.\1.to_k.\2", + r"^blocks\.(\d+)\.attn1\.to_v\.(.*)$": r"blocks.\1.to_v.\2", + r"^blocks\.(\d+)\.attn1\.to_out\.0\.(.*)$": r"blocks.\1.to_out.\2", + r"^blocks\.(\d+)\.attn1\.norm_q\.(.*)$": r"blocks.\1.norm_q.\2", + r"^blocks\.(\d+)\.attn1\.norm_k\.(.*)$": r"blocks.\1.norm_k.\2", + r"^blocks\.(\d+)\.attn1\.attn_op\.local_attn\.proj_l\.(.*)$": r"blocks.\1.attn1.local_attn.proj_l.\2", + r"^blocks\.(\d+)\.attn2\.to_out\.0\.(.*)$": r"blocks.\1.attn2.to_out.\2", + r"^blocks\.(\d+)\.ffn\.net\.0\.proj\.(.*)$": r"blocks.\1.ffn.fc_in.\2", + r"^blocks\.(\d+)\.ffn\.net\.2\.(.*)$": r"blocks.\1.ffn.fc_out.\2", + r"^blocks\.(\d+)\.norm2\.(.*)$": r"blocks.\1.self_attn_residual_norm.norm.\2", + } + ) + + reverse_param_names_mapping: dict = field(default_factory=lambda: {}) + + # Some LoRA adapters use the original official layer names instead of hf layer names, + # so apply this before the param_names_mapping + lora_param_names_mapping: dict = field( + default_factory=lambda: { + r"^blocks\.(\d+)\.self_attn\.q\.(.*)$": r"blocks.\1.attn1.to_q.\2", + r"^blocks\.(\d+)\.self_attn\.k\.(.*)$": r"blocks.\1.attn1.to_k.\2", + r"^blocks\.(\d+)\.self_attn\.v\.(.*)$": r"blocks.\1.attn1.to_v.\2", + r"^blocks\.(\d+)\.self_attn\.o\.(.*)$": r"blocks.\1.attn1.to_out.0.\2", + r"^blocks\.(\d+)\.cross_attn\.q\.(.*)$": r"blocks.\1.attn2.to_q.\2", + r"^blocks\.(\d+)\.cross_attn\.k\.(.*)$": r"blocks.\1.attn2.to_k.\2", + r"^blocks\.(\d+)\.cross_attn\.v\.(.*)$": r"blocks.\1.attn2.to_v.\2", + r"^blocks\.(\d+)\.cross_attn\.o\.(.*)$": r"blocks.\1.attn2.to_out.0.\2", + r"^blocks\.(\d+)\.ffn\.0\.(.*)$": r"blocks.\1.ffn.fc_in.\2", + r"^blocks\.(\d+)\.ffn\.2\.(.*)$": r"blocks.\1.ffn.fc_out.\2", + } + ) + + patch_size: tuple[int, int, int] = (1, 2, 2) + text_len = 512 + num_attention_heads: int = 40 + attention_head_dim: int = 128 + in_channels: int = 16 + out_channels: int = 16 + text_dim: int = 4096 + freq_dim: int = 256 + ffn_dim: int = 13824 + num_layers: int = 40 + cross_attn_norm: bool = True + qk_norm: str = "rms_norm_across_heads" + eps: float = 1e-6 + image_dim: int | None = None + added_kv_proj_dim: int | None = None + rope_max_seq_len: int = 1024 + pos_embed_seq_len: int | None = None + exclude_lora_layers: list[str] = field(default_factory=lambda: ["embedder"]) + + # Wan MoE + boundary_ratio: float | None = None + + # Causal Wan + local_attn_size: int = ( + -1 + ) # Window size for temporal local attention (-1 indicates global attention) + sink_size: int = ( + 0 # Size of the attention sink, we keep the first `sink_size` frames unchanged when rolling the KV cache + ) + num_frames_per_block: int = 3 + sliding_window_num_frames: int = 21 + attention_type: str = "original" + sla_topk: float = 0.1 + + def __post_init__(self): + super().__post_init__() + self.out_channels = self.out_channels or self.in_channels + self.hidden_size = self.num_attention_heads * self.attention_head_dim + self.num_channels_latents = self.out_channels + + +@dataclass +class WanVideoConfig(DiTConfig): + arch_config: DiTArchConfig = field(default_factory=WanVideoArchConfig) + + prefix: str = "Wan" diff --git a/sglang/python/sglang/multimodal_gen/configs/models/dits/zimage.py b/sglang/python/sglang/multimodal_gen/configs/models/dits/zimage.py new file mode 100644 index 0000000000000000000000000000000000000000..33c50e0cb5ad6fbe8e0df42b8ae817060420749b --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/dits/zimage.py @@ -0,0 +1,78 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +from typing import Tuple + +from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig + + +def is_zimage_layer(n: str, m) -> bool: + """Returns if the module should be sharded for Z-Image model.""" + if "layers" in n and str.isdigit(n.split(".")[-1]): + return True + if ("noise_refiner" in n or "context_refiner" in n) and str.isdigit( + n.split(".")[-1] + ): + return True + return False + + +@dataclass +class ZImageArchConfig(DiTArchConfig): + all_patch_size: Tuple[int, ...] = (2,) + all_f_patch_size: Tuple[int, ...] = (1,) + in_channels: int = 16 + out_channels: int | None = None + dim: int = 3840 + num_layers: int = 30 + n_refiner_layers: int = 2 + num_attention_heads: int = 30 + n_kv_heads: int = 30 + norm_eps: float = 1e-5 + qk_norm: bool = True + cap_feat_dim: int = 2560 + rope_theta: float = 256.0 + t_scale: float = 1000.0 + axes_dims: Tuple[int, int, int] = (32, 48, 48) + axes_lens: Tuple[int, int, int] = (1024, 512, 512) + + _fsdp_shard_conditions: list = field(default_factory=lambda: [is_zimage_layer]) + + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=lambda: [ + # (param_name, shard_name, shard_id) + (".feed_forward.w13", ".feed_forward.w1", "gate"), + (".feed_forward.w13", ".feed_forward.w3", "up"), + ] + ) + + param_names_mapping: dict = field( + default_factory=lambda: { + r"(.*)\.feed_forward\.w1\.weight$": (r"\1.feed_forward.w13.weight", 0, 2), + r"(.*)\.feed_forward\.w3\.weight$": (r"\1.feed_forward.w13.weight", 1, 2), + r"(.*)\.feed_forward\.w1\.(lora_A|lora_B)$": ( + r"\1.feed_forward.w13.\2", + 0, + 2, + ), + r"(.*)\.feed_forward\.w3\.(lora_A|lora_B)$": ( + r"\1.feed_forward.w13.\2", + 1, + 2, + ), + } + ) + + def __post_init__(self): + super().__post_init__() + self.out_channels = self.out_channels or self.in_channels + self.num_channels_latents = self.in_channels + self.hidden_size = self.dim + + +@dataclass +class ZImageDitConfig(DiTConfig): + arch_config: ZImageArchConfig = field(default_factory=ZImageArchConfig) + + prefix: str = "zimage" diff --git a/sglang/python/sglang/multimodal_gen/configs/models/encoders/__init__.py b/sglang/python/sglang/multimodal_gen/configs/models/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0b25420eca0545f1735a055f230831770dacf93 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/encoders/__init__.py @@ -0,0 +1,29 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from sglang.multimodal_gen.configs.models.encoders.base import ( + BaseEncoderOutput, + EncoderConfig, + ImageEncoderConfig, + TextEncoderConfig, +) +from sglang.multimodal_gen.configs.models.encoders.clip import ( + CLIPTextConfig, + CLIPVisionConfig, +) +from sglang.multimodal_gen.configs.models.encoders.gemma_3 import Gemma3Config +from sglang.multimodal_gen.configs.models.encoders.llama import LlamaConfig +from sglang.multimodal_gen.configs.models.encoders.qwen3 import Qwen3TextConfig +from sglang.multimodal_gen.configs.models.encoders.t5 import T5Config + +__all__ = [ + "EncoderConfig", + "TextEncoderConfig", + "ImageEncoderConfig", + "BaseEncoderOutput", + "CLIPTextConfig", + "CLIPVisionConfig", + "LlamaConfig", + "Qwen3TextConfig", + "T5Config", + "Gemma3Config", +] diff --git a/sglang/python/sglang/multimodal_gen/configs/models/encoders/base.py b/sglang/python/sglang/multimodal_gen/configs/models/encoders/base.py new file mode 100644 index 0000000000000000000000000000000000000000..0e0568dfcf071a523065e16417708d1e33e3f849 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/encoders/base.py @@ -0,0 +1,92 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +from typing import Any + +import torch + +from sglang.multimodal_gen.configs.models.base import ArchConfig, ModelConfig +from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum + + +@dataclass +class EncoderArchConfig(ArchConfig): + _fsdp_shard_conditions: list = field(default_factory=lambda: []) + architectures: list[str] = field(default_factory=lambda: []) + _supported_attention_backends: set[AttentionBackendEnum] = field( + default_factory=lambda: { + AttentionBackendEnum.FA, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.SAGE_ATTN_3, + } + ) + output_hidden_states: bool = False + use_return_dict: bool = True + + +@dataclass +class TextEncoderArchConfig(EncoderArchConfig): + vocab_size: int = 0 + hidden_size: int = 0 + num_hidden_layers: int = 0 + num_attention_heads: int = 0 + pad_token_id: int = 0 + eos_token_id: int = 0 + text_len: int = 0 + hidden_state_skip_layer: int = 0 + decoder_start_token_id: int = 0 + output_past: bool = True + scalable_attention: bool = True + tie_word_embeddings: bool = False + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=list + ) # mapping from huggingface weight names to custom names + tokenizer_kwargs: dict[str, Any] = field(default_factory=dict) + _fsdp_shard_conditions: list = field(default_factory=lambda: []) + + def __post_init__(self) -> None: + self.tokenizer_kwargs = { + "truncation": True, + "max_length": self.text_len, + "return_tensors": "pt", + } + + +@dataclass +class ImageEncoderArchConfig(EncoderArchConfig): + pass + + +@dataclass +class BaseEncoderOutput: + last_hidden_state: torch.FloatTensor | None = None + pooler_output: torch.FloatTensor | None = None + hidden_states: tuple[torch.FloatTensor, ...] | None = None + attentions: tuple[torch.FloatTensor, ...] | None = None + attention_mask: torch.Tensor | None = None + + +@dataclass +class EncoderConfig(ModelConfig): + arch_config: ArchConfig = field(default_factory=EncoderArchConfig) + + prefix: str = "" + quant_config: QuantizationConfig | None = None + lora_config: Any | None = None + + +@dataclass +class TextEncoderConfig(EncoderConfig): + arch_config: ArchConfig = field(default_factory=TextEncoderArchConfig) + + # Use the SP Group of the transformer as the TP Group of T5. + parallel_folding: bool = False + # "sp" or "ulysses" or "ring" + parallel_folding_mode: str = "sp" + + +@dataclass +class ImageEncoderConfig(EncoderConfig): + arch_config: ArchConfig = field(default_factory=ImageEncoderArchConfig) diff --git a/sglang/python/sglang/multimodal_gen/configs/models/encoders/clip.py b/sglang/python/sglang/multimodal_gen/configs/models/encoders/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..ff9a90b32a9339c7f57fa2f227d636ebae125503 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/encoders/clip.py @@ -0,0 +1,101 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.encoders.base import ( + ImageEncoderArchConfig, + ImageEncoderConfig, + TextEncoderArchConfig, + TextEncoderConfig, +) +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum + + +def _is_transformer_layer(n: str, m) -> bool: + return "layers" in n and str.isdigit(n.split(".")[-1]) + + +def _is_embeddings(n: str, m) -> bool: + return n.endswith("embeddings") + + +@dataclass +class CLIPTextArchConfig(TextEncoderArchConfig): + vocab_size: int = 49408 + hidden_size: int = 512 + intermediate_size: int = 2048 + projection_dim: int = 512 + num_hidden_layers: int = 12 + num_attention_heads: int = 8 + max_position_embeddings: int = 77 + hidden_act: str = "quick_gelu" + layer_norm_eps: float = 1e-5 + dropout: float = 0.0 + attention_dropout: float = 0.0 + initializer_range: float = 0.02 + initializer_factor: float = 1.0 + pad_token_id: int = 1 + bos_token_id: int = 49406 + eos_token_id: int = 49407 + text_len: int = 77 + _supported_attention_backends: set[AttentionBackendEnum] = field( + default_factory=lambda: { + AttentionBackendEnum.TORCH_SDPA, # Force TORCH_SDPA to support attention_mask + } + ) + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=lambda: [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + ) + _fsdp_shard_conditions: list = field( + default_factory=lambda: [_is_transformer_layer, _is_embeddings] + ) + + +@dataclass +class CLIPVisionArchConfig(ImageEncoderArchConfig): + hidden_size: int = 768 + intermediate_size: int = 3072 + projection_dim: int = 512 + num_hidden_layers: int = 12 + num_attention_heads: int = 12 + num_channels: int = 3 + image_size: int = 224 + patch_size: int = 32 + hidden_act: str = "quick_gelu" + layer_norm_eps: float = 1e-5 + dropout: float = 0.0 + attention_dropout: float = 0.0 + initializer_range: float = 0.02 + initializer_factor: float = 1.0 + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=lambda: [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + ) + + +@dataclass +class CLIPTextConfig(TextEncoderConfig): + arch_config: TextEncoderArchConfig = field(default_factory=CLIPTextArchConfig) + + num_hidden_layers_override: int | None = None + require_post_norm: bool | None = None + prefix: str = "clip" + + +@dataclass +class CLIPVisionConfig(ImageEncoderConfig): + arch_config: ImageEncoderArchConfig = field(default_factory=CLIPVisionArchConfig) + + num_hidden_layers_override: int | None = None + require_post_norm: bool | None = None + prefix: str = "clip" diff --git a/sglang/python/sglang/multimodal_gen/configs/models/encoders/gemma_3.py b/sglang/python/sglang/multimodal_gen/configs/models/encoders/gemma_3.py new file mode 100644 index 0000000000000000000000000000000000000000..64636985f9611fa69c0f3c3acf008f6ded876187 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/encoders/gemma_3.py @@ -0,0 +1,81 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.encoders.base import ( + TextEncoderArchConfig, + TextEncoderConfig, +) + + +def _is_transformer_layer(n: str, m) -> bool: + return "layers" in n and str.isdigit(n.split(".")[-1]) + + +def _is_embeddings(n: str, m) -> bool: + return n.endswith("embed_tokens") + + +def _is_final_norm(n: str, m) -> bool: + return n.endswith("norm") + + +@dataclass +class Gemma3ArchConfig(TextEncoderArchConfig): + """Minimal Gemma text-encoder config for tokenizer kwargs. + + Note: runtime will load the actual `text_encoder/` module from the model repo + (e.g. Gemma3Model) via transformers; this config mainly controls tokenization. + """ + + vocab_size: int = 32000 + hidden_size: int = 4096 + intermediate_size: int = 11008 + num_hidden_layers: int = 32 + num_attention_heads: int = 32 + num_key_value_heads: int | None = None + hidden_act: str = "gelu_pytorch_tanh" + max_position_embeddings: int = 2048 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + use_cache: bool = True + pad_token_id: int = 0 + bos_token_id: int = 1 + eos_token_id: int = 2 + pretraining_tp: int = 1 + tie_word_embeddings: bool = True + rope_theta: float = 10000.0 + rope_scaling: dict | None = None + rope_local_base_freq: float = 10000.0 + sliding_window: int = 4096 + layer_types: list[str] = field(default_factory=list) + query_pre_attn_scalar: int | None = None + attention_bias: bool = False + attention_dropout: float = 0.0 + mlp_bias: bool = False + head_dim: int | None = None + hidden_state_skip_layer: int = 2 + text_len: int = 1024 + + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=lambda: [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", "0"), # type: ignore + (".gate_up_proj", ".up_proj", "1"), # type: ignore + ] + ) + _fsdp_shard_conditions: list = field( + default_factory=lambda: [_is_transformer_layer, _is_embeddings, _is_final_norm] + ) + + +@dataclass +class Gemma3Config(TextEncoderConfig): + arch_config: TextEncoderArchConfig = field(default_factory=Gemma3ArchConfig) + + prefix: str = "gemma_3" diff --git a/sglang/python/sglang/multimodal_gen/configs/models/encoders/llama.py b/sglang/python/sglang/multimodal_gen/configs/models/encoders/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..41d98cab2eeb192e378d6500f2e20a5174038b59 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/encoders/llama.py @@ -0,0 +1,69 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.encoders.base import ( + TextEncoderArchConfig, + TextEncoderConfig, +) + + +def _is_transformer_layer(n: str, m) -> bool: + return "layers" in n and str.isdigit(n.split(".")[-1]) + + +def _is_embeddings(n: str, m) -> bool: + return n.endswith("embed_tokens") + + +def _is_final_norm(n: str, m) -> bool: + return n.endswith("norm") + + +@dataclass +class LlamaArchConfig(TextEncoderArchConfig): + vocab_size: int = 32000 + hidden_size: int = 4096 + intermediate_size: int = 11008 + num_hidden_layers: int = 32 + num_attention_heads: int = 32 + num_key_value_heads: int | None = None + hidden_act: str = "silu" + max_position_embeddings: int = 2048 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + use_cache: bool = True + pad_token_id: int = 0 + bos_token_id: int = 1 + eos_token_id: int = 2 + pretraining_tp: int = 1 + tie_word_embeddings: bool = False + rope_theta: float = 10000.0 + rope_scaling: float | None = None + attention_bias: bool = False + attention_dropout: float = 0.0 + mlp_bias: bool = False + head_dim: int | None = None + hidden_state_skip_layer: int = 2 + text_len: int = 256 + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=lambda: [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), # type: ignore + (".gate_up_proj", ".up_proj", 1), # type: ignore + ] + ) + _fsdp_shard_conditions: list = field( + default_factory=lambda: [_is_transformer_layer, _is_embeddings, _is_final_norm] + ) + + +@dataclass +class LlamaConfig(TextEncoderConfig): + arch_config: TextEncoderArchConfig = field(default_factory=LlamaArchConfig) + + prefix: str = "llama" diff --git a/sglang/python/sglang/multimodal_gen/configs/models/encoders/qwen3.py b/sglang/python/sglang/multimodal_gen/configs/models/encoders/qwen3.py new file mode 100644 index 0000000000000000000000000000000000000000..ed48da96a90f16af395102f9e3a0549027892801 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/encoders/qwen3.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Qwen3 text encoder configuration for SGLang diffusion models.""" + +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.encoders.base import ( + TextEncoderArchConfig, + TextEncoderConfig, +) + + +def _is_transformer_layer(n: str, m) -> bool: + return "layers" in n and str.isdigit(n.split(".")[-1]) + + +def _is_embeddings(n: str, m) -> bool: + return n.endswith("embed_tokens") + + +def _is_final_norm(n: str, m) -> bool: + return n.endswith("norm") + + +@dataclass +class Qwen3TextArchConfig(TextEncoderArchConfig): + """Architecture config for Qwen3 text encoder. + + Qwen3 is similar to LLaMA but with QK-Norm (RMSNorm on Q and K before attention). + """ + + vocab_size: int = 151936 + hidden_size: int = 2560 + intermediate_size: int = 9728 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + hidden_act: str = "silu" + max_position_embeddings: int = 40960 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + use_cache: bool = True + pad_token_id: int = 151643 + bos_token_id: int = 151643 + eos_token_id: int = 151645 + tie_word_embeddings: bool = True + rope_theta: float = 1000000.0 + rope_scaling: dict | None = None + attention_bias: bool = False + attention_dropout: float = 0.0 + mlp_bias: bool = False + head_dim: int = 128 + text_len: int = 512 + output_hidden_states: bool = True # Klein needs hidden states from layers 9, 18, 27 + + # Stacked params for weight loading with tensor parallelism + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=lambda: [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + ) + + # FSDP sharding conditions for CPU offload + _fsdp_shard_conditions: list = field( + default_factory=lambda: [_is_transformer_layer, _is_embeddings, _is_final_norm] + ) + + def __post_init__(self) -> None: + self.tokenizer_kwargs = { + "padding": "max_length", + "truncation": True, + "max_length": self.text_len, + "return_tensors": "pt", + } + + +@dataclass +class Qwen3TextConfig(TextEncoderConfig): + """Top-level config for Qwen3 text encoder.""" + + arch_config: TextEncoderArchConfig = field(default_factory=Qwen3TextArchConfig) + prefix: str = "qwen3" diff --git a/sglang/python/sglang/multimodal_gen/configs/models/encoders/qwen_image.py b/sglang/python/sglang/multimodal_gen/configs/models/encoders/qwen_image.py new file mode 100644 index 0000000000000000000000000000000000000000..d8268321378cbf7ee3e78a86709d33ff34f81acd --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/encoders/qwen_image.py @@ -0,0 +1,68 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.encoders.base import ( + TextEncoderArchConfig, + TextEncoderConfig, +) + + +def _is_transformer_layer(n: str, m) -> bool: + return "layers" in n and str.isdigit(n.split(".")[-1]) + + +def _is_embeddings(n: str, m) -> bool: + return n.endswith("embed_tokens") + + +def _is_final_norm(n: str, m) -> bool: + return n.endswith("norm") + + +@dataclass +class QwenImageArchConfig(TextEncoderArchConfig): + vocab_size: int = 32000 + hidden_size: int = 4096 + intermediate_size: int = 11008 + num_hidden_layers: int = 32 + num_attention_heads: int = 32 + num_key_value_heads: int | None = None + hidden_act: str = "silu" + max_position_embeddings: int = 2048 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + use_cache: bool = True + pad_token_id: int = -1 + eos_token_id: int = 2 + pretraining_tp: int = 1 + tie_word_embeddings: bool = False + rope_theta: float = 10000.0 + rope_scaling: float | None = None + attention_bias: bool = False + attention_dropout: float = 0.0 + mlp_bias: bool = False + head_dim: int | None = None + hidden_state_skip_layer: int = 2 + text_len: int = 512 + + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=lambda: [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), # type: ignore + (".gate_up_proj", ".up_proj", 1), # type: ignore + ] + ) + _fsdp_shard_conditions: list = field( + default_factory=lambda: [_is_transformer_layer, _is_embeddings, _is_final_norm] + ) + + +@dataclass +class Qwen2_5VLConfig(TextEncoderConfig): + arch_config: TextEncoderArchConfig = field(default_factory=QwenImageArchConfig) + # prefix: str = "qwen_image" diff --git a/sglang/python/sglang/multimodal_gen/configs/models/encoders/t5.py b/sglang/python/sglang/multimodal_gen/configs/models/encoders/t5.py new file mode 100644 index 0000000000000000000000000000000000000000..65856908896c1605fb9b86490976197ee77d0e11 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/encoders/t5.py @@ -0,0 +1,97 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import argparse +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.encoders.base import ( + TextEncoderArchConfig, + TextEncoderConfig, +) + + +def _is_transformer_layer(n: str, m) -> bool: + return "block" in n and str.isdigit(n.split(".")[-1]) + + +def _is_embeddings(n: str, m) -> bool: + return n.endswith("shared") + + +def _is_final_layernorm(n: str, m) -> bool: + return n.endswith("final_layer_norm") + + +@dataclass +class T5ArchConfig(TextEncoderArchConfig): + vocab_size: int = 32128 + d_model: int = 512 + d_kv: int = 64 + d_ff: int = 2048 + num_layers: int = 6 + num_decoder_layers: int | None = None + num_heads: int = 8 + relative_attention_num_buckets: int = 32 + relative_attention_max_distance: int = 128 + dropout_rate: float = 0.1 + layer_norm_epsilon: float = 1e-6 + initializer_factor: float = 1.0 + feed_forward_proj: str = "relu" + dense_act_fn: str = "" + is_gated_act: bool = False + is_encoder_decoder: bool = True + use_cache: bool = True + pad_token_id: int = 0 + eos_token_id: int = 1 + classifier_dropout: float = 0.0 + text_len: int = 512 + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=lambda: [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q", "q"), + (".qkv_proj", ".k", "k"), + (".qkv_proj", ".v", "v"), + ] + ) + _fsdp_shard_conditions: list = field( + default_factory=lambda: [ + _is_transformer_layer, + _is_embeddings, + _is_final_layernorm, + ] + ) + + # Referenced from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/configuration_t5.py + def __post_init__(self): + super().__post_init__() + act_info = self.feed_forward_proj.split("-") + self.dense_act_fn: str = act_info[-1] + self.is_gated_act: bool = act_info[0] == "gated" + if self.feed_forward_proj == "gated-gelu": + self.dense_act_fn = "gelu_new" + + self.tokenizer_kwargs = { + "padding": "max_length", + "truncation": True, + "max_length": self.text_len, + "add_special_tokens": True, + "return_attention_mask": True, + "return_tensors": "pt", + } + + +@dataclass +class T5Config(TextEncoderConfig): + arch_config: TextEncoderArchConfig = field(default_factory=T5ArchConfig) + + prefix: str = "t5" + # Use the SP Group of the transformer as the TP Group of T5. + parallel_folding: bool = False + # "sp" or "ulysses" or "ring" + parallel_folding_mode: str = "sp" + + @staticmethod + def add_cli_args( + parser: argparse.ArgumentParser, prefix: str = "t5-config" + ) -> argparse.ArgumentParser: + return parser diff --git a/sglang/python/sglang/multimodal_gen/configs/models/vaes/__init__.py b/sglang/python/sglang/multimodal_gen/configs/models/vaes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b9dd60fe5efb59b4a9eb36f22809b151ab2d704 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/vaes/__init__.py @@ -0,0 +1,13 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from sglang.multimodal_gen.configs.models.vaes.dac import DacVAEConfig +from sglang.multimodal_gen.configs.models.vaes.hunyuan3d import Hunyuan3DVAEConfig +from sglang.multimodal_gen.configs.models.vaes.hunyuanvae import HunyuanVAEConfig +from sglang.multimodal_gen.configs.models.vaes.wanvae import WanVAEConfig + +__all__ = [ + "DacVAEConfig", + "HunyuanVAEConfig", + "WanVAEConfig", + "Hunyuan3DVAEConfig", +] diff --git a/sglang/python/sglang/multimodal_gen/configs/models/vaes/base.py b/sglang/python/sglang/multimodal_gen/configs/models/vaes/base.py new file mode 100644 index 0000000000000000000000000000000000000000..344a37e6bf83b7e9b4edd4f2296577a7f4a8d0f4 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/vaes/base.py @@ -0,0 +1,156 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import argparse +import dataclasses +from dataclasses import dataclass, field +from typing import Any + +import torch + +from sglang.multimodal_gen.configs.models.base import ArchConfig, ModelConfig +from sglang.multimodal_gen.utils import StoreBoolean + + +@dataclass +class VAEArchConfig(ArchConfig): + scaling_factor: float | torch.Tensor = 0 + + temporal_compression_ratio: int = 4 + # or vae_scale_factor? + spatial_compression_ratio: int = 8 + + +@dataclass +class VAEConfig(ModelConfig): + arch_config: VAEArchConfig = field(default_factory=VAEArchConfig) + + # sglang-diffusion VAE-specific parameters + load_encoder: bool = True + load_decoder: bool = True + + tile_sample_min_height: int = 256 + tile_sample_min_width: int = 256 + tile_sample_min_num_frames: int = 16 + tile_sample_stride_height: int = 192 + tile_sample_stride_width: int = 192 + tile_sample_stride_num_frames: int = 12 + blend_num_frames: int = 0 + + use_tiling: bool = True + use_temporal_tiling: bool = True + use_parallel_tiling: bool = True + use_temporal_scaling_frames: bool = True + + def __post_init__(self): + self.blend_num_frames = ( + self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + ) + + def post_init(self): + pass + + @staticmethod + def add_cli_args(parser: Any, prefix: str = "vae-config") -> Any: + """Add CLI arguments for VAEConfig fields""" + parser.add_argument( + f"--{prefix}.load-encoder", + action=StoreBoolean, + dest=f"{prefix.replace('-', '_')}.load_encoder", + default=VAEConfig.load_encoder, + help="Whether to load the VAE encoder", + ) + parser.add_argument( + f"--{prefix}.load-decoder", + action=StoreBoolean, + dest=f"{prefix.replace('-', '_')}.load_decoder", + default=VAEConfig.load_decoder, + help="Whether to load the VAE decoder", + ) + parser.add_argument( + f"--{prefix}.tile-sample-min-height", + type=int, + dest=f"{prefix.replace('-', '_')}.tile_sample_min_height", + default=VAEConfig.tile_sample_min_height, + help="Minimum height for VAE tile sampling", + ) + parser.add_argument( + f"--{prefix}.tile-sample-min-width", + type=int, + dest=f"{prefix.replace('-', '_')}.tile_sample_min_width", + default=VAEConfig.tile_sample_min_width, + help="Minimum width for VAE tile sampling", + ) + parser.add_argument( + f"--{prefix}.tile-sample-min-num-frames", + type=int, + dest=f"{prefix.replace('-', '_')}.tile_sample_min_num_frames", + default=VAEConfig.tile_sample_min_num_frames, + help="Minimum number of frames for VAE tile sampling", + ) + parser.add_argument( + f"--{prefix}.tile-sample-stride-height", + type=int, + dest=f"{prefix.replace('-', '_')}.tile_sample_stride_height", + default=VAEConfig.tile_sample_stride_height, + help="Stride height for VAE tile sampling", + ) + parser.add_argument( + f"--{prefix}.tile-sample-stride-width", + type=int, + dest=f"{prefix.replace('-', '_')}.tile_sample_stride_width", + default=VAEConfig.tile_sample_stride_width, + help="Stride width for VAE tile sampling", + ) + parser.add_argument( + f"--{prefix}.tile-sample-stride-num-frames", + type=int, + dest=f"{prefix.replace('-', '_')}.tile_sample_stride_num_frames", + default=VAEConfig.tile_sample_stride_num_frames, + help="Stride number of frames for VAE tile sampling", + ) + parser.add_argument( + f"--{prefix}.blend-num-frames", + type=int, + dest=f"{prefix.replace('-', '_')}.blend_num_frames", + default=VAEConfig.blend_num_frames, + help="Number of frames to blend for VAE tile sampling", + ) + parser.add_argument( + f"--{prefix}.use-tiling", + action=StoreBoolean, + dest=f"{prefix.replace('-', '_')}.use_tiling", + default=VAEConfig.use_tiling, + help="Whether to use tiling for VAE", + ) + parser.add_argument( + f"--{prefix}.use-temporal-tiling", + action=StoreBoolean, + dest=f"{prefix.replace('-', '_')}.use_temporal_tiling", + default=VAEConfig.use_temporal_tiling, + help="Whether to use temporal tiling for VAE", + ) + parser.add_argument( + f"--{prefix}.use-parallel-tiling", + action=StoreBoolean, + dest=f"{prefix.replace('-', '_')}.use_parallel_tiling", + default=VAEConfig.use_parallel_tiling, + help="Whether to use parallel tiling for VAE", + ) + + return parser + + def get_vae_scale_factor(self): + return 2 ** (len(self.arch_config.block_out_channels) - 1) + + def encode_sample_mode(self): + return "argmax" + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> "VAEConfig": + kwargs = {} + for attr in dataclasses.fields(cls): + value = getattr(args, attr.name, None) + if value is not None: + kwargs[attr.name] = value + return cls(**kwargs) diff --git a/sglang/python/sglang/multimodal_gen/configs/models/vaes/dac.py b/sglang/python/sglang/multimodal_gen/configs/models/vaes/dac.py new file mode 100644 index 0000000000000000000000000000000000000000..63f59c6a52b12bfc3a0fc1edda60f84a59e1c29a --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/vaes/dac.py @@ -0,0 +1,30 @@ +# Copied and adapted from: mossVG/mova/diffusion/models/dac_vae.py +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass, field +from typing import List + +from sglang.multimodal_gen.configs.models.base import ArchConfig, ModelConfig + + +@dataclass +class DacVAEArchConfig(ArchConfig): + codebook_dim: int = 8 + codebook_size: int = 1024 + continuous: bool = True + decoder_dim: int = 2048 + decoder_rates: List[int] = field(default_factory=lambda: [8, 5, 4, 3, 2]) + encoder_dim: int = 128 + encoder_rates: List[int] = field(default_factory=lambda: [2, 3, 4, 5, 8]) + hop_length: int = 3840 + latent_dim: int = 128 + n_codebooks: int = 9 + quantizer_dropout: bool = False + sample_rate: int = 48000 + + +@dataclass +class DacVAEConfig(ModelConfig): + arch_config: DacVAEArchConfig = field(default_factory=DacVAEArchConfig) + load_encoder: bool = True + load_decoder: bool = True diff --git a/sglang/python/sglang/multimodal_gen/configs/models/vaes/flux.py b/sglang/python/sglang/multimodal_gen/configs/models/vaes/flux.py new file mode 100644 index 0000000000000000000000000000000000000000..33308640a17c06db811b586e67d340ecde1beb98 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/vaes/flux.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig + + +@dataclass +class FluxVAEArchConfig(VAEArchConfig): + spatial_compression_ratio: int = 1 + + base_dim: int = 96 + decoder_base_dim: int | None = None + z_dim: int = 16 + dim_mult: tuple[int, ...] = (1, 2, 4, 4) + num_res_blocks: int = 2 + attn_scales: tuple[float, ...] = () + temperal_downsample: tuple[bool, ...] = (False, True, True) + dropout: float = 0.0 + + is_residual: bool = False + in_channels: int = 3 + out_channels: int = 3 + patch_size: int | None = None + scale_factor_temporal: int = 4 + scale_factor_spatial: int = 8 + clip_output: bool = True + + +@dataclass +class Flux2VAEArchConfig(FluxVAEArchConfig): + pass + + +@dataclass +class FluxVAEConfig(VAEConfig): + arch_config: FluxVAEArchConfig = field(default_factory=FluxVAEArchConfig) + + use_feature_cache: bool = True + + use_tiling: bool = False + use_temporal_tiling: bool = False + use_parallel_tiling: bool = False + + def __post_init__(self): + self.blend_num_frames = ( + self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + ) * 2 + + def post_init(self): + # Calculate vae_scale_factor: prefer block_out_channels, fallback to dim_mult or scale_factor_spatial + if ( + hasattr(self.arch_config, "block_out_channels") + and self.arch_config.block_out_channels + ): + self.arch_config.vae_scale_factor = 2 ** ( + len(self.arch_config.block_out_channels) - 1 + ) + elif self.arch_config.dim_mult: + self.arch_config.vae_scale_factor = 2 ** ( + len(self.arch_config.dim_mult) - 1 + ) + else: + self.arch_config.vae_scale_factor = self.arch_config.scale_factor_spatial + + self.arch_config.spatial_compression_ratio = self.arch_config.vae_scale_factor + + +@dataclass +class Flux2VAEConfig(FluxVAEConfig): + arch_config: Flux2VAEArchConfig = field(default_factory=Flux2VAEArchConfig) diff --git a/sglang/python/sglang/multimodal_gen/configs/models/vaes/glmimage.py b/sglang/python/sglang/multimodal_gen/configs/models/vaes/glmimage.py new file mode 100644 index 0000000000000000000000000000000000000000..3c90cfef407aa74024de21a4985b26f0bff0e8ce --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/vaes/glmimage.py @@ -0,0 +1,60 @@ +from dataclasses import dataclass, field + +import torch + +from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig + + +@dataclass +class GlmImageVAEArchConfig(VAEArchConfig): + spatial_compression_ratio: int = 1 + + base_dim: int = 96 + decoder_base_dim: int | None = None + z_dim: int = 16 + dim_mult: tuple[int, ...] = (1, 2, 4, 4) + num_res_blocks: int = 2 + attn_scales: tuple[float, ...] = () + temperal_downsample: tuple[bool, ...] = (False, True, True) + dropout: float = 0.0 + + is_residual: bool = False + input_channels: int = 3 + out_channels: int = 3 + patch_size: int | None = None + scale_factor_temporal: int = 4 + scale_factor_spatial: int = 8 + clip_output: bool = True + + scaling_factor: float | torch.Tensor = 0 + + latents_mean: tuple[float, ...] | None = None + latents_std: tuple[float, ...] | None = None + shift_factor: float | None = None + latent_channels: int = 16 + in_channels: int = 16 + + +@dataclass +class GlmImageVAEConfig(VAEConfig): + arch_config: GlmImageVAEArchConfig = field(default_factory=GlmImageVAEArchConfig) + + use_feature_cache: bool = True + + use_tiling: bool = False + use_temporal_tiling: bool = False + use_parallel_tiling: bool = False + + def get_vae_scale_factor(self): + return 2 ** len(self.arch_config.temperal_downsample) + + def __post_init__(self): + self.blend_num_frames = ( + self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + ) * 2 + + def post_init(self): + self.arch_config.vae_scale_factor = 2 ** ( + len(self.arch_config.temperal_downsample) + ) + self.arch_config.spatial_compression_ratio = self.arch_config.vae_scale_factor diff --git a/sglang/python/sglang/multimodal_gen/configs/models/vaes/hunyuan3d.py b/sglang/python/sglang/multimodal_gen/configs/models/vaes/hunyuan3d.py new file mode 100644 index 0000000000000000000000000000000000000000..fd6d72adb5a3715f198a93da81405bd925d7ea83 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/vaes/hunyuan3d.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig + + +@dataclass +class Hunyuan3DVAEArchConfig(VAEArchConfig): + """Architecture config for Hunyuan3D VAE.""" + + latent_shape: tuple[int, ...] = (1024, 64) + scale_factor: float = 1.0 + + +@dataclass +class Hunyuan3DVAEConfig(VAEConfig): + """VAE configuration for Hunyuan3D.""" + + arch_config: Hunyuan3DVAEArchConfig = field(default_factory=Hunyuan3DVAEArchConfig) + subfolder: str = "hunyuan3d-dit-v2-0" + load_encoder: bool = False + load_decoder: bool = True diff --git a/sglang/python/sglang/multimodal_gen/configs/models/vaes/hunyuanvae.py b/sglang/python/sglang/multimodal_gen/configs/models/vaes/hunyuanvae.py new file mode 100644 index 0000000000000000000000000000000000000000..601b72d5730cbc5d0b292171e259a18355282a38 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/vaes/hunyuanvae.py @@ -0,0 +1,41 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig + + +@dataclass +class HunyuanVAEArchConfig(VAEArchConfig): + in_channels: int = 3 + out_channels: int = 3 + latent_channels: int = 16 + down_block_types: tuple[str, ...] = ( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ) + up_block_types: tuple[str, ...] = ( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ) + block_out_channels: tuple[int, ...] = (128, 256, 512, 512) + layers_per_block: int = 2 + act_fn: str = "silu" + norm_num_groups: int = 32 + scaling_factor: float = 0.476986 + spatial_compression_ratio: int = 8 + temporal_compression_ratio: int = 4 + mid_block_add_attention: bool = True + + def __post_init__(self): + self.spatial_compression_ratio: int = 2 ** (len(self.block_out_channels) - 1) + + +@dataclass +class HunyuanVAEConfig(VAEConfig): + arch_config: VAEArchConfig = field(default_factory=HunyuanVAEArchConfig) diff --git a/sglang/python/sglang/multimodal_gen/configs/models/vaes/ltx_audio.py b/sglang/python/sglang/multimodal_gen/configs/models/vaes/ltx_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..dfdb940ebcc38484c92eefc8c03006812f21f889 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/vaes/ltx_audio.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +from typing import Optional, Tuple + +from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig + + +@dataclass +class LTXAudioVAEArchConfig(VAEArchConfig): + # Architecture params + causality_axis: str = "height" + attn_resolutions: Optional[Tuple[int, ...]] = None + base_channels: int = 128 + latent_channels: int = 8 + output_channels: int = 2 + ch_mult: Tuple[int, ...] = (1, 2, 4) + num_res_blocks: int = 2 + norm_type: str = "pixel" + dropout: float = 0.0 + mid_block_add_attention: bool = False + sample_rate: int = 16000 + mel_hop_length: int = 160 + is_causal: bool = True + mel_bins: Optional[int] = 64 + double_z: bool = True + + +@dataclass +class LTXAudioVAEConfig(VAEConfig): + arch_config: LTXAudioVAEArchConfig = field(default_factory=LTXAudioVAEArchConfig) diff --git a/sglang/python/sglang/multimodal_gen/configs/models/vaes/ltx_video.py b/sglang/python/sglang/multimodal_gen/configs/models/vaes/ltx_video.py new file mode 100644 index 0000000000000000000000000000000000000000..3757cce63f746f478ddf404a79f37e755e66ba82 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/vaes/ltx_video.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +from typing import List + +from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig + + +@dataclass +class LTXVideoVAEArchConfig(VAEArchConfig): + # Architecture params + in_channels: int = 3 + latent_channels: int = 128 + out_channels: int = 3 + block_out_channels: List[int] = field( + default_factory=lambda: [256, 512, 1024, 2048] + ) + down_block_types: List[str] = field( + default_factory=lambda: [ + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ] + ) + spatio_temporal_scaling: List[bool] = field( + default_factory=lambda: [True, True, True, True] + ) + layers_per_block: List[int] = field(default_factory=lambda: [4, 6, 6, 2, 2]) + downsample_type: List[str] = field( + default_factory=lambda: [ + "spatial", + "temporal", + "spatiotemporal", + "spatiotemporal", + ] + ) + patch_size: int = 4 + patch_size_t: int = 1 + resnet_norm_eps: float = 1e-6 + encoder_causal: bool = True + encoder_spatial_padding_mode: str = "zeros" + + decoder_block_out_channels: List[int] = field( + default_factory=lambda: [256, 512, 1024] + ) + decoder_spatio_temporal_scaling: List[bool] = field( + default_factory=lambda: [True, True, True] + ) + decoder_layers_per_block: List[int] = field(default_factory=lambda: [5, 5, 5, 5]) + decoder_causal: bool = False + decoder_spatial_padding_mode: str = "reflect" + + +@dataclass +class LTXVideoVAEConfig(VAEConfig): + arch_config: LTXVideoVAEArchConfig = field(default_factory=LTXVideoVAEArchConfig) diff --git a/sglang/python/sglang/multimodal_gen/configs/models/vaes/qwenimage.py b/sglang/python/sglang/multimodal_gen/configs/models/vaes/qwenimage.py new file mode 100644 index 0000000000000000000000000000000000000000..af9fa9d2a0d884bd023e616891215a8e8189f6d9 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/vaes/qwenimage.py @@ -0,0 +1,53 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig + + +@dataclass +class QwenImageVAEArchConfig(VAEArchConfig): + spatial_compression_ratio: int = 1 + + base_dim: int = 96 + decoder_base_dim: int | None = None + z_dim: int = 16 + dim_mult: tuple[int, ...] = (1, 2, 4, 4) + num_res_blocks: int = 2 + attn_scales: tuple[float, ...] = () + temperal_downsample: tuple[bool, ...] = (False, True, True) + dropout: float = 0.0 + + is_residual: bool = False + input_channels: int = 3 + out_channels: int = 3 + patch_size: int | None = None + scale_factor_temporal: int = 4 + scale_factor_spatial: int = 8 + clip_output: bool = True + + +@dataclass +class QwenImageVAEConfig(VAEConfig): + arch_config: QwenImageVAEArchConfig = field(default_factory=QwenImageVAEArchConfig) + + use_feature_cache: bool = True + + use_tiling: bool = False + use_temporal_tiling: bool = False + use_parallel_tiling: bool = False + + def get_vae_scale_factor(self): + return 2 ** len(self.arch_config.temperal_downsample) + + def __post_init__(self): + self.blend_num_frames = ( + self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + ) * 2 + + def post_init(self): + self.arch_config.vae_scale_factor = 2 ** ( + len(self.arch_config.temperal_downsample) + ) + self.arch_config.spatial_compression_ratio = self.arch_config.vae_scale_factor diff --git a/sglang/python/sglang/multimodal_gen/configs/models/vaes/wanvae.py b/sglang/python/sglang/multimodal_gen/configs/models/vaes/wanvae.py new file mode 100644 index 0000000000000000000000000000000000000000..f61f67dc9d4519d910230adf94a740d323b91c7b --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/vaes/wanvae.py @@ -0,0 +1,91 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +import torch + +from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig + + +@dataclass +class WanVAEArchConfig(VAEArchConfig): + base_dim: int = 96 + decoder_base_dim: int | None = None + z_dim: int = 16 + dim_mult: tuple[int, ...] = (1, 2, 4, 4) + num_res_blocks: int = 2 + attn_scales: tuple[float, ...] = () + temperal_downsample: tuple[bool, ...] = (False, True, True) + dropout: float = 0.0 + latents_mean: tuple[float, ...] = ( + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ) + latents_std: tuple[float, ...] = ( + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ) + is_residual: bool = False + in_channels: int = 3 + out_channels: int = 3 + patch_size: int | None = None + scale_factor_temporal: int = 4 + scale_factor_spatial: int = 8 + clip_output: bool = True + + def __post_init__(self): + self.scaling_factor: torch.tensor = 1.0 / torch.tensor(self.latents_std).view( + 1, self.z_dim, 1, 1, 1 + ) + self.shift_factor: torch.tensor = torch.tensor(self.latents_mean).view( + 1, self.z_dim, 1, 1, 1 + ) + self.temporal_compression_ratio = self.scale_factor_temporal + self.spatial_compression_ratio = self.scale_factor_spatial + + +@dataclass +class WanVAEConfig(VAEConfig): + arch_config: WanVAEArchConfig = field(default_factory=WanVAEArchConfig) + use_feature_cache: bool = True + + use_tiling: bool = False + use_temporal_tiling: bool = False + use_parallel_tiling: bool = False + + use_parallel_encode: bool = True + use_parallel_decode: bool = True + + def __post_init__(self): + self.blend_num_frames = ( + self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + ) * 2 diff --git a/sglang/python/sglang/multimodal_gen/configs/models/vocoder/__init__.py b/sglang/python/sglang/multimodal_gen/configs/models/vocoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ce94b2036b38cab81cc64db396386969012da5ff --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/vocoder/__init__.py @@ -0,0 +1,3 @@ +from sglang.multimodal_gen.configs.models.vocoder.ltx_vocoder import LTXVocoderConfig + +__all__ = ["LTXVocoderConfig"] diff --git a/sglang/python/sglang/multimodal_gen/configs/models/vocoder/base.py b/sglang/python/sglang/multimodal_gen/configs/models/vocoder/base.py new file mode 100644 index 0000000000000000000000000000000000000000..0ac8150b69544761ebb10fc8007d0728548150f7 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/vocoder/base.py @@ -0,0 +1,29 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import argparse +import dataclasses +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.base import ArchConfig, ModelConfig + + +@dataclass +class VocoderArchConfig(ArchConfig): + in_channels: int = 128 + hidden_channels: int = 1024 + out_channels: int = 2 + + +@dataclass +class VocoderConfig(ModelConfig): + arch_config: VocoderArchConfig = field(default_factory=VocoderArchConfig) + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> "VocoderConfig": + kwargs = {} + for attr in dataclasses.fields(cls): + value = getattr(args, attr.name, None) + if value is not None: + kwargs[attr.name] = value + return cls(**kwargs) diff --git a/sglang/python/sglang/multimodal_gen/configs/models/vocoder/ltx_vocoder.py b/sglang/python/sglang/multimodal_gen/configs/models/vocoder/ltx_vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..503dc11bda2a8f7fe00e8114c7a28edd5fa2f58e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/models/vocoder/ltx_vocoder.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +from typing import List + +from sglang.multimodal_gen.configs.models.vocoder.base import ( + VocoderArchConfig, + VocoderConfig, +) + + +@dataclass +class LTXVocoderArchConfig(VocoderArchConfig): + # Architecture params + in_channels: int = 128 + hidden_channels: int = 1024 + out_channels: int = 2 + upsample_kernel_sizes: List[int] = field(default_factory=lambda: [3, 7, 11]) + upsample_factors: List[int] = field(default_factory=lambda: [6, 5, 2, 2, 2]) + resnet_kernel_sizes: List[int] = field(default_factory=lambda: [3, 7, 11]) + resnet_dilations: List[List[int]] = field( + default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + ) + leaky_relu_negative_slope: float = 0.1 + sample_rate: int = 24000 + + +@dataclass +class LTXVocoderConfig(VocoderConfig): + arch_config: LTXVocoderArchConfig = field(default_factory=LTXVocoderArchConfig) diff --git a/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/__init__.py b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f54b0ae9af6b3c6b21d13529d3162afcf90d309d --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/__init__.py @@ -0,0 +1,63 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from sglang.multimodal_gen.configs.pipeline_configs.base import ( + PipelineConfig, + SlidingTileAttnConfig, +) +from sglang.multimodal_gen.configs.pipeline_configs.diffusers_generic import ( + DiffusersGenericPipelineConfig, +) +from sglang.multimodal_gen.configs.pipeline_configs.flux import ( + Flux2KleinPipelineConfig, + Flux2PipelineConfig, + FluxPipelineConfig, +) +from sglang.multimodal_gen.configs.pipeline_configs.flux_finetuned import ( + Flux2FinetunedPipelineConfig, +) +from sglang.multimodal_gen.configs.pipeline_configs.helios import ( + HeliosDistilledConfig, + HeliosMidConfig, + HeliosT2VConfig, +) +from sglang.multimodal_gen.configs.pipeline_configs.hunyuan import ( + FastHunyuanConfig, + HunyuanConfig, +) +from sglang.multimodal_gen.configs.pipeline_configs.hunyuan3d import ( + Hunyuan3D2PipelineConfig, +) +from sglang.multimodal_gen.configs.pipeline_configs.ltx_2 import LTX2PipelineConfig +from sglang.multimodal_gen.configs.pipeline_configs.mova import MOVAPipelineConfig +from sglang.multimodal_gen.configs.pipeline_configs.wan import ( + SelfForcingWanT2V480PConfig, + WanI2V480PConfig, + WanI2V720PConfig, + WanT2V480PConfig, + WanT2V720PConfig, +) +from sglang.multimodal_gen.configs.pipeline_configs.zimage import ZImagePipelineConfig + +__all__ = [ + "DiffusersGenericPipelineConfig", + "HeliosDistilledConfig", + "HeliosMidConfig", + "HeliosT2VConfig", + "HunyuanConfig", + "FastHunyuanConfig", + "Hunyuan3D2PipelineConfig", + "FluxPipelineConfig", + "Flux2PipelineConfig", + "Flux2KleinPipelineConfig", + "Flux2FinetunedPipelineConfig", + "PipelineConfig", + "SlidingTileAttnConfig", + "MOVAPipelineConfig", + "WanT2V480PConfig", + "WanI2V480PConfig", + "WanT2V720PConfig", + "WanI2V720PConfig", + "SelfForcingWanT2V480PConfig", + "ZImagePipelineConfig", + "LTX2PipelineConfig", +] diff --git a/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/base.py b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4556354671afeee50e96ba68cfd2e4c5d68db536 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/base.py @@ -0,0 +1,856 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import json +import os +from collections.abc import Callable +from dataclasses import asdict, dataclass, field, fields +from enum import Enum, auto +from typing import Any + +import numpy as np +import PIL +import torch +from einops import rearrange + +from sglang.multimodal_gen.configs.models import ( + DiTConfig, + EncoderConfig, + ModelConfig, + VAEConfig, +) +from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput +from sglang.multimodal_gen.configs.models.encoders.t5 import T5Config +from sglang.multimodal_gen.configs.sample.sampling_params import DataType +from sglang.multimodal_gen.configs.utils import update_config_from_args +from sglang.multimodal_gen.runtime.distributed.communication_op import ( + sequence_model_parallel_all_gather, +) +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_sp_parallel_rank, + get_sp_world_size, +) +from sglang.multimodal_gen.runtime.models.vision_utils import get_default_height_width +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import ( + FlexibleArgumentParser, + StoreBoolean, + shallow_asdict, +) + +logger = init_logger(__name__) + + +# NOTE: possible duplication with DataType +# this may focus on the model's original ability +class ModelTaskType(Enum): + # TODO: check if I2V/TI2V models can work w/wo text + + I2V = auto() # Image to Video + T2V = auto() # Text to Video + TI2V = auto() # Text and Image to Video + + T2I = auto() # Text to Image + I2I = auto() # Image to Image + TI2I = auto() # Image to Image or Text-Image to Image + I2M = auto() # Image to Mesh + + def is_image_gen(self) -> bool: + return ( + self == ModelTaskType.T2I + or self == ModelTaskType.I2I + or self == ModelTaskType.TI2I + ) + + def requires_image_input(self) -> bool: + return ( + self == ModelTaskType.I2V + or self == ModelTaskType.I2I + or self == ModelTaskType.I2M + ) + + def accepts_image_input(self) -> bool: + return ( + self == ModelTaskType.I2V + or self == ModelTaskType.I2I + or self == ModelTaskType.TI2I + or self == ModelTaskType.TI2V + or self == ModelTaskType.I2M + ) + + def data_type(self) -> DataType: + if self == ModelTaskType.I2M: + return DataType.MESH + if self.is_image_gen(): + return DataType.IMAGE + else: + return DataType.VIDEO + + +class STA_Mode(str, Enum): + """STA (Sliding Tile Attention) modes.""" + + STA_INFERENCE = "STA_inference" + STA_SEARCHING = "STA_searching" + STA_TUNING = "STA_tuning" + STA_TUNING_CFG = "STA_tuning_cfg" + NONE = None + + +def preprocess_text(prompt: str) -> str: + return prompt + + +def postprocess_text(output: BaseEncoderOutput, _text_inputs) -> torch.tensor: + raise NotImplementedError + + +def shard_rotary_emb_for_sp(emb): + """ + Shard rotary embeddings [S, D] along sequence for SP. + If S is not divisible by SP degree, pad by repeating the last row. + """ + # Sequence Parallelism: slice image RoPE to local shard if enabled + try: + from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_sp_parallel_rank, + get_sp_world_size, + ) + + sp_world_size = get_sp_world_size() + except Exception: + sp_world_size = 1 + seq_len = emb.shape[0] + if seq_len % sp_world_size != 0: + pad_len = sp_world_size - (seq_len % sp_world_size) + pad = emb[-1:].repeat(pad_len, 1) + emb = torch.cat([emb, pad], dim=0) + if sp_world_size > 1: + try: + rank = get_sp_parallel_rank() + except Exception: + rank = 0 + seq_len = emb.shape[0] + local_len = seq_len // sp_world_size + start = rank * local_len + end = start + local_len + emb = emb[start:end] + return emb + else: + return emb + + +def maybe_unpad_latents(latents, batch): + # If SP padding was applied, remove extra tokens before reshaping + raw_shape = batch.raw_latent_shape + if len(raw_shape) == 3: + # Sequence format [B, S, D]: use seq_len directly + target_tokens = raw_shape[1] + else: + # Spatial format [B, C, H, W] or [B, C, T, H, W]: use width * height + width, height = raw_shape[-1], raw_shape[-2] + target_tokens = width * height + if latents.shape[1] > target_tokens: + latents = latents[:, :target_tokens, :] + return latents + + +# config for a single pipeline +@dataclass +class PipelineConfig: + """The base configuration class for a generation pipeline.""" + + task_type: ModelTaskType = ModelTaskType.I2I + skip_input_image_preprocess: bool = False + + model_path: str = "" + pipeline_config_path: str | None = None + + # precision and autocast + enable_autocast: bool = True + + # generation parameters + # controls the timestep embedding generation + should_use_guidance: bool = True + embedded_cfg_scale: float = 6.0 + flow_shift: float | None = None + disable_autocast: bool = False + + # Model configuration + dit_config: DiTConfig = field(default_factory=DiTConfig) + dit_precision: str = "bf16" + + # VAE configuration + vae_config: VAEConfig = field(default_factory=VAEConfig) + vae_precision: str = "fp32" + vae_tiling: bool = True + vae_sp: bool = True + + # Image encoder configuration + image_encoder_config: EncoderConfig = field(default_factory=EncoderConfig) + image_encoder_precision: str = "fp32" + + # Text encoder configuration + DEFAULT_TEXT_ENCODER_PRECISIONS = ("fp32",) + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (EncoderConfig(),) + ) + # See PRECISION_TO_TYPE for detailed mapping + text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp32",)) + text_encoder_extra_args: list[dict] = field(default_factory=lambda: [{}]) + + # image encoding + image_encoder_extra_args: dict = field(default_factory=lambda: {}) + + def postprocess_image(self, image): + return image.last_hidden_state + + preprocess_text_funcs: tuple[Callable[[str], str], ...] = field( + default_factory=lambda: (preprocess_text,) + ) + + # get prompt_embeds from encoder output + postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.tensor], ...] = ( + field(default_factory=lambda: (postprocess_text,)) + ) + + # STA (Sliding Tile Attention) parameters + mask_strategy_file_path: str | None = None + STA_mode: STA_Mode = STA_Mode.STA_INFERENCE + skip_time_steps: int = 15 + + # DMD parameters + dmd_denoising_steps: list[int] | None = field(default=None) + + # Wan2.2 TI2V parameters + boundary_ratio: float | None = None + + # Compilation + # enable_torch_compile: bool = False + + # calculate the adjust size for condition image + # width: original condition image width + # height: original condition image height + def calculate_condition_image_size(self, image, width, height) -> tuple[int, int]: + vae_scale_factor = self.vae_config.arch_config.spatial_compression_ratio + height, width = get_default_height_width(image, vae_scale_factor, height, width) + return width, height + + ## For timestep preparation stage + + def prepare_sigmas(self, sigmas, num_inference_steps): + return sigmas + + ## For ImageVAEEncodingStage + def preprocess_condition_image( + self, image, target_width, target_height, _vae_image_processor + ): + """ + preprocess the condition image, returns (image, final_image_width, final_image_height) + """ + return image.resize( + (target_width, target_height), PIL.Image.Resampling.LANCZOS + ), (target_width, target_height) + + def prepare_calculated_size(self, image): + return self.calculate_condition_image_size(image, image.width, image.height) + + def prepare_image_processor_kwargs(self, batch, neg=False): + return {} + + def postprocess_image_latent(self, latent_condition, batch): + vae_arch_config = self.vae_config.arch_config + spatial_compression_ratio = vae_arch_config.spatial_compression_ratio + temporal_compression_ratio = vae_arch_config.temporal_compression_ratio + num_frames = batch.num_frames + latent_height = batch.height // spatial_compression_ratio + latent_width = batch.width // spatial_compression_ratio + mask_lat_size = torch.ones(1, 1, num_frames, latent_height, latent_width) + mask_lat_size[:, :, 1:] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, + repeats=temporal_compression_ratio, + dim=2, + ) + mask_lat_size = torch.concat( + [first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2 + ) + mask_lat_size = mask_lat_size.view( + 1, + -1, + temporal_compression_ratio, + latent_height, + latent_width, + ) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) + image_latents = torch.concat([mask_lat_size, latent_condition], dim=1) + return image_latents + + def slice_noise_pred(self, noise, latents): + return noise + + def adjust_num_frames(self, num_frames): + return num_frames + + # tokenize the prompt + def tokenize_prompt(self, prompt: list[str], tokenizer, tok_kwargs) -> dict: + return tokenizer(prompt, **tok_kwargs) + + def prepare_latent_shape(self, batch, batch_size, num_frames): + height = batch.height // self.vae_config.arch_config.spatial_compression_ratio + width = batch.width // self.vae_config.arch_config.spatial_compression_ratio + + # Calculate latent shape + shape = ( + batch_size, + self.dit_config.num_channels_latents, + num_frames, + height, + width, + ) + + return shape + + def allow_set_num_frames(self): + return False + + def get_decode_scale_and_shift(self, device, dtype, vae): + vae_arch_config = self.vae_config.arch_config + scaling_factor = getattr(vae_arch_config, "scaling_factor", None) + if scaling_factor is None: + scaling_factor = getattr(vae, "scaling_factor", None) + + shift_factor = getattr(vae_arch_config, "shift_factor", None) + if shift_factor is None: + shift_factor = getattr(vae, "shift_factor", None) + return scaling_factor, shift_factor + + # called after latents are prepared + def maybe_pack_latents(self, latents, batch_size, batch): + return latents + + def maybe_prepare_latent_ids(self, latents): + return None + + # called after vae encode + def postprocess_vae_encode(self, image_latents, vae): + return image_latents + + # called after scale_and_shift, before vae decoding + def preprocess_decoding(self, latents, server_args=None, vae=None): + return latents + + def gather_latents_for_sp(self, latents): + # For video latents [B, C, T_local, H, W], gather along time dim=2 + latents = sequence_model_parallel_all_gather(latents, dim=2) + return latents + + def preprocess_vae_image(self, batch, vae_image_processor): + pass + + def shard_latents_for_sp(self, batch, latents): + # general logic for video models + sp_world_size, rank_in_sp_group = get_sp_world_size(), get_sp_parallel_rank() + if batch.enable_sequence_shard and sp_world_size > 1: + return latents, False + if latents.dim() != 5: + return latents, False + time_dim = latents.shape[2] + + # Pad to next multiple of SP degree if needed + if time_dim > 0 and time_dim % sp_world_size != 0: + logger.debug( + "Padding latents to next multiple of SP degree, performance is sub-optimal" + ) + pad_len = sp_world_size - (time_dim % sp_world_size) + pad = torch.zeros( + (*latents.shape[:2], pad_len, *latents.shape[3:]), + dtype=latents.dtype, + device=latents.device, + ) + latents = torch.cat([latents, pad], dim=2) + + assert latents.shape[2] % sp_world_size == 0 + sharded_tensor = rearrange( + latents, "b c (n t) h w -> b c n t h w", n=sp_world_size + ).contiguous() + sharded_tensor = sharded_tensor[:, :, rank_in_sp_group, :, :, :] + return sharded_tensor, True + + def get_pos_prompt_embeds(self, batch): + return batch.prompt_embeds + + def get_neg_prompt_embeds(self, batch): + return batch.negative_prompt_embeds + + def post_denoising_loop(self, latents, batch): + latents = maybe_unpad_latents(latents, batch) + return latents + + def post_decoding(self, frames, server_args): + return frames + + def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype): + return {} + + def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype): + return {} + + @staticmethod + def add_cli_args( + parser: FlexibleArgumentParser, prefix: str = "" + ) -> FlexibleArgumentParser: + prefix_with_dot = f"{prefix}." if (prefix.strip() != "") else "" + + # model_path will be conflicting with the model_path in ServerArgs, + # so we add it separately if prefix is not empty + if prefix_with_dot != "": + parser.add_argument( + f"--{prefix_with_dot}model-path", + type=str, + dest=f"{prefix_with_dot.replace('-', '_')}model_path", + default=PipelineConfig.model_path, + help="Path to the pretrained model", + ) + + parser.add_argument( + f"--{prefix_with_dot}pipeline-config-path", + type=str, + dest=f"{prefix_with_dot.replace('-', '_')}pipeline_config_path", + default=PipelineConfig.pipeline_config_path, + help="Path to the pipeline config", + ) + parser.add_argument( + f"--{prefix_with_dot}embedded-cfg-scale", + type=float, + dest=f"{prefix_with_dot.replace('-', '_')}embedded_cfg_scale", + default=PipelineConfig.embedded_cfg_scale, + help="Embedded CFG scale", + ) + parser.add_argument( + f"--{prefix_with_dot}flow-shift", + type=float, + dest=f"{prefix_with_dot.replace('-', '_')}flow_shift", + default=PipelineConfig.flow_shift, + help="Flow shift parameter", + ) + + # DiT configuration + parser.add_argument( + f"--{prefix_with_dot}dit-precision", + type=str, + dest=f"{prefix_with_dot.replace('-', '_')}dit_precision", + default=PipelineConfig.dit_precision, + choices=["fp32", "fp16", "bf16"], + help="Precision for the DiT model", + ) + + # VAE configuration + parser.add_argument( + f"--{prefix_with_dot}vae-precision", + type=str, + dest=f"{prefix_with_dot.replace('-', '_')}vae_precision", + default=PipelineConfig.vae_precision, + choices=["fp32", "fp16", "bf16"], + help="Precision for VAE", + ) + parser.add_argument( + f"--{prefix_with_dot}vae-tiling", + action=StoreBoolean, + dest=f"{prefix_with_dot.replace('-', '_')}vae_tiling", + default=PipelineConfig.vae_tiling, + help="Enable VAE tiling", + ) + parser.add_argument( + f"--{prefix_with_dot}vae-sp", + action=StoreBoolean, + dest=f"{prefix_with_dot.replace('-', '_')}vae_sp", + help="Enable VAE spatial parallelism", + ) + + # Text encoder configuration + parser.add_argument( + f"--{prefix_with_dot}text-encoder-precisions", + nargs="+", + type=str, + dest=f"{prefix_with_dot.replace('-', '_')}text_encoder_precisions", + default=PipelineConfig.DEFAULT_TEXT_ENCODER_PRECISIONS, + choices=["fp32", "fp16", "bf16"], + help="Precision for each text encoder", + ) + + # Image encoder configuration + parser.add_argument( + f"--{prefix_with_dot}image-encoder-precision", + type=str, + dest=f"{prefix_with_dot.replace('-', '_')}image_encoder_precision", + default=PipelineConfig.image_encoder_precision, + choices=["fp32", "fp16", "bf16"], + help="Precision for image encoder", + ) + + # DMD parameters + parser.add_argument( + f"--{prefix_with_dot}dmd-denoising-steps", + type=parse_int_list, + default=PipelineConfig.dmd_denoising_steps, + help="Comma-separated list of denoising steps (e.g., '1000,757,522')", + ) + + # Add VAE configuration arguments + from sglang.multimodal_gen.configs.models.vaes.base import VAEConfig + + VAEConfig.add_cli_args(parser, prefix=f"{prefix_with_dot}vae-config") + + # Add DiT configuration arguments + from sglang.multimodal_gen.configs.models.dits.base import DiTConfig + + DiTConfig.add_cli_args(parser, prefix=f"{prefix_with_dot}dit-config") + + # Add T5 configuration arguments + from sglang.multimodal_gen.configs.models.encoders.t5 import T5Config + + T5Config.add_cli_args(parser, prefix=f"{prefix_with_dot}t5-config") + + return parser + + def update_config_from_dict(self, args: dict[str, Any], prefix: str = "") -> None: + prefix_with_dot = f"{prefix}." if (prefix.strip() != "") else "" + update_config_from_args(self, args, prefix, pop_args=True) + update_config_from_args( + self.vae_config, args, f"{prefix_with_dot}vae_config", pop_args=True + ) + update_config_from_args( + self.dit_config, args, f"{prefix_with_dot}dit_config", pop_args=True + ) + for text_encoder_config in self.text_encoder_configs: + if isinstance(text_encoder_config, T5Config): + update_config_from_args( + text_encoder_config, + args, + f"{prefix_with_dot}t5_config", + pop_args=True, + ) + + @classmethod + def from_kwargs( + cls, kwargs: dict[str, Any], config_cli_prefix: str = "" + ) -> "PipelineConfig": + """ + Load PipelineConfig from kwargs Dictionary, as part of the ServerArg initialization process + kwargs: dictionary of kwargs + config_cli_prefix: prefix of CLI arguments for this PipelineConfig instance + """ + from sglang.multimodal_gen.registry import get_model_info + + prefix_with_dot = ( + f"{config_cli_prefix}." if (config_cli_prefix.strip() != "") else "" + ) + model_path: str | None = kwargs.get( + prefix_with_dot + "model_path", None + ) or kwargs.get("model_path") + pipeline_config_or_path: str | PipelineConfig | dict[str, Any] | None = ( + kwargs.get(prefix_with_dot + "pipeline_config", None) + or kwargs.get("pipeline_config") + ) + if model_path is None: + raise ValueError("model_path is required in kwargs") + + # Check if model_path is a safetensors file and pipeline_class_name is specified + pipeline_class_name = kwargs.get( + prefix_with_dot + "pipeline_class_name" + ) or kwargs.get("pipeline_class_name") + is_safetensors_file = os.path.isfile(model_path) and model_path.endswith( + ".safetensors" + ) + + # 1. Get the pipeline config class from the registry + from sglang.multimodal_gen.configs.pipeline_configs.flux import ( + Flux2PipelineConfig, + ) + from sglang.multimodal_gen.registry import get_pipeline_config_classes + + # If model_path is a safetensors file and pipeline_class_name is specified, + # try to get PipelineConfig from the registry first + if is_safetensors_file and pipeline_class_name: + config_classes = get_pipeline_config_classes(pipeline_class_name) + if config_classes is not None: + pipeline_config_cls, _ = config_classes + logger.info( + f"Detected safetensors file with {pipeline_class_name}, " + f"using {pipeline_config_cls.__name__} directly without model_index.json" + ) + else: + model_info = get_model_info( + model_path, + backend=kwargs.get("backend"), + model_id=kwargs.get("model_id"), + ) + if model_info is None: + from sglang.multimodal_gen.registry import ( + _PIPELINE_CONFIG_REGISTRY, + _discover_and_register_pipelines, + ) + + _discover_and_register_pipelines() + available_pipelines = list(_PIPELINE_CONFIG_REGISTRY.keys()) + raise ValueError( + f"Could not get model info for '{model_path}'. " + f"If using a safetensors file, please specify a valid pipeline_class_name. " + f"Available pipelines with config classes: {available_pipelines}" + ) + pipeline_config_cls = model_info.pipeline_config_cls + else: + model_info = get_model_info( + model_path, + backend=kwargs.get("backend"), + model_id=kwargs.get("model_id"), + ) + if model_info is None: + raise ValueError( + f"Could not get model info for '{model_path}'. " + f"If using a safetensors file, please specify pipeline_class_name" + ) + # 1.5. Adjust pipeline config for fine-tuned VAE if needed + pipeline_config_cls = model_info.pipeline_config_cls + vae_path = kwargs.get(prefix_with_dot + "vae_path") or kwargs.get("vae_path") + if vae_path is None: + component_paths = kwargs.get( + prefix_with_dot + "component_paths" + ) or kwargs.get("component_paths") + if isinstance(component_paths, dict): + vae_path = component_paths.get("vae") + + # Check if this is a Flux2 model with fal/FLUX.2-Tiny-AutoEncoder + if ( + isinstance(pipeline_config_cls, type) + and issubclass(pipeline_config_cls, Flux2PipelineConfig) + and vae_path is not None + and "FLUX.2-Tiny-AutoEncoder" in vae_path + ): + from sglang.multimodal_gen.configs.pipeline_configs.flux_finetuned import ( + Flux2FinetunedPipelineConfig, + ) + + pipeline_config_cls = Flux2FinetunedPipelineConfig + + pipeline_config = pipeline_config_cls() + + # 2. Load PipelineConfig from a json file or a PipelineConfig object if provided + if isinstance(pipeline_config_or_path, str): + pipeline_config.load_from_json(pipeline_config_or_path) + kwargs[prefix_with_dot + "pipeline_config_path"] = pipeline_config_or_path + elif isinstance(pipeline_config_or_path, PipelineConfig): + pipeline_config = pipeline_config_or_path + elif isinstance(pipeline_config_or_path, dict): + pipeline_config.update_pipeline_config(pipeline_config_or_path) + + # 3. Update PipelineConfig from CLI arguments if provided + kwargs[prefix_with_dot + "model_path"] = model_path + pipeline_config.update_config_from_dict(kwargs, config_cli_prefix) + return pipeline_config + + def check_pipeline_config(self) -> None: + if self.vae_sp and not self.vae_tiling: + raise ValueError( + "Currently enabling vae_sp requires enabling vae_tiling, please set --vae-tiling to True." + ) + + if len(self.text_encoder_configs) != len(self.text_encoder_precisions): + raise ValueError( + f"Length of text encoder configs ({len(self.text_encoder_configs)}) must be equal to length of text encoder precisions ({len(self.text_encoder_precisions)})" + ) + + if len(self.text_encoder_configs) != len(self.preprocess_text_funcs): + raise ValueError( + f"Length of text encoder configs ({len(self.text_encoder_configs)}) must be equal to length of text preprocessing functions ({len(self.preprocess_text_funcs)})" + ) + + if len(self.preprocess_text_funcs) != len(self.postprocess_text_funcs): + raise ValueError( + f"Length of text postprocess functions ({len(self.postprocess_text_funcs)}) must be equal to length of text preprocessing functions ({len(self.preprocess_text_funcs)})" + ) + + def dump_to_json(self, file_path: str): + output_dict = shallow_asdict(self) + del_keys = [] + for key, value in output_dict.items(): + if isinstance(value, ModelConfig): + model_dict = asdict(value) + # Model Arch Config should be hidden away from the users + model_dict.pop("arch_config") + output_dict[key] = model_dict + elif isinstance(value, tuple) and all( + isinstance(v, ModelConfig) for v in value + ): + model_dicts = [] + for v in value: + model_dict = asdict(v) + # Model Arch Config should be hidden away from the users + model_dict.pop("arch_config") + model_dicts.append(model_dict) + output_dict[key] = model_dicts + elif isinstance(value, tuple) and all(callable(f) for f in value): + # Skip dumping functions + del_keys.append(key) + + for key in del_keys: + output_dict.pop(key, None) + + with open(file_path, "w") as f: + json.dump(output_dict, f, indent=2) + + def load_from_json(self, file_path: str): + with open(file_path) as f: + input_pipeline_dict = json.load(f) + self.update_pipeline_config(input_pipeline_dict) + + def update_pipeline_config(self, source_pipeline_dict: dict[str, Any]) -> None: + for f in fields(self): + key = f.name + if key in source_pipeline_dict: + current_value = getattr(self, key) + new_value = source_pipeline_dict[key] + + # If it's a nested ModelConfig, update it recursively + if isinstance(current_value, ModelConfig): + current_value.update_model_config(new_value) + elif isinstance(current_value, tuple) and all( + isinstance(v, ModelConfig) for v in current_value + ): + assert len(current_value) == len( + new_value + ), "Users shouldn't delete or add text encoder config objects in your json" + for target_config, source_config in zip( + current_value, new_value, strict=True + ): + target_config.update_model_config(source_config) + else: + setattr(self, key, new_value) + + if hasattr(self, "__post_init__"): + self.__post_init__() + + +@dataclass +class ImagePipelineConfig(PipelineConfig): + """Base config for image generation pipelines with token-like latents [B, S, D].""" + + def _prepare_sigmas(self, sigmas, num_inference_steps): + sigmas = ( + np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + if sigmas is None + else sigmas + ) + return sigmas + + def shard_latents_for_sp(self, batch, latents): + # latents: [B, H * W, C] + sp_world_size, rank_in_sp_group = get_sp_world_size(), get_sp_parallel_rank() + seq_len = latents.shape[1] + + # TODO: reuse code in PipelineConfig::shard_latents_for_sp + # Pad to next multiple of SP degree if needed + if seq_len % sp_world_size != 0: + pad_len = sp_world_size - (seq_len % sp_world_size) + pad = torch.zeros( + (*latents.shape[:1], pad_len, *latents.shape[2:]), + dtype=latents.dtype, + device=latents.device, + ) + latents = torch.cat([latents, pad], dim=1) + + sharded_tensor = rearrange( + latents, "b (n s) d -> b n s d", n=sp_world_size + ).contiguous() + sharded_tensor = sharded_tensor[:, rank_in_sp_group, :, :] + return sharded_tensor, True + + def gather_latents_for_sp(self, latents): + # For image latents [B, S_local, D], gather along sequence dim=1 + latents = sequence_model_parallel_all_gather(latents, dim=1) + return latents + + def _unpad_and_unpack_latents(self, latents, batch): + vae_scale_factor = self.vae_config.arch_config.vae_scale_factor + channels = self.dit_config.arch_config.in_channels + batch_size = latents.shape[0] + + height = 2 * (int(batch.height) // (vae_scale_factor * 2)) + width = 2 * (int(batch.width) // (vae_scale_factor * 2)) + + latents = maybe_unpad_latents(latents, batch) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + return latents, batch_size, channels, height, width + + +@dataclass +class SpatialImagePipelineConfig(ImagePipelineConfig): + """Base config for spatial image pipelines (e.g. GLM-Image) with 4D latents (B, C, H', W'). + + Overrides shard_latents_for_sp / gather_latents_for_sp to shard along the height dimension + so that each SP rank gets (B, C, H'_local, W') instead of using the token-style (B, S, C) path. + """ + + def shard_latents_for_sp(self, batch, latents): + # 4D latents (B, C, H', W') -> shard along H' (dim=2); otherwise fall back to base (B, S, C) + sp_world_size = get_sp_world_size() + if sp_world_size <= 1: + return latents, False + if latents.dim() != 4: + return super().shard_latents_for_sp(batch, latents) + + # (B, C, H', W') + _, _, h_lat, w_lat = latents.shape + if h_lat % sp_world_size != 0: + pad_len = sp_world_size - (h_lat % sp_world_size) + pad = torch.zeros( + (latents.shape[0], latents.shape[1], pad_len, latents.shape[3]), + dtype=latents.dtype, + device=latents.device, + ) + latents = torch.cat([latents, pad], dim=2) + h_lat = latents.shape[2] + rank_in_sp_group = get_sp_parallel_rank() + chunk_size = h_lat // sp_world_size + h0 = rank_in_sp_group * chunk_size + h1 = h0 + chunk_size + sharded = latents[:, :, h0:h1, :].contiguous() + return sharded, True + + def gather_latents_for_sp(self, latents): + if get_sp_world_size() <= 1: + return latents + if latents.dim() != 4: + return super().gather_latents_for_sp(latents) + # Gather along dim=2 (H') to match shard_latents_for_sp + return sequence_model_parallel_all_gather(latents, dim=2) + + +@dataclass +class SlidingTileAttnConfig(PipelineConfig): + """Configuration for sliding tile attention.""" + + # Override any BaseConfig defaults as needed + # Add sliding tile specific parameters + window_size: int = 16 + stride: int = 8 + + # You can provide custom defaults for inherited fields + height: int = 576 + width: int = 1024 + + # Additional configuration specific to sliding tile attention + pad_to_square: bool = False + use_overlap_optimization: bool = True + + +def parse_int_list(value: str) -> list[int]: + """Parse a comma-separated string of integers into a list.""" + if not value: + return [] + return [int(x.strip()) for x in value.split(",")] diff --git a/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/diffusers_generic.py b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/diffusers_generic.py new file mode 100644 index 0000000000000000000000000000000000000000..96ab3c4736ffd43c25a399fa3b3dd853bbb31618 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/diffusers_generic.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Generic pipeline configuration for diffusers backend. + +This module provides a minimal pipeline configuration that works with the diffusers backend. +Since diffusers handles its own model loading and configuration, this config is intentionally minimal. +""" + +from dataclasses import dataclass, field +from typing import Any + +from sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig +from sglang.multimodal_gen.configs.pipeline_configs.base import ( + ModelTaskType, + PipelineConfig, +) + + +@dataclass +class DiffusersGenericPipelineConfig(PipelineConfig): + """ + Generic pipeline configuration for diffusers backend. + + This is a minimal configuration since the diffusers backend handles most + configuration internally. It provides sensible defaults for the required fields. + """ + + # default to T2I since it's the most common + task_type: ModelTaskType = ModelTaskType.T2I + + dit_precision: str = "bf16" + vae_precision: str = "bf16" + + should_use_guidance: bool = True + embedded_cfg_scale: float = 1.0 + flow_shift: float | None = None + disable_autocast: bool = True # let diffusers handle dtype + + # diffusers handles its own loading + dit_config: DiTConfig = field(default_factory=DiTConfig) + vae_config: VAEConfig = field(default_factory=VAEConfig) + image_encoder_config: EncoderConfig = field(default_factory=EncoderConfig) + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (EncoderConfig(),) + ) + text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp16",)) + + # VAE settings + vae_tiling: bool = False # diffusers handles this + vae_slicing: bool = False # slice VAE decode for lower memory usage + vae_sp: bool = False + + # Quantization config for pipeline-level quantization + # See: https://huggingface.co/docs/diffusers/main/en/quantization/overview + # Use PipelineQuantizationConfig for component-level control: + # from diffusers.quantizers import PipelineQuantizationConfig + # quantization_config = PipelineQuantizationConfig( + # quant_backend="bitsandbytes_4bit", + # quant_kwargs={"load_in_4bit": True, "bnb_4bit_compute_dtype": torch.bfloat16}, + # components_to_quantize=["transformer", "text_encoder_2"], + # ) + quantization_config: Any = None + + def check_pipeline_config(self) -> None: + """ + Override to skip most validation since diffusers handles its own config. + """ + pass + + def adjust_size(self, width, height, image): + """ + Pass through - diffusers handles size adjustments. + """ + return width, height + + def adjust_num_frames(self, num_frames): + """ + Pass through - diffusers handles frame count. + """ + return num_frames diff --git a/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/flux.py b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/flux.py new file mode 100644 index 0000000000000000000000000000000000000000..71d0c0128372c490eac13ba6097272e14014fc15 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/flux.py @@ -0,0 +1,696 @@ +import math +from dataclasses import dataclass, field +from typing import Callable, List, Optional + +import PIL +import torch +from diffusers.image_processor import VaeImageProcessor + +from sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig +from sglang.multimodal_gen.configs.models.dits.flux import FluxConfig +from sglang.multimodal_gen.configs.models.encoders import ( + BaseEncoderOutput, + CLIPTextConfig, + T5Config, + TextEncoderConfig, +) +from sglang.multimodal_gen.configs.models.encoders.base import TextEncoderArchConfig +from sglang.multimodal_gen.configs.models.encoders.qwen3 import Qwen3TextConfig +from sglang.multimodal_gen.configs.models.encoders.qwen_image import ( + _is_transformer_layer, +) +from sglang.multimodal_gen.configs.models.vaes.flux import Flux2VAEConfig, FluxVAEConfig +from sglang.multimodal_gen.configs.pipeline_configs.base import ( + ImagePipelineConfig, + ModelTaskType, + preprocess_text, + shard_rotary_emb_for_sp, +) +from sglang.multimodal_gen.configs.pipeline_configs.hunyuan import ( + clip_postprocess_text, + clip_preprocess_text, +) +from sglang.multimodal_gen.configs.pipeline_configs.qwen_image import _pack_latents +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device + + +def t5_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor: + return outputs.last_hidden_state + + +@dataclass +class FluxPipelineConfig(ImagePipelineConfig): + """Configuration for the FLUX pipeline.""" + + embedded_cfg_scale: float = 3.5 + + task_type: ModelTaskType = ModelTaskType.T2I + + vae_tiling: bool = False + + vae_sp: bool = False + + dit_config: DiTConfig = field(default_factory=FluxConfig) + # VAE + vae_config: VAEConfig = field(default_factory=FluxVAEConfig) + + enable_autocast: bool = False + + # Text encoding stage + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (CLIPTextConfig(), T5Config()) + ) + + text_encoder_precisions: tuple[str, ...] = field( + default_factory=lambda: ("bf16", "bf16") + ) + + preprocess_text_funcs: tuple[Callable[[str], str], ...] = field( + default_factory=lambda: (clip_preprocess_text, preprocess_text), + ) + + postprocess_text_funcs: tuple[Callable[[str], str], ...] = field( + default_factory=lambda: (clip_postprocess_text, t5_postprocess_text) + ) + + text_encoder_extra_args: list[dict] = field( + default_factory=lambda: [ + dict( + max_length=77, + padding="max_length", + truncation=True, + return_overflowing_tokens=False, + return_length=False, + ), + None, + ] + ) + + def prepare_sigmas(self, sigmas, num_inference_steps): + return self._prepare_sigmas(sigmas, num_inference_steps) + + def prepare_latent_shape(self, batch, batch_size, num_frames): + height = 2 * ( + batch.height // (self.vae_config.arch_config.vae_scale_factor * 2) + ) + width = 2 * (batch.width // (self.vae_config.arch_config.vae_scale_factor * 2)) + num_channels_latents = self.dit_config.arch_config.in_channels // 4 + shape = (batch_size, num_channels_latents, height, width) + return shape + + def maybe_pack_latents(self, latents, batch_size, batch): + height = 2 * ( + batch.height // (self.vae_config.arch_config.vae_scale_factor * 2) + ) + width = 2 * (batch.width // (self.vae_config.arch_config.vae_scale_factor * 2)) + num_channels_latents = self.dit_config.arch_config.in_channels // 4 + # pack latents + return _pack_latents(latents, batch_size, num_channels_latents, height, width) + + def get_pos_prompt_embeds(self, batch): + return batch.prompt_embeds[1] + + def get_neg_prompt_embeds(self, batch): + return batch.negative_prompt_embeds[1] + + def _prepare_latent_image_ids(self, original_height, original_width, device): + vae_scale_factor = self.vae_config.arch_config.vae_scale_factor + height = int(original_height) // (vae_scale_factor * 2) + width = int(original_width) // (vae_scale_factor * 2) + latent_image_ids = torch.zeros(height, width, 3, device=device) + latent_image_ids[..., 1] = ( + latent_image_ids[..., 1] + torch.arange(height, device=device)[:, None] + ) + latent_image_ids[..., 2] = ( + latent_image_ids[..., 2] + torch.arange(width, device=device)[None, :] + ) + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = ( + latent_image_ids.shape + ) + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids + + def get_freqs_cis(self, prompt_embeds, width, height, device, rotary_emb, batch): + txt_ids = torch.zeros(prompt_embeds.shape[1], 3, device=device) + img_ids = self._prepare_latent_image_ids( + original_height=height, + original_width=width, + device=device, + ) + + # NOTE(mick): prepare it here, to avoid unnecessary computations + img_cos, img_sin = rotary_emb.forward(img_ids) + img_cos = shard_rotary_emb_for_sp(img_cos) + img_sin = shard_rotary_emb_for_sp(img_sin) + + txt_cos, txt_sin = rotary_emb.forward(txt_ids) + + cos = torch.cat([txt_cos, img_cos], dim=0).to(device=device) + sin = torch.cat([txt_sin, img_sin], dim=0).to(device=device) + return cos, sin + + def post_denoising_loop(self, latents, batch): + # unpack latents for flux + ( + latents, + batch_size, + channels, + height, + width, + ) = self._unpad_and_unpack_latents(latents, batch) + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + return latents + + def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype): + return { + "freqs_cis": self.get_freqs_cis( + batch.prompt_embeds[1], + batch.width, + batch.height, + device, + rotary_emb, + batch, + ), + "pooled_projections": ( + batch.pooled_embeds[0] if batch.pooled_embeds else None + ), + } + + def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype): + return { + "freqs_cis": self.get_freqs_cis( + batch.negative_prompt_embeds[1], + batch.width, + batch.height, + device, + rotary_emb, + batch, + ), + "pooled_projections": ( + batch.neg_pooled_embeds[0] if batch.neg_pooled_embeds else None + ), + } + + +def _prepare_latent_ids( + latents: torch.Tensor, # (B, C, H, W) +): + r""" + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents (torch.Tensor): + Latent tensor of shape (B, C, H, W) + + Returns: + torch.Tensor: + Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0, + H=[0..H-1], W=[0..W-1], L=0 + """ + + batch_size, _, height, width = latents.shape + + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + layer = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, layer) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + return latent_ids + + +def _unpack_latents_with_ids( + x: torch.Tensor, x_ids: torch.Tensor +) -> list[torch.Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + x_ids = x_ids.to(device=x.device) + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + +def _patchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view( + batch_size, num_channels_latents, height // 2, 2, width // 2, 2 + ) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape( + batch_size, num_channels_latents * 4, height // 2, width // 2 + ) + return latents + + +def _unpatchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape( + batch_size, num_channels_latents // (2 * 2), 2, 2, height, width + ) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape( + batch_size, num_channels_latents // (2 * 2), height * 2, width * 2 + ) + return latents + + +def _prepare_text_ids( + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: Optional[torch.Tensor] = None, +): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + layer = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, layer) + out_ids.append(coords) + + return torch.stack(out_ids) + + +def _prepare_image_ids( + image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + scale: int = 10, +): + if not isinstance(image_latents, list): + raise ValueError( + f"Expected `image_latents` to be a list, got {type(image_latents)}." + ) + + # create time offset for each reference image + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod( + t, torch.arange(height), torch.arange(width), torch.arange(1) + ) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + +def flux2_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor: + hidden_states_layers: list[int] = [10, 20, 30] + + out = torch.stack([outputs.hidden_states[k] for k in hidden_states_layers], dim=1) + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape( + batch_size, seq_len, num_channels * hidden_dim + ) + + return prompt_embeds + + +def flux2_klein_postprocess_text( + outputs: BaseEncoderOutput, _text_inputs +) -> torch.Tensor: + hidden_states_layers: list[int] = [9, 18, 27] + + out = torch.stack([outputs.hidden_states[k] for k in hidden_states_layers], dim=1) + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape( + batch_size, seq_len, num_channels * hidden_dim + ) + + return prompt_embeds + + +@dataclass +class Flux2MistralTextArchConfig(TextEncoderArchConfig): + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=lambda: [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + ) + _fsdp_shard_conditions: list = field( + default_factory=lambda: [_is_transformer_layer] + ) + + def __post_init__(self): + self.tokenizer_kwargs = { + "padding": "max_length", + "truncation": True, + "max_length": 512, + "add_special_tokens": True, + "return_attention_mask": True, + "return_tensors": "pt", + } + + +@dataclass +class Flux2MistralTextConfig(TextEncoderConfig): + arch_config: TextEncoderArchConfig = field( + default_factory=Flux2MistralTextArchConfig + ) + + +def format_text_input(prompts: List[str], system_message: str = None): + # Remove [IMG] tokens from prompts to avoid Pixtral validation issues + # when truncation is enabled. The processor counts [IMG] tokens and fails + # if the count changes after truncation. + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] + + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + + +def flux_2_preprocess_text(prompt: str): + system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." + return format_text_input([prompt], system_message=system_message) + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents +def flux2_pack_latents(latents): + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + return latents + + +@dataclass +class Flux2PipelineConfig(FluxPipelineConfig): + embedded_cfg_scale: float = 4.0 + + task_type: ModelTaskType = ModelTaskType.TI2I + + text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("bf16",)) + + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (Flux2MistralTextConfig(),) + ) + preprocess_text_funcs: tuple[Callable[[str], str], ...] = field( + default_factory=lambda: (flux_2_preprocess_text,), + ) + + postprocess_text_funcs: tuple[Callable[[str], str], ...] = field( + default_factory=lambda: (flux2_postprocess_text,) + ) + vae_config: VAEConfig = field(default_factory=Flux2VAEConfig) + + def tokenize_prompt(self, prompts: list[str], tokenizer, tok_kwargs) -> dict: + # flatten to 1-d list + prompts = [p for prompt in prompts for p in prompt] + inputs = tokenizer.apply_chat_template( + prompts, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + # 2048 from official github repo, 512 from diffusers + max_length=512, + ) + + return inputs + + def prepare_latent_shape(self, batch, batch_size, num_frames): + height = 2 * ( + batch.height // (self.vae_config.arch_config.vae_scale_factor * 2) + ) + width = 2 * (batch.width // (self.vae_config.arch_config.vae_scale_factor * 2)) + num_channels_latents = self.dit_config.arch_config.in_channels + shape = (batch_size, num_channels_latents, height // 2, width // 2) + return shape + + def get_pos_prompt_embeds(self, batch): + return batch.prompt_embeds[0] + + def get_neg_prompt_embeds(self, batch): + return batch.negative_prompt_embeds[0] + + def calculate_condition_image_size( + self, image, width, height + ) -> Optional[tuple[int, int]]: + vae_scale_factor = self.vae_config.arch_config.vae_scale_factor + multiple_of = vae_scale_factor * 2 + + target_area: int = 1024 * 1024 + if width is not None and height is not None: + new_width, new_height = width, height + if width * height > target_area: + scale = math.sqrt(target_area / (width * height)) + new_width = int(width * scale) + new_height = int(height * scale) + + # Flux requires multiples of (VAE scale 8 * Patch size 2) + new_width = (new_width // multiple_of) * multiple_of + new_height = (new_height // multiple_of) * multiple_of + + if new_width != width or new_height != height: + return new_width, new_height + + return None + + def preprocess_condition_image( + self, image, target_width, target_height, vae_image_processor: VaeImageProcessor + ): + img = image.resize((target_width, target_height), PIL.Image.Resampling.LANCZOS) + image_width, image_height = img.size + vae_scale_factor = self.vae_config.arch_config.vae_scale_factor + multiple_of = vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = vae_image_processor.preprocess( + img, height=image_height, width=image_width, resize_mode="crop" + ) + return img, (image_width, image_height) + + def postprocess_image_latent(self, latent_condition, batch): + batch_size = batch.batch_size + # latent: (1, 128, 32, 32) + packed = self.maybe_pack_latents( + latent_condition, None, batch + ) # (1, 1024, 128) + packed = packed.squeeze(0) # (1024, 128) - remove batch dim + + # Concatenate all reference tokens along sequence dimension + image_latents = packed.unsqueeze(0) # (1, N*1024, 128) + image_latents = image_latents.repeat(batch_size, 1, 1) + return image_latents + + def prepare_condition_image_latent_ids(self, image_latents, batch): + image_latent_ids = _prepare_image_ids(image_latents) + image_latent_ids = image_latent_ids.repeat(batch.batch_size, 1, 1) + batch.condition_image_latent_ids = image_latent_ids.to(get_local_torch_device()) + + def get_freqs_cis(self, prompt_embeds, width, height, device, rotary_emb, batch): + txt_ids = _prepare_text_ids(prompt_embeds).to(device=device) + + img_ids = batch.latent_ids + if img_ids.ndim == 3: + img_ids = img_ids[0] + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + + # NOTE(mick): prepare it here, to avoid unnecessary computations + img_cos, img_sin = rotary_emb.forward(img_ids) + img_cos = shard_rotary_emb_for_sp(img_cos) + img_sin = shard_rotary_emb_for_sp(img_sin) + + if batch.image_latent is not None: + cond_ids = batch.condition_image_latent_ids + if cond_ids.ndim == 3: + cond_ids = cond_ids[0] + cond_cos, cond_sin = rotary_emb.forward(cond_ids) + cond_cos = shard_rotary_emb_for_sp(cond_cos) + cond_sin = shard_rotary_emb_for_sp(cond_sin) + img_cos = torch.cat([img_cos, cond_cos], dim=0) + img_sin = torch.cat([img_sin, cond_sin], dim=0) + + txt_cos, txt_sin = rotary_emb.forward(txt_ids) + + cos = torch.cat([txt_cos, img_cos], dim=0).to(device=device) + sin = torch.cat([txt_sin, img_sin], dim=0).to(device=device) + return cos, sin + + def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype): + return { + "freqs_cis": self.get_freqs_cis( + batch.prompt_embeds[0], + batch.width, + batch.height, + device, + rotary_emb, + batch, + ) + } + + def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype): + return {} + + def maybe_pack_latents(self, latents, batch_size, batch): + return flux2_pack_latents(latents) + + def maybe_prepare_latent_ids(self, latents): + return _prepare_latent_ids(latents) + + def postprocess_vae_encode(self, image_latents, vae): + # patchify + image_latents = _patchify_latents(image_latents) + return image_latents + + def _check_vae_has_bn(self, vae): + """Check if VAE has bn attribute (cached check to avoid repeated hasattr calls).""" + if not hasattr(self, "_vae_has_bn_cache"): + self._vae_has_bn_cache = hasattr(vae, "bn") and vae.bn is not None + return self._vae_has_bn_cache + + def preprocess_decoding(self, latents, server_args=None, vae=None): + """Preprocess latents before decoding. + + Dynamically adapts based on VAE type: + - Standard Flux2 VAE (has bn): needs unpatchify (128 channels -> 32 channels) + - Distilled VAE (no bn): keeps patchified latents (128 channels) + """ + if vae is not None and self._check_vae_has_bn(vae): + return _unpatchify_latents(latents) + return latents + + def get_decode_scale_and_shift(self, device, dtype, vae): + """Get scale and shift for decoding. + + Dynamically adapts based on VAE type: + - Standard Flux2 VAE (has bn): uses BatchNorm statistics + - Distilled VAE (no bn): uses scaling_factor from config + """ + vae_arch_config = self.vae_config.arch_config + + if self._check_vae_has_bn(vae): + # Standard Flux2 VAE: use BatchNorm statistics + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(device, dtype) + latents_bn_std = torch.sqrt( + vae.bn.running_var.view(1, -1, 1, 1) + vae_arch_config.batch_norm_eps + ).to(device, dtype) + return 1 / latents_bn_std, latents_bn_mean + + # Distilled VAE or unknown: use scaling_factor + scaling_factor = ( + getattr(vae.config, "scaling_factor", None) + if hasattr(vae, "config") + else getattr(vae, "scaling_factor", None) + ) or getattr(vae_arch_config, "scaling_factor", 0.13025) + + scale = torch.tensor(scaling_factor, device=device, dtype=dtype).view( + 1, 1, 1, 1 + ) + return 1 / scale, None + + def post_denoising_loop(self, latents, batch): + latent_ids = batch.latent_ids + latents = _unpack_latents_with_ids(latents, latent_ids) + + return latents + + def slice_noise_pred(self, noise, latents): + # remove noise over input image + noise = noise[:, : latents.size(1) :] + return noise + + +@dataclass +class Flux2KleinPipelineConfig(Flux2PipelineConfig): + # Klein is distilled, so no guidance embeddings + should_use_guidance: bool = False + + text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("bf16",)) + + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (Qwen3TextConfig(),) + ) + + preprocess_text_funcs: tuple[Callable[[str], str], ...] = field( + default_factory=lambda: (preprocess_text,), + ) + + postprocess_text_funcs: tuple[Callable[[str], str], ...] = field( + default_factory=lambda: (flux2_klein_postprocess_text,) + ) + + def tokenize_prompt(self, prompts: list[str], tokenizer, tok_kwargs) -> dict: + if prompts and isinstance(prompts[0], list): + prompts = [p for prompt in prompts for p in prompt] + + def _apply_chat_template(prompt: str) -> str: + messages = [{"role": "user", "content": prompt}] + try: + return tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + except TypeError: + return tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + texts = [_apply_chat_template(prompt) for prompt in prompts] + + tok_kwargs = dict(tok_kwargs or {}) + max_length = tok_kwargs.pop("max_length", 512) + padding = tok_kwargs.pop("padding", "max_length") + truncation = tok_kwargs.pop("truncation", True) + return_tensors = tok_kwargs.pop("return_tensors", "pt") + + return tokenizer( + texts, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + **tok_kwargs, + ) diff --git a/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/flux_finetuned.py b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/flux_finetuned.py new file mode 100644 index 0000000000000000000000000000000000000000..ebafdb5be47c4d3483a2958f2eb9f07656910585 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/flux_finetuned.py @@ -0,0 +1,103 @@ +""" +Pipeline configuration for Flux fine-tuned/distilled models. + +This module provides specialized handling for Flux fine-tuned models from HuggingFace, +such as fal/FLUX.2-Tiny-AutoEncoder and other community fine-tuned variants. + +Key differences from standard Flux2PipelineConfig: +- Handles custom VAE architectures loaded via auto_map +- Supports both patchified (128 channels) and unpatchified (32 channels) latents +- Dynamically adapts scale/shift based on VAE type +- Properly handles 5D latents (batch, channels, frames, height, width) for decoding +""" + +from dataclasses import dataclass + +import torch + +from sglang.multimodal_gen.configs.pipeline_configs.flux import ( + Flux2PipelineConfig, + _unpatchify_latents, +) + + +@dataclass +class Flux2FinetunedPipelineConfig(Flux2PipelineConfig): + """ + Pipeline configuration for Flux fine-tuned/distilled models. + + This configuration automatically detects and handles custom VAE architectures + (e.g., Flux2TinyAutoEncoder) loaded via HuggingFace's auto_map mechanism. + + Features: + - Automatic VAE type detection (standard vs. distilled) + - Proper handling of patchified/unpatchified latents + - Support for custom scaling factors from fine-tuned models + - 5D latents support for both single-frame and multi-frame generation + """ + + def preprocess_decoding( + self, latents: torch.Tensor, server_args=None, vae=None + ) -> torch.Tensor: + """ + Preprocess latents before decoding. + + Handles both standard Flux2 VAE and fine-tuned/distilled VAEs: + - Standard Flux2 VAE (has bn): needs unpatchify (128 channels -> 32 channels) + - Distilled/Finetuned VAE (no bn): keeps patchified latents (128 channels) + + Also handles 5D latents (batch, channels, frames, height, width) by converting + to 4D (batch, channels, height, width) for single-frame cases. + + Args: + latents: Input latents tensor, can be 4D or 5D + server_args: Server arguments (optional, for compatibility) + vae: VAE model instance for dynamic type detection + + Returns: + Preprocessed latents ready for VAE decoding + """ + # Handle 5D latents (batch, channels, frames, height, width) + if latents.ndim == 5: + batch_size, channels, frames, height, width = latents.shape + if frames == 1: + latents = latents.squeeze(2) + else: + latents = latents.permute(0, 2, 1, 3, 4).contiguous() + latents = latents.view(batch_size * frames, channels, height, width) + + if vae is not None and self._check_vae_has_bn(vae): + latents = _unpatchify_latents(latents) + return latents + + def get_decode_scale_and_shift(self, device, dtype, vae): + """ + Get scale and shift for decoding. + + Dynamically adapts based on VAE type: + - Standard Flux2 VAE (has bn): uses BatchNorm statistics + - Distilled/Finetuned VAE (no bn): uses scaling_factor from config + + Args: + device: Target device for tensors + dtype: Target dtype for tensors + vae: VAE model instance + + Returns: + Tuple of (scaling_factor, shift_factor) + - scaling_factor: Tensor or scalar to divide latents by + - shift_factor: Tensor or scalar to add to latents (None for distilled VAEs) + """ + vae_arch_config = self.vae_config.arch_config + + if self._check_vae_has_bn(vae): + # Standard Flux2 VAE: use BatchNorm statistics + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(device, dtype) + latents_bn_std = torch.sqrt( + vae.bn.running_var.view(1, -1, 1, 1) + vae_arch_config.batch_norm_eps + ).to(device, dtype) + return 1 / latents_bn_std, latents_bn_mean + + # Distilled/Finetuned VAE: Flux2TinyAutoEncoder doesn't need external scaling + scale = torch.tensor(1.0, device=device, dtype=dtype).view(1, 1, 1, 1) + return scale, None diff --git a/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/glm_image.py b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/glm_image.py new file mode 100644 index 0000000000000000000000000000000000000000..2801eeb37081d6891008d12bf3c5da2cb026c365 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/glm_image.py @@ -0,0 +1,98 @@ +from dataclasses import dataclass, field + +import torch +from diffusers.image_processor import VaeImageProcessor + +from sglang.multimodal_gen.configs.models import DiTConfig, VAEConfig +from sglang.multimodal_gen.configs.models.dits.glmimage import GlmImageDitConfig +from sglang.multimodal_gen.configs.models.encoders.base import EncoderConfig +from sglang.multimodal_gen.configs.models.encoders.t5 import T5Config +from sglang.multimodal_gen.configs.models.vaes.glmimage import GlmImageVAEConfig +from sglang.multimodal_gen.configs.pipeline_configs.base import ( + ModelTaskType, + SpatialImagePipelineConfig, +) + + +@dataclass +class GlmImagePipelineConfig(SpatialImagePipelineConfig): + """Configuration for the GlmImage pipeline.""" + + vae_precision: str = "bf16" + + should_use_guidance: bool = False + task_type: ModelTaskType = ModelTaskType.T2I + + vae_tiling: bool = False + + vae_sp: bool = False + + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (T5Config(),) + ) + + dit_config: DiTConfig = field(default_factory=GlmImageDitConfig) + # VAE + vae_config: VAEConfig = field(default_factory=GlmImageVAEConfig) + + # GLM-Image uses T5 text encoder; base default is EncoderConfig() which lacks + # parallel_folding and causes AttributeError + fallback to native T5 with missing weights. + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (T5Config(),) + ) + + enable_autocast: bool = False + + def __post_init__(self): + self.vae_scale_factor = self.vae_config.get_vae_scale_factor() + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def get_freqs_cis(self, batch, device, rotary_emb, dtype): + height = batch.height // self.vae_scale_factor + width = batch.width // self.vae_scale_factor + hidden_states = torch.empty(1, 1, height, width, device=device, dtype=dtype) + freqs_cis = rotary_emb(hidden_states) + return freqs_cis + + def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype): + return { + "prior_token_id": batch.prior_token_id, + "prior_token_drop": batch.prior_token_drop_cond, + "crop_coords": batch.crop_coords, + "target_size": batch.target_size, + "kv_caches": batch.kv_caches, + "kv_caches_mode": "read", + "freqs_cis": self.get_freqs_cis(batch, device, rotary_emb, dtype), + } + + def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype): + return { + "prior_token_id": batch.prior_token_id, + "prior_token_drop": batch.prior_token_drop_uncond, + "crop_coords": batch.crop_coords, + "target_size": batch.target_size, + "kv_caches": batch.kv_caches, + "kv_caches_mode": "skip", + "freqs_cis": self.get_freqs_cis(batch, device, rotary_emb, dtype), + } + + def get_decode_scale_and_shift(self, device, dtype, vae): + latents_mean = ( + torch.tensor(self.vae_config.latents_mean) + .view(1, self.vae_config.latent_channels, 1, 1) + .to(device, dtype) + ) + latents_std = ( + torch.tensor(self.vae_config.latents_std) + .view(1, self.vae_config.latent_channels, 1, 1) + .to(device, dtype) + ) + return 1.0 / latents_std, latents_mean + + def post_denoising_loop(self, latents, batch): + if getattr(batch, "kv_caches", None) is not None: + batch.kv_caches.clear() + return latents.bfloat16() + + def post_decoding(self, frames, server_args): + return self.image_processor.postprocess(frames, output_type="latent") diff --git a/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/helios.py b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/helios.py new file mode 100644 index 0000000000000000000000000000000000000000..d14e927691e00415a7c88b4f6f725ad7a602e64f --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/helios.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Callable +from dataclasses import dataclass, field + +import torch + +from sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig +from sglang.multimodal_gen.configs.models.dits.helios import HeliosConfig +from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput, T5Config +from sglang.multimodal_gen.configs.models.encoders.t5 import T5ArchConfig +from sglang.multimodal_gen.configs.models.vaes import WanVAEConfig +from sglang.multimodal_gen.configs.pipeline_configs.base import ( + ModelTaskType, + PipelineConfig, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +# Helios UMT5 max sequence length (used for both tokenizer and post-processing padding) +HELIOS_MAX_SEQUENCE_LENGTH = 226 + + +def umt5_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor: + """Post-process UMT5 text encoder outputs, padding to HELIOS_MAX_SEQUENCE_LENGTH tokens.""" + max_seq_len = HELIOS_MAX_SEQUENCE_LENGTH + mask: torch.Tensor = outputs.attention_mask + hidden_state: torch.Tensor = outputs.last_hidden_state + seq_lens = mask.gt(0).sum(dim=1).long() + assert torch.isnan(hidden_state).sum() == 0 + prompt_embeds = [u[:v] for u, v in zip(hidden_state, seq_lens, strict=True)] + prompt_embeds_tensor: torch.Tensor = torch.stack( + [ + torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) + for u in prompt_embeds + ], + dim=0, + ) + return prompt_embeds_tensor + + +@dataclass +class HeliosT2VConfig(PipelineConfig): + """Configuration for the Helios T2V pipeline.""" + + task_type: ModelTaskType = ModelTaskType.T2V + + # DiT + dit_config: DiTConfig = field(default_factory=HeliosConfig) + + # VAE (same as Wan) + vae_config: VAEConfig = field(default_factory=WanVAEConfig) + vae_tiling: bool = False + vae_sp: bool = False + + # Denoising stage + flow_shift: float | None = 1.0 + + # Text encoding stage (UMT5 is T5-compatible) + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: ( + T5Config(arch_config=T5ArchConfig(text_len=HELIOS_MAX_SEQUENCE_LENGTH)), + ) + ) + postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.Tensor], ...] = ( + field(default_factory=lambda: (umt5_postprocess_text,)) + ) + + # Precision for each component + precision: str = "bf16" + vae_precision: str = "fp32" + text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp32",)) + + # Helios-specific chunked denoising params + num_latent_frames_per_chunk: int = 9 + history_sizes: list[int] = field(default_factory=lambda: [16, 2, 1]) + is_cfg_zero_star: bool = False + zero_steps: int = 1 + keep_first_frame: bool = True + + # Stage 2 (Pyramid SR) & Stage 3 (DMD) params + is_enable_stage2: bool = False + pyramid_num_stages: int = 3 + pyramid_num_inference_steps_list: list[int] = field( + default_factory=lambda: [10, 10, 10] + ) + is_distilled: bool = False + is_amplify_first_chunk: bool = False + scheduler_type: str = "unipc" + gamma: float = 1 / 3 + + def __post_init__(self): + self.vae_config.load_encoder = False + self.vae_config.load_decoder = True + + +@dataclass +class HeliosMidConfig(HeliosT2VConfig): + """Configuration for Helios-Mid (Stage 1 + Stage 2 pyramid SR).""" + + is_enable_stage2: bool = True + is_cfg_zero_star: bool = True + pyramid_num_inference_steps_list: list[int] = field( + default_factory=lambda: [20, 20, 20] + ) + + +@dataclass +class HeliosDistilledConfig(HeliosT2VConfig): + """Configuration for Helios-Distilled (Stage 1 + Stage 2 + Stage 3 DMD).""" + + is_enable_stage2: bool = True + is_distilled: bool = True + is_amplify_first_chunk: bool = True + scheduler_type: str = "dmd" + pyramid_num_inference_steps_list: list[int] = field( + default_factory=lambda: [10, 10, 10] + ) diff --git a/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/hunyuan.py b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/hunyuan.py new file mode 100644 index 0000000000000000000000000000000000000000..d45dfadb2582fb553d460b9836be9e093cbf2144 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/hunyuan.py @@ -0,0 +1,114 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import TypedDict + +import torch + +from sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig +from sglang.multimodal_gen.configs.models.dits import HunyuanVideoConfig +from sglang.multimodal_gen.configs.models.encoders import ( + BaseEncoderOutput, + CLIPTextConfig, + LlamaConfig, +) +from sglang.multimodal_gen.configs.models.vaes import HunyuanVAEConfig +from sglang.multimodal_gen.configs.pipeline_configs.base import ( + ModelTaskType, + PipelineConfig, +) + +PROMPT_TEMPLATE_ENCODE_VIDEO = ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" +) + + +class PromptTemplate(TypedDict): + template: str + crop_start: int + + +prompt_template_video: PromptTemplate = { + "template": PROMPT_TEMPLATE_ENCODE_VIDEO, + "crop_start": 95, +} + + +def llama_preprocess_text(prompt: str) -> str: + return prompt_template_video["template"].format(prompt) + + +def llama_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.tensor: + hidden_state_skip_layer = 2 + assert outputs.hidden_states is not None + hidden_states: tuple[torch.Tensor, ...] = outputs.hidden_states + last_hidden_state: torch.tensor = hidden_states[-(hidden_state_skip_layer + 1)] + crop_start = prompt_template_video.get("crop_start", -1) + last_hidden_state = last_hidden_state[:, crop_start:] + return last_hidden_state + + +def clip_preprocess_text(prompt: str) -> str: + return prompt + + +def clip_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.tensor: + pooler_output: torch.tensor = outputs.pooler_output + return pooler_output + + +@dataclass +class HunyuanConfig(PipelineConfig): + """Base configuration for HunYuan pipeline architecture.""" + + task_type: ModelTaskType = ModelTaskType.T2V + + # HunyuanConfig-specific parameters with defaults + # DiT + dit_config: DiTConfig = field(default_factory=HunyuanVideoConfig) + # VAE + vae_config: VAEConfig = field(default_factory=HunyuanVAEConfig) + # Denoising stage + embedded_cfg_scale: int = 6 + flow_shift: int = 7 + + # Text encoding stage + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (LlamaConfig(), CLIPTextConfig()) + ) + preprocess_text_funcs: tuple[Callable[[str], str], ...] = field( + default_factory=lambda: (llama_preprocess_text, clip_preprocess_text) + ) + postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.tensor], ...] = ( + field(default_factory=lambda: (llama_postprocess_text, clip_postprocess_text)) + ) + + # Precision for each component + dit_precision: str = "bf16" + vae_precision: str = "fp16" + text_encoder_precisions: tuple[str, ...] = field( + default_factory=lambda: ("fp16", "fp16") + ) + + def __post_init__(self): + self.vae_config.load_encoder = False + self.vae_config.load_decoder = True + + +@dataclass +class FastHunyuanConfig(HunyuanConfig): + """Configuration specifically optimized for FastHunyuan weights.""" + + # Override HunyuanConfig defaults + flow_shift: int = 17 + + # No need to re-specify guidance_scale or embedded_cfg_scale as they + # already have the desired values from HunyuanConfig diff --git a/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/hunyuan3d.py b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/hunyuan3d.py new file mode 100644 index 0000000000000000000000000000000000000000..07ee0a475b5cefab5fe17be8bd373a8e7aa89ef6 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/hunyuan3d.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +from typing import Optional + +from sglang.multimodal_gen.configs.models import DiTConfig, VAEConfig +from sglang.multimodal_gen.configs.models.dits.hunyuan3d import Hunyuan3DDiTConfig +from sglang.multimodal_gen.configs.models.vaes.hunyuan3d import Hunyuan3DVAEConfig +from sglang.multimodal_gen.configs.pipeline_configs.base import ( + ModelTaskType, + PipelineConfig, +) + + +@dataclass +class Hunyuan3D2PipelineConfig(PipelineConfig): + """Pipeline configuration for Hunyuan3D image-to-mesh generation.""" + + task_type: ModelTaskType = ModelTaskType.I2M + + # Subfolder paths + shape_subfolder: str = "hunyuan3d-dit-v2-0" + paint_subfolder: str = "hunyuan3d-paint-v2-0" + delight_subfolder: str = "hunyuan3d-delight-v2-0" + + # DiT configuration + dit_config: DiTConfig = field(default_factory=Hunyuan3DDiTConfig) + dit_precision: str = "fp16" + + # VAE configuration + vae_config: VAEConfig = field(default_factory=Hunyuan3DVAEConfig) + vae_precision: str = "fp32" + + # Shape model configuration + shape_model_path: Optional[str] = None + shape_use_safetensors: bool = True + shape_variant: Optional[str] = "fp16" + shape_num_inference_steps: int = 50 + guidance_scale: float = 5.0 + shape_box_v: float = 1.01 + shape_octree_resolution: int = 384 + shape_mc_level: float = 0.0 + shape_mc_algo: Optional[str] = "mc" + shape_num_chunks: int = 8000 + shape_output_type: str = "trimesh" + + # Delight model configuration + delight_enable: bool = True + delight_prompt: str = "" + delight_negative_prompt: str = "" + delight_strength: float = 1.0 + delight_num_inference_steps: int = 50 + delight_guidance_scale: float = 1.0 + delight_cfg_image: float = 1.5 + + # Paint model configuration + paint_enable: bool = True + paint_num_inference_steps: int = 30 + paint_guidance_scale: float = 2.0 + paint_resolution: int = 512 + paint_render_size: int = 2048 + paint_texture_size: int = 2048 + paint_use_remesh: bool = True + paint_save_glb: bool = True + paint_turbo_mode: bool = False + + def __post_init__(self): + self.vae_config.load_encoder = False + self.vae_config.load_decoder = True + + def prepare_latent_shape(self, batch, batch_size, num_frames): + latent_shape = self.vae_config.arch_config.latent_shape + shape = (batch_size, *latent_shape) + return shape diff --git a/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/ltx_2.py b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/ltx_2.py new file mode 100644 index 0000000000000000000000000000000000000000..b3438eab832652251e8d84488636d4b84120e481 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/ltx_2.py @@ -0,0 +1,583 @@ +import dataclasses +from dataclasses import field +from typing import Callable + +import torch + +from sglang.multimodal_gen.configs.models.dits.ltx_2 import LTX2Config +from sglang.multimodal_gen.configs.models.encoders import ( + BaseEncoderOutput, + EncoderConfig, +) +from sglang.multimodal_gen.configs.models.encoders.gemma_3 import Gemma3Config +from sglang.multimodal_gen.configs.models.vaes.ltx_audio import LTXAudioVAEConfig +from sglang.multimodal_gen.configs.pipeline_configs.base import ( + ModelTaskType, + PipelineConfig, + preprocess_text, +) +from sglang.multimodal_gen.runtime.distributed import ( + get_sp_parallel_rank, + get_sp_world_size, + sequence_model_parallel_all_gather, +) + + +def pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, +) -> torch.Tensor: + """ + Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and + per-layer in a masked fashion (only over non-padded positions). + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + device = text_hidden_states.device + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + mask = token_indices < sequence_lengths[:, None] + elif padding_side == "left": + start_indices = seq_len - sequence_lengths[:, None] + mask = token_indices >= start_indices + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len, 1, 1] + + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / ( + num_valid_positions + eps + ) + + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin( + dim=(1, 2), keepdim=True + ) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax( + dim=(1, 2), keepdim=True + ) + + normalized_hidden_states = (text_hidden_states - masked_mean) / ( + x_max - x_min + eps + ) + normalized_hidden_states = normalized_hidden_states * scale_factor + + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + + return normalized_hidden_states + + +def _gemma_postprocess_func( + outputs: BaseEncoderOutput, text_inputs: dict +) -> torch.Tensor: + # LTX-2 requires all hidden states concatenated for the connector + if hasattr(outputs, "hidden_states") and outputs.hidden_states is not None: + # outputs.hidden_states is a tuple of tensors + # We need to stack them along the last dimension and pack them + hidden_states = torch.stack(outputs.hidden_states, dim=-1) + attention_mask = text_inputs["attention_mask"] + sequence_lengths = attention_mask.sum(dim=-1) + # Assuming left padding for Gemma as per Diffusers + return pack_text_embeds(hidden_states, sequence_lengths, padding_side="left") + else: + raise AttributeError( + "Unsupported text encoder output: expected `hidden_states`." + ) + + +@dataclasses.dataclass +class LTX2PipelineConfig(PipelineConfig): + """Configuration for LTX-Video pipeline.""" + + task_type: ModelTaskType = ModelTaskType.TI2V + skip_input_image_preprocess: bool = True + dit_config: LTX2Config = field(default_factory=LTX2Config) + + # Model architecture + in_channels: int = 128 + out_channels: int = 128 + patch_size: int = 1 + patch_size_t: int = 1 + + # Audio VAE configuration + audio_vae_config: LTXAudioVAEConfig = field(default_factory=LTXAudioVAEConfig) + audio_vae_precision: str = "fp32" + audio_vae_temporal_compression_ratio: int = 4 + audio_vae_mel_compression_ratio: int = 4 + + @property + def vae_scale_factor(self): + return getattr(self.vae_config.arch_config, "spatial_compression_ratio", 32) + + @property + def vae_temporal_compression(self): + return getattr(self.vae_config.arch_config, "temporal_compression_ratio", 8) + + def prepare_latent_shape(self, batch, batch_size, num_frames): + """Return packed latent shape [B, seq, C] directly.""" + height = batch.height // self.vae_scale_factor + width = batch.width // self.vae_scale_factor + + post_patch_num_frames = num_frames // self.patch_size_t + post_patch_height = height // self.patch_size + post_patch_width = width // self.patch_size + seq_len = post_patch_num_frames * post_patch_height * post_patch_width + + num_channels = ( + self.in_channels * self.patch_size_t * self.patch_size * self.patch_size + ) + + shape = (batch_size, seq_len, num_channels) + return shape + + def prepare_audio_latent_shape(self, batch, batch_size, num_frames): + # Adapted from diffusers pipeline prepare_audio_latents + duration_s = num_frames / batch.fps + + sample_rate = self.audio_vae_config.arch_config.sample_rate + hop_length = self.audio_vae_config.arch_config.mel_hop_length + temporal_compression = self.audio_vae_temporal_compression_ratio + + latents_per_second = ( + float(sample_rate) / float(hop_length) / float(temporal_compression) + ) + latent_length = round(duration_s * latents_per_second) + + num_mel_bins = self.audio_vae_config.arch_config.mel_bins + mel_compression_ratio = self.audio_vae_mel_compression_ratio + latent_mel_bins = num_mel_bins // mel_compression_ratio + + # Default to 8 + num_channels_latents = self.audio_vae_config.arch_config.latent_channels + + shape = (batch_size, latent_length, num_channels_latents * latent_mel_bins) + + return shape + + # Text encoding stage (Gemma) + # LTX-2 needs separate contexts for video/audio streams. We model this as + # two logical encoders sharing the same underlying `text_encoder` module. + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (Gemma3Config(),) + ) + text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("bf16",)) + text_encoder_extra_args: list[dict] = field(default_factory=lambda: [{}]) + + preprocess_text_funcs: tuple[Callable[[str], str], ...] = field( + default_factory=lambda: (preprocess_text,) + ) + postprocess_text_funcs: tuple[ + Callable[[BaseEncoderOutput, dict], torch.Tensor], ... + ] = field(default_factory=lambda: (_gemma_postprocess_func,)) + + def prepare_sigmas(self, sigmas, num_inference_steps): + if sigmas is None: + steps = int(num_inference_steps) + if steps <= 0: + raise ValueError(f"num_inference_steps must be positive, got {steps}") + return [1.0 - i / steps for i in range(steps)] + return sigmas + + def tokenize_prompt(self, prompt: list[str], tokenizer, tok_kwargs) -> dict: + # Adapted from diffusers_pipeline.py _get_gemma_prompt_embeds + # But we only need tokenization here, the embedding happens in TextEncodingStage + + # Gemma expects left padding for chat-style prompts + tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + max_sequence_length = tok_kwargs.get( + "max_length", 1024 + ) # Default from diffusers pipeline + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + return text_inputs + + def maybe_pack_latents(self, latents, batch_size, batch): + # If already packed (3D shape [B, seq, C]), skip packing + if latents.dim() == 3: + return latents + + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // self.patch_size_t + post_patch_height = height // self.patch_size + post_patch_width = width // self.patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + self.patch_size_t, + post_patch_height, + self.patch_size, + post_patch_width, + self.patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + def _infer_video_latent_frames_and_tokens_per_frame( + self, batch, seq_len: int + ) -> tuple[int, int]: + """Infer latent-frame count and tokens-per-frame for packed token latents [B, S, D]. + + Notes: + - This assumes `patch_size_t == 1` (no temporal patching). + - Tokens are ordered as (frame, height, width) after packing. + """ + if int(self.patch_size_t) != 1: + raise ValueError( + "LTX-2 SP time-sharding for packed token latents currently requires " + f"{self.patch_size_t=}. (Expected 1)" + ) + if int(seq_len) <= 0: + raise ValueError(f"Expected {seq_len=} > 0 for packed token latents.") + if int(self.vae_scale_factor) <= 0: + raise ValueError(f"Invalid {self.vae_scale_factor=}. Must be > 0.") + if int(self.patch_size) <= 0: + raise ValueError(f"Invalid {self.patch_size=}. Must be > 0.") + + latent_height = int(batch.height) // int(self.vae_scale_factor) + latent_width = int(batch.width) // int(self.vae_scale_factor) + if latent_height <= 0 or latent_width <= 0: + raise ValueError( + "Invalid latent H/W computed from batch.height/width: " + f"{batch.height=} {batch.width=} {self.vae_scale_factor=}" + ) + if (latent_height % int(self.patch_size)) != 0 or ( + latent_width % int(self.patch_size) + ) != 0: + raise ValueError( + "Invalid spatial patching for packed token latents. Expected latent H/W " + "to be divisible by patch_size, got " + f"{latent_height=} {latent_width=} {self.patch_size=}." + ) + + post_patch_h = latent_height // int(self.patch_size) + post_patch_w = latent_width // int(self.patch_size) + tokens_per_frame = int(post_patch_h) * int(post_patch_w) + if tokens_per_frame <= 0: + raise ValueError( + f"Invalid tokens_per_frame={tokens_per_frame} from " + f"{latent_height=} {latent_width=} {self.patch_size=}" + ) + if int(seq_len) % int(tokens_per_frame) != 0: + raise ValueError( + f"LTX-2 token latents seq_len={seq_len} is not divisible by " + f"tokens_per_frame={tokens_per_frame}. Cannot time-shard for SP." + ) + latent_num_frames = int(seq_len) // int(tokens_per_frame) + return int(latent_num_frames), int(tokens_per_frame) + + def shard_latents_for_sp(self, batch, latents): + """Shard LTX-2 packed token latents across SP ranks by latent time (frame) dimension.""" + sp_world_size = get_sp_world_size() + if sp_world_size <= 1: + return latents, False + + # Default behavior for 5D latents. + if isinstance(latents, torch.Tensor) and latents.ndim == 5: + return super().shard_latents_for_sp(batch, latents) + + # LTX-2 packed token latents [B, S, D] + if not (isinstance(latents, torch.Tensor) and latents.ndim == 3): + return latents, False + + sp_rank = get_sp_parallel_rank() + seq_len = int(latents.shape[1]) + latent_frames, tokens_per_frame = ( + self._infer_video_latent_frames_and_tokens_per_frame(batch, seq_len) + ) + + # Pad whole frames so `latent_frames` is divisible by `sp_world_size`. + pad_frames = (sp_world_size - (latent_frames % sp_world_size)) % sp_world_size + if pad_frames: + pad_tokens = int(pad_frames) * int(tokens_per_frame) + pad = torch.zeros( + (latents.shape[0], pad_tokens, latents.shape[2]), + device=latents.device, + dtype=latents.dtype, + ) + latents = torch.cat([latents, pad], dim=1) + latent_frames = int(latent_frames) + int(pad_frames) + + local_frames = int(latent_frames) // int(sp_world_size) + start_frame = int(sp_rank) * int(local_frames) + start = int(start_frame) * int(tokens_per_frame) + end = int(start) + int(local_frames) * int(tokens_per_frame) + latents = latents[:, start:end, :] + + # Store SP metadata for denoising (TI2V gating) and model-side RoPE shift. + batch.sp_video_latent_num_frames = int(local_frames) + batch.sp_video_start_frame = int(start_frame) + batch.sp_video_tokens_per_frame = int(tokens_per_frame) + + return latents, True + + def gather_latents_for_sp(self, latents): + """Gather latents after SP. For packed token latents [B, S_local, D], gather on dim=1.""" + if get_sp_world_size() <= 1: + return latents + if isinstance(latents, torch.Tensor) and latents.ndim == 3: + return sequence_model_parallel_all_gather(latents.contiguous(), dim=1) + return super().gather_latents_for_sp(latents) + + def maybe_pack_audio_latents(self, latents, batch_size, batch): + # If already packed (3D shape [B, T, C*F]), skip packing + if latents.dim() == 3: + return latents + + # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins + # We need to pack them if patch_size/patch_size_t are defined for audio (not standard DiT patch size) + + # So for LTX-2 (unless we change patch sizes), we just do: + latents = latents.transpose(1, 2).flatten( + 2, 3 + ) # [B, C, L, M] --> [B, L, C * M] + return latents + + def get_pos_prompt_embeds(self, batch): + # LTX-2 returns multiple prompt embed tensors (video/audio contexts). + return ( + batch.prompt_embeds[0] + if isinstance(batch.prompt_embeds, list) + else batch.prompt_embeds + ) + + def get_neg_prompt_embeds(self, batch): + return ( + batch.negative_prompt_embeds[0] + if isinstance(batch.negative_prompt_embeds, list) + else batch.negative_prompt_embeds + ) + + def get_decode_scale_and_shift(self, device, dtype, vae): + latents_mean = getattr(vae, "latents_mean", None) + latents_std = getattr(vae, "latents_std", None) + + scaling_factor = ( + getattr(getattr(vae, "config", None), "scaling_factor", None) + or getattr(vae, "scaling_factor", None) + or getattr(self.vae_config.arch_config, "scaling_factor", None) + or 1.0 + ) + if isinstance(scaling_factor, (int, float)) and float(scaling_factor) == 0.0: + scaling_factor = 1.0 + + if isinstance(latents_mean, torch.Tensor) and isinstance( + latents_std, torch.Tensor + ): + latents_mean = latents_mean.to(device=device, dtype=dtype).view( + 1, -1, 1, 1, 1 + ) + latents_std = latents_std.to(device=device, dtype=dtype).view( + 1, -1, 1, 1, 1 + ) + sf = torch.tensor(float(scaling_factor), device=device, dtype=dtype).view( + 1, 1, 1, 1, 1 + ) + return sf / latents_std, latents_mean + + sf = torch.tensor(float(scaling_factor), device=device, dtype=dtype).view( + 1, 1, 1, 1, 1 + ) + return sf, None + + @staticmethod + def _unpack_latents( + latents: torch.Tensor, + num_frames: int, + height: int, + width: int, + patch_size: int = 1, + patch_size_t: int = 1, + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape( + batch_size, + num_frames, + height, + width, + -1, + patch_size_t, + patch_size, + patch_size, + ) + latents = ( + latents.permute(0, 4, 1, 5, 2, 6, 3, 7) + .flatten(6, 7) + .flatten(4, 5) + .flatten(2, 3) + ) + return latents + + @staticmethod + def _denormalize_latents( + latents: torch.Tensor, + latents_mean: torch.Tensor, + latents_std: torch.Tensor, + scaling_factor: float = 1.0, + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + def _denormalize_audio_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor + ): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + + @staticmethod + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: int | None = None, + patch_size_t: int | None = None, + ) -> torch.Tensor: + # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M], + # where L is the latent audio length and M is the number of mel bins. + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape( + batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size + ) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1. + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + def _unpad_and_unpack_latents(self, latents, audio_latents, batch, vae, audio_vae): + # Calculate latent dimensions + # Assuming batch has height, width, num_frames + height = batch.height + width = batch.width + num_frames = batch.num_frames + + # Get compression ratios + # Default LTX-2 values if not present in config + vae_spatial_compression_ratio = getattr( + self.vae_config.arch_config, "spatial_compression_ratio", 32 + ) + vae_temporal_compression_ratio = getattr( + self.vae_config.arch_config, "temporal_compression_ratio", 8 + ) + + latent_height = height // vae_spatial_compression_ratio + latent_width = width // vae_spatial_compression_ratio + latent_num_frames = (num_frames - 1) // vae_temporal_compression_ratio + 1 + + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.patch_size, + self.patch_size_t, + ) + + sample_rate = self.audio_vae_config.arch_config.sample_rate + hop_length = self.audio_vae_config.arch_config.mel_hop_length + temporal_compression = self.audio_vae_temporal_compression_ratio + duration_s = num_frames / batch.fps + + latents_per_second = ( + float(sample_rate) / float(hop_length) / float(temporal_compression) + ) + audio_num_frames = round(duration_s * latents_per_second) + + num_mel_bins = self.audio_vae_config.arch_config.mel_bins + mel_compression_ratio = self.audio_vae_mel_compression_ratio + latent_mel_bins = num_mel_bins // mel_compression_ratio + + audio_latents_mean = getattr(audio_vae, "latents_mean", None) + audio_latents_std = getattr(audio_vae, "latents_std", None) + if ( + isinstance(audio_latents_mean, torch.Tensor) + and isinstance(audio_latents_std, torch.Tensor) + and audio_latents_mean.numel() == audio_latents_std.numel() + ): + audio_latents_mean = audio_latents_mean.to( + device=audio_latents.device, dtype=audio_latents.dtype + ) + audio_latents_std = audio_latents_std.to( + device=audio_latents.device, dtype=audio_latents.dtype + ) + if audio_latents.ndim == 3: + if audio_latents.shape[-1] != audio_latents_mean.numel(): + raise ValueError( + f"audio_latents last dim {audio_latents.shape[-1]} " + f"does not match audio_vae stats {audio_latents_mean.numel()}" + ) + audio_latents = audio_latents * audio_latents_std.view( + 1, 1, -1 + ) + audio_latents_mean.view(1, 1, -1) + elif audio_latents.ndim == 2: + if audio_latents.shape[-1] != audio_latents_mean.numel(): + raise ValueError( + f"audio_latents last dim {audio_latents.shape[-1]} " + f"does not match audio_vae stats {audio_latents_mean.numel()}" + ) + audio_latents = audio_latents * audio_latents_std.view( + 1, -1 + ) + audio_latents_mean.view(1, -1) + else: + audio_latents = audio_latents * audio_latents_std + audio_latents_mean + + audio_latents = self._unpack_audio_latents( + audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins + ) + + return latents, audio_latents + + +@dataclasses.dataclass +class LTX2I2VPipelineConfig(LTX2PipelineConfig): + task_type: ModelTaskType = ModelTaskType.TI2V diff --git a/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/mova.py b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/mova.py new file mode 100644 index 0000000000000000000000000000000000000000..4612a9eef493b6af221428d071f7af82fe5803b5 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/mova.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +MOVA pipeline configuration. +""" + +from dataclasses import dataclass, field + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image + +from sglang.multimodal_gen.configs.models.dits import MOVAAudioConfig, MOVAVideoConfig +from sglang.multimodal_gen.configs.models.encoders import T5Config +from sglang.multimodal_gen.configs.models.vaes import DacVAEConfig, WanVAEConfig +from sglang.multimodal_gen.configs.pipeline_configs.base import ( + ModelTaskType, + PipelineConfig, +) +from sglang.multimodal_gen.configs.pipeline_configs.wan import t5_postprocess_text +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +@dataclass +class MOVAPipelineConfig(PipelineConfig): + """Configuration for MOVA (text+image -> video+audio) pipelines.""" + + task_type: ModelTaskType = ModelTaskType.I2V + + # Model configs + dit_config: MOVAVideoConfig = field(default_factory=MOVAVideoConfig) + audio_dit_config: MOVAAudioConfig = field(default_factory=MOVAAudioConfig) + + # Video VAE (Wan) + Audio VAE (DAC) + vae_config: WanVAEConfig = field(default_factory=WanVAEConfig) + audio_vae_config: DacVAEConfig = field(default_factory=DacVAEConfig) + audio_vae_precision: str = "fp32" + + # Text encoder (UMT5 compatible) + text_encoder_configs: tuple = field(default_factory=lambda: (T5Config(),)) + postprocess_text_funcs: tuple = field( + default_factory=lambda: (t5_postprocess_text,) + ) + + # MOVA specific + audio_vae_type: str = "dac" + boundary_ratio: float | None = 0.9 + + # temporal alignment: MOVA expects (num_frames - 1) % 4 == 0 + time_division_factor: int = 4 + time_division_remainder: int = 1 + + def _center_crop_and_resize( + self, image: torch.Tensor | Image.Image, target_height: int, target_width: int + ) -> torch.Tensor | Image.Image: + if not isinstance(image, (Image.Image, torch.Tensor)): + raise TypeError(f"Unsupported image type: {type(image)}") + if isinstance(image, Image.Image): + image = torch.from_numpy(np.array(image)) + + if image.ndim == 2: + image = image[..., None] + + if not image.dtype.is_floating_point: + image = image.to(torch.float32).div(255.0) + + if image.ndim == 3: + if image.shape[0] in (1, 3, 4) and image.shape[-1] not in (1, 3, 4): + image = image.unsqueeze(0) + else: + image = image.permute(2, 0, 1).unsqueeze(0) + elif image.ndim == 4: + if image.shape[1] not in (1, 3, 4) and image.shape[-1] in (1, 3, 4): + image = image.permute(0, 3, 1, 2) + + image_height, image_width = image.shape[-2], image.shape[-1] + if image_height == target_height and image_width == target_width: + return image + + logger.info( + "Center cropping and resizing image to %dx%d", target_width, target_height + ) + + if image_height * target_width < image_width * target_height: + cropped_width = (image_height * target_width) // target_height + left = (image_width - cropped_width) // 2 + image = image[..., :, left : left + cropped_width] + else: + cropped_height = (image_width * target_height) // target_width + top = (image_height - cropped_height) // 2 + image = image[..., top : top + cropped_height, :] + + image = F.interpolate( + image, + size=(target_height, target_width), + mode="bilinear", + align_corners=False, + antialias=True, + ) + return image + + def adjust_num_frames(self, num_frames: int) -> int: + if num_frames is None: + return num_frames + if num_frames % self.time_division_factor != self.time_division_remainder: + adjusted = ( + (num_frames + self.time_division_factor - 1) + // self.time_division_factor + * self.time_division_factor + + self.time_division_remainder + ) + logger.warning( + "`num_frames` (%s) is not compatible with MOVA temporal constraints. " + "Rounding to %s.", + num_frames, + adjusted, + ) + return adjusted + return num_frames + + def preprocess_condition_image( + self, image, target_width, target_height, _vae_image_processor + ): + image = self._center_crop_and_resize(image, target_height, target_width) + return image, (target_width, target_height) + + def prepare_latent_shape(self, batch, batch_size, num_frames): + spatial = self.vae_config.arch_config.spatial_compression_ratio + length = (num_frames - 1) // self.time_division_factor + 1 + shape = ( + batch_size, + self.dit_config.arch_config.out_dim, + length, + batch.height // spatial, + batch.width // spatial, + ) + return shape + + def prepare_audio_latent_shape(self, batch_size, num_samples, audio_vae): + latent_T = (num_samples + audio_vae.hop_length - 1) // audio_vae.hop_length + return (batch_size, audio_vae.latent_dim, latent_T) + + def normalize_video_latents(self, latents: torch.Tensor, video_vae) -> torch.Tensor: + latents_mean = getattr(video_vae.config, "latents_mean", None) + latents_std = getattr(video_vae.config, "latents_std", None) + if latents_mean is None or latents_std is None: + return latents + mean = torch.tensor( + latents_mean, device=latents.device, dtype=latents.dtype + ).view(1, video_vae.config.z_dim, 1, 1, 1) + inv_std = ( + 1.0 / torch.tensor(latents_std, device=latents.device, dtype=latents.dtype) + ).view(1, video_vae.config.z_dim, 1, 1, 1) + return (latents - mean) * inv_std + + def denormalize_video_latents( + self, latents: torch.Tensor, video_vae + ) -> torch.Tensor: + latents_mean = getattr(video_vae.config, "latents_mean", None) + latents_std = getattr(video_vae.config, "latents_std", None) + if latents_mean is None or latents_std is None: + return latents + mean = torch.tensor( + latents_mean, device=latents.device, dtype=latents.dtype + ).view(1, video_vae.config.z_dim, 1, 1, 1) + std = torch.tensor( + latents_std, device=latents.device, dtype=latents.dtype + ).view(1, video_vae.config.z_dim, 1, 1, 1) + return latents * std + mean + + +@dataclass +class MOVA360PConfig(MOVAPipelineConfig): + """Configuration for MOVA 360P (text+image -> video+audio) pipelines.""" + + max_area: int = 352 * 640 + + +@dataclass +class MOVA720PConfig(MOVAPipelineConfig): + """Configuration for MOVA 720P (text+image -> video+audio) pipelines.""" + + max_area: int = 720 * 1280 diff --git a/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/qwen_image.py b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/qwen_image.py new file mode 100644 index 0000000000000000000000000000000000000000..40f42c0c70712ad304fee4e169822c0ffc886709 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/qwen_image.py @@ -0,0 +1,579 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from dataclasses import dataclass, field +from typing import Callable + +import torch + +from sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig +from sglang.multimodal_gen.configs.models.dits.qwenimage import ( + QwenImageDitConfig, + QwenImageEditPlus_2511_DitConfig, +) +from sglang.multimodal_gen.configs.models.encoders.qwen_image import Qwen2_5VLConfig +from sglang.multimodal_gen.configs.models.vaes.qwenimage import QwenImageVAEConfig +from sglang.multimodal_gen.configs.pipeline_configs.base import ( + ImagePipelineConfig, + ModelTaskType, + maybe_unpad_latents, + shard_rotary_emb_for_sp, +) +from sglang.multimodal_gen.runtime.models.vision_utils import resize +from sglang.multimodal_gen.utils import calculate_dimensions + + +def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + +def qwen_image_preprocess_text(prompt): + prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + + template = prompt_template_encode + txt = template.format(prompt) + return txt + + +def qwen_image_postprocess_text(outputs, _text_inputs, drop_idx=34): + # squeeze the batch dim + hidden_states = outputs.hidden_states[-1] + split_hidden_states = _extract_masked_hidden( + hidden_states, _text_inputs.attention_mask + ) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [ + torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) + for u in split_hidden_states + ] + ) + return prompt_embeds + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents +def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view( + batch_size, num_channels_latents, height // 2, 2, width // 2, 2 + ) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape( + batch_size, (height // 2) * (width // 2), num_channels_latents * 4 + ) + + return latents + + +@dataclass +class QwenImagePipelineConfig(ImagePipelineConfig): + """Configuration for the QwenImage pipeline.""" + + should_use_guidance: bool = False + task_type: ModelTaskType = ModelTaskType.T2I + + vae_tiling: bool = False + + vae_sp: bool = False + + dit_config: DiTConfig = field(default_factory=QwenImageDitConfig) + # VAE + vae_config: VAEConfig = field(default_factory=QwenImageVAEConfig) + + enable_autocast: bool = False + + # Text encoding stage + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (Qwen2_5VLConfig(),) + ) + + text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("bf16",)) + + preprocess_text_funcs: tuple[Callable[[str], str], ...] = field( + default_factory=lambda: (qwen_image_preprocess_text,) + ) + + postprocess_text_funcs: tuple[Callable[[str], str], ...] = field( + default_factory=lambda: (qwen_image_postprocess_text,) + ) + text_encoder_extra_args: list[dict] = field( + default_factory=lambda: [ + dict( + padding=True, + truncation=True, + ), + None, + ] + ) + + def prepare_sigmas(self, sigmas, num_inference_steps): + return self._prepare_sigmas(sigmas, num_inference_steps) + + def prepare_image_processor_kwargs(self, batch, neg=False): + prompt = batch.prompt if not neg else batch.negative_prompt + if prompt: + prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + txt = prompt_template_encode.format(batch.prompt) + return dict(text=[txt], padding=True) + else: + return {} + + def get_vae_scale_factor(self): + return self.vae_config.arch_config.vae_scale_factor + + def prepare_latent_shape(self, batch, batch_size, num_frames): + vae_scale_factor = self.vae_config.arch_config.vae_scale_factor + height = 2 * (batch.height // (vae_scale_factor * 2)) + width = 2 * (batch.width // (vae_scale_factor * 2)) + num_channels_latents = self.dit_config.arch_config.in_channels // 4 + shape = (batch_size, 1, num_channels_latents, height, width) + return shape + + def maybe_pack_latents(self, latents, batch_size, batch): + height = 2 * ( + batch.height // (self.vae_config.arch_config.vae_scale_factor * 2) + ) + width = 2 * (batch.width // (self.vae_config.arch_config.vae_scale_factor * 2)) + num_channels_latents = self.dit_config.arch_config.in_channels // 4 + # pack latents + return _pack_latents(latents, batch_size, num_channels_latents, height, width) + + def get_decode_scale_and_shift(self, device, dtype, vae): + vae_arch_config = self.vae_config.arch_config + scaling_factor = 1.0 / torch.tensor( + vae_arch_config.latents_std, device=device + ).view(1, vae_arch_config.z_dim, 1, 1, 1).to(device, dtype) + shift_factor = ( + torch.tensor(vae_arch_config.latents_mean) + .view(1, vae_arch_config.z_dim, 1, 1, 1) + .to(device, dtype) + ) + return scaling_factor, shift_factor + + @staticmethod + def get_freqs_cis(img_shapes, txt_seq_lens, rotary_emb, device, dtype): + # img_shapes: for global entire image + img_freqs, txt_freqs = rotary_emb(img_shapes, txt_seq_lens, device=device) + + # flashinfer RoPE expects a float32 cos/sin cache concatenated on the last dim + img_cos_half = img_freqs.real.to(dtype=torch.float32).contiguous() + img_sin_half = img_freqs.imag.to(dtype=torch.float32).contiguous() + txt_cos_half = txt_freqs.real.to(dtype=torch.float32).contiguous() + txt_sin_half = txt_freqs.imag.to(dtype=torch.float32).contiguous() + + img_cos_sin_cache = torch.cat([img_cos_half, img_sin_half], dim=-1) + txt_cos_sin_cache = torch.cat([txt_cos_half, txt_sin_half], dim=-1) + return img_cos_sin_cache, txt_cos_sin_cache + + def _prepare_cond_kwargs(self, batch, prompt_embeds, rotary_emb, device, dtype): + batch_size = prompt_embeds[0].shape[0] + height = batch.height + width = batch.width + vae_scale_factor = self.vae_config.arch_config.vae_scale_factor + + img_shapes = [ + [ + ( + 1, + height // vae_scale_factor // 2, + width // vae_scale_factor // 2, + ) + ] + ] * batch_size + txt_seq_lens = [prompt_embeds[0].shape[1]] + + if rotary_emb is None: + return { + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + "freqs_cis": None, + } + + freqs_cis = self.get_freqs_cis( + img_shapes, txt_seq_lens, rotary_emb, device, dtype + ) + + img_cache, txt_cache = freqs_cis + img_cache = shard_rotary_emb_for_sp(img_cache) + return { + "txt_seq_lens": txt_seq_lens, + "freqs_cis": (img_cache, txt_cache), + "img_shapes": img_shapes, + } + + def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype): + return self._prepare_cond_kwargs( + batch, batch.prompt_embeds, rotary_emb, device, dtype + ) + + def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype): + return self._prepare_cond_kwargs( + batch, batch.negative_prompt_embeds, rotary_emb, device, dtype + ) + + def post_denoising_loop(self, latents, batch): + # unpack latents for qwen-image + ( + latents, + batch_size, + channels, + height, + width, + ) = self._unpad_and_unpack_latents(latents, batch) + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + return latents + + +@dataclass +class QwenImageEditPipelineConfig(QwenImagePipelineConfig): + """Configuration for the QwenImageEdit pipeline.""" + + task_type: ModelTaskType = ModelTaskType.I2I + + def _prepare_edit_cond_kwargs( + self, batch, prompt_embeds, rotary_emb, device, dtype + ): + batch_size = batch.latents.shape[0] + assert batch_size == 1 + height = batch.height + width = batch.width + image_size = batch.original_condition_image_size + edit_width, edit_height, _ = calculate_dimensions( + 1024 * 1024, image_size[0] / image_size[1] + ) + vae_scale_factor = self.get_vae_scale_factor() + + img_shapes = [ + [ + ( + 1, + height // vae_scale_factor // 2, + width // vae_scale_factor // 2, + ), + ( + 1, + edit_height // vae_scale_factor // 2, + edit_width // vae_scale_factor // 2, + ), + ], + ] * batch_size + txt_seq_lens = [prompt_embeds[0].shape[1]] + + if rotary_emb is None: + return { + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + "freqs_cis": None, + } + + freqs_cis = QwenImagePipelineConfig.get_freqs_cis( + img_shapes, txt_seq_lens, rotary_emb, device, dtype + ) + + # perform sp shard on noisy image tokens + noisy_img_seq_len = ( + 1 * (height // vae_scale_factor // 2) * (width // vae_scale_factor // 2) + ) + + img_cache, txt_cache = freqs_cis + noisy_img_cache = shard_rotary_emb_for_sp(img_cache[:noisy_img_seq_len, :]) + img_cache = torch.cat( + [noisy_img_cache, img_cache[noisy_img_seq_len:, :]], dim=0 + ).to(device=device) + return { + "txt_seq_lens": txt_seq_lens, + "freqs_cis": (img_cache, txt_cache), + "img_shapes": img_shapes, + } + + def preprocess_condition_image( + self, image, target_width, target_height, _vae_image_processor + ): + return resize(image, target_height, target_width, resize_mode="default"), ( + target_width, + target_height, + ) + + def postprocess_image_latent(self, latent_condition, batch): + batch_size = batch.batch_size + if batch_size > latent_condition.shape[0]: + if batch_size % latent_condition.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // latent_condition.shape[0] + image_latents = latent_condition.repeat( + additional_image_per_prompt, 1, 1, 1 + ) + else: + raise ValueError( + f"Cannot duplicate `image` of batch size {latent_condition.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = latent_condition + image_latent_height, image_latent_width = image_latents.shape[3:] + num_channels_latents = self.dit_config.arch_config.in_channels // 4 + image_latents = _pack_latents( + image_latents, + batch_size, + num_channels_latents, + image_latent_height, + image_latent_width, + ) + + return image_latents + + def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype): + return self._prepare_edit_cond_kwargs( + batch, batch.prompt_embeds, rotary_emb, device, dtype + ) + + def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype): + return self._prepare_edit_cond_kwargs( + batch, batch.negative_prompt_embeds, rotary_emb, device, dtype + ) + + def calculate_condition_image_size(self, image, width, height) -> tuple[int, int]: + calculated_width, calculated_height, _ = calculate_dimensions( + 1024 * 1024, width / height + ) + return calculated_width, calculated_height + + def slice_noise_pred(self, noise, latents): + # remove noise over input image + noise = noise[:, : latents.size(1)] + return noise + + +CONDITION_IMAGE_SIZE = 384 * 384 +VAE_IMAGE_SIZE = 1024 * 1024 + + +@dataclass +class QwenImageEditPlusPipelineConfig(QwenImageEditPipelineConfig): + task_type: ModelTaskType = ModelTaskType.I2I + + def _get_condition_image_sizes(self, batch) -> list[tuple[int, int]]: + image = batch.condition_image + if not isinstance(image, list): + image = [image] + + condition_image_sizes = [] + for img in image: + image_width, image_height = img.size + edit_width, edit_height, _ = calculate_dimensions( + VAE_IMAGE_SIZE, image_width / image_height + ) + condition_image_sizes.append((edit_width, edit_height)) + + return condition_image_sizes + + def prepare_image_processor_kwargs(self, batch, neg=False) -> dict: + prompt = batch.prompt if not neg else batch.negative_prompt + prompt_list = [prompt] if isinstance(prompt, str) else prompt + image_list = batch.condition_image + + prompt_template_encode = ( + "<|im_start|>system\nDescribe the key features of the input image " + "(color, shape, size, texture, objects, background), then explain how " + "the user's text instruction should alter or modify the image. Generate " + "a new image that meets the user's requirements while maintaining " + "consistency with the original input where appropriate.<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" + if isinstance(image_list, list): + base_img_prompt = "" + for i, img in enumerate(image_list): + base_img_prompt += img_prompt_template.format(i + 1) + txt = [prompt_template_encode.format(base_img_prompt + p) for p in prompt_list] + return dict(text=txt, padding=True) + + def prepare_calculated_size(self, image): + return self.calculate_vae_image_size(image, image.width, image.height) + + def resize_condition_image(self, images, target_width, target_height): + if not isinstance(images, list): + images = [images] + new_images = [] + for img, width, height in zip(images, target_width, target_height): + new_images.append(resize(img, height, width, resize_mode="default")) + return new_images + + def calculate_condition_image_size(self, image, width, height) -> tuple[int, int]: + calculated_width, calculated_height, _ = calculate_dimensions( + CONDITION_IMAGE_SIZE, width / height + ) + return calculated_width, calculated_height + + def calculate_vae_image_size(self, image, width, height) -> tuple[int, int]: + calculated_width, calculated_height, _ = calculate_dimensions( + VAE_IMAGE_SIZE, width / height + ) + return calculated_width, calculated_height + + def preprocess_vae_image(self, batch, vae_image_processor): + if not isinstance(batch.condition_image, list): + batch.condition_image = [batch.condition_image] + new_images = [] + vae_image_sizes = [] + for img in batch.condition_image: + width, height = self.calculate_vae_image_size(img, img.width, img.height) + new_images.append(vae_image_processor.preprocess(img, height, width)) + vae_image_sizes.append((width, height)) + batch.vae_image = new_images + batch.vae_image_sizes = vae_image_sizes + return batch + + def _prepare_edit_cond_kwargs( + self, batch, prompt_embeds, rotary_emb, device, dtype + ): + batch_size = batch.latents.shape[0] + assert batch_size == 1 + height = batch.height + width = batch.width + + vae_scale_factor = self.get_vae_scale_factor() + + img_shapes = [ + [ + (1, height // vae_scale_factor // 2, width // vae_scale_factor // 2), + *[ + ( + 1, + vae_height // vae_scale_factor // 2, + vae_width // vae_scale_factor // 2, + ) + for vae_width, vae_height in batch.vae_image_sizes + ], + ], + ] * batch_size + txt_seq_lens = [prompt_embeds[0].shape[1]] + + freqs_cis = QwenImageEditPlusPipelineConfig.get_freqs_cis( + img_shapes, txt_seq_lens, rotary_emb, device, dtype + ) + + # perform sp shard on noisy image tokens + noisy_img_seq_len = ( + 1 * (height // vae_scale_factor // 2) * (width // vae_scale_factor // 2) + ) + + if isinstance(freqs_cis[0], torch.Tensor) and freqs_cis[0].dim() == 2: + img_cache, txt_cache = freqs_cis + noisy_img_cache = shard_rotary_emb_for_sp(img_cache[:noisy_img_seq_len, :]) + img_cache = torch.cat( + [noisy_img_cache, img_cache[noisy_img_seq_len:, :]], dim=0 + ).to(device=device) + return { + "txt_seq_lens": txt_seq_lens, + "freqs_cis": (img_cache, txt_cache), + "img_shapes": img_shapes, + } + + (img_cos, img_sin), (txt_cos, txt_sin) = freqs_cis + noisy_img_cos = shard_rotary_emb_for_sp(img_cos[:noisy_img_seq_len, :]) + noisy_img_sin = shard_rotary_emb_for_sp(img_sin[:noisy_img_seq_len, :]) + + # concat back the img_cos for input image (since it is not sp-shared later) + img_cos = torch.cat([noisy_img_cos, img_cos[noisy_img_seq_len:, :]], dim=0).to( + device=device + ) + img_sin = torch.cat([noisy_img_sin, img_sin[noisy_img_seq_len:, :]], dim=0).to( + device=device + ) + + return { + "txt_seq_lens": txt_seq_lens, + "freqs_cis": ((img_cos, img_sin), (txt_cos, txt_sin)), + "img_shapes": img_shapes, + } + + +@dataclass +class QwenImageEditPlus_2511_PipelineConfig(QwenImageEditPlusPipelineConfig): + dit_config: DiTConfig = field(default_factory=QwenImageEditPlus_2511_DitConfig) + + +@dataclass +class QwenImageLayeredPipelineConfig(QwenImageEditPipelineConfig): + resolution: int = 640 # TODO: allow user to set resolution + vae_precision: str = "bf16" + + def _prepare_edit_cond_kwargs( + self, batch, prompt_embeds, rotary_emb, device, dtype + ): + batch_size = batch.latents.shape[0] + assert batch_size == 1 + height = batch.height + width = batch.width + image_size = batch.original_condition_image_size + + vae_scale_factor = self.get_vae_scale_factor() + + img_shapes = batch.img_shapes + txt_seq_lens = batch.txt_seq_lens + + freqs_cis = QwenImageEditPlusPipelineConfig.get_freqs_cis( + img_shapes, txt_seq_lens, rotary_emb, device, dtype + ) + + # perform sp shard on noisy image tokens + noisy_img_seq_len = ( + 1 * (height // vae_scale_factor // 2) * (width // vae_scale_factor // 2) + ) + + img_cache, txt_cache = freqs_cis + noisy_img_cache = shard_rotary_emb_for_sp(img_cache[:noisy_img_seq_len, :]) + img_cache = torch.cat( + [noisy_img_cache, img_cache[noisy_img_seq_len:, :]], dim=0 + ).to(device=device) + + return { + "txt_seq_lens": txt_seq_lens, + "img_shapes": img_shapes, + "freqs_cis": (img_cache, txt_cache), + "additional_t_cond": torch.tensor([0], device=device, dtype=torch.long), + } + + def _unpad_and_unpack_latents(self, latents, batch): + vae_scale_factor = self.vae_config.arch_config.vae_scale_factor + channels = self.dit_config.arch_config.in_channels + batch_size = latents.shape[0] + layers = batch.num_frames + + height = 2 * (int(batch.height) // (vae_scale_factor * 2)) + width = 2 * (int(batch.width) // (vae_scale_factor * 2)) + + latents = maybe_unpad_latents(latents, batch) + latents = latents.view( + batch_size, layers + 1, height // 2, width // 2, channels // 4, 2, 2 + ) + latents = latents.permute(0, 1, 4, 2, 5, 3, 6) + + latents = latents.reshape( + batch_size, layers + 1, channels // (2 * 2), height, width + ) + latents = latents.permute(0, 2, 1, 3, 4) # (b, c, f, h, w) + return latents, batch_size, channels, height, width + + def allow_set_num_frames(self): + return True + + def post_denoising_loop(self, latents, batch): + # unpack latents for qwen-image + ( + latents, + batch_size, + channels, + height, + width, + ) = self._unpad_and_unpack_latents(latents, batch) + b, c, f, h, w = latents.shape + latents = latents[:, :, 1:] # remove the first frame as it is the origin input + latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w) + # latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + return latents diff --git a/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/wan.py b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/wan.py new file mode 100644 index 0000000000000000000000000000000000000000..6a824e67881fc88907cd6cfc593ba1639dee41e3 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/wan.py @@ -0,0 +1,236 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Callable +from dataclasses import dataclass, field + +import torch + +from sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig +from sglang.multimodal_gen.configs.models.dits import WanVideoConfig +from sglang.multimodal_gen.configs.models.encoders import ( + BaseEncoderOutput, + CLIPVisionConfig, + T5Config, +) +from sglang.multimodal_gen.configs.models.vaes import WanVAEConfig +from sglang.multimodal_gen.configs.pipeline_configs.base import ( + ModelTaskType, + PipelineConfig, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def t5_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor: + mask: torch.Tensor = outputs.attention_mask + hidden_state: torch.Tensor = outputs.last_hidden_state + seq_lens = mask.gt(0).sum(dim=1).long() + assert torch.isnan(hidden_state).sum() == 0 + prompt_embeds = [u[:v] for u, v in zip(hidden_state, seq_lens, strict=True)] + prompt_embeds_tensor: torch.Tensor = torch.stack( + [ + torch.cat([u, u.new_zeros(512 - u.size(0), u.size(1))]) + for u in prompt_embeds + ], + dim=0, + ) + return prompt_embeds_tensor + + +@dataclass +class WanI2VCommonConfig(PipelineConfig): + # for all wan i2v pipelines + def adjust_num_frames(self, num_frames): + vae_scale_factor_temporal = self.vae_config.arch_config.scale_factor_temporal + if num_frames % vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = ( + num_frames // vae_scale_factor_temporal * vae_scale_factor_temporal + 1 + ) + return num_frames + return num_frames + + +@dataclass +class WanT2V480PConfig(PipelineConfig): + """Base configuration for Wan T2V 1.3B pipeline architecture.""" + + task_type: ModelTaskType = ModelTaskType.T2V + # WanConfig-specific parameters with defaults + # DiT + dit_config: DiTConfig = field(default_factory=WanVideoConfig) + + # VAE + vae_config: VAEConfig = field(default_factory=WanVAEConfig) + vae_tiling: bool = False + vae_sp: bool = False + + # Denoising stage + flow_shift: float | None = 3.0 + + # Text encoding stage + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (T5Config(),) + ) + postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.Tensor], ...] = ( + field(default_factory=lambda: (t5_postprocess_text,)) + ) + + # Precision for each component + precision: str = "bf16" + vae_precision: str = "fp32" + text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp32",)) + + # WanConfig-specific added parameters + + def __post_init__(self): + self.vae_config.load_encoder = False + self.vae_config.load_decoder = True + + +@dataclass +class TurboWanT2V480PConfig(WanT2V480PConfig): + """Base configuration for Wan T2V 1.3B pipeline architecture.""" + + flow_shift: float | None = 8.0 + dmd_denoising_steps: list[int] | None = field( + default_factory=lambda: [988, 932, 852, 608] + ) + + +@dataclass +class WanT2V720PConfig(WanT2V480PConfig): + """Base configuration for Wan T2V 14B 720P pipeline architecture.""" + + # WanConfig-specific parameters with defaults + + # Denoising stage + flow_shift: float | None = 5.0 + + +@dataclass +class WanI2V480PConfig(WanT2V480PConfig, WanI2VCommonConfig): + """Base configuration for Wan I2V 14B 480P pipeline architecture.""" + + max_area: int = 480 * 832 + # WanConfig-specific parameters with defaults + task_type: ModelTaskType = ModelTaskType.I2V + # Precision for each component + image_encoder_config: EncoderConfig = field(default_factory=CLIPVisionConfig) + image_encoder_precision: str = "fp32" + + image_encoder_extra_args: dict = field( + default_factory=lambda: dict( + output_hidden_states=True, + ) + ) + + def postprocess_image(self, image): + return image.hidden_states[-2] + + def __post_init__(self) -> None: + self.vae_config.load_encoder = True + self.vae_config.load_decoder = True + + +@dataclass +class WanI2V720PConfig(WanI2V480PConfig): + """Base configuration for Wan I2V 14B 720P pipeline architecture.""" + + max_area: int = 720 * 1280 + # WanConfig-specific parameters with defaults + + # Denoising stage + flow_shift: float | None = 5.0 + + +@dataclass +class TurboWanI2V720Config(WanI2V720PConfig): + flow_shift: float | None = 8.0 + dmd_denoising_steps: list[int] | None = field( + default_factory=lambda: [996, 932, 852, 608] + ) + boundary_ratio: float | None = 0.9 + + def __post_init__(self) -> None: + self.dit_config.boundary_ratio = self.boundary_ratio + + +@dataclass +class FastWan2_1_T2V_480P_Config(WanT2V480PConfig): + """Base configuration for FastWan T2V 1.3B 480P pipeline architecture with DMD""" + + # WanConfig-specific parameters with defaults + + # Denoising stage + flow_shift: float | None = 8.0 + dmd_denoising_steps: list[int] | None = field( + default_factory=lambda: [1000, 757, 522] + ) + + +@dataclass +class Wan2_2_TI2V_5B_Config(WanT2V480PConfig, WanI2VCommonConfig): + flow_shift: float | None = 5.0 + task_type: ModelTaskType = ModelTaskType.TI2V + expand_timesteps: bool = True + # ti2v, 5B + vae_stride = (4, 16, 16) + + def prepare_latent_shape(self, batch, batch_size, num_frames): + F = num_frames + z_dim = self.vae_config.arch_config.z_dim + vae_stride = self.vae_stride + oh = batch.height + ow = batch.width + shape = (batch_size, z_dim, F, oh // vae_stride[1], ow // vae_stride[2]) + return shape + + def __post_init__(self) -> None: + self.vae_config.load_encoder = True + self.vae_config.load_decoder = True + self.dit_config.expand_timesteps = self.expand_timesteps + + +@dataclass +class FastWan2_2_TI2V_5B_Config(Wan2_2_TI2V_5B_Config): + flow_shift: float | None = 5.0 + dmd_denoising_steps: list[int] | None = field( + default_factory=lambda: [1000, 757, 522] + ) + + +@dataclass +class Wan2_2_T2V_A14B_Config(WanT2V480PConfig): + flow_shift: float | None = 12.0 + boundary_ratio: float | None = 0.875 + + def __post_init__(self) -> None: + self.dit_config.boundary_ratio = self.boundary_ratio + + +@dataclass +class Wan2_2_I2V_A14B_Config(WanI2V480PConfig): + flow_shift: float | None = 5.0 + boundary_ratio: float | None = 0.900 + + def __post_init__(self) -> None: + super().__post_init__() + self.dit_config.boundary_ratio = self.boundary_ratio + + +# ============================================= +# ============= Causal Self-Forcing ============= +# ============================================= +@dataclass +class SelfForcingWanT2V480PConfig(WanT2V480PConfig): + is_causal: bool = True + flow_shift: float | None = 5.0 + dmd_denoising_steps: list[int] | None = field( + default_factory=lambda: [1000, 750, 500, 250] + ) + warp_denoising_step: bool = True diff --git a/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/zimage.py b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/zimage.py new file mode 100644 index 0000000000000000000000000000000000000000..e2e180a9765aa7dc48acc570e2cb81b3fcfb82fc --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/pipeline_configs/zimage.py @@ -0,0 +1,328 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo +import math +from dataclasses import dataclass, field +from typing import Callable + +import torch + +from sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig +from sglang.multimodal_gen.configs.models.dits.zimage import ZImageDitConfig +from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput +from sglang.multimodal_gen.configs.models.encoders.qwen3 import Qwen3TextConfig +from sglang.multimodal_gen.configs.models.vaes.flux import FluxVAEConfig +from sglang.multimodal_gen.configs.pipeline_configs.base import ( + ImagePipelineConfig, + ModelTaskType, +) +from sglang.multimodal_gen.runtime.distributed.communication_op import ( + sequence_model_parallel_all_gather, +) +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_sp_parallel_rank, + get_sp_world_size, +) + + +def zimage_preprocess_text(prompt: str): + messages = [ + {"role": "user", "content": prompt}, + ] + return messages + + +def zimage_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor: + device = outputs.hidden_states[-2].device + prompt_mask = _text_inputs.attention_mask.to(device).bool() + return outputs.hidden_states[-2][0][prompt_mask[0]] + + +class TransformersModelConfig(EncoderConfig): + tokenizer_kwargs: dict = field(default_factory=lambda: {}) + + +@dataclass +class ZImagePipelineConfig(ImagePipelineConfig): + should_use_guidance: bool = False + task_type: ModelTaskType = ModelTaskType.T2I + dit_config: DiTConfig = field(default_factory=ZImageDitConfig) + vae_config: VAEConfig = field(default_factory=FluxVAEConfig) + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (Qwen3TextConfig(),) + ) + + preprocess_text_funcs: tuple[Callable, ...] = field( + default_factory=lambda: (zimage_preprocess_text,) + ) + postprocess_text_funcs: tuple[Callable, ...] = field( + default_factory=lambda: (zimage_postprocess_text,) + ) + + SEQ_LEN_MULTIPLE: int = 32 + PATCH_SIZE: int = 2 + F_PATCH_SIZE: int = 1 + + def tokenize_prompt(self, prompts: list[str], tokenizer, tok_kwargs) -> dict: + # flatten to 1-d list + inputs = tokenizer.apply_chat_template( + prompts, + tokenize=True, + add_generation_prompt=True, + enable_thinking=True, + padding="max_length", + max_length=512, # TODO (yhyang201): set max length according to config + truncation=True, + return_tensors="pt", + return_dict=True, + ) + return inputs + + @staticmethod + def _ceil_to_multiple(x: int, m: int) -> int: + if m <= 0: + return x + return int(math.ceil(x / m) * m) + + def _build_zimage_sp_plan(self, batch) -> dict: + """Build a minimal SP plan on batch for zimage (spatial sharding + cap sharding).""" + sp_size = get_sp_world_size() + rank = get_sp_parallel_rank() + + raw_latent_shape = getattr(batch, "raw_latent_shape", None) + if raw_latent_shape is not None and len(raw_latent_shape) >= 5: + H = int(raw_latent_shape[3]) + W = int(raw_latent_shape[4]) + else: + H = int( + batch.height // self.vae_config.arch_config.spatial_compression_ratio + ) + W = int( + batch.width // self.vae_config.arch_config.spatial_compression_ratio + ) + + # Rule: shard along the larger spatial dimension (W/H), implemented via optional H/W transpose. + # Choose the larger of H and W for sharding, so H_eff = max(H, W). + swap_hw = W > H + H_eff = W if swap_hw else H + W_eff = H if swap_hw else W + + # ZImage uses PATCH_SIZE=2 for spatial patchify; shard in token space and convert back to latent rows. + H_tok = H_eff // self.PATCH_SIZE + W_tok = W_eff // self.PATCH_SIZE + H_tok_pad = self._ceil_to_multiple(H_tok, sp_size) + H_tok_local = H_tok_pad // sp_size + h0_tok = rank * H_tok_local + + # Cap/text sharding: avoid duplicating cap tokens across ranks. + cap_len = ( + int(batch.prompt_embeds[0].size(0)) + if getattr(batch, "prompt_embeds", None) + else 0 + ) + cap_total = self._ceil_to_multiple(cap_len, self.SEQ_LEN_MULTIPLE * sp_size) + cap_local = cap_total // sp_size + cap_start = rank * cap_local + + plan = { + "sp_size": sp_size, + "rank": rank, + "swap_hw": swap_hw, + "H": H, + "W": W, + "H_eff": H_eff, + "W_eff": W_eff, + "H_tok": H_tok, + "W_tok": W_tok, + "H_tok_pad": H_tok_pad, + "H_tok_local": H_tok_local, + "h0_tok": h0_tok, + "cap_total": cap_total, + "cap_local": cap_local, + "cap_start": cap_start, + } + batch._zimage_sp_plan = plan + return plan + + def _get_zimage_sp_plan(self, batch) -> dict: + plan = getattr(batch, "_zimage_sp_plan", None) + sp_size = get_sp_world_size() + if plan is None or plan.get("sp_size") != sp_size: + plan = self._build_zimage_sp_plan(batch) + return plan + + def _shard_cap(self, cap: torch.Tensor, plan: dict) -> torch.Tensor: + """cap: [L, D] -> [cap_local, D], padded by repeating last token.""" + if plan["sp_size"] <= 1: + return cap + # print(f"cap shape: {cap.shape}") # [L, 2560] for zimage-turbo + L = cap.size(0) + cap_total = plan["cap_total"] + if cap_total > L: + cap = torch.cat([cap, cap[-1:].repeat(cap_total - L, 1)], dim=0) + start = plan["cap_start"] + local = plan["cap_local"] + return cap[start : start + local] + + def get_pos_prompt_embeds(self, batch): + # Keep ZImage model signature: encoder_hidden_states is List[Tensor] + if get_sp_world_size() <= 1: + return batch.prompt_embeds + plan = self._get_zimage_sp_plan(batch) + return [self._shard_cap(batch.prompt_embeds[0], plan)] + + def shard_latents_for_sp(self, batch, latents): + sp_size = get_sp_world_size() + if sp_size <= 1 or latents.dim() != 5: + return latents, False + + plan = self._get_zimage_sp_plan(batch) + + # Layout: [B, C, T, H, W]. Always shard on dim=3 by optionally swapping H/W. + if plan["swap_hw"]: + latents = latents.transpose(3, 4).contiguous() + + # Pad on effective-H so that H_tok is divisible by sp. + H_eff = latents.size(3) + + H_tok = H_eff // self.PATCH_SIZE + pad_tok = plan["H_tok_pad"] - H_tok + pad_lat = pad_tok * self.PATCH_SIZE + if pad_lat > 0: + pad = latents[:, :, :, -1:, :].repeat(1, 1, 1, pad_lat, 1) + latents = torch.cat([latents, pad], dim=3) + h0 = plan["h0_tok"] * self.PATCH_SIZE + h1 = (plan["h0_tok"] + plan["H_tok_local"]) * self.PATCH_SIZE + latents = latents[:, :, :, h0:h1, :] + + batch._zimage_sp_swap_hw = plan["swap_hw"] + return latents, True + + def gather_latents_for_sp(self, latents): + # Gather on effective-H dim=3 (matches shard_latents_for_sp); swap-back is handled in post_denoising_loop. + latents = latents.contiguous() + if get_sp_world_size() <= 1 or latents.dim() != 5: + return latents + return sequence_model_parallel_all_gather(latents, dim=3) + + def post_denoising_loop(self, latents, batch): + # Restore swapped H/W and crop padded spatial dims before final reshape. + if latents.dim() == 5 and getattr(batch, "_zimage_sp_swap_hw", False): + latents = latents.transpose(3, 4).contiguous() + raw_latent_shape = getattr(batch, "raw_latent_shape", None) + if raw_latent_shape is not None and latents.dim() == 5: + latents = latents[:, :, :, : raw_latent_shape[3], : raw_latent_shape[4]] + + bs, channels, num_frames, height, width = latents.shape + if raw_latent_shape is not None and num_frames > raw_latent_shape[2]: + latents = latents[:, :, : raw_latent_shape[2], :, :] + num_frames = raw_latent_shape[2] + if num_frames != 1: + return latents[:, :, 0, :, :] + return latents.view(bs, channels, height, width) + + def get_freqs_cis(self, prompt_embeds, width, height, device, rotary_emb, batch): + def create_coordinate_grid(size, start=None, device=None): + if start is None: + start = (0 for _ in size) + + axes = [ + torch.arange(x0, x0 + span, dtype=torch.int32, device=device) + for x0, span in zip(start, size) + ] + grids = torch.meshgrid(axes, indexing="ij") + return torch.stack(grids, dim=-1) + + sp_size = get_sp_world_size() + if sp_size > 1: + # SP path: build local-only freqs_cis matching local cap/x. + plan = self._get_zimage_sp_plan(batch) + + # cap (local) + cap_pos_ids = create_coordinate_grid( + size=(plan["cap_local"], 1, 1), + start=(1 + plan["cap_start"], 0, 0), + device=device, + ).flatten(0, 2) + cap_freqs_cis = rotary_emb(cap_pos_ids) + + # image (local, effective H-shard). Use cap_total for a stable offset across ranks/passes. + F_tokens = 1 + H_tokens_local = plan["H_tok_local"] + W_tokens = plan["W_tok"] + img_pos_ids = create_coordinate_grid( + size=(F_tokens, H_tokens_local, W_tokens), + start=(plan["cap_total"] + 1, plan["h0_tok"], 0), + device=device, + ).flatten(0, 2) + img_pad_len = (-img_pos_ids.shape[0]) % self.SEQ_LEN_MULTIPLE + if img_pad_len: + pad_ids = create_coordinate_grid( + size=(1, 1, 1), start=(0, 0, 0), device=device + ).flatten(0, 2) + img_pos_ids = torch.cat( + [img_pos_ids, pad_ids.repeat(img_pad_len, 1)], dim=0 + ) + x_freqs_cis = rotary_emb(img_pos_ids) + return (cap_freqs_cis, x_freqs_cis) + + cap_ori_len = prompt_embeds.size(0) + cap_padding_len = (-cap_ori_len) % self.SEQ_LEN_MULTIPLE + cap_padded_pos_ids = create_coordinate_grid( + size=(cap_ori_len + cap_padding_len, 1, 1), + start=(1, 0, 0), + device=device, + ).flatten(0, 2) + + F = 1 + H = height // self.vae_config.arch_config.spatial_compression_ratio + W = width // self.vae_config.arch_config.spatial_compression_ratio + + pH, pW = self.PATCH_SIZE, self.PATCH_SIZE + pF = self.F_PATCH_SIZE + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + image_ori_len = F_tokens * H_tokens * W_tokens + image_padding_len = (-image_ori_len) % self.SEQ_LEN_MULTIPLE + + image_ori_pos_ids = create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_ori_len + cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padding_pos_ids = ( + create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) + ) + image_padded_pos_ids = torch.cat( + [image_ori_pos_ids, image_padding_pos_ids], dim=0 + ) + cap_freqs_cis = rotary_emb(cap_padded_pos_ids) + x_freqs_cis = rotary_emb(image_padded_pos_ids) + return (cap_freqs_cis, x_freqs_cis) + + def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype): + return { + "freqs_cis": self.get_freqs_cis( + batch.prompt_embeds[0], + batch.width, + batch.height, + device, + rotary_emb, + batch, + ), + } + + def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype): + return { + "freqs_cis": self.get_freqs_cis( + batch.prompt_embeds[0], + batch.width, + batch.height, + device, + rotary_emb, + batch, + ), + } diff --git a/sglang/python/sglang/multimodal_gen/configs/quantization.py b/sglang/python/sglang/multimodal_gen/configs/quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..f0bd7f9c8a5b552a8ecf16fee9f6b7afcb8ed822 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/quantization.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +import re +from dataclasses import dataclass +from typing import Any + +import torch + +from sglang.multimodal_gen.runtime.layers.quantization.configs.nunchaku_config import ( + is_nunchaku_available, +) +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import StoreBoolean + +logger = init_logger(__name__) + + +@dataclass +class NunchakuSVDQuantArgs: + """CLI-facing configuration for Nunchaku (SVDQuant) inference. + + This is intentionally lightweight and only contains arguments needed to + construct `runtime.layers.quantization.nunchaku_config.NunchakuConfig`. + """ + + enable_svdquant: bool = False + transformer_weights_path: str | None = None + quantization_precision: str | None = None # "int4" or "nvfp4" + quantization_rank: int | None = None + quantization_act_unsigned: bool = False + + def _adjust_config(self) -> None: + """infer precision and rank from filename if not provided""" + if self.transformer_weights_path and not self.enable_svdquant: + filename = os.path.basename(self.transformer_weights_path) + if re.search(r"svdq-(int4|fp4)_r(\d+)", filename): + self.enable_svdquant = True + + if not self.enable_svdquant or not self.transformer_weights_path: + return + + inferred_precision = None + inferred_rank = None + + filename = os.path.basename(self.transformer_weights_path) + # Expected pattern: svdq-{precision}_r{rank}-... + # e.g., svdq-int4_r32-qwen-image.safetensors + match = re.search(r"svdq-(int4|fp4)_r(\d+)", filename) + + if match: + p_str, r_str = match.groups() + inferred_precision = "nvfp4" if p_str == "fp4" else "int4" + inferred_rank = int(r_str) + + if self.quantization_precision is None: + self.quantization_precision = inferred_precision or "int4" + if inferred_precision: + logger.info( + f"inferred --quantization-precision: {self.quantization_precision} " + f"from --transformer-weights-path: {self.transformer_weights_path}" + ) + + if self.quantization_rank is None: + self.quantization_rank = inferred_rank or 32 + if inferred_rank: + logger.info( + f"inferred --quantization-rank: {self.quantization_rank} " + f"from --transformer-weights-path: {self.transformer_weights_path}" + ) + + def validate(self) -> None: + # TODO: warn if the served model doesn't support nunchaku + self._adjust_config() + + if not self.enable_svdquant: + return + + if not current_platform.is_cuda(): + raise ValueError( + "Nunchaku SVDQuant is only supported on NVIDIA CUDA GPUs " + "(Ampere SM8x or SM12x)." + ) + + device_count = torch.cuda.device_count() + + unsupported: list[str] = [] + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major == 9: + unsupported.append(f"cuda:{i} (SM{major}{minor}, Hopper)") + elif major not in (8, 12): + unsupported.append(f"cuda:{i} (SM{major}{minor})") + + if unsupported: + raise ValueError( + "Nunchaku SVDQuant is currently only supported on Ampere (SM8x) or SM12x GPUs; " + "Hopper (SM90) is not supported. " + f"Unsupported devices: {', '.join(unsupported)}. " + "Disable it with --enable-svdquant false." + ) + + if not self.transformer_weights_path: + raise ValueError( + "--enable-svdquant requires --transformer-weights-path to be set" + ) + + if not is_nunchaku_available(): + raise ValueError( + "Nunchaku is enabled, but not installed. Please refer to https://nunchaku.tech/docs/nunchaku/installation/installation.html for detailed installation methods." + ) + + if self.quantization_precision not in ("int4", "nvfp4"): + raise ValueError( + f"Invalid --quantization-precision: {self.quantization_precision}. " + "Must be one of: int4, nvfp4" + ) + + if self.quantization_rank <= 0: + raise ValueError( + f"Invalid --quantization-rank: {self.quantization_rank}. Must be > 0" + ) + + @staticmethod + def add_cli_args(parser) -> None: + parser.add_argument( + "--enable-svdquant", + action=StoreBoolean, + default=NunchakuSVDQuantArgs.enable_svdquant, + help="Enable Nunchaku SVDQuant (W4A4-style) inference.", + ) + parser.add_argument( + "--transformer-weights-path", + type=str, + default=NunchakuSVDQuantArgs.transformer_weights_path, + help=( + "Path to pre-quantized transformer weights. Can be a single .safetensors " + "file, a directory, or a HuggingFace repo ID. Used by Nunchaku (SVDQuant) and quantized single-file checkpoints." + ), + ) + parser.add_argument( + "--quantization-precision", + type=str, + default=None, + help="Quantization precision: int4 or nvfp4. If not specified, inferred from model path or defaults to int4.", + ) + parser.add_argument( + "--quantization-rank", + type=int, + default=None, + help="SVD low-rank dimension (e.g., 32). If not specified, inferred from model path or defaults to 32.", + ) + parser.add_argument( + "--quantization-act-unsigned", + action=StoreBoolean, + default=NunchakuSVDQuantArgs.quantization_act_unsigned, + help="Use unsigned activation quantization (if supported).", + ) + + @classmethod + def from_dict(cls, kwargs: dict[str, Any]) -> "NunchakuSVDQuantArgs": + # Map CLI/config keys to dataclass fields (keep backwards compatibility). + path = ( + kwargs.get("transformer_weights_path") + or kwargs.get("transformer_quantized_path") + or kwargs.get("quantized_model_path") + ) + return cls( + enable_svdquant=bool(kwargs.get("enable_svdquant", cls.enable_svdquant)), + transformer_weights_path=path, + quantization_precision=kwargs.get("quantization_precision"), + quantization_rank=kwargs.get("quantization_rank"), + quantization_act_unsigned=bool( + kwargs.get("quantization_act_unsigned", cls.quantization_act_unsigned) + ), + ) diff --git a/sglang/python/sglang/multimodal_gen/configs/sample/__init__.py b/sglang/python/sglang/multimodal_gen/configs/sample/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a81bce02754733ea143f7c14cdffe8a2132a294 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/sample/__init__.py @@ -0,0 +1,8 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from sglang.multimodal_gen.configs.sample.diffusers_generic import ( + DiffusersGenericSamplingParams, +) +from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams + +__all__ = ["SamplingParams", "DiffusersGenericSamplingParams"] diff --git a/sglang/python/sglang/multimodal_gen/configs/sample/diffusers_generic.py b/sglang/python/sglang/multimodal_gen/configs/sample/diffusers_generic.py new file mode 100644 index 0000000000000000000000000000000000000000..d30ea5ddb51d002380b69a464f6952c5ad135d3d --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/sample/diffusers_generic.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Generic sampling parameters for diffusers backend. + +This module provides generic sampling parameters that work with any diffusers pipeline. +""" + +from dataclasses import dataclass, field +from typing import Any + +from sglang.multimodal_gen.configs.sample.sampling_params import ( + DataType, + SamplingParams, +) + + +@dataclass +class DiffusersGenericSamplingParams(SamplingParams): + """ + Generic sampling parameters for diffusers backend. + + These parameters cover the most common options across different diffusers pipelines. + The diffusers pipeline will use whichever parameters it supports. + + For pipeline-specific parameters, use `diffusers_kwargs` dict which will be + passed directly to the diffusers pipeline call. + """ + + # Override defaults with more conservative values that work across pipelines + num_frames: int = 1 # default to image generation + height: int = 1024 + width: int = 1024 + num_inference_steps: int = 30 + guidance_scale: float = 7.5 + negative_prompt: str = "" + + # extra kwargs to pass directly to the diffusers pipeline + # example: {"output_type": "latent", "return_dict": False} + diffusers_kwargs: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + if self.num_frames > 1: + self.data_type = DataType.VIDEO + else: + self.data_type = DataType.IMAGE + + if self.width is None: + self.width_not_provided = True + self.width = 1024 + if self.height is None: + self.height_not_provided = True + self.height = 1024 diff --git a/sglang/python/sglang/multimodal_gen/configs/sample/flux.py b/sglang/python/sglang/multimodal_gen/configs/sample/flux.py new file mode 100644 index 0000000000000000000000000000000000000000..692df8332ac2ec33061432d3dc7afaa4b315b807 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/sample/flux.py @@ -0,0 +1,28 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams + + +@dataclass +class FluxSamplingParams(SamplingParams): + num_frames: int = 1 + # Denoising stage + guidance_scale: float = 1.0 + negative_prompt: str = None + num_inference_steps: int = 50 + + def __post_init__(self): + default_sample_size = 128 + vae_scale_factor = 8 + # FIXME + # self.height = default_sample_size * vae_scale_factor + # self.width = default_sample_size * vae_scale_factor + + +@dataclass +class Flux2KleinSamplingParams(FluxSamplingParams): + # Klein is step-distilled, so default to 4 steps + num_inference_steps: int = 4 diff --git a/sglang/python/sglang/multimodal_gen/configs/sample/glmimage.py b/sglang/python/sglang/multimodal_gen/configs/sample/glmimage.py new file mode 100644 index 0000000000000000000000000000000000000000..27ff3c741791db14f87f5071bca86a5a13b45902 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/sample/glmimage.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass + +from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams + + +@dataclass +class GlmImageSamplingParams(SamplingParams): + negative_prompt = "" + + num_frames: int = 1 + guidance_scale: float = 1.5 + num_inference_steps: int = 30 diff --git a/sglang/python/sglang/multimodal_gen/configs/sample/helios.py b/sglang/python/sglang/multimodal_gen/configs/sample/helios.py new file mode 100644 index 0000000000000000000000000000000000000000..28ccd2093e3b0e745410c7799c7a110b27a9500a --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/sample/helios.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams + + +@dataclass +class HeliosT2VSamplingParams(SamplingParams): + # Video parameters + height: int = 384 + width: int = 640 + num_frames: int = 99 + fps: int = 24 + + # Denoising stage + guidance_scale: float = 5.0 + negative_prompt: str = ( + "Bright tones, overexposed, static, blurred details, subtitles, style, " + "works, paintings, images, static, overall gray, worst quality, low quality, " + "JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, " + "poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, " + "still picture, messy background, three legs, many people in the background, " + "walking backwards" + ) + num_inference_steps: int = 50 + + # Helios T2V supported resolutions + supported_resolutions: list[tuple[int, int]] | None = field( + default_factory=lambda: [ + (640, 384), # ~5:3 + (384, 640), # ~3:5 + (832, 480), # ~16:9-ish + (480, 832), # ~9:16-ish + ] + ) + + +@dataclass +class HeliosMidSamplingParams(HeliosT2VSamplingParams): + """Sampling params for Helios-Mid (Stage 2 pyramid SR).""" + + num_inference_steps: int = 20 + + +@dataclass +class HeliosDistilledSamplingParams(HeliosT2VSamplingParams): + """Sampling params for Helios-Distilled (DMD, no CFG needed).""" + + guidance_scale: float = 1.0 + num_inference_steps: int = 10 diff --git a/sglang/python/sglang/multimodal_gen/configs/sample/hunyuan.py b/sglang/python/sglang/multimodal_gen/configs/sample/hunyuan.py new file mode 100644 index 0000000000000000000000000000000000000000..ae69dbd62ccd019674aab925f5925b3b25243a31 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/sample/hunyuan.py @@ -0,0 +1,55 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams +from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams + + +@dataclass +class HunyuanSamplingParams(SamplingParams): + num_inference_steps: int = 50 + + num_frames: int = 125 + height: int = 720 + width: int = 1280 + fps: int = 24 + + guidance_scale: float = 1.0 + + # HunyuanVideo supported resolutions + supported_resolutions: list[tuple[int, int]] | None = field( + default_factory=lambda: [ + # 540p resolutions + (960, 544), # 9:16 + (544, 960), # 16:9 + (832, 624), # 4:3 + (624, 832), # 3:4 + (720, 720), # 1:1 + # 720p resolutions (recommended) + (1280, 720), # 9:16 + (720, 1280), # 16:9 + (832, 1104), # 4:3 + (1104, 832), # 3:4 + (960, 960), # 1:1 + ] + ) + + teacache_params: TeaCacheParams = field( + default_factory=lambda: TeaCacheParams( + teacache_thresh=0.15, + coefficients=[ + 7.33226126e02, + -4.01131952e02, + 6.75869174e01, + -3.14987800e00, + 9.61237896e-02, + ], + ) + ) + + +@dataclass +class FastHunyuanSamplingParam(HunyuanSamplingParams): + num_inference_steps: int = 6 diff --git a/sglang/python/sglang/multimodal_gen/configs/sample/hunyuan3d.py b/sglang/python/sglang/multimodal_gen/configs/sample/hunyuan3d.py new file mode 100644 index 0000000000000000000000000000000000000000..843ab99c9335d1a5d6bd1d3d3b97286614669ea1 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/sample/hunyuan3d.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Sampling parameters for Hunyuan3D generation.""" + +from dataclasses import dataclass + +from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams + + +@dataclass +class Hunyuan3DSamplingParams(SamplingParams): + """Sampling parameters for Hunyuan3D image-to-mesh generation.""" + + negative_prompt: str = "" + + shape_num_inference_steps: int = 50 + guidance_scale: float = 5.0 + + paint_num_inference_steps: int = 30 + paint_guidance_scale: float = 2.0 + + def __post_init__(self): + if self.prompt is None: + self.prompt = "" + + if self.num_inference_steps is None: + self.num_inference_steps = self.shape_num_inference_steps + + self.guidance_scale = max(5.0, min(self.guidance_scale, 6.5)) + super().__post_init__() diff --git a/sglang/python/sglang/multimodal_gen/configs/sample/ltx_2.py b/sglang/python/sglang/multimodal_gen/configs/sample/ltx_2.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2e92b58f067230016f2dd75efb71437f78a547 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/sample/ltx_2.py @@ -0,0 +1,40 @@ +import dataclasses + +from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams + + +@dataclasses.dataclass +class LTX2SamplingParams(SamplingParams): + """Sampling parameters for LTX-2.""" + + # Match the reference defaults used by ltx-pipelines (one-stage). + # See: LTX-2/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py + seed: int = 10 + + # Video parameters + height: int = 512 + width: int = 768 + num_frames: int = 121 + fps: int = 24 + + # Audio specific + generate_audio: bool = True + + # Denoising parameters + guidance_scale: float = 4.0 + num_inference_steps: int = 40 + + # Match ltx-pipelines default negative prompt (covers video + audio artifacts). + negative_prompt: str = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." + ) diff --git a/sglang/python/sglang/multimodal_gen/configs/sample/mova.py b/sglang/python/sglang/multimodal_gen/configs/sample/mova.py new file mode 100644 index 0000000000000000000000000000000000000000..0dafbd0f5681e249a1eb0d8817167b09b4638178 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/sample/mova.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams + + +@dataclass +class MOVASamplingParams(SamplingParams): + # Video parameters (MOVA defaults) + height: int = 352 + width: int = 640 + num_frames: int = 193 + fps: int = 24 + + # Denoising stage + guidance_scale: float = 5.0 + num_inference_steps: int = 50 + sigma_shift: float = 5.0 + visual_shift: float = 5.0 + audio_shift: float = 5.0 + + adjust_frames: bool = False + + negative_prompt: str = ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止," + "整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指," + "画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合," + "静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + ) + + +@dataclass +class MOVA_360P_SamplingParams(MOVASamplingParams): + # Video parameters (MOVA 360P) + height: int = 352 + width: int = 640 + + # MOVA 360P supported resolutions + supported_resolutions: list[tuple[int, int]] = field( + default_factory=lambda: [ + (352, 640), + (640, 352), + ] + ) + + +@dataclass +class MOVA_720P_SamplingParams(MOVASamplingParams): + # Video parameters (MOVA 720P) + height: int = 720 + width: int = 1280 + + # MOVA 720P supported resolutions + supported_resolutions: list[tuple[int, int]] = field( + default_factory=lambda: [ + (720, 1280), + (1280, 720), + ] + ) diff --git a/sglang/python/sglang/multimodal_gen/configs/sample/qwenimage.py b/sglang/python/sglang/multimodal_gen/configs/sample/qwenimage.py new file mode 100644 index 0000000000000000000000000000000000000000..127dfa5fc6279718478d7b26c8c91e8a88b7a447 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/sample/qwenimage.py @@ -0,0 +1,44 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams + + +@dataclass +class QwenImageSamplingParams(SamplingParams): + negative_prompt: str = " " + num_frames: int = 1 + # Denoising stage + guidance_scale: float = 4.0 + num_inference_steps: int = 50 + + +@dataclass +class QwenImage2512SamplingParams(QwenImageSamplingParams): + negative_prompt: str = ( + "低分辨率,低画质,肢体畸形,手指畸形,画面过饱和,蜡像感,人脸无细节,过度光滑,画面具有AI感。构图混乱。文字模糊,扭曲。" + ) + + +@dataclass +class QwenImageEditPlusSamplingParams(QwenImageSamplingParams): + # Denoising stage + guidance_scale: float = 4.0 + # true_cfg_scale: float = 4.0 + num_inference_steps: int = 40 + + +@dataclass +class QwenImageLayeredSamplingParams(QwenImageSamplingParams): + # num_frames: int = 4 + height: int = 640 + width: int = 640 + prompt: str = " " + negative_prompt: str = " " + + guidance_scale: float = 4.0 + num_inference_steps: int = 50 + cfg_normalize: bool = True + use_en_prompt: bool = True diff --git a/sglang/python/sglang/multimodal_gen/configs/sample/sampling_params.py b/sglang/python/sglang/multimodal_gen/configs/sample/sampling_params.py new file mode 100644 index 0000000000000000000000000000000000000000..ec387da2d7c6e97b5f415b7f305e77ad5c96b600 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/sample/sampling_params.py @@ -0,0 +1,951 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import argparse +import dataclasses +import hashlib +import json +import math +import os +import os.path +import re +import time +import unicodedata +import uuid +from dataclasses import dataclass +from enum import Enum, auto +from typing import TYPE_CHECKING, Any + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import StoreBoolean, expand_path_fields + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from sglang.multimodal_gen.runtime.server_args import ServerArgs + + +def _json_safe(obj: Any): + """ + Recursively convert objects to JSON-serializable forms. + - Enums -> their name + - Sets/Tuples -> lists + - Dicts/Lists -> recursively processed + """ + if isinstance(obj, Enum): + return obj.name + if isinstance(obj, dict): + return {k: _json_safe(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple, set)): + return [_json_safe(v) for v in obj] + return obj + + +def generate_request_id() -> str: + return str(uuid.uuid4()) + + +def _sanitize_filename(name: str, replacement: str = "_", max_length: int = 150) -> str: + """Create a filesystem- and ffmpeg-friendly filename. + + - Normalize to ASCII (drop accents and unsupported chars) + - Replace spaces with underscores + - Replace any char not in [A-Za-z0-9_.-] with replacement + - Collapse multiple underscores + - Trim leading/trailing dots/underscores and limit length + """ + normalized = unicodedata.normalize("NFKD", name) + ascii_name = normalized.encode("ascii", "ignore").decode("ascii") + ascii_name = ascii_name.replace(" ", "_") + ascii_name = re.sub(r"[^A-Za-z0-9._-]", replacement, ascii_name) + ascii_name = re.sub(r"_+", "_", ascii_name).strip("._") + if not ascii_name: + ascii_name = "output" + if max_length and len(ascii_name) > max_length: + ascii_name = ascii_name[:max_length] + return ascii_name + + +class DataType(Enum): + IMAGE = auto() + VIDEO = auto() + MESH = auto() + + def get_default_extension(self) -> str: + if self == DataType.IMAGE: + return "png" + if self == DataType.VIDEO: + return "mp4" + return "glb" + + +@dataclass +class SamplingParams: + """ + Sampling parameters for generation. + """ + + data_type: DataType = DataType.VIDEO + + request_id: str | None = None + + # All fields below are copied from ForwardBatch + + # Image inputs + image_path: str | list[str] | None = None + + # Text inputs + prompt: str | list[str] | None = None + negative_prompt: str = ( + "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + ) + prompt_path: str | None = None + output_path: str | None = None + output_file_name: str | None = None + output_quality: str | None = "default" + output_compression: int | None = None + + # Frame interpolation + enable_frame_interpolation: bool = False + frame_interpolation_exp: int = 1 # 1=2x, 2=4x + frame_interpolation_scale: float = 1.0 # RIFE inference scale (0.5 for high-res) + frame_interpolation_model_path: str | None = ( + None # local dir or HF repo ID with flownet.pkl (default: elfgum/RIFE-4.22.lite) + ) + + # Batch info + num_outputs_per_prompt: int = 1 + seed: int = 42 + generator_device: str = "cuda" # Device for random generator: "cuda" or "cpu" + + # Original dimensions (before VAE scaling) + num_frames: int = 1 # Default for image models + num_frames_round_down: bool = ( + False # Whether to round down num_frames if it's not divisible by num_gpus + ) + height: int | None = None + width: int | None = None + # NOTE: this is temporary, we need a way to know if width or height is not provided, or do the image resize earlier + height_not_provided: bool = False + width_not_provided: bool = False + fps: int = 24 + + # Resolution validation + supported_resolutions: list[tuple[int, int]] | None = ( + None # None means all resolutions allowed + ) + + # Denoising parameters + num_inference_steps: int = None + guidance_scale: float = 1.0 + guidance_scale_2: float = None + true_cfg_scale: float = None # for CFG vs guidance distillation (e.g., QwenImage) + guidance_rescale: float = 0.0 + cfg_normalization: float | bool = 0.0 + boundary_ratio: float | None = None + + # TeaCache parameters + enable_teacache: bool = False + + # Profiling + profile: bool = False + num_profiled_timesteps: int = 5 + profile_all_stages: bool = False + + # Debugging + debug: bool = False + perf_dump_path: str | None = None + + # Misc + save_output: bool = True + return_frames: bool = False + return_trajectory_latents: bool = False # returns all latents for each timestep + return_trajectory_decoded: bool = False # returns decoded latents for each timestep + # if True, disallow user params to override subclass-defined protected fields + no_override_protected_fields: bool = False + # whether to adjust num_frames for multi-GPU friendly splitting (default: True) + adjust_frames: bool = True + # if True, suppress verbose logging for this request + suppress_logs: bool = False + + return_file_paths_only: bool = True + enable_sequence_shard: bool | None = None + + def _set_output_file_ext(self): + # add extension if needed + if not any( + self.output_file_name.endswith(ext) + for ext in [".mp4", ".jpg", ".png", ".webp", ".obj", ".glb"] + ): + self.output_file_name = ( + f"{self.output_file_name}.{self.data_type.get_default_extension()}" + ) + + def _set_output_file_name(self): + # settle output_file_name + if ( + self.output_file_name is None + and self.prompt + and isinstance(self.prompt, str) + ): + # generate a random filename + # get a hash of current params + params_dict = dataclasses.asdict(self) + # Avoid recursion + params_dict["output_file_name"] = "" + + # Convert to a stable JSON string + params_str = json.dumps(_json_safe(params_dict), sort_keys=True) + # Create a hash + hasher = hashlib.sha256() + hasher.update(params_str.encode("utf-8")) + param_hash = hasher.hexdigest()[:8] + + timestamp = time.strftime("%Y%m%d-%H%M%S") + base = f"{self.prompt[:100]}_{timestamp}_{param_hash}" + self.output_file_name = base + + if self.output_file_name is None: + timestamp = time.strftime("%Y%m%d-%H%M%S") + self.output_file_name = f"output_{timestamp}" + + self.output_file_name = _sanitize_filename(self.output_file_name) + + # Ensure a proper extension is present + self._set_output_file_ext() + + def __post_init__(self) -> None: + assert self.num_frames >= 1 + + if self.width is None: + self.width_not_provided = True + if self.height is None: + self.height_not_provided = True + + # Handle output_quality to output_compression conversion + if self.output_compression is None and self.output_quality is not None: + self.output_compression = self._adjust_output_quality( + self.output_quality, self.data_type + ) + + self._validate() + + # Allow env var to override num_inference_steps (for faster CI testing on AMD) + env_steps = os.environ.get("SGLANG_TEST_NUM_INFERENCE_STEPS") + if env_steps is not None and self.num_inference_steps is not None: + self.num_inference_steps = int(env_steps) + + def _adjust_output_quality(self, output_quality: str, data_type: DataType) -> int: + """Convert output_quality string to compression level.""" + output_quality_mapper = {"maximum": 100, "high": 90, "medium": 55, "low": 35} + if output_quality == "default": + return 50 if data_type == DataType.VIDEO else 75 + return output_quality_mapper.get(output_quality) + + def _validate(self): + """ + check if the sampling params is correct by itself + """ + if self.prompt_path and not self.prompt_path.endswith(".txt"): + raise ValueError( + f"prompt_path must be a txt file, got {self.prompt_path!r}" + ) + + # These are always required to be sane regardless of pipeline. + if ( + not isinstance(self.num_outputs_per_prompt, int) + or self.num_outputs_per_prompt <= 0 + ): + raise ValueError( + f"num_outputs_per_prompt must be a positive int, got {self.num_outputs_per_prompt!r}" + ) + + # Used by seconds() and video writer; fps <= 0 is always invalid. + if not isinstance(self.fps, int) or self.fps <= 0: + raise ValueError(f"fps must be a positive int, got {self.fps!r}") + + # num_frames is already asserted in __post_init__, but keep a friendly error here too + # (e.g., when validation is triggered from other code paths). + if not isinstance(self.num_frames, int) or self.num_frames <= 0: + raise ValueError( + f"num_frames must be a positive int, got {self.num_frames!r}" + ) + + if self.num_inference_steps is not None: + if ( + not isinstance(self.num_inference_steps, int) + or self.num_inference_steps <= 0 + ): + raise ValueError( + f"num_inference_steps must be a positive int, got {self.num_inference_steps!r}" + ) + + # Numeric hyperparams should not be NaN/Inf and should be within basic ranges. + # Note: bool is a subclass of int; reject it explicitly to avoid silent surprises. + def _finite_non_negative_float( + name: str, value: Any, allow_none: bool = True + ) -> None: + if value is None and allow_none: + return + if isinstance(value, bool) or not isinstance(value, (int, float)): + raise ValueError(f"{name} must be a number, got {value!r}") + if not math.isfinite(float(value)): + raise ValueError(f"{name} must be finite, got {value!r}") + if float(value) < 0.0: + raise ValueError(f"{name} must be non-negative, got {value!r}") + + _finite_non_negative_float( + "guidance_scale", self.guidance_scale, allow_none=True + ) + _finite_non_negative_float( + "guidance_scale_2", self.guidance_scale_2, allow_none=True + ) + _finite_non_negative_float( + "true_cfg_scale", self.true_cfg_scale, allow_none=True + ) + _finite_non_negative_float( + "guidance_rescale", self.guidance_rescale, allow_none=False + ) + + if self.cfg_normalization is None: + self.cfg_normalization = 0.0 + elif isinstance(self.cfg_normalization, bool): + self.cfg_normalization = 1.0 if self.cfg_normalization else 0.0 + + if self.boundary_ratio is not None: + if isinstance(self.boundary_ratio, bool) or not isinstance( + self.boundary_ratio, (int, float) + ): + raise ValueError( + f"boundary_ratio must be a number, got {self.boundary_ratio!r}" + ) + if not math.isfinite(float(self.boundary_ratio)): + raise ValueError( + f"boundary_ratio must be finite, got {self.boundary_ratio!r}" + ) + if not (0.0 <= float(self.boundary_ratio) <= 1.0): + raise ValueError( + f"boundary_ratio must be within [0, 1], got {self.boundary_ratio!r}" + ) + + def check_sampling_param(self): + # Keep backward-compatibility for old call sites. + self._validate() + + def _validate_with_pipeline_config(self, pipeline_config): + """ + check if the sampling params is compatible and valid with server_args + """ + if pipeline_config.task_type.requires_image_input(): + # requires image input + if self.image_path is None: + raise ValueError( + f"Served model with task type '{pipeline_config.task_type.name}' requires an 'image_path' input, but none was provided" + ) + + if not pipeline_config.task_type.accepts_image_input(): + # does not support image input + if self.image_path is not None: + raise ValueError( + f"input_reference is not supported for {pipeline_config.task_type.name} models." + ) + + def _adjust( + self, + server_args, + ): + """ + final adjustment, called after merged with user params + """ + expand_path_fields(self) + + # TODO: SamplingParams should not rely on ServerArgs + pipeline_config = server_args.pipeline_config + + if self.guidance_scale is None: + try: + from sglang.multimodal_gen.configs.pipeline_configs.hunyuan3d import ( + Hunyuan3D2PipelineConfig, + ) + + if isinstance(pipeline_config, Hunyuan3D2PipelineConfig): + self.guidance_scale = pipeline_config.guidance_scale + else: + self.guidance_scale = 1.0 + except ImportError: + self.guidance_scale = 1.0 + + self.data_type = server_args.pipeline_config.task_type.data_type() + + if self.output_path is None: + if server_args.output_path is not None: + self.output_path = server_args.output_path + logger.debug( + f"Overriding output_path with server configuration: {self.output_path}" + ) + else: + self.save_output = False + + # Process negative prompt + if self.negative_prompt is not None and not self.negative_prompt.isspace(): + # avoid stripping default negative prompt: ' ' for qwen-image + self.negative_prompt = self.negative_prompt.strip() + + # Validate dimensions + if self.num_frames <= 0: + raise ValueError( + f"height, width, and num_frames must be positive integers, got " + f"height={self.height}, width={self.width}, " + f"num_frames={self.num_frames}" + ) + + # Validate resolution against pipeline-specific supported resolutions + if self.height is None and self.width is None: + if self.supported_resolutions is not None: + self.width, self.height = self.supported_resolutions[0] + logger.info( + f"Resolution unspecified, using default: {self.supported_resolutions[0]}" + ) + + if self.height is not None and self.width is not None: + if self.supported_resolutions is not None: + if (self.width, self.height) not in self.supported_resolutions: + supported_str = ", ".join( + [f"{w}x{h}" for w, h in self.supported_resolutions] + ) + error_msg = ( + f"Unsupported resolution: {self.width}x{self.height}, output quality may suffer. " + f"Supported resolutions: {supported_str}" + ) + logger.warning(error_msg) + + pipeline_name_lower = server_args.pipeline_config.__class__.__name__.lower() + + if "wan" in pipeline_name_lower and ( + self.enable_sequence_shard is None or self.enable_sequence_shard + ): + self.enable_sequence_shard = True + logger.debug("Automatically enabled enable_sequence_shard") + else: + self.enable_sequence_shard = False + + if self.enable_sequence_shard: + self.adjust_frames = False + logger.info( + f"Sequence dimension shard is enabled, disabling frame adjustment for better performance" + ) + + if pipeline_config.task_type.is_image_gen(): + # settle num_frames + if not server_args.pipeline_config.allow_set_num_frames(): + logger.debug(f"Setting `num_frames` to 1 for image generation model") + self.num_frames = 1 + + else: + # mandatory frame adjusting logic, mod + # NOTE: We must apply adjust_num_frames BEFORE the SP alignment logic below. + # If we apply it after, adjust_num_frames might modify the frame count + # and break the divisibility constraint (alignment) required by num_gpus. + original_num_frames = self.num_frames + self.num_frames = server_args.pipeline_config.adjust_num_frames( + original_num_frames + ) + logger.info( + "Adjusting number of frames from %s to %s based on model", + original_num_frames, + self.num_frames, + ) + + if self.adjust_frames: + # Adjust number of frames based on number of GPUs for video task + use_temporal_scaling_frames = ( + pipeline_config.vae_config.use_temporal_scaling_frames + ) + num_frames = self.num_frames + num_gpus = server_args.num_gpus + temporal_scale_factor = ( + pipeline_config.vae_config.arch_config.temporal_compression_ratio + ) + + if use_temporal_scaling_frames: + orig_latent_num_frames = ( + num_frames - 1 + ) // temporal_scale_factor + 1 + else: + orig_latent_num_frames = num_frames + + if orig_latent_num_frames % server_args.num_gpus != 0: + # Adjust latent frames to be divisible by number of GPUs + if self.num_frames_round_down: + # Ensure we have at least 1 batch per GPU + new_latent_num_frames = ( + max(1, (orig_latent_num_frames // num_gpus)) * num_gpus + ) + else: + new_latent_num_frames = ( + math.ceil(orig_latent_num_frames / num_gpus) * num_gpus + ) + + if use_temporal_scaling_frames: + # Convert back to number of frames, ensuring num_frames-1 is a multiple of temporal_scale_factor + new_num_frames = ( + new_latent_num_frames - 1 + ) * temporal_scale_factor + 1 + else: + new_num_frames = new_latent_num_frames + + logger.info( + "Adjusting number of frames from %s to %s based on number of GPUs (%s)", + self.num_frames, + new_num_frames, + server_args.num_gpus, + ) + self.num_frames = new_num_frames + + if not server_args.comfyui_mode: + self._set_output_file_name() + + @classmethod + def from_pretrained(cls, model_path: str, **kwargs) -> "SamplingParams": + from sglang.multimodal_gen.registry import get_model_info + + backend = kwargs.pop("backend", None) + model_id = kwargs.pop("model_id", None) + model_info = get_model_info(model_path, backend=backend, model_id=model_id) + sampling_params: SamplingParams = model_info.sampling_param_cls(**kwargs) + return sampling_params + + @staticmethod + def from_user_sampling_params_args( + model_path: str, server_args: "ServerArgs", *args, **kwargs + ): + try: + sampling_params = SamplingParams.from_pretrained( + model_path, backend=server_args.backend, model_id=server_args.model_id + ) + except (AttributeError, ValueError) as e: + # Handle safetensors files or other cases where model_index.json is not available + # Use appropriate SamplingParams based on pipeline_class_name from registry + if os.path.isfile(model_path) and model_path.endswith(".safetensors"): + # Determine which sampling params to use based on pipeline_class_name + pipeline_class_name = getattr(server_args, "pipeline_class_name", None) + + # Try to get SamplingParams from registry + from sglang.multimodal_gen.registry import get_pipeline_config_classes + + config_classes = ( + get_pipeline_config_classes(pipeline_class_name) + if pipeline_class_name + else None + ) + + if config_classes is not None: + _, sampling_params_cls = config_classes + try: + sampling_params = sampling_params_cls() + logger.info( + f"Using {sampling_params_cls.__name__} for {pipeline_class_name} safetensors file (no model_index.json): %s", + model_path, + ) + except Exception as import_error: + logger.warning( + f"Failed to instantiate {sampling_params_cls.__name__}: {import_error}. " + "Using default SamplingParams" + ) + sampling_params = SamplingParams() + else: + raise ValueError( + f"Could not get pipeline config classes for {pipeline_class_name}" + ) + else: + # Re-raise if it's not a safetensors file issue + raise + + user_kwargs = dict(kwargs) + user_kwargs.pop("diffusers_kwargs", None) + user_sampling_params = SamplingParams(*args, **user_kwargs) + # TODO: refactor + sampling_params._merge_with_user_params(user_sampling_params) + sampling_params._adjust(server_args) + + sampling_params._validate_with_pipeline_config(server_args.pipeline_config) + + return sampling_params + + def output_size_str(self) -> str: + return f"{self.width}x{self.height}" + + def seconds(self) -> float: + return self.num_frames / self.fps + + @staticmethod + def add_cli_args(parser: Any) -> Any: + """Add CLI arguments for SamplingParam fields""" + parser.add_argument("--data-type", type=str, nargs="+", default=DataType.VIDEO) + parser.add_argument( + "--num-frames-round-down", + action="store_true", + default=SamplingParams.num_frames_round_down, + ) + parser.add_argument( + "--enable-teacache", + action="store_true", + default=SamplingParams.enable_teacache, + ) + + # profiling + parser.add_argument( + "--profile", + action="store_true", + default=SamplingParams.profile, + help="Enable torch profiler for denoising stage", + ) + parser.add_argument( + "--num-profiled-timesteps", + type=int, + default=SamplingParams.num_profiled_timesteps, + help="Number of timesteps to profile after warmup", + ) + parser.add_argument( + "--profile-all-stages", + action="store_true", + dest="profile_all_stages", + default=SamplingParams.profile_all_stages, + help="Used with --profile, profile all pipeline stages", + ) + + parser.add_argument( + "--debug", + action="store_true", + default=SamplingParams.debug, + help="", + ) + + parser.add_argument( + "--prompt", + type=str, + nargs="+", + default=SamplingParams.prompt, + help="Text prompt(s) for generation. Use space-separated values for multiple prompts, e.g., --prompt 'prompt 1' 'prompt 2'", + ) + parser.add_argument( + "--negative-prompt", + type=str, + default=SamplingParams.negative_prompt, + help="Negative text prompt for generation", + ) + parser.add_argument( + "--prompt-path", + type=str, + default=SamplingParams.prompt_path, + help="Path to a text file containing the prompt", + ) + parser.add_argument( + "--output-file-name", + type=str, + default=SamplingParams.output_file_name, + help="Name of the output file", + ) + parser.add_argument( + "--output-quality", + type=str, + default=SamplingParams.output_quality, + help="Output quality setting (default, low, medium, high, maximum)", + ) + parser.add_argument( + "--output-compression", + type=int, + default=SamplingParams.output_compression, + help="Output compression level (0-100, higher means better quality but larger file size)", + ) + parser.add_argument( + "--num-outputs-per-prompt", + type=int, + default=SamplingParams.num_outputs_per_prompt, + help="Number of outputs to generate per prompt", + ) + parser.add_argument( + "--seed", + type=int, + default=SamplingParams.seed, + help="Random seed for generation", + ) + parser.add_argument( + "--generator-device", + type=str, + default=SamplingParams.generator_device, + choices=["cuda", "musa", "cpu"], + help="Device for random generator (cuda, musa or cpu). Default: cuda", + ) + parser.add_argument( + "--num-frames", + type=int, + default=SamplingParams.num_frames, + help="Number of frames to generate", + ) + parser.add_argument( + "--height", + type=int, + default=SamplingParams.height, + help="Height of generated output", + ) + parser.add_argument( + "--width", + type=int, + default=SamplingParams.width, + help="Width of generated output", + ) + # resolution shortcuts + parser.add_argument( + "--4k", + action="store_true", + dest="resolution_4k", + help="Set resolution to 4K (3840x2160)", + ) + parser.add_argument( + "--2k", + action="store_true", + dest="resolution_2k", + help="Set resolution to 2K (2560x1440)", + ) + parser.add_argument( + "--1080p", + action="store_true", + dest="resolution_1080p", + help="Set resolution to 1080p (1920x1080)", + ) + parser.add_argument( + "--720p", + action="store_true", + dest="resolution_720p", + help="Set resolution to 720p (1280x720)", + ) + + parser.add_argument( + "--fps", + type=int, + default=SamplingParams.fps, + help="Frames per second for saved output", + ) + parser.add_argument( + "--num-inference-steps", + type=int, + default=SamplingParams.num_inference_steps, + help="Number of denoising steps", + ) + parser.add_argument( + "--guidance-scale", + type=float, + default=SamplingParams.guidance_scale, + help="Classifier-free guidance scale", + ) + parser.add_argument( + "--guidance-scale-2", + type=float, + default=SamplingParams.guidance_scale_2, + dest="guidance_scale_2", + help="Secondary guidance scale for dual-guidance models (e.g., Wan low-noise expert)", + ) + parser.add_argument( + "--guidance-rescale", + type=float, + default=SamplingParams.guidance_rescale, + help="Guidance rescale factor", + ) + parser.add_argument( + "--cfg-normalization", + type=float, + default=SamplingParams.cfg_normalization, # type: ignore[arg-type] + dest="cfg_normalization", + help=("CFG renormalization factor (for Z-Image). "), + ) + parser.add_argument( + "--boundary-ratio", + type=float, + default=SamplingParams.boundary_ratio, + help="Boundary timestep ratio", + ) + parser.add_argument( + "--save-output", + action="store_true", + default=SamplingParams.save_output, + help="Whether to save the output to disk", + ) + parser.add_argument( + "--no-save-output", + action="store_false", + dest="save_output", + help="Don't save the output to disk", + ) + parser.add_argument( + "--return-frames", + action="store_true", + default=SamplingParams.return_frames, + help="Whether to return the raw frames", + ) + parser.add_argument( + "--image-path", + type=str, + nargs="+", + default=SamplingParams.image_path, + help=( + "Path(s) to input image(s) for image-to-image / image-to-video " + "generation. For multiple images, pass them as space-separated " + "values, e.g.: " + '--image-path "img1.png" "img2.png"' + ), + ) + parser.add_argument( + "--moba-config-path", + type=str, + default=None, + help="Path to a JSON file containing V-MoBA specific configurations.", + ) + parser.add_argument( + "--return-trajectory-latents", + action="store_true", + default=SamplingParams.return_trajectory_latents, + help="Whether to return the trajectory", + ) + parser.add_argument( + "--return-trajectory-decoded", + action="store_true", + default=SamplingParams.return_trajectory_decoded, + help="Whether to return the decoded trajectory", + ) + parser.add_argument( + "--diffusers-kwargs", + type=str, + default=None, + help="JSON string of extra kwargs to pass to diffusers pipeline. " + 'Example: \'{"output_type": "latent", "clip_skip": 2}\'', + ) + parser.add_argument( + "--no-override-protected-fields", + action="store_true", + default=SamplingParams.no_override_protected_fields, + help=( + "If set, disallow user params to override fields defined in subclasses." + ), + ) + parser.add_argument( + "--adjust-frames", + action=StoreBoolean, + default=SamplingParams.adjust_frames, + help=( + "Enable/disable adjusting num_frames to evenly split latent frames across GPUs " + "and satisfy model temporal constraints. If disabled, tokens might be padded for SP." + "Default: true. Examples: --adjust-frames, --adjust-frames true, --adjust-frames false." + ), + ) + parser.add_argument( + "--return-file-paths-only", + action=StoreBoolean, + default=SamplingParams.return_file_paths_only, + help="If set, output file will be saved early to get a performance boost, while output tensors will not be returned.", + ) + parser.add_argument( + "--enable-sequence-shard", + action=StoreBoolean, + default=SamplingParams.enable_sequence_shard, + help="Enable sequence dimension shard with sequence parallelism.", + ) + parser.add_argument( + "--enable-frame-interpolation", + action="store_true", + help="Enable post-generation frame interpolation using RIFE 4.22.lite.", + ) + parser.add_argument( + "--frame-interpolation-exp", + type=int, + default=SamplingParams.frame_interpolation_exp, + help="Frame interpolation exponent: 1=2x, 2=4x (default: 1).", + ) + parser.add_argument( + "--frame-interpolation-scale", + type=float, + default=SamplingParams.frame_interpolation_scale, + help="RIFE inference scale factor (default: 1.0; use 0.5 for high-res).", + ) + parser.add_argument( + "--frame-interpolation-model-path", + type=str, + default=SamplingParams.frame_interpolation_model_path, + help="Local directory or HuggingFace repo ID containing RIFE flownet.pkl weights " + "(default: elfgum/RIFE-4.22.lite, downloaded automatically).", + ) + return parser + + @classmethod + def get_cli_args(cls, args: argparse.Namespace): + # handle resolution shortcuts + if hasattr(args, "resolution_4k") and args.resolution_4k: + args.width = 3840 + args.height = 2160 + elif hasattr(args, "resolution_2k") and args.resolution_2k: + args.width = 2560 + args.height = 1440 + elif hasattr(args, "resolution_1080p") and args.resolution_1080p: + args.width = 1920 + args.height = 1080 + elif hasattr(args, "resolution_720p") and args.resolution_720p: + args.width = 1280 + args.height = 720 + + sampling_params_fields = {attr.name for attr in dataclasses.fields(cls)} + args_attrs = set(vars(args).keys()) + attrs = sampling_params_fields & args_attrs + args.height_not_provided = False + args.width_not_provided = False + return {attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)} + + def output_file_path(self): + if self.output_path is None: + return None + return os.path.join(self.output_path, self.output_file_name) + + def _merge_with_user_params(self, user_params: "SamplingParams"): + """ + Merges parameters from a user-provided SamplingParams object. + """ + if user_params is None: + return + + predefined_fields = set(type(self).__annotations__.keys()) + + # global switch: if True, allow overriding protected fields + allow_override_protected = not user_params.no_override_protected_fields + for field in dataclasses.fields(user_params): + field_name = field.name + user_value = getattr(user_params, field_name) + default_class_value = getattr(SamplingParams, field_name) + + # A field is considered user-modified if its value is different from the default + is_user_modified = user_value != default_class_value + is_protected_field = field_name in predefined_fields + if is_user_modified and ( + allow_override_protected or not is_protected_field + ): + setattr(self, field_name, user_value) + self.height_not_provided = user_params.height_not_provided + self.width_not_provided = user_params.width_not_provided + self.__post_init__() + + @property + def n_tokens(self) -> int: + # Calculate latent sizes + if self.height and self.width: + latents_size = [ + (self.num_frames - 1) // 4 + 1, + self.height // 8, + self.width // 8, + ] + n_tokens = latents_size[0] * latents_size[1] * latents_size[2] + else: + n_tokens = -1 + return n_tokens + + +@dataclass +class CacheParams: + cache_type: str = "none" diff --git a/sglang/python/sglang/multimodal_gen/configs/sample/teacache.py b/sglang/python/sglang/multimodal_gen/configs/sample/teacache.py new file mode 100644 index 0000000000000000000000000000000000000000..ada71d0b36181fb7917241a921e9cbae5314c233 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/sample/teacache.py @@ -0,0 +1,43 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.sample.sampling_params import CacheParams + + +@dataclass +class TeaCacheParams(CacheParams): + cache_type: str = "teacache" + teacache_thresh: float = 0.0 + coefficients: list[float] = field(default_factory=list) + + +@dataclass +class WanTeaCacheParams(CacheParams): + # Unfortunately, TeaCache is very different for Wan than other models + cache_type: str = "teacache" + teacache_thresh: float = 0.0 + use_ret_steps: bool = True + ret_steps_coeffs: list[float] = field(default_factory=list) + non_ret_steps_coeffs: list[float] = field(default_factory=list) + + @property + def coefficients(self) -> list[float]: + if self.use_ret_steps: + return self.ret_steps_coeffs + else: + return self.non_ret_steps_coeffs + + @property + def ret_steps(self) -> int: + if self.use_ret_steps: + return 5 * 2 + else: + return 1 * 2 + + def get_cutoff_steps(self, num_inference_steps: int) -> int: + if self.use_ret_steps: + return num_inference_steps * 2 + else: + return num_inference_steps * 2 - 2 diff --git a/sglang/python/sglang/multimodal_gen/configs/sample/wan.py b/sglang/python/sglang/multimodal_gen/configs/sample/wan.py new file mode 100644 index 0000000000000000000000000000000000000000..2c405b2f050b79eb9ab1c4a9c190e5e77567bde7 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/sample/wan.py @@ -0,0 +1,289 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams +from sglang.multimodal_gen.configs.sample.teacache import WanTeaCacheParams + + +@dataclass +class WanT2V_1_3B_SamplingParams(SamplingParams): + # Video parameters + height: int = 480 + width: int = 832 + num_frames: int = 81 + fps: int = 16 + + # Denoising stage + guidance_scale: float = 3.0 + negative_prompt: str = ( + "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + ) + num_inference_steps: int = 50 + + # Wan T2V 1.3B supported resolutions + supported_resolutions: list[tuple[int, int]] | None = field( + default_factory=lambda: [ + (832, 480), # 16:9 + (480, 832), # 9:16 + ] + ) + + teacache_params: WanTeaCacheParams = field( + default_factory=lambda: WanTeaCacheParams( + teacache_thresh=0.08, + ret_steps_coeffs=[ + -5.21862437e04, + 9.23041404e03, + -5.28275948e02, + 1.36987616e01, + -4.99875664e-02, + ], + non_ret_steps_coeffs=[ + 2.39676752e03, + -1.31110545e03, + 2.01331979e02, + -8.29855975e00, + 1.37887774e-01, + ], + ) + ) + + +@dataclass +class WanT2V_14B_SamplingParams(SamplingParams): + # Video parameters + height: int = 720 + width: int = 1280 + num_frames: int = 81 + fps: int = 16 + + # Denoising stage + guidance_scale: float = 5.0 + negative_prompt: str = ( + "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + ) + num_inference_steps: int = 50 + + # Wan T2V 14B supported resolutions + supported_resolutions: list[tuple[int, int]] | None = field( + default_factory=lambda: [ + (1280, 720), # 16:9 + (720, 1280), # 9:16 + (832, 480), # 16:9 + (480, 832), # 9:16 + ] + ) + + teacache_params: WanTeaCacheParams = field( + default_factory=lambda: WanTeaCacheParams( + teacache_thresh=0.20, + use_ret_steps=False, + ret_steps_coeffs=[ + -3.03318725e05, + 4.90537029e04, + -2.65530556e03, + 5.87365115e01, + -3.15583525e-01, + ], + non_ret_steps_coeffs=[ + -5784.54975374, + 5449.50911966, + -1811.16591783, + 256.27178429, + -13.02252404, + ], + ) + ) + + +@dataclass +class WanI2V_14B_480P_SamplingParam(WanT2V_1_3B_SamplingParams): + # Denoising stage + guidance_scale: float = 5.0 + num_inference_steps: int = 50 + # num_inference_steps: int = 40 + + # Wan I2V 480P supported resolutions (override parent) + supported_resolutions: list[tuple[int, int]] | None = field( + default_factory=lambda: [ + (832, 480), # 16:9 + (480, 832), # 9:16 + ] + ) + + teacache_params: WanTeaCacheParams = field( + default_factory=lambda: WanTeaCacheParams( + teacache_thresh=0.26, + ret_steps_coeffs=[ + -3.03318725e05, + 4.90537029e04, + -2.65530556e03, + 5.87365115e01, + -3.15583525e-01, + ], + non_ret_steps_coeffs=[ + -5784.54975374, + 5449.50911966, + -1811.16591783, + 256.27178429, + -13.02252404, + ], + ) + ) + + +@dataclass +class WanI2V_14B_720P_SamplingParam(WanT2V_14B_SamplingParams): + # Denoising stage + guidance_scale: float = 5.0 + num_inference_steps: int = 50 + # num_inference_steps: int = 40 + + # Wan I2V 720P supported resolutions (override parent) + supported_resolutions: list[tuple[int, int]] | None = field( + default_factory=lambda: [ + (1280, 720), # 16:9 + (720, 1280), # 9:16 + (832, 480), # 16:9 + (480, 832), # 9:16 + ] + ) + + teacache_params: WanTeaCacheParams = field( + default_factory=lambda: WanTeaCacheParams( + teacache_thresh=0.3, + ret_steps_coeffs=[ + -3.03318725e05, + 4.90537029e04, + -2.65530556e03, + 5.87365115e01, + -3.15583525e-01, + ], + non_ret_steps_coeffs=[ + -5784.54975374, + 5449.50911966, + -1811.16591783, + 256.27178429, + -13.02252404, + ], + ) + ) + + +@dataclass +class FastWanT2V480PConfig(WanT2V_1_3B_SamplingParams): + # DMD parameters + # dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 757, 522]) + num_inference_steps: int = 3 + num_frames: int = 61 + height: int = 448 + width: int = 832 + fps: int = 16 + + +# ============================================= +# ============= Wan2.1 Fun Models ============= +# ============================================= +@dataclass +class Wan2_1_Fun_1_3B_InP_SamplingParams(SamplingParams): + """Sampling parameters for Wan2.1 Fun 1.3B InP model.""" + + height: int = 480 + width: int = 832 + num_frames: int = 81 + fps: int = 16 + negative_prompt: str | None = ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + ) + guidance_scale: float = 6.0 + num_inference_steps: int = 50 + + +# ============================================= +# ============= Wan2.2 TI2V Models ============= +# ============================================= +@dataclass +class Wan2_2_Base_SamplingParams(SamplingParams): + """Sampling parameters for Wan2.2 TI2V 5B model.""" + + negative_prompt: str | None = ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + ) + + +@dataclass +class Wan2_2_TI2V_5B_SamplingParam(Wan2_2_Base_SamplingParams): + """Sampling parameters for Wan2.2 TI2V 5B model.""" + + height: int = 704 + width: int = 1280 + num_frames: int = 121 + fps: int = 24 + guidance_scale: float = 5.0 + num_inference_steps: int = 50 + + # Wan2.2 TI2V 5B supported resolutions + supported_resolutions: list[tuple[int, int]] | None = field( + default_factory=lambda: [ + (1280, 704), # 16:9-ish + (704, 1280), # 9:16-ish + ] + ) + + +@dataclass +class Wan2_2_T2V_A14B_SamplingParam(Wan2_2_Base_SamplingParams): + guidance_scale: float = 4.0 # high_noise + guidance_scale_2: float = 3.0 # low_noise + num_inference_steps: int = 40 + fps: int = 16 + + num_frames: int = 81 + + # Wan2.2 T2V A14B supported resolutions + supported_resolutions: list[tuple[int, int]] | None = field( + default_factory=lambda: [ + (1280, 720), # 16:9 + (720, 1280), # 9:16 + (832, 480), # 16:9 + (480, 832), # 9:16 + ] + ) + + +@dataclass +class Wan2_2_I2V_A14B_SamplingParam(Wan2_2_Base_SamplingParams): + guidance_scale: float = 3.5 # high_noise + guidance_scale_2: float = 3.5 # low_noise + num_inference_steps: int = 40 + fps: int = 16 + + num_frames: int = 81 + + # Wan2.2 I2V A14B supported resolutions + supported_resolutions: list[tuple[int, int]] | None = field( + default_factory=lambda: [ + (1280, 720), # 16:9 + (720, 1280), # 9:16 + (832, 480), # 16:9 + (480, 832), # 9:16 + ] + ) + + +@dataclass +class Turbo_Wan2_2_I2V_A14B_SamplingParam(Wan2_2_Base_SamplingParams): + guidance_scale: float = 3.5 # high_noise + guidance_scale_2: float = 3.5 # low_noise + num_inference_steps: int = 4 + fps: int = 16 + + +# ============================================= +# ============= Causal Self-Forcing ============= +# ============================================= +@dataclass +class SelfForcingWanT2V480PConfig(WanT2V_1_3B_SamplingParams): + pass diff --git a/sglang/python/sglang/multimodal_gen/configs/sample/zimage.py b/sglang/python/sglang/multimodal_gen/configs/sample/zimage.py new file mode 100644 index 0000000000000000000000000000000000000000..77a9dabf90dea2f2ec6e9f9149126ed24f6efdb9 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/sample/zimage.py @@ -0,0 +1,44 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams +from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams + + +@dataclass +class ZImageTurboSamplingParams(SamplingParams): + num_inference_steps: int = 9 + + num_frames: int = 1 + negative_prompt: str = None + # height: int = 720 + # width: int = 1280 + # fps: int = 24 + + guidance_scale: float = 0.0 + cfg_normalization: float | bool = False + + teacache_params: TeaCacheParams = field( + default_factory=lambda: TeaCacheParams( + teacache_thresh=0.15, + coefficients=[ + 7.33226126e02, + -4.01131952e02, + 6.75869174e01, + -3.14987800e00, + 9.61237896e-02, + ], + ) + ) + + +@dataclass +class ZImageSamplingParams(SamplingParams): + num_inference_steps: int = 50 + + num_frames: int = 1 + negative_prompt: str = " " + guidance_scale: float = 5.0 + cfg_normalization: float | bool = True diff --git a/sglang/python/sglang/multimodal_gen/configs/utils.py b/sglang/python/sglang/multimodal_gen/configs/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d2cc69adb9d1f74149d77247007aeb3c3283e92a --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/configs/utils.py @@ -0,0 +1,61 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import argparse +from typing import Any + + +def update_config_from_args( + config: Any, args_dict: dict[str, Any], prefix: str = "", pop_args: bool = False +) -> bool: + """ + Update configuration object from arguments dictionary. + + Args: + config: The configuration object to update + args_dict: Dictionary containing arguments + prefix: Prefix for the configuration parameters in the args_dict. + If None, assumes direct attribute mapping without prefix. + """ + # Handle top-level attributes (no prefix) + args_not_to_remove = [ + "model_path", + ] + args_to_remove = [] + if prefix.strip() == "": + for key, value in args_dict.items(): + if hasattr(config, key) and value is not None: + if key == "text_encoder_precisions" and isinstance(value, list): + setattr(config, key, tuple(value)) + else: + setattr(config, key, value) + if pop_args: + args_to_remove.append(key) + else: + # Handle nested attributes with prefix + prefix_with_dot = f"{prefix}." + for key, value in args_dict.items(): + if key.startswith(prefix_with_dot) and value is not None: + attr_name = key[len(prefix_with_dot) :] + if hasattr(config, attr_name): + setattr(config, attr_name, value) + if pop_args: + args_to_remove.append(key) + + if pop_args: + for key in args_to_remove: + if key not in args_not_to_remove: + args_dict.pop(key) + + return len(args_to_remove) > 0 + + +def clean_cli_args(args: argparse.Namespace) -> dict[str, Any]: + """ + Clean the arguments by removing the ones that not explicitly provided by the user. + """ + provided_args = {} + for k, v in vars(args).items(): + if v is not None and hasattr(args, "_provided") and k in args._provided: + provided_args[k] = v + + return provided_args diff --git a/sglang/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/README.md b/sglang/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7b41bd51b1fffedc5fbb4c1228034ce519b23b9d --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/README.md @@ -0,0 +1,31 @@ +# Attention Kernel Used in SGLang diffusion + +## VMoBA: Mixture-of-Block Attention for Video Diffusion Models (VMoBA) + +### Installation +Please ensure that you have installed FlashAttention version **2.7.1 or higher**, as some interfaces have changed in recent releases. + +### Usage + +You can use `moba_attn_varlen` in the following ways: + +**Install from source:** +```bash +python setup.py install +``` + +**Import after installation:** +```python +from vmoba import moba_attn_varlen +``` + +**Or import directly from the project root:** +```python +from csrc.attn.vmoba_attn.vmoba import moba_attn_varlen +``` + +### Verify if you have successfully installed + +```bash +python csrc/attn/vmoba_attn/vmoba/vmoba.py +``` diff --git a/sglang/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/setup.py b/sglang/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..3a1bdb67f476f4281b0afd84db37eb323661a061 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/setup.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 + +from setuptools import find_packages, setup + +PACKAGE_NAME = "vmoba" +VERSION = "0.0.0" +AUTHOR = "JianzongWu" +DESCRIPTION = "VMoBA: Mixture-of-Block Attention for Video Diffusion Models" +URL = "https://github.com/KwaiVGI/VMoBA" + +setup( + name=PACKAGE_NAME, + version=VERSION, + author=AUTHOR, + description=DESCRIPTION, + url=URL, + packages=find_packages(), + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + ], + python_requires=">=3.12", + install_requires=[ + "flash-attn >= 2.7.1", + ], +) diff --git a/sglang/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/tests/test_vmoba_attn.py b/sglang/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/tests/test_vmoba_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..9350f3c9e15b5a1ee17dab6a1b90c1b060b83318 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/tests/test_vmoba_attn.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 + +import random + +import pytest +import torch +from sglang.multimodal_gen.csrc.attn.vmoba_attn.vmoba import moba_attn_varlen + +def generate_test_data( + batch_size, total_seqlen, num_heads, head_dim, dtype, device="cuda" +): + """ + Generates random data for testing the variable-length attention function. + """ + torch.manual_seed(42) + random.seed(42) + torch.cuda.manual_seed_all(42) + + # Generate sequence lengths for each item in the batch + if batch_size > 1: + # Ensure sequence lengths are reasonably distributed + avg_seqlen = total_seqlen // batch_size + seqlens = [ + random.randint(avg_seqlen // 2, avg_seqlen + avg_seqlen // 2) + for _ in range(batch_size - 1) + ] + remaining_len = total_seqlen - sum(seqlens) + if remaining_len > 0: + seqlens.append(remaining_len) + else: # Adjust if sum exceeds total_seqlen + seqlens.append(avg_seqlen) + current_sum = sum(seqlens) + seqlens[-1] -= current_sum - total_seqlen + # Ensure all lengths are positive + seqlens = [max(1, s) for s in seqlens] + # Final adjustment to match total_seqlen + seqlens[-1] += total_seqlen - sum(seqlens) + + else: + seqlens = [total_seqlen] + + cu_seqlens = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seqlens), 0)), + device=device, + dtype=torch.int32, + ) + max_seqlen = max(seqlens) if seqlens else 0 + + q = torch.randn( + (total_seqlen, num_heads, head_dim), + dtype=dtype, + device=device, + requires_grad=False, + ) + k = torch.randn( + (total_seqlen, num_heads, head_dim), + dtype=dtype, + device=device, + requires_grad=False, + ) + v = torch.randn( + (total_seqlen, num_heads, head_dim), + dtype=dtype, + device=device, + requires_grad=False, + ) + + return q, k, v, cu_seqlens, max_seqlen + + +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("total_seqlen", [512, 1024]) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("head_dim", [64]) +@pytest.mark.parametrize("moba_chunk_size", [64]) +@pytest.mark.parametrize("moba_topk", [2, 4]) +@pytest.mark.parametrize("select_mode", ["topk", "threshold"]) +@pytest.mark.parametrize("threshold_type", ["query_head", "head_global", "overall"]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_moba_attn_varlen_forward( + batch_size, + total_seqlen, + num_heads, + head_dim, + moba_chunk_size, + moba_topk, + select_mode, + threshold_type, + dtype, +): + """ + Tests the forward pass of moba_attn_varlen for basic correctness. + It checks output shape, dtype, and for the presence of NaNs/Infs. + """ + if dtype == torch.float32: + pytest.skip("float32 is not supported in flash attention") + + q, k, v, cu_seqlens, max_seqlen = generate_test_data( + batch_size, total_seqlen, num_heads, head_dim, dtype + ) + + # Ensure chunk size is not larger than the smallest sequence length + min_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).min().item() + if moba_chunk_size > min_seqlen: + pytest.skip( + "moba_chunk_size is larger than the minimum sequence length in the batch" + ) + + try: + output = moba_attn_varlen( + q=q, + k=k, + v=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + moba_chunk_size=moba_chunk_size, + moba_topk=moba_topk, + select_mode=select_mode, + threshold_type=threshold_type, + simsum_threshold=0.5, # A reasonable default for threshold mode + ) + except Exception as e: + pytest.fail(f"moba_attn_varlen forward pass failed with exception: {e}") + + # 1. Check output shape + assert ( + output.shape == q.shape + ), f"Expected output shape {q.shape}, but got {output.shape}" + + # 2. Check output dtype + assert ( + output.dtype == q.dtype + ), f"Expected output dtype {q.dtype}, but got {output.dtype}" + + # 3. Check for NaNs or Infs in the output + assert torch.all(torch.isfinite(output)), "Output contains NaN or Inf values" diff --git a/sglang/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/vmoba/__init__.py b/sglang/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/vmoba/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8119387c34287d39beeb5d88db8abaed15dc111c --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/vmoba/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +from .vmoba import moba_attn_varlen, process_moba_input, process_moba_output diff --git a/sglang/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/vmoba/vmoba.py b/sglang/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/vmoba/vmoba.py new file mode 100644 index 0000000000000000000000000000000000000000..8a29360a98b88870e90be85e06af1646fee66fc6 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/vmoba/vmoba.py @@ -0,0 +1,1086 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapt from https://github.com/KwaiVGI/VMoBA/blob/main/src/vmoba.py + +import random +import time +from typing import Tuple + +import torch + +try: + from flash_attn import ( # Use the new flash attention function + flash_attn_varlen_func, + ) + from flash_attn.flash_attn_interface import ( + _flash_attn_varlen_backward, + _flash_attn_varlen_forward, + ) +except ImportError: + + def _unsupported(*args, **kwargs): + raise ImportError( + "flash-attn is not installed. Please install it, e.g., `pip install flash-attn`." + ) + + _flash_attn_varlen_forward = _unsupported + _flash_attn_varlen_backward = _unsupported + flash_attn_varlen_func = _unsupported + +from functools import lru_cache + +from einops import rearrange + + +@lru_cache(maxsize=16) +def calc_chunks(cu_seqlen, moba_chunk_size): + """ + Calculate chunk boundaries. + + For vision tasks we include all chunks (even the last one which might be shorter) + so that every chunk can be selected. + """ + batch_sizes = cu_seqlen[1:] - cu_seqlen[:-1] + batch_num_chunk = (batch_sizes + (moba_chunk_size - 1)) // moba_chunk_size + cu_num_chunk = torch.ones( + batch_num_chunk.numel() + 1, + device=cu_seqlen.device, + dtype=batch_num_chunk.dtype, + ) + cu_num_chunk[1:] = batch_num_chunk.cumsum(dim=0) + num_chunk = cu_num_chunk[-1] + chunk_sizes = torch.full( + (num_chunk + 1,), moba_chunk_size, dtype=torch.int32, device=cu_seqlen.device + ) + chunk_sizes[0] = 0 + batch_last_chunk_size = batch_sizes - (batch_num_chunk - 1) * moba_chunk_size + chunk_sizes[cu_num_chunk[1:]] = batch_last_chunk_size + cu_chunk = chunk_sizes.cumsum(dim=-1, dtype=torch.int32) + chunk_to_batch = torch.zeros( + (num_chunk,), dtype=torch.int32, device=cu_seqlen.device + ) + chunk_to_batch[cu_num_chunk[1:-1]] = 1 + chunk_to_batch = chunk_to_batch.cumsum(dim=0, dtype=torch.int32) + + # Do not filter out any chunk + filtered_chunk_indices = torch.arange( + num_chunk, device=cu_seqlen.device, dtype=torch.int32 + ) + num_filtered_chunk = num_chunk + + return cu_chunk, filtered_chunk_indices, num_filtered_chunk, chunk_to_batch + + +# --- Threshold Selection Helper Functions --- + + +def _select_threshold_query_head( + gate: torch.Tensor, + valid_gate_mask: torch.Tensor, + gate_self_chunk_mask: torch.Tensor, + simsum_threshold: float, +) -> torch.Tensor: + """ + Selects chunks for each pair based on threshold. + Normalization and sorting happen along the chunk dimension (dim=0). + """ + C, H, S = gate.shape + eps = 1e-6 + + # LSE‐style normalization per (across chunks) + gate_masked = torch.where(valid_gate_mask, gate, -torch.inf) # Use -inf for max + gate_min_val = torch.where(valid_gate_mask, gate, torch.inf) # Use +inf for min + + row_min = gate_min_val.amin(dim=0) # (H, S) + row_max = gate_masked.amax(dim=0) # (H, S) + denom = row_max - row_min + denom = torch.where( + denom <= eps, torch.ones_like(denom), denom + ) # avoid divide‑by‑zero + + gate_norm = (gate - row_min.unsqueeze(0)) / denom.unsqueeze(0) + gate_norm = torch.where(valid_gate_mask, gate_norm, 0.0) # (C, H, S) + + # 1) pull out the self‐chunk’s normalized weight for each + self_norm = (gate_norm * gate_self_chunk_mask).sum(dim=0) # (H, S) + + # 2) compute how much more normalized weight we need beyond self + total_norm_sum = gate_norm.sum(dim=0) # (H, S) + remain_ratio = simsum_threshold - self_norm / (total_norm_sum + eps) # (H, S) + remain_ratio = torch.clamp( + remain_ratio, min=0.0 + ) # if already ≥ thresh, no extra needed + + # 3) zero out the self‐chunk in a copy, so we only sort “others” + others_norm = gate_norm.clone() + others_norm[gate_self_chunk_mask] = 0.0 + + # 4) sort the other chunks by descending norm, per + sorted_norm, sorted_idx = torch.sort( + others_norm, descending=True, dim=0 + ) # (C, H, S) + + # 5) cumulative‑sum the sorted norms per + cumsum_others = sorted_norm.cumsum(dim=0) # (C, H, S) + + # 6) for each , find the smallest k where cumsum_ratio ≥ remain_ratio + ratio = cumsum_others / (total_norm_sum.unsqueeze(0) + eps) # (C, H, S) + cond = ratio >= remain_ratio.unsqueeze(0) # (C, H, S) boolean mask + any_cond = cond.any(dim=0) # (H, S) + # Find the index of the first True value along dim 0. If none, use C-1. + cutoff = torch.where( + any_cond, + cond.float().argmax(dim=0), + torch.full_like(any_cond, fill_value=C - 1), + ) # (H, S) + + # 7) build a mask in sorted order up to that cutoff + idx_range = torch.arange(C, device=gate.device).view(-1, 1, 1) # (C, 1, 1) + sorted_mask = idx_range <= cutoff.unsqueeze(0) # (C, H, S) + + # 8) scatter it back to original chunk order + others_mask = torch.zeros_like(gate, dtype=torch.bool) + others_mask.scatter_(0, sorted_idx, sorted_mask) + + # 9) finally, include every self‐chunk plus all selected others + final_gate_mask = valid_gate_mask & (others_mask | gate_self_chunk_mask) + + return final_gate_mask + + +def _select_threshold_block( + gate: torch.Tensor, + valid_gate_mask: torch.Tensor, + gate_self_chunk_mask: torch.Tensor, + simsum_threshold: float, +) -> torch.Tensor: + """ + Selects pairs for each block based on threshold. + Normalization and sorting happen across the head and sequence dimensions (dim=1, 2). + """ + C, H, S = gate.shape + HS = H * S + eps = 1e-6 + + # LSE‐style normalization per block (across heads and queries) + gate_masked = torch.where(valid_gate_mask, gate, -torch.inf) # Use -inf for max + gate_min_val = torch.where(valid_gate_mask, gate, torch.inf) # Use +inf for min + + block_max = gate_masked.amax(dim=(1, 2), keepdim=True) # (C, 1, 1) + block_min = gate_min_val.amin(dim=(1, 2), keepdim=True) # (C, 1, 1) + block_denom = block_max - block_min + block_denom = torch.where( + block_denom <= eps, torch.ones_like(block_denom), block_denom + ) # (C, 1, 1) + + gate_norm = (gate - block_min) / block_denom # (C, H, S) + gate_norm = torch.where(valid_gate_mask, gate_norm, 0.0) # (C, H, S) + + # 1) identify normalized weights of entries that *are* self-chunks (from query perspective) + self_norm_entries = gate_norm * gate_self_chunk_mask # (C, H, S) + # Sum these weights *per block* + self_norm_sum_per_block = self_norm_entries.sum(dim=(1, 2)) # (C,) + + # 2) compute how much more normalized weight each block needs beyond its self-chunk contributions + total_norm_sum_per_block = gate_norm.sum(dim=(1, 2)) # (C,) + remain_ratio = simsum_threshold - self_norm_sum_per_block / ( + total_norm_sum_per_block + eps + ) # (C,) + remain_ratio = torch.clamp(remain_ratio, min=0.0) # (C,) + + # 3) zero out the self‐chunk entries in a copy, so we only sort “others” + others_norm = gate_norm.clone() + others_norm[gate_self_chunk_mask] = 0.0 # Zero out self entries + + # 4) sort the other pairs by descending norm, per block + others_flat = others_norm.contiguous().view(C, HS) # (C, H*S) + sorted_others_flat, sorted_indices_flat = torch.sort( + others_flat, dim=1, descending=True + ) # (C, H*S) + + # 5) cumulative‑sum the sorted norms per block + cumsum_others_flat = sorted_others_flat.cumsum(dim=1) # (C, H*S) + + # 6) for each block, find the smallest k where cumsum_ratio ≥ remain_ratio + ratio_flat = cumsum_others_flat / ( + total_norm_sum_per_block.unsqueeze(1) + eps + ) # (C, H*S) + cond_flat = ratio_flat >= remain_ratio.unsqueeze(1) # (C, H*S) boolean mask + any_cond = cond_flat.any(dim=1) # (C,) + # Find the index of the first True value along dim 1. If none, use HS-1. + cutoff_flat = torch.where( + any_cond, + cond_flat.float().argmax(dim=1), + torch.full_like(any_cond, fill_value=HS - 1), + ) # (C,) + + # 7) build a mask in sorted order up to that cutoff per block + idx_range_flat = torch.arange(HS, device=gate.device).unsqueeze(0) # (1, H*S) + sorted_mask_flat = idx_range_flat <= cutoff_flat.unsqueeze(1) # (C, H*S) + + # 8) scatter it back to original order per block + others_mask_flat = torch.zeros_like(others_flat, dtype=torch.bool) # (C, H*S) + others_mask_flat.scatter_(1, sorted_indices_flat, sorted_mask_flat) + others_mask = others_mask_flat.view(C, H, S) # (C, H, S) + + # 9) finally, include every self‐chunk entry plus all selected others + final_gate_mask = valid_gate_mask & (others_mask | gate_self_chunk_mask) + + return final_gate_mask + + +def _select_threshold_overall( + gate: torch.Tensor, + valid_gate_mask: torch.Tensor, + gate_self_chunk_mask: torch.Tensor, + simsum_threshold: float, +) -> torch.Tensor: + """ + Selects triplets globally based on threshold. + Normalization and sorting happen across all valid entries. + """ + C, H, S = gate.shape + CHS = C * H * S + eps = 1e-6 + + # LSE‐style normalization globally across all valid entries + gate_masked = torch.where(valid_gate_mask, gate, -torch.inf) # Use -inf for max + gate_min_val = torch.where(valid_gate_mask, gate, torch.inf) # Use +inf for min + + overall_max = gate_masked.max() # scalar + overall_min = gate_min_val.min() # scalar + overall_denom = overall_max - overall_min + overall_denom = torch.where( + overall_denom <= eps, + torch.tensor(1.0, device=gate.device, dtype=gate.dtype), + overall_denom, + ) + + gate_norm = (gate - overall_min) / overall_denom # (C, H, S) + gate_norm = torch.where(valid_gate_mask, gate_norm, 0.0) # (C, H, S) + + # 1) identify normalized weights of entries that *are* self-chunks + self_norm_entries = gate_norm * gate_self_chunk_mask # (C, H, S) + # Sum these weights globally + self_norm_sum_overall = self_norm_entries.sum() # scalar + + # 2) compute how much more normalized weight is needed globally beyond self-chunk contributions + total_norm_sum_overall = gate_norm.sum() # scalar + remain_ratio = simsum_threshold - self_norm_sum_overall / ( + total_norm_sum_overall + eps + ) # scalar + remain_ratio = torch.clamp(remain_ratio, min=0.0) # scalar + + # 3) zero out the self‐chunk entries in a copy, so we only sort “others” + others_norm = gate_norm.clone() + others_norm[gate_self_chunk_mask] = 0.0 # Zero out self entries + + # 4) sort all other entries by descending norm, globally + others_flat = others_norm.flatten() # (C*H*S,) + valid_others_mask_flat = ( + valid_gate_mask.flatten() & ~gate_self_chunk_mask.flatten() + ) # Mask for valid, non-self entries + + # Only sort the valid 'other' entries + valid_others_indices = torch.where(valid_others_mask_flat)[0] + valid_others_values = others_flat[valid_others_indices] + + sorted_others_values, sort_perm = torch.sort( + valid_others_values, descending=True + ) # (N_valid_others,) + sorted_original_indices = valid_others_indices[ + sort_perm + ] # Original indices in C*H*S space, sorted by value + + # 5) cumulative‑sum the sorted valid 'other' norms globally + cumsum_others_values = sorted_others_values.cumsum(dim=0) # (N_valid_others,) + + # 6) find the smallest k where cumsum_ratio ≥ remain_ratio globally + ratio_values = cumsum_others_values / ( + total_norm_sum_overall + eps + ) # (N_valid_others,) + cond_values = ratio_values >= remain_ratio # (N_valid_others,) boolean mask + any_cond = cond_values.any() # scalar + + # Find the index of the first True value in the *sorted* list. If none, use all valid others. + cutoff_idx_in_sorted = torch.where( + any_cond, + cond_values.float().argmax(dim=0), + torch.tensor( + len(sorted_others_values) - 1, device=gate.device, dtype=torch.long + ), + ) + + # 7) build a mask selecting the top-k others based on the cutoff + # Select the original indices corresponding to the top entries in the sorted list + selected_other_indices = sorted_original_indices[: cutoff_idx_in_sorted + 1] + + # 8) create the mask in the original flat shape + others_mask_flat = torch.zeros_like(others_flat, dtype=torch.bool) # (C*H*S,) + if selected_other_indices.numel() > 0: # Check if any 'other' indices were selected + others_mask_flat[selected_other_indices] = True + others_mask = others_mask_flat.view(C, H, S) # (C, H, S) + + # 9) finally, include every self‐chunk entry plus all selected others + final_gate_mask = valid_gate_mask & (others_mask | gate_self_chunk_mask) + + return final_gate_mask + + +def _select_threshold_head_global( + gate: torch.Tensor, + valid_gate_mask: torch.Tensor, + gate_self_chunk_mask: torch.Tensor, + simsum_threshold: float, +) -> torch.Tensor: + """ + Selects globally for each head based on threshold. + """ + C, H, S = gate.shape + eps = 1e-6 + + # 1) LSE‐style normalization per head (across chunks and sequence dims) + gate_masked = torch.where(valid_gate_mask, gate, -torch.inf) + gate_min_val = torch.where(valid_gate_mask, gate, torch.inf) + + max_per_head = gate_masked.amax(dim=(0, 2), keepdim=True) # (1, H, 1) + min_per_head = gate_min_val.amin(dim=(0, 2), keepdim=True) # (1, H, 1) + denom = max_per_head - min_per_head + denom = torch.where(denom <= eps, torch.ones_like(denom), denom) + + gate_norm = (gate - min_per_head) / denom + gate_norm = torch.where(valid_gate_mask, gate_norm, 0.0) # (C, H, S) + + # 2) sum normalized self‐chunk contributions per head + self_norm_sum = (gate_norm * gate_self_chunk_mask).sum(dim=(0, 2)) # (H,) + + # 3) total normalized sum per head + total_norm_sum = gate_norm.sum(dim=(0, 2)) # (H,) + + # 4) how much more normalized weight needed per head + remain_ratio = simsum_threshold - self_norm_sum / (total_norm_sum + eps) # (H,) + remain_ratio = torch.clamp(remain_ratio, min=0.0) + + # 5) zero out self‐chunk entries to focus on "others" + others_norm = gate_norm.clone() + others_norm[gate_self_chunk_mask] = 0.0 # (C, H, S) + + # 6) flatten chunk and sequence dims, per head + CS = C * S + others_flat = others_norm.permute(1, 0, 2).reshape(H, CS) # (H, C*S) + valid_flat = ( + (valid_gate_mask & ~gate_self_chunk_mask).permute(1, 0, 2).reshape(H, CS) + ) # (H, C*S) + + # 7) vectorized selection of “others” per head + masked_flat = torch.where(valid_flat, others_flat, torch.zeros_like(others_flat)) + sorted_vals, sorted_idx = torch.sort( + masked_flat, dim=1, descending=True + ) # (H, C*S) + + cumsum_vals = sorted_vals.cumsum(dim=1) # (H, C*S) + ratio_vals = cumsum_vals / (total_norm_sum.unsqueeze(1) + eps) # (H, C*S) + cond = ratio_vals >= remain_ratio.unsqueeze(1) # (H, C*S) + + has_cutoff = cond.any(dim=1) # (H,) + default = torch.full((H,), CS - 1, device=gate.device, dtype=torch.long) + cutoff = torch.where(has_cutoff, cond.float().argmax(dim=1), default) # (H,) + + idx_range = torch.arange(CS, device=gate.device).unsqueeze(0) # (1, C*S) + sorted_mask = idx_range <= cutoff.unsqueeze(1) # (H, C*S) + + selected_flat = torch.zeros_like(valid_flat) # (H, C*S) + selected_flat.scatter_(1, sorted_idx, sorted_mask) # (H, C*S) + + # 8) reshape selection mask back to (C, H, S) + others_mask = selected_flat.reshape(H, C, S).permute(1, 0, 2) # (C, H, S) + + # 9) include self‐chunks plus selected others, and obey valid mask + final_gate_mask = valid_gate_mask & (gate_self_chunk_mask | others_mask) + + return final_gate_mask + + +class MixedAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + self_attn_cu_seqlen, + moba_q, + moba_kv, + moba_cu_seqlen_q, + moba_cu_seqlen_kv, + max_seqlen, + moba_chunk_size, + moba_q_sh_indices, + ): + ctx.max_seqlen = max_seqlen + ctx.moba_chunk_size = moba_chunk_size + ctx.softmax_scale = softmax_scale = q.shape[-1] ** (-0.5) + + # Non-causal self-attention branch + # return out, softmax_lse, S_dmask, rng_state + self_attn_out_sh, self_attn_lse_hs, _, _ = _flash_attn_varlen_forward( + q=q, + k=k, + v=v, + cu_seqlens_q=self_attn_cu_seqlen, + cu_seqlens_k=self_attn_cu_seqlen, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=False, + dropout_p=0.0, + ) + # MOBA attention branch (non-causal) + moba_attn_out, moba_attn_lse_hs, _, _ = _flash_attn_varlen_forward( + q=moba_q, + k=moba_kv[:, 0], + v=moba_kv[:, 1], + cu_seqlens_q=moba_cu_seqlen_q, + cu_seqlens_k=moba_cu_seqlen_kv, + max_seqlen_q=max_seqlen, + max_seqlen_k=moba_chunk_size, + softmax_scale=softmax_scale, + causal=False, + dropout_p=0.0, + ) + + self_attn_lse_sh = self_attn_lse_hs.t().contiguous() + moba_attn_lse = moba_attn_lse_hs.t().contiguous() + + output = torch.zeros( + (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 + ) + output_2d = output.view(-1, q.shape[2]) + + max_lse_1d = self_attn_lse_sh.view(-1) + max_lse_1d = max_lse_1d.index_reduce( + 0, moba_q_sh_indices, moba_attn_lse.view(-1), "amax" + ) + self_attn_lse_sh = self_attn_lse_sh - max_lse_1d.view_as(self_attn_lse_sh) + moba_attn_lse = ( + moba_attn_lse.view(-1) + .sub(max_lse_1d.index_select(0, moba_q_sh_indices)) + .reshape_as(moba_attn_lse) + ) + + mixed_attn_se_sh = self_attn_lse_sh.exp() + moba_attn_se = moba_attn_lse.exp() + + mixed_attn_se_sh.view(-1).index_add_( + 0, moba_q_sh_indices, moba_attn_se.view(-1) + ) + mixed_attn_lse_sh = mixed_attn_se_sh.log() + + # Combine self-attention output + factor = (self_attn_lse_sh - mixed_attn_lse_sh).exp() # [S, H] + self_attn_out_sh = self_attn_out_sh * factor.unsqueeze(-1) + output_2d += self_attn_out_sh.reshape_as(output_2d) + + # Combine MOBA attention output + mixed_attn_lse = ( + mixed_attn_lse_sh.view(-1) + .index_select(0, moba_q_sh_indices) + .view_as(moba_attn_lse) + ) + factor = (moba_attn_lse - mixed_attn_lse).exp() # [S, H] + moba_attn_out = moba_attn_out * factor.unsqueeze(-1) + raw_attn_out = moba_attn_out.view(-1, moba_attn_out.shape[-1]) + output_2d.index_add_(0, moba_q_sh_indices, raw_attn_out) + output = output.to(q.dtype) + mixed_attn_lse_sh = mixed_attn_lse_sh + max_lse_1d.view_as(mixed_attn_se_sh) + ctx.save_for_backward( + output, + mixed_attn_lse_sh, + q, + k, + v, + self_attn_cu_seqlen, + moba_q, + moba_kv, + moba_cu_seqlen_q, + moba_cu_seqlen_kv, + moba_q_sh_indices, + ) + + return output + + @staticmethod + def backward(ctx, d_output): + + max_seqlen = ctx.max_seqlen + moba_chunk_size = ctx.moba_chunk_size + softmax_scale = ctx.softmax_scale + + ( + output, + mixed_attn_vlse_sh, + q, + k, + v, + self_attn_cu_seqlen, + moba_q, + moba_kv, + moba_cu_seqlen_q, + moba_cu_seqlen_kv, + moba_q_sh_indices, + ) = ctx.saved_tensors + + d_output = d_output.contiguous() + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + _ = _flash_attn_varlen_backward( + dout=d_output, + q=q, + k=k, + v=v, + out=output, + softmax_lse=mixed_attn_vlse_sh.t().contiguous(), + dq=dq, + dk=dk, + dv=dv, + cu_seqlens_q=self_attn_cu_seqlen, + cu_seqlens_k=self_attn_cu_seqlen, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=False, + dropout_p=0.0, + softcap=0.0, + alibi_slopes=None, + deterministic=True, + window_size_left=-1, + window_size_right=-1, + ) + + headdim = q.shape[-1] + d_moba_output = ( + d_output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1) + ) + moba_output = ( + output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1) + ) + + mixed_attn_vlse = ( + mixed_attn_vlse_sh.view(-1).index_select(0, moba_q_sh_indices).view(1, -1) + ) + + dmq = torch.empty_like(moba_q) + dmkv = torch.empty_like(moba_kv) + _ = _flash_attn_varlen_backward( + dout=d_moba_output, + q=moba_q, + k=moba_kv[:, 0], + v=moba_kv[:, 1], + out=moba_output, + softmax_lse=mixed_attn_vlse, + dq=dmq, + dk=dmkv[:, 0], + dv=dmkv[:, 1], + cu_seqlens_q=moba_cu_seqlen_q, + cu_seqlens_k=moba_cu_seqlen_kv, + max_seqlen_q=max_seqlen, + max_seqlen_k=moba_chunk_size, + softmax_scale=softmax_scale, + causal=False, + dropout_p=0.0, + softcap=0.0, + alibi_slopes=None, + deterministic=True, + window_size_left=-1, + window_size_right=-1, + ) + + return dq, dk, dv, None, dmq, dmkv, None, None, None, None, None + + +def moba_attn_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + moba_chunk_size: int, + moba_topk: int, + select_mode: str = "threshold", # "topk" or "threshold" + simsum_threshold: float = 0.25, + threshold_type: str = "query_head", +) -> torch.Tensor: + """ + Accelerated MOBA attention for vision tasks with proper LSE normalization. + + This version: + - Splits KV into chunks. + - For each query head, selects the top-k relevant KV chunks (including the self chunk) + by amplifying the diagonal (self-chunk) logits. + - Aggregates the attention outputs from the selected chunks using a log-sum-exp + reduction so that attending to each query over the selected chunks is equivalent + to the original algorithm. + """ + # Stack keys and values. + kv = torch.stack((k, v), dim=1) + seqlen, num_head, head_dim = q.shape + + # Compute chunk boundaries. + cu_chunk, filtered_chunk_indices, num_filtered_chunk, chunk_to_batch = calc_chunks( + cu_seqlens, moba_chunk_size + ) + + self_attn_cu_seqlen = cu_chunk + + # Update top-k selection to include the self chunk. + moba_topk = min(moba_topk, num_filtered_chunk) + + # --- Build filtered KV from chunks --- + chunk_starts = cu_chunk[filtered_chunk_indices] # [num_filtered_chunk] + chunk_ends = cu_chunk[filtered_chunk_indices + 1] # [num_filtered_chunk] + chunk_lengths = chunk_ends - chunk_starts # [num_filtered_chunk] + max_chunk_len = int(chunk_lengths.max().item()) + + range_tensor = torch.arange( + max_chunk_len, device=kv.device, dtype=chunk_starts.dtype + ).unsqueeze(0) + indices = chunk_starts.unsqueeze(1) + range_tensor + indices = torch.clamp(indices, max=kv.shape[0] - 1) + valid_mask = range_tensor < chunk_lengths.unsqueeze(1) + gathered = kv[indices.view(-1)].view( + num_filtered_chunk, max_chunk_len, *kv.shape[1:] + ) + gathered = gathered * valid_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).type_as( + gathered + ) + + # Compute key_gate_weight over valid tokens. + key_values = gathered[ + :, :, 0 + ].float() # [num_filtered_chunk, max_chunk_len, num_head, head_dim] + valid_mask_exp = valid_mask.unsqueeze(-1).unsqueeze(-1) + key_sum = (key_values * valid_mask_exp).sum(dim=1) + divisor = valid_mask.sum(dim=1).unsqueeze(-1).unsqueeze(-1) + key_gate_weight = key_sum / divisor # [num_filtered_chunk, num_head, head_dim] + + # Compute gate logits between key_gate_weight and queries. + q_float = q.float() + # gate = torch.einsum("nhd,shd->nhs", key_gate_weight, q_float) # [num_filtered_chunk, num_head, seqlen] + gate = torch.bmm( + key_gate_weight.permute(1, 0, 2), q_float.permute(1, 0, 2).transpose(1, 2) + ).permute(1, 0, 2) + + # Amplify the diagonal (self chunk) contributions. + gate_seq_idx = ( + torch.arange(seqlen, device=q.device, dtype=torch.int32) + .unsqueeze(0) + .expand(num_filtered_chunk, seqlen) + ) + chunk_start = cu_chunk[filtered_chunk_indices] # [num_filtered_chunk] + chunk_end = cu_chunk[filtered_chunk_indices + 1] # [num_filtered_chunk] + gate_self_chunk_mask = ( + ( + (gate_seq_idx >= chunk_start.unsqueeze(1)) + & (gate_seq_idx < chunk_end.unsqueeze(1)) + ) + .unsqueeze(1) + .expand(-1, num_head, -1) + ) + amplification_factor = 1e9 # Example factor; adjust as needed. + origin_gate = gate.clone() + gate = gate.clone() + if select_mode == "topk": + gate[gate_self_chunk_mask] += amplification_factor + + # Exclude positions that are outside the valid batch boundaries. + batch_starts = cu_seqlens[chunk_to_batch[filtered_chunk_indices]] + batch_ends = cu_seqlens[chunk_to_batch[filtered_chunk_indices] + 1] + gate_batch_start_mask = gate_seq_idx < batch_starts.unsqueeze(1) + gate_batch_end_mask = gate_seq_idx >= batch_ends.unsqueeze(1) + gate_inf_mask = gate_batch_start_mask | gate_batch_end_mask + gate.masked_fill_(gate_inf_mask.unsqueeze(1), -float("inf")) + + if select_mode == "topk": + # We amplify self‐chunk in gate already, so self entries will rank highest. + valid_gate_mask = gate != -float("inf") + if threshold_type == "query_head": + # === per‐ top-k across chunks (original behavior) === + # gate: (C, H, S) + _, gate_topk_idx = torch.topk( + gate, k=moba_topk, dim=0, largest=True, sorted=False + ) + gate_idx_mask = torch.zeros_like(gate, dtype=torch.bool) + gate_idx_mask.scatter_(0, gate_topk_idx, True) + gate_mask = valid_gate_mask & gate_idx_mask + elif threshold_type == "overall": + # === global top-k across all (chunk, head, seq) entries === + C, H, S = gate.shape + flat_gate = gate.flatten() + flat_mask = valid_gate_mask.flatten() + flat_gate_masked = torch.where(flat_mask, flat_gate, -float("inf")) + # pick topk global entries + vals, idx = torch.topk( + flat_gate_masked, k=moba_topk * H * S, largest=True, sorted=False + ) + others_mask_flat = torch.zeros_like(flat_mask, dtype=torch.bool) + others_mask_flat[idx] = True + gate_mask = (valid_gate_mask.flatten() & others_mask_flat).view(gate.shape) + elif threshold_type == "head_global": + # per-head top-k across all chunks and sequence positions + C, H, S = gate.shape + CS = C * S + flat_gate = gate.permute(1, 0, 2).reshape(H, CS) + flat_valid = valid_gate_mask.permute(1, 0, 2).reshape(H, CS) + flat_gate_masked = torch.where( + flat_valid, flat_gate, torch.full_like(flat_gate, -float("inf")) + ) + # pick top-k indices per head + _, topk_idx = torch.topk( + flat_gate_masked, k=moba_topk * S, dim=1, largest=True, sorted=False + ) + gate_idx_flat = torch.zeros_like(flat_valid, dtype=torch.bool) + gate_idx_flat.scatter_(1, topk_idx, True) + gate_mask = gate_idx_flat.reshape(H, C, S).permute(1, 0, 2) + else: + raise ValueError( + f"Invalid threshold_type for topk: {threshold_type}. " + "Choose 'query_head', 'block', or 'overall'." + ) + elif select_mode == "threshold": + # Delegate to the specific thresholding function + valid_gate_mask = gate != -float("inf") # (num_chunk, num_head, seqlen) + if threshold_type == "query_head": + gate_mask = _select_threshold_query_head( + gate, valid_gate_mask, gate_self_chunk_mask, simsum_threshold + ) + elif threshold_type == "block": + gate_mask = _select_threshold_block( + gate, valid_gate_mask, gate_self_chunk_mask, simsum_threshold + ) + elif threshold_type == "overall": + gate_mask = _select_threshold_overall( + gate, valid_gate_mask, gate_self_chunk_mask, simsum_threshold + ) + elif threshold_type == "head_global": + gate_mask = _select_threshold_head_global( + gate, valid_gate_mask, gate_self_chunk_mask, simsum_threshold + ) + else: + raise ValueError( + f"Invalid threshold_type: {threshold_type}. Choose 'query_head', 'block', or 'overall'." + ) + else: + raise ValueError( + f"Invalid select_mode: {select_mode}. Choose 'topk' or 'threshold'." + ) + + # eliminate self_chunk in MoBA branch + gate_mask = gate_mask & ~gate_self_chunk_mask + # if gate_mask is all false, perform flash_attn instead + if gate_mask.sum() == 0: + return flash_attn_varlen_func( + q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, causal=False + ) + + # Determine which query positions are selected. + # nonzero_indices has shape [N, 3] where each row is [chunk_index, head_index, seq_index]. + moba_q_indices = gate_mask.reshape(gate_mask.shape[0], -1).nonzero(as_tuple=True)[ + -1 + ] # [(h s k)] + moba_q_sh_indices = (moba_q_indices % seqlen) * num_head + ( + moba_q_indices // seqlen + ) + moba_q = ( + rearrange(q, "s h d -> (h s) d").index_select(0, moba_q_indices).unsqueeze(1) + ) + + # Build cumulative sequence lengths for the selected queries. + moba_seqlen_q = gate_mask.sum(dim=-1).flatten() + q_zero_mask = moba_seqlen_q == 0 + valid_expert_mask = ~q_zero_mask + if q_zero_mask.sum() > 0: + moba_seqlen_q = moba_seqlen_q[valid_expert_mask] + moba_cu_seqlen_q = torch.cat( + ( + torch.tensor([0], device=q.device, dtype=moba_seqlen_q.dtype), + moba_seqlen_q.cumsum(dim=0), + ), + dim=0, + ).to(torch.int32) + + # Rearrange gathered KV for the MOBA branch. + experts_tensor = rearrange(gathered, "nc cl two h d -> (nc h) cl two d") + valid_expert_lengths = ( + chunk_lengths.unsqueeze(1) + .expand(num_filtered_chunk, num_head) + .reshape(-1) + .to(torch.int32) + ) + if q_zero_mask.sum() > 0: + experts_tensor = experts_tensor[valid_expert_mask] + valid_expert_lengths = valid_expert_lengths[valid_expert_mask] + + seq_range = torch.arange( + experts_tensor.shape[1], device=experts_tensor.device + ).unsqueeze(0) + mask = seq_range < valid_expert_lengths.unsqueeze(1) + moba_kv = experts_tensor[mask] # Shape: ((nc h cl_valid) two d) + moba_kv = moba_kv.unsqueeze(2) # Shape: ((nc h cl_valid) two 1 d) + + moba_cu_seqlen_kv = torch.cat( + [ + torch.zeros(1, device=experts_tensor.device, dtype=torch.int32), + valid_expert_lengths.cumsum(dim=0), + ], + dim=0, + ).to(torch.int32) + + assert ( + moba_cu_seqlen_kv.shape == moba_cu_seqlen_q.shape + ), f"Mismatch between moba_cu_seqlen_kv.shape and moba_cu_seqlen_q.shape: {moba_cu_seqlen_kv.shape} vs {moba_cu_seqlen_q.shape}" + + return MixedAttention.apply( + q, + k, + v, + self_attn_cu_seqlen, + moba_q, + moba_kv, + moba_cu_seqlen_q, + moba_cu_seqlen_kv, + max_seqlen, + moba_chunk_size, + moba_q_sh_indices, + ) + + +def process_moba_input( + x, + patch_resolution, + chunk_size, +): + """ + Process inputs for the attention function. + + Args: + x (torch.Tensor): Input tensor with shape [batch_size, num_patches, num_heads, head_dim]. + patch_resolution (tuple): Tuple containing the patch resolution (t, h, w). + chunk_size (int): Size of the chunk. (maybe tuple or int, according to chunk type) + + Returns: + torch.Tensor: Processed input tensor. + """ + if isinstance(chunk_size, float) or isinstance(chunk_size, int): + moba_chunk_size = int(chunk_size * patch_resolution[1] * patch_resolution[2]) + else: + assert isinstance( + chunk_size, (Tuple, list) + ), f"chunk_size should be a tuple, list, or int, now it is: {type(chunk_size)}" + if len(chunk_size) == 2: + assert ( + patch_resolution[1] % chunk_size[0] == 0 + and patch_resolution[2] % chunk_size[1] == 0 + ), f"spatial patch_resolution {patch_resolution[1:]} should be divisible by 2d chunk_size {chunk_size}" + nch, ncw = ( + patch_resolution[1] // chunk_size[0], + patch_resolution[2] // chunk_size[1], + ) + x = rearrange( + x, + "b (t nch ch ncw cw) n d -> b (nch ncw t ch cw) n d", + t=patch_resolution[0], + nch=nch, + ncw=ncw, + ch=chunk_size[0], + cw=chunk_size[1], + ) + moba_chunk_size = patch_resolution[0] * chunk_size[0] * chunk_size[1] + elif len(chunk_size) == 3: + assert ( + patch_resolution[0] % chunk_size[0] == 0 + and patch_resolution[1] % chunk_size[1] == 0 + and patch_resolution[2] % chunk_size[2] == 0 + ), f"patch_resolution {patch_resolution} should be divisible by 3d chunk_size {chunk_size}" + nct, nch, ncw = ( + patch_resolution[0] // chunk_size[0], + patch_resolution[1] // chunk_size[1], + patch_resolution[2] // chunk_size[2], + ) + x = rearrange( + x, + "b (nct ct nch ch ncw cw) n d -> b (nct nch ncw ct ch cw) n d", + nct=nct, + nch=nch, + ncw=ncw, + ct=chunk_size[0], + ch=chunk_size[1], + cw=chunk_size[2], + ) + moba_chunk_size = chunk_size[0] * chunk_size[1] * chunk_size[2] + else: + raise ValueError( + f"chunk_size should be a int, or a tuple of length 2 or 3, now it is: {len(chunk_size)}" + ) + + return x, moba_chunk_size + + +def process_moba_output( + x, + patch_resolution, + chunk_size, +): + if isinstance(chunk_size, float) or isinstance(chunk_size, int): + pass + elif len(chunk_size) == 2: + x = rearrange( + x, + "b (nch ncw t ch cw) n d -> b (t nch ch ncw cw) n d", + nch=patch_resolution[1] // chunk_size[0], + ncw=patch_resolution[2] // chunk_size[1], + t=patch_resolution[0], + ch=chunk_size[0], + cw=chunk_size[1], + ) + elif len(chunk_size) == 3: + x = rearrange( + x, + "b (nct nch ncw ct ch cw) n d -> b (nct ct nch ch ncw cw) n d", + nct=patch_resolution[0] // chunk_size[0], + nch=patch_resolution[1] // chunk_size[1], + ncw=patch_resolution[2] // chunk_size[2], + ct=chunk_size[0], + ch=chunk_size[1], + cw=chunk_size[2], + ) + + return x + + +# TEST +def generate_data(batch_size, seqlen, num_head, head_dim, dtype): + random.seed(0) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + device = torch.cuda.current_device() + + q = torch.randn((batch_size, seqlen, num_head, head_dim), requires_grad=True).to( + dtype=dtype, device="cuda" + ) + k = torch.randn((batch_size, seqlen, num_head, head_dim), requires_grad=True).to( + dtype=dtype, device="cuda" + ) + v = torch.randn((batch_size, seqlen, num_head, head_dim), requires_grad=True).to( + dtype=dtype, device="cuda" + ) + print(f"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}") + cu_seqlens = torch.arange( + 0, q.shape[0] * q.shape[1] + 1, q.shape[1], dtype=torch.int32, device="cuda" + ) + max_seqlen = q.shape[1] + q = rearrange(q, "b s ... -> (b s) ...") + k = rearrange(k, "b s ... -> (b s) ...") + v = rearrange(v, "b s ... -> (b s) ...") + + return q, k, v, cu_seqlens, max_seqlen + + +def test_attn_varlen_moba_speed( + batch, + head, + seqlen, + head_dim, + moba_chunk_size, + moba_topk, + dtype=torch.bfloat16, + select_mode="threshold", + simsum_threshold=0.25, + threshold_type="query_head", +): + """Speed test comparing flash_attn vs moba_attention""" + # Get data + q, k, v, cu_seqlen, max_seqlen = generate_data(batch, seqlen, head, head_dim, dtype) + print( + f"batch:{batch} head:{head} seqlen:{seqlen} chunk:{moba_chunk_size} topk:{moba_topk} select_mode: {select_mode} simsum_threshold:{simsum_threshold}" + ) + vo_grad = torch.randn_like(q) + + # Warmup + warmup_iters = 3 + perf_test_iters = 10 + + # Warmup + for _ in range(warmup_iters): + o = flash_attn_varlen_func( + q, k, v, cu_seqlen, cu_seqlen, max_seqlen, max_seqlen, causal=False + ) + torch.autograd.backward(o, vo_grad) + + torch.cuda.synchronize() + start_flash = time.perf_counter() + for _ in range(perf_test_iters): + o = flash_attn_varlen_func( + q, k, v, cu_seqlen, cu_seqlen, max_seqlen, max_seqlen, causal=False + ) + torch.autograd.backward(o, vo_grad) + + torch.cuda.synchronize() + time_flash = (time.perf_counter() - start_flash) / perf_test_iters * 1000 + + # Warmup + for _ in range(warmup_iters): + om = moba_attn_varlen( + q, + k, + v, + cu_seqlen, + max_seqlen, + moba_chunk_size=moba_chunk_size, + moba_topk=moba_topk, + select_mode=select_mode, + simsum_threshold=simsum_threshold, + threshold_type=threshold_type, + ) + torch.autograd.backward(om, vo_grad) + + torch.cuda.synchronize() + start_moba = time.perf_counter() + for _ in range(perf_test_iters): + om = moba_attn_varlen( + q, + k, + v, + cu_seqlen, + max_seqlen, + moba_chunk_size=moba_chunk_size, + moba_topk=moba_topk, + select_mode=select_mode, + simsum_threshold=simsum_threshold, + threshold_type=threshold_type, + ) + torch.autograd.backward(om, vo_grad) + + torch.cuda.synchronize() + time_moba = (time.perf_counter() - start_moba) / perf_test_iters * 1000 + + print(f"Flash: {time_flash:.2f}ms, MoBA: {time_moba:.2f}ms") + print(f"Speedup: {time_flash / time_moba:.2f}x") + + +if __name__ == "__main__": + """ + CUDA_VISIBLE_DEVICES=1 \ + python -u csrc/attn/vmoba_attn/vmoba/vmoba.py + """ + test_attn_varlen_moba_speed( + batch=1, + head=12, + seqlen=32760, + head_dim=128, + moba_chunk_size=32760 // 3 // 6 // 4, + moba_topk=3, + select_mode="threshold", + simsum_threshold=0.3, + threshold_type="query_head", + ) diff --git a/sglang/python/sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/__init__.py b/sglang/python/sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d778479e782b6a288a4e46e4f077ff32eb5de39 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/__init__.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Custom CUDA rasterizer for Hunyuan3D texture generation. + +This module provides JIT-compiled CUDA rasterization for fast mesh rendering. +Adapted from Hunyuan3D-2: https://github.com/Tencent/Hunyuan3D-2 +""" + +from __future__ import annotations + +import os +from typing import List, Tuple + +import torch + +_abs_path = os.path.dirname(os.path.abspath(__file__)) +_custom_rasterizer_kernel = None + + +def _load_custom_rasterizer(): + """JIT compile and load the custom rasterizer kernel.""" + global _custom_rasterizer_kernel + + if _custom_rasterizer_kernel is not None: + return _custom_rasterizer_kernel + + from torch.utils.cpp_extension import load + + _custom_rasterizer_kernel = load( + name="custom_rasterizer_kernel", + sources=[ + f"{_abs_path}/rasterizer.cpp", + f"{_abs_path}/rasterizer_gpu.cu", + ], + extra_cflags=["-O3"], + extra_cuda_cflags=["-O3", "--use_fast_math"], + verbose=False, + ) + return _custom_rasterizer_kernel + + +def rasterize( + pos: torch.Tensor, + tri: torch.Tensor, + resolution: Tuple[int, int], + clamp_depth: torch.Tensor = None, + use_depth_prior: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Rasterize mesh to get face indices and barycentric coordinates.""" + kernel = _load_custom_rasterizer() + + if clamp_depth is None: + clamp_depth = torch.zeros(0, device=pos.device) + + # pos should be [N, 4], remove batch dim if present + if pos.dim() == 3: + pos = pos[0] + + findices, barycentric = kernel.rasterize_image( + pos, tri, clamp_depth, resolution[1], resolution[0], 1e-6, use_depth_prior + ) + return findices, barycentric + + +def interpolate( + col: torch.Tensor, + findices: torch.Tensor, + barycentric: torch.Tensor, + tri: torch.Tensor, +) -> torch.Tensor: + """Interpolate vertex attributes using barycentric coordinates.""" + # Handle zero indices (background) + f = findices - 1 + (findices == 0) + vcol = col[0, tri.long()[f.long()]] + result = barycentric.view(*barycentric.shape, 1) * vcol + result = torch.sum(result, axis=-2) + return result.view(1, *result.shape) + + +__all__ = ["rasterize", "interpolate"] diff --git a/sglang/python/sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/rasterizer.cpp b/sglang/python/sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/rasterizer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..72aa005b351cf6b2cdc36b0a52265bf3813ac3cb --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/rasterizer.cpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: Apache-2.0 +// Adapted from Hunyuan3D-2: https://github.com/Tencent/Hunyuan3D-2 +// Original license: TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT + +#include "rasterizer.h" + +void rasterizeTriangleCPU(int idx, float* vt0, float* vt1, float* vt2, int width, int height, INT64* zbuffer, float* d, float occlusion_truncation) { + float x_min = std::min(vt0[0], std::min(vt1[0],vt2[0])); + float x_max = std::max(vt0[0], std::max(vt1[0],vt2[0])); + float y_min = std::min(vt0[1], std::min(vt1[1],vt2[1])); + float y_max = std::max(vt0[1], std::max(vt1[1],vt2[1])); + + for (int px = x_min; px < x_max + 1; ++px) { + if (px < 0 || px >= width) + continue; + for (int py = y_min; py < y_max + 1; ++py) { + if (py < 0 || py >= height) + continue; + float vt[2] = {px + 0.5f, py + 0.5f}; + float baryCentricCoordinate[3]; + calculateBarycentricCoordinate(vt0, vt1, vt2, vt, baryCentricCoordinate); + if (isBarycentricCoordInBounds(baryCentricCoordinate)) { + int pixel = py * width + px; + if (zbuffer == 0) { + zbuffer[pixel] = (INT64)(idx + 1); + continue; + } + + float depth = baryCentricCoordinate[0] * vt0[2] + baryCentricCoordinate[1] * vt1[2] + baryCentricCoordinate[2] * vt2[2]; + float depth_thres = 0; + if (d) { + depth_thres = d[pixel] * 0.49999f + 0.5f + occlusion_truncation; + } + + int z_quantize = depth * (2<<17); + INT64 token = (INT64)z_quantize * MAXINT + (INT64)(idx + 1); + if (depth < depth_thres) + continue; + zbuffer[pixel] = std::min(zbuffer[pixel], token); + } + } + } +} + +void barycentricFromImgcoordCPU(float* V, int* F, int* findices, INT64* zbuffer, int width, int height, int num_vertices, int num_faces, + float* barycentric_map, int pix) +{ + INT64 f = zbuffer[pix] % MAXINT; + if (f == (MAXINT-1)) { + findices[pix] = 0; + barycentric_map[pix * 3] = 0; + barycentric_map[pix * 3 + 1] = 0; + barycentric_map[pix * 3 + 2] = 0; + return; + } + findices[pix] = f; + f -= 1; + float barycentric[3] = {0, 0, 0}; + if (f >= 0) { + float vt[2] = {float(pix % width) + 0.5f, float(pix / width) + 0.5f}; + float* vt0_ptr = V + (F[f * 3] * 4); + float* vt1_ptr = V + (F[f * 3 + 1] * 4); + float* vt2_ptr = V + (F[f * 3 + 2] * 4); + + float vt0[2] = {(vt0_ptr[0] / vt0_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt0_ptr[1] / vt0_ptr[3]) * (height - 1) + 0.5f}; + float vt1[2] = {(vt1_ptr[0] / vt1_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt1_ptr[1] / vt1_ptr[3]) * (height - 1) + 0.5f}; + float vt2[2] = {(vt2_ptr[0] / vt2_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt2_ptr[1] / vt2_ptr[3]) * (height - 1) + 0.5f}; + + calculateBarycentricCoordinate(vt0, vt1, vt2, vt, barycentric); + + barycentric[0] = barycentric[0] / vt0_ptr[3]; + barycentric[1] = barycentric[1] / vt1_ptr[3]; + barycentric[2] = barycentric[2] / vt2_ptr[3]; + float w = 1.0f / (barycentric[0] + barycentric[1] + barycentric[2]); + barycentric[0] *= w; + barycentric[1] *= w; + barycentric[2] *= w; + } + barycentric_map[pix * 3] = barycentric[0]; + barycentric_map[pix * 3 + 1] = barycentric[1]; + barycentric_map[pix * 3 + 2] = barycentric[2]; +} + +void rasterizeImagecoordsKernelCPU(float* V, int* F, float* d, INT64* zbuffer, float occlusion_trunc, int width, int height, int num_vertices, int num_faces, int f) +{ + float* vt0_ptr = V + (F[f * 3] * 4); + float* vt1_ptr = V + (F[f * 3 + 1] * 4); + float* vt2_ptr = V + (F[f * 3 + 2] * 4); + + float vt0[3] = {(vt0_ptr[0] / vt0_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt0_ptr[1] / vt0_ptr[3]) * (height - 1) + 0.5f, vt0_ptr[2] / vt0_ptr[3] * 0.49999f + 0.5f}; + float vt1[3] = {(vt1_ptr[0] / vt1_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt1_ptr[1] / vt1_ptr[3]) * (height - 1) + 0.5f, vt1_ptr[2] / vt1_ptr[3] * 0.49999f + 0.5f}; + float vt2[3] = {(vt2_ptr[0] / vt2_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt2_ptr[1] / vt2_ptr[3]) * (height - 1) + 0.5f, vt2_ptr[2] / vt2_ptr[3] * 0.49999f + 0.5f}; + + rasterizeTriangleCPU(f, vt0, vt1, vt2, width, height, zbuffer, d, occlusion_trunc); +} + +std::vector rasterize_image_cpu(torch::Tensor V, torch::Tensor F, torch::Tensor D, + int width, int height, float occlusion_truncation, int use_depth_prior) +{ + int num_faces = F.size(0); + int num_vertices = V.size(0); + auto options = torch::TensorOptions().dtype(torch::kInt32).requires_grad(false); + auto INT64_options = torch::TensorOptions().dtype(torch::kInt64).requires_grad(false); + auto findices = torch::zeros({height, width}, options); + INT64 maxint = (INT64)MAXINT * (INT64)MAXINT + (MAXINT - 1); + auto z_min = torch::ones({height, width}, INT64_options) * (int64_t)maxint; + + if (!use_depth_prior) { + for (int i = 0; i < num_faces; ++i) { + rasterizeImagecoordsKernelCPU(V.data_ptr(), F.data_ptr(), 0, + (INT64*)z_min.data_ptr(), occlusion_truncation, width, height, num_vertices, num_faces, i); + } + } else { + for (int i = 0; i < num_faces; ++i) + rasterizeImagecoordsKernelCPU(V.data_ptr(), F.data_ptr(), D.data_ptr(), + (INT64*)z_min.data_ptr(), occlusion_truncation, width, height, num_vertices, num_faces, i); + } + + auto float_options = torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false); + auto barycentric = torch::zeros({height, width, 3}, float_options); + for (int i = 0; i < width * height; ++i) + barycentricFromImgcoordCPU(V.data_ptr(), F.data_ptr(), + findices.data_ptr(), (INT64*)z_min.data_ptr(), width, height, num_vertices, num_faces, barycentric.data_ptr(), i); + + return {findices, barycentric}; +} + +std::vector rasterize_image(torch::Tensor V, torch::Tensor F, torch::Tensor D, + int width, int height, float occlusion_truncation, int use_depth_prior) +{ + int device_id = V.get_device(); + if (device_id == -1) + return rasterize_image_cpu(V, F, D, width, height, occlusion_truncation, use_depth_prior); + else + return rasterize_image_gpu(V, F, D, width, height, occlusion_truncation, use_depth_prior); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("rasterize_image", &rasterize_image, "Custom image rasterization"); +} diff --git a/sglang/python/sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/rasterizer.h b/sglang/python/sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/rasterizer.h new file mode 100644 index 0000000000000000000000000000000000000000..bb1703cf08d0a936f64f953aa704249b95befd4e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/rasterizer.h @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: Apache-2.0 +// Adapted from Hunyuan3D-2: https://github.com/Tencent/Hunyuan3D-2 +// Original license: TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT + +#ifndef RASTERIZER_H_ +#define RASTERIZER_H_ + +#include +#include +#include +#include + +#define INT64 unsigned long long +#define MAXINT 2147483647 + +__host__ __device__ inline float calculateSignedArea2(float* a, float* b, float* c) { + return ((c[0] - a[0]) * (b[1] - a[1]) - (b[0] - a[0]) * (c[1] - a[1])); +} + +__host__ __device__ inline void calculateBarycentricCoordinate(float* a, float* b, float* c, float* p, + float* barycentric) +{ + float beta_tri = calculateSignedArea2(a, p, c); + float gamma_tri = calculateSignedArea2(a, b, p); + float area = calculateSignedArea2(a, b, c); + if (area == 0) { + barycentric[0] = -1.0; + barycentric[1] = -1.0; + barycentric[2] = -1.0; + return; + } + float tri_inv = 1.0 / area; + float beta = beta_tri * tri_inv; + float gamma = gamma_tri * tri_inv; + float alpha = 1.0 - beta - gamma; + barycentric[0] = alpha; + barycentric[1] = beta; + barycentric[2] = gamma; +} + +__host__ __device__ inline bool isBarycentricCoordInBounds(float* barycentricCoord) { + return barycentricCoord[0] >= 0.0 && barycentricCoord[0] <= 1.0 && + barycentricCoord[1] >= 0.0 && barycentricCoord[1] <= 1.0 && + barycentricCoord[2] >= 0.0 && barycentricCoord[2] <= 1.0; +} + +std::vector rasterize_image_gpu(torch::Tensor V, torch::Tensor F, torch::Tensor D, + int width, int height, float occlusion_truncation, int use_depth_prior); + +#endif diff --git a/sglang/python/sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/rasterizer_gpu.cu b/sglang/python/sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/rasterizer_gpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..f1317270d23d1b0e8d0b2a2669f53369a78ad127 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/csrc/render/hunyuan3d_rasterizer/rasterizer_gpu.cu @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: Apache-2.0 +// Adapted from Hunyuan3D-2: https://github.com/Tencent/Hunyuan3D-2 +// Original license: TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT + +#include "rasterizer.h" + +__device__ void rasterizeTriangleGPU(int idx, float* vt0, float* vt1, float* vt2, int width, int height, INT64* zbuffer, float* d, float occlusion_truncation) { + float x_min = std::min(vt0[0], std::min(vt1[0],vt2[0])); + float x_max = std::max(vt0[0], std::max(vt1[0],vt2[0])); + float y_min = std::min(vt0[1], std::min(vt1[1],vt2[1])); + float y_max = std::max(vt0[1], std::max(vt1[1],vt2[1])); + + for (int px = x_min; px < x_max + 1; ++px) { + if (px < 0 || px >= width) + continue; + for (int py = y_min; py < y_max + 1; ++py) { + if (py < 0 || py >= height) + continue; + float vt[2] = {px + 0.5f, py + 0.5f}; + float baryCentricCoordinate[3]; + calculateBarycentricCoordinate(vt0, vt1, vt2, vt, baryCentricCoordinate); + if (isBarycentricCoordInBounds(baryCentricCoordinate)) { + int pixel = py * width + px; + if (zbuffer == 0) { + atomicExch(&zbuffer[pixel], (INT64)(idx + 1)); + continue; + } + float depth = baryCentricCoordinate[0] * vt0[2] + baryCentricCoordinate[1] * vt1[2] + baryCentricCoordinate[2] * vt2[2]; + float depth_thres = 0; + if (d) { + depth_thres = d[pixel] * 0.49999f + 0.5f + occlusion_truncation; + } + + int z_quantize = depth * (2<<17); + INT64 token = (INT64)z_quantize * MAXINT + (INT64)(idx + 1); + if (depth < depth_thres) + continue; + atomicMin(&zbuffer[pixel], token); + } + } + } +} + +__global__ void barycentricFromImgcoordGPU(float* V, int* F, int* findices, INT64* zbuffer, int width, int height, int num_vertices, int num_faces, + float* barycentric_map) +{ + int pix = blockIdx.x * blockDim.x + threadIdx.x; + if (pix >= width * height) + return; + INT64 f = zbuffer[pix] % MAXINT; + if (f == (MAXINT-1)) { + findices[pix] = 0; + barycentric_map[pix * 3] = 0; + barycentric_map[pix * 3 + 1] = 0; + barycentric_map[pix * 3 + 2] = 0; + return; + } + findices[pix] = f; + f -= 1; + float barycentric[3] = {0, 0, 0}; + if (f >= 0) { + float vt[2] = {float(pix % width) + 0.5f, float(pix / width) + 0.5f}; + float* vt0_ptr = V + (F[f * 3] * 4); + float* vt1_ptr = V + (F[f * 3 + 1] * 4); + float* vt2_ptr = V + (F[f * 3 + 2] * 4); + + float vt0[2] = {(vt0_ptr[0] / vt0_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt0_ptr[1] / vt0_ptr[3]) * (height - 1) + 0.5f}; + float vt1[2] = {(vt1_ptr[0] / vt1_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt1_ptr[1] / vt1_ptr[3]) * (height - 1) + 0.5f}; + float vt2[2] = {(vt2_ptr[0] / vt2_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt2_ptr[1] / vt2_ptr[3]) * (height - 1) + 0.5f}; + + calculateBarycentricCoordinate(vt0, vt1, vt2, vt, barycentric); + + barycentric[0] = barycentric[0] / vt0_ptr[3]; + barycentric[1] = barycentric[1] / vt1_ptr[3]; + barycentric[2] = barycentric[2] / vt2_ptr[3]; + float w = 1.0f / (barycentric[0] + barycentric[1] + barycentric[2]); + barycentric[0] *= w; + barycentric[1] *= w; + barycentric[2] *= w; + } + barycentric_map[pix * 3] = barycentric[0]; + barycentric_map[pix * 3 + 1] = barycentric[1]; + barycentric_map[pix * 3 + 2] = barycentric[2]; +} + +__global__ void rasterizeImagecoordsKernelGPU(float* V, int* F, float* d, INT64* zbuffer, float occlusion_trunc, int width, int height, int num_vertices, int num_faces) +{ + int f = blockIdx.x * blockDim.x + threadIdx.x; + if (f >= num_faces) + return; + + float* vt0_ptr = V + (F[f * 3] * 4); + float* vt1_ptr = V + (F[f * 3 + 1] * 4); + float* vt2_ptr = V + (F[f * 3 + 2] * 4); + + float vt0[3] = {(vt0_ptr[0] / vt0_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt0_ptr[1] / vt0_ptr[3]) * (height - 1) + 0.5f, vt0_ptr[2] / vt0_ptr[3] * 0.49999f + 0.5f}; + float vt1[3] = {(vt1_ptr[0] / vt1_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt1_ptr[1] / vt1_ptr[3]) * (height - 1) + 0.5f, vt1_ptr[2] / vt1_ptr[3] * 0.49999f + 0.5f}; + float vt2[3] = {(vt2_ptr[0] / vt2_ptr[3] * 0.5f + 0.5f) * (width - 1) + 0.5f, (0.5f + 0.5f * vt2_ptr[1] / vt2_ptr[3]) * (height - 1) + 0.5f, vt2_ptr[2] / vt2_ptr[3] * 0.49999f + 0.5f}; + + rasterizeTriangleGPU(f, vt0, vt1, vt2, width, height, zbuffer, d, occlusion_trunc); +} + +std::vector rasterize_image_gpu(torch::Tensor V, torch::Tensor F, torch::Tensor D, + int width, int height, float occlusion_truncation, int use_depth_prior) +{ + int device_id = V.get_device(); + cudaSetDevice(device_id); + int num_faces = F.size(0); + int num_vertices = V.size(0); + auto options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA, device_id).requires_grad(false); + auto INT64_options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA, device_id).requires_grad(false); + auto findices = torch::zeros({height, width}, options); + INT64 maxint = (INT64)MAXINT * (INT64)MAXINT + (MAXINT - 1); + auto z_min = torch::ones({height, width}, INT64_options) * (int64_t)maxint; + + if (!use_depth_prior) { + rasterizeImagecoordsKernelGPU<<<(num_faces+255)/256,256,0,at::cuda::getCurrentCUDAStream()>>>(V.data_ptr(), F.data_ptr(), 0, + (INT64*)z_min.data_ptr(), occlusion_truncation, width, height, num_vertices, num_faces); + } else { + rasterizeImagecoordsKernelGPU<<<(num_faces+255)/256,256,0,at::cuda::getCurrentCUDAStream()>>>(V.data_ptr(), F.data_ptr(), D.data_ptr(), + (INT64*)z_min.data_ptr(), occlusion_truncation, width, height, num_vertices, num_faces); + } + + auto float_options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, device_id).requires_grad(false); + auto barycentric = torch::zeros({height, width, 3}, float_options); + barycentricFromImgcoordGPU<<<(width * height + 255)/256, 256, 0, at::cuda::getCurrentCUDAStream()>>>(V.data_ptr(), F.data_ptr(), + findices.data_ptr(), (INT64*)z_min.data_ptr(), width, height, num_vertices, num_faces, barycentric.data_ptr()); + + return {findices, barycentric}; +} diff --git a/sglang/python/sglang/multimodal_gen/csrc/render/mesh_processor/__init__.py b/sglang/python/sglang/multimodal_gen/csrc/render/mesh_processor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..21fc15b96429a9464672f259826e76ab83e2d767 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/csrc/render/mesh_processor/__init__.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Mesh processor C++ extension for texture inpainting. + +This module provides JIT-compiled C++ mesh processing for fast texture inpainting. +Adapted from Hunyuan3D-2: https://github.com/Tencent/Hunyuan3D-2 +""" + +from __future__ import annotations + +import os +from typing import Tuple + +import numpy as np + +_abs_path = os.path.dirname(os.path.abspath(__file__)) +_mesh_processor_kernel = None + + +def _load_mesh_processor(): + """JIT compile and load the mesh processor kernel.""" + global _mesh_processor_kernel + + if _mesh_processor_kernel is not None: + return _mesh_processor_kernel + + from torch.utils.cpp_extension import load + + _mesh_processor_kernel = load( + name="mesh_processor_kernel", + sources=[ + f"{_abs_path}/mesh_processor.cpp", + ], + extra_cflags=["-O3"], + verbose=False, + ) + return _mesh_processor_kernel + + +def meshVerticeInpaint( + texture: np.ndarray, + mask: np.ndarray, + vtx_pos: np.ndarray, + vtx_uv: np.ndarray, + pos_idx: np.ndarray, + uv_idx: np.ndarray, + method: str = "smooth", +) -> Tuple[np.ndarray, np.ndarray]: + """Inpaint texture using mesh vertex connectivity.""" + kernel = _load_mesh_processor() + + texture = np.ascontiguousarray(texture, dtype=np.float32) + mask = np.ascontiguousarray(mask, dtype=np.uint8) + vtx_pos = np.ascontiguousarray(vtx_pos, dtype=np.float32) + vtx_uv = np.ascontiguousarray(vtx_uv, dtype=np.float32) + pos_idx = np.ascontiguousarray(pos_idx, dtype=np.int32) + uv_idx = np.ascontiguousarray(uv_idx, dtype=np.int32) + + return kernel.meshVerticeInpaint(texture, mask, vtx_pos, vtx_uv, pos_idx, uv_idx, method) + + +__all__ = ["meshVerticeInpaint"] diff --git a/sglang/python/sglang/multimodal_gen/csrc/render/mesh_processor/mesh_processor.cpp b/sglang/python/sglang/multimodal_gen/csrc/render/mesh_processor/mesh_processor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1ce0d35c28537a6b0c2b7b5467980fdb96ea91b4 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/csrc/render/mesh_processor/mesh_processor.cpp @@ -0,0 +1,163 @@ +// SPDX-License-Identifier: Apache-2.0 +// Adapted from Hunyuan3D-2: https://github.com/Tencent/Hunyuan3D-2 +// Original license: TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; +using namespace std; + +std::pair, + py::array_t> meshVerticeInpaint_smooth(py::array_t texture, +py::array_t mask, + py::array_t vtx_pos, py::array_t vtx_uv, + py::array_t pos_idx, py::array_t uv_idx) { + auto texture_buf = texture.request(); + auto mask_buf = mask.request(); + auto vtx_pos_buf = vtx_pos.request(); + auto vtx_uv_buf = vtx_uv.request(); + auto pos_idx_buf = pos_idx.request(); + auto uv_idx_buf = uv_idx.request(); + + int texture_height = texture_buf.shape[0]; + int texture_width = texture_buf.shape[1]; + int texture_channel = texture_buf.shape[2]; + float* texture_ptr = static_cast(texture_buf.ptr); + uint8_t* mask_ptr = static_cast(mask_buf.ptr); + + int vtx_num = vtx_pos_buf.shape[0]; + float* vtx_pos_ptr = static_cast(vtx_pos_buf.ptr); + float* vtx_uv_ptr = static_cast(vtx_uv_buf.ptr); + int* pos_idx_ptr = static_cast(pos_idx_buf.ptr); + int* uv_idx_ptr = static_cast(uv_idx_buf.ptr); + + vector vtx_mask(vtx_num, 0.0f); + vector> vtx_color(vtx_num, vector(texture_channel, 0.0f)); + vector uncolored_vtxs; + + vector> G(vtx_num); + + for (int i = 0; i < uv_idx_buf.shape[0]; ++i) { + for (int k = 0; k < 3; ++k) { + int vtx_uv_idx = uv_idx_ptr[i * 3 + k]; + int vtx_idx = pos_idx_ptr[i * 3 + k]; + int uv_v = round(vtx_uv_ptr[vtx_uv_idx * 2] * (texture_width - 1)); + int uv_u = round((1.0 - vtx_uv_ptr[vtx_uv_idx * 2 + 1]) * (texture_height - 1)); + + if (mask_ptr[uv_u * texture_width + uv_v] > 0) { + vtx_mask[vtx_idx] = 1.0f; + for (int c = 0; c < texture_channel; ++c) { + vtx_color[vtx_idx][c] = texture_ptr[(uv_u * texture_width + uv_v) * texture_channel + c]; + } + }else{ + uncolored_vtxs.push_back(vtx_idx); + } + + G[pos_idx_ptr[i * 3 + k]].push_back(pos_idx_ptr[i * 3 + (k + 1) % 3]); + } + } + + int smooth_count = 2; + int last_uncolored_vtx_count = 0; + while (smooth_count>0) { + int uncolored_vtx_count = 0; + + for (int vtx_idx : uncolored_vtxs) { + + vector sum_color(texture_channel, 0.0f); + float total_weight = 0.0f; + + array vtx_0 = {vtx_pos_ptr[vtx_idx * 3], +vtx_pos_ptr[vtx_idx * 3 + 1], vtx_pos_ptr[vtx_idx * 3 + 2]}; + for (int connected_idx : G[vtx_idx]) { + if (vtx_mask[connected_idx] > 0) { + array vtx1 = {vtx_pos_ptr[connected_idx * 3], + vtx_pos_ptr[connected_idx * 3 + 1], vtx_pos_ptr[connected_idx * 3 + 2]}; + float dist_weight = 1.0f / max(sqrt(pow(vtx_0[0] - vtx1[0], 2) + pow(vtx_0[1] - vtx1[1], 2) + \ + pow(vtx_0[2] - vtx1[2], 2)), 1E-4); + dist_weight = dist_weight * dist_weight; + for (int c = 0; c < texture_channel; ++c) { + sum_color[c] += vtx_color[connected_idx][c] * dist_weight; + } + total_weight += dist_weight; + } + } + + if (total_weight > 0.0f) { + for (int c = 0; c < texture_channel; ++c) { + vtx_color[vtx_idx][c] = sum_color[c] / total_weight; + } + vtx_mask[vtx_idx] = 1.0f; + } else { + uncolored_vtx_count++; + } + + } + + if(last_uncolored_vtx_count==uncolored_vtx_count){ + smooth_count--; + }else{ + smooth_count++; + } + last_uncolored_vtx_count = uncolored_vtx_count; + } + + py::array_t new_texture(texture_buf.size); + py::array_t new_mask(mask_buf.size); + + auto new_texture_buf = new_texture.request(); + auto new_mask_buf = new_mask.request(); + + float* new_texture_ptr = static_cast(new_texture_buf.ptr); + uint8_t* new_mask_ptr = static_cast(new_mask_buf.ptr); + std::copy(texture_ptr, texture_ptr + texture_buf.size, new_texture_ptr); + std::copy(mask_ptr, mask_ptr + mask_buf.size, new_mask_ptr); + + for (int face_idx = 0; face_idx < uv_idx_buf.shape[0]; ++face_idx) { + for (int k = 0; k < 3; ++k) { + int vtx_uv_idx = uv_idx_ptr[face_idx * 3 + k]; + int vtx_idx = pos_idx_ptr[face_idx * 3 + k]; + + if (vtx_mask[vtx_idx] == 1.0f) { + int uv_v = round(vtx_uv_ptr[vtx_uv_idx * 2] * (texture_width - 1)); + int uv_u = round((1.0 - vtx_uv_ptr[vtx_uv_idx * 2 + 1]) * (texture_height - 1)); + + for (int c = 0; c < texture_channel; ++c) { + new_texture_ptr[(uv_u * texture_width + uv_v) * texture_channel + c] = vtx_color[vtx_idx][c]; + } + new_mask_ptr[uv_u * texture_width + uv_v] = 255; + } + } + } + + new_texture.resize({texture_height, texture_width, 3}); + new_mask.resize({texture_height, texture_width}); + return std::make_pair(new_texture, new_mask); +} + + +std::pair, py::array_t> meshVerticeInpaint(py::array_t texture, + py::array_t mask, + py::array_t vtx_pos, py::array_t vtx_uv, + py::array_t pos_idx, py::array_t uv_idx, const std::string& method = "smooth") { + if (method == "smooth") { + return meshVerticeInpaint_smooth(texture, mask, vtx_pos, vtx_uv, pos_idx, uv_idx); + } else { + throw std::invalid_argument("Invalid method. Use 'smooth'."); + } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("meshVerticeInpaint", &meshVerticeInpaint, "Mesh-aware texture inpainting", + py::arg("texture"), py::arg("mask"), + py::arg("vtx_pos"), py::arg("vtx_uv"), + py::arg("pos_idx"), py::arg("uv_idx"), + py::arg("method") = "smooth"); +} diff --git a/sglang/python/sglang/multimodal_gen/docs/quantization.md b/sglang/python/sglang/multimodal_gen/docs/quantization.md new file mode 100644 index 0000000000000000000000000000000000000000..b8588a46d9fca667d82df1a3af5275d9768b3a89 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/docs/quantization.md @@ -0,0 +1,169 @@ +# Quantization + +This document introduces the model quantization schemes supported in SGLang and how to use them to reduce memory usage and accelerate inference. + +## Nunchaku (SVDQuant) + +### Introduction + +**SVDQuant** is a Post-Training Quantization (PTQ) technique for diffusion models that quantizes model weights and activations to 4-bit precision (W4A4) while maintaining high visual quality. This method uses Singular Value Decomposition (SVD) to decompose the weight matrix into low-rank components and residuals, effectively absorbing outliers in activations, making 4-bit quantization possible. + +**Nunchaku** is a high-performance inference engine that implements SVDQuant, optimized for low-bit neural networks. It is not Quantization-Aware Training (QAT), but directly quantizes pre-trained models. + +Paper: [SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models](https://arxiv.org/abs/2411.05007) (ICLR 2025 Spotlight) + +### Key Features + +SVDQuant significantly reduces memory usage and accelerates inference while maintaining visual quality: + +- **Memory Optimization**: Reduces memory usage by **3.6×** compared to BF16 models. +- **Inference Acceleration**: + - **3.0×** faster than the NF4 (W4A16) baseline on desktop/laptop RTX 4090 GPUs. + - **8.7×** speedup on laptop RTX 4090 by eliminating CPU offloading compared to 16-bit models. + - **3.1×** faster than BF16 and NF4 models on RTX 5090 GPUs with NVFP4. + +### Supported Precisions + +Nunchaku supports two quantization precisions: + +- **INT4**: Standard INT4 quantization, supported on NVIDIA GPUs with Compute Capability 7.0+ (RTX 20 series and above). +- **NVFP4**: FP4 quantization, providing better image quality on newer cards like the RTX 5090. + +### Usage + +#### 1. Install Nunchaku + +```bash +pip install nunchaku +``` + +For more installation information, please refer to the [Nunchaku Official Documentation](https://nunchaku.tech/docs/nunchaku/installation/installation.html). + +#### 2. Download Quantized Models + +Nunchaku provides pre-quantized model weights available on Hugging Face: + +- [nunchaku-ai/nunchaku-qwen-image](https://huggingface.co/nunchaku-ai/nunchaku-qwen-image) +- [nunchaku-ai/nunchaku-flux](https://huggingface.co/nunchaku-ai/nunchaku-flux) + +Taking Qwen-Image as an example, several quantized models with different configurations are provided: + +| Filename | Precision | Rank | Usage | +|----------|-----------|------|-------| +| `svdq-int4_r32-qwen-image.safetensors` | INT4 | 32 | Standard Version | +| `svdq-int4_r128-qwen-image.safetensors` | INT4 | 128 | High-Quality Version | +| `svdq-fp4_r32-qwen-image.safetensors` | NVFP4 | 32 | RTX 5090 Standard Version | +| `svdq-fp4_r128-qwen-image.safetensors` | NVFP4 | 128 | RTX 5090 High-Quality Version | +| `svdq-int4_r32-qwen-image-lightningv1.0-4steps.safetensors` | INT4 | 32 | Lightning 4-Step Version | +| `svdq-int4_r128-qwen-image-lightningv1.1-8steps.safetensors` | INT4 | 128 | Lightning 8-Step Version | + +> **Note**: Higher Rank usually means better image quality, but with slightly increased memory usage and computation. + +#### 3. Run Quantized Models + +SGLang features **smart auto-detection** for Nunchaku models. In most cases, you only need to provide the path to the quantized weights, and the precision and rank will be automatically inferred from the filename. + +**Simplified Command (Recommended):** + +```bash +sglang generate \ + --model-path Qwen/Qwen-Image \ + --prompt "change the raccoon to a cute cat" \ + --save-output \ + --transformer-weights-path /path/to/svdq-int4_r32-qwen-image.safetensors +``` + +**Manual Override (If needed):** + +If your filename doesn't follow the standard naming convention, or you want to force specific settings: + +- `--enable-svdquant`: Manually enable SVDQuant. +- `--quantization-precision`: Set to `int4` or `nvfp4`. +- `--quantization-rank`: Set the SVD rank (e.g., 32, 128). +- `--quantization-act-unsigned` (Optional): Use unsigned activation quantization. + +Example with manual overrides: + +```bash +sglang generate \ + --model-path Qwen/Qwen-Image \ + --prompt "a beautiful sunset" \ + --enable-svdquant \ + --transformer-weights-path /path/to/custom_model.safetensors \ + --quantization-precision int4 \ + --quantization-rank 128 +``` + +#### 4. Configuration Recommendations + +Choose the appropriate configuration based on your hardware and requirements: + +| Scenario | Recommended Config | Description | +|----------|-------------------|-------------| +| Standard Use (20/30/40 Series GPU) | INT4 + Rank 32 | Balanced performance and quality | +| Quality Focus (Sufficient VRAM) | INT4 + Rank 128 | Better image quality | +| RTX 5090 Standard Use | NVFP4 + Rank 32 | Utilizes FP4 hardware acceleration | +| RTX 5090 Quality Focus | NVFP4 + Rank 128 | Best image quality | +| Fast Prototyping/Preview | Lightning 4-Step Version | Extremely fast generation, slightly reduced quality | + +### Notes + +1. Model Path Correspondence: `--model-path` should point to the original non-quantized model (for loading config and tokenizer, etc.), while `--transformer-weights-path` points to the quantized weight file / folder / Huggingface Repo ID. + +2. Auto-Detection Requirements: For auto-detection to work, the filename must contain the pattern `svdq-{precision}_r{rank}` (e.g., `svdq-int4_r32`). + +3. GPU Compatibility: + - INT4: Supports NVIDIA GPUs with Compute Capability 7.0+ (RTX 20 series and above). + - NVFP4: Optimized mainly for newer cards like the RTX 50 series that support FP4. + +4. Lightning Models: When using Lightning versions, adjust `--num-inference-steps` accordingly (usually 4 or 8 steps). + +### Custom Model Quantization + +If you want to quantize your own models, you can use the [DeepCompressor](https://github.com/mit-han-lab/deepcompressor) tool. For detailed instructions, please refer to the Nunchaku official documentation. + +## Quantization + +### Usage + +#### Option 1: Pre-quantized folder (has `config.json`) + +For quantized checkpoints that include a `config.json` with a `quantization_config` field (e.g., models converted via `convert_hf_to_fp8.py`), where the transformer's `config.json` already encodes the `quantization_config`, use the component override: + +```bash +sglang generate \ + --model-path /path/to/FLUX.1-dev \ + --transformer-path /path/to/FLUX.1-dev/transformer-FP8 \ + --prompt "A Logo With Bold Large Text: SGL Diffusion" \ + --save-output +``` + + +If you need to convert a model to FP8 format yourself, use the provided conversion script: + +```bash +# convert transformer to FP8 with block quantization +python -m sglang.multimodal_gen.tools.convert_hf_to_fp8 \ + --model-dir /path/to/FLUX.1-dev/transformer \ + --save-dir /path/to/FLUX.1-dev/transformer-FP8 \ + --strategy block \ + --block-size 128 128 +``` + +#### Option 2: Pre-quantized single-file checkpoint (no `config.json`) + + + +Some providers (e.g., [black-forest-labs/FLUX.2-klein-9b-fp8](https://huggingface.co/black-forest-labs/FLUX.2-klein-9b-fp8)) distribute a single `.safetensors` file without a companion `config.json`. Use `--transformer-weights-path` to point to this file (or HuggingFace repo ID) while keeping `--model-path` for the base model: + +```bash +sglang generate \ + --model-path black-forest-labs/FLUX.2-klein-9B \ + --transformer-weights-path black-forest-labs/FLUX.2-klein-9b-fp8 \ + --prompt "A Logo With Bold Large Text: SGL Diffusion" \ + --save-output +``` + +SGLang-Diffusion will automatically read the `quantization_config` metadata embedded in the safetensors file header (if present). For the quant config to be auto-detected, the file's metadata must contain a JSON-encoded `quantization_config` key with at least a `quant_method` field (e.g. `"fp8"`). + +Note: this feature is a WIP diff --git a/sglang/python/sglang/multimodal_gen/envs.py b/sglang/python/sglang/multimodal_gen/envs.py new file mode 100644 index 0000000000000000000000000000000000000000..80ee489ad6f3d298dda653e8469dfe18c638b149 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/envs.py @@ -0,0 +1,332 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/envs.py + +import logging +import os +from typing import TYPE_CHECKING, Any, Callable + +from sglang.multimodal_gen.runtime.utils.common import get_bool_env_var + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + SGLANG_DIFFUSION_RINGBUFFER_WARNING_INTERVAL: int = 60 + SGLANG_DIFFUSION_NCCL_SO_PATH: str | None = None + LD_LIBRARY_PATH: str | None = None + LOCAL_RANK: int = 0 + CUDA_VISIBLE_DEVICES: str | None = None + SGLANG_DIFFUSION_CACHE_ROOT: str = os.path.expanduser("~/.cache/sgl_diffusion") + SGLANG_DIFFUSION_CONFIG_ROOT: str = os.path.expanduser("~/.config/sgl_diffusion") + SGLANG_DIFFUSION_CONFIGURE_LOGGING: int = 1 + SGLANG_DIFFUSION_LOGGING_LEVEL: str = "INFO" + SGLANG_DIFFUSION_LOGGING_PREFIX: str = "" + SGLANG_DIFFUSION_LOGGING_CONFIG_PATH: str | None = None + SGLANG_DIFFUSION_TRACE_FUNCTION: int = 0 + SGLANG_DIFFUSION_WORKER_MULTIPROC_METHOD: str = "fork" + SGLANG_DIFFUSION_TARGET_DEVICE: str = "cuda" + MAX_JOBS: str | None = None + NVCC_THREADS: str | None = None + CMAKE_BUILD_TYPE: str | None = None + VERBOSE: bool = False + SGLANG_DIFFUSION_SERVER_DEV_MODE: bool = False + SGLANG_DIFFUSION_STAGE_LOGGING: bool = False + # cache-dit env vars (primary transformer) + SGLANG_CACHE_DIT_ENABLED: bool = False + SGLANG_CACHE_DIT_FN: int = 1 + SGLANG_CACHE_DIT_BN: int = 0 + SGLANG_CACHE_DIT_WARMUP: int = 4 + SGLANG_CACHE_DIT_RDT: float = 0.24 + SGLANG_CACHE_DIT_MC: int = 3 + SGLANG_CACHE_DIT_TAYLORSEER: bool = False + SGLANG_CACHE_DIT_TS_ORDER: int = 1 + SGLANG_CACHE_DIT_SCM_PRESET: str = "none" + SGLANG_CACHE_DIT_SCM_COMPUTE_BINS: str | None = None + SGLANG_CACHE_DIT_SCM_CACHE_BINS: str | None = None + SGLANG_CACHE_DIT_SCM_POLICY: str = "dynamic" + # cache-dit env vars (secondary transformer, e.g., Wan2.2 low-noise expert) + SGLANG_CACHE_DIT_SECONDARY_FN: int = 1 + SGLANG_CACHE_DIT_SECONDARY_BN: int = 0 + SGLANG_CACHE_DIT_SECONDARY_WARMUP: int = 4 + SGLANG_CACHE_DIT_SECONDARY_RDT: float = 0.24 + SGLANG_CACHE_DIT_SECONDARY_MC: int = 3 + SGLANG_CACHE_DIT_SECONDARY_TAYLORSEER: bool = False + SGLANG_CACHE_DIT_SECONDARY_TS_ORDER: int = 1 + # model loading + SGLANG_USE_RUNAI_MODEL_STREAMER: bool = True + SGLANG_DIFFUSION_VAE_CHANNELS_LAST_3D: bool = False + + +def get_default_cache_root() -> str: + return os.getenv( + "XDG_CACHE_HOME", + os.path.join(os.path.expanduser("~"), ".cache"), + ) + + +def get_default_config_root() -> str: + return os.getenv( + "XDG_CONFIG_HOME", + os.path.join(os.path.expanduser("~"), ".config"), + ) + + +def maybe_convert_int(value: str | None) -> int | None: + return int(value) if value is not None else None + + +# helpers for environment variable definitions +def _lazy_str(key: str, default: str | None = None) -> Callable[[], str | None]: + return lambda: os.getenv(key, default) + + +def _lazy_int(key: str, default: str | int | None = None) -> Callable[[], int | None]: + def _getter(): + val = os.getenv(key) + if val is None: + return int(default) if default is not None else None + return int(val) + + return _getter + + +def _lazy_float(key: str, default: str | float) -> Callable[[], float]: + return lambda: float(os.getenv(key, str(default))) + + +def _lazy_bool(key: str, default: str = "false") -> Callable[[], bool]: + return lambda: get_bool_env_var(key, default) + + +def _lazy_bool_any(keys: list[str], default: str = "false") -> Callable[[], bool]: + def _getter(): + for key in keys: + if get_bool_env_var(key, "false"): + return True + return ( + get_bool_env_var("", default) + if not keys + else get_bool_env_var(keys[0], default) + ) + + return _getter + + +def _lazy_path( + key: str, default_func: Callable[[], str] | None = None +) -> Callable[[], str | None]: + def _getter(): + val = os.getenv(key) + if val is None: + if default_func is None: + return None + val = default_func() + return os.path.expanduser(val) + + return _getter + + +# The begin-* and end* here are used by the documentation generator +# to extract the used env vars. + +# begin-env-vars-definition + +environment_variables: dict[str, Callable[[], Any]] = { + # ================== Installation Time Env Vars ================== + # Target device of sglang-diffusion, supporting [cuda (by default), + # rocm, neuron, cpu, openvino] + "SGLANG_DIFFUSION_TARGET_DEVICE": _lazy_str( + "SGLANG_DIFFUSION_TARGET_DEVICE", "cuda" + ), + # Maximum number of compilation jobs to run in parallel. + # By default this is the number of CPUs + "MAX_JOBS": _lazy_str("MAX_JOBS"), + # Number of threads to use for nvcc + # By default this is 1. + # If set, `MAX_JOBS` will be reduced to avoid oversubscribing the CPU. + "NVCC_THREADS": _lazy_str("NVCC_THREADS"), + # If set, sgl_diffusion will use precompiled binaries (*.so) + "SGLANG_DIFFUSION_USE_PRECOMPILED": _lazy_bool_any( + [ + "SGLANG_DIFFUSION_USE_PRECOMPILED", + "SGLANG_DIFFUSION_PRECOMPILED_WHEEL_LOCATION", + ] + ), + # CMake build type + # If not set, defaults to "Debug" or "RelWithDebInfo" + # Available options: "Debug", "Release", "RelWithDebInfo" + "CMAKE_BUILD_TYPE": _lazy_str("CMAKE_BUILD_TYPE"), + # If set, sgl_diffusion will print verbose logs during installation + "VERBOSE": _lazy_bool("VERBOSE"), + # Root directory for SGL-diffusion configuration files + # Defaults to `~/.config/sgl_diffusion` unless `XDG_CONFIG_HOME` is set + # Note that this not only affects how sgl_diffusion finds its configuration files + # during runtime, but also affects how sgl_diffusion installs its configuration + # files during **installation**. + "SGLANG_DIFFUSION_CONFIG_ROOT": _lazy_path( + "SGLANG_DIFFUSION_CONFIG_ROOT", + lambda: os.path.join(get_default_config_root(), "sgl_diffusion"), + ), + # ================== Runtime Env Vars ================== + # Root directory for SGL-diffusion cache files + # Defaults to `~/.cache/sgl_diffusion` unless `XDG_CACHE_HOME` is set + "SGLANG_DIFFUSION_CACHE_ROOT": _lazy_path( + "SGLANG_DIFFUSION_CACHE_ROOT", + lambda: os.path.join(get_default_cache_root(), "sgl_diffusion"), + ), + # Interval in seconds to log a warning message when the ring buffer is full + "SGLANG_DIFFUSION_RINGBUFFER_WARNING_INTERVAL": _lazy_int( + "SGLANG_DIFFUSION_RINGBUFFER_WARNING_INTERVAL", 60 + ), + # Path to the NCCL library file. It is needed because nccl>=2.19 brought + # by PyTorch contains a bug: https://github.com/NVIDIA/nccl/issues/1234 + "SGLANG_DIFFUSION_NCCL_SO_PATH": _lazy_str("SGLANG_DIFFUSION_NCCL_SO_PATH"), + # when `SGLANG_DIFFUSION_NCCL_SO_PATH` is not set, sgl_diffusion will try to find the nccl + # library file in the locations specified by `LD_LIBRARY_PATH` + "LD_LIBRARY_PATH": _lazy_str("LD_LIBRARY_PATH"), + # Internal flag to enable Dynamo fullgraph capture + "SGLANG_DIFFUSION_TEST_DYNAMO_FULLGRAPH_CAPTURE": _lazy_bool( + "SGLANG_DIFFUSION_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1" + ), + # local rank of the process in the distributed setting, used to determine + # the GPU device id + "LOCAL_RANK": _lazy_int("LOCAL_RANK", 0), + # used to control the visible devices in the distributed setting + "CUDA_VISIBLE_DEVICES": _lazy_str("CUDA_VISIBLE_DEVICES"), + # timeout for each iteration in the engine + "SGLANG_DIFFUSION_ENGINE_ITERATION_TIMEOUT_S": _lazy_int( + "SGLANG_DIFFUSION_ENGINE_ITERATION_TIMEOUT_S", 60 + ), + # Logging configuration + # If set to 0, sgl_diffusion will not configure logging + # If set to 1, sgl_diffusion will configure logging using the default configuration + # or the configuration file specified by SGLANG_DIFFUSION_LOGGING_CONFIG_PATH + "SGLANG_DIFFUSION_CONFIGURE_LOGGING": _lazy_int( + "SGLANG_DIFFUSION_CONFIGURE_LOGGING", 1 + ), + "SGLANG_DIFFUSION_LOGGING_CONFIG_PATH": _lazy_str( + "SGLANG_DIFFUSION_LOGGING_CONFIG_PATH" + ), + # this is used for configuring the default logging level + "SGLANG_DIFFUSION_LOGGING_LEVEL": _lazy_str( + "SGLANG_DIFFUSION_LOGGING_LEVEL", "INFO" + ), + # if set, SGLANG_DIFFUSION_LOGGING_PREFIX will be prepended to all log messages + "SGLANG_DIFFUSION_LOGGING_PREFIX": _lazy_str("SGLANG_DIFFUSION_LOGGING_PREFIX", ""), + # Trace function calls + # If set to 1, sgl_diffusion will trace function calls + # Useful for debugging + "SGLANG_DIFFUSION_TRACE_FUNCTION": _lazy_int("SGLANG_DIFFUSION_TRACE_FUNCTION", 0), + # Path to the attention configuration file. Only used for sliding tile + # attention for now. + "SGLANG_DIFFUSION_ATTENTION_CONFIG": _lazy_path( + "SGLANG_DIFFUSION_ATTENTION_CONFIG" + ), + # Optional override to force a specific attention backend (e.g. "aiter") + "SGLANG_DIFFUSION_ATTENTION_BACKEND": _lazy_str( + "SGLANG_DIFFUSION_ATTENTION_BACKEND" + ), + # Use dedicated multiprocess context for workers. + # Both spawn and fork work + "SGLANG_DIFFUSION_WORKER_MULTIPROC_METHOD": _lazy_str( + "SGLANG_DIFFUSION_WORKER_MULTIPROC_METHOD", "fork" + ), + # Enables torch profiler if set. Path to the directory where torch profiler + # traces are saved. Note that it must be an absolute path. + "SGLANG_DIFFUSION_TORCH_PROFILER_DIR": _lazy_path( + "SGLANG_DIFFUSION_TORCH_PROFILER_DIR" + ), + # If set, sgl_diffusion will run in development mode, which will enable + # some additional endpoints for developing and debugging, + # e.g. `/reset_prefix_cache` + "SGLANG_DIFFUSION_SERVER_DEV_MODE": _lazy_bool("SGLANG_DIFFUSION_SERVER_DEV_MODE"), + # If set, sgl_diffusion will enable stage logging, which will print the time + # taken for each stage + "SGLANG_DIFFUSION_STAGE_LOGGING": _lazy_bool("SGLANG_DIFFUSION_STAGE_LOGGING"), + "SGLANG_DIFFUSION_VAE_CHANNELS_LAST_3D": _lazy_bool( + "SGLANG_DIFFUSION_VAE_CHANNELS_LAST_3D", "false" + ), + # ================== cache-dit Env Vars ================== + # Enable cache-dit acceleration for DiT inference + "SGLANG_CACHE_DIT_ENABLED": _lazy_bool("SGLANG_CACHE_DIT_ENABLED"), + # Number of first blocks to always compute (DBCache F parameter) + "SGLANG_CACHE_DIT_FN": _lazy_int("SGLANG_CACHE_DIT_FN", 1), + # Number of last blocks to always compute (DBCache B parameter) + "SGLANG_CACHE_DIT_BN": _lazy_int("SGLANG_CACHE_DIT_BN", 0), + # Warmup steps before caching (DBCache W parameter) + "SGLANG_CACHE_DIT_WARMUP": _lazy_int("SGLANG_CACHE_DIT_WARMUP", 4), + # Residual difference threshold (DBCache R parameter) + "SGLANG_CACHE_DIT_RDT": _lazy_float("SGLANG_CACHE_DIT_RDT", 0.24), + # Maximum continuous cached steps (DBCache MC parameter) + "SGLANG_CACHE_DIT_MC": _lazy_int("SGLANG_CACHE_DIT_MC", 3), + # Enable TaylorSeer calibrator + "SGLANG_CACHE_DIT_TAYLORSEER": _lazy_bool("SGLANG_CACHE_DIT_TAYLORSEER", "false"), + # TaylorSeer order (1 or 2) + "SGLANG_CACHE_DIT_TS_ORDER": _lazy_int("SGLANG_CACHE_DIT_TS_ORDER", 1), + # SCM preset: none, slow, medium, fast, ultra + "SGLANG_CACHE_DIT_SCM_PRESET": _lazy_str("SGLANG_CACHE_DIT_SCM_PRESET", "none"), + # SCM custom compute bins (e.g., "8,3,3,2,2") + "SGLANG_CACHE_DIT_SCM_COMPUTE_BINS": _lazy_str("SGLANG_CACHE_DIT_SCM_COMPUTE_BINS"), + # SCM custom cache bins (e.g., "1,2,2,2,3") + "SGLANG_CACHE_DIT_SCM_CACHE_BINS": _lazy_str("SGLANG_CACHE_DIT_SCM_CACHE_BINS"), + # SCM policy: dynamic or static + "SGLANG_CACHE_DIT_SCM_POLICY": _lazy_str("SGLANG_CACHE_DIT_SCM_POLICY", "dynamic"), + # model loading + "SGLANG_USE_RUNAI_MODEL_STREAMER": _lazy_bool( + "SGLANG_USE_RUNAI_MODEL_STREAMER", "true" + ), +} + +# Add cache-dit Secondary Transformer Env Vars via programmatic generation to reduce duplication +_CACHE_DIT_SECONDARY_CONFIGS = [ + ("FN", int, "1"), + ("BN", int, "0"), + ("WARMUP", int, "4"), + ("RDT", float, "0.24"), + ("MC", int, "3"), + ("TS_ORDER", int, "1"), +] + + +def _create_secondary_getter(suffix, type_func, default_val): + primary_key = f"SGLANG_CACHE_DIT_{suffix}" + secondary_key = f"SGLANG_CACHE_DIT_SECONDARY_{suffix}" + + def _getter(): + val = os.getenv(secondary_key) + if val is not None: + return type_func(val) + return type_func(os.getenv(primary_key, str(default_val))) + + return secondary_key, _getter + + +for suffix, type_func, default_val in _CACHE_DIT_SECONDARY_CONFIGS: + key, getter = _create_secondary_getter(suffix, type_func, default_val) + environment_variables[key] = getter + + +# Special handling for boolean secondary var (TaylorSeer) +def _secondary_taylorseer_getter(): + return get_bool_env_var( + "SGLANG_CACHE_DIT_SECONDARY_TAYLORSEER", + default=os.getenv("SGLANG_CACHE_DIT_TAYLORSEER", "false"), + ) + + +environment_variables["SGLANG_CACHE_DIT_SECONDARY_TAYLORSEER"] = ( + _secondary_taylorseer_getter +) + + +# end-env-vars-definition +def __getattr__(name: str): + # lazy evaluation of environment variables + if name in environment_variables: + return environment_variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return list(environment_variables.keys()) diff --git a/sglang/python/sglang/multimodal_gen/registry.py b/sglang/python/sglang/multimodal_gen/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..6903272185c26ae52ed96189de95affeeefc1a49 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/registry.py @@ -0,0 +1,789 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Central registry for multimodal models. + +This module provides a centralized registry for multimodal models, including pipelines +and sampling parameters. It allows for easy registration and retrieval of model +information based on model paths or other identifiers. +""" + +import dataclasses +import importlib +import os +import pkgutil +from functools import lru_cache +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, +) + +if TYPE_CHECKING: + from sglang.multimodal_gen.runtime.server_args import Backend + +from sglang.multimodal_gen.configs.pipeline_configs import ( + FastHunyuanConfig, + FluxPipelineConfig, + HeliosDistilledConfig, + HeliosMidConfig, + HeliosT2VConfig, + HunyuanConfig, + WanI2V480PConfig, + WanI2V720PConfig, + WanT2V480PConfig, + WanT2V720PConfig, + ZImagePipelineConfig, +) +from sglang.multimodal_gen.configs.pipeline_configs.base import PipelineConfig +from sglang.multimodal_gen.configs.pipeline_configs.flux import ( + Flux2KleinPipelineConfig, + Flux2PipelineConfig, +) +from sglang.multimodal_gen.configs.pipeline_configs.glm_image import ( + GlmImagePipelineConfig, +) +from sglang.multimodal_gen.configs.pipeline_configs.hunyuan3d import ( + Hunyuan3D2PipelineConfig, +) +from sglang.multimodal_gen.configs.pipeline_configs.ltx_2 import LTX2PipelineConfig +from sglang.multimodal_gen.configs.pipeline_configs.mova import ( + MOVA360PConfig, + MOVA720PConfig, +) +from sglang.multimodal_gen.configs.pipeline_configs.qwen_image import ( + QwenImageEditPipelineConfig, + QwenImageEditPlus_2511_PipelineConfig, + QwenImageEditPlusPipelineConfig, + QwenImageLayeredPipelineConfig, + QwenImagePipelineConfig, +) +from sglang.multimodal_gen.configs.pipeline_configs.wan import ( + FastWan2_1_T2V_480P_Config, + FastWan2_2_TI2V_5B_Config, + TurboWanI2V720Config, + TurboWanT2V480PConfig, + Wan2_2_I2V_A14B_Config, + Wan2_2_T2V_A14B_Config, + Wan2_2_TI2V_5B_Config, +) +from sglang.multimodal_gen.configs.sample.flux import ( + Flux2KleinSamplingParams, + FluxSamplingParams, +) +from sglang.multimodal_gen.configs.sample.glmimage import GlmImageSamplingParams +from sglang.multimodal_gen.configs.sample.helios import ( + HeliosDistilledSamplingParams, + HeliosMidSamplingParams, + HeliosT2VSamplingParams, +) +from sglang.multimodal_gen.configs.sample.hunyuan import ( + FastHunyuanSamplingParam, + HunyuanSamplingParams, +) +from sglang.multimodal_gen.configs.sample.hunyuan3d import Hunyuan3DSamplingParams +from sglang.multimodal_gen.configs.sample.ltx_2 import LTX2SamplingParams +from sglang.multimodal_gen.configs.sample.mova import ( + MOVA_360P_SamplingParams, + MOVA_720P_SamplingParams, +) +from sglang.multimodal_gen.configs.sample.qwenimage import ( + QwenImage2512SamplingParams, + QwenImageEditPlusSamplingParams, + QwenImageLayeredSamplingParams, + QwenImageSamplingParams, +) +from sglang.multimodal_gen.configs.sample.wan import ( + FastWanT2V480PConfig, + Turbo_Wan2_2_I2V_A14B_SamplingParam, + Wan2_1_Fun_1_3B_InP_SamplingParams, + Wan2_2_I2V_A14B_SamplingParam, + Wan2_2_T2V_A14B_SamplingParam, + Wan2_2_TI2V_5B_SamplingParam, + WanI2V_14B_480P_SamplingParam, + WanI2V_14B_720P_SamplingParam, + WanT2V_1_3B_SamplingParams, + WanT2V_14B_SamplingParams, +) +from sglang.multimodal_gen.configs.sample.zimage import ( + ZImageSamplingParams, + ZImageTurboSamplingParams, +) +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( + maybe_download_model_index, + verify_model_config_and_directory, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +# --- Part 1: Pipeline Discovery --- + +_PIPELINE_REGISTRY: Dict[str, Type[ComposedPipelineBase]] = {} + +# Registry for pipeline configuration classes (for safetensors files without model_index.json) +# Maps pipeline_class_name -> (PipelineConfig class, SamplingParams class) +_PIPELINE_CONFIG_REGISTRY: Dict[str, Tuple[Type[PipelineConfig], Type[Any]]] = {} + + +def _discover_and_register_pipelines(): + """ + Automatically discover and register all ComposedPipelineBase subclasses. + This function scans the 'sglang.multimodal_gen.runtime.pipelines' package, + finds modules with an 'EntryClass' attribute, and maps the class's 'pipeline_name' + to the class itself in a global registry. + """ + if _PIPELINE_REGISTRY: # run only once + return + + package_name = "sglang.multimodal_gen.runtime.pipelines" + package = importlib.import_module(package_name) + + for _, module_name, ispkg in pkgutil.walk_packages( + package.__path__, package.__name__ + "." + ): + if not ispkg: + pipeline_module = importlib.import_module(module_name) + if hasattr(pipeline_module, "EntryClass"): + entry_cls = pipeline_module.EntryClass + entry_cls_list = ( + [entry_cls] if not isinstance(entry_cls, list) else entry_cls + ) + + for cls in entry_cls_list: + if not issubclass(cls, ComposedPipelineBase): + continue + if cls.pipeline_name in _PIPELINE_REGISTRY: + logger.warning( + f"Duplicate pipeline name '{cls.pipeline_name}' found. Overwriting." + ) + _PIPELINE_REGISTRY[cls.pipeline_name] = cls + + # Special handling for ComfyUI Pipelines: + # Auto-register config classes if Pipeline class has them defined + # since comfyui get model from a single weight file, so we need to register the config classes here + if hasattr(cls, "pipeline_config_cls") and hasattr( + cls, "sampling_params_cls" + ): + _PIPELINE_CONFIG_REGISTRY[cls.pipeline_name] = ( + cls.pipeline_config_cls, + cls.sampling_params_cls, + ) + logger.debug( + f"Auto-registered config classes for pipeline '{cls.pipeline_name}': " + f"PipelineConfig={cls.pipeline_config_cls.__name__}, " + f"SamplingParams={cls.sampling_params_cls.__name__}" + ) + logger.debug( + f"Registering pipelines complete, {len(_PIPELINE_REGISTRY)} pipelines registered" + ) + + +def get_pipeline_config_classes( + pipeline_class_name: str, +) -> Tuple[Type[PipelineConfig], Type[Any]] | None: + """ + Get the configuration classes for a pipeline. + """ + # Ensure pipelines are discovered first + _discover_and_register_pipelines() + return _PIPELINE_CONFIG_REGISTRY.get(pipeline_class_name) + + +# --- Part 2: Config Registration --- +@dataclasses.dataclass +class ConfigInfo: + """Encapsulates all configuration information required to register a + diffusers model within this framework.""" + + sampling_param_cls: Any + pipeline_config_cls: Type[PipelineConfig] + + +# The central registry mapping a model name to its configuration information +_CONFIG_REGISTRY: Dict[str, ConfigInfo] = {} + +# Mappings from Hugging Face model paths to our internal model names +_MODEL_HF_PATH_TO_NAME: Dict[str, str] = {} + +# Detectors to identify model families from paths or class names +_MODEL_NAME_DETECTORS: List[Tuple[str, Callable[[str], bool]]] = [] + + +def register_configs( + sampling_param_cls: Any, + pipeline_config_cls: Type[PipelineConfig], + hf_model_paths: Optional[List[str]] = None, + model_detectors: Optional[List[Callable[[str], bool]]] = None, +): + """ + Registers configuration classes for a new model family. + """ + model_id = str(len(_CONFIG_REGISTRY)) + + _CONFIG_REGISTRY[model_id] = ConfigInfo( + sampling_param_cls=sampling_param_cls, + pipeline_config_cls=pipeline_config_cls, + ) + if hf_model_paths: + for path in hf_model_paths: + if path in _MODEL_HF_PATH_TO_NAME: + logger.warning( + f"Model path '{path}' is already mapped to '{_MODEL_HF_PATH_TO_NAME[path]}' and will be overwritten by '{model_id}'." + ) + _MODEL_HF_PATH_TO_NAME[path] = model_id + + if model_detectors: + for detector in model_detectors: + _MODEL_NAME_DETECTORS.append((model_id, detector)) + + +def get_model_short_name(model_id: str) -> str: + if "/" in model_id: + return model_id.rstrip("/").split("/")[-1] + else: + return model_id + + +@lru_cache(maxsize=1) +def _get_config_info( + model_path: str, model_id: Optional[str] = None +) -> Optional[ConfigInfo]: + """ + Gets the ConfigInfo for a given model path using mappings and detectors. + """ + all_model_hf_paths = sorted(_MODEL_HF_PATH_TO_NAME.keys(), key=len, reverse=True) + + # 0. Explicit model_id override: match by short name + if model_id is not None: + model_id_lower = model_id.lower() + for registered_hf_id in all_model_hf_paths: + if get_model_short_name(registered_hf_id).lower() == model_id_lower: + logger.debug( + f"Resolved model via explicit --model-id '{model_id}' → '{registered_hf_id}'." + ) + return _CONFIG_REGISTRY.get(_MODEL_HF_PATH_TO_NAME[registered_hf_id]) + logger.warning( + f"--model-id '{model_id}' did not match any registered model; " + "falling back to automatic detection." + ) + + # 1. Exact match + if model_path in _MODEL_HF_PATH_TO_NAME: + model_id = _MODEL_HF_PATH_TO_NAME[model_path] + logger.debug(f"Resolved model path '{model_path}' from exact path match.") + return _CONFIG_REGISTRY.get(model_id) + + # 2. Partial match: find the best (longest) match against all registered model hf paths. + model_short_name = get_model_short_name(model_path.lower()) + for registered_model_hf_id in all_model_hf_paths: + registered_model_name = get_model_short_name(registered_model_hf_id.lower()) + + if registered_model_name == model_short_name: + logger.debug( + f"Resolved model name '{registered_model_hf_id}' from partial path match." + ) + model_id = _MODEL_HF_PATH_TO_NAME[registered_model_hf_id] + return _CONFIG_REGISTRY.get(model_id) + + # 3. Use detectors + if os.path.exists(model_path): + config = verify_model_config_and_directory(model_path) + else: + config = maybe_download_model_index(model_path) + + pipeline_name = config.get("_class_name", "").lower() + + matched_model_names = [] + for model_id, detector in _MODEL_NAME_DETECTORS: + if detector(model_path.lower()) or detector(pipeline_name): + logger.debug( + f"Matched model name '{model_id}' using a registered detector." + ) + matched_model_names += [model_id] + + if len(matched_model_names) >= 1: + if len(matched_model_names) > 1: + logger.warning( + f"More than one model name is matched, using the first matched" + ) + model_id = matched_model_names[0] + return _CONFIG_REGISTRY.get(model_id) + else: + raise RuntimeError(f"No model info found for model path: {model_path}") + + +# --- Part 3: Main Resolver --- + + +@dataclasses.dataclass +class ModelInfo: + """ + Encapsulates all configuration information required to register a + diffusers model within this framework. + """ + + pipeline_cls: Type[ComposedPipelineBase] + sampling_param_cls: Any + pipeline_config_cls: Type[PipelineConfig] + + +def _get_diffusers_model_info() -> ModelInfo: + """ + Get model info for diffusers backend. + + Returns a ModelInfo with DiffusersPipeline and generic configs. + """ + from sglang.multimodal_gen.configs.pipeline_configs.diffusers_generic import ( + DiffusersGenericPipelineConfig, + ) + from sglang.multimodal_gen.configs.sample.diffusers_generic import ( + DiffusersGenericSamplingParams, + ) + from sglang.multimodal_gen.runtime.pipelines.diffusers_pipeline import ( + DiffusersPipeline, + ) + + return ModelInfo( + pipeline_cls=DiffusersPipeline, + sampling_param_cls=DiffusersGenericSamplingParams, + pipeline_config_cls=DiffusersGenericPipelineConfig, + ) + + +@lru_cache(maxsize=1) +def get_model_info( + model_path: str, + backend: Optional[Union[str, "Backend"]] = None, + model_id: Optional[str] = None, +) -> Optional[ModelInfo]: + """ + Resolves all necessary classes (pipeline, sampling, config) for a given model path. + + This function serves as the main entry point for model resolution. It performs two main tasks: + 1. Dynamically resolves the pipeline class by reading 'model_index.json' and matching + '_class_name' against an auto-discovered registry of pipeline implementations. + 2. Resolves the associated configuration classes (for sampling and pipeline) using a + manually registered mapping based on the model path. + + Args: + backend: Backend to use ('auto', 'sglang', 'diffusers'). If None, uses 'auto'. + + """ + # import Backend enum here to avoid circular imports + from sglang.multimodal_gen.runtime.server_args import Backend + + # Normalize backend + if backend is None: + backend = Backend.AUTO + elif isinstance(backend, str): + backend = Backend.from_string(backend) + + # Handle explicit diffusers backend + if backend == Backend.DIFFUSERS: + logger.info( + "Using diffusers backend for model '%s' (explicitly requested)", model_path + ) + return _get_diffusers_model_info() + + # For AUTO or SGLANG backend, try native implementation first + # 1. Discover all available pipeline classes and cache them + _discover_and_register_pipelines() + + # Detect quantized models and fallback to diffusers + is_quantized = any(q in model_path.lower() for q in ["-4bit", "-awq", "-gptq"]) + if is_quantized and backend != Backend.DIFFUSERS: + logger.info( + "Detected a quantized model format ('%s'). " + "The native sglang-diffusion engine currently only supports BF16/FP16. " + "Falling back to diffusers backend.", + model_path, + ) + return _get_diffusers_model_info(model_path) + + # 2. Get pipeline class - check non-diffusers models first + pipeline_class_name = get_non_diffusers_pipeline_name(model_path) + if pipeline_class_name: + # Known non-diffusers model, skip model_index.json download + logger.debug( + f"Using registered pipeline '{pipeline_class_name}' for non-diffusers model '{model_path}'" + ) + else: + # Try to get from model_index.json + try: + if os.path.exists(model_path): + config = verify_model_config_and_directory(model_path) + else: + config = maybe_download_model_index(model_path) + except Exception as e: + logger.error(f"Could not read model config for '{model_path}': {e}") + if backend == Backend.AUTO: + logger.info("Falling back to diffusers backend") + return _get_diffusers_model_info() + return None + + pipeline_class_name = config.get("_class_name") + if not pipeline_class_name: + logger.error( + f"'_class_name' not found in model_index.json for '{model_path}'" + ) + if backend == Backend.AUTO: + logger.info("Falling back to diffusers backend") + return _get_diffusers_model_info() + return None + + pipeline_cls = _PIPELINE_REGISTRY.get(pipeline_class_name) + if not pipeline_cls: + if backend == Backend.AUTO: + logger.warning( + f"Pipeline class '{pipeline_class_name}' specified in '{model_path}' has no native sglang support. " + f"Falling back to diffusers backend." + ) + return _get_diffusers_model_info() + else: + logger.error( + f"Pipeline class '{pipeline_class_name}' specified in '{model_path}' is not a registered EntryClass in the framework. " + f"Available pipelines: {list(_PIPELINE_REGISTRY.keys())}. " + f"Consider using --backend diffusers to use vanilla diffusers pipeline." + ) + return None + + # 3. Get configuration classes (sampling, pipeline config) + config_info = _get_config_info(model_path, model_id=model_id) + if not config_info: + if backend == Backend.AUTO: + logger.warning( + f"Could not resolve native configuration for model '{model_path}'. " + f"Falling back to diffusers backend." + ) + return _get_diffusers_model_info() + else: + logger.error( + f"Could not resolve configuration for model '{model_path}'. " + "It is not a registered model path or detected by any registered model family detectors. " + f"Known model paths: {list(_MODEL_HF_PATH_TO_NAME.keys())}. " + f"Consider using --backend diffusers to use vanilla diffusers pipeline." + ) + return None + + # 4. Combine and return the complete model info + logger.debug("Using native sglang backend for model '%s'", model_path) + model_info = ModelInfo( + pipeline_cls=pipeline_cls, + sampling_param_cls=config_info.sampling_param_cls, + pipeline_config_cls=config_info.pipeline_config_cls, + ) + logger.debug(f"Found model info: {model_info}") + + return model_info + + +# Registration of model configs +def _register_configs(): + # LTX-2 + register_configs( + sampling_param_cls=LTX2SamplingParams, + pipeline_config_cls=LTX2PipelineConfig, + model_detectors=[ + lambda path: "ltx" in path.lower() and "video" in path.lower(), + lambda path: "ltx-2" in path.lower(), + ], + ) + + # Hunyuan + register_configs( + sampling_param_cls=HunyuanSamplingParams, + pipeline_config_cls=HunyuanConfig, + hf_model_paths=[ + "hunyuanvideo-community/HunyuanVideo", + ], + model_detectors=[lambda hf_id: "hunyuanvideo" in hf_id.lower()], + ) + register_configs( + sampling_param_cls=FastHunyuanSamplingParam, + pipeline_config_cls=FastHunyuanConfig, + hf_model_paths=[ + "FastVideo/FastHunyuan-diffusers", + ], + ) + # Wan + register_configs( + sampling_param_cls=WanT2V_1_3B_SamplingParams, + pipeline_config_cls=WanT2V480PConfig, + hf_model_paths=[ + "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + ], + model_detectors=[lambda hf_id: "wanpipeline" in hf_id.lower()], + ) + register_configs( + sampling_param_cls=WanT2V_1_3B_SamplingParams, + pipeline_config_cls=TurboWanT2V480PConfig, + hf_model_paths=[ + "IPostYellow/TurboWan2.1-T2V-1.3B-Diffusers", + ], + ) + register_configs( + sampling_param_cls=WanT2V_14B_SamplingParams, + pipeline_config_cls=WanT2V720PConfig, + hf_model_paths=[ + "Wan-AI/Wan2.1-T2V-14B-Diffusers", + ], + ) + register_configs( + sampling_param_cls=WanT2V_14B_SamplingParams, + pipeline_config_cls=TurboWanT2V480PConfig, + hf_model_paths=[ + "IPostYellow/TurboWan2.1-T2V-14B-Diffusers", + "IPostYellow/TurboWan2.1-T2V-14B-720P-Diffusers", + ], + ) + register_configs( + sampling_param_cls=WanI2V_14B_480P_SamplingParam, + pipeline_config_cls=WanI2V480PConfig, + hf_model_paths=[ + "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers", + ], + model_detectors=[lambda hf_id: "wanimagetovideo" in hf_id.lower()], + ) + register_configs( + sampling_param_cls=WanI2V_14B_720P_SamplingParam, + pipeline_config_cls=WanI2V720PConfig, + hf_model_paths=[ + "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers", + ], + ) + register_configs( + sampling_param_cls=Turbo_Wan2_2_I2V_A14B_SamplingParam, + pipeline_config_cls=TurboWanI2V720Config, + hf_model_paths=[ + "IPostYellow/TurboWan2.2-I2V-A14B-Diffusers", + ], + ) + register_configs( + sampling_param_cls=Wan2_1_Fun_1_3B_InP_SamplingParams, + pipeline_config_cls=WanI2V480PConfig, + hf_model_paths=[ + "weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers", + ], + ) + register_configs( + sampling_param_cls=Wan2_2_TI2V_5B_SamplingParam, + pipeline_config_cls=Wan2_2_TI2V_5B_Config, + hf_model_paths=[ + "Wan-AI/Wan2.2-TI2V-5B-Diffusers", + ], + ) + register_configs( + sampling_param_cls=Wan2_2_TI2V_5B_SamplingParam, + pipeline_config_cls=FastWan2_2_TI2V_5B_Config, + hf_model_paths=[ + "FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers", + "FastVideo/FastWan2.2-TI2V-5B-Diffusers", + ], + ) + register_configs( + sampling_param_cls=Wan2_2_T2V_A14B_SamplingParam, + pipeline_config_cls=Wan2_2_T2V_A14B_Config, + hf_model_paths=["Wan-AI/Wan2.2-T2V-A14B-Diffusers"], + ) + register_configs( + sampling_param_cls=Wan2_2_I2V_A14B_SamplingParam, + pipeline_config_cls=Wan2_2_I2V_A14B_Config, + hf_model_paths=["Wan-AI/Wan2.2-I2V-A14B-Diffusers"], + ) + register_configs( + sampling_param_cls=FastWanT2V480PConfig, + pipeline_config_cls=FastWan2_1_T2V_480P_Config, + hf_model_paths=[ + "FastVideo/FastWan2.1-T2V-1.3B-Diffusers", + ], + ) + # MOVA + register_configs( + sampling_param_cls=MOVA_360P_SamplingParams, + pipeline_config_cls=MOVA360PConfig, + model_detectors=[ + lambda hf_id: "mova" in hf_id.lower() and "360p" in hf_id.lower() + ], + ) + register_configs( + sampling_param_cls=MOVA_720P_SamplingParams, + pipeline_config_cls=MOVA720PConfig, + model_detectors=[ + lambda hf_id: "mova" in hf_id.lower() and "720p" in hf_id.lower() + ], + ) + # FLUX + register_configs( + sampling_param_cls=FluxSamplingParams, + pipeline_config_cls=FluxPipelineConfig, + hf_model_paths=[ + "black-forest-labs/FLUX.1-dev", + ], + model_detectors=[lambda hf_id: "flux.1" in hf_id.lower()], + ) + register_configs( + sampling_param_cls=Flux2KleinSamplingParams, + pipeline_config_cls=Flux2KleinPipelineConfig, + hf_model_paths=[ + "black-forest-labs/FLUX.2-klein-4B", + "black-forest-labs/FLUX.2-klein-9B", + ], + model_detectors=[ + lambda hf_id: "flux.2-klein" in hf_id.lower() + or "flux2-klein" in hf_id.lower() + ], + ) + register_configs( + sampling_param_cls=FluxSamplingParams, + pipeline_config_cls=Flux2PipelineConfig, + hf_model_paths=[ + "black-forest-labs/FLUX.2-dev", + ], + model_detectors=[ + lambda hf_id: "flux.2" in hf_id.lower() and "klein" not in hf_id.lower() + ], + ) + register_configs( + sampling_param_cls=ZImageTurboSamplingParams, + pipeline_config_cls=ZImagePipelineConfig, + hf_model_paths=[ + "Tongyi-MAI/Z-Image-Turbo", + ], + model_detectors=[lambda hf_id: "z-image-turbo" in hf_id.lower()], + ) + register_configs( + sampling_param_cls=ZImageSamplingParams, + pipeline_config_cls=ZImagePipelineConfig, + hf_model_paths=[ + "Tongyi-MAI/Z-Image", + ], + model_detectors=[ + lambda hf_id: "z-image" in hf_id.lower() and "turbo" not in hf_id.lower() + ], + ) + # Qwen-Image + register_configs( + sampling_param_cls=QwenImageSamplingParams, + pipeline_config_cls=QwenImagePipelineConfig, + hf_model_paths=["Qwen/Qwen-Image"], + model_detectors=[ + lambda hf_id: "qwen-image" in hf_id.lower() + and "edit" not in hf_id.lower() + and "layered" not in hf_id.lower() + and "2512" not in hf_id.lower() + ], + ) + register_configs( + sampling_param_cls=QwenImage2512SamplingParams, + pipeline_config_cls=QwenImagePipelineConfig, + hf_model_paths=["Qwen/Qwen-Image-2512"], + model_detectors=[lambda hf_id: "qwen-image-2512" in hf_id.lower()], + ) + register_configs( + sampling_param_cls=QwenImageSamplingParams, + pipeline_config_cls=QwenImageEditPipelineConfig, + hf_model_paths=["Qwen/Qwen-Image-Edit"], + model_detectors=[ + lambda hf_id: "qwen-image-edit" in hf_id.lower() + and "2509" not in hf_id.lower() + and "2511" not in hf_id.lower() + ], + ) + + register_configs( + sampling_param_cls=QwenImageEditPlusSamplingParams, + pipeline_config_cls=QwenImageEditPlusPipelineConfig, + hf_model_paths=["Qwen/Qwen-Image-Edit-2509"], + model_detectors=[lambda hf_id: "qwen-image-edit-2509" in hf_id.lower()], + ) + + register_configs( + sampling_param_cls=QwenImageEditPlusSamplingParams, + pipeline_config_cls=QwenImageEditPlus_2511_PipelineConfig, + hf_model_paths=["Qwen/Qwen-Image-Edit-2511"], + model_detectors=[lambda hf_id: "qwen-image-edit-2511" in hf_id.lower()], + ) + + register_configs( + sampling_param_cls=QwenImageLayeredSamplingParams, + pipeline_config_cls=QwenImageLayeredPipelineConfig, + hf_model_paths=["Qwen/Qwen-Image-Layered"], + model_detectors=[lambda hf_id: "qwen-image-layered" in hf_id.lower()], + ) + + register_configs( + sampling_param_cls=GlmImageSamplingParams, + pipeline_config_cls=GlmImagePipelineConfig, + model_detectors=[lambda hf_id: "glm-image" in hf_id.lower()], + ) + register_configs( + sampling_param_cls=Hunyuan3DSamplingParams, + pipeline_config_cls=Hunyuan3D2PipelineConfig, + hf_model_paths=[ + "tencent/Hunyuan3D-2", + ], + model_detectors=[lambda hf_id: "hunyuan3d" in hf_id.lower()], + ) + + # Helios + register_configs( + sampling_param_cls=HeliosT2VSamplingParams, + pipeline_config_cls=HeliosT2VConfig, + hf_model_paths=[ + "BestWishYsh/Helios-Base", + ], + model_detectors=[ + lambda hf_id: "helios" in hf_id.lower() + and "mid" not in hf_id.lower() + and "distill" not in hf_id.lower() + ], + ) + register_configs( + sampling_param_cls=HeliosMidSamplingParams, + pipeline_config_cls=HeliosMidConfig, + hf_model_paths=[ + "BestWishYsh/Helios-Mid", + ], + ) + register_configs( + sampling_param_cls=HeliosDistilledSamplingParams, + pipeline_config_cls=HeliosDistilledConfig, + hf_model_paths=[ + "BestWishYsh/Helios-Distilled", + ], + ) + + +_register_configs() + + +# Known non-diffusers multimodal model patterns +# Maps pattern -> pipeline_name for models that don't have model_index.json +_NON_DIFFUSERS_MULTIMODAL_PATTERNS: Dict[str, str] = { + "hunyuan3d": "Hunyuan3D2Pipeline", +} + + +def is_known_non_diffusers_multimodal_model(model_path: str) -> bool: + model_path_lower = model_path.lower() + return any( + pattern in model_path_lower for pattern in _NON_DIFFUSERS_MULTIMODAL_PATTERNS + ) + + +def get_non_diffusers_pipeline_name(model_path: str) -> Optional[str]: + """Get the pipeline name for a known non-diffusers model.""" + model_path_lower = model_path.lower() + for pattern, pipeline_name in _NON_DIFFUSERS_MULTIMODAL_PATTERNS.items(): + if pattern in model_path_lower: + return pipeline_name + return None diff --git a/sglang/python/sglang/multimodal_gen/runtime/cache/__init__.py b/sglang/python/sglang/multimodal_gen/runtime/cache/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..62f0f8457f8fb685f846bdaf4f060fa27ae1a164 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/cache/__init__.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Cache acceleration module for SGLang-diffusion + +This module provides various caching strategies to accelerate +diffusion transformer (DiT) inference: + +- TeaCache: Temporal similarity-based caching for diffusion models +- cache-dit integration: Block-level caching with DBCache and TaylorSeer + +""" + +from sglang.multimodal_gen.runtime.cache.cache_dit_integration import ( + CacheDitConfig, + enable_cache_on_dual_transformer, + enable_cache_on_transformer, + get_scm_mask, +) +from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheContext, TeaCacheMixin + +__all__ = [ + # TeaCache (always available) + "TeaCacheContext", + "TeaCacheMixin", + # cache-dit integration (lazy-loaded, requires cache-dit package) + "CacheDitConfig", + "enable_cache_on_transformer", + "enable_cache_on_dual_transformer", + "get_scm_mask", +] diff --git a/sglang/python/sglang/multimodal_gen/runtime/cache/cache_dit_integration.py b/sglang/python/sglang/multimodal_gen/runtime/cache/cache_dit_integration.py new file mode 100644 index 0000000000000000000000000000000000000000..e0812488188b8f0abb7df518432b79760ef724c9 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/cache/cache_dit_integration.py @@ -0,0 +1,576 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +cache-dit integration module for SGLang DiT pipelines. + +This module provides helper functions to enable cache-dit acceleration +on transformer modules in SGLang's modular pipeline architecture. +""" + +from dataclasses import dataclass +from typing import List, Optional + +import torch +import torch.distributed as dist + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +import cache_dit +from cache_dit import ( + BlockAdapter, + DBCacheConfig, + ForwardPattern, + ParamsModifier, + TaylorSeerCalibratorConfig, + steps_mask, +) +from cache_dit.caching.block_adapters import BlockAdapterRegister +from cache_dit.parallelism import ParallelismBackend, ParallelismConfig + +from sglang.multimodal_gen.runtime.distributed.parallel_state import get_dit_group + +_original_similarity = None + + +def _patch_cache_dit_similarity(): + from cache_dit.caching.cache_contexts import cache_manager + + global _original_similarity + if _original_similarity is not None: + return + + _original_similarity = cache_manager.CachedContextManager.similarity + + def patched_similarity(self, t1, t2, *, threshold, parallelized=False, prefix="Fn"): + if not parallelized: + return _original_similarity( + self, + t1, + t2, + threshold=threshold, + parallelized=parallelized, + prefix=prefix, + ) + + sp_group = getattr(self, "_sglang_sp_group", None) + tp_group = getattr(self, "_sglang_tp_group", None) + tp_sp_group = getattr(self, "_sglang_tp_sp_group", None) + target_group = tp_sp_group or sp_group or tp_group + + if target_group is None: + return _original_similarity( + self, + t1, + t2, + threshold=threshold, + parallelized=parallelized, + prefix=prefix, + ) + + # Adapted from https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/cache_contexts/cache_manager.py#L495-L523 + condition_thresh = self.get_important_condition_threshold() + if condition_thresh > 0.0: + raw_diff = (t1 - t2).abs() + token_m_df = raw_diff.mean(dim=-1) + token_m_t1 = t1.abs().mean(dim=-1) + token_diff = token_m_df / token_m_t1 + condition = token_diff > condition_thresh + if condition.sum() > 0: + condition = condition.unsqueeze(-1).expand_as(raw_diff) + mean_diff = raw_diff[condition].mean() + mean_t1 = t1[condition].abs().mean() + else: + mean_diff = (t1 - t2).abs().mean() + mean_t1 = t1.abs().mean() + else: + mean_diff = (t1 - t2).abs().mean() + mean_t1 = t1.abs().mean() + + dist.all_reduce(mean_diff, op=dist.ReduceOp.AVG, group=target_group) + dist.all_reduce(mean_t1, op=dist.ReduceOp.AVG, group=target_group) + + diff = (mean_diff / mean_t1).item() + self.add_residual_diff(diff) + return diff < threshold + + cache_manager.CachedContextManager.similarity = patched_similarity + + +def _build_parallelism_config( + sp_group: Optional[torch.distributed.ProcessGroup], + tp_group: Optional[torch.distributed.ProcessGroup], +): + if sp_group is None and tp_group is None: + return None + + ulysses_size = None + ring_size = None + if sp_group is not None: + ulysses_size = getattr(sp_group, "ulysses_world_size", None) + ring_size = getattr(sp_group, "ring_world_size", None) + + tp_size = None + if tp_group is not None: + tp_size = dist.get_world_size(tp_group) + + return ParallelismConfig( + backend=ParallelismBackend.NATIVE_PYTORCH, + ulysses_size=ulysses_size, + ring_size=ring_size, + tp_size=tp_size, + ) + + +def _mark_transformer_parallelized(transformer, config, sp_group, tp_group): + if config is None: + return + + transformer._is_parallelized = True + transformer._parallelism_config = config + + +def get_scm_mask( + preset: str, + num_inference_steps: int, + compute_bins: Optional[List[int]] = None, + cache_bins: Optional[List[int]] = None, +) -> Optional[List[int]]: + """ + Get SCM mask using cache-dit's steps_mask(). + + This is a thin wrapper that delegates to cache-dit's built-in + steps_mask() function which handles all presets and scaling logic. + + Args: + preset: Preset name ("none", "slow", "medium", "fast", "ultra"). + compute_bins: Custom compute bins (overrides preset). + cache_bins: Custom cache bins (overrides preset). + + Returns: + SCM mask list (1=compute, 0=cache), or None if disabled. + """ + if preset == "none" and not (compute_bins and cache_bins): + return None + + # Use cache-dit's steps_mask() directly + mask = steps_mask( + compute_bins=compute_bins, + cache_bins=cache_bins, + total_steps=num_inference_steps, + mask_policy=preset if preset != "none" else "medium", + ) + + compute_count = sum(mask) + cache_count = len(mask) - compute_count + logger.info( + "SCM: generated mask with %d compute steps, %d cache steps (preset=%s)", + compute_count, + cache_count, + preset, + ) + + return mask + + +@dataclass +class CacheDitConfig: + """Configuration for cache-dit integration. + + Attributes: + enabled: Whether to enable cache-dit acceleration. + Fn_compute_blocks: Number of first blocks to always compute (DBCache F). + Bn_compute_blocks: Number of last blocks to always compute (DBCache B). + max_warmup_steps: Number of warmup steps before caching starts (DBCache W). + residual_diff_threshold: Threshold for residual difference (DBCache R). + max_continuous_cached_steps: Maximum consecutive cached steps (DBCache MC). + enable_taylorseer: Whether to enable TaylorSeer calibrator. + taylorseer_order: Order of Taylor expansion (1 or 2). + num_inference_steps: Total number of inference steps (required for transformer-only mode). + steps_computation_mask: Binary mask for step-level caching (1=compute, 0=cache). + Generated by get_scm_mask() (wrapper around cache_dit.steps_mask()). + steps_computation_policy: Caching policy for SCM ("dynamic" or "static"). + """ + + enabled: bool = False + Fn_compute_blocks: int = 1 + Bn_compute_blocks: int = 0 + # Use 4 as default warmup steps instead of 8 in cache-dit, thus making + # DBCache work for few steps distilled models, e.g., Z-Image w/ 8-steps. + max_warmup_steps: int = 4 + # Use a relatively higher residual diff threshold (namely, 0.24) as default + # to allow more aggressive caching due to we have already applied max continuous + # cached steps limit, otherwise, we should use a lower threshold here like 0.12. + residual_diff_threshold: float = 0.24 + max_continuous_cached_steps: int = 3 + # TaylorSeer is not suitable for few steps distilled models, so, we choose + # to disable it by default. Reference: + # - From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers, + # https://arxiv.org/pdf/2503.06923 + # - FoCa: Forecast then Calibrate: Feature Caching as ODE for Efficient + # Diffusion Transformers, https://arxiv.org/pdf/2508.16211 + enable_taylorseer: bool = False + taylorseer_order: int = 1 + num_inference_steps: Optional[int] = None + # SCM fields (generated by _maybe_enable_cache_dit from env configuration) + steps_computation_mask: Optional[List[int]] = None + steps_computation_policy: str = "dynamic" + + +def enable_cache_on_transformer( + transformer: torch.nn.Module, + config: CacheDitConfig, + model_name: str = "transformer", + sp_group: Optional[torch.distributed.ProcessGroup] = None, + tp_group: Optional[torch.distributed.ProcessGroup] = None, +) -> torch.nn.Module: + """Enable cache-dit on a transformer module, by wrapping the module with cache-dit + + This function enables cache-dit acceleration using the BlockAdapterRegister + for pre-registered models + + Args: + model_name: Name of the model for logging purposes. + sp_group: Sequence parallel process group (for Ulysses/Ring). + tp_group: Tensor parallel process group. + + """ + if not config.enabled: + return transformer + + if config.num_inference_steps is None: + raise ValueError( + "num_inference_steps is required for transformer-only mode. " + "Please provide it in CacheDitConfig." + ) + + # Check if the transformer is pre-registered in cache-dit + if not BlockAdapterRegister.is_supported(transformer): + transformer_cls_name = transformer.__class__.__name__ + raise ValueError( + f"{transformer_cls_name} is not officially supported by cache-dit. " + "Supported cache-dit DiT families include Flux, QwenImage, HunyuanDiT, " + "HunyuanVideo, Wan, CogVideoX, Mochi, and others. " + "Please ensure your transformer belongs to one of these families or " + "define a custom BlockAdapter." + ) + + # Build cache config (including SCM fields if provided) + cache_config = DBCacheConfig( + num_inference_steps=config.num_inference_steps, + Fn_compute_blocks=config.Fn_compute_blocks, + Bn_compute_blocks=config.Bn_compute_blocks, + max_warmup_steps=config.max_warmup_steps, + residual_diff_threshold=config.residual_diff_threshold, + max_continuous_cached_steps=config.max_continuous_cached_steps, + # SCM fields + steps_computation_mask=config.steps_computation_mask, + steps_computation_policy=config.steps_computation_policy, + ) + + # Build calibrator config if TaylorSeer is enabled + calibrator_config = None + if config.enable_taylorseer: + calibrator_config = TaylorSeerCalibratorConfig( + taylorseer_order=config.taylorseer_order, + ) + + # Enable cache-dit on the transformer + logger.info( + "Enabling cache-dit on %s with config: Fn=%d, Bn=%d, W=%d, R=%.2f, MC=%d, " + "TaylorSeer=%s (order=%d), steps=%d", + model_name, + config.Fn_compute_blocks, + config.Bn_compute_blocks, + config.max_warmup_steps, + config.residual_diff_threshold, + config.max_continuous_cached_steps, + config.enable_taylorseer, + config.taylorseer_order, + config.num_inference_steps, + ) + + # Log SCM configuration if enabled + if config.steps_computation_mask: + compute_steps = sum(config.steps_computation_mask) + cache_steps = len(config.steps_computation_mask) - compute_steps + logger.info( + "SCM enabled: %d compute steps, %d cache steps, policy=%s", + compute_steps, + cache_steps, + config.steps_computation_policy, + ) + + parallelism_config = _build_parallelism_config(sp_group, tp_group) + if parallelism_config is not None: + _patch_cache_dit_similarity() + + _mark_transformer_parallelized(transformer, parallelism_config, sp_group, tp_group) + + cache_dit.enable_cache( + transformer, + cache_config=cache_config, + calibrator_config=calibrator_config, + parallelism_config=None, + ) + + if parallelism_config is not None: + context_manager = getattr(transformer, "_context_manager", None) + if context_manager is not None: + context_manager._sglang_sp_group = sp_group + context_manager._sglang_tp_group = tp_group + # In mixed TP + SP (Ulysses/Ring) mode, cache-dit decisions must be consistent + # across the full TP×SP model-parallel slice. Prefer using SGLang's DIT group + # as a conservative superset group; fallback to None. + tp_sp_group = None + if sp_group is not None and tp_group is not None: + tp_sp_group = get_dit_group() + + context_manager._sglang_tp_sp_group = tp_sp_group + + return transformer + + +def enable_cache_on_dual_transformer( + transformer: torch.nn.Module, + transformer_2: torch.nn.Module, + primary_config: CacheDitConfig, + secondary_config: CacheDitConfig, + model_name: str = "wan2.2", + sp_group: Optional[torch.distributed.ProcessGroup] = None, + tp_group: Optional[torch.distributed.ProcessGroup] = None, +) -> tuple[torch.nn.Module, torch.nn.Module]: + """Enable cache-dit on dual transformers using BlockAdapter. + + For models with two transformers (high-noise expert and low-noise expert), + cache-dit requires enabling cache on both simultaneously via BlockAdapter. + This cannot be done by calling enable_cache separately on each transformer. + + Args: + primary_config: CacheDitConfig for primary transformer. + secondary_config: CacheDitConfig for secondary transformer. + sp_group: Sequence parallel process group (for Ulysses/Ring). + tp_group: Tensor parallel process group. + """ + _supported_dual_transformer_models = [ + "wan2.2", # Currently, only Wan2.2 will run into dual-transformer case + ] + if model_name not in _supported_dual_transformer_models: + raise ValueError( + f"Dual-transformer cache-dit is only supported for " + f"{_supported_dual_transformer_models}, got {model_name}." + ) + + if not primary_config.enabled: + return transformer, transformer_2 + + if primary_config.num_inference_steps is None: + raise ValueError( + "num_inference_steps is required for dual-transformer mode. " + "Please provide it in CacheDitConfig." + ) + + # Build DBCacheConfig for primary transformer + primary_cache_config = DBCacheConfig( + num_inference_steps=primary_config.num_inference_steps, + Fn_compute_blocks=primary_config.Fn_compute_blocks, + Bn_compute_blocks=primary_config.Bn_compute_blocks, + max_warmup_steps=primary_config.max_warmup_steps, + residual_diff_threshold=primary_config.residual_diff_threshold, + max_continuous_cached_steps=primary_config.max_continuous_cached_steps, + steps_computation_mask=primary_config.steps_computation_mask, + steps_computation_policy=primary_config.steps_computation_policy, + ) + + # Build DBCacheConfig for secondary transformer + secondary_cache_config = DBCacheConfig( + num_inference_steps=secondary_config.num_inference_steps, + Fn_compute_blocks=secondary_config.Fn_compute_blocks, + Bn_compute_blocks=secondary_config.Bn_compute_blocks, + max_warmup_steps=secondary_config.max_warmup_steps, + residual_diff_threshold=secondary_config.residual_diff_threshold, + max_continuous_cached_steps=secondary_config.max_continuous_cached_steps, + steps_computation_mask=secondary_config.steps_computation_mask, + steps_computation_policy=secondary_config.steps_computation_policy, + ) + + # Build calibrator configs if TaylorSeer is enabled + primary_calibrator = None + if primary_config.enable_taylorseer: + primary_calibrator = TaylorSeerCalibratorConfig( + taylorseer_order=primary_config.taylorseer_order, + ) + + secondary_calibrator = None + if secondary_config.enable_taylorseer: + secondary_calibrator = TaylorSeerCalibratorConfig( + taylorseer_order=secondary_config.taylorseer_order, + ) + + # Build ParamsModifier for each transformer + primary_modifier = ParamsModifier( + cache_config=primary_cache_config, + calibrator_config=primary_calibrator, + ) + secondary_modifier = ParamsModifier( + cache_config=secondary_cache_config, + calibrator_config=secondary_calibrator, + ) + + # Log configuration + logger.info( + "Enabling cache-dit on %s dual transformers with BlockAdapter", + model_name, + ) + logger.info( + " Primary (transformer): Fn=%d, Bn=%d, W=%d, R=%.2f, MC=%d, TaylorSeer=%s", + primary_config.Fn_compute_blocks, + primary_config.Bn_compute_blocks, + primary_config.max_warmup_steps, + primary_config.residual_diff_threshold, + primary_config.max_continuous_cached_steps, + primary_config.enable_taylorseer, + ) + logger.info( + " Secondary (transformer_2): Fn=%d, Bn=%d, W=%d, R=%.2f, MC=%d, TaylorSeer=%s", + secondary_config.Fn_compute_blocks, + secondary_config.Bn_compute_blocks, + secondary_config.max_warmup_steps, + secondary_config.residual_diff_threshold, + secondary_config.max_continuous_cached_steps, + secondary_config.enable_taylorseer, + ) + + # Log SCM configuration if enabled + if primary_config.steps_computation_mask: + compute_steps = sum(primary_config.steps_computation_mask) + cache_steps = len(primary_config.steps_computation_mask) - compute_steps + logger.info( + " SCM enabled for primary transformer: %d compute steps, %d cache steps, policy=%s", + compute_steps, + cache_steps, + primary_config.steps_computation_policy, + ) + if secondary_config.steps_computation_mask: + compute_steps = sum(secondary_config.steps_computation_mask) + cache_steps = len(secondary_config.steps_computation_mask) - compute_steps + logger.info( + " SCM enabled for secondary transformer: %d compute steps, %d cache steps, policy=%s", + compute_steps, + cache_steps, + secondary_config.steps_computation_policy, + ) + + parallelism_config = _build_parallelism_config(sp_group, tp_group) + if parallelism_config is not None: + _patch_cache_dit_similarity() + + _mark_transformer_parallelized(transformer, parallelism_config, sp_group, tp_group) + _mark_transformer_parallelized( + transformer_2, parallelism_config, sp_group, tp_group + ) + + # Get blocks attribute - Wan transformers use 'blocks' attribute + transformer_blocks = getattr(transformer, "blocks", None) + transformer_2_blocks = getattr(transformer_2, "blocks", None) + + if transformer_blocks is None or transformer_2_blocks is None: + raise ValueError( + "Dual transformers must have 'blocks' attribute for cache-dit. " + f"transformer has blocks: {transformer_blocks is not None}, " + f"transformer_2 has blocks: {transformer_2_blocks is not None}" + ) + + # Enable cache-dit using BlockAdapter for both transformers simultaneously + # This is required for Wan2.2 and similar dual-transformer architectures + if model_name == "wan2.2": + # Use Pattern_2 for Wan2.2 dual-transformer. We should check `model_name` + # to ensure we only apply this for supported models. Different models + # may require different ForwardPattern. + cache_dit.enable_cache( + BlockAdapter( + transformer=[transformer, transformer_2], + blocks=[transformer_blocks, transformer_2_blocks], + forward_pattern=[ForwardPattern.Pattern_2, ForwardPattern.Pattern_2], + params_modifiers=[primary_modifier, secondary_modifier], + has_separate_cfg=True, + ), + parallelism_config=None, + ) + else: + raise ValueError( + f"Dual-transformer is not implemented for model {model_name} yet." + ) + + if parallelism_config is not None: + for t in [transformer, transformer_2]: + context_manager = getattr(t, "_context_manager", None) + if context_manager is not None: + context_manager._sglang_sp_group = sp_group + context_manager._sglang_tp_group = tp_group + tp_sp_group = None + if sp_group is not None and tp_group is not None: + try: + tp_sp_group = get_dit_group() + except Exception: + tp_sp_group = None + context_manager._sglang_tp_sp_group = tp_sp_group + + return transformer, transformer_2 + + +def refresh_context_on_transformer( + transformer: torch.nn.Module, + num_inference_steps: int, + scm_preset: str | None = None, + verbose: bool = False, +) -> None: + """Refresh cache-dit context for transformer.""" + cache_dit.refresh_context( + transformer, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_inference_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=scm_preset, total_steps=num_inference_steps + ), + steps_computation_policy=scm_preset, + ), + verbose=verbose, + ) + logger.debug(f"cache-dit refreshed on transformer (steps={num_inference_steps})") + + +def refresh_context_on_dual_transformer( + transformer: torch.nn.Module, + transformer_2: torch.nn.Module, + num_high_noise_steps: int, + num_low_noise_steps: int, + scm_preset: str | None = None, + verbose: bool = False, +) -> None: + """Refresh cache-dit context for dual transformers.""" + cache_dit.refresh_context( + transformer, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_high_noise_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=scm_preset, total_steps=num_high_noise_steps + ), + steps_computation_policy=scm_preset, + ), + verbose=verbose, + ) + cache_dit.refresh_context( + transformer_2, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_low_noise_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=scm_preset, total_steps=num_low_noise_steps + ), + steps_computation_policy=scm_preset, + ), + verbose=verbose, + ) + logger.debug( + f"cache-dit refreshed on dual transformers (steps={num_high_noise_steps}, {num_low_noise_steps})" + ) diff --git a/sglang/python/sglang/multimodal_gen/runtime/cache/teacache.py b/sglang/python/sglang/multimodal_gen/runtime/cache/teacache.py new file mode 100644 index 0000000000000000000000000000000000000000..5cdafd08bc0449a9999f14c2c1a304e8e7d88978 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -0,0 +1,316 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +TeaCache: Temporal similarity-based caching for diffusion models. + +TeaCache accelerates diffusion inference by selectively skipping redundant +computation when consecutive diffusion steps are similar enough. This is +achieved by tracking the L1 distance between modulated inputs across timesteps. + +Key concepts: +- Modulated input: The input to transformer blocks after timestep conditioning +- L1 distance: Measures how different consecutive timesteps are +- Threshold: When accumulated L1 distance exceeds threshold, force computation +- CFG support: Separate caches for positive and negative branches + +References: +- TeaCache: Accelerating Diffusion Models with Temporal Similarity + https://arxiv.org/abs/2411.14324 +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch + +from sglang.multimodal_gen.configs.models import DiTConfig + +if TYPE_CHECKING: + from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams + + +@dataclass +class TeaCacheContext: + """Common context extracted for TeaCache skip decision. + + This context is populated from the forward_batch and forward_context + during each denoising step, providing all information needed to make + cache decisions. + + Attributes: + current_timestep: Current denoising timestep index (0-indexed). + num_inference_steps: Total number of inference steps. + do_cfg: Whether classifier-free guidance is enabled. + is_cfg_negative: True if currently processing negative CFG branch. + teacache_thresh: Threshold for accumulated L1 distance. + coefficients: Polynomial coefficients for L1 rescaling. + teacache_params: Full TeaCacheParams for model-specific access. + """ + + current_timestep: int + num_inference_steps: int + do_cfg: bool + is_cfg_negative: bool # For CFG branch selection + teacache_thresh: float + coefficients: list[float] + teacache_params: "TeaCacheParams" # Full params for model-specific access + + +class TeaCacheMixin: + """ + Mixin class providing TeaCache optimization functionality. + + TeaCache accelerates diffusion inference by selectively skipping redundant + computation when consecutive diffusion steps are similar enough. + + This mixin should be inherited by DiT model classes that want to support + TeaCache optimization. It provides: + - State management for tracking L1 distances + - CFG-aware caching (separate caches for positive/negative branches) + - Decision logic for when to compute vs. use cache + + Example usage in a DiT model: + class MyDiT(TeaCacheMixin, BaseDiT): + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self._init_teacache_state() + + def forward(self, hidden_states, timestep, ...): + ctx = self._get_teacache_context() + if ctx is not None: + # Compute modulated input (model-specific, e.g., after timestep embedding) + modulated_input = self._compute_modulated_input(hidden_states, timestep) + is_boundary = (ctx.current_timestep == 0 or + ctx.current_timestep >= ctx.num_inference_steps - 1) + + should_calc = self._compute_teacache_decision( + modulated_inp=modulated_input, + is_boundary_step=is_boundary, + coefficients=ctx.coefficients, + teacache_thresh=ctx.teacache_thresh, + ) + + if not should_calc: + # Use cached residual (must implement retrieve_cached_states) + return self.retrieve_cached_states(hidden_states) + + # Normal forward pass... + output = self._transformer_forward(hidden_states, timestep, ...) + + # Cache states for next step + if ctx is not None: + self.maybe_cache_states(output, hidden_states) + + return output + + Subclass implementation notes: + - `_compute_modulated_input()`: Model-specific method to compute the input + after timestep conditioning (used for L1 distance calculation) + - `retrieve_cached_states()`: Must be overridden to return cached output + - `maybe_cache_states()`: Override to store states for cache retrieval + + Attributes: + cnt: Counter for tracking steps. + enable_teacache: Whether TeaCache is enabled. + previous_modulated_input: Cached modulated input for positive branch. + previous_residual: Cached residual for positive branch. + accumulated_rel_l1_distance: Accumulated L1 distance for positive branch. + is_cfg_negative: Whether currently processing negative CFG branch. + _supports_cfg_cache: Whether this model supports CFG cache separation. + + CFG-specific attributes (only when _supports_cfg_cache is True): + previous_modulated_input_negative: Cached input for negative branch. + previous_residual_negative: Cached residual for negative branch. + accumulated_rel_l1_distance_negative: L1 distance for negative branch. + """ + + # Models that support CFG cache separation (wan/hunyuan/zimage) + # Models not in this set (flux/qwen) auto-disable TeaCache when CFG is enabled + _CFG_SUPPORTED_PREFIXES: set[str] = {"wan", "hunyuan", "zimage"} + config: DiTConfig + + def _init_teacache_state(self) -> None: + """Initialize TeaCache state. Call this in subclass __init__.""" + # Common TeaCache state + self.cnt = 0 + self.enable_teacache = True + # Flag indicating if this model supports CFG cache separation + self._supports_cfg_cache = ( + self.config.prefix.lower() in self._CFG_SUPPORTED_PREFIXES + ) + + # Always initialize positive cache fields (used in all modes) + self.previous_modulated_input: torch.Tensor | None = None + self.previous_residual: torch.Tensor | None = None + self.accumulated_rel_l1_distance: float = 0.0 + + self.is_cfg_negative = False + # CFG-specific fields initialized to None (created when CFG is used) + # These are only used when _supports_cfg_cache is True AND do_cfg is True + if self._supports_cfg_cache: + self.previous_modulated_input_negative: torch.Tensor | None = None + self.previous_residual_negative: torch.Tensor | None = None + self.accumulated_rel_l1_distance_negative: float = 0.0 + + def reset_teacache_state(self) -> None: + """Reset all TeaCache state at the start of each generation task.""" + self.cnt = 0 + + # Primary cache fields (always present) + self.previous_modulated_input = None + self.previous_residual = None + self.accumulated_rel_l1_distance = 0.0 + self.is_cfg_negative = False + self.enable_teacache = True + # CFG negative cache fields (always reset, may be unused) + if self._supports_cfg_cache: + self.previous_modulated_input_negative = None + self.previous_residual_negative = None + self.accumulated_rel_l1_distance_negative = 0.0 + + def _compute_l1_and_decide( + self, + modulated_inp: torch.Tensor, + coefficients: list[float], + teacache_thresh: float, + ) -> tuple[float, bool]: + """ + Compute L1 distance and decide whether to calculate or use cache. + + Args: + modulated_inp: Current timestep's modulated input. + coefficients: Polynomial coefficients for L1 rescaling. + teacache_thresh: Threshold for cache decision. + + Returns: + Tuple of (new_accumulated_distance, should_calc). + """ + prev_modulated_inp = ( + self.previous_modulated_input_negative + if self.is_cfg_negative + else self.previous_modulated_input + ) + + # Defensive check: if previous input is not set, force calculation + if prev_modulated_inp is None: + return 0.0, True + + # Compute relative L1 distance + diff = modulated_inp - prev_modulated_inp + rel_l1 = (diff.abs().mean() / prev_modulated_inp.abs().mean()).cpu().item() + + # Apply polynomial rescaling + rescale_func = np.poly1d(coefficients) + + accumulated_rel_l1_distance = ( + self.accumulated_rel_l1_distance_negative + if self.is_cfg_negative + else self.accumulated_rel_l1_distance + ) + accumulated_rel_l1_distance = accumulated_rel_l1_distance + rescale_func(rel_l1) + + if accumulated_rel_l1_distance >= teacache_thresh: + # Threshold exceeded: force compute and reset accumulator + return 0.0, True + # Cache hit: keep accumulated distance + return accumulated_rel_l1_distance, False + + def _compute_teacache_decision( + self, + modulated_inp: torch.Tensor, + is_boundary_step: bool, + coefficients: list[float], + teacache_thresh: float, + ) -> bool: + """ + Compute cache decision for TeaCache. + + Args: + modulated_inp: Current timestep's modulated input. + is_boundary_step: True for boundary timesteps that always compute. + coefficients: Polynomial coefficients for L1 rescaling. + teacache_thresh: Threshold for cache decision. + + Returns: + True if forward computation is needed, False to use cache. + """ + if not self.enable_teacache: + return True + + if is_boundary_step: + new_accum, should_calc = 0.0, True + else: + new_accum, should_calc = self._compute_l1_and_decide( + modulated_inp=modulated_inp, + coefficients=coefficients, + teacache_thresh=teacache_thresh, + ) + + # Advance baseline and accumulator for the active branch + if not self.is_cfg_negative: + self.previous_modulated_input = modulated_inp.clone() + self.accumulated_rel_l1_distance = new_accum + elif self._supports_cfg_cache: + self.previous_modulated_input_negative = modulated_inp.clone() + self.accumulated_rel_l1_distance_negative = new_accum + + return should_calc + + def _get_teacache_context(self) -> TeaCacheContext | None: + """ + Check TeaCache preconditions and extract common context. + + Returns: + TeaCacheContext if TeaCache is enabled and properly configured, + None if should skip TeaCache logic entirely. + """ + from sglang.multimodal_gen.runtime.managers.forward_context import ( + get_forward_context, + ) + + forward_context = get_forward_context() + forward_batch = forward_context.forward_batch + + # Early return checks + if ( + forward_batch is None + or not forward_batch.enable_teacache + or forward_batch.teacache_params is None + ): + return None + + teacache_params = forward_batch.teacache_params + + # Extract common values + current_timestep = forward_context.current_timestep + num_inference_steps = forward_batch.num_inference_steps + do_cfg = forward_batch.do_classifier_free_guidance + is_cfg_negative = forward_batch.is_cfg_negative + + # Reset at first timestep + if current_timestep == 0 and not self.is_cfg_negative: + self.reset_teacache_state() + + return TeaCacheContext( + current_timestep=current_timestep, + num_inference_steps=num_inference_steps, + do_cfg=do_cfg, + is_cfg_negative=is_cfg_negative, + teacache_thresh=teacache_params.teacache_thresh, + coefficients=teacache_params.coefficients, + teacache_params=teacache_params, + ) + + def maybe_cache_states( + self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor + ) -> None: + """Cache states for later retrieval. Override in subclass if needed.""" + pass + + def should_skip_forward_for_cached_states(self, **kwargs: dict[str, Any]) -> bool: + """Check if forward can be skipped using cached states.""" + return False + + def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Retrieve cached states. Must be implemented by subclass.""" + raise NotImplementedError("retrieve_cached_states is not implemented") diff --git a/sglang/python/sglang/multimodal_gen/runtime/distributed/__init__.py b/sglang/python/sglang/multimodal_gen/runtime/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b0d101d329fd936ae4c7e650067d07bef5501e74 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/distributed/__init__.py @@ -0,0 +1,71 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo +from functools import lru_cache + +from sglang.multimodal_gen.configs.models.encoders import TextEncoderConfig +from sglang.multimodal_gen.runtime.distributed.communication_op import * +from sglang.multimodal_gen.runtime.distributed.group_coordinator import ( + get_local_torch_device, +) +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + cleanup_dist_env_and_memory, + get_dp_group, + get_dp_rank, + get_dp_world_size, + get_sp_group, + get_sp_parallel_rank, + get_sp_world_size, + get_tp_group, + get_tp_rank, + get_tp_world_size, + get_world_group, + get_world_rank, + get_world_size, + init_distributed_environment, + initialize_model_parallel, + maybe_init_distributed_environment_and_model_parallel, + model_parallel_is_initialized, +) +from sglang.multimodal_gen.runtime.distributed.utils import * + +# SPDX-License-Identifier: Apache-2.0 + + +__all__ = [ + # Initialization + "init_distributed_environment", + "initialize_model_parallel", + "cleanup_dist_env_and_memory", + "model_parallel_is_initialized", + "maybe_init_distributed_environment_and_model_parallel", + # World group + "get_world_group", + "get_world_rank", + "get_world_size", + # Data parallel group + "get_dp_group", + "get_dp_rank", + "get_dp_world_size", + # Sequence parallel group + "get_sp_group", + "get_sp_parallel_rank", + "get_sp_world_size", + # Tensor parallel group + "get_tp_group", + "get_tp_rank", + "get_tp_world_size", + # Get torch device + "get_local_torch_device", +] + + +def _get_folding_tp_group( + config: TextEncoderConfig, +) -> torch.distributed.ProcessGroup | None: + if config.parallel_folding: + if config.parallel_folding_mode == "sp": + return get_sp_group() + elif config.parallel_folding_mode == "ulysses": + return get_sp_group().ulysses_group + elif config.parallel_folding_mode == "ring": + return get_sp_group().ring_group + return get_tp_group() diff --git a/sglang/python/sglang/multimodal_gen/runtime/distributed/communication_op.py b/sglang/python/sglang/multimodal_gen/runtime/distributed/communication_op.py new file mode 100644 index 0000000000000000000000000000000000000000..2714d7ce5119a6283a1a6c568bd9486279aee690 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/distributed/communication_op.py @@ -0,0 +1,59 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/communication_op.py + +import torch +import torch.distributed as dist + +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_cfg_group, + get_sp_group, + get_tp_group, +) + + +def tensor_model_parallel_all_reduce( + input_: torch.Tensor, tp_group: dist.ProcessGroup = None +) -> torch.Tensor: + """All-reduce the input tensor across model parallel group.""" + tp_group = tp_group or get_tp_group() + return tp_group.all_reduce(input_) + + +def tensor_model_parallel_all_gather( + input_: torch.Tensor, dim: int = -1, tp_group: dist.ProcessGroup = None +) -> torch.Tensor: + """All-gather the input tensor across model parallel group.""" + tp_group = tp_group or get_tp_group() + return tp_group.all_gather(input_, dim) + + +# TODO: remove model, make it sequence_parallel +def sequence_model_parallel_all_to_all_4D( + input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1 +) -> torch.Tensor: + """All-to-all communication of 4D tensors (e.g. QKV matrices) across sequence parallel group.""" + return get_sp_group().all_to_all_4D(input_, scatter_dim, gather_dim) + + +def sequence_model_parallel_all_gather( + input_: torch.Tensor, dim: int = -1 +) -> torch.Tensor: + """All-gather the input tensor across model parallel group.""" + return get_sp_group().all_gather(input_, dim) + + +def cfg_model_parallel_all_gather( + input_: torch.Tensor, dim: int = -1, separate_tensors: bool = False +) -> torch.Tensor: + """All-gather the input tensor across model parallel group.""" + return get_cfg_group().all_gather(input_, dim, separate_tensors) + + +def cfg_model_parallel_all_reduce( + input_: torch.Tensor, + op: torch._C._distributed_c10d.ReduceOp = torch._C._distributed_c10d.ReduceOp.SUM, +) -> torch.Tensor: + """All-reduce the input tensor across CFG parallel group.""" + return get_cfg_group().all_reduce(input_, op=op) diff --git a/sglang/python/sglang/multimodal_gen/runtime/distributed/device_communicators/__init__.py b/sglang/python/sglang/multimodal_gen/runtime/distributed/device_communicators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..af2eb7d103a81378d30056cf536353bd24621a5e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/distributed/device_communicators/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/sglang/python/sglang/multimodal_gen/runtime/distributed/device_communicators/base_device_communicator.py b/sglang/python/sglang/multimodal_gen/runtime/distributed/device_communicators/base_device_communicator.py new file mode 100644 index 0000000000000000000000000000000000000000..01bdf1c293e6fb24e2a5691b783773108f3b6346 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/distributed/device_communicators/base_device_communicator.py @@ -0,0 +1,297 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/base_device_communicator.py + +from typing import Any + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup, ReduceOp + + +class DistributedAutograd: + """Collection of autograd functions for distributed operations. + + This class provides custom autograd functions for distributed operations like all_reduce, + all_gather, and all_to_all. Each operation is implemented as a static inner class with + proper forward and backward implementations. + """ + + class AllReduce(torch.autograd.Function): + """Differentiable all_reduce operation. + + The gradient of all_reduce is another all_reduce operation since the operation + combines values from all ranks equally. + """ + + @staticmethod + def forward( + ctx: Any, + group: ProcessGroup, + input_: Tensor, + op: dist.ReduceOp | None = None, + ) -> Tensor: + ctx.group = group + ctx.op = op + output = input_.clone() + dist.all_reduce(output, group=group, op=op) + return output + + @staticmethod + def backward(ctx: Any, grad_output: Tensor) -> tuple[None, Tensor, None]: + grad_output = grad_output.clone() + dist.all_reduce(grad_output, group=ctx.group, op=ctx.op) + return None, grad_output, None + + class AllGather(torch.autograd.Function): + """Differentiable all_gather operation. + + The operation gathers tensors from all ranks and concatenates them along a specified dimension. + The backward pass uses reduce_scatter to efficiently distribute gradients back to source ranks. + """ + + @staticmethod + def forward( + ctx: Any, group: ProcessGroup, input_: Tensor, world_size: int, dim: int + ) -> Tensor: + ctx.group = group + ctx.world_size = world_size + ctx.dim = dim + ctx.input_shape = input_.shape + + input_size = input_.size() + output_size = (input_size[0] * world_size,) + input_size[1:] + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) + + dist.all_gather_into_tensor(output_tensor, input_, group=group) + + output_tensor = output_tensor.reshape((world_size,) + input_size) + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape( + input_size[:dim] + + (world_size * input_size[dim],) + + input_size[dim + 1 :] + ) + return output_tensor + + @staticmethod + def backward(ctx: Any, grad_output: Tensor) -> tuple[None, Tensor, None, None]: + # Split the gradient tensor along the gathered dimension + dim_size = grad_output.size(ctx.dim) // ctx.world_size + grad_chunks = grad_output.reshape( + grad_output.shape[: ctx.dim] + + (ctx.world_size, dim_size) + + grad_output.shape[ctx.dim + 1 :] + ) + grad_chunks = grad_chunks.movedim(ctx.dim, 0) + + # Each rank only needs its corresponding gradient + grad_input = torch.empty( + ctx.input_shape, dtype=grad_output.dtype, device=grad_output.device + ) + dist.reduce_scatter_tensor( + grad_input, grad_chunks.contiguous(), group=ctx.group + ) + + return None, grad_input, None, None + + class AllToAll4D(torch.autograd.Function): + """Differentiable all_to_all operation specialized for 4D tensors. + + This operation is particularly useful for attention operations where we need to + redistribute data across ranks for efficient parallel processing. + + The operation supports two modes: + 1. scatter_dim=2, gather_dim=1: Used for redistributing attention heads + 2. scatter_dim=1, gather_dim=2: Used for redistributing sequence dimensions + """ + + @staticmethod + def forward( + ctx: Any, + group: ProcessGroup, + input_: Tensor, + world_size: int, + scatter_dim: int, + gather_dim: int, + ) -> Tensor: + ctx.group = group + ctx.world_size = world_size + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + + if world_size == 1: + return input_ + + assert ( + input_.dim() == 4 + ), f"input must be 4D tensor, got {input_.dim()} and shape {input_.shape}" + + if scatter_dim == 2 and gather_dim == 1: + bs, shard_seqlen, hn, hd = input_.shape + seqlen = shard_seqlen * world_size + shard_hn = hn // world_size + + input_ = input_.transpose(0, 2).contiguous() # hn, shard_seqlen, bs, hd + output = torch.empty_like(input_) + + dist.all_to_all_single( + output, input_, group=group + ) # hn, shard_seqlen, bs, hd + + output = torch.cat( + output.split(shard_hn), dim=1 + ) # sharded hn, seqlen, bs, hd + + output = output.transpose( + 0, 2 + ).contiguous() # bs, seqlen, sharded_hn, hd + + return output + elif scatter_dim == 1 and gather_dim == 2: + bs, seqlen, shard_hn, hd = input_.shape + hn = shard_hn * world_size + shard_seqlen = seqlen // world_size + + input_ = input_.transpose(0, 2).contiguous() # shard_hn, seqlen, bs, hd + + input_ = ( + input_.reshape(shard_hn, world_size, shard_seqlen, bs, hd) + .transpose(0, 1) + .reshape(shard_hn * world_size, shard_seqlen, bs, hd) + .contiguous() + ) + + output = torch.empty_like(input_) + + dist.all_to_all_single(output, input_, group=group) + + output = output.transpose( + 0, 2 + ).contiguous() # bs, seqlen, sharded_hn, hd + + return output + else: + raise RuntimeError( + f"Invalid scatter_dim={scatter_dim}, gather_dim={gather_dim}. " + f"Only (scatter_dim=2, gather_dim=1) and (scatter_dim=1, gather_dim=2) are supported." + ) + + @staticmethod + def backward( + ctx: Any, grad_output: Tensor + ) -> tuple[None, Tensor, None, None, None]: + if ctx.world_size == 1: + return None, grad_output, None, None, None + + # For backward pass, we swap scatter_dim and gather_dim + output = DistributedAutograd.AllToAll4D.apply( + ctx.group, grad_output, ctx.world_size, ctx.gather_dim, ctx.scatter_dim + ) + return None, output, None, None, None + + +class DeviceCommunicatorBase: + """ + Base class for device-specific communicator with autograd support. + It can use the `cpu_group` to initialize the communicator. + If the device has PyTorch integration (PyTorch can recognize its + communication backend), the `device_group` will also be given. + """ + + def __init__( + self, + cpu_group: ProcessGroup, + device: torch.device | None = None, + device_group: ProcessGroup | None = None, + unique_name: str = "", + ): + self.device = device or torch.device("cpu") + self.cpu_group = cpu_group + self.device_group = device_group + self.unique_name = unique_name + self.rank = dist.get_rank(cpu_group) + self.world_size = dist.get_world_size(cpu_group) + self.ranks = dist.get_process_group_ranks(cpu_group) + self.global_rank = dist.get_rank() + self.global_world_size = dist.get_world_size() + self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank) + + def all_reduce( + self, input_: torch.Tensor, op: dist.ReduceOp | None = ReduceOp.SUM + ) -> torch.Tensor: + """Performs an all_reduce operation with gradient support.""" + return DistributedAutograd.AllReduce.apply(self.device_group, input_, op) + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + """Performs an all_gather operation with gradient support.""" + if dim < 0: + dim += input_.dim() + return DistributedAutograd.AllGather.apply( + self.device_group, input_, self.world_size, dim + ) + + def all_to_all_4D( + self, input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1 + ) -> torch.Tensor: + """Performs a 4D all-to-all operation with gradient support.""" + return DistributedAutograd.AllToAll4D.apply( + self.device_group, input_, self.world_size, scatter_dim, gather_dim + ) + + def gather( + self, input_: torch.Tensor, dst: int = 0, dim: int = -1 + ) -> torch.Tensor | None: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def send(self, tensor: torch.Tensor, dst: int | None = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + torch.distributed.send(tensor, self.ranks[dst], self.device_group) + + def recv( + self, size: torch.Size, dtype: torch.dtype, src: int | None = None + ) -> torch.Tensor: + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + if src is None: + src = (self.rank_in_group - 1) % self.world_size + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.recv(tensor, self.ranks[src], self.device_group) + return tensor + + def destroy(self) -> None: + pass diff --git a/sglang/python/sglang/multimodal_gen/runtime/distributed/device_communicators/cpu_communicator.py b/sglang/python/sglang/multimodal_gen/runtime/distributed/device_communicators/cpu_communicator.py new file mode 100644 index 0000000000000000000000000000000000000000..434cf384de73e1deea164e94b62d4ace0b8c88e7 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/distributed/device_communicators/cpu_communicator.py @@ -0,0 +1,161 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/cpu_communicator.py + +import os + +import torch +from torch.distributed import ProcessGroup + +from .base_device_communicator import DeviceCommunicatorBase + + +class CpuCommunicator(DeviceCommunicatorBase): + + def __init__( + self, + cpu_group: ProcessGroup, + device: torch.device | None = None, + device_group: ProcessGroup | None = None, + unique_name: str = "", + ): + from sglang.multimodal_gen.runtime.platforms import current_platform + from sglang.multimodal_gen.runtime.platforms.interface import CpuArchEnum + + super().__init__(cpu_group, device, device_group, unique_name) + self.dist_module = torch.distributed + + if ( + (current_platform.get_cpu_architecture() == CpuArchEnum.X86) + and hasattr(torch.ops._C, "init_shm_manager") + and unique_name.startswith("tp") + ): + self.dist_module = _CPUSHMDistributed(self) + + def all_reduce( + self, + input_: torch.Tensor, + op: torch.distributed.ReduceOp | None = torch.distributed.ReduceOp.SUM, + ) -> torch.Tensor: + self.dist_module.all_reduce(input_, group=self.device_group, op=op) + return input_ + + def gather( + self, input_: torch.Tensor, dst: int = 0, dim: int = -1 + ) -> torch.Tensor | None: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + + # Gather. + self.dist_module.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) + + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # NOTE: we have to use concat-style all-gather here, + # stack-style all-gather has compatibility issues with + # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 + output_size = (input_size[0] * self.world_size,) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) + # All-gather. + self.dist_module.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) + + # Reshape + output_tensor = output_tensor.reshape((self.world_size,) + input_size) + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape( + input_size[:dim] + + (self.world_size * input_size[dim],) + + input_size[dim + 1 :] + ) + return output_tensor + + +class _CPUSHMDistributed: + + def __init__(self, communicator: CpuCommunicator): + instance_identifier = os.environ["VLLM_DIST_IDENT"] + unique_name = communicator.unique_name + instance_identifier = f"{instance_identifier}-{unique_name}" + self.communicator = communicator + + group_ranks = [str(rank) for rank in self.communicator.ranks] + shm_group_identifier = f"[{'-'.join(group_ranks)}]" + self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm" + + self.handle = self._init_cpu_shm() + + def _init_cpu_shm(self) -> int: + handle = torch.ops._C.init_shm_manager( + self.group_name, + self.communicator.world_size, + self.communicator.rank, + ) + torch.distributed.barrier(self.communicator.device_group) + torch.ops._C.join_shm_manager( + handle, + self.group_name, + ) + torch.distributed.barrier(self.communicator.device_group) + + return int(handle) + + def all_reduce( + self, input: torch.Tensor, group: ProcessGroup | None = None + ) -> None: + torch.ops._C.shm_allreduce(self.handle, input) + + def gather( + self, + input: torch.Tensor, + gather_list: list[torch.Tensor] | None, + dst: int = -1, + group: ProcessGroup | None = None, + ) -> None: + # Note: different from the torch gather, here we use local dst rank. + torch.ops._C.shm_gather( + self.handle, + input, + gather_list, + torch.distributed.get_group_rank(group, dst), + ) + + def all_gather_into_tensor( + self, + output: torch.Tensor, + input: torch.Tensor, + group: ProcessGroup | None = None, + ) -> None: + torch.ops._C.shm_all_gather(self.handle, input, output) diff --git a/sglang/python/sglang/multimodal_gen/runtime/distributed/device_communicators/cuda_communicator.py b/sglang/python/sglang/multimodal_gen/runtime/distributed/device_communicators/cuda_communicator.py new file mode 100644 index 0000000000000000000000000000000000000000..c128c69fce13cbb6b87ac32742c91b452a4c9303 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/distributed/device_communicators/cuda_communicator.py @@ -0,0 +1,79 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/cuda_communicator.py + +import torch +from torch.distributed import ProcessGroup + +from sglang.multimodal_gen.runtime.distributed.device_communicators.base_device_communicator import ( + DeviceCommunicatorBase, +) + + +class CudaCommunicator(DeviceCommunicatorBase): + + def __init__( + self, + cpu_group: ProcessGroup, + device: torch.device | None = None, + device_group: ProcessGroup | None = None, + unique_name: str = "", + ): + super().__init__(cpu_group, device, device_group, unique_name) + + from sglang.multimodal_gen.runtime.distributed.device_communicators.pynccl import ( + PyNcclCommunicator, + ) + + self.pynccl_comm: PyNcclCommunicator | None = None + if self.world_size > 1: + self.pynccl_comm = PyNcclCommunicator( + group=self.cpu_group, + device=self.device, + ) + + def all_reduce(self, input_, op: torch.distributed.ReduceOp | None = None): + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None + out = pynccl_comm.all_reduce(input_, op=op) + if out is None: + # fall back to the default all-reduce using PyTorch. + # this usually happens during testing. + # when we run the model, allreduce only happens for the TP + # group, where we always have either custom allreduce or pynccl. + out = input_.clone() + torch.distributed.all_reduce(out, group=self.device_group, op=op) + return out + + def send(self, tensor: torch.Tensor, dst: int | None = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.send(tensor, dst) + else: + torch.distributed.send(tensor, self.ranks[dst], self.device_group) + + def recv( + self, size: torch.Size, dtype: torch.dtype, src: int | None = None + ) -> torch.Tensor: + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + if src is None: + src = (self.rank_in_group - 1) % self.world_size + + tensor = torch.empty(size, dtype=dtype, device=self.device) + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.recv(tensor, src) + else: + torch.distributed.recv(tensor, self.ranks[src], self.device_group) + return tensor + + def destroy(self) -> None: + if self.pynccl_comm is not None: + self.pynccl_comm = None diff --git a/sglang/python/sglang/multimodal_gen/runtime/distributed/device_communicators/pynccl.py b/sglang/python/sglang/multimodal_gen/runtime/distributed/device_communicators/pynccl.py new file mode 100644 index 0000000000000000000000000000000000000000..2d1ef558ad1249790e5df9d9c6a2f070725f08c5 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/distributed/device_communicators/pynccl.py @@ -0,0 +1,258 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/pynccl.py + +# ===================== import region ===================== +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup, ReduceOp + +from sglang.multimodal_gen.runtime.distributed.device_communicators.pynccl_wrapper import ( + NCCLLibrary, + buffer_type, + cudaStream_t, + ncclComm_t, + ncclDataTypeEnum, + ncclRedOpTypeEnum, + ncclUniqueId, +) +from sglang.multimodal_gen.runtime.distributed.utils import StatelessProcessGroup +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import current_stream + +logger = init_logger(__name__) + + +class PyNcclCommunicator: + + def __init__( + self, + group: ProcessGroup | StatelessProcessGroup, + device: int | str | torch.device, + library_path: str | None = None, + ): + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the PyNcclCommunicator to. If None, + it will be bind to f"cuda:{local_rank}". + library_path: the path to the NCCL library. If None, it will + use the default library path. + It is the caller's responsibility to make sure each communicator + is bind to a unique device. + """ + if not isinstance(group, StatelessProcessGroup): + assert dist.is_initialized() + assert ( + dist.get_backend(group) != dist.Backend.NCCL + ), "PyNcclCommunicator should be attached to a non-NCCL group." + # note: this rank is the rank in the group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) + else: + self.rank = group.rank + self.world_size = group.world_size + + self.group = group + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + return + try: + self.nccl = NCCLLibrary(library_path) + except Exception: + # disable because of missing NCCL library + # e.g. in a non-GPU environment + self.available = False + self.disabled = True + return + + self.available = True + self.disabled = False + + logger.info("sglang-diffusion is using nccl==%s", self.nccl.ncclGetVersion()) + + if self.rank == 0: + # get the unique id from NCCL + self.unique_id = self.nccl.ncclGetUniqueId() + else: + # construct an empty unique id + self.unique_id = ncclUniqueId() + + if not isinstance(group, StatelessProcessGroup): + tensor = torch.ByteTensor(list(self.unique_id.internal)) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) + byte_list = tensor.tolist() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + else: + self.unique_id = group.broadcast_obj(self.unique_id, src=0) + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + # nccl communicator and stream will use this device + # `torch.cuda.device` is a context manager that changes the + # current cuda device to the specified one + with torch.cuda.device(device): + self.comm: ncclComm_t = self.nccl.ncclCommInitRank( + self.world_size, self.unique_id, self.rank + ) + + stream = current_stream() + # A small all_reduce for warmup. + data = torch.zeros(1, device=device) + self.all_reduce(data) + if stream is not None: + stream.synchronize() + del data + + def all_reduce( + self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None + ) -> torch.Tensor: + if self.disabled: + return None + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert in_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {in_tensor.device}" + ) + + out_tensor = torch.empty_like(in_tensor) + + if stream is None: + stream = current_stream() + self.nccl.ncclAllReduce( + buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + ncclDataTypeEnum.from_torch(in_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + return out_tensor + + def all_gather( + self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}" + ) + if stream is None: + stream = current_stream() + self.nccl.ncclAllGather( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), + input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def reduce_scatter( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None, + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}" + ) + if stream is None: + stream = current_stream() + self.nccl.ncclReduceScatter( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), + output_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def send(self, tensor: torch.Tensor, dst: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = current_stream() + self.nccl.ncclSend( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + dst, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def recv(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = current_stream() + self.nccl.ncclRecv( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = current_stream() + if src == self.rank: + sendbuff = buffer_type(tensor.data_ptr()) + # NCCL requires the sender also to have a receive buffer + recvbuff = buffer_type(tensor.data_ptr()) + else: + sendbuff = buffer_type() + recvbuff = buffer_type(tensor.data_ptr()) + self.nccl.ncclBroadcast( + sendbuff, + recvbuff, + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) diff --git a/sglang/python/sglang/multimodal_gen/runtime/distributed/device_communicators/pynccl_wrapper.py b/sglang/python/sglang/multimodal_gen/runtime/distributed/device_communicators/pynccl_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..671adb5498fecfee1aa22f38e86627ea18f8d8aa --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/distributed/device_communicators/pynccl_wrapper.py @@ -0,0 +1,450 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/pynccl_wrapper.py + +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `SGLANG_DIFFUSION_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +# TODO(will): support SGLANG_DIFFUSION_NCCL_SO_PATH + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any + +import torch +from torch.distributed import ReduceOp + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import find_nccl_library + +logger = init_logger(__name__) + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: list[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function( + "ncclCommInitRank", + ncclResult_t, + [ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int], + ), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclAllReduce", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclAllGather( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclAllGather", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclReduceScatter( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclReduceScatter", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclSend( + # const void* sendbuff, size_t count, ncclDataType_t datatype, + # int dest, ncclComm_t comm, cudaStream_t stream); + Function( + "ncclSend", + ncclResult_t, + [ + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclRecv( + # void* recvbuff, size_t count, ncclDataType_t datatype, + # int src, ncclComm_t comm, cudaStream_t stream); + Function( + "ncclRecv", + ncclResult_t, + [ + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclBroadcast( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, int root, ncclComm_t comm, + # cudaStream_t stream); + Function( + "ncclBroadcast", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: dict[str, dict[str, Any]] = {} + + def __init__(self, so_file: str | None = None): + + so_file = so_file or find_nccl_library() + + try: + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + except Exception as e: + logger.error( + "Failed to load NCCL library from %s ." + "It is expected if you are not running on NVIDIA/AMD/MTHREADS GPUs." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s." + "If you already have the library, please set the " + "environment variable SGLANG_DIFFUSION_NCCL_SO_PATH" + " to point to the correct nccl library path.", + so_file, + platform.platform(), + ) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs: dict[str, Any] = {} + for func in NCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return str(self._funcs["ncclGetErrorString"](result).decode("utf-8")) + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL error: {error_str}") + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id))) + return unique_id + + def ncclCommInitRank( + self, world_size: int, unique_id: ncclUniqueId, rank: int + ) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK( + self._funcs["ncclCommInitRank"]( + ctypes.byref(comm), world_size, unique_id, rank + ) + ) + return comm + + def ncclAllReduce( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK( + self._funcs["ncclAllReduce"]( + sendbuff, recvbuff, count, datatype, op, comm, stream + ) + ) + + def ncclReduceScatter( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK( + self._funcs["ncclReduceScatter"]( + sendbuff, recvbuff, count, datatype, op, comm, stream + ) + ) + + def ncclAllGather( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK( + self._funcs["ncclAllGather"]( + sendbuff, recvbuff, count, datatype, comm, stream + ) + ) + + def ncclSend( + self, + sendbuff: buffer_type, + count: int, + datatype: int, + dest: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream) + ) + + def ncclRecv( + self, + recvbuff: buffer_type, + count: int, + datatype: int, + src: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream) + ) + + def ncclBroadcast( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + root: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclBroadcast"]( + sendbuff, recvbuff, count, datatype, root, comm, stream + ) + ) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + + +__all__ = [ + "NCCLLibrary", + "ncclDataTypeEnum", + "ncclRedOpTypeEnum", + "ncclUniqueId", + "ncclComm_t", + "cudaStream_t", + "buffer_type", +] diff --git a/sglang/python/sglang/multimodal_gen/runtime/distributed/group_coordinator.py b/sglang/python/sglang/multimodal_gen/runtime/distributed/group_coordinator.py new file mode 100644 index 0000000000000000000000000000000000000000..cabd056ff7b7932a75f01dec987bd5348ad09f43 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/distributed/group_coordinator.py @@ -0,0 +1,1222 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import pickle +from collections import namedtuple +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed +from torch.cuda import synchronize +from torch.distributed import Backend, ProcessGroup + +from sglang.multimodal_gen.runtime.distributed.device_communicators.base_device_communicator import ( + DeviceCommunicatorBase, +) +from sglang.multimodal_gen.runtime.distributed.device_communicators.cpu_communicator import ( + CpuCommunicator, +) +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.utils.logging_utils import ( + init_logger, + suppress_stdout, +) + +try: + import torch_musa # noqa: F401 + from torch_musa.core.device import synchronize +except ModuleNotFoundError: + pass + +logger = init_logger(__name__) + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +_group_name_counter: dict[str, int] = {} + + +def get_local_torch_device() -> torch.device: + """Return the torch device for the current rank.""" + + return current_platform.get_local_torch_device() + + +def _get_unique_name(name: str) -> str: + """Get a unique name for the group. + Example: + _get_unique_name("tp") -> "tp:0" + _get_unique_name("tp") -> "tp:1" + """ + if name not in _group_name_counter: + _group_name_counter[name] = 0 + newname = f"{name}:{_group_name_counter[name]}" + _group_name_counter[name] += 1 + return newname + + +def _split_tensor_dict( + tensor_dict: Dict[str, Union[torch.Tensor, Any]], prefix: str = "" +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + + If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its + metadata will be "key1%key2". + """ + metadata_list: List[Tuple[str, Any]] = [] + tensor_list = [] + for key, value in tensor_dict.items(): + assert "%" not in key, ( + "Avoid having '%' in key " + "as it is used as a separator for nested entries." + ) + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + ( + prefix + key, + TensorMetadata(device, value.dtype, value.size()), + ) + ) + tensor_list.append(value) + elif isinstance(value, dict): + if len(value) == 0: + metadata_list.append((prefix + key, value)) + inner_metadata_list, inner_tensor_list = _split_tensor_dict( + value, prefix + key + "%" + ) + metadata_list.extend(inner_metadata_list) + tensor_list.extend(inner_tensor_list) + else: + metadata_list.append((prefix + key, value)) + return metadata_list, tensor_list + + +def _update_nested_dict(nested_dict, flattened_key, value): + key_splits = flattened_key.split("%") + cur_dict = nested_dict + for k in key_splits[:-1]: + if k not in cur_dict: + cur_dict[k] = {} + cur_dict = cur_dict[k] + cur_dict[key_splits[-1]] = value + + +@dataclass +class GraphCaptureContext: + stream: torch.cuda.Stream | None + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and cuda graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank in the current node, used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + use_device_communicator: bool # whether to use device communicator + device_communicator: DeviceCommunicatorBase # device communicator + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + use_device_communicator: bool = True, + use_message_queue_broadcaster: bool = False, + group_name: str | None = None, + ): + self.unique_name = _get_unique_name(group_name) + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + with suppress_stdout(): + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None, f"{group_ranks=}, {local_rank=}" + assert self.device_group is not None + + # TODO: fix it for other platforms + self.device = get_local_torch_device() + + self.use_device_communicator = use_device_communicator + + self.device_communicator: DeviceCommunicatorBase = None # type: ignore + if use_device_communicator and self.world_size > 1: + # Platform-aware device communicator selection + if current_platform.is_cuda_alike(): + from sglang.multimodal_gen.runtime.distributed.device_communicators.cuda_communicator import ( + CudaCommunicator, + ) + + self.device_communicator = CudaCommunicator( + cpu_group=self.cpu_group, + device=self.device, + device_group=self.device_group, + unique_name=self.unique_name, + ) + else: + # For MPS and CPU, use the CPU communicator + self.device_communicator = CpuCommunicator( + cpu_group=self.cpu_group, + device=self.device, + device_group=self.device_group, + unique_name=self.unique_name, + ) + + self.mq_broadcaster = None + + # TODO(will): check if this is needed + # self.use_custom_op_call = current_platform.is_cuda_alike() + self.use_custom_op_call = False + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @property + def group_next_rank(self): + """Return the group rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group + 1) % world_size + + @property + def group_prev_rank(self): + """Return the group rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group - 1) % world_size + + @property + def skip_rank(self): + """Return the global rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(world_size - rank_in_group - 1) % world_size] + + @property + def group_skip_rank(self): + """Return the group rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (world_size - rank_in_group - 1) % world_size + + @contextmanager + def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None): + if current_platform.is_cuda_alike(): + if graph_capture_context is None: + stream = torch.cuda.Stream() + graph_capture_context = GraphCaptureContext(stream) + else: + stream = graph_capture_context.stream + + # ensure all initialization operations complete before attempting to + # capture the graph on another stream + curr_stream = torch.cuda.current_stream() + if curr_stream != stream: + stream.wait_stream(curr_stream) + + with torch.cuda.stream(stream): + yield graph_capture_context + else: + # For non-CUDA platforms (MPS, CPU), just yield the context without stream management + if graph_capture_context is None: + # Create a dummy context for non-CUDA platforms + graph_capture_context = GraphCaptureContext(None) + yield graph_capture_context + + def all_to_all_4D( + self, input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1 + ) -> torch.Tensor: + if self.world_size == 1: + return input_ + return self.device_communicator.all_to_all_4D(input_, scatter_dim, gather_dim) + + def all_reduce( + self, + input_: torch.Tensor, + op=torch._C._distributed_c10d.ReduceOp.SUM, + async_op: bool = False, + ) -> torch.Tensor: + """ + NOTE: This operation will be applied in-place or out-of-place. + Always assume this function modifies its input, but use the return + value as the output. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + else: + torch.distributed.all_reduce( + input_, op=op, group=self.device_group, async_op=async_op + ) + return input_ + + def all_gather( + self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False + ) -> Union[torch.Tensor, List[torch.Tensor]]: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + input_size = list(input_.size()) + input_size[0] *= world_size + output_tensor = torch.empty( + input_size, dtype=input_.dtype, device=input_.device + ) + # All-gather. + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) + if dim != 0: + input_size[0] //= world_size + output_tensor = output_tensor.reshape( + [ + world_size, + ] + + input_size + ) + output_tensor = output_tensor.movedim(0, dim) + + if separate_tensors: + tensor_list = [ + output_tensor.reshape(-1) + .narrow(0, input_.numel() * i, input_.numel()) + .view_as(input_) + for i in range(world_size) + ] + return tensor_list + else: + input_size = list(input_.size()) + input_size[dim] = input_size[dim] * world_size + # Reshape + output_tensor = output_tensor.reshape(input_size) + return output_tensor + + def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Tensor: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0, async_op: bool = False): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast( + input_, + src=self.ranks[src], + group=self.device_group, + async_op=async_op, + ) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.shm_broadcaster is not None: + assert src == 0, "Shared memory broadcaster only supports src=0" + return self.shm_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list( + [obj], src=self.ranks[src], group=self.cpu_group + ) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=self.ranks[src], group=self.cpu_group + ) + return recv[0] + + def broadcast_object_list( + self, + obj_list: List[Any], + src: int = 0, + group: Optional[ProcessGroup] = None, + ): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list( + obj_list, src=self.ranks[src], group=self.device_group + ) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + assert dst != self.rank, ( + "Invalid destination rank. Destination rank is the same " + "as the current rank." + ) + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor( + [object_tensor.numel()], dtype=torch.long, device="cpu" + ) + + # Send object size + + torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + assert src < self.world_size, f"Invalid src rank ({src})" + + assert ( + src != self.rank + ), "Invalid source rank. Source rank is the same as the current rank." + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv( + size_tensor, src=self.ranks[src], group=self.cpu_group + ) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu", + ) + + rank_object = torch.distributed.recv( + object_tensor, src=self.ranks[src], group=self.cpu_group + ) + + assert ( + rank_object == rank_size + ), "Received object sender rank does not match the size sender rank." + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + src = self.ranks[src] + + rank = self.rank + if rank == src: + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, dict + ), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = self.group_next_rank + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, dict + ), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send( + tensor, dst=self.ranks[dst], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.send(tensor, dst=self.ranks[dst], group=group) + return None + + def recv_tensor_dict( + self, src: Optional[int] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = self.group_prev_rank + assert src < self.world_size, f"Invalid src rank ({src})" + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: Dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv( + tensor, src=self.ranks[src], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, src=self.ranks[src], group=group) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the rank_in_group of the destination rank.""" + if dst is None: + dst = self.group_next_rank + + torch.distributed.send( + tensor, + self.ranks[dst], + group=( + self.device_groups[self.rank_in_group % 2] + if self.world_size == 2 + else self.device_group + ), + ) + + def recv( + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + ) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the rank_in_group of the source rank.""" + if src is None: + src = self.group_prev_rank + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.recv( + tensor, + self.ranks[src], + ( + self.device_groups[(self.rank_in_group + 1) % 2] + if self.world_size == 2 + else self.device_group + ), + ) + return tensor + + def destroy(self) -> None: + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + if self.device_communicator is not None: + self.device_communicator.destroy() + if self.mq_broadcaster is not None: + self.mq_broadcaster = None + + +class PipelineGroupCoordinator(GroupCoordinator): + """ + available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + difference between `local_rank` and `rank_in_group`: + if we have a group of size 4 across two nodes: + Process | Node | Rank | Local Rank | Rank in Group + 0 | 0 | 0 | 0 | 0 + 1 | 0 | 1 | 1 | 1 + 2 | 1 | 2 | 0 | 2 + 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + """ + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + group_name: str | None = None, + ): + super().__init__( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=torch_distributed_backend, + group_name=group_name, + ) + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + self.cpu_groups = [] + self.device_groups = [] + if len(group_ranks[0]) > 2 or len(group_ranks[0]) == 1: + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + with suppress_stdout(): + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + # when pipeline parallelism is 2, we need to create two groups to avoid + # communication stall. + # *_group_0_1 represents the group for communication from device 0 to + # device 1. + # *_group_1_0 represents the group for communication from device 1 to + # device 0. + elif len(group_ranks[0]) == 2: + for ranks in group_ranks: + device_group_0_1 = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + device_group_1_0 = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + with suppress_stdout(): + cpu_group_0_1 = torch.distributed.new_group(ranks, backend="gloo") + cpu_group_1_0 = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_groups = [device_group_0_1, device_group_1_0] + self.cpu_groups = [cpu_group_0_1, cpu_group_1_0] + self.device_group = device_group_0_1 + self.cpu_group = cpu_group_0_1 + + assert self.cpu_group is not None + assert self.device_group is not None + + self.device = current_platform.get_device(local_rank) + + self.recv_buffer_set: bool = False + self.recv_tasks_queue: List[Tuple[str, int]] = [] + self.receiving_tasks: List[Tuple[torch.distributed.Work, str, int]] = [] + self.dtype: Optional[torch.dtype] = None + self.num_pipefusion_patches: Optional[int] = None + + self.recv_shape: Dict[str, Dict[int, torch.Size]] = {} + self.send_shape: Dict[str, Dict[int, torch.Size]] = {} + self.recv_buffer: Dict[str, Dict[int, torch.Size]] = {} + + self.skip_tensor_recv_buffer_set: bool = False + self.recv_skip_tasks_queue: List[Union[int, Tuple[str, int]]] = [] + self.receiving_skip_tasks: List[Tuple[torch.distributed.Work, str, int]] = [] + self.skip_tensor_recv_buffer: Optional[ + Union[List[torch.Tensor], torch.Tensor] + ] = None + self.skip_device_group = None + for ranks in group_ranks: + skip_device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + if self.rank in ranks: + self.skip_device_group = skip_device_group + assert self.skip_device_group is not None + + def reset_buffer(self): + self.recv_tasks_queue = [] + self.receiving_tasks = [] + self.recv_shape = {} + self.send_shape = {} + self.recv_buffer = {} + + self.recv_skip_tasks_queue = [] + self.receiving_skip_tasks = [] + self.skip_tensor_recv_buffer = {} + + def set_config(self, dtype: torch.dtype): + self.dtype = dtype + + def set_recv_buffer( + self, + num_pipefusion_patches: int, + patches_shape_list: List[List[int]], + feature_map_shape: List[int], + dtype: torch.dtype, + ): + assert isinstance(dtype, torch.dtype), "dtype must be a torch.dtype object" + assert ( + isinstance(num_pipefusion_patches, int) and num_pipefusion_patches >= 1 + ), "num_pipefusion_patches must be greater than or equal to 1" + self.dtype = dtype + self.num_pipefusion_patches = num_pipefusion_patches + self.recv_buffer = [ + torch.zeros(*shape, dtype=self.dtype, device=self.device) + for shape in patches_shape_list + ] + self.recv_buffer.append( + torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device) + ) + self.recv_buffer_set = True + + def set_extra_tensors_recv_buffer( + self, + name: str, + shape: List[int], + num_buffers: int = 1, + dtype: torch.dtype = torch.float16, + ): + self.extra_tensors_recv_buffer[name] = [ + torch.zeros(*shape, dtype=dtype, device=self.device) + for _ in range(num_buffers) + ] + + def _check_shape_and_buffer( + self, + tensor_send_to_next=None, + recv_prev=False, + name: Optional[str] = None, + segment_idx: int = 0, + ): + send_flag = False + name = name or "latent" + if tensor_send_to_next is not None: + shape_list = self.send_shape.get(name, None) + if shape_list is None: + self.send_shape[name] = {segment_idx: tensor_send_to_next.shape} + send_flag = True + elif shape_list.get(segment_idx, None) is None: + self.send_shape[name][segment_idx] = tensor_send_to_next.shape + send_flag = True + + recv_flag = False + if recv_prev: + shape_list = self.recv_shape.get(name, None) + if shape_list is None: + recv_flag = True + elif shape_list.get(segment_idx, None) is None: + recv_flag = True + + recv_prev_shape = self._communicate_shapes( + tensor_send_to_next=tensor_send_to_next if send_flag else None, + recv_prev=recv_flag, + ) + + if recv_flag: + if self.recv_shape.get(name, None) is None: + self.recv_shape[name] = {segment_idx: recv_prev_shape} + else: + self.recv_shape[name][segment_idx] = recv_prev_shape + + if self.recv_buffer.get(name, None) is None: + self.recv_buffer[name] = { + segment_idx: torch.zeros( + recv_prev_shape, device=self.device, dtype=self.dtype + ) + } + else: + if self.recv_buffer[name].get(segment_idx, None) is not None: + logger.warning( + f"Recv buffer [name: {name}, segment_idx: {segment_idx}] already exist. updating..." + ) + self.recv_buffer[name][segment_idx] = torch.zeros( + recv_prev_shape, device=self.device, dtype=self.dtype + ) + + def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): + """Communicate tensor shapes between stages. Used to communicate + tensor shapes before the actual tensor communication happens. + + Args: + tensor_send_next: tensor to send to next rank (no tensor sent if + set to None). + recv_prev: boolean for whether tensor should be received from + previous rank. + """ + + ops = [] + if recv_prev: + recv_prev_dim_tensor = torch.empty( + (1), device=self.device, dtype=torch.int64 + ) + recv_prev_dim_op = torch.distributed.P2POp( + torch.distributed.irecv, + recv_prev_dim_tensor, + self.prev_rank, + self.device_group, + ) + ops.append(recv_prev_dim_op) + + if tensor_send_to_next is not None: + send_next_dim_tensor = torch.tensor( + tensor_send_to_next.dim(), device=self.device, dtype=torch.int64 + ) + send_next_dim_op = torch.distributed.P2POp( + torch.distributed.isend, + send_next_dim_tensor, + self.next_rank, + self.device_group, + ) + ops.append(send_next_dim_op) + + if len(ops) > 0: + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + # To protect against race condition when using batch_isend_irecv(). + # should take this out once the bug with batch_isend_irecv is resolved. + synchronize() + + ops = [] + recv_prev_shape_tensor = None + if recv_prev: + recv_prev_shape_tensor = torch.empty( + torch.Size(recv_prev_dim_tensor), + device=self.device, + dtype=torch.int64, + ) + recv_prev_shape_op = torch.distributed.P2POp( + torch.distributed.irecv, + recv_prev_shape_tensor, + self.prev_rank, + self.device_group, + ) + ops.append(recv_prev_shape_op) + + if tensor_send_to_next is not None: + send_next_shape_tensor = torch.tensor( + tensor_send_to_next.size(), + device=self.device, + dtype=torch.int64, + ) + send_next_shape_op = torch.distributed.P2POp( + torch.distributed.isend, + send_next_shape_tensor, + self.next_rank, + self.device_group, + ) + ops.append(send_next_shape_op) + + if len(ops) > 0: + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + synchronize() + + recv_prev_shape = [0, 0, 0] + if recv_prev_shape_tensor is not None: + recv_prev_shape = recv_prev_shape_tensor + return torch.Size(recv_prev_shape) + + def pipeline_send( + self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1 + ) -> None: + tensor = tensor.contiguous() + self._check_shape_and_buffer( + tensor_send_to_next=tensor, name=name, segment_idx=segment_idx + ) + self._pipeline_isend(tensor).wait() + + def pipeline_isend( + self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1 + ) -> None: + tensor = tensor.contiguous() + self._check_shape_and_buffer( + tensor_send_to_next=tensor, name=name, segment_idx=segment_idx + ) + self._pipeline_isend(tensor) + + def pipeline_recv(self, idx: int = -1, name: str = "latent") -> torch.Tensor: + name = name or "latent" + self._check_shape_and_buffer(recv_prev=True, name=name, segment_idx=idx) + self._pipeline_irecv(self.recv_buffer[name][idx]).wait() + return self.recv_buffer[name][idx] + + def add_pipeline_recv_task(self, idx: int = -1, name: str = "latent"): + name = name or "latent" + self.recv_tasks_queue.append((name, idx)) + + def recv_next(self): + if len(self.recv_tasks_queue) == 0: + raise ValueError("No more tasks to receive") + elif len(self.recv_tasks_queue) > 0: + name, idx = self.recv_tasks_queue.pop(0) + self._check_shape_and_buffer(recv_prev=True, name=name, segment_idx=idx) + self.receiving_tasks.append( + (self._pipeline_irecv(self.recv_buffer[name][idx]), name, idx) + ) + + def get_pipeline_recv_data( + self, idx: int = -1, name: str = "latent" + ) -> torch.Tensor: + assert ( + len(self.receiving_tasks) > 0 + ), "No tasks to receive, call add_pipeline_recv_task first" + receiving_task = self.receiving_tasks.pop(0) + receiving_task[0].wait() + assert ( + receiving_task[1] == name and receiving_task[2] == idx + ), "Received tensor does not match the requested" + return self.recv_buffer[name][idx] + + def _pipeline_irecv(self, tensor: torch.tensor): + return torch.distributed.irecv( + tensor, + src=self.prev_rank, + group=( + self.device_groups[(self.rank_in_group + 1) % 2] + if self.world_size == 2 + else self.device_group + ), + ) + + def _pipeline_isend(self, tensor: torch.tensor): + return torch.distributed.isend( + tensor, + dst=self.next_rank, + group=( + self.device_groups[self.rank_in_group % 2] + if self.world_size == 2 + else self.device_group + ), + ) + + def set_skip_tensor_recv_buffer( + self, + patches_shape_list: List[List[int]], + feature_map_shape: List[int], + ): + self.skip_tensor_recv_buffer = [ + torch.zeros(*shape, dtype=self.dtype, device=self.device) + for shape in patches_shape_list + ] + self.skip_tensor_recv_buffer.append( + torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device) + ) + self.skip_tensor_recv_buffer_set = True + + def pipeline_send_skip(self, tensor: torch.Tensor) -> None: + tensor = tensor.contiguous() + self._pipeline_isend_skip(tensor).wait() + + def pipeline_isend_skip(self, tensor: torch.Tensor) -> None: + tensor = tensor.contiguous() + self._pipeline_isend_skip(tensor) + + def pipeline_recv_skip(self, idx: int = -1) -> torch.Tensor: + self._pipeline_irecv_skip(self.skip_tensor_recv_buffer[idx]).wait() + return self.skip_tensor_recv_buffer[idx] + + def add_pipeline_recv_skip_task(self, idx: int = -1): + self.recv_skip_tasks_queue.append(idx) + + def get_pipeline_recv_skip_data(self, idx: int = -1) -> torch.Tensor: + assert ( + len(self.receiving_skip_tasks) > 0 + ), "No tasks to receive, call add_pipeline_recv_skip_task first" + receiving_skip_task = self.receiving_skip_tasks.pop(0) + receiving_skip_task[0].wait() + assert ( + receiving_skip_task[2] == idx + ), "Received tensor does not match the requested" + return self.skip_tensor_recv_buffer[idx] + + def recv_skip_next(self): + if len(self.recv_skip_tasks_queue) == 0: + raise ValueError("No more tasks to receive") + elif len(self.recv_skip_tasks_queue) > 0: + task = self.recv_skip_tasks_queue.pop(0) + idx = task + self.receiving_skip_tasks.append( + ( + self._pipeline_irecv_skip(self.skip_tensor_recv_buffer[idx]), + None, + idx, + ) + ) + + def _pipeline_irecv_skip(self, tensor: torch.tensor): + return torch.distributed.irecv( + tensor, src=self.skip_rank, group=self.skip_device_group + ) + + def _pipeline_isend_skip(self, tensor: torch.tensor): + return torch.distributed.isend( + tensor, dst=self.skip_rank, group=self.skip_device_group + ) + + +class SequenceParallelGroupCoordinator(GroupCoordinator): + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + group_name: str | None = None, + **kwargs, + ): + super().__init__( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=torch_distributed_backend, + group_name=group_name, + ) + ulysses_group = kwargs.get("ulysses_group", None) + ring_group = kwargs.get("ring_group", None) + if ulysses_group is None: + raise RuntimeError( + f"Please pass argument 'ulysses_group' when calling init func of SequenceParallelGroupCoordinator" + ) + if ring_group is None: + raise RuntimeError( + f"Please pass argument 'ring_group' when calling init func of SequenceParallelGroupCoordinator" + ) + self.ulysses_group = ulysses_group + self.ring_group = ring_group + + self.ulysses_world_size = torch.distributed.get_world_size(self.ulysses_group) + self.ulysses_rank = torch.distributed.get_rank(self.ulysses_group) + self.ring_world_size = torch.distributed.get_world_size(self.ring_group) + self.ring_rank = torch.distributed.get_rank(self.ring_group) diff --git a/sglang/python/sglang/multimodal_gen/runtime/distributed/parallel_groups.py b/sglang/python/sglang/multimodal_gen/runtime/distributed/parallel_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..567c9b30a371f44b7cd510160f2fdce122c61ac2 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/distributed/parallel_groups.py @@ -0,0 +1,91 @@ +# Reference: https://github.com/feifeibear/long-context-attention/blob/main/yunchang/globals.py + + +import torch + + +class Singleton: + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs) + return cls._instance + + +class ProcessGroupSingleton(Singleton): + def __init__(self): + self.ULYSSES_PG = None + self.RING_PG = None + + +PROCESS_GROUP = ProcessGroupSingleton() + + +def set_seq_parallel_pg_by_sp_groups( + sp_ulysses_degree, + sp_ring_degree, + rank: int, + sp_groups: list[list[int]], + use_ulysses_low: bool = True, +): + """Create Ulysses/Ring process groups inside each SP group. + + This is required when TP>1, because SP groups are not necessarily made of + consecutive global ranks (e.g., tp-sp order makes SP ranks strided). + + Args: + sp_ulysses_degree: ulysses degree inside SP. + sp_ring_degree: ring degree inside SP. + rank: global rank of current process. + sp_groups: list of global-rank lists for each SP group. + use_ulysses_low: keep the same semantics as the original function. + """ + sp_degree = sp_ring_degree * sp_ulysses_degree + assert sp_degree > 0 + assert all( + len(g) == sp_degree for g in sp_groups + ), f"Each SP group must have size {sp_degree}, got sizes {[len(g) for g in sp_groups]}" + + ulyssess_pg = None + ring_pg = None + + num_ulysses_pgs = sp_ring_degree + num_ring_pgs = sp_ulysses_degree + + def _map_indices_to_ranks(ranks: list[int], indices: list[int]) -> list[int]: + return [ranks[i] for i in indices] + + # Important: call torch.distributed.new_group in the same order on all ranks. + for sp_ranks in sp_groups: + if use_ulysses_low: + for i in range(num_ulysses_pgs): + idx = list(range(i * sp_ulysses_degree, (i + 1) * sp_ulysses_degree)) + ulysses_ranks = _map_indices_to_ranks(sp_ranks, idx) + group = torch.distributed.new_group(ulysses_ranks) + if rank in ulysses_ranks: + ulyssess_pg = group + + for i in range(num_ring_pgs): + idx = list(range(i, sp_degree, num_ring_pgs)) + ring_ranks = _map_indices_to_ranks(sp_ranks, idx) + group = torch.distributed.new_group(ring_ranks) + if rank in ring_ranks: + ring_pg = group + else: + for i in range(num_ring_pgs): + idx = list(range(i * sp_ring_degree, (i + 1) * sp_ring_degree)) + ring_ranks = _map_indices_to_ranks(sp_ranks, idx) + group = torch.distributed.new_group(ring_ranks) + if rank in ring_ranks: + ring_pg = group + + for i in range(num_ulysses_pgs): + idx = list(range(i, sp_degree, num_ulysses_pgs)) + ulysses_ranks = _map_indices_to_ranks(sp_ranks, idx) + group = torch.distributed.new_group(ulysses_ranks) + if rank in ulysses_ranks: + ulyssess_pg = group + + PROCESS_GROUP.ULYSSES_PG = ulyssess_pg + PROCESS_GROUP.RING_PG = ring_pg diff --git a/sglang/python/sglang/multimodal_gen/runtime/distributed/parallel_state.py b/sglang/python/sglang/multimodal_gen/runtime/distributed/parallel_state.py new file mode 100644 index 0000000000000000000000000000000000000000..147cf1444f6670561bc829d586d20fceb37a0782 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/distributed/parallel_state.py @@ -0,0 +1,1184 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo +# SPDX-License-Identifier: Apache-2.0 +# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Adapted from +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""sglang-diffusion distributed state. + +It takes over the control of the distributed environment from PyTorch. +The typical workflow is: + +- call `init_distributed_environment` to initialize the distributed environment. +- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to + initialize the model parallel groups. + +- any code dealing with the distributed stuff + +- call `destroy_model_parallel` to destroy the model parallel groups. +- call `destroy_distributed_environment` to destroy the distributed environment. + +If you only need to use the distributed environment without model parallelism, + you can skip the model parallel initialization and destruction steps. +""" + +import contextlib +import datetime +import os +import weakref +from collections import namedtuple +from collections.abc import Callable +from contextlib import contextmanager +from multiprocessing import shared_memory +from typing import Any, List, Optional +from unittest.mock import patch + +import torch +import torch.distributed +from torch.distributed import ProcessGroup + +import sglang.multimodal_gen.envs as envs +from sglang.multimodal_gen.runtime.distributed.utils import StatelessProcessGroup +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +from ..utils.distributed import RankGenerator +from .group_coordinator import ( + GroupCoordinator, + PipelineGroupCoordinator, + SequenceParallelGroupCoordinator, + get_local_torch_device, +) + +logger = init_logger(__name__) + +_WORLD: Optional[GroupCoordinator] = None +_TP: Optional[GroupCoordinator] = None +_SP: Optional[SequenceParallelGroupCoordinator] = None +_PP: Optional[PipelineGroupCoordinator] = None +_CFG: Optional[GroupCoordinator] = None +_DP: Optional[GroupCoordinator] = None +_DIT: Optional[GroupCoordinator] = None +_VAE: Optional[GroupCoordinator] = None + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: dict[str, torch.Tensor | Any], +) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + """ + metadata_list: list[tuple[str, Any]] = [] + tensor_list: list[torch.Tensor] = [] + for key, value in tensor_dict.items(): + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (key, TensorMetadata(device, value.dtype, value.size())) + ) + tensor_list.append(value) + else: + metadata_list.append((key, value)) + return metadata_list, tensor_list + + +_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {} + + +def _register_group(group: "GroupCoordinator") -> None: + _groups[group.unique_name] = weakref.ref(group) + + +def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_reduce_out_place(tensor) + + +def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + return torch.empty_like(tensor) + + +_WORLD: GroupCoordinator | None = None +_NODE: GroupCoordinator | None = None + + +def get_world_group() -> GroupCoordinator: + assert _WORLD is not None, "world group is not initialized" + return _WORLD + + +def init_world_group( + ranks: list[int], local_rank: int, backend: str +) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + use_device_communicator=True, + group_name="world", + ) + + +# xDiT +def init_parallel_group_coordinator( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + parallel_mode: str, + **kwargs, +) -> GroupCoordinator: + """ + Returns a Group Coordinator for the given parallel mode + """ + assert parallel_mode in [ + "data", + "pipeline", + "tensor", + "sequence", + "classifier_free_guidance", + ], f"parallel_mode {parallel_mode} is not supported" + if parallel_mode == "pipeline": + return PipelineGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + group_name="pp_group", + ) + elif parallel_mode == "sequence": + return SequenceParallelGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + group_name="sp_group", + **kwargs, + ) + else: + # fallback to GroupCoordinator + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + group_name="cfg_group", + ) + + +# def init_parallel_group_coordinator( +# group_ranks: list[list[int]], +# local_rank: int, +# backend: str, +# use_message_queue_broadcaster: bool = False, +# group_name: str | None = None, +# ) -> GroupCoordinator: +# return GroupCoordinator( +# group_ranks=group_ranks, +# local_rank=local_rank, +# torch_distributed_backend=backend, +# use_device_communicator=True, +# use_message_queue_broadcaster=use_message_queue_broadcaster, +# group_name=group_name, +# ) + + +_TP: GroupCoordinator | None = None + + +def get_tp_group() -> GroupCoordinator: + assert _TP is not None, "tensor model parallel group is not initialized" + return _TP + + +_ENABLE_CUSTOM_ALL_REDUCE = True + + +def set_custom_all_reduce(enable: bool): + global _ENABLE_CUSTOM_ALL_REDUCE + _ENABLE_CUSTOM_ALL_REDUCE = enable + + +def init_distributed_environment( + world_size: int = 1, + rank: int = 0, + distributed_init_method: str = "env://", + local_rank: int = 0, + backend: str = "nccl", + device_id: torch.device | None = None, + timeout: int | None = None, +): + # Determine the appropriate backend based on the platform + from sglang.multimodal_gen.runtime.platforms import current_platform + + if backend == "nccl" and not current_platform.is_cuda_alike(): + # Use gloo backend for non-CUDA platforms (MPS, CPU) + backend = "gloo" + logger.info("Using gloo backend for %s platform", current_platform.device_name) + + logger.debug( + "world_size=%d rank=%d local_rank=%d " + "distributed_init_method=%s backend=%s timeout=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + timeout, + ) + if not torch.distributed.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing " + "distributed environment" + ) + + # For MPS and MUSA, don't pass device_id as it doesn't support device indices + extra_args = ( + {} + if ( + current_platform.is_mps() + or current_platform.is_musa() + or current_platform.is_npu() + ) + else dict(device_id=device_id) + ) + + if timeout is not None: + + extra_args["timeout"] = datetime.timedelta(seconds=timeout) + logger.info(f"Setting distributed timeout to {timeout} seconds") + + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + **extra_args, + ) + + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = envs.LOCAL_RANK + else: + local_rank = rank + global _WORLD + if _WORLD is None: + ranks = list(range(torch.distributed.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + else: + assert ( + _WORLD.world_size == torch.distributed.get_world_size() + ), "world group already initialized with a different world size" + + +_SP: GroupCoordinator | None = None + + +def get_sp_group() -> SequenceParallelGroupCoordinator: + assert _SP is not None, "pipeline model parallel group is not initialized" + return _SP + + +_DP: GroupCoordinator | None = None + + +def get_dp_group() -> GroupCoordinator: + assert _DP is not None, "data parallel group is not initialized" + return _DP + + +# xDiT +def initialize_model_parallel( + data_parallel_size: int = 1, + classifier_free_guidance_degree: int = 1, + sequence_parallel_degree: Optional[int] = None, + ulysses_degree: int = 1, + ring_degree: int = 1, + tensor_parallel_degree: int = 1, + pipeline_parallel_degree: int = 1, + vae_parallel_size: int = 0, + backend: Optional[str] = None, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + data_parallel_size: number of data parallelism groups. + classifier_free_guidance_degree: number of GPUs used for Classifier Free Guidance (CFG) + sequence_parallel_degree: number of GPUs used for sequence parallelism. sequence_parallel_degree = ulysses_degree * ring_degree + ulysses_degree: number of GPUs used for ulysses sequence parallelism. + ring_degree: number of GPUs used for ring sequence parallelism. + tensor_parallel_degree: number of GPUs used for tensor parallelism. + pipeline_parallel_degree: number of GPUs used for pipeline parallelism. + backend: distributed backend of pytorch collective comm. + + Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we + use 2 groups to parallelize the batch dim(dp), 2 groups to parallelize + split batch caused by CFG, and 2 GPUs to parallelize sequence. + + dp_degree (2) * cfg_degree (2) * sp_degree (2) * pp_degree (2) = 16. + + The present function will create 8 data-parallel groups, + 8 CFG group, 8 pipeline-parallel group, and + 8 sequence-parallel groups: + 8 data-parallel groups: + [g0, g8], [g1, g9], [g2, g10], [g3, g11], + [g4, g12], [g5, g13], [g6, g14], [g7, g15] + 8 CFG-parallel groups: + [g0, g4], [g1, g5], [g2, g6], [g3, g7], + [g8, g12], [g9, g13], [g10, g14], [g11, g15] + 8 sequence-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7], + [g8, g9], [g10, g11], [g12, g13], [g14, g15] + 8 pipeline-parallel groups: + [g0, g2], [g4, g6], [g8, g10], [g12, g14], + [g1, g3], [g5, g7], [g9, g11], [g13, g15] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + + if backend is None: + from sglang.multimodal_gen.runtime.platforms import current_platform + + backend = current_platform.get_torch_distributed_backend_str() + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + + dit_parallel_size = ( + data_parallel_size + * classifier_free_guidance_degree + * sequence_parallel_degree + * pipeline_parallel_degree + * tensor_parallel_degree + ) + + if world_size < dit_parallel_size: + raise RuntimeError( + f"world_size ({world_size}) is less than " + f"tensor_parallel_degree ({tensor_parallel_degree}) x " + f"pipeline_parallel_degree ({pipeline_parallel_degree}) x" + f"sequence_parallel_degree ({sequence_parallel_degree}) x" + f"classifier_free_guidance_degree " + f"({classifier_free_guidance_degree}) x" + f"data_parallel_degree ({data_parallel_size})" + ) + + rank_generator: RankGenerator = RankGenerator( + tensor_parallel_degree, + sequence_parallel_degree, + pipeline_parallel_degree, + classifier_free_guidance_degree, + data_parallel_size, + "tp-sp-pp-cfg-dp", + ) + global _DP + assert _DP is None, "data parallel group is already initialized" + _DP = init_parallel_group_coordinator( + group_ranks=rank_generator.get_ranks("dp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="data", + ) + + global _CFG + assert _CFG is None, "classifier_free_guidance group is already initialized" + _CFG = init_parallel_group_coordinator( + group_ranks=rank_generator.get_ranks("cfg"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="classifier_free_guidance", + ) + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + _PP = init_parallel_group_coordinator( + group_ranks=rank_generator.get_ranks("pp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="pipeline", + ) + + global _SP + assert _SP is None, "sequence parallel group is already initialized" + + try: + from .parallel_groups import PROCESS_GROUP as _YC_PROCESS_GROUP + from .parallel_groups import ( + set_seq_parallel_pg_by_sp_groups as _set_seq_parallel_pg_by_sp_groups, + ) + except ImportError: + _set_seq_parallel_pg_by_sp_groups = None + + class _DummyProcessGroup: + ULYSSES_PG = torch.distributed.group.WORLD + RING_PG = torch.distributed.group.WORLD + + PROCESS_GROUP = _DummyProcessGroup() + else: + # Build SGLang Diffusion SP sub-groups based on the true SP groups. This is + # critical when TP>1, because SP groups may be strided in global ranks + # (e.g., tp-sp order). + sp_groups = rank_generator.get_ranks("sp") + _set_seq_parallel_pg_by_sp_groups( + sp_ulysses_degree=ulysses_degree, + sp_ring_degree=ring_degree, + rank=get_world_group().rank, + sp_groups=sp_groups, + ) + PROCESS_GROUP = _YC_PROCESS_GROUP + + _SP = init_parallel_group_coordinator( + group_ranks=rank_generator.get_ranks("sp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="sequence", + ulysses_group=PROCESS_GROUP.ULYSSES_PG, + ring_group=PROCESS_GROUP.RING_PG, + ) + + global _TP + assert _TP is None, "Tensor parallel group is already initialized" + _TP = init_parallel_group_coordinator( + group_ranks=rank_generator.get_ranks("tp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="tensor", + ) + + if vae_parallel_size > 0: + init_vae_group(dit_parallel_size, vae_parallel_size, backend) + init_dit_group(dit_parallel_size, backend) + + +# + + +# def initialize_model_parallel( +# tensor_model_parallel_size: int = 1, +# sequence_model_parallel_size: int = 1, +# data_parallel_size: int = 1, +# backend: str | None = None, +# ) -> None: +# """ +# Initialize model parallel groups. +# +# Arguments: +# tensor_model_parallel_size: number of GPUs used for tensor model +# parallelism (used for language encoder). +# sequence_model_parallel_size: number of GPUs used for sequence model +# parallelism (used for DiT). +# """ +# # Get world size and rank. Ensure some consistencies. +# assert ( +# _WORLD is not None +# ), "world group is not initialized, please call init_distributed_environment first" +# world_size: int = get_world_size() +# backend = backend or torch.distributed.get_backend(get_world_group().device_group) +# assert ( +# world_size >= tensor_model_parallel_size +# ), f"world_size({world_size}) must be greater than or equal to tensor_model_parallel_size({tensor_model_parallel_size})" +# num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size +# global _TP +# assert _TP is None, "tensor model parallel group is already initialized" +# group_ranks = [] +# for i in range(num_tensor_model_parallel_groups): +# ranks = list( +# range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) +# ) +# group_ranks.append(ranks) +# +# # message queue broadcaster is only used in tensor model parallel group +# _TP = init_parallel_group_coordinator( +# group_ranks, +# get_world_group().local_rank, +# backend, +# use_message_queue_broadcaster=True, +# group_name="tp", +# ) +# +# # Build the sequence model-parallel groups. +# num_sequence_model_parallel_groups: int = world_size // sequence_model_parallel_size +# global _SP +# assert _SP is None, "sequence model parallel group is already initialized" +# group_ranks = [] +# +# # Since SP is incompatible with TP and PP, we can use a simpler group creation logic +# for i in range(num_sequence_model_parallel_groups): +# # Create groups of consecutive ranks +# ranks = list( +# range( +# i * sequence_model_parallel_size, (i + 1) * sequence_model_parallel_size +# ) +# ) +# group_ranks.append(ranks) +# +# _SP = init_parallel_group_coordinator( +# group_ranks, get_world_group().local_rank, backend, group_name="sp" +# ) +# +# # Build the data parallel groups. +# num_data_parallel_groups: int = sequence_model_parallel_size +# global _DP +# assert _DP is None, "data parallel group is already initialized" +# group_ranks = [] +# +# for i in range(num_data_parallel_groups): +# ranks = list(range(i, world_size, num_data_parallel_groups)) +# group_ranks.append(ranks) +# +# _DP = init_parallel_group_coordinator( +# group_ranks, get_world_group().local_rank, backend, group_name="dp" +# ) +# + + +def get_sp_world_size() -> int: + """Return world size for the sequence model parallel group.""" + return get_sp_group().world_size + + +def get_sp_parallel_rank() -> int: + """Return my rank for the sequence model parallel group.""" + return get_sp_group().rank_in_group + + +def get_world_size() -> int: + """Return world size for the world group.""" + return get_world_group().world_size + + +def get_world_rank() -> int: + """Return my rank for the world group.""" + return get_world_group().rank + + +def get_dp_world_size() -> int: + """Return world size for the data parallel group.""" + return get_dp_group().world_size + + +def get_dp_rank() -> int: + """Return my rank for the data parallel group.""" + return get_dp_group().rank_in_group + + +def maybe_init_distributed_environment_and_model_parallel( + tp_size: int, + sp_size: int, + enable_cfg_parallel: bool, + ulysses_degree: int = 1, + ring_degree: int = 1, + dp_size: int = 1, + distributed_init_method: str = "env://", + dist_timeout: int | None = None, +): + from sglang.multimodal_gen.runtime.platforms import current_platform + + if _WORLD is not None and model_parallel_is_initialized(): + # make sure the tp and sp sizes are correct + assert ( + get_tp_world_size() == tp_size + ), f"You are trying to initialize model parallel groups with size {tp_size}, but they are already initialized with size {get_tp_world_size()}" + assert ( + get_sp_world_size() == sp_size + ), f"You are trying to initialize model parallel groups with size {sp_size}, but they are already initialized with size {get_sp_world_size()}" + return + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + rank = int(os.environ.get("RANK", 0)) + device = get_local_torch_device() + logger.info( + "Initializing distributed environment with world_size=%d, device=%s, timeout=%s", + world_size, + device, + dist_timeout, + main_process_only=False, + ) + + init_distributed_environment( + world_size=world_size, + rank=rank, + local_rank=local_rank, + distributed_init_method=distributed_init_method, + device_id=device, + backend=current_platform.get_torch_distributed_backend_str(), + timeout=dist_timeout, + ) + initialize_model_parallel( + data_parallel_size=dp_size, + classifier_free_guidance_degree=2 if enable_cfg_parallel else 1, + tensor_parallel_degree=tp_size, + ulysses_degree=ulysses_degree, + ring_degree=ring_degree, + sequence_parallel_degree=sp_size, + ) + + # Only set CUDA device if we're on a CUDA platform + if current_platform.is_cuda_alike(): + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + elif current_platform.is_npu(): + device = torch.device(f"npu:{local_rank}") + torch.npu.set_device(device) + + +def model_parallel_is_initialized() -> bool: + """Check if tensor, sequence parallel groups are initialized.""" + return _TP is not None and _SP is not None and _DP is not None and _CFG is not None + + +_TP_STATE_PATCHED = False + + +@contextmanager +def patch_tensor_parallel_group(tp_group: GroupCoordinator): + """Patch the tp group temporarily until this function ends. + + This method is for draft workers of speculative decoding to run draft model + with different tp degree from that of target model workers. + + """ + global _TP_STATE_PATCHED + assert not _TP_STATE_PATCHED, "Should not call when it's already patched" + + _TP_STATE_PATCHED = True + old_tp_group = get_tp_group() + global _TP + _TP = tp_group + try: + yield + finally: + # restore the original state + _TP_STATE_PATCHED = False + _TP = old_tp_group + + +def get_tp_world_size() -> int: + """Return world size for the tensor model parallel group.""" + return get_tp_group().world_size + + +def get_tp_rank() -> int: + """Return my rank for the tensor model parallel group.""" + return get_tp_group().rank_in_group + + +def destroy_distributed_environment() -> None: + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +def cleanup_dist_env_and_memory(shutdown_ray: bool = False): + destroy_model_parallel() + destroy_distributed_environment() + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + if shutdown_ray: + import ray # Lazy import Ray + + ray.shutdown() + + +def is_the_same_node_as( + pg: ProcessGroup | StatelessProcessGroup, source_rank: int = 0 +) -> list[int]: + """ + This is a collective operation that returns if each rank is in the same node + as the source rank. It tests if processes are attached to the same + memory system (shared access to shared memory). + """ + if isinstance(pg, ProcessGroup): + assert ( + torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL + ), "in_the_same_node_as should be tested with a non-NCCL group." + # local rank inside the group + rank = torch.distributed.get_rank(group=pg) + world_size = torch.distributed.get_world_size(group=pg) + + # global ranks of the processes in the group + ranks = torch.distributed.get_process_group_ranks(pg) + else: + rank = pg.rank + world_size = pg.world_size + ranks = list(range(world_size)) + + # local tensor in each process to store the result + is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32) + + magic_message = b"magic_message" + shm = None + + try: + with contextlib.suppress(OSError): + if rank == source_rank: + # create a shared memory segment + shm = shared_memory.SharedMemory(create=True, size=128) + shm.buf[: len(magic_message)] = magic_message + if isinstance(pg, ProcessGroup): + torch.distributed.broadcast_object_list( + [shm.name], src=ranks[source_rank], group=pg + ) + else: + pg.broadcast_obj(shm.name, src=source_rank) + is_in_the_same_node[rank] = 1 + else: + # try to open the shared memory segment + if isinstance(pg, ProcessGroup): + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=ranks[source_rank], group=pg + ) + name = recv[0] + else: + name = pg.broadcast_obj(None, src=source_rank) + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch( + "multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None, + ): + shm = shared_memory.SharedMemory(name=name) + if shm.buf[: len(magic_message)] == magic_message: + is_in_the_same_node[rank] = 1 + except Exception as e: + logger.error("Error ignored in is_in_the_same_node: %s", e) + finally: + if shm: + shm.close() + + if isinstance(pg, ProcessGroup): + torch.distributed.barrier(group=pg) + else: + pg.barrier() + + # clean up the shared memory segment + with contextlib.suppress(OSError): + if rank == source_rank and shm: + shm.unlink() + + if isinstance(pg, ProcessGroup): + torch.distributed.all_reduce(is_in_the_same_node, group=pg) + aggregated_data = is_in_the_same_node + else: + aggregated_data = torch.zeros_like(is_in_the_same_node) + for i in range(world_size): + rank_data = pg.broadcast_obj(is_in_the_same_node, src=i) + aggregated_data += rank_data + + return [x == 1 for x in aggregated_data.tolist()] + + +def initialize_tensor_parallel_group( + tensor_model_parallel_size: int = 1, + backend: str | None = None, + group_name_suffix: str = "", +) -> GroupCoordinator: + """Initialize a tensor parallel group for a specific model. + + This function creates a tensor parallel group that can be used with the + patch_tensor_parallel_group context manager. It allows different models + to use different tensor parallelism configurations. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model parallelism. + backend: communication backend to use. + group_name_suffix: optional suffix to make the group name unique. + + Returns: + A GroupCoordinator for tensor parallelism that can be used with + the patch_tensor_parallel_group context manager. + + Example usage: + ```python + # Initialize tensor parallel group for model1 + tp_group_model1 = initialize_tensor_parallel_group( + tensor_model_parallel_size=4, + group_name_suffix="model1" + ) + + # Use tensor parallelism for model1 + with patch_tensor_parallel_group(tp_group_model1): + # Run model1 with tensor parallelism + output1 = model1(input1) + ``` + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + + # Ensure the world size is compatible with the parallelism configuration + assert ( + world_size % tensor_model_parallel_size == 0 + ), f"World size ({world_size}) must be divisible by tensor_model_parallel_size ({tensor_model_parallel_size})" + + # Build the tensor model-parallel groups. + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + tp_group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = list( + range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + ) + tp_group_ranks.append(ranks) + + # Create TP group coordinator with a unique name + group_name = f"tp_{group_name_suffix}" if group_name_suffix else "tp" + tp_group = init_parallel_group_coordinator( + tp_group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name=group_name, + ) + + return tp_group + + +def initialize_sequence_parallel_group( + sequence_model_parallel_size: int = 1, + backend: str | None = None, + group_name_suffix: str = "", +) -> GroupCoordinator: + """Initialize a sequence parallel group for a specific model. + + This function creates a sequence parallel group that can be used with the + patch_sequence_parallel_group context manager. It allows different models + to use different sequence parallelism configurations. + + Arguments: + sequence_model_parallel_size: number of GPUs used for sequence model parallelism. + backend: communication backend to use. + group_name_suffix: optional suffix to make the group name unique. + + Returns: + A GroupCoordinator for sequence parallelism that can be used with + the patch_sequence_parallel_group context manager. + + Example usage: + ```python + # Initialize sequence parallel group for model2 + sp_group_model2 = initialize_sequence_parallel_group( + sequence_model_parallel_size=2, + group_name_suffix="model2" + ) + + # Use sequence parallelism for model2 + with patch_sequence_parallel_group(sp_group_model2): + # Run model2 with sequence parallelism + output2 = model2(input2) + ``` + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + + # Ensure the world size is compatible with the parallelism configuration + assert ( + world_size % sequence_model_parallel_size == 0 + ), f"World size ({world_size}) must be divisible by sequence_model_parallel_size ({sequence_model_parallel_size})" + + # Build the sequence model-parallel groups. + num_sequence_model_parallel_groups: int = world_size // sequence_model_parallel_size + sp_group_ranks = [] + + for i in range(num_sequence_model_parallel_groups): + # Create groups of consecutive ranks + ranks = list( + range( + i * sequence_model_parallel_size, (i + 1) * sequence_model_parallel_size + ) + ) + sp_group_ranks.append(ranks) + + # Create SP group coordinator with a unique name + group_name = f"sp_{group_name_suffix}" if group_name_suffix else "sp" + sp_group = init_parallel_group_coordinator( + sp_group_ranks, get_world_group().local_rank, backend, group_name=group_name + ) + + return sp_group + + +# * QUERY +def get_world_group() -> GroupCoordinator: + assert _WORLD is not None, "world group is not initialized" + return _WORLD + + +# TP +def get_tp_group() -> GroupCoordinator: + assert _TP is not None, "tensor model parallel group is not initialized" + return _TP + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return get_tp_group().world_size + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return get_tp_group().rank_in_group + + +def get_sequence_parallel_world_size(): + """Return world size for the sequence parallel group.""" + return get_sp_group().world_size + + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + return get_sp_group().rank_in_group + + +def get_ulysses_parallel_world_size(): + return get_sp_group().ulysses_world_size + + +def get_ulysses_parallel_rank(): + return get_sp_group().ulysses_rank + + +def get_ring_parallel_world_size(): + return get_sp_group().ring_world_size + + +def get_ring_parallel_rank(): + return get_sp_group().ring_rank + + +# PP +def get_pp_group() -> PipelineGroupCoordinator: + assert _PP is not None, "pipeline model parallel group is not initialized" + return _PP + + +def get_pipeline_parallel_world_size(): + """Return world size for the pipeline model parallel group.""" + return get_pp_group().world_size + + +def get_pipeline_parallel_rank(): + """Return my rank for the pipeline model parallel group.""" + return get_pp_group().rank_in_group + + +def is_pipeline_first_stage(): + """Return True if in the first pipeline model parallel stage, False otherwise.""" + return get_pipeline_parallel_rank() == 0 + + +def is_pipeline_last_stage(): + """Return True if in the last pipeline model parallel stage, False otherwise.""" + return get_pipeline_parallel_rank() == (get_pipeline_parallel_world_size() - 1) + + +# CFG +def get_cfg_group() -> GroupCoordinator: + assert ( + _CFG is not None + ), "classifier_free_guidance parallel group is not initialized" + return _CFG + + +def get_classifier_free_guidance_world_size(): + """Return world size for the classifier_free_guidance parallel group.""" + return get_cfg_group().world_size + + +def get_classifier_free_guidance_rank(): + """Return my rank for the classifier_free_guidance parallel group.""" + return get_cfg_group().rank_in_group + + +# DP +def get_dp_group() -> GroupCoordinator: + assert _DP is not None, "pipeline model parallel group is not initialized" + return _DP + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + return get_dp_group().world_size + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + return get_dp_group().rank_in_group + + +def is_dp_last_group(): + """Return True if in the last data parallel group, False otherwise.""" + return ( + get_sequence_parallel_rank() == (get_sequence_parallel_world_size() - 1) + and get_classifier_free_guidance_rank() + == (get_classifier_free_guidance_world_size() - 1) + and get_pipeline_parallel_rank() == (get_pipeline_parallel_world_size() - 1) + ) + + +def get_dit_world_size(): + """Return world size for the DiT model (excluding VAE).""" + return ( + get_data_parallel_world_size() + * get_classifier_free_guidance_world_size() + * get_sequence_parallel_world_size() + * get_pipeline_parallel_world_size() + * get_tensor_model_parallel_world_size() + ) + + +# Add VAE getter functions +def get_vae_parallel_group() -> GroupCoordinator: + assert _VAE is not None, "VAE parallel group is not initialized" + return _VAE + + +def get_vae_parallel_world_size(): + """Return world size for the VAE parallel group.""" + return get_vae_parallel_group().world_size + + +def get_vae_parallel_rank(): + """Return my rank for the VAE parallel group.""" + return get_vae_parallel_group().rank_in_group + + +# * SET + + +def init_world_group( + ranks: List[int], local_rank: int, backend: str +) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + ) + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return ( + _DP is not None + and _CFG is not None + and _SP is not None + and _PP is not None + and _TP is not None + ) + + +def init_dit_group( + dit_parallel_size: int, + backend: str, +): + global _DIT + _DIT = torch.distributed.new_group( + ranks=list(range(dit_parallel_size)), backend=backend + ) + + +def get_dit_group(): + assert _DIT is not None, "DIT group is not initialized" + return _DIT + + +def init_vae_group( + dit_parallel_size: int, + vae_parallel_size: int, + backend: str, +): + # Initialize VAE group first + global _VAE + assert _VAE is None, "VAE parallel group is already initialized" + vae_ranks = list(range(dit_parallel_size, dit_parallel_size + vae_parallel_size)) + _VAE = torch.distributed.new_group(ranks=vae_ranks, backend=backend) + + +def destroy_model_parallel() -> None: + """Set the groups to none and destroy them.""" + global _TP + if _TP: + _TP.destroy() + _TP = None + + global _SP + if _SP: + _SP.destroy() + _SP = None + + global _DP + if _DP: + _DP.destroy() + _DP = None + + +# xDit +# def destroy_model_parallel(): +# """Set the groups to none and destroy them.""" +# global _DP +# if _DP: +# _DP.destroy() +# _DP = None +# +# global _CFG +# if _CFG: +# _CFG.destroy() +# _CFG = None +# +# global _SP +# if _SP: +# _SP.destroy() +# _SP = None +# +# global _TP +# if _TP: +# _TP.destroy() +# _TP = None +# +# global _PP +# if _PP: +# _PP.destroy() +# _PP = None +# +# global _VAE +# if _VAE: +# _VAE.destroy() +# _VAE = None + + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() diff --git a/sglang/python/sglang/multimodal_gen/runtime/distributed/utils.py b/sglang/python/sglang/multimodal_gen/runtime/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2d84f8b52f57aa3e59ded09dd6619f2e5ca112f5 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/distributed/utils.py @@ -0,0 +1,195 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/utils.py + +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import dataclasses +import pickle +import time +from collections import deque +from collections.abc import Sequence +from typing import Any + +import torch +from torch.distributed import TCPStore + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def ensure_divisibility(numerator, denominator) -> None: + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format( + numerator, denominator + ) + + +def divide(numerator: int, denominator: int) -> int: + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> Sequence[torch.Tensor]: + """Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # NOTE: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tuple(tensor_list) + + +@dataclasses.dataclass +class StatelessProcessGroup: + """A dataclass to hold a metadata store, and the rank, world_size of the + group. Only use it to communicate metadata between processes. + For data-plane communication, create NCCL-related objects. + """ + + rank: int + world_size: int + store: torch._C._distributed_c10d.Store + data_expiration_seconds: int = 3600 # 1 hour + + # dst rank -> counter + send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict) + # src rank -> counter + recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict) + broadcast_send_counter: int = 0 + broadcast_recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict) + + # A deque to store the data entries, with key and timestamp. + entries: deque[tuple[str, float]] = dataclasses.field(default_factory=deque) + + def __post_init__(self): + assert self.rank < self.world_size + self.send_dst_counter = {i: 0 for i in range(self.world_size)} + self.recv_src_counter = {i: 0 for i in range(self.world_size)} + self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)} + + def send_obj(self, obj: Any, dst: int): + """Send an object to a destination rank.""" + self.expire_data() + key = f"send_to/{dst}/{self.send_dst_counter[dst]}" + self.store.set(key, pickle.dumps(obj)) + self.send_dst_counter[dst] += 1 + self.entries.append((key, time.perf_counter())) + + def expire_data(self) -> None: + """Expire data that is older than `data_expiration_seconds` seconds.""" + while self.entries: + # check the oldest entry + key, timestamp = self.entries[0] + if time.perf_counter() - timestamp > self.data_expiration_seconds: + self.store.delete_key(key) + self.entries.popleft() + else: + break + + def recv_obj(self, src: int) -> Any: + """Receive an object from a source rank.""" + obj = pickle.loads( + self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}") + ) + self.recv_src_counter[src] += 1 + return obj + + def broadcast_obj(self, obj: Any | None, src: int) -> Any: + """Broadcast an object from a source rank to all other ranks. + It does not clean up after all ranks have received the object. + Use it for limited times, e.g., for initialization. + """ + if self.rank == src: + self.expire_data() + key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}" + self.store.set(key, pickle.dumps(obj)) + self.broadcast_send_counter += 1 + self.entries.append((key, time.perf_counter())) + return obj + else: + key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}" + recv_obj = pickle.loads(self.store.get(key)) + self.broadcast_recv_src_counter[src] += 1 + return recv_obj + + def all_gather_obj(self, obj: Any) -> list[Any]: + """All gather an object from all ranks.""" + gathered_objs = [] + for i in range(self.world_size): + if i == self.rank: + gathered_objs.append(obj) + self.broadcast_obj(obj, src=self.rank) + else: + recv_obj = self.broadcast_obj(None, src=i) + gathered_objs.append(recv_obj) + return gathered_objs + + def barrier(self): + """A barrier to synchronize all ranks.""" + for i in range(self.world_size): + if i == self.rank: + self.broadcast_obj(None, src=self.rank) + else: + self.broadcast_obj(None, src=i) + + @staticmethod + def create( + host: str, + port: int, + rank: int, + world_size: int, + data_expiration_seconds: int = 3600, + ) -> "StatelessProcessGroup": + """A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. + + If we have process A and process B called `torch.distributed.init_process_group` + to form a group, and then we want to form another group with process A, B, C, + D, it is not possible in PyTorch, because process A and process B have already + formed a group, and process C and process D cannot join that group. This + function is a workaround for this issue. + + `torch.distributed.init_process_group` is a global call, while this function + is a stateless call. It will return a `StatelessProcessGroup` object that can be + used for exchanging metadata. With this function, process A and process B + can call `StatelessProcessGroup.create` to form a group, and then process A, B, + C, and D can call `StatelessProcessGroup.create` to form another group. + """ # noqa + store = TCPStore( + host_name=host, + port=port, + world_size=world_size, + is_master=(rank == 0), + ) + + return StatelessProcessGroup( + rank=rank, + world_size=world_size, + store=store, + data_expiration_seconds=data_expiration_seconds, + ) diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/__init__.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71c2e4a938b0b0c3c0059e88938caf09405b7f24 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/__init__.py @@ -0,0 +1,4 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo +from sglang.multimodal_gen.runtime.utils.logging_utils import globally_suppress_loggers + +globally_suppress_loggers() diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/cli/__init__.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..af2eb7d103a81378d30056cf536353bd24621a5e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/cli/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/cli/cli_types.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/cli/cli_types.py new file mode 100644 index 0000000000000000000000000000000000000000..16b9dd44b6ea22ef1b3076b115a34685f3ee23f5 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/cli/cli_types.py @@ -0,0 +1,30 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/types.py + +import argparse + +from sglang.multimodal_gen.utils import FlexibleArgumentParser + + +class CLISubcommand: + """Base class for CLI subcommands""" + + name: str + + def cmd( + self, args: argparse.Namespace, unknown_args: list[str] | None = None + ) -> None: + """Execute the command with the given arguments""" + raise NotImplementedError + + def validate(self, args: argparse.Namespace) -> None: + """Validate the arguments for this command""" + pass + + def subparser_init( + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: + """Initialize the subparser for this command""" + raise NotImplementedError diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/cli/generate.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/cli/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..a38a9cfc76baa598ccd270d979384a18463c7fec --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/cli/generate.py @@ -0,0 +1,185 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/serve.py + +import argparse +import dataclasses +import json +import os +from typing import cast + +from sglang.multimodal_gen import DiffGenerator +from sglang.multimodal_gen.configs.sample.sampling_params import ( + SamplingParams, + generate_request_id, +) +from sglang.multimodal_gen.runtime.entrypoints.cli.cli_types import CLISubcommand +from sglang.multimodal_gen.runtime.entrypoints.cli.utils import ( + RaiseNotImplementedAction, +) +from sglang.multimodal_gen.runtime.entrypoints.utils import GenerationResult +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.runtime.utils.perf_logger import ( + MemorySnapshot, + PerformanceLogger, + RequestMetrics, +) +from sglang.multimodal_gen.utils import FlexibleArgumentParser + +logger = init_logger(__name__) + + +def add_multimodal_gen_generate_args(parser: argparse.ArgumentParser): + """Add the arguments for the generate command.""" + parser.add_argument( + "--config", + type=str, + default="", + required=False, + help="Read CLI options from a config JSON or YAML file. If provided, --model-path and --prompt are optional.", + ) + parser.add_argument( + "--perf-dump-path", + type=str, + default=None, + required=False, + help="Path to dump the performance metrics (JSON) for the run.", + ) + + parser = ServerArgs.add_cli_args(parser) + parser = SamplingParams.add_cli_args(parser) + + parser.add_argument( + "--text-encoder-configs", + action=RaiseNotImplementedAction, + help="JSON array of text encoder configurations (NOT YET IMPLEMENTED)", + ) + + return parser + + +def maybe_dump_performance( + args: argparse.Namespace, + server_args, + prompt: str, + results: GenerationResult | list[GenerationResult] | None, +): + """dump performance if necessary""" + if not (args.perf_dump_path and results): + return + + if isinstance(results, list): + result = results[0] if results else None + else: + result = results + + metrics_dict = result.metrics + if not (args.perf_dump_path and metrics_dict): + return + + metrics = RequestMetrics(request_id=metrics_dict.get("request_id")) + metrics.stages = metrics_dict.get("stages", {}) + metrics.steps = metrics_dict.get("steps", []) + metrics.total_duration_ms = metrics_dict.get("total_duration_ms", 0) + + # restore memory snapshots from serialized dict + memory_snapshots_dict = metrics_dict.get("memory_snapshots", {}) + for checkpoint_name, snapshot_dict in memory_snapshots_dict.items(): + snapshot = MemorySnapshot( + allocated_mb=snapshot_dict.get("allocated_mb", 0.0), + reserved_mb=snapshot_dict.get("reserved_mb", 0.0), + peak_allocated_mb=snapshot_dict.get("peak_allocated_mb", 0.0), + peak_reserved_mb=snapshot_dict.get("peak_reserved_mb", 0.0), + ) + metrics.memory_snapshots[checkpoint_name] = snapshot + + PerformanceLogger.dump_benchmark_report( + file_path=args.perf_dump_path, + metrics=metrics, + meta={ + "prompt": prompt, + "model": server_args.model_path, + }, + tag="cli_generate", + ) + + +def generate_cmd(args: argparse.Namespace, unknown_args: list[str] | None = None): + """The entry point for the generate command.""" + args.request_id = "mocked_fake_id_for_offline_generate" + + server_args = ServerArgs.from_cli_args(args, unknown_args) + + sampling_params_kwargs = SamplingParams.get_cli_args(args) + sampling_params_kwargs["request_id"] = generate_request_id() + + # Handle diffusers-specific kwargs passed via CLI + if hasattr(args, "diffusers_kwargs") and args.diffusers_kwargs: + try: + sampling_params_kwargs["diffusers_kwargs"] = json.loads( + args.diffusers_kwargs + ) + logger.info( + "Parsed diffusers_kwargs: %s", + sampling_params_kwargs["diffusers_kwargs"], + ) + except json.JSONDecodeError as e: + logger.error("Failed to parse --diffusers-kwargs as JSON: %s", e) + raise ValueError( + f"--diffusers-kwargs must be valid JSON. Got: {args.diffusers_kwargs}" + ) from e + + generator = DiffGenerator.from_pretrained( + model_path=server_args.model_path, server_args=server_args, local_mode=True + ) + + results = generator.generate(sampling_params_kwargs=sampling_params_kwargs) + + prompt = sampling_params_kwargs.get("prompt") + maybe_dump_performance(args, server_args, prompt, results) + + +class GenerateSubcommand(CLISubcommand): + """The `generate` subcommand for the sglang-diffusion CLI""" + + def __init__(self) -> None: + self.name = "generate" + super().__init__() + self.init_arg_names = self._get_init_arg_names() + self.generation_arg_names = self._get_generation_arg_names() + + def _get_init_arg_names(self) -> list[str]: + """Get names of arguments for DiffGenerator initialization""" + return ["num_gpus", "tp_size", "sp_size", "model_path"] + + def _get_generation_arg_names(self) -> list[str]: + """Get names of arguments for generate_video method""" + return [field.name for field in dataclasses.fields(SamplingParams)] + + def cmd( + self, args: argparse.Namespace, unknown_args: list[str] | None = None + ) -> None: + generate_cmd(args, unknown_args) + + def validate(self, args: argparse.Namespace) -> None: + """Validate the arguments for this command""" + if args.num_gpus is not None and args.num_gpus <= 0: + raise ValueError("Number of gpus must be positive") + + if args.config and not os.path.exists(args.config): + raise ValueError(f"Config file not found: {args.config}") + + def subparser_init( + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: + generate_parser = subparsers.add_parser( + "generate", + help="Run inference on a model", + usage="sglang generate (--model-path MODEL_PATH_OR_ID --prompt PROMPT) | --config CONFIG_FILE [OPTIONS]", + ) + + generate_parser = add_multimodal_gen_generate_args(generate_parser) + + return cast(FlexibleArgumentParser, generate_parser) diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/cli/main.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/cli/main.py new file mode 100644 index 0000000000000000000000000000000000000000..26b6a4f6707f1699770941b894defc71855a3e18 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/cli/main.py @@ -0,0 +1,44 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/main.py + +from sglang.multimodal_gen.runtime.entrypoints.cli.cli_types import CLISubcommand +from sglang.multimodal_gen.runtime.entrypoints.cli.generate import GenerateSubcommand +from sglang.multimodal_gen.runtime.entrypoints.cli.serve import ServeSubcommand +from sglang.multimodal_gen.utils import FlexibleArgumentParser + + +def generate_cmd_init() -> list[CLISubcommand]: + return [GenerateSubcommand(), ServeSubcommand()] + + +def cmd_init() -> list[CLISubcommand]: + """Initialize all commands from separate modules""" + commands = [] + commands.extend(generate_cmd_init()) + return commands + + +def main() -> None: + parser = FlexibleArgumentParser(description="sglang-diffusion CLI") + parser.add_argument("-v", "--version", action="version", version="0.1.0") + + subparsers = parser.add_subparsers(required=False, dest="subparser") + + cmds = {} + for cmd in cmd_init(): + cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd) + cmds[cmd.name] = cmd + args, unknown_args = parser.parse_known_args() + if args.subparser in cmds: + cmds[args.subparser].validate(args) + + if hasattr(args, "dispatch_function"): + args.dispatch_function(args, unknown_args=unknown_args) + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/cli/serve.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/cli/serve.py new file mode 100644 index 0000000000000000000000000000000000000000..baa908c774f4df445868dd03d329d5fefbd2b27a --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/cli/serve.py @@ -0,0 +1,72 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os +from typing import cast + +from sglang.multimodal_gen.apps.webui import run_sgl_diffusion_webui +from sglang.multimodal_gen.runtime.entrypoints.cli.cli_types import CLISubcommand +from sglang.multimodal_gen.runtime.launch_server import launch_server +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import FlexibleArgumentParser + +logger = init_logger(__name__) + + +def add_multimodal_gen_serve_args(parser: argparse.ArgumentParser): + """Add the arguments for the serve command.""" + parser.add_argument( + "--config", + type=str, + default="", + required=False, + help="Read CLI options from a config JSON or YAML file.", + ) + return ServerArgs.add_cli_args(parser) + + +def execute_serve_cmd(args: argparse.Namespace, unknown_args: list[str] | None = None): + """The entry point for the serve command.""" + server_args = ServerArgs.from_cli_args(args, unknown_args) + launch_server(server_args) + + if server_args.webui: + run_sgl_diffusion_webui(server_args) + + +class ServeSubcommand(CLISubcommand): + """The `serve` subcommand for the sglang-diffusion CLI""" + + def __init__(self) -> None: + self.name = "serve" + super().__init__() + + def cmd( + self, args: argparse.Namespace, unknown_args: list[str] | None = None + ) -> None: + execute_serve_cmd(args, unknown_args) + + def validate(self, args: argparse.Namespace) -> None: + """Validate the arguments for this command""" + if args.config and not os.path.exists(args.config): + raise ValueError(f"Config file not found: {args.config}") + + def subparser_init( + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: + serve_parser = subparsers.add_parser( + "serve", + help="Launch the server and start FastAPI listener.", + usage="sglang serve --model-path MODEL_PATH_OR_ID [OPTIONS]", + ) + + serve_parser = add_multimodal_gen_serve_args(serve_parser) + + return cast(FlexibleArgumentParser, serve_parser) + + +def cmd_init() -> list[CLISubcommand]: + return [ServeSubcommand()] diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/cli/utils.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/cli/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1699fca947a3336621b0188357c966442abe98b6 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/cli/utils.py @@ -0,0 +1,75 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os +import shlex +import subprocess +import sys + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class RaiseNotImplementedAction(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + raise NotImplementedError(f"The {option_string} option is not yet implemented") + + +def launch_distributed( + num_gpus: int, args: list[str], master_port: int | None = None +) -> int: + """ + Launch a distributed job with the given arguments + + Args: + num_gpus: Number of GPUs to use + args: Arguments to pass to v1_sgl_diffusion_inference.py (defaults to sys.argv[1:]) + master_port: Port for the master process (default: random) + """ + + current_env = os.environ.copy() + python_executable = sys.executable + project_root = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../../..") + ) + main_script = os.path.join( + project_root, "sgl_diffusion/sample/v1_sgl_diffusion_inference.py" + ) + + cmd = [ + python_executable, + "-m", + "torch.distributed.run", + f"--nproc_per_node={num_gpus}", + ] + + if master_port is not None: + cmd.append(f"--master_port={master_port}") + + cmd.append(main_script) + cmd.extend(args) + + logger.info("Running inference with %d GPU(s)", num_gpus) + logger.info("Launching command: %s", shlex.join(cmd)) + + current_env["PYTHONIOENCODING"] = "utf-8" + process = subprocess.Popen( + cmd, + env=current_env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + bufsize=1, + encoding="utf-8", + errors="replace", + ) + + if process.stdout: + for line in iter(process.stdout.readline, ""): + print(line.strip()) + + return process.wait() diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..162df5f89d8d6f2f56b175ae99ef616348329d1f --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py @@ -0,0 +1,533 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +DiffGenerator module for sglang-diffusion. + +This module provides a consolidated interface for generating images/videos using +diffusion models. +""" + +import dataclasses +import multiprocessing as mp +import os +import time +from typing import Any, List, Union + +from sglang.multimodal_gen.configs.sample.sampling_params import ( + DataType, + SamplingParams, +) +from sglang.multimodal_gen.runtime.entrypoints.utils import ( + GenerationResult, + ListLorasReq, + MergeLoraWeightsReq, + SetLoraReq, + ShutdownReq, + UnmergeLoraWeightsReq, + format_lora_message, + prepare_request, + save_outputs, +) +from sglang.multimodal_gen.runtime.launch_server import launch_server +from sglang.multimodal_gen.runtime.pipelines_core import Req +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch +from sglang.multimodal_gen.runtime.scheduler_client import sync_scheduler_client +from sglang.multimodal_gen.runtime.server_args import PortArgs, ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import ( + GREEN, + RESET, + init_logger, + log_batch_completion, + log_generation_timer, +) + +logger = init_logger(__name__) + +# TODO: move to somewhere appropriate +try: + # Set the start method to 'spawn' to avoid CUDA errors in forked processes. + # This must be done at the top level of the module, before any CUDA context + # or other processes are initialized. + mp.set_start_method("spawn", force=True) +except RuntimeError: + # The start method can only be set once per program execution. + pass + + +class DiffGenerator: + """ + A unified class for generating images/videos using diffusion models. + + This class provides a simple interface for image/video generation with rich + customization options, similar to popular frameworks like HF Diffusers. + """ + + def __init__( + self, + server_args: ServerArgs, + ): + """ + Initialize the generator. + + Args: + server_args: The inference arguments + """ + self.server_args = server_args + self.port_args = PortArgs.from_server_args(server_args) + + # The executor is now a client to the Scheduler service + self.local_scheduler_process: list[mp.Process] | None = None + self.owns_scheduler_client: bool = False + + @classmethod + def from_pretrained( + cls, + local_mode: bool = True, + **kwargs, + ) -> "DiffGenerator": + """ + Create a DiffGenerator from a pretrained model. + + Priority level: Default pipeline config < User's pipeline config < User's kwargs + """ + # If users also provide some kwargs, it will override the ServerArgs and PipelineConfig. + + if (server_args := kwargs.get("server_args", None)) is not None: + if isinstance(server_args, ServerArgs): + pass + elif isinstance(server_args, dict): + server_args = ServerArgs.from_kwargs(**server_args) + else: + server_args = ServerArgs.from_kwargs(**kwargs) + + return cls.from_server_args(server_args, local_mode=local_mode) + + @classmethod + def from_server_args( + cls, server_args: ServerArgs, local_mode: bool = True + ) -> "DiffGenerator": + """ + Create a DiffGenerator with the specified arguments. + + Args: + server_args: The inference arguments + + Returns: + The created DiffGenerator + """ + instance = cls( + server_args=server_args, + ) + logger.info(f"Local mode: {local_mode}") + if local_mode: + instance.local_scheduler_process = instance._start_local_server_if_needed() + else: + # In remote mode, we just need to connect and check. + sync_scheduler_client.initialize(server_args) + instance._check_remote_scheduler() + + # In both modes, this DiffGenerator instance is responsible for the client's lifecycle. + instance.owns_scheduler_client = True + return instance + + def _start_local_server_if_needed( + self, + ) -> list[mp.Process]: + """Check if a local server is running; if not, start it and return the process handles.""" + # First, we need a client to test the server. Initialize it temporarily. + sync_scheduler_client.initialize(self.server_args) + + processes = launch_server(self.server_args, launch_http_server=False) + + return processes + + def _check_remote_scheduler(self): + """Check if the remote scheduler is accessible.""" + if not sync_scheduler_client.ping(): + raise ConnectionError( + f"Could not connect to remote scheduler at " + f"{self.server_args.scheduler_endpoint} with `local mode` as False. " + "Please ensure the server is running." + ) + logger.info( + f"Successfully connected to remote scheduler at " + f"{self.server_args.scheduler_endpoint}." + ) + + def generate( + self, + sampling_params_kwargs: dict | None = None, + ) -> GenerationResult | list[GenerationResult] | None: + """Generate image(s)/video(s) based on the given prompt(s). + + Returns a single GenerationResult for a single prompt, a list for + multiple prompts, or None when every request failed. + """ + # 1. prepare requests + prompts = self._resolve_prompts(sampling_params_kwargs.get("prompt")) + user_output_file_name = sampling_params_kwargs.get("output_file_name") + + if len(prompts) > 1 and user_output_file_name is not None: + raise ValueError( + "Cannot use multiple prompts with a fixed output_file_name. " + "Either remove --output-file-name or use a single prompt." + ) + + sampling_params_orig = SamplingParams.from_user_sampling_params_args( + self.server_args.model_path, + server_args=self.server_args, + **sampling_params_kwargs, + ) + + requests: list[Req] = [] + for p in prompts: + sampling_params = dataclasses.replace( + sampling_params_orig, + prompt=p, + output_file_name=user_output_file_name, + ) + sampling_params._set_output_file_name() + req = prepare_request( + server_args=self.server_args, + sampling_params=sampling_params, + ) + requests.append(req) + + results: list[GenerationResult] = [] + total_start_time = time.perf_counter() + + # 2. send requests to scheduler one at a time + # TODO: send batch when supported + for request_idx, req in enumerate(requests): + try: + with log_generation_timer( + logger, req.prompt, request_idx + 1, len(requests) + ) as timer: + output_batch = self._send_to_scheduler_and_wait_for_response([req]) + if output_batch.error: + raise Exception(f"{output_batch.error}") + + if ( + output_batch.output is None + and output_batch.output_file_paths is None + ): + logger.error( + "Received empty output from scheduler for prompt %d", + request_idx + 1, + ) + continue + + common = dict( + prompt=req.prompt, + size=(req.height, req.width, req.num_frames), + generation_time=timer.duration, + peak_memory_mb=output_batch.peak_memory_mb, + metrics=( + output_batch.metrics.to_dict() + if output_batch.metrics + else {} + ), + trajectory_latents=output_batch.trajectory_latents, + trajectory_timesteps=output_batch.trajectory_timesteps, + trajectory_decoded=output_batch.trajectory_decoded, + ) + + if req.save_output and req.return_file_paths_only: + for idx, path in enumerate(output_batch.output_file_paths): + results.append( + GenerationResult( + **common, + prompt_index=idx, + output_file_path=path, + ) + ) + continue + + if req.data_type == DataType.MESH: + for output_idx, sample in enumerate( + output_batch.output_file_paths + ): + results.append( + GenerationResult( + **common, + prompt_index=output_idx, + output_file_path=sample, + ) + ) + continue + + samples_out: list[Any] = [] + audios_out: list[Any] = [] + frames_out: list[Any] = [] + num_outputs = len(output_batch.output) + save_outputs( + output_batch.output, + req.data_type, + req.fps, + req.save_output, + lambda idx: req.output_file_path(num_outputs, idx), + audio=output_batch.audio, + audio_sample_rate=output_batch.audio_sample_rate, + samples_out=samples_out, + audios_out=audios_out, + frames_out=frames_out, + output_compression=req.output_compression, + enable_frame_interpolation=req.enable_frame_interpolation, + frame_interpolation_exp=req.frame_interpolation_exp, + frame_interpolation_scale=req.frame_interpolation_scale, + frame_interpolation_model_path=req.frame_interpolation_model_path, + ) + + for idx in range(len(samples_out)): + results.append( + GenerationResult( + **common, + samples=samples_out[idx], + frames=frames_out[idx], + audio=audios_out[idx], + prompt_index=idx, + output_file_path=req.output_file_path(num_outputs, idx), + ) + ) + except Exception as e: + logger.error( + "Generation failed for prompt %d/%d: %s", + request_idx + 1, + len(requests), + e, + exc_info=True, + ) + continue + + total_gen_time = time.perf_counter() - total_start_time + log_batch_completion(logger, len(results), total_gen_time) + self._log_summary(results) + + if not results: + return None + return results[0] if len(results) == 1 else results + + def _resolve_prompts(self, prompt: str | list[str] | None) -> list[str]: + """Collect prompts from the argument or from a prompt file.""" + if self.server_args.prompt_file_path is not None: + path = self.server_args.prompt_file_path + if not os.path.exists(path): + raise FileNotFoundError(f"Prompt text file not found: {path}") + with open(path, encoding="utf-8") as f: + prompts = [line.strip() for line in f if line.strip()] + if not prompts: + raise ValueError(f"No prompts found in file: {path}") + logger.info("Found %d prompts in %s", len(prompts), path) + return prompts + + if prompt is None: + return [" "] + if isinstance(prompt, str): + return [prompt] + return list(prompt) + + def _log_summary(self, results: list[GenerationResult]) -> None: + if not results: + return + if self.server_args.warmup: + total_duration_ms = results[0].metrics.get("total_duration_ms", 0) + logger.info( + f"Warmed-up request processed in {GREEN}%.2f{RESET} seconds (with warmup excluded)", + total_duration_ms / 1000.0, + ) + + peak_memories = [r.peak_memory_mb for r in results if r.peak_memory_mb] + if peak_memories: + logger.info( + f"Memory usage - Max peak: {max(peak_memories):.2f} MB, " + f"Avg peak: {sum(peak_memories) / len(peak_memories):.2f} MB" + ) + + def _send_to_scheduler_and_wait_for_response(self, batch: list[Req]) -> OutputBatch: + """ + Sends a request to the scheduler and waits for a response. + """ + return sync_scheduler_client.forward(batch) + + # LoRA + def _send_lora_request(self, req: Any, success_msg: str, failure_msg: str): + response = sync_scheduler_client.forward(req) + if response.error is None: + logger.info(success_msg) + return response + else: + error_msg = response.error + raise RuntimeError(f"{failure_msg}: {error_msg}") + + def set_lora( + self, + lora_nickname: Union[str, List[str]], + lora_path: Union[str, None, List[Union[str, None]]] = None, + target: Union[str, List[str]] = "all", + strength: Union[float, List[float]] = 1.0, + ) -> None: + """ + Set LoRA adapter(s) for the specified transformer(s). + Supports both single LoRA (backward compatible) and multiple LoRA adapters. + + Args: + lora_nickname: The nickname(s) of the adapter(s). Can be a string or a list of strings. + lora_path: Path(s) to the LoRA adapter(s). Can be a string, None, or a list of strings/None. + target: Which transformer(s) to apply the LoRA to. Can be a string or a list of strings. + Valid values: + - "all": Apply to all transformers (default) + - "transformer": Apply only to the primary transformer (high noise for Wan2.2) + - "transformer_2": Apply only to transformer_2 (low noise for Wan2.2) + - "critic": Apply only to the critic model + strength: LoRA strength(s) for merge, default 1.0. Can be a float or a list of floats. + """ + req = SetLoraReq( + lora_nickname=lora_nickname, + lora_path=lora_path, + target=target, + strength=strength, + ) + nickname_str, target_str, strength_str = format_lora_message( + lora_nickname, target, strength + ) + + self._send_lora_request( + req, + f"Successfully set LoRA adapter(s): {nickname_str} (target: {target_str}, strength: {strength_str})", + "Failed to set LoRA adapter", + ) + + def unmerge_lora_weights(self, target: str = "all") -> None: + """ + Unmerge LoRA weights from the base model. + + Args: + target: Which transformer(s) to unmerge. + """ + req = UnmergeLoraWeightsReq(target=target) + self._send_lora_request( + req, + f"Successfully unmerged LoRA weights (target: {target})", + "Failed to unmerge LoRA weights", + ) + + def merge_lora_weights(self, target: str = "all", strength: float = 1.0) -> None: + """ + Merge LoRA weights into the base model. + + Args: + target: Which transformer(s) to merge. + strength: LoRA strength for merge, default 1.0. + """ + req = MergeLoraWeightsReq(target=target, strength=strength) + self._send_lora_request( + req, + f"Successfully merged LoRA weights (target: {target}, strength: {strength})", + "Failed to merge LoRA weights", + ) + + def list_loras(self) -> dict: + """List loaded LoRA adapters and current application status per module.""" + output = self._send_lora_request( + req=ListLorasReq(), + success_msg="Successfully listed LoRA adapters", + failure_msg="Failed to list LoRA adapters", + ) + # _send_lora_request already raises on error, so output.error is always None here + return output.output or {} + + def _ensure_lora_state( + self, + lora_path: str | None, + lora_nickname: str | None = None, + merge_lora: bool = True, + ) -> None: + """ + Ensure the LoRA state matches the desired configuration. + + Note: This method does not cache client-side state. The server handles + idempotent operations, so redundant calls are safe but may have minor overhead. + """ + if lora_path is None: + # Unmerge all LoRA weights when no lora_path is provided + self.unmerge_lora_weights() + return + + lora_nickname = lora_nickname or self.server_args.lora_nickname + + # Set the LoRA adapter (server handles idempotent logic) + self.set_lora(lora_nickname, lora_path) + + # Merge or unmerge based on the merge_lora flag + if merge_lora: + self.merge_lora_weights() + else: + self.unmerge_lora_weights() + + def generate_with_lora( + self, + prompt: str | list[str] | None = None, + sampling_params: SamplingParams | None = None, + *, + lora_path: str | None = None, + lora_nickname: str | None = None, + merge_lora: bool = True, + **kwargs, + ): + self._ensure_lora_state( + lora_path=lora_path, lora_nickname=lora_nickname, merge_lora=merge_lora + ) + return self.generate( + sampling_params_kwargs=dict( + prompt=prompt, + sampling_params=sampling_params, + **kwargs, + ) + ) + + def shutdown(self): + """ + Shutdown the generator. + If in local mode, it also shuts down the scheduler server. + """ + # sends the shutdown command to the server + if self.local_scheduler_process and self.owns_scheduler_client: + try: + sync_scheduler_client.forward(ShutdownReq()) + except Exception: + pass + + if self.local_scheduler_process: + for process in self.local_scheduler_process: + process.join(timeout=10) + if process.is_alive(): + logger.warning( + f"Local worker {process.name} did not terminate gracefully, forcing." + ) + process.terminate() + self.local_scheduler_process = None + + if self.owns_scheduler_client: + sync_scheduler_client.close() + self.owns_scheduler_client = False + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.shutdown() + + def __del__(self): + if self.owns_scheduler_client: + logger.warning( + "Generator was garbage collected without being shut down. " + "Attempting to shut down the local server and client." + ) + self.shutdown() + elif self.local_scheduler_process: + logger.warning( + "Generator was garbage collected without being shut down. " + "Attempting to shut down the local server." + ) + self.shutdown() diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py new file mode 100644 index 0000000000000000000000000000000000000000..4830411bfd3ac67e9deeb7b1de07b6a0c7432976 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py @@ -0,0 +1,284 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import asyncio +import base64 +import os +import uuid +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING + +import torch +from fastapi import APIRouter, FastAPI, Request +from fastapi.responses import ORJSONResponse + +from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams +from sglang.multimodal_gen.runtime.entrypoints.openai import image_api, video_api +from sglang.multimodal_gen.runtime.entrypoints.openai.protocol import ( + VertexGenerateReqInput, +) +from sglang.multimodal_gen.runtime.entrypoints.openai.utils import build_sampling_params +from sglang.multimodal_gen.runtime.entrypoints.post_training import weights_api +from sglang.multimodal_gen.runtime.entrypoints.utils import ( + prepare_request, + save_outputs, +) +from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client +from sglang.multimodal_gen.runtime.server_args import ServerArgs, get_global_server_args +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.version import __version__ + +if TYPE_CHECKING: + from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req + +logger = init_logger(__name__) + +DEFAULT_SEED = 1024 +VERTEX_ROUTE = os.environ.get("AIP_PREDICT_ROUTE", "/vertex_generate") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + from sglang.multimodal_gen.runtime.scheduler_client import ( + async_scheduler_client, + run_zeromq_broker, + ) + + # 1. Initialize the singleton client that connects to the backend Scheduler + server_args = app.state.server_args + async_scheduler_client.initialize(server_args) + + # 2. Start the ZMQ Broker in the background to handle offline requests + broker_task = asyncio.create_task(run_zeromq_broker(server_args)) + + yield + + # On shutdown + logger.info("FastAPI app is shutting down...") + broker_task.cancel() + async_scheduler_client.close() + + +# Health router +health_router = APIRouter() + + +@health_router.get("/health") +async def health(): + return {"status": "ok"} + + +@health_router.get("/models", deprecated=True) +async def get_models(request: Request): + """ + Get information about the model served by this server. + + .. deprecated:: + Use /v1/models instead for OpenAI-compatible model discovery. + This endpoint will be removed in a future version. + """ + from sglang.multimodal_gen.registry import get_model_info + + server_args: ServerArgs = request.app.state.server_args + model_info = get_model_info(server_args.model_path, model_id=server_args.model_id) + + response = { + "model_path": server_args.model_path, + "num_gpus": server_args.num_gpus, + "task_type": server_args.pipeline_config.task_type.name, + "dit_precision": server_args.pipeline_config.dit_precision, + "vae_precision": server_args.pipeline_config.vae_precision, + } + + if model_info: + response["pipeline_name"] = model_info.pipeline_cls.pipeline_name + response["pipeline_class"] = model_info.pipeline_cls.__name__ + + return response + + +@health_router.get("/server_info") +async def server_info_endpoint(request: Request): + """Get server information. + + Returns fields compatible with the LLM engine's /server_info so that + the model gateway can discover diffusion workers. + """ + server_args: ServerArgs = request.app.state.server_args + + return { + "model_path": server_args.model_path, + "served_model_name": server_args.model_id or server_args.model_path, + "tp_size": server_args.tp_size, + "dp_size": server_args.dp_size, + "version": __version__, + } + + +@health_router.get("/model_info") +async def model_info_endpoint(request: Request): + """Get model information. + + Returns fields compatible with the LLM engine's /model_info so that + the model gateway can detect capabilities for diffusion workers. + """ + from sglang.multimodal_gen.registry import get_model_info + + server_args: ServerArgs = request.app.state.server_args + task_type = server_args.pipeline_config.task_type + + try: + registry_info = get_model_info( + server_args.model_path, + backend=server_args.backend, + model_id=server_args.model_id, + ) + except Exception: + logger.warning("Failed to resolve model info from registry", exc_info=True) + registry_info = None + + return { + # Fields consumed by the model gateway for worker discovery + "model_path": server_args.model_path, + "is_generation": True, + "model_type": "diffusion", + "architectures": ( + [registry_info.pipeline_cls.__name__] if registry_info else None + ), + # Fields matching the LLM engine's /model_info shape + "has_image_understanding": task_type.accepts_image_input(), + "has_audio_understanding": False, + # Diffusion-specific fields + "task_type": task_type.name, + "is_image_gen": task_type.is_image_gen(), + } + + +@health_router.get("/health_generate") +async def health_generate(): + # TODO : health generate endpoint + return {"status": "ok"} + + +def make_serializable(obj): + """Recursively converts Tensors to None for JSON serialization.""" + if isinstance(obj, torch.Tensor): + return None + if isinstance(obj, dict): + return {k: make_serializable(v) for k, v in obj.items()} + if isinstance(obj, list): + return [make_serializable(v) for v in obj] + return obj + + +def encode_video_to_base64(file_path: str): + if not os.path.exists(file_path): + return None + with open(file_path, "rb") as f: + return base64.b64encode(f.read()).decode("utf-8") + + +async def forward_to_scheduler( + req_obj: "Req", + sp: SamplingParams, +): + """Forwards request to scheduler and processes the result.""" + try: + response = await async_scheduler_client.forward(req_obj) + if response.output is None and response.output_file_paths is None: + raise RuntimeError("Model generation returned no output.") + + if response.output_file_paths: + output_file_path = response.output_file_paths[0] + else: + output_file_path = sp.output_file_path() + save_outputs( + [response.output[0]], + sp.data_type, + sp.fps, + True, + lambda _idx: output_file_path, + audio=response.audio, + audio_sample_rate=response.audio_sample_rate, + enable_frame_interpolation=sp.enable_frame_interpolation, + frame_interpolation_exp=sp.frame_interpolation_exp, + frame_interpolation_scale=sp.frame_interpolation_scale, + frame_interpolation_model_path=sp.frame_interpolation_model_path, + ) + + if hasattr(response, "model_dump"): + data = response.model_dump() + else: + data = response if isinstance(response, dict) else vars(response) + + if output_file_path: + logger.info("Processing output file: %s", output_file_path) + b64_video = encode_video_to_base64(output_file_path) + + if b64_video: + data["output"] = b64_video + data.pop("video_data", None) + data.pop("video_tensor", None) + + return make_serializable(data) + + except Exception as e: + logger.error("Error during generation: %s", e, exc_info=True) + return {"error": str(e)} + + +vertex_router = APIRouter() + + +@vertex_router.post(VERTEX_ROUTE) +async def vertex_generate(vertex_req: VertexGenerateReqInput): + if not vertex_req.instances: + return ORJSONResponse({"predictions": []}) + + server_args = get_global_server_args() + params = vertex_req.parameters or {} + + futures = [] + + for inst in vertex_req.instances: + rid = f"vertex_{uuid.uuid4()}" + + sp = build_sampling_params( + rid, + prompt=inst.get("prompt") or inst.get("text"), + image_path=inst.get("image") or inst.get("image_url"), + seed=params.get("seed", DEFAULT_SEED), + num_frames=params.get("num_frames"), + fps=params.get("fps"), + width=params.get("width"), + height=params.get("height"), + guidance_scale=params.get("guidance_scale"), + save_output=params.get("save_output"), + ) + + backend_req = prepare_request(server_args, sampling_params=sp) + futures.append(forward_to_scheduler(backend_req, sp)) + + results = await asyncio.gather(*futures) + + return ORJSONResponse({"predictions": results}) + + +def create_app(server_args: ServerArgs): + """ + Create and configure the FastAPI application instance. + """ + app = FastAPI(lifespan=lifespan) + + app.include_router(health_router) + app.include_router(vertex_router) + + from sglang.multimodal_gen.runtime.entrypoints.openai import common_api, mesh_api + + app.include_router(common_api.router) + app.include_router(image_api.router) + app.include_router(video_api.router) + app.include_router(mesh_api.router) + app.include_router(weights_api.router) + + app.state.server_args = server_args + return app diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/common_api.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/common_api.py new file mode 100644 index 0000000000000000000000000000000000000000..921f64410a62567c5a47d101383cf9189bf74ac3 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/common_api.py @@ -0,0 +1,249 @@ +import time +from typing import Any, List, Optional, Union + +from fastapi import APIRouter, Body, HTTPException +from fastapi.responses import ORJSONResponse +from pydantic import BaseModel, Field + +from sglang.multimodal_gen.registry import get_model_info +from sglang.multimodal_gen.runtime.entrypoints.utils import ( + ListLorasReq, + MergeLoraWeightsReq, + SetLoraReq, + UnmergeLoraWeightsReq, + format_lora_message, +) +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch +from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client +from sglang.multimodal_gen.runtime.server_args import get_global_server_args +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +router = APIRouter(prefix="/v1") +logger = init_logger(__name__) + + +class ModelCard(BaseModel): + """Model cards.""" + + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "sglang" + root: Optional[str] = None + parent: Optional[str] = None + max_model_len: Optional[int] = None + + +class DiffusionModelCard(ModelCard): + """Extended ModelCard with diffusion-specific fields.""" + + num_gpus: Optional[int] = None + task_type: Optional[str] = None + dit_precision: Optional[str] = None + vae_precision: Optional[str] = None + pipeline_name: Optional[str] = None + pipeline_class: Optional[str] = None + + +async def _handle_lora_request(req: Any, success_msg: str, failure_msg: str): + try: + output: OutputBatch = await async_scheduler_client.forward(req) + if output.error is None: + return {"status": "ok", "message": success_msg} + else: + error_msg = output.error + raise HTTPException(status_code=500, detail=f"{failure_msg}: {error_msg}") + except Exception as e: + if isinstance(e, HTTPException): + raise + logger.error(f"Error during '{failure_msg}': {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/set_lora") +async def set_lora( + lora_nickname: Union[str, List[str]] = Body(..., embed=True), + lora_path: Optional[Union[str, List[Optional[str]]]] = Body(None, embed=True), + target: Union[str, List[str]] = Body("all", embed=True), + strength: Union[float, List[float]] = Body(1.0, embed=True), +): + """ + Set LoRA adapter(s) for the specified transformer(s). + Supports both single LoRA (backward compatible) and multiple LoRA adapters. + + Args: + lora_nickname: The nickname(s) of the adapter(s). Can be a string or a list of strings. + lora_path: Path(s) to the LoRA adapter(s) (local path or HF repo id). + Can be a string, None, or a list of strings/None. Must match the length of lora_nickname. + target: Which transformer(s) to apply the LoRA to. Can be a string or a list of strings. + If a list, must match the length of lora_nickname. Valid values: + - "all": Apply to all transformers (default) + - "transformer": Apply only to the primary transformer (high noise for Wan2.2) + - "transformer_2": Apply only to transformer_2 (low noise for Wan2.2) + - "critic": Apply only to the critic model + strength: LoRA strength(s) for merge, default 1.0. Can be a float or a list of floats. + If a list, must match the length of lora_nickname. Values < 1.0 reduce the effect, + values > 1.0 amplify the effect. + """ + req = SetLoraReq( + lora_nickname=lora_nickname, + lora_path=lora_path, + target=target, + strength=strength, + ) + nickname_str, target_str, strength_str = format_lora_message( + lora_nickname, target, strength + ) + + return await _handle_lora_request( + req, + f"Successfully set LoRA adapter(s): {nickname_str} (target: {target_str}, strength: {strength_str})", + "Failed to set LoRA adapter", + ) + + +@router.post("/merge_lora_weights") +async def merge_lora_weights( + target: str = Body("all", embed=True), + strength: float = Body(1.0, embed=True), +): + """ + Merge LoRA weights into the base model. + + Args: + target: Which transformer(s) to merge. One of "all", "transformer", + "transformer_2", "critic". + strength: LoRA strength for merge, default 1.0. Values < 1.0 reduce the effect, + values > 1.0 amplify the effect. + """ + req = MergeLoraWeightsReq(target=target, strength=strength) + return await _handle_lora_request( + req, + f"Successfully merged LoRA weights (target: {target}, strength: {strength})", + "Failed to merge LoRA weights", + ) + + +@router.post("/unmerge_lora_weights") +async def unmerge_lora_weights( + target: str = Body("all", embed=True), +): + """ + Unmerge LoRA weights from the base model. + + Args: + target: Which transformer(s) to unmerge. One of "all", "transformer", + "transformer_2", "critic". + """ + req = UnmergeLoraWeightsReq(target=target) + return await _handle_lora_request( + req, + f"Successfully unmerged LoRA weights (target: {target})", + "Failed to unmerge LoRA weights", + ) + + +@router.get("/model_info") +async def model_info(): + """Get the model information.""" + server_args = get_global_server_args() + if not server_args: + raise HTTPException(status_code=500, detail="Server args not initialized") + + result = { + "model_path": server_args.model_path, + } + return result + + +@router.get("/list_loras") +async def list_loras(): + """List loaded LoRA adapters and current application status per module.""" + try: + req = ListLorasReq() + output: OutputBatch = await async_scheduler_client.forward(req) + if output.error is None: + return output.output or {} + else: + raise HTTPException(status_code=500, detail=output.error) + except Exception as e: + if isinstance(e, HTTPException): + raise + logger.error(f"Error during 'list_loras': {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/models", response_class=ORJSONResponse) +async def available_models(): + """Show available models. OpenAI-compatible endpoint with extended diffusion info.""" + server_args = get_global_server_args() + if not server_args: + raise HTTPException(status_code=500, detail="Server args not initialized") + + model_info = get_model_info( + server_args.model_path, + backend=server_args.backend, + model_id=server_args.model_id, + ) + + card_kwargs = { + "id": server_args.model_path, + "root": server_args.model_path, + # Extended diffusion-specific fields + "num_gpus": server_args.num_gpus, + "task_type": server_args.pipeline_config.task_type.name, + "dit_precision": server_args.pipeline_config.dit_precision, + "vae_precision": server_args.pipeline_config.vae_precision, + } + + if model_info: + card_kwargs["pipeline_name"] = model_info.pipeline_cls.pipeline_name + card_kwargs["pipeline_class"] = model_info.pipeline_cls.__name__ + + model_card = DiffusionModelCard(**card_kwargs) + + # Return dict directly to preserve extended fields (ModelList strips them) + return {"object": "list", "data": [model_card.model_dump()]} + + +@router.get("/models/{model:path}", response_class=ORJSONResponse) +async def retrieve_model(model: str): + """Retrieve a model instance. OpenAI-compatible endpoint with extended diffusion info.""" + server_args = get_global_server_args() + if not server_args: + raise HTTPException(status_code=500, detail="Server args not initialized") + + if model != server_args.model_path: + return ORJSONResponse( + status_code=404, + content={ + "error": { + "message": f"The model '{model}' does not exist", + "type": "invalid_request_error", + "param": "model", + "code": "model_not_found", + } + }, + ) + + model_info = get_model_info( + server_args.model_path, + backend=server_args.backend, + model_id=server_args.model_id, + ) + + card_kwargs = { + "id": model, + "root": model, + "num_gpus": server_args.num_gpus, + "task_type": server_args.pipeline_config.task_type.name, + "dit_precision": server_args.pipeline_config.dit_precision, + "vae_precision": server_args.pipeline_config.vae_precision, + } + + if model_info: + card_kwargs["pipeline_name"] = model_info.pipeline_cls.pipeline_name + card_kwargs["pipeline_class"] = model_info.pipeline_cls.__name__ + + # Return dict to preserve extended fields + return DiffusionModelCard(**card_kwargs).model_dump() diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py new file mode 100644 index 0000000000000000000000000000000000000000..9d4cb410c18bb258276a056bea522cd86ba5685d --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py @@ -0,0 +1,344 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import base64 +import contextlib +import os +import time +from typing import List, Optional + +from fastapi import APIRouter, File, Form, HTTPException, Path, Query, UploadFile +from fastapi.responses import FileResponse + +from sglang.multimodal_gen.configs.sample.sampling_params import generate_request_id +from sglang.multimodal_gen.runtime.entrypoints.openai.protocol import ( + ImageGenerationsRequest, + ImageResponse, + ImageResponseData, +) +from sglang.multimodal_gen.runtime.entrypoints.openai.storage import cloud_storage +from sglang.multimodal_gen.runtime.entrypoints.openai.stores import IMAGE_STORE +from sglang.multimodal_gen.runtime.entrypoints.openai.utils import ( + add_common_data_to_response, + build_sampling_params, + choose_output_image_ext, + merge_image_input_list, + process_generation_batch, + save_image_to_path, + temp_dir_if_disabled, +) +from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch +from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client +from sglang.multimodal_gen.runtime.server_args import get_global_server_args +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +router = APIRouter(prefix="/v1/images", tags=["images"]) +logger = init_logger(__name__) + + +def _read_b64_for_paths(paths: list[str]) -> list[str]: + """Read and base64-encode each file. Must be called before cloud upload deletes them.""" + result = [] + for path in paths: + with open(path, "rb") as f: + result.append(base64.b64encode(f.read()).decode("utf-8")) + return result + + +def _build_image_response_kwargs( + save_file_path_list: list[str], + resp_format: str, + prompt: str, + request_id: str, + result: OutputBatch, + *, + b64_list: list[str] | None = None, + cloud_url: str | None = None, + fallback_url: str | None = None, + is_persistent: bool = True, +) -> dict: + """Build ImageResponse data list. + + For b64_json: uses pre-read b64_list (call _read_b64_for_paths first). + For url: uses cloud_url or fallback_url. + file_path is omitted when is_persistent=False to avoid exposing stale temp paths. + """ + ret = None + if resp_format == "b64_json": + if not b64_list: + raise ValueError("b64_list required for b64_json response_format") + data = [ + ImageResponseData( + b64_json=b64, + revised_prompt=prompt, + file_path=os.path.abspath(path) if is_persistent else None, + ) + for b64, path in zip(b64_list, save_file_path_list) + ] + ret = {"data": data} + elif resp_format == "url": + url = cloud_url or fallback_url + if not url: + raise HTTPException( + status_code=400, + detail="response_format='url' requires cloud storage to be configured.", + ) + ret = { + "data": [ + ImageResponseData( + url=url, + revised_prompt=prompt, + file_path=( + os.path.abspath(save_file_path_list[0]) + if is_persistent + else None + ), + ) + ], + } + else: + raise HTTPException( + status_code=400, detail=f"response_format={resp_format} is not supported" + ) + + ret = add_common_data_to_response(ret, request_id=request_id, result=result) + + return ret + + +@router.post("/generations", response_model=ImageResponse) +async def generations( + request: ImageGenerationsRequest, +): + request_id = generate_request_id() + server_args = get_global_server_args() + ext = choose_output_image_ext(request.output_format, request.background) + + with temp_dir_if_disabled(server_args.output_path) as output_dir: + sampling = build_sampling_params( + request_id, + prompt=request.prompt, + size=request.size, + num_outputs_per_prompt=max(1, min(int(request.n or 1), 10)), + output_file_name=f"{request_id}.{ext}", + output_path=output_dir, + seed=request.seed, + generator_device=request.generator_device, + num_inference_steps=request.num_inference_steps, + guidance_scale=request.guidance_scale, + true_cfg_scale=request.true_cfg_scale, + negative_prompt=request.negative_prompt, + enable_teacache=request.enable_teacache, + output_compression=request.output_compression, + output_quality=request.output_quality, + ) + batch = prepare_request( + server_args=server_args, + sampling_params=sampling, + ) + # Add diffusers_kwargs if provided + if request.diffusers_kwargs: + batch.extra["diffusers_kwargs"] = request.diffusers_kwargs + + save_file_path_list, result = await process_generation_batch( + async_scheduler_client, batch + ) + save_file_path = save_file_path_list[0] + resp_format = (request.response_format or "b64_json").lower() + + # read b64 before cloud upload may delete the local file + b64_list = ( + _read_b64_for_paths(save_file_path_list) + if resp_format == "b64_json" + else None + ) + + cloud_url = await cloud_storage.upload_and_cleanup(save_file_path) + + is_persistent = server_args.output_path is not None + await IMAGE_STORE.upsert( + request_id, + { + "id": request_id, + "created_at": int(time.time()), + "file_path": None if cloud_url or not is_persistent else save_file_path, + "url": cloud_url, + }, + ) + + response_kwargs = _build_image_response_kwargs( + save_file_path_list, + resp_format, + request.prompt, + request_id, + result, + b64_list=b64_list, + cloud_url=cloud_url, + fallback_url=f"/v1/images/{request_id}/content" if is_persistent else None, + is_persistent=is_persistent, + ) + + return ImageResponse(**response_kwargs) + + +@router.post("/edits", response_model=ImageResponse) +async def edits( + image: Optional[List[UploadFile]] = File(None), + image_array: Optional[List[UploadFile]] = File(None, alias="image[]"), + url: Optional[List[str]] = Form(None), + url_array: Optional[List[str]] = Form(None, alias="url[]"), + prompt: str = Form(...), + mask: Optional[UploadFile] = File(None), + model: Optional[str] = Form(None), + n: Optional[int] = Form(1), + response_format: Optional[str] = Form(None), + size: Optional[str] = Form(None), + output_format: Optional[str] = Form(None), + background: Optional[str] = Form("auto"), + seed: Optional[int] = Form(1024), + generator_device: Optional[str] = Form("cuda"), + user: Optional[str] = Form(None), + negative_prompt: Optional[str] = Form(None), + guidance_scale: Optional[float] = Form(None), + true_cfg_scale: Optional[float] = Form(None), + num_inference_steps: Optional[int] = Form(None), + output_quality: Optional[str] = Form("default"), + output_compression: Optional[int] = Form(None), + enable_teacache: Optional[bool] = Form(False), + num_frames: int = Form(1), +): + request_id = generate_request_id() + server_args = get_global_server_args() + # Resolve images from either `image` or `image[]` (OpenAI SDK sends `image[]` when list is provided) + images = image or image_array + urls = url or url_array + + if (not images or len(images) == 0) and (not urls or len(urls) == 0): + raise HTTPException( + status_code=422, detail="Field 'image' or 'url' is required" + ) + + image_list = merge_image_input_list(images, urls) + + with contextlib.ExitStack() as stack: + uploads_dir = stack.enter_context( + temp_dir_if_disabled(server_args.input_save_path) + ) + output_dir = stack.enter_context(temp_dir_if_disabled(server_args.output_path)) + + input_paths = [] + try: + for idx, img in enumerate(image_list): + filename = img.filename if hasattr(img, "filename") else f"image_{idx}" + input_path = await save_image_to_path( + img, + os.path.join(uploads_dir, f"{request_id}_{idx}_{filename}"), + ) + input_paths.append(input_path) + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Failed to process image source: {str(e)}", + ) + + ext = choose_output_image_ext(output_format, background) + sampling = build_sampling_params( + request_id, + prompt=prompt, + size=size, + num_outputs_per_prompt=max(1, min(int(n or 1), 10)), + output_file_name=f"{request_id}.{ext}", + output_path=output_dir, + image_path=input_paths, + seed=seed, + generator_device=generator_device, + negative_prompt=negative_prompt, + guidance_scale=guidance_scale, + true_cfg_scale=true_cfg_scale, + num_inference_steps=num_inference_steps, + enable_teacache=enable_teacache, + num_frames=num_frames, + output_compression=output_compression, + output_quality=output_quality, + ) + batch = prepare_request( + server_args=server_args, + sampling_params=sampling, + ) + save_file_path_list, result = await process_generation_batch( + async_scheduler_client, batch + ) + save_file_path = save_file_path_list[0] + resp_format = (response_format or "b64_json").lower() + + # read b64 before cloud upload may delete the local file + b64_list = ( + _read_b64_for_paths(save_file_path_list) + if resp_format == "b64_json" + else None + ) + + cloud_url = await cloud_storage.upload_and_cleanup(save_file_path) + + is_persistent = server_args.output_path is not None + is_input_persistent = server_args.input_save_path is not None + await IMAGE_STORE.upsert( + request_id, + { + "id": request_id, + "created_at": int(time.time()), + "file_path": None if cloud_url or not is_persistent else save_file_path, + "url": cloud_url, + "input_image_paths": input_paths if is_input_persistent else None, + "num_input_images": len(input_paths), + }, + ) + + response_kwargs = _build_image_response_kwargs( + save_file_path_list, + resp_format, + prompt, + request_id, + result, + b64_list=b64_list, + cloud_url=cloud_url, + fallback_url=f"/v1/images/{request_id}/content" if is_persistent else None, + is_persistent=is_persistent, + ) + + return ImageResponse(**response_kwargs) + + +@router.get("/{image_id}/content") +async def download_image_content( + image_id: str = Path(...), variant: Optional[str] = Query(None) +): + item = await IMAGE_STORE.get(image_id) + if not item: + raise HTTPException(status_code=404, detail="Image not found") + + if item.get("url"): + raise HTTPException( + status_code=400, + detail=f"Image has been uploaded to cloud storage. Please use the cloud URL: {item.get('url')}", + ) + + file_path = item.get("file_path") + if not file_path: + raise HTTPException( + status_code=404, + detail="Image was not persisted on disk (output_path is disabled). Use b64_json response_format or configure cloud storage.", + ) + if not os.path.exists(file_path): + raise HTTPException(status_code=404, detail="Image is still being generated") + + ext = os.path.splitext(file_path)[1].lower() + media_type = "image/jpeg" + if ext == ".png": + media_type = "image/png" + elif ext == ".webp": + media_type = "image/webp" + + return FileResponse( + path=file_path, media_type=media_type, filename=os.path.basename(file_path) + ) diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/mesh_api.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/mesh_api.py new file mode 100644 index 0000000000000000000000000000000000000000..ab0b90468bf332bcae617492cade1e14a90df481 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/mesh_api.py @@ -0,0 +1,296 @@ +import asyncio +import os +import time +from typing import Any, Dict, List, Optional + +from fastapi import ( + APIRouter, + File, + Form, + HTTPException, + Path, + Query, + Request, + UploadFile, +) +from fastapi.responses import FileResponse + +from sglang.multimodal_gen.configs.sample.sampling_params import ( + SamplingParams, + generate_request_id, +) +from sglang.multimodal_gen.runtime.entrypoints.openai.protocol import ( + MeshGenerationsRequest, + MeshListResponse, + MeshResponse, +) +from sglang.multimodal_gen.runtime.entrypoints.openai.storage import cloud_storage +from sglang.multimodal_gen.runtime.entrypoints.openai.stores import MESH_STORE +from sglang.multimodal_gen.runtime.entrypoints.openai.utils import ( + add_common_data_to_response, + merge_image_input_list, + process_generation_batch, + save_image_to_path, +) +from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.server_args import get_global_server_args +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) +router = APIRouter(prefix="/v1/meshes", tags=["meshes"]) + + +def _normalize_format(fmt: Optional[str]) -> str: + fmt = (fmt or "glb").lower() + return fmt if fmt in ("glb", "obj") else "glb" + + +def _build_sampling_params_from_request( + request_id: str, req: MeshGenerationsRequest, image_path: Optional[str] = None +) -> SamplingParams: + ext = _normalize_format(req.output_format) + + server_args = get_global_server_args() + sampling_kwargs: Dict[str, Any] = { + "request_id": request_id, + "prompt": req.prompt, + "num_frames": 1, + "image_path": [image_path] if image_path else None, + "save_output": True, + "output_file_name": f"{request_id}.{ext}", + "seed": req.seed, + "generator_device": req.generator_device, + } + if req.num_inference_steps is not None: + sampling_kwargs["num_inference_steps"] = req.num_inference_steps + if req.guidance_scale is not None: + sampling_kwargs["guidance_scale"] = req.guidance_scale + if req.negative_prompt is not None: + sampling_kwargs["negative_prompt"] = req.negative_prompt + + return SamplingParams.from_user_sampling_params_args( + model_path=server_args.model_path, + server_args=server_args, + **sampling_kwargs, + ) + + +def _mesh_job_from_sampling( + request_id: str, req: MeshGenerationsRequest, sampling: SamplingParams +) -> Dict[str, Any]: + return { + "id": request_id, + "object": "mesh", + "model": req.model or "", + "status": "queued", + "progress": 0, + "created_at": int(time.time()), + "format": _normalize_format(req.output_format), + "file_path": os.path.abspath(sampling.output_file_path()), + } + + +async def _dispatch_job_async(job_id: str, batch: Req) -> None: + from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client + + try: + save_file_path_list, result = await process_generation_batch( + async_scheduler_client, batch + ) + save_file_path = save_file_path_list[0] + + file_size = None + if os.path.exists(save_file_path): + file_size = os.path.getsize(save_file_path) + + cloud_url = await cloud_storage.upload_and_cleanup(save_file_path) + + update_fields: Dict[str, Any] = { + "status": "completed", + "progress": 100, + "completed_at": int(time.time()), + "url": cloud_url, + "file_path": save_file_path if not cloud_url else None, + "file_size_bytes": file_size, + } + update_fields = add_common_data_to_response( + update_fields, request_id=job_id, result=result + ) + await MESH_STORE.update_fields(job_id, update_fields) + except Exception as e: + logger.error(f"{e}") + await MESH_STORE.update_fields( + job_id, {"status": "failed", "error": {"message": str(e)}} + ) + + +@router.post("", response_model=MeshResponse) +async def create_mesh( + request: Request, + image: Optional[List[UploadFile]] = File(None), + image_array: Optional[List[UploadFile]] = File(None, alias="image[]"), + url: Optional[List[str]] = Form(None), + url_array: Optional[List[str]] = Form(None, alias="url[]"), + prompt: Optional[str] = Form("generate 3d mesh"), + model: Optional[str] = Form(None), + seed: Optional[int] = Form(None), + generator_device: Optional[str] = Form("cuda"), + guidance_scale: Optional[float] = Form(None), + num_inference_steps: Optional[int] = Form(None), + negative_prompt: Optional[str] = Form(None), + output_format: Optional[str] = Form("glb"), +): + content_type = request.headers.get("content-type", "").lower() + request_id = generate_request_id() + server_args = get_global_server_args() + + input_path = None + + if "multipart/form-data" in content_type: + images = image or image_array + urls = url or url_array + image_list = merge_image_input_list(images, urls) + + if not image_list: + raise HTTPException( + status_code=422, + detail="Field 'image' or 'url' is required for mesh generation", + ) + + uploads_dir = os.path.join("outputs", "uploads") + os.makedirs(uploads_dir, exist_ok=True) + img = image_list[0] + filename = img.filename if hasattr(img, "filename") else "input_image" + try: + input_path = await save_image_to_path( + img, os.path.join(uploads_dir, f"{request_id}_{filename}") + ) + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Failed to process image source: {str(e)}" + ) + + req = MeshGenerationsRequest( + prompt=prompt or "generate 3d mesh", + model=model, + seed=seed, + generator_device=generator_device, + num_inference_steps=num_inference_steps, + negative_prompt=negative_prompt, + output_format=output_format, + **( + {"guidance_scale": guidance_scale} if guidance_scale is not None else {} + ), + ) + else: + try: + body = await request.json() + except Exception: + body = {} + try: + payload: Dict[str, Any] = dict(body or {}) + + if payload.get("input_image"): + img_src = payload.pop("input_image") + uploads_dir = os.path.join("outputs", "uploads") + os.makedirs(uploads_dir, exist_ok=True) + input_path = await save_image_to_path( + img_src, + os.path.join(uploads_dir, f"{request_id}_input_image"), + ) + + req = MeshGenerationsRequest(**payload) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Invalid request body: {e}") + + if not input_path: + raise HTTPException( + status_code=422, + detail="An input image is required for mesh generation", + ) + + sampling_params = _build_sampling_params_from_request(request_id, req, input_path) + job = _mesh_job_from_sampling(request_id, req, sampling_params) + await MESH_STORE.upsert(request_id, job) + + batch = prepare_request( + server_args=server_args, + sampling_params=sampling_params, + ) + + asyncio.create_task(_dispatch_job_async(request_id, batch)) + return MeshResponse(**job) + + +@router.get("", response_model=MeshListResponse) +async def list_meshes( + after: Optional[str] = Query(None), + limit: Optional[int] = Query(None, ge=1, le=100), + order: Optional[str] = Query("desc"), +): + order = (order or "desc").lower() + if order not in ("asc", "desc"): + order = "desc" + jobs = await MESH_STORE.list_values() + + reverse = order != "asc" + jobs.sort(key=lambda j: j.get("created_at", 0), reverse=reverse) + + if after is not None: + try: + idx = next(i for i, j in enumerate(jobs) if j["id"] == after) + jobs = jobs[idx + 1 :] + except StopIteration: + jobs = [] + + if limit is not None: + jobs = jobs[:limit] + items = [MeshResponse(**j) for j in jobs] + return MeshListResponse(data=items) + + +@router.get("/{mesh_id}", response_model=MeshResponse) +async def retrieve_mesh(mesh_id: str = Path(...)): + job = await MESH_STORE.get(mesh_id) + if not job: + raise HTTPException(status_code=404, detail="Mesh not found") + return MeshResponse(**job) + + +@router.delete("/{mesh_id}", response_model=MeshResponse) +async def delete_mesh(mesh_id: str = Path(...)): + job = await MESH_STORE.pop(mesh_id) + if not job: + raise HTTPException(status_code=404, detail="Mesh not found") + job["status"] = "deleted" + return MeshResponse(**job) + + +@router.get("/{mesh_id}/content") +async def download_mesh_content( + mesh_id: str = Path(...), variant: Optional[str] = Query(None) +): + job = await MESH_STORE.get(mesh_id) + if not job: + raise HTTPException(status_code=404, detail="Mesh not found") + + if job.get("url"): + raise HTTPException( + status_code=400, + detail=f"Mesh has been uploaded to cloud storage. Please use the cloud URL: {job.get('url')}", + ) + + file_path = job.get("file_path") + if not file_path or not os.path.exists(file_path): + raise HTTPException(status_code=404, detail="Generation is still in-progress") + + ext = os.path.splitext(file_path)[1].lower() + media_type = { + ".glb": "model/gltf-binary", + ".obj": "text/plain", + }.get(ext, "application/octet-stream") + + return FileResponse( + path=file_path, media_type=media_type, filename=os.path.basename(file_path) + ) diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..c959e2f22259b74404dc3a8a4866e8f37af083a0 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py @@ -0,0 +1,166 @@ +import time +import uuid +from abc import ABC +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel, Field + + +# Image API protocol models +class ImageResponseData(BaseModel): + b64_json: Optional[str] = None + url: Optional[str] = None + revised_prompt: Optional[str] = None + file_path: Optional[str] = None + + +class ImageResponse(BaseModel): + id: str + created: int = Field(default_factory=lambda: int(time.time())) + data: List[ImageResponseData] + peak_memory_mb: Optional[float] = None + inference_time_s: Optional[float] = None + + +class ImageGenerationsRequest(BaseModel): + prompt: str + model: Optional[str] = None + n: Optional[int] = 1 + quality: Optional[str] = "auto" + response_format: Optional[str] = "url" # url | b64_json + size: Optional[str] = "1024x1024" # e.g., 1024x1024 + style: Optional[str] = "vivid" + background: Optional[str] = "auto" # transparent | opaque | auto + output_format: Optional[str] = None # png | jpeg | webp + user: Optional[str] = None + # SGLang extensions + num_inference_steps: Optional[int] = None + guidance_scale: Optional[float] = None + true_cfg_scale: Optional[float] = ( + None # for CFG vs guidance distillation (e.g., QwenImage) + ) + seed: Optional[int] = 1024 + generator_device: Optional[str] = "cuda" + negative_prompt: Optional[str] = None + output_quality: Optional[str] = "default" + output_compression: Optional[int] = None + enable_teacache: Optional[bool] = False + diffusers_kwargs: Optional[Dict[str, Any]] = None # kwargs for diffusers backend + + +# Video API protocol models +class VideoResponse(BaseModel): + id: str + object: str = "video" + model: str = "sora-2" + status: str = "queued" + progress: int = 0 + created_at: int = Field(default_factory=lambda: int(time.time())) + size: str = "" + seconds: str = "4" + quality: str = "standard" + url: Optional[str] = None + remixed_from_video_id: Optional[str] = None + completed_at: Optional[int] = None + expires_at: Optional[int] = None + error: Optional[Dict[str, Any]] = None + file_path: Optional[str] = None + peak_memory_mb: Optional[float] = None + inference_time_s: Optional[float] = None + + +class VideoGenerationsRequest(BaseModel): + prompt: str + input_reference: Optional[str] = None + reference_url: Optional[str] = None + model: Optional[str] = None + seconds: Optional[int] = 4 + size: Optional[str] = "" + fps: Optional[int] = None + num_frames: Optional[int] = None + seed: Optional[int] = 1024 + generator_device: Optional[str] = "cuda" + # SGLang extensions + num_inference_steps: Optional[int] = None + guidance_scale: Optional[float] = None + guidance_scale_2: Optional[float] = None + true_cfg_scale: Optional[float] = ( + None # for CFG vs guidance distillation (e.g., QwenImage) + ) + negative_prompt: Optional[str] = None + enable_teacache: Optional[bool] = False + # Frame interpolation + enable_frame_interpolation: Optional[bool] = False + frame_interpolation_exp: Optional[int] = 1 # 1=2×, 2=4× + frame_interpolation_scale: Optional[float] = 1.0 + frame_interpolation_model_path: Optional[str] = None + output_quality: Optional[str] = "default" + output_compression: Optional[int] = None + output_path: Optional[str] = None + diffusers_kwargs: Optional[Dict[str, Any]] = None # kwargs for diffusers backend + + +class VideoListResponse(BaseModel): + data: List[VideoResponse] + object: str = "list" + + +class VideoRemixRequest(BaseModel): + prompt: str + + +# Mesh API protocol models +class MeshResponse(BaseModel): + id: str + object: str = "mesh" + model: str = "" + status: str = "queued" + progress: int = 0 + created_at: int = Field(default_factory=lambda: int(time.time())) + format: str = "glb" + url: Optional[str] = None + completed_at: Optional[int] = None + expires_at: Optional[int] = None + error: Optional[Dict[str, Any]] = None + file_path: Optional[str] = None + file_size_bytes: Optional[int] = None + peak_memory_mb: Optional[float] = None + inference_time_s: Optional[float] = None + + +class MeshGenerationsRequest(BaseModel): + prompt: str = "generate 3d mesh" + input_image: Optional[str] = None + model: Optional[str] = None + seed: Optional[int] = None + generator_device: Optional[str] = "cuda" + num_inference_steps: Optional[int] = None + guidance_scale: Optional[float] = None + negative_prompt: Optional[str] = None + output_format: Optional[str] = "glb" + + +class MeshListResponse(BaseModel): + data: List[MeshResponse] + object: str = "list" + + +@dataclass +class BaseReq(ABC): + rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True) + http_worker_ipc: Optional[str] = field(default=None, kw_only=True) + + def regenerate_rid(self): + """Generate a new request ID and return it.""" + if isinstance(self.rid, list): + self.rid = [uuid.uuid4().hex for _ in range(len(self.rid))] + else: + self.rid = uuid.uuid4().hex + return self.rid + + +@dataclass +class VertexGenerateReqInput(BaseReq): + instances: List[dict] + parameters: Optional[dict] = None diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/storage.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/storage.py new file mode 100644 index 0000000000000000000000000000000000000000..68f66827c26393c6520b9481a1efd16d3a250af0 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/storage.py @@ -0,0 +1,109 @@ +import asyncio +import os +from typing import Optional + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class CloudStorage: + def __init__(self): + self.enabled = os.getenv("SGLANG_CLOUD_STORAGE_TYPE", "").lower() == "s3" + if not self.enabled: + return + + try: + import boto3 + except ImportError: + logger.error( + "boto3 is not installed. Please install it with `pip install boto3` to use cloud storage." + ) + self.enabled = False + return + + self.bucket_name = os.getenv("SGLANG_S3_BUCKET_NAME") + if not self.bucket_name: + self.enabled = False + return + + endpoint_url = os.getenv("SGLANG_S3_ENDPOINT_URL") or None + region_name = os.getenv("SGLANG_S3_REGION_NAME") or None + + self.client = boto3.client( + "s3", + aws_access_key_id=os.getenv("SGLANG_S3_ACCESS_KEY_ID"), + aws_secret_access_key=os.getenv("SGLANG_S3_SECRET_ACCESS_KEY"), + endpoint_url=endpoint_url, + region_name=region_name, + ) + self.endpoint_url = endpoint_url + self.region_name = region_name + + def is_enabled(self) -> bool: + return self.enabled + + async def upload_file(self, local_path: str, destination_key: str) -> Optional[str]: + if not self.is_enabled(): + return None + + def _sync_upload(): + """Synchronous part of the upload to run in a thread.""" + ext = os.path.splitext(local_path)[1].lower() + content_type = { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".webp": "image/webp", + ".mp4": "video/mp4", + ".glb": "model/gltf-binary", + ".obj": "text/plain", + }.get(ext, "application/octet-stream") + + # Use the client created once in __init__ + self.client.upload_file( + local_path, + self.bucket_name, + destination_key, + ExtraArgs={"ContentType": content_type}, + ) + + try: + # Offload the blocking I/O call to a thread executor + await asyncio.get_running_loop().run_in_executor(None, _sync_upload) + except Exception as e: + # If upload fails, log the error and return None for fallback + logger.error(f"Upload failed for {destination_key}: {e}") + return None + + # Simplified URL generation with a default region + if self.endpoint_url: + url = ( + f"{self.endpoint_url.rstrip('/')}/{self.bucket_name}/{destination_key}" + ) + else: + region = self.region_name or "us-east-1" + url = f"https://{self.bucket_name}.s3.{region}.amazonaws.com/{destination_key}" + + logger.info(f"Uploaded {local_path} to {url}") + return url + + async def upload_and_cleanup(self, file_path: str) -> Optional[str]: + """Helper to upload a file and delete the local copy if successful.""" + if not self.is_enabled(): + return None + + key = os.path.basename(file_path) + url = await self.upload_file(file_path, key) + + if url: + try: + # pass if removal fails + os.remove(file_path) + except OSError as e: + logger.warning(f"Failed to remove temporary file {file_path}: {e}") + return url + + +# Global instance +cloud_storage = CloudStorage() diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/stores.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/stores.py new file mode 100644 index 0000000000000000000000000000000000000000..29622f6513be8f015f5d140f48ba13b6db07505b --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/stores.py @@ -0,0 +1,48 @@ +import asyncio +from typing import Any, Dict, List, Optional + + +class AsyncDictStore: + """A small async-safe in-memory key-value store for dict items. + + This encapsulates the usual pattern of a module-level dict guarded by + an asyncio.Lock and provides simple CRUD methods that are safe to call + concurrently from FastAPI request handlers and background tasks. + """ + + def __init__(self) -> None: + self._items: Dict[str, Dict[str, Any]] = {} + self._lock = asyncio.Lock() + + async def upsert(self, key: str, value: Dict[str, Any]) -> None: + async with self._lock: + self._items[key] = value + + async def update_fields( + self, key: str, updates: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + async with self._lock: + item = self._items.get(key) + if item is None: + return None + item.update(updates) + return item + + async def get(self, key: str) -> Optional[Dict[str, Any]]: + async with self._lock: + return self._items.get(key) + + async def pop(self, key: str) -> Optional[Dict[str, Any]]: + async with self._lock: + return self._items.pop(key, None) + + async def list_values(self) -> List[Dict[str, Any]]: + async with self._lock: + return list(self._items.values()) + + +# Global stores shared by OpenAI entrypoints +# [request_id, dict] +VIDEO_STORE = AsyncDictStore() +IMAGE_STORE = AsyncDictStore() +MESH_STORE = AsyncDictStore() diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/utils.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e66d96ada3efef4f45a09023f857697b46e5f264 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/utils.py @@ -0,0 +1,342 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo +import base64 +import os +import re +import shutil +import tempfile +import time +from contextlib import contextmanager +from typing import Any, Generator, List, Optional, Union + +import httpx +from fastapi import UploadFile + +from sglang.multimodal_gen.configs.sample.sampling_params import ( + DataType, + SamplingParams, +) +from sglang.multimodal_gen.runtime.entrypoints.utils import ( + ListLorasReq, + MergeLoraWeightsReq, + SetLoraReq, + ShutdownReq, + UnmergeLoraWeightsReq, + format_lora_message, + save_outputs, +) +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch +from sglang.multimodal_gen.runtime.scheduler_client import AsyncSchedulerClient +from sglang.multimodal_gen.runtime.server_args import get_global_server_args +from sglang.multimodal_gen.runtime.utils.logging_utils import ( + init_logger, + log_batch_completion, + log_generation_timer, +) + +# re-export LoRA protocol types for backward compatibility +__all__ = [ + "SetLoraReq", + "MergeLoraWeightsReq", + "UnmergeLoraWeightsReq", + "ListLorasReq", + "ShutdownReq", + "format_lora_message", +] + +logger = init_logger(__name__) + +OUTPUT_QUALITY_MAPPER = {"maximum": 100, "high": 90, "medium": 55, "low": 35} +DEFAULT_FPS = 24 +DEFAULT_VIDEO_SECONDS = 4 + + +@contextmanager +def temp_dir_if_disabled( + configured_path: str | None, +) -> Generator[str, None, None]: + """Yield *configured_path* when it is set, otherwise create a temporary + directory that is automatically removed when the context exits.""" + if configured_path is not None: + os.makedirs(configured_path, exist_ok=True) + yield configured_path + else: + tmp = tempfile.mkdtemp(prefix="sglang_") + try: + yield tmp + finally: + shutil.rmtree(tmp, ignore_errors=True) + + +def _parse_size(size: str) -> tuple[int, int] | tuple[None, None]: + try: + parts = size.lower().replace(" ", "").split("x") + if len(parts) != 2: + raise ValueError + w, h = int(parts[0]), int(parts[1]) + return w, h + except Exception: + return None, None + + +def choose_output_image_ext( + output_format: Optional[str], background: Optional[str] +) -> str: + fmt = (output_format or "").lower() + if fmt in {"png", "webp", "jpeg", "jpg"}: + return "jpg" if fmt == "jpeg" else fmt + if (background or "auto").lower() == "transparent": + return "png" + return "jpg" + + +def build_sampling_params(request_id: str, **kwargs) -> SamplingParams: + """Build SamplingParams from request parameters. + + Handles size parsing, output_quality resolution, and None filtering before + delegating to SamplingParams.from_user_sampling_params_args. Callers pass + only the parameters they have; None values are stripped automatically so + that SamplingParams defaults apply. + """ + server_args = get_global_server_args() + + # pop HTTP-layer params that aren't SamplingParams fields + output_quality = kwargs.pop("output_quality", None) + + has_explicit_compression = kwargs.get("output_compression") is not None + + # parse "WxH" size string if provided + size = kwargs.pop("size", None) + if size: + w, h = _parse_size(size) + if w is not None: + kwargs.setdefault("width", w) + kwargs.setdefault("height", h) + + # filter out None values to let SamplingParams defaults apply + kwargs = {k: v for k, v in kwargs.items() if v is not None} + kwargs.setdefault("save_output", True) + + sampling_params = SamplingParams.from_user_sampling_params_args( + model_path=server_args.model_path, + server_args=server_args, + request_id=request_id, + **kwargs, + ) + + # resolve output_quality → output_compression with the correct data_type. + # SamplingParams.__post_init__ may have resolved with the wrong data_type + # (default VIDEO) before _adjust() set the correct one. + if not has_explicit_compression and output_quality is not None: + resolved = adjust_output_quality(output_quality, sampling_params.data_type) + if resolved is not None: + sampling_params.output_compression = resolved + + return sampling_params + + +async def save_image_to_path(image: Union[UploadFile, str], target_path: str) -> str: + input_path = await _maybe_url_image(image, target_path) + if input_path is None: + input_path = await _save_upload_to_path(image, target_path) + return input_path + + +# Helpers +async def _save_upload_to_path(upload: UploadFile, target_path: str) -> str: + os.makedirs(os.path.dirname(target_path), exist_ok=True) + content = await upload.read() + with open(target_path, "wb") as f: + f.write(content) + return target_path + + +async def _maybe_url_image(img_url: str, target_path: str) -> str | None: + if not isinstance(img_url, str): + return None + + if img_url.lower().startswith(("http://", "https://")): + # Download image from URL + input_path = await _save_url_image_to_path(img_url, target_path) + return input_path + elif img_url.startswith("data:image"): + # encode image base64 url + input_path = await _save_base64_image_to_path(img_url, target_path) + return input_path + else: + raise ValueError("Unsupported image url format") + + +async def _save_url_image_to_path(image_url: str, target_path: str) -> str: + """Download image from URL and save to target path.""" + + os.makedirs(os.path.dirname(target_path), exist_ok=True) + + try: + async with httpx.AsyncClient(follow_redirects=True) as client: + response = await client.get(image_url, timeout=10.0) + response.raise_for_status() + + # Determine file extension from content type or URL after downloading + if not os.path.splitext(target_path)[1]: + content_type = response.headers.get("content-type", "").lower() + + url_path = image_url.split("?")[0] + _, url_ext = os.path.splitext(url_path) + url_ext = url_ext.lower() + + if url_ext in {".jpg", ".jpeg", ".png", ".webp", ".gif", ".bmp"}: + ext = ".jpg" if url_ext == ".jpeg" else url_ext + elif content_type.startswith("image/"): + if "jpeg" in content_type or "jpg" in content_type: + ext = ".jpg" + elif "png" in content_type: + ext = ".png" + elif "webp" in content_type: + ext = ".webp" + else: + ext = ".jpg" # Default to jpg + elif content_type == "application/octet-stream": + # for octet-stream, if we couldn't get it from URL, default to jpg + ext = ".jpg" + else: + raise ValueError( + f"URL does not point to an image. Content-Type: {content_type}" + ) + target_path = f"{target_path}{ext}" + + with open(target_path, "wb") as f: + f.write(response.content) + + return target_path + except Exception as e: + raise Exception(f"Failed to download image from URL: {str(e)}") + + +async def _save_base64_image_to_path(base64_data: str, target_path: str) -> str: + """Decode base64 image data and save to target path.""" + + _B64_FMT_HINT = ( + "Failed to decode base64 image. " + "Expected format: `data:[];base64,`" + ) + + # split `data:[][;base64],` to media-type base64 data + pattern = r"data:(.*?)(;base64)?,(.*)" + match = re.match(pattern, base64_data) + if not match: + raise ValueError(_B64_FMT_HINT) + media_type = match.group(1) + is_base64 = match.group(2) + if not is_base64: + raise ValueError(f"{_B64_FMT_HINT} (missing ;base64 marker)") + data = match.group(3) + if not data: + raise ValueError(f"{_B64_FMT_HINT} (empty data payload)") + # get ext from url + if media_type.startswith("image/"): + ext = media_type.split("/")[-1].lower() + if ext == "jpeg": + ext = "jpg" + else: + ext = "jpg" + target_path = f"{target_path}.{ext}" + os.makedirs(os.path.dirname(target_path), exist_ok=True) + + try: + image_data = base64.b64decode(data) + with open(target_path, "wb") as f: + f.write(image_data) + + return target_path + except Exception as e: + raise Exception(f"Failed to decode base64 image: {str(e)}") + + +async def process_generation_batch( + scheduler_client: AsyncSchedulerClient, + batch, +) -> tuple[list[str], OutputBatch]: + total_start_time = time.perf_counter() + with log_generation_timer(logger, batch.prompt): + result = await scheduler_client.forward([batch]) + + if result.output is None and result.output_file_paths is None: + error_msg = result.error or "Unknown error" + raise RuntimeError( + f"Model generation returned no output. Error from scheduler: {error_msg}" + ) + + if result.output_file_paths: + save_file_path_list = result.output_file_paths + else: + num_outputs = len(result.output) + save_file_path_list = save_outputs( + result.output, + batch.data_type, + batch.fps, + batch.save_output, + lambda idx: str(batch.output_file_path(num_outputs, idx)), + audio=result.audio, + audio_sample_rate=result.audio_sample_rate, + output_compression=batch.output_compression, + enable_frame_interpolation=batch.enable_frame_interpolation, + frame_interpolation_exp=batch.frame_interpolation_exp, + frame_interpolation_scale=batch.frame_interpolation_scale, + frame_interpolation_model_path=batch.frame_interpolation_model_path, + ) + + total_time = time.perf_counter() - total_start_time + log_batch_completion(logger, 1, total_time) + + if result.peak_memory_mb and result.peak_memory_mb > 0: + logger.info(f"Peak memory usage: {result.peak_memory_mb:.2f} MB") + + return save_file_path_list, result + + +def merge_image_input_list(*inputs: Union[List, Any, None]) -> List: + """ + Merge multiple image input sources into a single list. + + This function handles both single items and lists of items, merging them + into a single flattened list. Useful for processing images, URLs, or other + multimedia inputs that can come as either single items or lists. + + Args: + *inputs: Variable number of inputs, each can be None, single item, or list + + Returns: + List: Flattened list of all non-None inputs + + Example: + >>> merge_image_input_list(["img1", "img2"], "img3", None) + ["img1", "img2", "img3"] + """ + result = [] + for input_item in inputs: + if input_item is not None: + if isinstance(input_item, list): + result.extend(input_item) + else: + result.append(input_item) + return result + + +def add_common_data_to_response( + response: dict, request_id: str, result: OutputBatch +) -> dict: + if result.peak_memory_mb and result.peak_memory_mb > 0: + response["peak_memory_mb"] = result.peak_memory_mb + + if result.metrics and result.metrics.total_duration_s > 0: + response["inference_time_s"] = result.metrics.total_duration_s + + response["id"] = request_id + + return response + + +def adjust_output_quality(output_quality: str, data_type: DataType = None) -> int: + if output_quality == "default": + return 50 if data_type == DataType.VIDEO else 75 + return OUTPUT_QUALITY_MAPPER.get(output_quality, None) diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py new file mode 100644 index 0000000000000000000000000000000000000000..9db0fde3ca61049475000f4ec16d71df6108286a --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py @@ -0,0 +1,415 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import asyncio +import json +import os +import shutil +import tempfile +import time +from typing import Any, Dict, Optional + +from fastapi import ( + APIRouter, + File, + Form, + HTTPException, + Path, + Query, + Request, + UploadFile, +) +from fastapi.responses import FileResponse + +from sglang.multimodal_gen.configs.sample.sampling_params import ( + SamplingParams, + generate_request_id, +) +from sglang.multimodal_gen.runtime.entrypoints.openai.protocol import ( + VideoGenerationsRequest, + VideoListResponse, + VideoResponse, +) +from sglang.multimodal_gen.runtime.entrypoints.openai.storage import cloud_storage +from sglang.multimodal_gen.runtime.entrypoints.openai.stores import VIDEO_STORE +from sglang.multimodal_gen.runtime.entrypoints.openai.utils import ( + DEFAULT_FPS, + DEFAULT_VIDEO_SECONDS, + add_common_data_to_response, + build_sampling_params, + merge_image_input_list, + process_generation_batch, + save_image_to_path, +) +from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.server_args import get_global_server_args +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) +router = APIRouter(prefix="/v1/videos", tags=["videos"]) + + +def _build_video_sampling_params(request_id: str, request: VideoGenerationsRequest): + """Resolve video-specific defaults (fps, seconds → num_frames) then + delegate to the shared build_sampling_params.""" + seconds = request.seconds if request.seconds is not None else DEFAULT_VIDEO_SECONDS + fps = request.fps if request.fps is not None else DEFAULT_FPS + num_frames = request.num_frames if request.num_frames is not None else fps * seconds + + return build_sampling_params( + request_id, + prompt=request.prompt, + size=request.size, + num_frames=num_frames, + fps=fps, + image_path=request.input_reference, + output_file_name=request_id, + seed=request.seed, + generator_device=request.generator_device, + num_inference_steps=request.num_inference_steps, + guidance_scale=request.guidance_scale, + guidance_scale_2=request.guidance_scale_2, + negative_prompt=request.negative_prompt, + enable_teacache=request.enable_teacache, + enable_frame_interpolation=request.enable_frame_interpolation, + frame_interpolation_exp=request.frame_interpolation_exp, + frame_interpolation_scale=request.frame_interpolation_scale, + frame_interpolation_model_path=request.frame_interpolation_model_path, + output_path=request.output_path, + output_compression=request.output_compression, + output_quality=request.output_quality, + ) + + +# extract metadata which http_server needs to know +def _video_job_from_sampling( + request_id: str, req: VideoGenerationsRequest, sampling: SamplingParams +) -> Dict[str, Any]: + size_str = f"{sampling.width}x{sampling.height}" + seconds = int(round((sampling.num_frames or 0) / float(sampling.fps or 24))) + return { + "id": request_id, + "object": "video", + "model": req.model or "sora-2", + "status": "queued", + "progress": 0, + "created_at": int(time.time()), + "size": size_str, + "seconds": str(seconds), + "quality": "standard", + "file_path": os.path.abspath(sampling.output_file_path()), + } + + +async def _save_first_input_image( + image_sources, request_id: str, uploads_dir: str +) -> str | None: + """Save the first input image from a list of sources and return its path.""" + image_list = merge_image_input_list(image_sources) + if not image_list: + return None + image = image_list[0] + + os.makedirs(uploads_dir, exist_ok=True) + + filename = image.filename if hasattr(image, "filename") else "url_image" + target_path = os.path.join(uploads_dir, f"{request_id}_{filename}") + return await save_image_to_path(image, target_path) + + +async def _dispatch_job_async( + job_id: str, + batch: Req, + *, + temp_dirs: list[str] | None = None, + output_persistent: bool = True, +) -> None: + from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client + + try: + save_file_path_list, result = await process_generation_batch( + async_scheduler_client, batch + ) + save_file_path = save_file_path_list[0] + + cloud_url = await cloud_storage.upload_and_cleanup(save_file_path) + + persistent_path = ( + save_file_path if not cloud_url and output_persistent else None + ) + update_fields = { + "status": "completed", + "progress": 100, + "completed_at": int(time.time()), + "url": cloud_url, + "file_path": persistent_path, + } + update_fields = add_common_data_to_response( + update_fields, request_id=job_id, result=result + ) + await VIDEO_STORE.update_fields(job_id, update_fields) + except Exception as e: + logger.error(f"{e}") + await VIDEO_STORE.update_fields( + job_id, {"status": "failed", "error": {"message": str(e)}} + ) + finally: + for td in temp_dirs or []: + shutil.rmtree(td, ignore_errors=True) + + +# TODO: support image to video generation +@router.post("", response_model=VideoResponse) +async def create_video( + request: Request, + # multipart/form-data fields (optional; used only when content-type is multipart) + prompt: Optional[str] = Form(None), + input_reference: Optional[UploadFile] = File(None), + reference_url: Optional[str] = Form(None), + model: Optional[str] = Form(None), + seconds: Optional[int] = Form(None), + size: Optional[str] = Form(None), + fps: Optional[int] = Form(None), + num_frames: Optional[int] = Form(None), + seed: Optional[int] = Form(1024), + generator_device: Optional[str] = Form("cuda"), + negative_prompt: Optional[str] = Form(None), + guidance_scale: Optional[float] = Form(None), + num_inference_steps: Optional[int] = Form(None), + enable_teacache: Optional[bool] = Form(False), + enable_frame_interpolation: Optional[bool] = Form(False), + frame_interpolation_exp: Optional[int] = Form(1), + frame_interpolation_scale: Optional[float] = Form(1.0), + frame_interpolation_model_path: Optional[str] = Form(None), + output_quality: Optional[str] = Form("default"), + output_compression: Optional[int] = Form(None), + extra_body: Optional[str] = Form(None), +): + content_type = request.headers.get("content-type", "").lower() + request_id = generate_request_id() + + server_args = get_global_server_args() + task_type = server_args.pipeline_config.task_type + + # Resolve input upload directory (may be a temp dir when saving is disabled) + temp_dirs: list[str] = [] + if server_args.input_save_path is not None: + uploads_dir = server_args.input_save_path + os.makedirs(uploads_dir, exist_ok=True) + else: + uploads_dir = tempfile.mkdtemp(prefix="sglang_input_") + temp_dirs.append(uploads_dir) + + # Resolve output directory + effective_output_path = server_args.output_path + output_persistent = True + if "multipart/form-data" not in content_type: + # JSON body may carry a per-request output_path; checked after parsing below + pass + + if "multipart/form-data" in content_type: + if not prompt: + raise HTTPException(status_code=400, detail="prompt is required") + # Validate image input based on model task type + image_sources = merge_image_input_list(input_reference, reference_url) + if task_type.requires_image_input() and not image_sources: + raise HTTPException( + status_code=400, + detail="input_reference or reference_url is required for image-to-video generation", + ) + try: + input_path = await _save_first_input_image( + image_sources, request_id, uploads_dir + ) + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Failed to process image source: {str(e)}" + ) + + # Parse extra_body JSON (if provided in multipart form) to get fps/num_frames overrides + extra_from_form: Dict[str, Any] = {} + if extra_body: + try: + extra_from_form = json.loads(extra_body) + except Exception: + extra_from_form = {} + + fps_val = fps if fps is not None else extra_from_form.get("fps") + num_frames_val = ( + num_frames if num_frames is not None else extra_from_form.get("num_frames") + ) + + req = VideoGenerationsRequest( + prompt=prompt, + input_reference=input_path, + model=model, + seconds=seconds if seconds is not None else 4, + size=size, + fps=fps_val, + num_frames=num_frames_val, + seed=seed, + generator_device=generator_device, + negative_prompt=negative_prompt, + num_inference_steps=num_inference_steps, + enable_teacache=enable_teacache, + enable_frame_interpolation=enable_frame_interpolation, + frame_interpolation_exp=frame_interpolation_exp, + frame_interpolation_scale=frame_interpolation_scale, + frame_interpolation_model_path=frame_interpolation_model_path, + output_compression=output_compression, + output_quality=output_quality, + **( + {"guidance_scale": guidance_scale} if guidance_scale is not None else {} + ), + ) + else: + try: + body = await request.json() + except Exception: + body = {} + try: + # If client uses extra_body, merge it into the top-level payload + payload: Dict[str, Any] = dict(body or {}) + extra = payload.pop("extra_body", None) + if isinstance(extra, dict): + # Shallow-merge: only keys like fps/num_frames are expected + payload.update(extra) + # openai may turn extra_body to extra_json + extra_json = payload.pop("extra_json", None) + if isinstance(extra_json, dict): + payload.update(extra_json) + # Validate image input based on model task type + has_image_input = payload.get("reference_url") or payload.get( + "input_reference" + ) + if task_type.requires_image_input() and not has_image_input: + raise HTTPException( + status_code=400, + detail="input_reference or reference_url is required for image-to-video generation", + ) + # for non-multipart/form-data type + if payload.get("reference_url"): + try: + input_path = await _save_first_input_image( + payload.get("reference_url"), request_id, uploads_dir + ) + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Failed to process image source: {str(e)}", + ) + payload["input_reference"] = input_path + req = VideoGenerationsRequest(**payload) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Invalid request body: {e}") + + # Resolve per-request output_path override + effective_output_path = req.output_path or server_args.output_path + if effective_output_path is None: + output_tmp = tempfile.mkdtemp(prefix="sglang_output_") + temp_dirs.append(output_tmp) + effective_output_path = output_tmp + output_persistent = False + + # Inject resolved output_path so _build_video_sampling_params picks it up + req.output_path = effective_output_path + + logger.debug(f"Server received from create_video endpoint: req={req}") + + try: + sampling_params = _build_video_sampling_params(request_id, req) + except (ValueError, TypeError) as e: + raise HTTPException(status_code=400, detail=str(e)) + + job = _video_job_from_sampling(request_id, req, sampling_params) + await VIDEO_STORE.upsert(request_id, job) + + # Build Req for scheduler + batch = prepare_request( + server_args=server_args, + sampling_params=sampling_params, + ) + # Add diffusers_kwargs if provided + if req.diffusers_kwargs: + batch.extra["diffusers_kwargs"] = req.diffusers_kwargs + # Enqueue the job asynchronously and return immediately + asyncio.create_task( + _dispatch_job_async( + request_id, + batch, + temp_dirs=temp_dirs or None, + output_persistent=output_persistent, + ) + ) + return VideoResponse(**job) + + +@router.get("", response_model=VideoListResponse) +async def list_videos( + after: Optional[str] = Query(None), + limit: Optional[int] = Query(None, ge=1, le=100), + order: Optional[str] = Query("desc"), +): + # Normalize order + order = (order or "desc").lower() + if order not in ("asc", "desc"): + order = "desc" + jobs = await VIDEO_STORE.list_values() + + reverse = order != "asc" + jobs.sort(key=lambda j: j.get("created_at", 0), reverse=reverse) + + if after is not None: + try: + idx = next(i for i, j in enumerate(jobs) if j["id"] == after) + jobs = jobs[idx + 1 :] + except StopIteration: + jobs = [] + + if limit is not None: + jobs = jobs[:limit] + items = [VideoResponse(**j) for j in jobs] + return VideoListResponse(data=items) + + +@router.get("/{video_id}", response_model=VideoResponse) +async def retrieve_video(video_id: str = Path(...)): + job = await VIDEO_STORE.get(video_id) + if not job: + raise HTTPException(status_code=404, detail="Video not found") + return VideoResponse(**job) + + +# TODO: support aborting a job. +@router.delete("/{video_id}", response_model=VideoResponse) +async def delete_video(video_id: str = Path(...)): + job = await VIDEO_STORE.pop(video_id) + if not job: + raise HTTPException(status_code=404, detail="Video not found") + # Mark as deleted in response semantics + job["status"] = "deleted" + return VideoResponse(**job) + + +@router.get("/{video_id}/content") +async def download_video_content( + video_id: str = Path(...), variant: Optional[str] = Query(None) +): + job = await VIDEO_STORE.get(video_id) + if not job: + raise HTTPException(status_code=404, detail="Video not found") + + if job.get("url"): + raise HTTPException( + status_code=400, + detail=f"Video has been uploaded to cloud storage. Please use the cloud URL: {job.get('url')}", + ) + + file_path = job.get("file_path") + if not file_path or not os.path.exists(file_path): + raise HTTPException(status_code=404, detail="Generation is still in-progress") + + media_type = "video/mp4" # default variant + return FileResponse( + path=file_path, media_type=media_type, filename=os.path.basename(file_path) + ) diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/post_training/__init__.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/post_training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py new file mode 100644 index 0000000000000000000000000000000000000000..bda72df12a8f44543a03d22197b8f46391d81944 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py @@ -0,0 +1,19 @@ +"""Request/response data structures for post-training APIs.""" + +from dataclasses import dataclass + + +@dataclass +class UpdateWeightFromDiskReqInput: + """Request to update model weights from disk for diffusion models.""" + + model_path: str + flush_cache: bool = True + target_modules: list[str] | None = None + + +@dataclass +class GetWeightsChecksumReqInput: + """Compute SHA-256 checksum of loaded module weights for verification.""" + + module_names: list[str] | None = None diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py new file mode 100644 index 0000000000000000000000000000000000000000..1b9312d8ea0f636653d75a3b248459ee642338cc --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py @@ -0,0 +1,62 @@ +"""Weight update API for the diffusion engine.""" + +from fastapi import APIRouter, Request +from fastapi.responses import ORJSONResponse + +from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( + GetWeightsChecksumReqInput, + UpdateWeightFromDiskReqInput, +) +from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client + +router = APIRouter() + + +@router.post("/update_weights_from_disk") +async def update_weights_from_disk(request: Request): + """Update model weights from disk inplace without restarting the server.""" + body = await request.json() + model_path = body.get("model_path") + if not model_path: + return ORJSONResponse( + {"success": False, "message": "model_path is required"}, + status_code=400, + ) + + req = UpdateWeightFromDiskReqInput( + model_path=model_path, + flush_cache=body.get("flush_cache", True), + target_modules=body.get("target_modules"), + ) + + try: + response = await async_scheduler_client.forward(req) + except Exception as e: + return ORJSONResponse( + {"success": False, "message": str(e)}, + status_code=500, + ) + + result = response.output + success = result.get("success", False) + message = result.get("message", "Unknown status") + return ORJSONResponse( + {"success": success, "message": message}, + status_code=200 if success else 400, + ) + + +@router.post("/get_weights_checksum") +async def get_weights_checksum(request: Request): + """Return SHA-256 checksum of each requested module's weights.""" + body = await request.json() + req = GetWeightsChecksumReqInput( + module_names=body.get("module_names"), + ) + + try: + response = await async_scheduler_client.forward(req) + except Exception as e: + return ORJSONResponse({"error": str(e)}, status_code=500) + + return ORJSONResponse(response.output, status_code=200) diff --git a/sglang/python/sglang/multimodal_gen/runtime/entrypoints/utils.py b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3d9d8b45bbf836f79c65f6f5fc3b08b7f33d699c --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/entrypoints/utils.py @@ -0,0 +1,500 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +DiffGenerator module for sglang-diffusion. + +This module provides a consolidated interface for generating videos using +diffusion models. +""" + +import os +import shutil +import subprocess +import tempfile +from dataclasses import dataclass, field +from typing import Any, Callable, List, Optional, Sequence, Union + +import imageio +import numpy as np +import torch + +try: + import scipy.io.wavfile as scipy_wavfile +except ImportError: # pragma: no cover + scipy_wavfile = None + +try: + import imageio_ffmpeg as _imageio_ffmpeg +except ImportError: # pragma: no cover + _imageio_ffmpeg = None + +from sglang.multimodal_gen.configs.sample.sampling_params import ( + DataType, + SamplingParams, +) +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import CYAN, RESET, init_logger + +logger = init_logger(__name__) + + +@dataclass +class SetLoraReq: + lora_nickname: Union[str, List[str]] + lora_path: Optional[Union[str, List[Optional[str]]]] = None + target: Union[str, List[str]] = "all" + strength: Union[float, List[float]] = 1.0 + + +@dataclass +class MergeLoraWeightsReq: + target: str = "all" + strength: float = 1.0 + + +@dataclass +class UnmergeLoraWeightsReq: + target: str = "all" + + +@dataclass +class ListLorasReq: + pass + + +@dataclass +class ShutdownReq: + pass + + +def format_lora_message( + lora_nickname: Union[str, List[str]], + target: Union[str, List[str]], + strength: Union[float, List[float]], +) -> tuple[str, str, str]: + """Format success message for single or multiple LoRAs.""" + if isinstance(lora_nickname, list): + nickname_str = ", ".join(lora_nickname) + target_str = ", ".join(target) if isinstance(target, list) else target + strength_str = ( + ", ".join(f"{s:.2f}" for s in strength) + if isinstance(strength, list) + else f"{strength:.2f}" + ) + else: + nickname_str = lora_nickname + target_str = target if isinstance(target, str) else ", ".join(target) + strength_str = ( + f"{strength:.2f}" + if isinstance(strength, (int, float)) + else ", ".join(f"{s:.2f}" for s in strength) + ) + return nickname_str, target_str, strength_str + + +@dataclass +class GenerationResult: + """Result of a single generation request from DiffGenerator.""" + + samples: Any = None + frames: Any = None + audio: Any = None + prompt: str | None = None + size: tuple | None = None # (height, width, num_frames) + generation_time: float = 0.0 + peak_memory_mb: float = 0.0 + metrics: dict = field(default_factory=dict) + trajectory_latents: Any = None + trajectory_timesteps: Any = None + trajectory_decoded: Any = None + prompt_index: int = 0 + output_file_path: str | None = None + + +def _normalize_audio_to_numpy(audio: Any) -> np.ndarray | None: + """Convert audio (torch / numpy) into a float32 numpy array in [-1, 1], best-effort.""" + if audio is None: + return None + if isinstance(audio, torch.Tensor): + audio_np = audio.detach().float().clamp(-1.0, 1.0).cpu().numpy() + elif isinstance(audio, np.ndarray): + audio_np = audio.astype(np.float32, copy=False) + audio_np = np.clip(audio_np, -1.0, 1.0) + else: + return None + + # 1. Squeeze leading singleton dimensions (Batch, etc.) + while audio_np.ndim > 1 and audio_np.shape[0] == 1: + audio_np = audio_np.squeeze(0) + + # 2. Handle (C, L) -> (L, C) + if audio_np.ndim == 2 and audio_np.shape[0] < audio_np.shape[1]: + audio_np = audio_np.transpose(1, 0) + + # 3. Final safety check: if still 2D and channels (dim 1) is huge, something is wrong + if audio_np.ndim == 2 and audio_np.shape[1] > 256 and audio_np.shape[0] == 1: + audio_np = audio_np.flatten() + + return audio_np + + +def _pick_audio_sample_rate( + *, + audio_np: np.ndarray, + audio_sample_rate: Optional[int], + fps: int, + num_frames: int, +) -> int: + """Pick a plausible sample rate, falling back to inferring from video duration.""" + selected_sr = int(audio_sample_rate) if audio_sample_rate is not None else None + if selected_sr is None or not (8000 <= selected_sr <= 192000): + selected_sr = 24000 + try: + duration_s = float(num_frames) / float(fps) if fps else 0.0 + if duration_s > 0: + audio_len = ( + int(audio_np.shape[0]) + if audio_np.ndim == 2 + else int(audio_np.shape[-1]) + ) + inferred_sr = int(round(float(audio_len) / duration_s)) + if 8000 <= inferred_sr <= 192000: + selected_sr = inferred_sr + except Exception: + pass + return selected_sr + + +def _resolve_ffmpeg_exe() -> str: + ffmpeg_exe = "ffmpeg" + ffmpeg_on_path = shutil.which("ffmpeg") + if ffmpeg_on_path: + ffmpeg_exe = ffmpeg_on_path + try: + if _imageio_ffmpeg is not None: + ffmpeg_exe = _imageio_ffmpeg.get_ffmpeg_exe() + except Exception: + pass + + ffmpeg_ok = False + if ffmpeg_exe: + if os.path.isabs(ffmpeg_exe): + ffmpeg_ok = os.path.exists(ffmpeg_exe) + else: + ffmpeg_ok = shutil.which(ffmpeg_exe) is not None + if not ffmpeg_ok: + raise RuntimeError("ffmpeg not found") + return ffmpeg_exe + + +def _mux_audio_np_into_mp4( + *, + save_file_path: str, + audio_np: np.ndarray, + sample_rate: int, + ffmpeg_exe: str, +) -> None: + merged_path = save_file_path.rsplit(".", 1)[0] + ".tmp_mux.mp4" + tmp_wav_path = None + try: + if scipy_wavfile is None: + raise RuntimeError( + "scipy is required to mux audio into mp4 (pip install scipy)" + ) + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: + tmp_wav_path = f.name + scipy_wavfile.write(tmp_wav_path, sample_rate, audio_np) + subprocess.run( + [ + ffmpeg_exe, + "-y", + "-i", + save_file_path, + "-i", + tmp_wav_path, + "-c:v", + "copy", + "-c:a", + "aac", + "-strict", + "experimental", + merged_path, + ], + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + os.replace(merged_path, save_file_path) + finally: + if tmp_wav_path: + try: + os.remove(tmp_wav_path) + except OSError: + pass + if os.path.exists(merged_path): + try: + os.remove(merged_path) + except OSError: + pass + + +def _maybe_mux_audio_into_mp4( + *, + save_file_path: str, + audio: Any, + frames: list, + fps: int, + audio_sample_rate: Optional[int], +) -> None: + """Best-effort mux audio into an already-written mp4 at save_file_path. + + Any failure should keep the silent video and only log a warning. + """ + audio_np = _normalize_audio_to_numpy(audio) + if audio_np is None: + return + selected_sr = _pick_audio_sample_rate( + audio_np=audio_np, + audio_sample_rate=audio_sample_rate, + fps=fps, + num_frames=len(frames), + ) + + try: + ffmpeg_exe = _resolve_ffmpeg_exe() + _mux_audio_np_into_mp4( + save_file_path=save_file_path, + audio_np=audio_np, + sample_rate=selected_sr, + ffmpeg_exe=ffmpeg_exe, + ) + logger.info(f"Merged video saved to {CYAN}{save_file_path}{RESET}") + except Exception as e: + logger.warning( + "Failed to mux audio into mp4 (saved silent video): %s", + str(e), + ) + + +def prepare_request( + server_args: ServerArgs, + sampling_params: SamplingParams, +) -> Req: + """ + Create a Req object with sampling_params as a parameter. + """ + req = Req( + sampling_params=sampling_params, + VSA_sparsity=server_args.attention_backend_config.VSA_sparsity, + ) + try: + diffusers_kwargs = sampling_params.diffusers_kwargs + except AttributeError: + diffusers_kwargs = None + if diffusers_kwargs: + req.extra["diffusers_kwargs"] = diffusers_kwargs + + req.adjust_size(server_args) + + if not isinstance(req.prompt, str): + raise TypeError(f"`prompt` must be a string, but got {type(req.prompt)}") + + if (req.width is not None and req.width <= 0) or ( + req.height is not None and req.height <= 0 + ): + raise ValueError( + f"Height and width must be positive, got height={req.height}, width={req.width}" + ) + + return req + + +def attach_audio_to_video_sample( + sample: Any, + audio: Any, + output_idx: int, +) -> Any: + """Attach per-sample audio for video outputs when available.""" + if audio is None: + return sample + if isinstance(audio, torch.Tensor) and audio.ndim >= 2: + audio = audio[output_idx] if audio.shape[0] > output_idx else None + elif isinstance(audio, np.ndarray) and audio.ndim >= 2: + audio = audio[output_idx] if audio.shape[0] > output_idx else None + + if audio is not None and not ( + isinstance(sample, (tuple, list)) and len(sample) == 2 + ): + return (sample, audio) + return sample + + +def save_outputs( + outputs: Sequence[Any], + data_type: DataType, + fps: int, + save_output: bool, + build_output_path: Callable[[int], str], + *, + audio: Any = None, + audio_sample_rate: Optional[int] = None, + samples_out: Optional[list[Any]] = None, + audios_out: Optional[list[Any]] = None, + frames_out: Optional[list[Any]] = None, + output_compression: Optional[int] = None, + enable_frame_interpolation: bool = False, + frame_interpolation_exp: int = 1, + frame_interpolation_scale: float = 1.0, + frame_interpolation_model_path: Optional[str] = None, +) -> list[str]: + """Save outputs to files and return the list of file paths.""" + output_paths: list[str] = [] + for idx, output in enumerate(outputs): + save_file_path = build_output_path(idx) + sample = output + if data_type == DataType.VIDEO: + sample = attach_audio_to_video_sample(sample, audio, idx) + + frames = post_process_sample( + sample, + data_type, + fps, + save_output, + save_file_path, + audio_sample_rate=audio_sample_rate, + output_compression=output_compression, + enable_frame_interpolation=enable_frame_interpolation, + frame_interpolation_exp=frame_interpolation_exp, + frame_interpolation_scale=frame_interpolation_scale, + frame_interpolation_model_path=frame_interpolation_model_path, + ) + + if samples_out is not None: + samples_out.append(sample) + if audios_out is not None: + if data_type == DataType.VIDEO: + audio_item = audio + if isinstance(audio, torch.Tensor) and audio.ndim >= 2: + audio_item = audio[idx] if audio.shape[0] > idx else None + elif isinstance(audio, np.ndarray) and audio.ndim >= 2: + audio_item = audio[idx] if audio.shape[0] > idx else None + audios_out.append(audio_item) + else: + audios_out.append(audio) + if frames_out is not None: + frames_out.append(frames) + output_paths.append(save_file_path) + return output_paths + + +def post_process_sample( + sample: Any, + data_type: DataType, + fps: int, + save_output: bool = True, + save_file_path: Optional[str] = None, + audio_sample_rate: Optional[int] = None, + output_compression: Optional[int] = None, + enable_frame_interpolation: bool = False, + frame_interpolation_exp: int = 1, + frame_interpolation_scale: float = 1.0, + frame_interpolation_model_path: Optional[str] = None, +): + """ + Process sample output, optionally interpolate video frames, and save. + """ + audio = None + if isinstance(sample, (tuple, list)) and len(sample) == 2: + sample, audio = sample + + # 1. Convert tensor / array to list of uint8 HWC frames + frames = None + if isinstance(sample, torch.Tensor): + if sample.dim() == 3: + sample = sample.unsqueeze(1) + sample = (sample * 255).clamp(0, 255).to(torch.uint8) + videos = sample.permute(1, 2, 3, 0).cpu().numpy() + frames = list(videos) + else: + if not isinstance(sample, np.ndarray): + raise TypeError(f"Unsupported sample type: {type(sample)}") + + arr = sample + if arr.ndim == 3: + if arr.shape[-1] in (1, 3, 4): + arr = arr[None, ...] + else: + arr = arr[..., None] + if arr.ndim != 4: + raise ValueError(f"Unexpected numpy sample shape: {tuple(arr.shape)}") + + if arr.shape[-1] not in (1, 3, 4) and arr.shape[0] in (1, 3, 4): + t = torch.from_numpy(arr) + if t.dim() == 3: + t = t.unsqueeze(1) + t = (t * 255).clamp(0, 255).to(torch.uint8) + videos = t.permute(1, 2, 3, 0).cpu().numpy() + frames = list(videos) + else: + if arr.dtype != np.uint8: + arr = (np.clip(arr, 0.0, 1.0) * 255.0).astype(np.uint8) + frames = list(arr) + + # 2. Frame interpolation (video only) + if enable_frame_interpolation and data_type == DataType.VIDEO and len(frames) > 1: + from sglang.multimodal_gen.runtime.postprocess import ( + interpolate_video_frames, + ) + + frames, multiplier = interpolate_video_frames( + frames, + exp=frame_interpolation_exp, + scale=frame_interpolation_scale, + model_path=frame_interpolation_model_path, + ) + fps = fps * multiplier + + # 3. Save outputs if requested + if save_output: + if save_file_path: + os.makedirs(os.path.dirname(save_file_path), exist_ok=True) + if data_type == DataType.VIDEO: + quality = ( + output_compression / 10 if output_compression is not None else 5 + ) + imageio.mimsave( + save_file_path, + frames, + fps=fps, + format=data_type.get_default_extension(), + codec="libx264", + quality=quality, + ) + + _maybe_mux_audio_into_mp4( + save_file_path=save_file_path, + audio=audio, + frames=frames, + fps=fps, + audio_sample_rate=audio_sample_rate, + ) + + else: + quality = output_compression if output_compression is not None else 75 + if len(frames) > 1: + for i, image in enumerate(frames): + parts = save_file_path.rsplit(".", 1) + if len(parts) == 2: + indexed_path = f"{parts[0]}_{i}.{parts[1]}" + else: + indexed_path = f"{save_file_path}_{i}" + imageio.imwrite(indexed_path, image, quality=quality) + else: + imageio.imwrite(save_file_path, frames[0], quality=quality) + logger.info(f"Output saved to {CYAN}{save_file_path}{RESET}") + else: + logger.info(f"No output path provided, output not saved") + + return frames diff --git a/sglang/python/sglang/multimodal_gen/runtime/launch_server.py b/sglang/python/sglang/multimodal_gen/runtime/launch_server.py new file mode 100644 index 0000000000000000000000000000000000000000..318ed604215b335b563cd51371a726ba597225d0 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/launch_server.py @@ -0,0 +1,208 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import multiprocessing as mp +import os +import signal +import sys +import threading + +import psutil +import uvicorn + +from sglang.multimodal_gen.runtime.entrypoints.http_server import create_app +from sglang.multimodal_gen.runtime.managers.gpu_worker import run_scheduler_process +from sglang.multimodal_gen.runtime.server_args import ( + ServerArgs, + prepare_server_args, + set_global_server_args, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import configure_logger, logger + + +def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None): + """Kill the process and all its child processes.""" + # Remove sigchld handler to avoid spammy logs. + if threading.current_thread() is threading.main_thread(): + signal.signal(signal.SIGCHLD, signal.SIG_DFL) + + if parent_pid is None: + parent_pid = os.getpid() + include_parent = False + + try: + itself = psutil.Process(parent_pid) + except psutil.NoSuchProcess: + return + + children = itself.children(recursive=True) + for child in children: + if child.pid == skip_pid: + continue + try: + child.kill() + except psutil.NoSuchProcess: + pass + + if include_parent: + try: + if parent_pid == os.getpid(): + itself.kill() + sys.exit(0) + + itself.kill() + + # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes), + # so we send an additional signal to kill them. + itself.send_signal(signal.SIGQUIT) + except psutil.NoSuchProcess: + pass + + +def launch_server(server_args: ServerArgs, launch_http_server: bool = True): + """ + Args: + launch_http_server: False for offline local mode + """ + configure_logger(server_args) + + # Start a new server with multiple worker processes + logger.info("Starting server...") + + num_gpus = server_args.num_gpus + processes = [] + + # Pipes for master to talk to slaves + task_pipes_to_slaves_w = [] + task_pipes_to_slaves_r = [] + for _ in range(num_gpus - 1): + r, w = mp.Pipe(duplex=False) + task_pipes_to_slaves_r.append(r) + task_pipes_to_slaves_w.append(w) + + # Pipes for slaves to talk to master + result_pipes_from_slaves_w = [] + result_pipes_from_slaves_r = [] + for _ in range(num_gpus - 1): + r, w = mp.Pipe(duplex=False) + result_pipes_from_slaves_r.append(r) + result_pipes_from_slaves_w.append(w) + + # Launch all worker processes + master_port = server_args.master_port or (server_args.master_port + 100) + scheduler_pipe_readers = [] + scheduler_pipe_writers = [] + + for i in range(num_gpus): + reader, writer = mp.Pipe(duplex=False) + scheduler_pipe_writers.append(writer) + if i == 0: # Master worker + process = mp.Process( + target=run_scheduler_process, + args=( + i, # local_rank + i, # rank + master_port, + server_args, + writer, + None, # No task pipe to read from master + None, # No result pipe to write to master + task_pipes_to_slaves_w, + result_pipes_from_slaves_r, + ), + name=f"sglang-diffusionWorker-{i}", + daemon=True, + ) + else: # Slave workers + process = mp.Process( + target=run_scheduler_process, + args=( + i, # local_rank + i, # rank + master_port, + server_args, + writer, + None, # No task pipe to read from master + None, # No result pipe to write to master + task_pipes_to_slaves_r[i - 1], + result_pipes_from_slaves_w[i - 1], + ), + name=f"sglang-diffusionWorker-{i}", + daemon=True, + ) + scheduler_pipe_readers.append(reader) + process.start() + processes.append(process) + + # Wait for all workers to be ready + scheduler_infos = [] + for writer in scheduler_pipe_writers: + writer.close() + + # Close unused pipe ends in parent process + for p in task_pipes_to_slaves_w: + p.close() + for p in task_pipes_to_slaves_r: + p.close() + for p in result_pipes_from_slaves_w: + p.close() + for p in result_pipes_from_slaves_r: + p.close() + + for i, reader in enumerate(scheduler_pipe_readers): + try: + data = reader.recv() + except EOFError: + logger.error( + f"Rank {i} scheduler is dead. Please check if there are relevant logs." + ) + processes[i].join() + logger.error(f"Exit code: {processes[i].exitcode}") + raise + + if data["status"] != "ready": + raise RuntimeError( + "Initialization failed. Please see the error messages above." + ) + scheduler_infos.append(data) + reader.close() + + logger.debug("All workers are ready") + + if launch_http_server: + logger.info("Starting FastAPI server.") + if server_args.webui: + logger.info("Launch FastAPI server in another process because of webui.") + http_server_process = mp.Process( + target=launch_http_server_only, + args=(server_args,), + name=f"sglang-diffusion-webui", + daemon=True, + ) + http_server_process.start() + else: + launch_http_server_only(server_args) + + return processes + + +def launch_http_server_only(server_args): + # set for endpoints to access global_server_args + set_global_server_args(server_args) + app = create_app(server_args) + uvicorn.run( + app, + use_colors=True, + log_level=server_args.log_level, + host=server_args.host, + port=server_args.port, + reload=False, + ) + + +if __name__ == "__main__": + server_args = prepare_server_args(sys.argv[1:]) + + try: + launch_server(server_args) + finally: + kill_process_tree(os.getpid(), include_parent=False) diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/__init__.py b/sglang/python/sglang/multimodal_gen/runtime/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..af2eb7d103a81378d30056cf536353bd24621a5e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/activation.py b/sglang/python/sglang/multimodal_gen/runtime/layers/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..b4d45727262f1e1fb512e3f18095be18048706e9 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/activation.py @@ -0,0 +1,151 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/activation.py +"""Custom activation functions.""" + +import math +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sglang.multimodal_gen.runtime.platforms import current_platform + +_is_cuda = current_platform.is_cuda() +_is_hip = current_platform.is_hip() +_is_npu = current_platform.is_npu() +if _is_cuda or _is_hip: + from sgl_kernel import silu_and_mul + +if _is_npu: + import torch_npu +# TODO (will): remove this dependency +from sglang.multimodal_gen.runtime.layers.custom_op import CustomOp + + +@CustomOp.register("silu_and_mul") +class SiluAndMul(CustomOp): + """An activation function for SwiGLU. + + The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2. + + Shapes: + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) + """ + + def __init__(self) -> None: + super().__init__() + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = x.shape[:-1] + (d,) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + silu_and_mul(x, out) + return out + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] + + def forward_npu(self, x: torch.Tensor) -> torch.Tensor: + out = torch_npu.npu_swiglu(x) + return out + + def forward_musa(self, x: torch.Tensor) -> torch.Tensor: + return nn.SwishGLU()(x) + + +@CustomOp.register("gelu_and_mul") +class GeluAndMul(CustomOp): + """An activation function for GeGLU. + + The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2. + + Shapes: + x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) + return: (batch_size, seq_len, d) or (num_tokens, d) + """ + + def __init__(self, approximate: str = "none"): + super().__init__() + self.approximate = approximate + if approximate not in ("none", "tanh"): + raise ValueError(f"Unknown approximate mode: {approximate}") + + def forward_cuda(self, *args, **kwargs) -> Any: + return self.forward_native(*args, **kwargs) + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + d = x.shape[-1] // 2 + return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] + + def extra_repr(self) -> str: + return f"approximate={repr(self.approximate)}" + + +@CustomOp.register("gelu_new") +class NewGELU(CustomOp): + + def __init__(self): + super().__init__() + + def forward_cuda(self, *args, **kwargs) -> Any: + return self.forward_native(*args, **kwargs) + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + c = math.sqrt(2.0 / math.pi) + return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0)))) + + +@CustomOp.register("quick_gelu") +class QuickGELU(CustomOp): + # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90 + def __init__(self): + super().__init__() + + def forward_cuda(self, *args, **kwargs) -> Any: + return self.forward_native(*args, **kwargs) + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + return x * torch.sigmoid(1.702 * x) + + +_ACTIVATION_REGISTRY = { + "gelu": nn.GELU, + "gelu_new": NewGELU, + "gelu_pytorch_tanh": lambda: nn.GELU(approximate="tanh"), + "relu": nn.ReLU, + "silu": nn.SiLU, + "quick_gelu": QuickGELU, +} + + +def get_act_fn(act_fn_name: str) -> nn.Module: + """Get an activation function by name.""" + act_fn_name = act_fn_name.lower() + if act_fn_name not in _ACTIVATION_REGISTRY: + raise ValueError(f"Activation function {act_fn_name!r} is not supported.") + + return _ACTIVATION_REGISTRY[act_fn_name]() + + +_ACTIVATION_AND_MUL_REGISTRY = { + "gelu": GeluAndMul, + "silu": SiluAndMul, +} + + +def get_act_and_mul_fn(act_fn_name: str) -> nn.Module: + """Get an activation-and-mul (i.e. SiluAndMul) function by name.""" + act_fn_name = act_fn_name.lower() + if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY: + raise ValueError(f"Activation function {act_fn_name!r} is not supported.") + + return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]() diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/attention/STA_configuration.py b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/STA_configuration.py new file mode 100644 index 0000000000000000000000000000000000000000..9635a67401b0c2a5f98fca0ad30e87777832d43f --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/STA_configuration.py @@ -0,0 +1,414 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import json +import os +from collections import defaultdict +from typing import Any + +import numpy as np + +from sglang.multimodal_gen.utils import dict_to_3d_list + + +def configure_sta( + mode: str = "STA_searching", + layer_num: int = 40, + time_step_num: int = 50, + head_num: int = 40, + **kwargs, +) -> list[list[list[Any]]]: + """ + Configure Sliding Tile Attention (STA) parameters based on the specified mode. + + Parameters: + ---------- + mode : str + The STA mode to use. Options are: + - 'STA_searching': Generate a set of mask candidates for initial search + - 'STA_tuning': Select best mask strategy based on previously saved results + - 'STA_inference': Load and use a previously tuned mask strategy + layer_num: int, number of layers + time_step_num: int, number of timesteps + head_num: int, number of heads + + **kwargs : dict + Mode-specific parameters: + + For 'STA_searching': + - mask_candidates: list of str, optional, mask candidates to use + - mask_selected: list of int, optional, indices of selected masks + + For 'STA_tuning': + - mask_search_files_path: str, required, path to mask search results + - mask_candidates: list of str, optional, mask candidates to use + - mask_selected: list of int, optional, indices of selected masks + - skip_time_steps: int, optional, number of time steps to use full attention (default 12) + - save_dir: str, optional, directory to save mask strategy (default "mask_candidates") + + For 'STA_inference': + - load_path: str, optional, path to load mask strategy (default "mask_candidates/mask_strategy.json") + """ + valid_modes = ["STA_searching", "STA_tuning", "STA_inference", "STA_tuning_cfg"] + if mode not in valid_modes: + raise ValueError(f"Mode must be one of {valid_modes}, got {mode}") + + if mode == "STA_searching": + # Get parameters with defaults + mask_candidates: list[str] | None = kwargs.get("mask_candidates") + if mask_candidates is None: + raise ValueError("mask_candidates is required for STA_searching mode") + mask_selected: list[int] = kwargs.get( + "mask_selected", list(range(len(mask_candidates))) + ) + + # Parse selected masks + selected_masks: list[list[int]] = [] + for index in mask_selected: + mask = mask_candidates[index] + masks_list = [int(x) for x in mask.split(",")] + selected_masks.append(masks_list) + + # Create 3D mask structure with fixed dimensions (t=50, l=60) + masks_3d: list[list[list[list[int]]]] = [] + for i in range(time_step_num): # Fixed t dimension = 50 + row = [] + for j in range(layer_num): # Fixed l dimension = 60 + row.append(selected_masks) # Add all masks at each position + masks_3d.append(row) + + return masks_3d + + elif mode == "STA_tuning": + # Get required parameters + mask_search_files_path: str | None = kwargs.get("mask_search_files_path") + if not mask_search_files_path: + raise ValueError("mask_search_files_path is required for STA_tuning mode") + + # Get optional parameters with defaults + mask_candidates_tuning: list[str] | None = kwargs.get("mask_candidates") + if mask_candidates_tuning is None: + raise ValueError("mask_candidates is required for STA_tuning mode") + mask_selected_tuning: list[int] = kwargs.get( + "mask_selected", list(range(len(mask_candidates_tuning))) + ) + skip_time_steps_tuning: int | None = kwargs.get("skip_time_steps") + save_dir_tuning: str | None = kwargs.get("save_dir", "mask_candidates") + + # Parse selected masks + selected_masks_tuning: list[list[int]] = [] + for index in mask_selected_tuning: + mask = mask_candidates_tuning[index] + masks_list = [int(x) for x in mask.split(",")] + selected_masks_tuning.append(masks_list) + + # Read JSON results + results = read_specific_json_files(mask_search_files_path) + averaged_results = average_head_losses(results, selected_masks_tuning) + + # Add full attention mask for specific cases + full_attention_mask_tuning: list[int] | None = kwargs.get("full_attention_mask") + if full_attention_mask_tuning is not None: + selected_masks_tuning.append(full_attention_mask_tuning) + + # Select best mask strategy + timesteps_tuning: int = kwargs.get("timesteps", time_step_num) + if skip_time_steps_tuning is None: + skip_time_steps_tuning = 12 + mask_strategy, sparsity, strategy_counts = select_best_mask_strategy( + averaged_results, + selected_masks_tuning, + skip_time_steps_tuning, + timesteps_tuning, + head_num, + ) + + # Save mask strategy + if save_dir_tuning is not None: + os.makedirs(save_dir_tuning, exist_ok=True) + file_path = os.path.join( + save_dir_tuning, f"mask_strategy_s{skip_time_steps_tuning}.json" + ) + with open(file_path, "w") as f: + json.dump(mask_strategy, f, indent=4) + print(f"Successfully saved mask_strategy to {file_path}") + + # Print sparsity and strategy counts for information + print(f"Overall sparsity: {sparsity:.4f}") + print("\nStrategy usage counts:") + total_heads = time_step_num * layer_num * head_num # Fixed dimensions + for strategy, count in strategy_counts.items(): + print(f"Strategy {strategy}: {count} heads ({count/total_heads*100:.2f}%)") + + # Convert dictionary to 3D list with fixed dimensions + mask_strategy_3d = dict_to_3d_list( + mask_strategy, t_max=time_step_num, l_max=layer_num, h_max=head_num + ) + + return mask_strategy_3d + elif mode == "STA_tuning_cfg": + # Get required parameters for both positive and negative paths + mask_search_files_path_pos: str | None = kwargs.get( + "mask_search_files_path_pos" + ) + mask_search_files_path_neg: str | None = kwargs.get( + "mask_search_files_path_neg" + ) + save_dir_cfg: str | None = kwargs.get("save_dir") + + if ( + not mask_search_files_path_pos + or not mask_search_files_path_neg + or not save_dir_cfg + ): + raise ValueError( + "mask_search_files_path_pos, mask_search_files_path_neg, and save_dir are required for STA_tuning_cfg mode" + ) + + # Get optional parameters with defaults + mask_candidates_cfg: list[str] | None = kwargs.get("mask_candidates") + if mask_candidates_cfg is None: + raise ValueError("mask_candidates is required for STA_tuning_cfg mode") + mask_selected_cfg: list[int] = kwargs.get( + "mask_selected", list(range(len(mask_candidates_cfg))) + ) + skip_time_steps_cfg: int | None = kwargs.get("skip_time_steps") + + # Parse selected masks + selected_masks_cfg: list[list[int]] = [] + for index in mask_selected_cfg: + mask = mask_candidates_cfg[index] + masks_list = [int(x) for x in mask.split(",")] + selected_masks_cfg.append(masks_list) + + # Read JSON results for both positive and negative paths + pos_results = read_specific_json_files(mask_search_files_path_pos) + neg_results = read_specific_json_files(mask_search_files_path_neg) + # Combine positive and negative results into one list + combined_results = pos_results + neg_results + + # Average the combined results + averaged_results = average_head_losses(combined_results, selected_masks_cfg) + + # Add full attention mask for specific cases + full_attention_mask_cfg: list[int] | None = kwargs.get("full_attention_mask") + if full_attention_mask_cfg is not None: + selected_masks_cfg.append(full_attention_mask_cfg) + + timesteps_cfg: int = kwargs.get("timesteps", time_step_num) + if skip_time_steps_cfg is None: + skip_time_steps_cfg = 12 + # Select best mask strategy using combined results + mask_strategy, sparsity, strategy_counts = select_best_mask_strategy( + averaged_results, + selected_masks_cfg, + skip_time_steps_cfg, + timesteps_cfg, + head_num, + ) + + # Save mask strategy + os.makedirs(save_dir_cfg, exist_ok=True) + file_path = os.path.join( + save_dir_cfg, f"mask_strategy_s{skip_time_steps_cfg}.json" + ) + with open(file_path, "w") as f: + json.dump(mask_strategy, f, indent=4) + print(f"Successfully saved mask_strategy to {file_path}") + + # Print sparsity and strategy counts for information + print(f"Overall sparsity: {sparsity:.4f}") + print("\nStrategy usage counts:") + total_heads = time_step_num * layer_num * head_num # Fixed dimensions + for strategy, count in strategy_counts.items(): + print(f"Strategy {strategy}: {count} heads ({count/total_heads*100:.2f}%)") + + # Convert dictionary to 3D list with fixed dimensions + mask_strategy_3d = dict_to_3d_list( + mask_strategy, t_max=time_step_num, l_max=layer_num, h_max=head_num + ) + + return mask_strategy_3d + + else: # STA_inference + # Get parameters with defaults + load_path: str | None = kwargs.get( + "load_path", "mask_candidates/mask_strategy.json" + ) + if load_path is None: + raise ValueError("load_path is required for STA_inference mode") + + # Load previously saved mask strategy + with open(load_path) as f: + mask_strategy = json.load(f) + + # Convert dictionary to 3D list with fixed dimensions + mask_strategy_3d = dict_to_3d_list( + mask_strategy, t_max=time_step_num, l_max=layer_num, h_max=head_num + ) + + return mask_strategy_3d + + +# Helper functions + + +def read_specific_json_files(folder_path: str) -> list[dict[str, Any]]: + """Read and parse JSON files containing mask search results.""" + json_contents: list[dict[str, Any]] = [] + + # List files only in the current directory (no walk) + files = os.listdir(folder_path) + # Filter files + matching_files = [f for f in files if "mask" in f and f.endswith(".json")] + print(f"Found {len(matching_files)} matching files: {matching_files}") + + for file_name in matching_files: + file_path = os.path.join(folder_path, file_name) + with open(file_path) as file: + data = json.load(file) + json_contents.append(data) + + return json_contents + + +def average_head_losses( + results: list[dict[str, Any]], selected_masks: list[list[int]] +) -> dict[str, dict[str, np.ndarray]]: + """Average losses across all prompts for each mask strategy.""" + # Initialize a dictionary to store the averaged results + averaged_losses: dict[str, dict[str, np.ndarray]] = {} + loss_type = "L2_loss" + # Get all loss types (e.g., 'L2_loss') + averaged_losses[loss_type] = {} + + for mask in selected_masks: + mask_str = str(mask) + data_shape = np.array(results[0][loss_type][mask_str]).shape + accumulated_data = np.zeros(data_shape) + + # Sum across all prompts + for prompt_result in results: + accumulated_data += np.array(prompt_result[loss_type][mask_str]) + + # Average by dividing by number of prompts + averaged_data = accumulated_data / len(results) + averaged_losses[loss_type][mask_str] = averaged_data + + return averaged_losses + + +def select_best_mask_strategy( + averaged_results: dict[str, dict[str, np.ndarray]], + selected_masks: list[list[int]], + skip_time_steps: int = 12, + timesteps: int = 50, + head_num: int = 40, +) -> tuple[dict[str, list[int]], float, dict[str, int]]: + """Select the best mask strategy for each head based on loss minimization.""" + best_mask_strategy: dict[str, list[int]] = {} + loss_type = "L2_loss" + # Get the shape of time steps and layers + layers = len(averaged_results[loss_type][str(selected_masks[0])][0]) + + # Counter for sparsity calculation + total_tokens = 0 # total number of masked tokens + total_length = 0 # total sequence length + + strategy_counts: dict[str, int] = {str(strategy): 0 for strategy in selected_masks} + full_attn_strategy = selected_masks[-1] # Last strategy is full attention + print(f"Strategy {full_attn_strategy}, skip first {skip_time_steps} steps ") + + for t in range(timesteps): + for layer_idx in range(layers): + for h in range(head_num): + if t < skip_time_steps: # First steps use full attention + strategy = full_attn_strategy + else: + # Get losses for this head across all strategies + head_losses = [] + for strategy in selected_masks[:-1]: # Exclude full attention + head_losses.append( + averaged_results[loss_type][str(strategy)][t][layer_idx][h] + ) + + # Find which strategy gives minimum loss + best_strategy_idx = np.argmin(head_losses) + strategy = selected_masks[best_strategy_idx] + + best_mask_strategy[f"{t}_{layer_idx}_{h}"] = strategy + + # Calculate sparsity + nums = strategy # strategy is already a list of numbers + total_tokens += ( + nums[0] * nums[1] * nums[2] + ) # masked tokens for chosen strategy + total_length += ( + full_attn_strategy[0] + * full_attn_strategy[1] + * full_attn_strategy[2] + ) + + # Count strategy usage + strategy_counts[str(strategy)] += 1 + + overall_sparsity = 1 - total_tokens / total_length + + return best_mask_strategy, overall_sparsity, strategy_counts + + +def save_mask_search_results( + mask_search_final_result: list[dict[str, list[float]]], + prompt: str, + mask_strategies: list[str], + output_dir: str = "output/mask_search_result/", +) -> str | None: + if not mask_search_final_result: + print("No mask search results to save") + return None + + # Create result dictionary with defaultdict for nested lists + mask_search_dict: dict[str, dict[str, list[list[float]]]] = { + "L2_loss": defaultdict(list), + "L1_loss": defaultdict(list), + } + + mask_selected = list(range(len(mask_strategies))) + selected_masks: list[list[int]] = [] + for index in mask_selected: + mask = mask_strategies[index] + masks_list = [int(x) for x in mask.split(",")] + selected_masks.append(masks_list) + + # Process each mask strategy + for i, mask_strategy in enumerate(selected_masks): + mask_strategy_str = str(mask_strategy) + # Process L2 loss + step_results: list[list[float]] = [] + for step_data in mask_search_final_result: + if isinstance(step_data, dict) and "L2_loss" in step_data: + layer_losses = [float(loss) for loss in step_data["L2_loss"]] + step_results.append(layer_losses) + mask_search_dict["L2_loss"][mask_strategy_str] = step_results + + step_results = [] + for step_data in mask_search_final_result: + if isinstance(step_data, dict) and "L1_loss" in step_data: + layer_losses = [float(loss) for loss in step_data["L1_loss"]] + step_results.append(layer_losses) + mask_search_dict["L1_loss"][mask_strategy_str] = step_results + + # Create the output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Create a filename based on the first 20 characters of the prompt + filename = prompt[:50].replace(" ", "_") + filepath = os.path.join(output_dir, f"mask_search_{filename}.json") + + # Save the results to a JSON file + with open(filepath, "w") as f: + json.dump(mask_search_dict, f, indent=4) + + print(f"Successfully saved mask research results to {filepath}") + + return filepath diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/attention/__init__.py b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aacb7412be0b3e2acfcd587a9e2ceef660695ff9 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/__init__.py @@ -0,0 +1,30 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, + AttentionMetadata, + AttentionMetadataBuilder, +) +from sglang.multimodal_gen.runtime.layers.attention.layer import ( + LocalAttention, + UlyssesAttention, + UlyssesAttention_VSA, + USPAttention, +) +from sglang.multimodal_gen.runtime.layers.attention.selector import get_attn_backend +from sglang.multimodal_gen.runtime.layers.attention.turbo_layer import MinimalA2AAttnOp + +__all__ = [ + "USPAttention", + "LocalAttention", + "UlyssesAttention", + "UlyssesAttention_VSA", + "MinimalA2AAttnOp", + "AttentionBackend", + "AttentionMetadata", + "AttentionMetadataBuilder", + # "AttentionState", + "get_attn_backend", +] diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/__init__.py b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..af2eb7d103a81378d30056cf536353bd24621a5e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/aiter.py b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/aiter.py new file mode 100644 index 0000000000000000000000000000000000000000..efd69a1c6a2eed9a23a84d66ff5fd70bc1767189 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/aiter.py @@ -0,0 +1,93 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import aiter +import torch + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, +) +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum + + +class AITerBackend(AttentionBackend): + """ + Backend for AITemplate attention implementation. + """ + + @staticmethod + def get_enum() -> AttentionBackendEnum: + return AttentionBackendEnum.AITER + + @staticmethod + def get_impl_cls() -> type["AITerImpl"]: + return AITerImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + # AITer backend does not require special metadata. + return AttentionMetadata + + @staticmethod + def get_builder_cls() -> type["AttentionMetadataBuilder"]: + raise NotImplementedError("AITer backend does not have a metadata builder.") + + +class AITerImpl(AttentionImpl): + """ + Implementation of attention using AITemplate. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + softmax_scale: float, + causal: bool = False, + num_kv_heads: int | None = None, + prefix: str = "", + dropout_p: float = 0.0, + **extra_impl_args, + ) -> None: + if num_kv_heads is not None and num_kv_heads != num_heads: + raise NotImplementedError( + "AITer backend does not support Grouped Query Attention yet." + ) + self.causal = causal + self.dropout_p = dropout_p + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata | None = None, + ) -> torch.Tensor: + """ + Performs attention using aiter.flash_attn_func. + + Args: + query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim] + key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim] + value: Value tensor of shape [batch_size, num_heads, seq_len, head_dim] + attn_metadata: Metadata for the attention operation (unused). + + Returns: + Output tensor of shape [batch_size, num_heads, seq_len, head_dim] + """ + # aiter.flash_attn_func expects tensors in [B, H, S, D] layout, + # which is what ring_attn provides. + output, _ = aiter.flash_attn_func( + query, + key, + value, + dropout_p=self.dropout_p, + causal=self.causal, + return_attn_probs=False, + return_lse=True, + ) + return output diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/attention_backend.py b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/attention_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..42256d261c9781c83021d25f04633b4d94e98de7 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/attention_backend.py @@ -0,0 +1,170 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/attention/backends/abstract.py + +from abc import ABC, abstractmethod +from dataclasses import dataclass, fields +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar + +if TYPE_CHECKING: + pass + +import torch + +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum + + +class AttentionBackend(ABC): + """Abstract class for attention backends.""" + + # For some attention backends, we allocate an output tensor before + # calling the custom op. When piecewise cudagraph is enabled, this + # makes sure the output tensor is allocated inside the cudagraph. + accept_output_buffer: bool = False + + @staticmethod + @abstractmethod + def get_enum() -> AttentionBackendEnum: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_impl_cls() -> type["AttentionImpl"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + raise NotImplementedError + + # @staticmethod + # @abstractmethod + # def get_state_cls() -> Type["AttentionState"]: + # raise NotImplementedError + + # @classmethod + # def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": + # return cls.get_metadata_cls()(*args, **kwargs) + + @staticmethod + @abstractmethod + def get_builder_cls() -> type["AttentionMetadataBuilder"]: + return None + + +@dataclass +class AttentionMetadata: + """Attention metadata for prefill and decode batched together.""" + + # Current step of diffusion process + current_timestep: int + + def asdict_zerocopy(self, skip_fields: set[str] | None = None) -> dict[str, Any]: + """Similar to dataclasses.asdict, but avoids deepcopying.""" + if skip_fields is None: + skip_fields = set() + # Note that if we add dataclasses as fields, they will need + # similar handling. + return { + field.name: getattr(self, field.name) + for field in fields(self) + if field.name not in skip_fields + } + + +T = TypeVar("T", bound=AttentionMetadata) + + +class AttentionMetadataBuilder(ABC, Generic[T]): + """Abstract class for attention metadata builders.""" + + @abstractmethod + def __init__(self) -> None: + """Create the builder, remember some configuration and parameters.""" + raise NotImplementedError + + @abstractmethod + def prepare(self) -> None: + """Prepare for one batch.""" + raise NotImplementedError + + @abstractmethod + def build( + self, + **kwargs: dict[str, Any], + ) -> AttentionMetadata: + """Build attention metadata with on-device tensors.""" + raise NotImplementedError + + +class AttentionLayer(Protocol): + + _k_scale: torch.Tensor + _v_scale: torch.Tensor + _k_scale_float: float + _v_scale_float: float + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: ... + + +class AttentionImpl(ABC, Generic[T]): + + @abstractmethod + def __init__( + self, + num_heads: int, + head_size: int, + softmax_scale: float, + causal: bool = False, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + raise NotImplementedError + + def preprocess_qkv(self, qkv: torch.Tensor, attn_metadata: T) -> torch.Tensor: + """Preprocess QKV tensor before performing attention operation. + + Default implementation returns the tensor unchanged. + Subclasses can override this to implement custom preprocessing + like reshaping, tiling, scaling, or other transformations. + + Called AFTER all_to_all for distributed attention + + """ + return qkv + + def postprocess_output( + self, + output: torch.Tensor, + attn_metadata: T, + ) -> torch.Tensor: + """Postprocess the output tensor after the attention operation. + + Default implementation returns the tensor unchanged. + Subclasses can override this to implement custom postprocessing + like untiling, scaling, or other transformations. + + Called BEFORE all_to_all for distributed attention + + """ + + return output + + @abstractmethod + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: T, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..72100c65b5bd298eaff57a674953dc1e116d858d --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py @@ -0,0 +1,565 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +from functools import lru_cache +from typing import Any, List, Optional, Tuple + +import torch + +from sglang.multimodal_gen.runtime.layers.utils import register_custom_op +from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) + +try: + from sgl_kernel.flash_attn import flash_attn_varlen_func + + from sglang.jit_kernel.flash_attention_v4 import ( + flash_attn_varlen_func as flash_attn_varlen_func_fa4, + ) + + def flash_attn_func(*args, ver: int = 3, **kwargs): + if ver == 4: + return flash_attn_varlen_func_fa4(*args, **kwargs) + return flash_attn_varlen_func(*args, **kwargs) + +except ImportError as e: + raise e + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +# ----------------------------- +# Fake implementations for schema / tracing +# custom op schema requires FIXED return structure. +# We provide TWO ops: +# 1) out-only op: always returns Tensor +# 2) out+lse op: always returns Tuple[Tensor, Tensor] +# ----------------------------- +def flash_attn_varlen_func_fake_out( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + qv: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + window_size: Optional[List[int]] = None, + attention_chunk: int = 0, + softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, + return_softmax_lse: bool = False, + sinks: Optional[torch.Tensor] = None, + ver: int = 4, +) -> torch.Tensor: + assert ver == 4, "only support flash attention v4" + q, k, v = [maybe_contiguous(t) for t in (q, k, v)] + num_head, head_dim = q.shape[-2:] + if cu_seqlens_q is None: + batch_size, seqlen_q = q.shape[:2] + else: + batch_size = cu_seqlens_q.shape[0] - 1 + seqlen_q = None + head_dim_v = v.shape[-1] + + if cu_seqlens_q is not None: + assert cu_seqlens_q.shape == ( + batch_size + 1, + ), "cu_seqlens_q must have shape (batch_size + 1,)" + assert cu_seqlens_q.dtype == torch.int32, "cu_seqlens_q must be int32" + assert cu_seqlens_q.stride(0) == 1, "cu_seqlens_q must be contiguous" + + assert q.dtype in [ + torch.float16, + torch.bfloat16, + ], "inputs must be float16 or bfloat16" + assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype" + assert head_dim <= 256, "head_dim must be less than or equal to 256" + alignment = 16 // q.element_size() + assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" + + q_batch_seqlen_shape = ( + (batch_size, seqlen_q) if cu_seqlens_q is None else (q.shape[0],) + ) + out = q.new_empty(*q_batch_seqlen_shape, num_head, head_dim_v) + return out + + +def flash_attn_varlen_func_fake_out_lse( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + qv: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + window_size: Optional[List[int]] = None, + attention_chunk: int = 0, + softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, + return_softmax_lse: bool = True, + sinks: Optional[torch.Tensor] = None, + ver: int = 4, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert ver == 4, "only support flash attention v4" + q, k, v = [maybe_contiguous(t) for t in (q, k, v)] + num_head, head_dim = q.shape[-2:] + if cu_seqlens_q is None: + batch_size, seqlen_q = q.shape[:2] + total_q = batch_size * seqlen_q + else: + batch_size = cu_seqlens_q.shape[0] - 1 + seqlen_q = None + total_q = q.shape[0] + head_dim_v = v.shape[-1] + + if cu_seqlens_q is not None: + assert cu_seqlens_q.shape == ( + batch_size + 1, + ), "cu_seqlens_q must have shape (batch_size + 1,)" + assert cu_seqlens_q.dtype == torch.int32, "cu_seqlens_q must be int32" + assert cu_seqlens_q.stride(0) == 1, "cu_seqlens_q must be contiguous" + + assert q.dtype in [ + torch.float16, + torch.bfloat16, + ], "inputs must be float16 or bfloat16" + assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype" + assert head_dim <= 256, "head_dim must be less than or equal to 256" + alignment = 16 // q.element_size() + assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" + + q_batch_seqlen_shape = ( + (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,) + ) + lse_shape = ( + (batch_size, num_head, seqlen_q) + if cu_seqlens_q is None + else (num_head, total_q) + ) + + out = q.new_empty(*q_batch_seqlen_shape, num_head, head_dim_v) + lse = q.new_empty(lse_shape, dtype=torch.float32) + return out, lse + + +# ----------------------------- +# Registered custom ops +# NOTE: fixed return schemas to avoid: +# "Object of type 'Tensor' is not an instance of 'sequence'" +# ----------------------------- +@register_custom_op(fake_impl=flash_attn_varlen_func_fake_out) +def flash_attn_varlen_func_op( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + qv: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + window_size: Optional[List[int]] = None, + attention_chunk: int = 0, + softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, + return_softmax_lse: bool = False, + sinks: Optional[torch.Tensor] = None, + ver: int = 4, +) -> torch.Tensor: + if window_size is None: + window_size = [-1, -1] + if return_softmax_lse: + raise ValueError( + "flash_attn_varlen_func_op is out-only op; return_softmax_lse must be False. " + "Use flash_attn_varlen_func_op_lse for (out, lse)." + ) + return flash_attn_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + page_table=page_table, + softmax_scale=softmax_scale, + causal=causal, + qv=qv, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=tuple(window_size), + attention_chunk=attention_chunk, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + sm_margin=sm_margin, + return_softmax_lse=False, + sinks=sinks, + ver=ver, + ) + + +@register_custom_op(fake_impl=flash_attn_varlen_func_fake_out_lse) +def flash_attn_varlen_func_op_lse( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + qv: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + window_size: Optional[List[int]] = None, + attention_chunk: int = 0, + softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, + return_softmax_lse: bool = True, + sinks: Optional[torch.Tensor] = None, + ver: int = 4, +) -> Tuple[torch.Tensor, torch.Tensor]: + if window_size is None: + window_size = [-1, -1] + if not return_softmax_lse: + raise ValueError( + "flash_attn_varlen_func_op_lse is out+lse op; return_softmax_lse must be True. " + "Use flash_attn_varlen_func_op for out-only." + ) + return flash_attn_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + page_table=page_table, + softmax_scale=softmax_scale, + causal=causal, + qv=qv, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=tuple(window_size), + attention_chunk=attention_chunk, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + sm_margin=sm_margin, + return_softmax_lse=True, + sinks=sinks, + ver=ver, + ) + + +try: + if current_platform.is_hopper(): + from flash_attn_interface import ( + flash_attn_varlen_func as flash_attn_varlen_func_upstream, + ) + else: + flash_attn_varlen_func_upstream = None + +except Exception: + flash_attn_varlen_func_upstream = None + logger.warning( + "flash_attn 3 package is not installed. It's recommended to install flash_attn3 on hopper, otherwise performance is sub-optimal" + ) + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, +) + +fa_ver = 3 + + +@lru_cache(maxsize=128) +def _get_cu_seqlens(device_index: int, bsz: int, seqlen: int) -> torch.Tensor: + return torch.arange( + 0, + (bsz + 1) * seqlen, + step=seqlen, + device=torch.device("cuda", device_index), + dtype=torch.int32, + ) + + +@lru_cache(maxsize=256) +def _should_use_upstream_flash_attention( + upstream_available: bool, + upstream_heads_ok: bool, + q_shape: tuple[int, ...], + k_shape: tuple[int, ...], + v_shape: tuple[int, ...], +) -> bool: + if not upstream_available or not upstream_heads_ok: + return False + + if len(q_shape) != 4 or len(k_shape) != 4 or len(v_shape) != 4: + return False + + bsz, seqlen, nheads_q, d = q_shape + bsz_k, seqlen_k, nheads_k, d_k = k_shape + bsz_v, seqlen_v, nheads_v, d_v = v_shape + + if ( + bsz != bsz_k + or bsz != bsz_v + or seqlen != seqlen_k + or seqlen != seqlen_v + or d != d_k + or d != d_v + ): + return False + if nheads_k != nheads_v: + return False + if nheads_k == 0 or (nheads_q % nheads_k) != 0: + return False + return True + + +def set_fa_ver(ver: int) -> None: + global fa_ver + fa_ver = ver + + +@dataclass +class FlashAttentionMetadata: + # Sequence lengths for the forward batch + # Maximum sequence length for query + max_seqlen_q: int = 1 + # Maximum sequence length for key + max_seqlen_k: int = 0 + # Cumulative sequence lengths for query + cu_seqlens_q: torch.Tensor = None + # Cumulative sequence lengths for key + cu_seqlens_k: torch.Tensor = None + + +class FlashAttentionMetadataBuilder(AttentionMetadataBuilder): + def __init__(self) -> None: + pass + + def prepare(self) -> None: + pass + + def build( # type: ignore + self, + raw_latent_shape=list, + **kwargs: dict[str, Any], + ) -> FlashAttentionMetadata: + # TODO: put empty values here to be set at first-run, since the q_len calculation can be complicated + return FlashAttentionMetadata(max_seqlen_q=None, max_seqlen_k=None) + + +class FlashAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_enum() -> AttentionBackendEnum: + return AttentionBackendEnum.FA + + @staticmethod + def get_impl_cls() -> type["FlashAttentionImpl"]: + return FlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + raise NotImplementedError + + @staticmethod + def get_builder_cls() -> type["AttentionMetadataBuilder"]: + return FlashAttentionMetadataBuilder + + +class FlashAttentionImpl(AttentionImpl): + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool, + softmax_scale: float, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_size = head_size + self.causal = causal + self.softmax_scale = softmax_scale + self.attention_metadata = FlashAttentionMetadata() + if self.num_kv_heads is None: + self._upstream_heads_ok = True + else: + # For gqa, the num_heads must be a multiple of num_kv_heads + self._upstream_heads_ok = ( + self.num_kv_heads > 0 and (self.num_heads % self.num_kv_heads) == 0 + ) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata = None, + *, + return_softmax_lse: bool = False, + ): + attn_metadata: FlashAttentionMetadata = get_forward_context().attn_metadata + if attn_metadata is not None and attn_metadata.max_seqlen_q is None: + attn_metadata.max_seqlen_q = query.shape[1] + attn_metadata.max_seqlen_k = key.shape[1] + max_seqlen_q = attn_metadata.max_seqlen_q + max_seqlen_k = attn_metadata.max_seqlen_k + else: + max_seqlen_q = query.shape[1] + max_seqlen_k = key.shape[1] + + q_shape = tuple(query.shape) + k_shape = tuple(key.shape) + v_shape = tuple(value.shape) + + use_upstream = _should_use_upstream_flash_attention( + flash_attn_varlen_func_upstream is not None, + self._upstream_heads_ok, + q_shape, + k_shape, + v_shape, + ) + + if use_upstream: + bsz, seqlen, nheads_q, d = q_shape + q_ = query.contiguous() + k_ = key.contiguous() + v_ = value.contiguous() + out = flash_attn_varlen_func_upstream( + q_, + k_, + v_, + None, + None, + seqlen, + seqlen, + softmax_scale=self.softmax_scale, + causal=self.causal, + return_attn_probs=return_softmax_lse, + ) + if return_softmax_lse: + out_tensor, softmax_lse = out + return out_tensor.reshape(bsz, seqlen, nheads_q, -1), softmax_lse + return out.reshape(bsz, seqlen, nheads_q, d) + + # FA version selection: + # - fa_ver == 3: call python function (can return Tensor or (Tensor, Tensor) depending on flag) + # - fa_ver == 4: call custom ops with FIXED return schema + if fa_ver == 3: + flash_attn_op = flash_attn_func + output = flash_attn_op( + q=query, + k=key, + v=value, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.softmax_scale, + causal=self.causal, + return_softmax_lse=return_softmax_lse, + ver=fa_ver, + ) + return output + + if fa_ver == 4: + if return_softmax_lse: + out_tensor, softmax_lse = flash_attn_varlen_func_op_lse( + q=query, + k=key, + v=value, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.softmax_scale, + causal=self.causal, + return_softmax_lse=True, + ver=fa_ver, + ) + return out_tensor, softmax_lse + out_tensor = flash_attn_varlen_func_op( + q=query, + k=key, + v=value, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.softmax_scale, + causal=self.causal, + return_softmax_lse=False, + ver=fa_ver, + ) + return out_tensor + + raise ValueError(f"flash attention version {fa_ver} is not supported.") diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn_2.py b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn_2.py new file mode 100644 index 0000000000000000000000000000000000000000..62a1974adc4edc82746567d029c415f675feb24e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn_2.py @@ -0,0 +1,79 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, +) +from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import ( + flash_attn_func, +) +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class FlashAttention2Backend(AttentionBackend): + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_enum() -> AttentionBackendEnum: + return AttentionBackendEnum.FA2 + + @staticmethod + def get_impl_cls() -> type["FlashAttention2Impl"]: + return FlashAttention2Impl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + raise NotImplementedError + + @staticmethod + def get_builder_cls() -> type["AttentionMetadataBuilder"]: + raise NotImplementedError + + +class FlashAttention2Impl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool, + softmax_scale: float, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.causal = causal + self.softmax_scale = softmax_scale + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata, + ): + output = flash_attn_func( + q=query, # type: ignore[no-untyped-call] + k=key, + v=value, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=None, + max_seqlen_k=None, + softmax_scale=self.softmax_scale, + causal=self.causal, + ) + return output diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/sage_attn.py b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/sage_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..973608f6487ab5cc64f253c11ffe5aeffeae9e3a --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/sage_attn.py @@ -0,0 +1,74 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from sageattention import sageattn + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( # FlashAttentionMetadata, + AttentionBackend, + AttentionImpl, + AttentionMetadata, +) +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class SageAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_enum() -> AttentionBackendEnum: + return AttentionBackendEnum.SAGE_ATTN + + @staticmethod + def get_impl_cls() -> type["SageAttentionImpl"]: + return SageAttentionImpl + + +class SageAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool, + softmax_scale: float, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout = extra_impl_args.get("dropout_p", 0.0) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata, + *, + return_softmax_lse: bool = False, + ) -> torch.Tensor: + output = sageattn( + query, + key, + value, + # since input is (batch_size, seq_len, head_num, head_dim) + tensor_layout="NHD", + is_causal=self.causal, + sm_scale=self.softmax_scale, + return_lse=return_softmax_lse, + ) + if return_softmax_lse: + output, softmax_lse = output + return output, softmax_lse + return output diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/sage_attn3.py b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/sage_attn3.py new file mode 100644 index 0000000000000000000000000000000000000000..ef78e80fe816a38edb60bed211e7f15140be2216 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/sage_attn3.py @@ -0,0 +1,92 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.nn.functional as F +from sageattn3 import sageattn3_blackwell + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, +) +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class SageAttention3Backend(AttentionBackend): + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [64, 128, 256] + + @staticmethod + def get_enum() -> AttentionBackendEnum: + return AttentionBackendEnum.SAGE_ATTN_3 + + @staticmethod + def get_impl_cls() -> type["SageAttention3Impl"]: + return SageAttention3Impl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + raise NotImplementedError + + +class SageAttention3Impl(AttentionImpl): + _warned_gqa_fallback_global: bool = False + + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool, + softmax_scale: float, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout = extra_impl_args.get("dropout_p", 0.0) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + # SageAttention3's Blackwell kernel assumes MHA (Hq == Hkv). For GQA/MQA + # (Hq != Hkv), fall back to torch SDPA which supports GQA. + if key.shape[1] != query.shape[1]: + if query.shape[1] % key.shape[1] != 0: + raise ValueError( + "GQA/MQA requires query heads to be a multiple of KV heads, " + f"got q_heads={query.shape[1]} and kv_heads={key.shape[1]}" + ) + if not type(self)._warned_gqa_fallback_global: + logger.warning( + "SageAttention3 does not support GQA/MQA (Hq != Hkv); falling back to torch SDPA." + ) + type(self)._warned_gqa_fallback_global = True + output = F.scaled_dot_product_attention( + query, + key, + value, + is_causal=self.causal, + dropout_p=self.dropout, + scale=self.softmax_scale, + enable_gqa=True, + ) + else: + output = sageattn3_blackwell(query, key, value, is_causal=self.causal) + output = output.transpose(1, 2) + return output diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/sdpa.py b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/sdpa.py new file mode 100644 index 0000000000000000000000000000000000000000..5523f655920c990f20a85d7281107a183ace715e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/sdpa.py @@ -0,0 +1,78 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( # FlashAttentionMetadata, + AttentionBackend, + AttentionImpl, + AttentionMetadata, +) +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class SDPABackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_enum() -> AttentionBackendEnum: + return AttentionBackendEnum.TORCH_SDPA + + @staticmethod + def get_impl_cls() -> type["SDPAImpl"]: + return SDPAImpl + + # @staticmethod + # def get_metadata_cls() -> Type["AttentionMetadata"]: + # return FlashAttentionMetadata + + +class SDPAImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool, + softmax_scale: float, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout = extra_impl_args.get("dropout_p", 0.0) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + # transpose to bs, heads, seq_len, head_dim + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + attn_kwargs = { + "attn_mask": None, + "dropout_p": self.dropout, + "is_causal": self.causal, + "scale": self.softmax_scale, + } + if query.shape[1] != key.shape[1]: + attn_kwargs["enable_gqa"] = True + output = torch.nn.functional.scaled_dot_product_attention( + query, key, value, **attn_kwargs + ) + output = output.transpose(1, 2) + return output diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/sliding_tile_attn.py b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/sliding_tile_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..37a1acf3014bf9f420abf17c0d37dd973fb3fdb1 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/sliding_tile_attn.py @@ -0,0 +1,316 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import json +from dataclasses import dataclass +from typing import Any + +import torch +from einops import rearrange + +from sglang.multimodal_gen.runtime.distributed import get_sp_group +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, +) +from sglang.multimodal_gen.runtime.managers.forward_context import ( + ForwardContext, + get_forward_context, +) +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.server_args import get_global_server_args +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import dict_to_3d_list + +try: + from st_attn import sliding_tile_attention + + st_attn_backend_available = True +except Exception: + st_attn_backend_available = False + +logger = init_logger(__name__) + + +class RangeDict(dict): + + def __getitem__(self, item: int) -> str: + for key in self.keys(): + if isinstance(key, tuple): + low, high = key + if low <= item <= high: + return str(super().__getitem__(key)) + elif key == item: + return str(super().__getitem__(key)) + raise KeyError(f"seq_len {item} not supported for STA") + + +class SlidingTileAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + # TODO(will-refactor): check this + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_enum() -> AttentionBackendEnum: + return AttentionBackendEnum.SLIDING_TILE_ATTN + + @staticmethod + def get_impl_cls() -> type["SlidingTileAttentionImpl"]: + return SlidingTileAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["SlidingTileAttentionMetadata"]: + return SlidingTileAttentionMetadata + + @staticmethod + def get_builder_cls() -> type["SlidingTileAttentionMetadataBuilder"]: + return SlidingTileAttentionMetadataBuilder + + +@dataclass +class SlidingTileAttentionMetadata(AttentionMetadata): + current_timestep: int + STA_param: list[ + list[Any] + ] # each timestep with one metadata, shape [num_layers, num_heads] + + +class SlidingTileAttentionMetadataBuilder(AttentionMetadataBuilder): + + def __init__(self): + pass + + def prepare(self): + pass + + def build( # type: ignore + self, + STA_param: list[list[Any]], + current_timestep: int, + **kwargs: dict[str, Any], + ) -> SlidingTileAttentionMetadata: + param = STA_param + if param is None: + return SlidingTileAttentionMetadata( + current_timestep=current_timestep, STA_param=[] + ) + return SlidingTileAttentionMetadata( + current_timestep=current_timestep, STA_param=param[current_timestep] + ) + + +class SlidingTileAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool, + softmax_scale: float, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + if not st_attn_backend_available: + raise ValueError("st attn not supported") + # TODO(will-refactor): for now this is the mask strategy, but maybe we should + # have a more general config for STA? + mask_strategy_file_path = ( + get_global_server_args().attention_backend_config.mask_strategy_file_path + ) + if mask_strategy_file_path is None: + raise ValueError("SGLANG_DIFFUSION_ATTENTION_CONFIG is not set") + + # TODO(kevin): get mask strategy for different STA modes + with open(mask_strategy_file_path) as f: + mask_strategy = json.load(f) + self.mask_strategy = dict_to_3d_list(mask_strategy) + + self.prefix = prefix + sp_group = get_sp_group() + self.sp_size = sp_group.world_size + # STA config + self.STA_base_tile_size = [6, 8, 8] + self.dit_seq_shape_mapping = RangeDict( + { + (115200, 115456): "30x48x80", + 82944: "36x48x48", + 69120: "18x48x80", + } + ) + self.full_window_mapping = { + "30x48x80": [5, 6, 10], + "36x48x48": [6, 6, 6], + "18x48x80": [3, 6, 10], + } + + def tile(self, x: torch.Tensor) -> torch.Tensor: + return rearrange( + x, + "b (n_t ts_t n_h ts_h n_w ts_w) h d -> b (n_t n_h n_w ts_t ts_h ts_w) h d", + n_t=self.full_window_size[0], + n_h=self.full_window_size[1], + n_w=self.full_window_size[2], + ts_t=self.STA_base_tile_size[0], + ts_h=self.STA_base_tile_size[1], + ts_w=self.STA_base_tile_size[2], + ) + + def untile(self, x: torch.Tensor) -> torch.Tensor: + x = rearrange( + x, + "b (n_t n_h n_w ts_t ts_h ts_w) h d -> b (n_t ts_t n_h ts_h n_w ts_w) h d", + n_t=self.full_window_size[0], + n_h=self.full_window_size[1], + n_w=self.full_window_size[2], + ts_t=self.STA_base_tile_size[0], + ts_h=self.STA_base_tile_size[1], + ts_w=self.STA_base_tile_size[2], + ) + return x + + def preprocess_qkv( + self, + qkv: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + img_sequence_length = qkv.shape[1] + self.dit_seq_shape_str = self.dit_seq_shape_mapping[img_sequence_length] + self.full_window_size = self.full_window_mapping[self.dit_seq_shape_str] + self.dit_seq_shape_int = list(map(int, self.dit_seq_shape_str.split("x"))) + self.img_seq_length = ( + self.dit_seq_shape_int[0] + * self.dit_seq_shape_int[1] + * self.dit_seq_shape_int[2] + ) + return self.tile(qkv) + + def postprocess_output( + self, + output: torch.Tensor, + attn_metadata: SlidingTileAttentionMetadata, + ) -> torch.Tensor: + return self.untile(output) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_metadata: SlidingTileAttentionMetadata, + ) -> torch.Tensor: + if self.mask_strategy is None: + raise ValueError("mask_strategy cannot be None for SlidingTileAttention") + if self.mask_strategy[0] is None: + raise ValueError("mask_strategy[0] cannot be None for SlidingTileAttention") + + timestep = attn_metadata.current_timestep + forward_context: ForwardContext = get_forward_context() + forward_batch = forward_context.forward_batch + if forward_batch is None: + raise ValueError("forward_batch cannot be None") + # pattern:'.double_blocks.0.attn.impl' or '.single_blocks.0.attn.impl' + layer_idx = int(self.prefix.split(".")[-3]) + if attn_metadata.STA_param is None or len(attn_metadata.STA_param) <= layer_idx: + raise ValueError("Invalid STA_param") + STA_param = attn_metadata.STA_param[layer_idx] + + text_length = q.shape[1] - self.img_seq_length + has_text = text_length > 0 + + query = q.transpose(1, 2).contiguous() + key = k.transpose(1, 2).contiguous() + value = v.transpose(1, 2).contiguous() + + head_num = query.size(1) + sp_group = get_sp_group() + current_rank = sp_group.rank_in_group + start_head = current_rank * head_num + + # searching or tuning mode + if len(STA_param) < head_num * sp_group.world_size: + sparse_attn_hidden_states_all = [] + full_mask_window = STA_param[-1] + for window_size in STA_param[:-1]: + sparse_hidden_states = sliding_tile_attention( + query, + key, + value, + [window_size] * head_num, + text_length, + has_text, + self.dit_seq_shape_str, + ).transpose(1, 2) + sparse_attn_hidden_states_all.append(sparse_hidden_states) + + hidden_states = sliding_tile_attention( + query, + key, + value, + [full_mask_window] * head_num, + text_length, + has_text, + self.dit_seq_shape_str, + ).transpose(1, 2) + + attn_L2_loss = [] + attn_L1_loss = [] + # average loss across all heads + for sparse_attn_hidden_states in sparse_attn_hidden_states_all: + # L2 loss + attn_L2_loss_ = ( + torch.mean( + (sparse_attn_hidden_states.float() - hidden_states.float()) + ** 2, + dim=[0, 1, 3], + ) + .cpu() + .numpy() + ) + attn_L2_loss_ = [round(float(x), 6) for x in attn_L2_loss_] + attn_L2_loss.append(attn_L2_loss_) + # L1 loss + attn_L1_loss_ = ( + torch.mean( + torch.abs( + sparse_attn_hidden_states.float() - hidden_states.float() + ), + dim=[0, 1, 3], + ) + .cpu() + .numpy() + ) + attn_L1_loss_ = [round(float(x), 6) for x in attn_L1_loss_] + attn_L1_loss.append(attn_L1_loss_) + + layer_loss_save = {"L2_loss": attn_L2_loss, "L1_loss": attn_L1_loss} + + if forward_batch.is_cfg_negative: + if forward_batch.mask_search_final_result_neg is not None: + forward_batch.mask_search_final_result_neg[timestep].append( + layer_loss_save + ) + else: + if forward_batch.mask_search_final_result_pos is not None: + forward_batch.mask_search_final_result_pos[timestep].append( + layer_loss_save + ) + else: + windows = [STA_param[head_idx + start_head] for head_idx in range(head_num)] + + hidden_states = sliding_tile_attention( + query, + key, + value, + windows, + text_length, + has_text, + self.dit_seq_shape_str, + ).transpose(1, 2) + + return hidden_states diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/sparse_linear_attn.py b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/sparse_linear_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..c6c25ebf2e87ad31deda2e18d0dbfd85c2ad0798 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/sparse_linear_attn.py @@ -0,0 +1,695 @@ +""" +Copyright (c) 2025 by SLA team. + +Licensed under the Apache License, Version 2.0 (the "License"); + +This implementation is adapted from: from https://github.com/thu-ml/TurboDiffusion/blob/main/turbodiffusion/SLA/core.py and https://github.com/thu-ml/SLA/blob/main/SageSLA/core.py +Citation (please cite if you use this code): + +@article{zhang2025sla, + title={SLA: Beyond Sparsity in Diffusion Transformers via Fine-Tunable Sparse-Linear Attention}, + author={Jintao Zhang and Haoxu Wang and Kai Jiang and Shuo Yang and Kaiwen Zheng and Haocheng Xi and Ziteng Wang and Hongzhou Zhu and Min Zhao and Ion Stoica and Joseph E. Gonzalez and Jun Zhu and Jianfei Chen}, + journal={arXiv preprint arXiv:2509.24006}, + year={2025} +} +""" + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, +) +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +# ==================================SLA Functions=================================== +def get_block_map(q, k, topk_ratio, BLKQ=64, BLKK=64): + arg_k = k - torch.mean( + k, dim=-2, keepdim=True + ) # smooth-k technique in SageAttention + pooled_qblocks = mean_pool(q, BLKQ) + pooled_kblocks = mean_pool(arg_k, BLKK) + pooled_score = pooled_qblocks @ pooled_kblocks.transpose(-1, -2) + + K = pooled_score.shape[-1] + topk = min(K, int(topk_ratio * K)) + lut = torch.topk(pooled_score, topk, dim=-1, sorted=False).indices + + sparse_map = torch.zeros_like(pooled_score, dtype=torch.int8) + sparse_map.scatter_(-1, lut, 1) + return sparse_map, lut, topk + + +def mean_pool(x, BLK): + assert x.is_contiguous() + + B, H, L, D = x.shape + L_BLOCKS = (L + BLK - 1) // BLK + x_mean = torch.empty((B, H, L_BLOCKS, D), device=x.device, dtype=x.dtype) + + grid = (L_BLOCKS, B * H) + compress_kernel[grid](x, x_mean, L, D, BLK) + return x_mean + + +@triton.jit +def compress_kernel( + X, + XM, + L: tl.constexpr, + D: tl.constexpr, + BLOCK_L: tl.constexpr, +): + idx_l = tl.program_id(0) + idx_bh = tl.program_id(1) + + offs_l = idx_l * BLOCK_L + tl.arange(0, BLOCK_L) + offs_d = tl.arange(0, D) + + x_offset = idx_bh * L * D + xm_offset = idx_bh * ((L + BLOCK_L - 1) // BLOCK_L) * D + x = tl.load( + X + x_offset + offs_l[:, None] * D + offs_d[None, :], mask=offs_l[:, None] < L + ) + + nx = min(BLOCK_L, L - idx_l * BLOCK_L) + x_mean = tl.sum(x, axis=0, dtype=tl.float32) / nx + tl.store(XM + xm_offset + idx_l * D + offs_d, x_mean.to(XM.dtype.element_ty)) + + +@triton.jit +def _attn_fwd( + Q, + K, + V, + qk_scale: tl.constexpr, + topk: tl.constexpr, + LUT, + LSE, + OS, + L: tl.constexpr, + M_BLOCKS: tl.constexpr, + D: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + idx_m = tl.program_id(0).to(tl.int64) + idx_bh = tl.program_id(1).to(tl.int64) + + qkv_offset = idx_bh * L * D + lut_offset = (idx_bh * M_BLOCKS + idx_m) * topk + lse_offset = idx_bh * L + offs_m = idx_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, D) + + Q_ptrs = Q + qkv_offset + offs_m[:, None] * D + offs_d[None, :] + K_ptrs = K + qkv_offset + offs_n[None, :] * D + offs_d[:, None] + V_ptrs = V + qkv_offset + offs_n[:, None] * D + offs_d[None, :] + OS_ptrs = OS + qkv_offset + offs_m[:, None] * D + offs_d[None, :] + LUT_ptr = LUT + lut_offset + LSE_ptrs = LSE + lse_offset + offs_m + + m_i = tl.full([BLOCK_M], -float("inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + o_s = tl.zeros([BLOCK_M, D], dtype=tl.float32) + + q = tl.load(Q_ptrs, mask=offs_m[:, None] < L) + for block_idx in tl.range(topk): + idx_n = tl.load(LUT_ptr + block_idx) + n_mask = offs_n < L - idx_n * BLOCK_N + + k = tl.load(K_ptrs + idx_n * BLOCK_N * D, mask=n_mask[None, :]) + qk = tl.dot(q, k) * (qk_scale * 1.4426950408889634) # = 1 / ln(2) + if L - idx_n * BLOCK_N < BLOCK_N: + qk = tl.where(n_mask[None, :], qk, float("-inf")) + + v = tl.load(V_ptrs + idx_n * BLOCK_N * D, mask=n_mask[:, None]) + local_m = tl.max(qk, 1) + new_m = tl.maximum(m_i, local_m) + qk = qk - new_m[:, None] + + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - new_m) + o_s = o_s * alpha[:, None] + o_s += tl.dot(p.to(v.dtype), v) + + l_i = l_i * alpha + l_ij + m_i = new_m + + o_s = o_s / l_i[:, None] + tl.store(OS_ptrs, o_s.to(OS.type.element_ty), mask=offs_m[:, None] < L) + + m_i += tl.math.log2(l_i) + tl.store(LSE_ptrs, m_i, mask=offs_m < L) + + +def _get_cuda_arch(device_index: int) -> str: + """Get CUDA architecture string for the given device.""" + major, minor = torch.cuda.get_device_capability(device_index) + return f"sm{major}{minor}" + + +# ==================================SLA Class=================================== +class SparseLinearAttentionBackend(AttentionBackend): + """Sparse Linear Attention Backend for efficient attention computation.""" + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [64, 128] + + @staticmethod + def get_enum() -> AttentionBackendEnum: + return AttentionBackendEnum.SLA_ATTN + + @staticmethod + def get_impl_cls() -> type["SparseLinearAttentionImpl"]: + return SparseLinearAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["SparseLinearAttentionMetadata"]: + return SparseLinearAttentionMetadata + + @staticmethod + def get_builder_cls() -> type["SparseLinearAttentionMetadataBuilder"]: + return SparseLinearAttentionMetadataBuilder + + +@dataclass +class SparseLinearAttentionMetadata(AttentionMetadata): + """Metadata for Sparse Linear Attention computation.""" + + # Basic attention parameters + current_timestep: int + + # Sparse attention configuration + topk_ratio: float = 0.1 + + +class SparseLinearAttentionMetadataBuilder(AttentionMetadataBuilder): + """Builder for SparseLinearAttentionMetadata.""" + + def __init__(self) -> None: + pass + + def prepare(self) -> None: + pass + + def build( + self, + current_timestep: int, + topk_ratio: float = 0.1, + **kwargs: dict[str, Any], + ) -> SparseLinearAttentionMetadata: + return SparseLinearAttentionMetadata( + current_timestep=current_timestep, + topk_ratio=topk_ratio, + ) + + +class SparseLinearAttentionImpl(AttentionImpl, nn.Module): + """Implementation of sparse linear attention for the backend.""" + + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool = False, + softmax_scale: float | None = None, + num_kv_heads: int | None = None, + prefix: str = "", + # SLA-specific parameters - matched to TurboDiffusion defaults + topk_ratio: float = 0.1, # TurboDiffusion uses topk=0.1 + feature_map: str = "softmax", + BLKQ: int = 128, # TurboDiffusion uses BLKQ=128 + BLKK: int = 64, # TurboDiffusion uses BLKK=64 + use_bf16: bool = True, + **extra_impl_args, + ) -> None: + nn.Module.__init__(self) + + # SLA-specific config + self.topk_ratio = topk_ratio + self.BLKQ = BLKQ + self.BLKK = BLKK + self.dtype = torch.bfloat16 if use_bf16 else torch.float16 + + # Learnable linear projection for combining sparse + linear attention + self.proj_l = nn.Linear(head_size, head_size, dtype=torch.float32) + + # Feature map for linear attention + # Type annotation for callables + self.feature_map_q: Callable[[torch.Tensor], torch.Tensor] + self.feature_map_k: Callable[[torch.Tensor], torch.Tensor] + if feature_map == "elu": + self.feature_map_q = lambda x: F.elu(x) + 1 + self.feature_map_k = lambda x: F.elu(x) + 1 + elif feature_map == "relu": + self.feature_map_q = F.relu + self.feature_map_k = F.relu + elif feature_map == "softmax": + self.feature_map_q = lambda x: F.softmax(x, dim=-1) + self.feature_map_k = lambda x: F.softmax(x, dim=-1) + else: + raise ValueError(f"Unknown feature map: {feature_map}") + + self._init_weights() + + def _init_weights(self) -> None: + """Initialize projection weights to zero for residual-like behavior.""" + with torch.no_grad(): + nn.init.zeros_(self.proj_l.weight) + nn.init.zeros_(self.proj_l.bias) # type: ignore[arg-type] + + def _calc_linear_attention_with_torch(self, q, k, v): + kv = torch.matmul(k.transpose(-1, -2), v) + k_sum = torch.sum(k, dim=-2, keepdim=True) + return torch.matmul(q, kv) / (1e-5 + torch.matmul(q, k_sum.transpose(-1, -2))) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: SparseLinearAttentionMetadata = None, + ) -> torch.Tensor: + """Forward pass for sparse linear attention. + + Args: + query: query tensor of shape (B, H, L, D) + key: key tensor of shape (B, H, L, D) + value: value tensor of shape (B, H, L, D) + attn_metadata: attention metadata containing configuration + Returns: + output tensor of shape (B, H, L, D) + """ + dtype = query.dtype + + # Transpose for computation + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + + # Get sparse attention map + sparse_map, lut, real_topk = get_block_map( + query, key, topk_ratio=self.topk_ratio, BLKQ=self.BLKQ, BLKK=self.BLKK + ) + + # Convert to computation dtype + query = query.to(self.dtype) + key = key.to(self.dtype) + value = value.to(self.dtype) + + # Sparse attention computation + o_s = _attention.apply( + query, key, value, sparse_map, lut, real_topk, self.BLKQ, self.BLKK + ) + + # Apply feature maps + query = self.feature_map_q(query).contiguous().to(self.dtype) # c_q + key = self.feature_map_k(key).contiguous().to(self.dtype) # c_k + # Linear attention computation + o_l = self._calc_linear_attention_with_torch(query, key, value) + + # Apply projection and combine results + with torch.amp.autocast("cuda", dtype=self.dtype): + o_l = self.proj_l(o_l) + + # Combine sparse and linear attention + output = (o_s + o_l).to(dtype).transpose(1, 2) + + return output + + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, k_block_id, lut, topk, BLOCK_M, BLOCK_N, qk_scale=None): + assert q.is_contiguous() and k.is_contiguous() and v.is_contiguous() + assert k_block_id.is_contiguous() and lut.is_contiguous() + + # We recommend the following two settings + assert BLOCK_M == 64 or BLOCK_M == 128 + assert BLOCK_N == 64 + + B, H, L, D = q.shape + if qk_scale is None: + qk_scale = D**-0.5 + + M_BLOCKS = triton.cdiv(L, BLOCK_M) + + o_s = torch.empty_like(v) + lse = torch.empty(q.shape[:-1], device=q.device, dtype=torch.float32) + + grid = (M_BLOCKS, B * H) + _attn_fwd[grid]( + q, + k, + v, + qk_scale, + topk, + lut, + lse, + o_s, + L, + M_BLOCKS, + D, + BLOCK_M, + BLOCK_N, + num_warps=4 if q.shape[-1] == 64 else 8, + num_stages=3, + ) + + ctx.save_for_backward(q, k, v, k_block_id, lut, lse, o_s) + ctx.qk_scale = qk_scale + ctx.topk = topk + ctx.BLOCK_M = BLOCK_M + ctx.BLOCK_N = BLOCK_N + return o_s + + +# ==================================SageSLA Class=================================== +SAGESLA_ENABLED = True +try: + import spas_sage_attn._fused as fused + import spas_sage_attn._qattn as qattn + from spas_sage_attn.utils import block_map_lut_triton, get_vanilla_qk_quant +except ImportError: + SAGESLA_ENABLED = False + +SAGE2PP_ENABLED = True +try: + from spas_sage_attn._qattn import ( + qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold, + ) +except ImportError: + SAGE2PP_ENABLED = False + + +class SageSparseLinearAttentionBackend(AttentionBackend): + """Quantized Sparse-Linear Attention backend using SageAttention kernels.""" + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [64, 128] + + @staticmethod + def get_enum() -> AttentionBackendEnum: + return AttentionBackendEnum.SAGE_SLA_ATTN + + @staticmethod + def get_impl_cls() -> type["SageSparseLinearAttentionImpl"]: + return SageSparseLinearAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["SageSparseLinearAttentionMetadata"]: + return SageSparseLinearAttentionMetadata + + @staticmethod + def get_builder_cls() -> type["SageSparseLinearAttentionMetadataBuilder"]: + return SageSparseLinearAttentionMetadataBuilder + + +@dataclass +class SageSparseLinearAttentionMetadata(AttentionMetadata): + """Metadata for Sage Sparse Linear Attention computation.""" + + # Basic attention parameters + current_timestep: int + + # Sparse attention configuration + topk_ratio: float = 0.1 + + +class SageSparseLinearAttentionMetadataBuilder(AttentionMetadataBuilder): + """Builder for SageSparseLinearAttentionMetadata.""" + + def __init__(self) -> None: + pass + + def prepare(self) -> None: + pass + + def build( + self, + current_timestep: int, + topk_ratio: float = 0.1, + **kwargs: dict[str, Any], + ) -> SageSparseLinearAttentionMetadata: + return SageSparseLinearAttentionMetadata( + current_timestep=current_timestep, + topk_ratio=topk_ratio, + ) + + +class SageSparseLinearAttentionImpl(AttentionImpl, nn.Module): + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool = False, + softmax_scale: float | None = None, + num_kv_heads: int | None = None, + prefix: str = "", + topk_ratio: float = 0.5, + feature_map: str = "softmax", + use_bf16: bool = True, + **extra_impl_args, + ) -> None: + nn.Module.__init__(self) + + assert ( + SAGESLA_ENABLED + ), "Install spas_sage_attn(pip install git+https://github.com/thu-ml/SpargeAttn.git --no-build-isolation) first to enable SageSLA." + + self.num_heads = num_heads + self.head_size = head_size + self.softmax_scale = softmax_scale if softmax_scale else head_size**-0.5 + self.causal = causal + self.prefix = prefix + + self.topk_ratio = topk_ratio + self.dtype = torch.bfloat16 if use_bf16 else torch.float16 + + # Learnable linear projection for combining sparse + linear attention + self.proj_l = nn.Linear(head_size, head_size, dtype=torch.float32) + + # Feature map for linear attention + # Type annotation for callables + self.feature_map_q: Callable[[torch.Tensor], torch.Tensor] + self.feature_map_k: Callable[[torch.Tensor], torch.Tensor] + if feature_map == "elu": + self.feature_map_q = lambda x: F.elu(x) + 1 + self.feature_map_k = lambda x: F.elu(x) + 1 + elif feature_map == "relu": + self.feature_map_q = F.relu + self.feature_map_k = F.relu + elif feature_map == "softmax": + self.feature_map_q = lambda x: F.softmax(x, dim=-1) + self.feature_map_k = lambda x: F.softmax(x, dim=-1) + else: + raise ValueError(f"Unknown feature map: {feature_map}") + + self._init_weights() + + def _init_weights(self) -> None: + """Initialize projection weights to zero for residual-like behavior.""" + with torch.no_grad(): + nn.init.zeros_(self.proj_l.weight) + nn.init.zeros_(self.proj_l.bias) # type: ignore[arg-type] + + def _calc_linear_attention_with_torch( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ): + kv = torch.matmul(k.transpose(-1, -2), v) + k_sum = torch.sum(k, dim=-2, keepdim=True) + return torch.matmul(q, kv) / (1e-5 + torch.matmul(q, k_sum.transpose(-1, -2))) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + """Forward pass for Sage Sparse Linear attention with quantized kernels. + Args: + query: query tensor of shape (B, L, H, D) + key: key tensor of shape (B, L, H, D) + value: value tensor of shape (B, L, H, D) + attn_metadata: attention metadata containing configuration + Returns: + output tensor of shape (B, L, H, D) + """ + dtype = query.dtype + + # Transpose from (B, L, H, D) to SLA format (B, H, L, D) + q = query.transpose(1, 2).contiguous() + k = key.transpose(1, 2).contiguous() + v = value.transpose(1, 2).contiguous() + + # Determine block sizes based on GPU architecture + arch = _get_cuda_arch(q.device.index) + + if arch == "sm90": + BLKQ = 64 + BLKK = 128 + else: + BLKQ = 128 + BLKK = 64 + # Compute block-sparse attention pattern + sparse_map, lut, real_topk = get_block_map( + q, k, topk_ratio=self.topk_ratio, BLKQ=BLKQ, BLKK=BLKK + ) + + # Convert to compute dtype + q = q.to(self.dtype) + k = k.to(self.dtype) + v = v.to(self.dtype) + + ########## SPARGE BEGIN ########## + km = k.mean(dim=-2, keepdim=True) + headdim = q.size(-1) + assert headdim in [ + 64, + 128, + ], "headdim should be in [64, 128]. For other headdim, you can use padding and specify the softmax scale." + + # Quantize Q, K to INT8 + q_int8, q_scale, k_int8, k_scale = get_vanilla_qk_quant(q, k, km, BLKQ, BLKK) + lut, valid_block_num = block_map_lut_triton(sparse_map) + scale = 1.0 / (headdim**0.5) + + o_s = torch.empty_like(q) + + if arch in ("sm80", "sm86", "sm87"): + pvthreshold = torch.full( + (q.shape[-3],), 1e6, dtype=torch.float32, device=q.device + ) + v_fp16 = v.to(torch.float16) + qattn.qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold( + q_int8, + k_int8, + v_fp16, + o_s, + lut, + valid_block_num, + pvthreshold, + q_scale, + k_scale, + 1, + False, + 1, + scale, + 0, + ) + else: + b, h_kv, kv_len, head_dim = v.shape + padded_len = (kv_len + 127) // 128 * 128 + v_transposed_permutted = torch.empty( + (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device + ) + fused.transpose_pad_permute_cuda(v, v_transposed_permutted, 1) + v_fp8 = torch.empty( + v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device + ) + v_scale = torch.empty( + (b, h_kv, head_dim), dtype=torch.float32, device=v.device + ) + fused.scale_fuse_quant_cuda( + v_transposed_permutted, v_fp8, v_scale, kv_len, 2.25, 1 + ) + + if arch == "sm90": + qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_sm90( + q_int8, + k_int8, + v_fp8, + o_s, + lut, + valid_block_num, + q_scale, + k_scale, + v_scale, + 1, + False, + 1, + scale, + ) + else: + pvthreshold = torch.full( + (q.shape[-3],), 1e6, dtype=torch.float32, device=q.device + ) + if SAGE2PP_ENABLED: + qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold( + q_int8, + k_int8, + v_fp8, + o_s, + lut, + valid_block_num, + pvthreshold, + q_scale, + k_scale, + v_scale, + 1, + False, + 1, + scale, + 0, + ) + else: + qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold( + q_int8, + k_int8, + v_fp8, + o_s, + lut, + valid_block_num, + pvthreshold, + q_scale, + k_scale, + v_scale, + 1, + False, + 1, + scale, + 0, + ) + + ########## SPARGE END ########## + + # Linear attention with feature maps + q_linear = self.feature_map_q(q).contiguous().to(self.dtype) + k_linear = self.feature_map_k(k).contiguous().to(self.dtype) + o_l = self._calc_linear_attention_with_torch(q_linear, k_linear, v) + + # Project linear attention output and combine + with torch.amp.autocast("cuda", dtype=self.dtype): + o_l = self.proj_l(o_l) + + # Combine sparse and linear outputs + output = (o_s + o_l).to(dtype).transpose(1, 2) + + return output diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/sparse_video_gen_2_attn.py b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/sparse_video_gen_2_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..0d07259c011348d33b2141d56c901e4c33b2ded7 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/sparse_video_gen_2_attn.py @@ -0,0 +1,562 @@ +""" +Sparse Video Gen 2 (SAP) attention backend. + +This is a baseline integration that wires the backend into the +attention framework. + +Adapted from https://github.com/svg-project/Sparse-VideoGen/blob/main/svg/models/wan/attention.py +""" + +from dataclasses import dataclass, field +from typing import Any + +import torch +import torch.nn.functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel + +try: + from svg.kernels.triton.permute import ( + apply_inverse_permutation_triton, + permute_tensor_by_labels_triton, + ) + from svg.kmeans_utils import ( + batch_kmeans_Euclid, + dynamic_block_sparse_fwd_flashinfer, + identify_dynamic_map, + ) + + svg2_available = True +except ImportError: + svg2_available = False + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, +) +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class SparseVideoGen2AttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [64, 128, 256] + + @staticmethod + def get_enum() -> AttentionBackendEnum: + return AttentionBackendEnum.SPARSE_VIDEO_GEN_2_ATTN + + @staticmethod + def get_impl_cls() -> type["SparseVideoGen2AttentionImpl"]: + return SparseVideoGen2AttentionImpl + + @staticmethod + def get_metadata_cls() -> type["SparseVideoGen2AttentionMetadata"]: + return SparseVideoGen2AttentionMetadata + + @staticmethod + def get_builder_cls() -> type["SparseVideoGen2AttentionMetadataBuilder"]: + return SparseVideoGen2AttentionMetadataBuilder + + +@dataclass +class Svg2LayerCache: + # centroids for kmeans clustering + q_centroids: torch.Tensor | None = None + k_centroids: torch.Tensor | None = None + centroids_initialized: bool = False + + +@dataclass +class Svg2Cache: + layers: dict[int, Svg2LayerCache] = field(default_factory=dict) + + def get_layer(self, layer_idx: int) -> Svg2LayerCache: + layer_cache = self.layers.get(layer_idx) + if layer_cache is None: + layer_cache = Svg2LayerCache() + self.layers[layer_idx] = layer_cache + return layer_cache + + +@dataclass +class SparseVideoGen2AttentionMetadata(AttentionMetadata): + current_timestep: int + num_q_centroids: int + num_k_centroids: int + top_p_kmeans: float + min_kc_ratio: float + kmeans_iter_init: int + kmeans_iter_step: int + zero_step_kmeans_init: bool + first_layers_fp: float + first_times_fp: float + context_length: int + num_frame: int + frame_size: int + cache: Svg2Cache + prompt_length: int | None = None + max_seqlen_q: int | None = None + max_seqlen_k: int | None = None + + +def _require_kwarg(kwargs: dict[str, Any], name: str) -> Any: + if name not in kwargs: + raise ValueError( + f"Missing required argument for SparseVideoGen2Attention: {name}" + ) + return kwargs[name] + + +class SparseVideoGen2AttentionMetadataBuilder(AttentionMetadataBuilder): + + def __init__(self) -> None: + pass + + def prepare(self) -> None: + pass + + def build( # type: ignore[override] + self, + current_timestep: int, + raw_latent_shape: tuple[int, ...], + patch_size: tuple[int, int, int], + cache: Svg2Cache, + num_q_centroids: int, + num_k_centroids: int, + top_p_kmeans: float, + min_kc_ratio: float, + kmeans_iter_init: int, + kmeans_iter_step: int, + zero_step_kmeans_init: bool, + first_layers_fp: float, + first_times_fp: float, + context_length: int = 0, + prompt_length: int | None = None, + **kwargs: dict[str, Any], + ) -> SparseVideoGen2AttentionMetadata: + raw_shape = tuple(raw_latent_shape) + if len(raw_shape) == 5: + t, h, w = raw_shape[2:5] + elif len(raw_shape) == 3: + t, h, w = raw_shape + else: + raise ValueError( + "raw_latent_shape must be (T, H, W) or (B, C, T, H, W) for SAP attention" + ) + pt, ph, pw = patch_size + if t % pt != 0 or h % ph != 0 or w % pw != 0: + raise ValueError( + "raw_latent_shape must be divisible by patch_size for SAP attention" + ) + + num_frame = t // pt + frame_size = (h // ph) * (w // pw) + + return SparseVideoGen2AttentionMetadata( + current_timestep=current_timestep, + num_q_centroids=num_q_centroids, + num_k_centroids=num_k_centroids, + top_p_kmeans=top_p_kmeans, + min_kc_ratio=min_kc_ratio, + kmeans_iter_init=kmeans_iter_init, + kmeans_iter_step=kmeans_iter_step, + zero_step_kmeans_init=zero_step_kmeans_init, + first_layers_fp=first_layers_fp, + first_times_fp=first_times_fp, + context_length=context_length, + prompt_length=prompt_length, + num_frame=num_frame, + frame_size=frame_size, + cache=cache, + ) + + +class SparseVideoGen2AttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool, + softmax_scale: float, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + if causal: + raise ValueError( + "Sparse Video Gen 2 attention does not support causal attention" + ) + if not svg2_available: + raise ImportError( + "Sparse Video Gen 2 attention backend requires svg package to be installed" + "Please install it by following the instructions at " + "https://github.com/svg-project/Sparse-VideoGen" + ) + self.prefix = prefix + self.layer_idx = self._get_layer_idx(prefix) + + def _get_layer_idx(self, prefix: str) -> int: + parts = prefix.split(".") + if len(parts) < 3: + raise ValueError( + f"Invalid prefix for SparseVideoGen2AttentionImpl: {prefix}" + ) + return int(parts[-3]) + + def kmeans_init( + self, + query: torch.Tensor, + key: torch.Tensor, + attn_metadata: SparseVideoGen2AttentionMetadata, + ): + cfg, num_heads, seq_len, dim = query.size() + qlabels, qcentroids, qcluster_sizes, qiter = batch_kmeans_Euclid( + query.reshape(cfg * num_heads, seq_len, dim), + n_clusters=attn_metadata.num_q_centroids, + max_iters=attn_metadata.kmeans_iter_init, + ) + klabels, kcentroids, kcluster_sizes, kiter = batch_kmeans_Euclid( + key.reshape(cfg * num_heads, seq_len, dim), + n_clusters=attn_metadata.num_k_centroids, + max_iters=attn_metadata.kmeans_iter_init, + ) + + layer_cache = attn_metadata.cache.get_layer(self.layer_idx) + layer_cache.q_centroids = qcentroids + layer_cache.k_centroids = kcentroids + + return ( + qlabels, + qcentroids, + qcluster_sizes, + qiter, + klabels, + kcentroids, + kcluster_sizes, + kiter, + ) + + def kmeans_step( + self, + query: torch.Tensor, + key: torch.Tensor, + attn_metadata: SparseVideoGen2AttentionMetadata, + ): + cfg, num_heads, seq_len, dim = query.size() + layer_cache = attn_metadata.cache.get_layer(self.layer_idx) + qlabels, qcentroids, qcluster_sizes, qiter = batch_kmeans_Euclid( + query.reshape(cfg * num_heads, seq_len, dim), + n_clusters=attn_metadata.num_q_centroids, + max_iters=attn_metadata.kmeans_iter_step, + init_centroids=layer_cache.q_centroids, + ) + klabels, kcentroids, kcluster_sizes, kiter = batch_kmeans_Euclid( + key.reshape(cfg * num_heads, seq_len, dim), + n_clusters=attn_metadata.num_k_centroids, + max_iters=attn_metadata.kmeans_iter_step, + init_centroids=layer_cache.k_centroids, + ) + + layer_cache.q_centroids = qcentroids + layer_cache.k_centroids = kcentroids + + return ( + qlabels, + qcentroids, + qcluster_sizes, + qiter, + klabels, + kcentroids, + kcluster_sizes, + kiter, + ) + + def kmeans_clustering( + self, + query: torch.Tensor, + key: torch.Tensor, + attn_metadata: SparseVideoGen2AttentionMetadata, + ): + layer_cache = attn_metadata.cache.get_layer(self.layer_idx) + if not layer_cache.centroids_initialized: + ( + qlabels, + qcentroids, + qcluster_sizes, + qiter, + klabels, + kcentroids, + kcluster_sizes, + kiter, + ) = self.kmeans_init(query, key, attn_metadata) + layer_cache.centroids_initialized = True + logger.debug( + "Centroids initialized at layer %s (init iters: %s).", + self.layer_idx, + attn_metadata.kmeans_iter_init, + ) + else: + ( + qlabels, + qcentroids, + qcluster_sizes, + qiter, + klabels, + kcentroids, + kcluster_sizes, + kiter, + ) = self.kmeans_step(query, key, attn_metadata) + + return ( + qlabels, + qcentroids, + qcluster_sizes, + qiter, + klabels, + kcentroids, + kcluster_sizes, + kiter, + ) + + def semantic_aware_permutation( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: SparseVideoGen2AttentionMetadata, + ): + cfg, num_heads, seq_len, dim = query.size() + + # 1. Kmeans clustering + ( + qlabels, + qcentroids, + qcluster_sizes, + qiter, + klabels, + kcentroids, + kcluster_sizes, + kiter, + ) = self.kmeans_clustering(query, key, attn_metadata) + + # 2. Identify dynamic map + q_cluster_sizes = qcluster_sizes.view( + cfg, num_heads, attn_metadata.num_q_centroids + ) + k_cluster_sizes = kcluster_sizes.view( + cfg, num_heads, attn_metadata.num_k_centroids + ) + + dynamic_map = identify_dynamic_map( + qcentroids.view(cfg, num_heads, attn_metadata.num_q_centroids, dim), + kcentroids.view(cfg, num_heads, attn_metadata.num_k_centroids, dim), + q_cluster_sizes, + k_cluster_sizes, + attn_metadata.top_p_kmeans, + attn_metadata.min_kc_ratio, + ) + + # 3. Permute the query, key, value + q_permuted, q_sorted_indices = permute_tensor_by_labels_triton( + query, qlabels, dim=2 + ) + k_permuted, k_sorted_indices = permute_tensor_by_labels_triton( + key, klabels, dim=2 + ) + v_permuted, v_sorted_indices = permute_tensor_by_labels_triton( + value, klabels, dim=2, sorted_indices=k_sorted_indices + ) + + return ( + q_permuted, + k_permuted, + v_permuted, + dynamic_map, + q_cluster_sizes, + k_cluster_sizes, + q_sorted_indices, + ) + + def _hunyuan_dynamic_map_post_processing( + self, + q_perm: torch.Tensor, + k_perm: torch.Tensor, + v_perm: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dyn_map: torch.Tensor, + qc_sz_s: torch.Tensor, + kc_sz_s: torch.Tensor, + q_sorted_indices: torch.Tensor, + video_length: int, + context_length: int, + prompt_length: int, + unprompt_length: int, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + # Place the permuted video tokens back and keep text tokens at the tail. + query[:, :, :-context_length, :] = q_perm + key[:, :, :-context_length, :] = k_perm + value[:, :, :-context_length, :] = v_perm + + # Add prompt/unprompt clusters to the dynamic map. + dyn_map = F.pad(dyn_map, (0, 2, 0, 2), value=0) + dyn_map[:, :, -2, :-1] = True + dyn_map[:, :, :-1, -2] = True + dyn_map[:, :, -1, -1] = True + + qc_sz_s = F.pad(qc_sz_s, (0, 2), value=0) + qc_sz_s[:, :, -2] = prompt_length + qc_sz_s[:, :, -1] = unprompt_length + kc_sz_s = F.pad(kc_sz_s, (0, 2), value=0) + kc_sz_s[:, :, -2] = prompt_length + kc_sz_s[:, :, -1] = unprompt_length + + q_sorted_indices = F.pad(q_sorted_indices, (0, context_length), value=0) + q_sorted_indices[:, video_length:] = torch.arange( + video_length, + video_length + context_length, + device=q_sorted_indices.device, + ) + return query, key, value, dyn_map, qc_sz_s, kc_sz_s, q_sorted_indices + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: SparseVideoGen2AttentionMetadata, + ) -> torch.Tensor: + torch.backends.cuda.preferred_linalg_library(backend="magma") + res = None + # bshd -> bhsd + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + batch_size, num_heads, seq_len, dim = query.size() + + context_length, num_frame, frame_size = ( + attn_metadata.context_length, + attn_metadata.num_frame, + attn_metadata.frame_size, + ) + prompt_length = attn_metadata.prompt_length + if prompt_length is None: + prompt_length = context_length + + assert ( + seq_len == context_length + num_frame * frame_size + ), f"Query Shape: {seq_len} is not equivalent to {context_length} + {num_frame} * {frame_size}" + + # Determine if we use Full Attention to calculate + full_attention_flag = False + + if self.layer_idx < attn_metadata.first_layers_fp: + full_attention_flag = True + if attn_metadata.current_timestep > attn_metadata.first_times_fp: + full_attention_flag = True + + if full_attention_flag: + if attn_metadata.zero_step_kmeans_init: + video_length = attn_metadata.num_frame * attn_metadata.frame_size + query_video = query[:, :, :video_length, :].contiguous() + key_video = key[:, :, :video_length, :].contiguous() + self.kmeans_clustering(query_video, key_video, attn_metadata) + + with sdpa_kernel( + SDPBackend.CUDNN_ATTENTION + ): # not sure why we need to force cudnn here, but it's faster than flash attention + output_hidden_states = torch.nn.functional.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) + + res = output_hidden_states.reshape( + batch_size, num_heads, seq_len, dim + ).transpose(1, 2) + else: + if context_length > 0: + video_length = num_frame * frame_size + unprompt_length = max(context_length - prompt_length, 0) + query_video = query[:, :, :video_length, :].contiguous() + key_video = key[:, :, :video_length, :].contiguous() + value_video = value[:, :, :video_length, :].contiguous() + + ( + q_perm, + k_perm, + v_perm, + dyn_map, + qc_sz_s, + kc_sz_s, + q_sorted_indices, + ) = self.semantic_aware_permutation( + query_video, key_video, value_video, attn_metadata + ) + ( + q_perm, + k_perm, + v_perm, + dyn_map, + qc_sz_s, + kc_sz_s, + q_sorted_indices, + ) = self._hunyuan_dynamic_map_post_processing( + q_perm, + k_perm, + v_perm, + query, + key, + value, + dyn_map, + qc_sz_s, + kc_sz_s, + q_sorted_indices, + video_length, + context_length, + prompt_length, + unprompt_length, + ) + else: + ( + q_perm, + k_perm, + v_perm, + dyn_map, + qc_sz_s, + kc_sz_s, + q_sorted_indices, + ) = self.semantic_aware_permutation(query, key, value, attn_metadata) + + output_permuted = dynamic_block_sparse_fwd_flashinfer( + q_perm, k_perm, v_perm, dyn_map, qc_sz_s, kc_sz_s, is_cpu=False + ) + + attn_output = apply_inverse_permutation_triton( + output_permuted, q_sorted_indices, dim=2 + ) + + res = attn_output.reshape(batch_size, num_heads, seq_len, dim).transpose( + 1, 2 + ) + + torch.backends.cuda.preferred_linalg_library( + backend="default" + ) # reset to default + return res.contiguous() diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/video_sparse_attn.py b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/video_sparse_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..2ee9a17dfad9062d2b500c8ed247828537c5a3ef --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/video_sparse_attn.py @@ -0,0 +1,332 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import functools +import math +from dataclasses import dataclass + +import torch + +try: + from vsa import video_sparse_attn +except ImportError: + video_sparse_attn = None + +from typing import Any + +from sglang.multimodal_gen.runtime.distributed import get_sp_group +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, +) +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) +VSA_TILE_SIZE = (4, 4, 4) + + +@functools.lru_cache(maxsize=10) +def get_tile_partition_indices( + dit_seq_shape: tuple[int, int, int], + tile_size: tuple[int, int, int], + device: torch.device, +) -> torch.LongTensor: + T, H, W = dit_seq_shape + ts, hs, ws = tile_size + indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W) + ls = [] + for t in range(math.ceil(T / ts)): + for h in range(math.ceil(H / hs)): + for w in range(math.ceil(W / ws)): + ls.append( + indices[ + t * ts : min(t * ts + ts, T), + h * hs : min(h * hs + hs, H), + w * ws : min(w * ws + ws, W), + ].flatten() + ) + index = torch.cat(ls, dim=0) + return index + + +@functools.lru_cache(maxsize=10) +def get_reverse_tile_partition_indices( + dit_seq_shape: tuple[int, int, int], + tile_size: tuple[int, int, int], + device: torch.device, +) -> torch.LongTensor: + return torch.argsort(get_tile_partition_indices(dit_seq_shape, tile_size, device)) + + +@functools.lru_cache(maxsize=10) +def construct_variable_block_sizes( + dit_seq_shape: tuple[int, int, int], + num_tiles: tuple[int, int, int], + device: torch.device, +) -> torch.LongTensor: + """ + Compute the number of valid (non‑padded) tokens inside every + (ts_t × ts_h × ts_w) tile after padding ‑‑ flattened in the order + (t‑tile, h‑tile, w‑tile) that `rearrange` uses. + + Returns + ------- + torch.LongTensor # shape: [∏ full_window_size] + """ + # unpack + t, h, w = dit_seq_shape + ts_t, ts_h, ts_w = VSA_TILE_SIZE + n_t, n_h, n_w = num_tiles + + def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor: + """Vector with the size of each tile along one dimension.""" + sizes = torch.full((n_tiles,), tile, dtype=torch.int, device=device) + # size of last (possibly partial) tile + remainder = dim_len - (n_tiles - 1) * tile + sizes[-1] = remainder if remainder > 0 else tile + return sizes + + t_sizes = _sizes(t, ts_t, n_t) # [n_t] + h_sizes = _sizes(h, ts_h, n_h) # [n_h] + w_sizes = _sizes(w, ts_w, n_w) # [n_w] + + # broadcast‑multiply to get voxels per tile, then flatten + block_sizes = ( + t_sizes[:, None, None] # [n_t, 1, 1] + * h_sizes[None, :, None] # [1, n_h, 1] + * w_sizes[None, None, :] # [1, 1, n_w] + ).reshape( + -1 + ) # [n_t * n_h * n_w] + + return block_sizes + + +@functools.lru_cache(maxsize=10) +def get_non_pad_index( + variable_block_sizes: torch.LongTensor, + max_block_size: int, +): + n_win = variable_block_sizes.shape[0] + device = variable_block_sizes.device + starts_pad = torch.arange(n_win, device=device) * max_block_size + index_pad = ( + starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :] + ) + index_mask = ( + torch.arange(max_block_size, device=device)[None, :] + < variable_block_sizes[:, None] + ) + return index_pad[index_mask] + + +class VideoSparseAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [64, 128] + + @staticmethod + def get_enum() -> AttentionBackendEnum: + return AttentionBackendEnum.VIDEO_SPARSE_ATTN + + @staticmethod + def get_impl_cls() -> type["VideoSparseAttentionImpl"]: + return VideoSparseAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["VideoSparseAttentionMetadata"]: + return VideoSparseAttentionMetadata + + @staticmethod + def get_builder_cls() -> type["VideoSparseAttentionMetadataBuilder"]: + return VideoSparseAttentionMetadataBuilder + + +@dataclass +class VideoSparseAttentionMetadata(AttentionMetadata): + current_timestep: int + dit_seq_shape: list[int] + VSA_sparsity: float + num_tiles: list[int] + total_seq_length: int + tile_partition_indices: torch.LongTensor + reverse_tile_partition_indices: torch.LongTensor + variable_block_sizes: torch.LongTensor + non_pad_index: torch.LongTensor + + # adaption for FastWan2.1-T2V-1.3B-Diffusers + # Sequence lengths for the forward batch + # Maximum sequence length for query + max_seqlen_q: int = 1 + # Maximum sequence length for key + max_seqlen_k: int = 0 + + +class VideoSparseAttentionMetadataBuilder(AttentionMetadataBuilder): + + def __init__(self): + pass + + def prepare(self): + pass + + def build( # type: ignore + self, + current_timestep: int, + raw_latent_shape: tuple[int, int, int], + patch_size: tuple[int, int, int], + VSA_sparsity: float, + device: torch.device, + **kwargs: dict[str, Any], + ) -> VideoSparseAttentionMetadata: + patch_size = patch_size + dit_seq_shape = ( + raw_latent_shape[0] // patch_size[0], + raw_latent_shape[1] // patch_size[1], + raw_latent_shape[2] // patch_size[2], + ) + + num_tiles = ( + math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]), + math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]), + math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2]), + ) + total_seq_length = math.prod(dit_seq_shape) + + tile_partition_indices = get_tile_partition_indices( + dit_seq_shape, VSA_TILE_SIZE, device + ) + reverse_tile_partition_indices = get_reverse_tile_partition_indices( + dit_seq_shape, VSA_TILE_SIZE, device + ) + variable_block_sizes = construct_variable_block_sizes( + dit_seq_shape, num_tiles, device + ) + non_pad_index = get_non_pad_index( + variable_block_sizes, math.prod(VSA_TILE_SIZE) + ) + + return VideoSparseAttentionMetadata( + current_timestep=current_timestep, + dit_seq_shape=dit_seq_shape, # type: ignore + VSA_sparsity=VSA_sparsity, # type: ignore + num_tiles=num_tiles, # type: ignore + total_seq_length=total_seq_length, # type: ignore + tile_partition_indices=tile_partition_indices, # type: ignore + reverse_tile_partition_indices=reverse_tile_partition_indices, + variable_block_sizes=variable_block_sizes, + non_pad_index=non_pad_index, + ) + + +class VideoSparseAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool, + softmax_scale: float, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.prefix = prefix + sp_group = get_sp_group() + self.sp_size = sp_group.world_size + + def tile( + self, + x: torch.Tensor, + num_tiles: list[int], + tile_partition_indices: torch.LongTensor, + non_pad_index: torch.LongTensor, + ) -> torch.Tensor: + t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0] + h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1] + w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2] + + x_padded = torch.zeros( + ( + x.shape[0], + t_padded_size * h_padded_size * w_padded_size, + x.shape[-2], + x.shape[-1], + ), + device=x.device, + dtype=x.dtype, + ) + x_padded[:, non_pad_index] = x[:, tile_partition_indices] + return x_padded + + def untile( + self, + x: torch.Tensor, + reverse_tile_partition_indices: torch.LongTensor, + non_pad_index: torch.LongTensor, + ) -> torch.Tensor: + x = x[:, non_pad_index][:, reverse_tile_partition_indices] + return x + + def preprocess_qkv( + self, + qkv: torch.Tensor, + attn_metadata: VideoSparseAttentionMetadata, + ) -> torch.Tensor: + return self.tile( + qkv, + attn_metadata.num_tiles, + attn_metadata.tile_partition_indices, + attn_metadata.non_pad_index, + ) + + def postprocess_output( + self, + output: torch.Tensor, + attn_metadata: VideoSparseAttentionMetadata, + ) -> torch.Tensor: + return self.untile( + output, + attn_metadata.reverse_tile_partition_indices, + attn_metadata.non_pad_index, + ) + + def forward( # type: ignore[override] + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + gate_compress: torch.Tensor, + attn_metadata: VideoSparseAttentionMetadata, + ) -> torch.Tensor: + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + gate_compress = gate_compress.transpose(1, 2).contiguous() + + VSA_sparsity = attn_metadata.VSA_sparsity + + cur_topk = math.ceil( + (1 - VSA_sparsity) + * (attn_metadata.total_seq_length / math.prod(VSA_TILE_SIZE)) + ) + + if video_sparse_attn is None: + raise NotImplementedError("video_sparse_attn is not installed") + hidden_states = video_sparse_attn( + query, + key, + value, + variable_block_sizes=attn_metadata.variable_block_sizes, + topk=cur_topk, + block_size=VSA_TILE_SIZE, + compress_attn_weight=gate_compress, + ).transpose(1, 2) + + return hidden_states diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/vmoba.py b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/vmoba.py new file mode 100644 index 0000000000000000000000000000000000000000..e07c74336e4afde94473997782e12f6ed8e446a4 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/backends/vmoba.py @@ -0,0 +1,259 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import re +from dataclasses import dataclass + +import torch +from einops import rearrange +from kernel.attn.vmoba_attn.vmoba import ( + moba_attn_varlen, + process_moba_input, + process_moba_output, +) + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, +) +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class VMOBAAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_enum() -> AttentionBackendEnum: + return AttentionBackendEnum.VMOBA_ATTN + + @staticmethod + def get_impl_cls() -> type["VMOBAAttentionImpl"]: + return VMOBAAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["VideoMobaAttentionMetadata"]: + return VideoMobaAttentionMetadata + + @staticmethod + def get_builder_cls() -> type["VideoMobaAttentionMetadataBuilder"]: + return VideoMobaAttentionMetadataBuilder + + +@dataclass +class VideoMobaAttentionMetadata(AttentionMetadata): + current_timestep: int + + temporal_chunk_size: int + temporal_topk: int + spatial_chunk_size: tuple[int, int] + spatial_topk: int + st_chunk_size: tuple[int, int, int] + st_topk: int + + moba_select_mode: str + moba_threshold: float + moba_threshold_type: str + patch_resolution: list[int] + + first_full_step: int = 12 + first_full_layer: int = 0 + # temporal_layer -> spatial_layer -> st_layer + temporal_layer: int = 1 + spatial_layer: int = 1 + st_layer: int = 1 + + +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = torch.zeros( + (batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype + ) + output[indices] = hidden_states + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +class VideoMobaAttentionMetadataBuilder(AttentionMetadataBuilder): + + def __init__(self): + pass + + def prepare(self): + pass + + def build( # type: ignore + self, + current_timestep: int, + raw_latent_shape: tuple[int, int, int], + patch_size: tuple[int, int, int], + temporal_chunk_size: int, + temporal_topk: int, + spatial_chunk_size: tuple[int, int], + spatial_topk: int, + st_chunk_size: tuple[int, int, int], + st_topk: int, + moba_select_mode: str = "threshold", + moba_threshold: float = 0.25, + moba_threshold_type: str = "query_head", + device: torch.device = None, + first_full_layer: int = 0, + first_full_step: int = 12, + temporal_layer: int = 1, + spatial_layer: int = 1, + st_layer: int = 1, + **kwargs, + ) -> VideoMobaAttentionMetadata: + if device is None: + device = torch.device("cpu") + assert ( + raw_latent_shape[0] % patch_size[0] == 0 + and raw_latent_shape[1] % patch_size[1] == 0 + and raw_latent_shape[2] % patch_size[2] == 0 + ), f"spatial patch_resolution {raw_latent_shape} should be divisible by patch_size {patch_size}" + patch_resolution = [ + t // pt for t, pt in zip(raw_latent_shape, patch_size, strict=False) + ] + + return VideoMobaAttentionMetadata( + current_timestep=current_timestep, + temporal_chunk_size=temporal_chunk_size, + temporal_topk=temporal_topk, + spatial_chunk_size=spatial_chunk_size, + spatial_topk=spatial_topk, + st_chunk_size=st_chunk_size, + st_topk=st_topk, + moba_select_mode=moba_select_mode, + moba_threshold=moba_threshold, + moba_threshold_type=moba_threshold_type, + patch_resolution=patch_resolution, + first_full_layer=first_full_layer, + first_full_step=first_full_step, + temporal_layer=temporal_layer, + spatial_layer=spatial_layer, + st_layer=st_layer, + ) + + +class VMOBAAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads, + head_size, + softmax_scale, + causal=False, + num_kv_heads=None, + prefix="", + **extra_impl_args, + ) -> None: + self.prefix = prefix + self.layer_idx = self._get_layer_idx(prefix) + + self.pad_input = pad_input + + def _get_layer_idx(self, prefix: str) -> int | None: + match = re.search(r"blocks\.(\d+)", prefix) + if not match: + raise ValueError(f"Invalid prefix: {prefix}") + return int(match.group(1)) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + """ + query: [B, L, H, D] + key: [B, L, H, D] + value: [B, L, H, D] + attn_metadata: AttentionMetadata + """ + batch_size, sequence_length, num_heads, head_dim = query.shape + + # select chunk type according to layer idx: + loop_layer_num = ( + attn_metadata.temporal_layer + + attn_metadata.spatial_layer + + attn_metadata.st_layer + ) + moba_layer = self.layer_idx - attn_metadata.first_full_layer + if moba_layer % loop_layer_num < attn_metadata.temporal_layer: + moba_chunk_size = attn_metadata.temporal_chunk_size + moba_topk = attn_metadata.temporal_topk + elif ( + moba_layer % loop_layer_num + < attn_metadata.temporal_layer + attn_metadata.spatial_layer + ): + moba_chunk_size = attn_metadata.spatial_chunk_size + moba_topk = attn_metadata.spatial_topk + elif ( + moba_layer % loop_layer_num + < attn_metadata.temporal_layer + + attn_metadata.spatial_layer + + attn_metadata.st_layer + ): + moba_chunk_size = attn_metadata.st_chunk_size + moba_topk = attn_metadata.st_topk + + query, chunk_size = process_moba_input( + query, attn_metadata.patch_resolution, moba_chunk_size + ) + key, chunk_size = process_moba_input( + key, attn_metadata.patch_resolution, moba_chunk_size + ) + value, chunk_size = process_moba_input( + value, attn_metadata.patch_resolution, moba_chunk_size + ) + max_seqlen = query.shape[1] + indices_q = torch.arange( + 0, query.shape[0] * query.shape[1], device=query.device + ) + cu_seqlens = torch.arange( + 0, + query.shape[0] * query.shape[1] + 1, + query.shape[1], + dtype=torch.int32, + device=query.device, + ) + query = rearrange(query, "b s ... -> (b s) ...") + key = rearrange(key, "b s ... -> (b s) ...") + value = rearrange(value, "b s ... -> (b s) ...") + + # current_timestep=attn_metadata.current_timestep + hidden_states = moba_attn_varlen( + query, + key, + value, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + moba_chunk_size=chunk_size, + moba_topk=moba_topk, + select_mode=attn_metadata.moba_select_mode, + simsum_threshold=attn_metadata.moba_threshold, + threshold_type=attn_metadata.moba_threshold_type, + ) + hidden_states = self.pad_input( + hidden_states, indices_q, batch_size, sequence_length + ) + hidden_states = process_moba_output( + hidden_states, attn_metadata.patch_resolution, moba_chunk_size + ) + + return hidden_states diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/attention/layer.py b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..4bfe9f2c00b98164953bda6bd0660e6d1907ed74 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/layer.py @@ -0,0 +1,470 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from typing import Type + +import torch +import torch.nn as nn + +from sglang.multimodal_gen.runtime.distributed.communication_op import ( + sequence_model_parallel_all_gather, + sequence_model_parallel_all_to_all_4D, +) +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_ring_parallel_world_size, + get_sequence_parallel_world_size, + get_sp_group, + get_sp_parallel_rank, + get_sp_world_size, + get_ulysses_parallel_world_size, +) +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionImpl, +) +from sglang.multimodal_gen.runtime.layers.attention.selector import get_attn_backend +from sglang.multimodal_gen.runtime.layers.usp import ( + _usp_input_all_to_all, + _usp_output_all_to_all, + ring_attn, +) +from sglang.multimodal_gen.runtime.managers.forward_context import ( + ForwardContext, + get_forward_context, +) +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.utils import get_compute_dtype + + +class UlyssesAttention(nn.Module): + """Ulysses-style SequenceParallelism attention layer.""" + + def __init__( + self, + num_heads: int, + head_size: int, + num_kv_heads: int | None = None, + softmax_scale: float | None = None, + causal: bool = False, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + super().__init__() + if softmax_scale is None: + self.softmax_scale = head_size**-0.5 + else: + self.softmax_scale = softmax_scale + + if num_kv_heads is None: + num_kv_heads = num_heads + + dtype = get_compute_dtype() + attn_backend = get_attn_backend( + head_size, dtype, supported_attention_backends=supported_attention_backends + ) + impl_cls = attn_backend.get_impl_cls() + + self.attn_impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + causal=causal, + softmax_scale=self.softmax_scale, + num_kv_heads=num_kv_heads, + prefix=f"{prefix}.impl", + **extra_impl_args, + ) + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + self.backend = attn_backend.get_enum() + self.dtype = dtype + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + replicated_q: torch.Tensor | None = None, + replicated_k: torch.Tensor | None = None, + replicated_v: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Forward pass for distributed attention. + + Args: + q (torch.Tensor): Query tensor [batch_size, seq_len, num_heads, head_dim] + k (torch.Tensor): Key tensor [batch_size, seq_len, num_heads, head_dim] + v (torch.Tensor): Value tensor [batch_size, seq_len, num_heads, head_dim] + replicated_q (Optional[torch.Tensor]): Replicated query tensor, typically for text tokens + replicated_k (Optional[torch.Tensor]): Replicated key tensor + replicated_v (Optional[torch.Tensor]): Replicated value tensor + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: A tuple containing: + - o (torch.Tensor): Output tensor after attention for the main sequence + - replicated_o (Optional[torch.Tensor]): Output tensor for replicated tokens, if provided + """ + # Check input shapes + assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "Expected 4D tensors" + batch_size, seq_len, num_heads, head_dim = q.shape + local_rank = get_sp_parallel_rank() + world_size = get_sp_world_size() + + forward_context: ForwardContext = get_forward_context() + ctx_attn_metadata = forward_context.attn_metadata + + # Stack QKV + qkv = torch.cat([q, k, v], dim=0) # [3, seq_len, num_heads, head_dim] + + # Redistribute heads across sequence dimension + qkv = sequence_model_parallel_all_to_all_4D(qkv, scatter_dim=2, gather_dim=1) + # Apply backend-specific preprocess_qkv + qkv = self.attn_impl.preprocess_qkv(qkv, ctx_attn_metadata) + + # Concatenate with replicated QKV if provided + if replicated_q is not None: + assert replicated_k is not None and replicated_v is not None + replicated_qkv = torch.cat( + [replicated_q, replicated_k, replicated_v], dim=0 + ) # [3, seq_len, num_heads, head_dim] + heads_per_rank = num_heads // world_size + replicated_qkv = replicated_qkv[ + :, :, local_rank * heads_per_rank : (local_rank + 1) * heads_per_rank + ] + qkv = torch.cat([qkv, replicated_qkv], dim=1) + + q, k, v = qkv.chunk(3, dim=0) + + output = self.attn_impl.forward(q, k, v, ctx_attn_metadata) + + # Redistribute back if using sequence parallelism + replicated_output = None + if replicated_q is not None: + replicated_output = output[:, seq_len * world_size :] + output = output[:, : seq_len * world_size] + # TODO: make this asynchronous + replicated_output = sequence_model_parallel_all_gather( + replicated_output.contiguous(), dim=2 + ) + # Apply backend-specific postprocess_output + output = self.attn_impl.postprocess_output(output, ctx_attn_metadata) + + output = sequence_model_parallel_all_to_all_4D( + output, scatter_dim=1, gather_dim=2 + ) + return output, replicated_output + + +class UlyssesAttention_VSA(UlyssesAttention): + """Distributed attention layer with VSA support.""" + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + replicated_q: torch.Tensor | None = None, + replicated_k: torch.Tensor | None = None, + replicated_v: torch.Tensor | None = None, + gate_compress: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass for distributed attention. + + Args: + q (torch.Tensor): Query tensor [batch_size, seq_len, num_heads, head_dim] + k (torch.Tensor): Key tensor [batch_size, seq_len, num_heads, head_dim] + v (torch.Tensor): Value tensor [batch_size, seq_len, num_heads, head_dim] + gate_compress (torch.Tensor): Gate compress tensor [batch_size, seq_len, num_heads, head_dim] + replicated_q (Optional[torch.Tensor]): Replicated query tensor, typically for text tokens + replicated_k (Optional[torch.Tensor]): Replicated key tensor + replicated_v (Optional[torch.Tensor]): Replicated value tensor + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: A tuple containing: + - o (torch.Tensor): Output tensor after attention for the main sequence + - replicated_o (Optional[torch.Tensor]): Output tensor for replicated tokens, if provided + """ + # Check text tokens are not supported for VSA now + assert ( + replicated_q is None and replicated_k is None and replicated_v is None + ), "Replicated QKV is not supported for VSA now" + # Check input shapes + assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "Expected 4D tensors" + + forward_context: ForwardContext = get_forward_context() + ctx_attn_metadata = forward_context.attn_metadata + + # Stack QKV + qkvg = torch.cat( + [q, k, v, gate_compress], dim=0 + ) # [3, seq_len, num_heads, head_dim] + + # Redistribute heads across sequence dimension + qkvg = sequence_model_parallel_all_to_all_4D(qkvg, scatter_dim=2, gather_dim=1) + + qkvg = self.attn_impl.preprocess_qkv(qkvg, ctx_attn_metadata) + + q, k, v, gate_compress = qkvg.chunk(4, dim=0) + output = self.attn_impl.forward( + q, k, v, gate_compress=gate_compress, attn_metadata=ctx_attn_metadata + ) # type: ignore[call-arg] + + # Apply backend-specific postprocess_output + output = self.attn_impl.postprocess_output(output, ctx_attn_metadata) + + output = sequence_model_parallel_all_to_all_4D( + output, scatter_dim=1, gather_dim=2 + ) + + return output + + +class LocalAttention(nn.Module): + """Attention layer.""" + + def __init__( + self, + num_heads: int, + head_size: int, + num_kv_heads: int | None = None, + softmax_scale: float | None = None, + causal: bool = False, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + **extra_impl_args, + ) -> None: + super().__init__() + if softmax_scale is None: + self.softmax_scale = head_size**-0.5 + else: + self.softmax_scale = softmax_scale + if num_kv_heads is None: + num_kv_heads = num_heads + + dtype = get_compute_dtype() + attn_backend = get_attn_backend( + head_size, dtype, supported_attention_backends=supported_attention_backends + ) + impl_cls = attn_backend.get_impl_cls() + self.attn_impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + softmax_scale=self.softmax_scale, + num_kv_heads=num_kv_heads, + causal=causal, + **extra_impl_args, + ) + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + self.backend = attn_backend.get_enum() + self.dtype = dtype + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + """ + Apply local attention between query, key and value tensors. + + Args: + q (torch.Tensor): Query tensor of shape [batch_size, seq_len, num_heads, head_dim] + k (torch.Tensor): Key tensor of shape [batch_size, seq_len, num_heads, head_dim] + v (torch.Tensor): Value tensor of shape [batch_size, seq_len, num_heads, head_dim] + + Returns: + torch.Tensor: Output tensor after local attention + """ + # Check input shapes + assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "Expected 4D tensors" + + forward_context: ForwardContext = get_forward_context() + ctx_attn_metadata = forward_context.attn_metadata + + output = self.attn_impl.forward(q, k, v, attn_metadata=ctx_attn_metadata) + return output + + +class USPAttention(nn.Module): + """ + Ulysses Sequence Parallelism with Ring Attention. + + This class implements the USP algorithm, which is a combination of + Ulysses-style all-to-all communication for sequence-head dimension sharding + and Ring Attention for fine-grained sequence parallelism within subgroups. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + num_kv_heads: int | None = None, + softmax_scale: float | None = None, + causal: bool = False, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + prefix: str = "", + dropout_rate: float = 0.0, + skip_sequence_parallel: bool = False, + **extra_impl_args, + ) -> None: + """ + Args: + skip_sequence_parallel: + when KV is replicated across all SP ranks (e.g. cross-attention to + text/image encoder outputs), the full USP pipeline is redundant: + each rank's local Q shard can attend directly to the locally-held + full KV without any collective communication. + """ + super().__init__() + if softmax_scale is None: + self.softmax_scale = head_size**-0.5 + else: + self.softmax_scale = softmax_scale + + if num_kv_heads is None: + num_kv_heads = num_heads + + dtype = get_compute_dtype() + attn_backend = get_attn_backend( + head_size, dtype, supported_attention_backends=supported_attention_backends + ) + impl_cls: Type["AttentionImpl"] = attn_backend.get_impl_cls() + self.attn_impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + causal=causal, + softmax_scale=self.softmax_scale, + num_kv_heads=num_kv_heads, + prefix=f"{prefix}.impl", + **extra_impl_args, + ) + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + self.backend = attn_backend.get_enum() + self.dtype = dtype + self.causal = causal + self.dropout_p = dropout_rate + + self.skip_sequence_parallel = skip_sequence_parallel + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + num_replicated_prefix: int = 0, + ) -> torch.Tensor: + """ + Forward pass for USPAttention. + + q, k, v: [B, S_local, H, D] + num_replicated_prefix: number of leading tokens in q/k/v that are + replicated (identical) across all SP ranks, e.g. text tokens + in FLUX joint attention. These tokens are excluded from the + Ulysses all-to-all so they appear exactly once in the gathered + sequence, preserving correct attention weights. + + Note: Replicated tensors are not supported in this implementation. + When skip_sequence_parallel=True (set at construction time), all SP + communication is bypassed — use this for cross-attention where KV + content is replicated across ranks (distinct from replicated_k/v args). + """ + forward_context: ForwardContext = get_forward_context() + ctx_attn_metadata = forward_context.attn_metadata + if self.skip_sequence_parallel or get_sequence_parallel_world_size() == 1: + # No sequence parallelism, just run local attention. + out = self.attn_impl.forward(q, k, v, ctx_attn_metadata) + return out + + sp_size = get_ulysses_parallel_world_size() + if sp_size > 1 and num_replicated_prefix > 0: + return self._forward_with_replicated_prefix( + q, k, v, ctx_attn_metadata, num_replicated_prefix + ) + + # Ulysses-style All-to-All for sequence/head sharding + if sp_size > 1: + # -> [B, S, H_local, D] + q = _usp_input_all_to_all(q, head_dim=2) + k = _usp_input_all_to_all(k, head_dim=2) + v = _usp_input_all_to_all(v, head_dim=2) + + # Ring Attention within subgroups or local attention + if get_ring_parallel_world_size() > 1: + out = ring_attn( + q, + k, + v, + attn_impl=self.attn_impl, + is_causal=self.causal, + dropout_p=self.dropout_p, + ) + else: + # -> [B, S, H_local, D] + out = self.attn_impl.forward(q, k, v, ctx_attn_metadata) + + # Ulysses-style All-to-All to restore original sharding + if sp_size > 1: + # -> [B, S_local, H, D] + out = _usp_output_all_to_all(out, head_dim=2) + + return out + + def _forward_with_replicated_prefix( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ctx_attn_metadata, + num_rep: int, + ) -> torch.Tensor: + """Ulysses attention where the first *num_rep* tokens are replicated + across SP ranks (e.g. text tokens) and should NOT be duplicated by the + all-to-all. + + Strategy: + 1. Split q/k/v into replicated prefix and SP-sharded suffix. + 2. All-to-all only the sharded suffix (gathers sequence, shards heads). + 3. Locally slice the replicated prefix to the same head shard. + 4. Concatenate [prefix_h_local, gathered_suffix] and run attention. + 5. Split output, all-to-all back the suffix, all-gather prefix heads. + """ + sp_size = get_ulysses_parallel_world_size() + sp_rank = get_sp_parallel_rank() + + q_rep, q_shard = q[:, :num_rep], q[:, num_rep:] + k_rep, k_shard = k[:, :num_rep], k[:, num_rep:] + v_rep, v_shard = v[:, :num_rep], v[:, num_rep:] + + q_shard = _usp_input_all_to_all(q_shard, head_dim=2) + k_shard = _usp_input_all_to_all(k_shard, head_dim=2) + v_shard = _usp_input_all_to_all(v_shard, head_dim=2) + + h_local = q_shard.shape[2] + h_start = sp_rank * h_local + h_end = h_start + h_local + q_rep = q_rep[:, :, h_start:h_end, :].contiguous() + k_rep = k_rep[:, :, h_start:h_end, :].contiguous() + v_rep = v_rep[:, :, h_start:h_end, :].contiguous() + + q = torch.cat([q_rep, q_shard], dim=1) + k = torch.cat([k_rep, k_shard], dim=1) + v = torch.cat([v_rep, v_shard], dim=1) + + out = self.attn_impl.forward(q, k, v, ctx_attn_metadata) + + out_rep = out[:, :num_rep] + out_shard = out[:, num_rep:] + + out_shard = _usp_output_all_to_all(out_shard, head_dim=2) + + gathered = [torch.empty_like(out_rep) for _ in range(sp_size)] + torch.distributed.all_gather( + gathered, + out_rep.contiguous(), + group=get_sp_group().ulysses_group, + ) + out_rep = torch.cat(gathered, dim=2) + + return torch.cat([out_rep, out_shard], dim=1) diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/attention/selector.py b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/selector.py new file mode 100644 index 0000000000000000000000000000000000000000..a82a4eca89e7711dfc37f2f5fbefde015bc8fd6f --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/selector.py @@ -0,0 +1,191 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/attention/selector.py + +import os +from collections.abc import Generator +from contextlib import contextmanager +from functools import cache +from typing import cast + +import torch + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, +) +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.server_args import get_global_server_args +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname + +logger = init_logger(__name__) + + +def backend_name_to_enum(backend_name: str) -> AttentionBackendEnum | None: + """ + Convert a string backend name to a _Backend enum value. + + Returns: + * _Backend: enum value if backend_name is a valid in-tree type + * None: otherwise it's an invalid in-tree type or an out-of-tree platform is + loaded. + """ + assert backend_name is not None + return ( + AttentionBackendEnum[backend_name] + if backend_name in AttentionBackendEnum.__members__ + else None + ) + + +def get_env_variable_attn_backend() -> AttentionBackendEnum | None: + """ + Get the backend override specified by the sglang-diffusion attention + backend environment variable, if one is specified. + + Returns: + + * _Backend enum value if an override is specified + * None otherwise + """ + backend_name = os.environ.get(STR_BACKEND_ENV_VAR) + return None if backend_name is None else backend_name_to_enum(backend_name) + + +# Global state allows a particular choice of backend +# to be forced, overriding the logic which auto-selects +# a backend based on system & workload configuration +# (default behavior if this variable is None) +# +# THIS SELECTION TAKES PRECEDENCE OVER THE +# FASTVIDEO ATTENTION BACKEND ENVIRONMENT VARIABLE +forced_attn_backend: AttentionBackendEnum | None = None + + +def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None: + """ + Force all attention operations to use a specified backend. + + Passing `None` for the argument re-enables automatic + backend selection., + + Arguments: + + * attn_backend: backend selection (None to revert to auto) + """ + global forced_attn_backend + forced_attn_backend = attn_backend + + +def get_global_forced_attn_backend() -> AttentionBackendEnum | None: + """ + Get the currently-forced choice of attention backend, + or None if auto-selection is currently enabled. + """ + return forced_attn_backend + + +def get_attn_backend( + head_size: int, + dtype: torch.dtype, + supported_attention_backends: set[AttentionBackendEnum] | None = None, +) -> type[AttentionBackend]: + if supported_attention_backends is None: + be_tuple = tuple() + else: + # Sort the backend names to ensure consistent cache key + be_tuple = tuple( + sorted(list(supported_attention_backends), key=lambda b: b.name) + ) + return _cached_get_attn_backend(head_size, dtype, be_tuple) + + +@cache +def _cached_get_attn_backend( + head_size: int, + dtype: torch.dtype, + supported_attention_backends: tuple[AttentionBackendEnum], +) -> type[AttentionBackend]: + # Check whether a particular choice of backend was + # previously forced via global_force_attn_backend() or --attention-backend CLI arg. + from sglang.multimodal_gen.runtime.platforms import current_platform + + supported_attention_backends = set(supported_attention_backends) + selected_backend = None + backend_by_global_setting: AttentionBackendEnum | None = ( + get_global_forced_attn_backend() + ) + if backend_by_global_setting is not None: + selected_backend = backend_by_global_setting + else: + # Check the server arguments for a backend override + server_args = get_global_server_args() + if server_args.attention_backend is not None: + try: + selected_backend = AttentionBackendEnum[ + server_args.attention_backend.upper() + ] + + except KeyError: + raise ValueError( + f"Invalid attention backend '{server_args.attention_backend}' specified via command line. " + f"Available options are: {[e.name.lower() for e in AttentionBackendEnum]}" + ) + + # get device-specific attn_backend + if len(supported_attention_backends) == 0: + # all attention backends are allowed + pass + elif selected_backend is None: + logger.debug(f"Attention backend not specified") + elif selected_backend not in supported_attention_backends: + supported_attention_backends_str = [ + supported_attention_backend.__str__() + for supported_attention_backend in supported_attention_backends + ] + logger.debug( + f"Selected attention backend: '{selected_backend}' not in supported attention backends: {supported_attention_backends_str}" + ) + selected_backend = None + + attention_cls = current_platform.get_attn_backend_cls_str( + selected_backend, head_size, dtype + ) + if not attention_cls: + raise ValueError( + f"Invalid attention backend for {current_platform.device_name}" + ) + return cast(type[AttentionBackend], resolve_obj_by_qualname(attention_cls)) + + +@contextmanager +def global_force_attn_backend_context_manager( + attn_backend: AttentionBackendEnum, +) -> Generator[None, None, None]: + """ + Globally force a sglang-diffusion attention backend override within a + context manager, reverting the global attention backend + override to its prior state upon exiting the context + manager. + + Arguments: + * attn_backend: attention backend to force + + Returns: + + * Generator + """ + + # Save the current state of the global backend override (if any) + original_value = get_global_forced_attn_backend() + + # Globally force the new backend override + global_force_attn_backend(attn_backend) + + # Yield control back to the enclosed code block + try: + yield + finally: + # Revert the original global backend override, if any + global_force_attn_backend(original_value) diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/attention/turbo_layer.py b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/turbo_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..3a1642559608da67f1a035904dfa4c3354f63810 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/attention/turbo_layer.py @@ -0,0 +1,272 @@ +# copy and modify from https://github.com/thu-ml/TurboDiffusion/blob/main/turbodiffusion/rcm/utils/a2a_cp.py and https://github.com/thu-ml/TurboDiffusion/blob/main/turbodiffusion/SLA/core.py + +from typing import Any, Callable, List, Tuple, Type, Union + +import torch +import torch.distributed as dist +from einops import rearrange +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn import Module + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionImpl, +) +from sglang.multimodal_gen.runtime.layers.attention.backends.sparse_linear_attn import ( + SageSparseLinearAttentionBackend, + SparseLinearAttentionBackend, +) +from sglang.multimodal_gen.runtime.layers.attention.selector import get_attn_backend +from sglang.multimodal_gen.runtime.managers.forward_context import ( + ForwardContext, + get_forward_context, +) +from sglang.multimodal_gen.runtime.platforms.interface import AttentionBackendEnum +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import get_compute_dtype + +logger = init_logger(__name__) + + +def post_all2all(local_seq_2_local_head, seq_world_size): + def post_func(input): + # b, s, n, h + if local_seq_2_local_head: + output = rearrange(input, "w bs seq h d -> bs (w seq) h d") + else: + output = rearrange(input, "w bs s h d -> bs s (w h) d", w=seq_world_size) + + return output + + return post_func + + +def single_all_to_all(input, local_seq_2_local_head, group, async_op=False): + seq_world_size = dist.get_world_size(group) + + # b, s, n, h + if local_seq_2_local_head: + bs, local_seq_len, num_total_head, head_dim = input.shape + assert ( + num_total_head % seq_world_size == 0 + ), f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!" + input_t = rearrange( + input, + "bs seq_len (w h) d -> w bs seq_len h d", + w=seq_world_size, + h=num_total_head // seq_world_size, + ).contiguous() + post_all2all_fun = post_all2all(local_seq_2_local_head, seq_world_size) + else: + bs, global_seq_len, num_local_head, head_dim = input.shape + input_t = rearrange( + input, + "bs (w s) h d -> w bs s h d", + w=seq_world_size, + s=global_seq_len // seq_world_size, + ).contiguous() + post_all2all_fun = post_all2all(local_seq_2_local_head, seq_world_size) + + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group, async_op=async_op) + + res = post_all2all_fun(output) + return res + + +def async_a2a_communicate( + a2a_inputs: Union[torch.Tensor, List[torch.Tensor]], + cp_size: int, + cp_group: ProcessGroup, + cp_stream: torch.get_device_module().Stream, + local_seq_2_local_head: bool, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """ + A2A communication for context parallelism. best used in communicate qkv + Modified from Nvidia Transformer Engine. + """ + a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs + a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) + a2a_post_fns = [None] * len(a2a_inputs) + if local_seq_2_local_head: + for i in range(len(a2a_inputs) + 2): + if 0 < i < len(a2a_inputs) + 1: + a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) + a2a_reqs[i - 1] = torch.distributed.all_to_all_single( + a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True + ) + a2a_post_fns[i - 1] = post_all2all(local_seq_2_local_head, cp_size) + if i > 1: + with torch.get_device_module().stream(cp_stream): + a2a_reqs[i - 2].wait() + a2a_outputs[i - 2] = a2a_post_fns[i - 2](a2a_outputs[i - 2]) + if i < len(a2a_inputs): + a2a_inputs[i] = rearrange( + a2a_inputs[i], "bs seq_len (w h) d -> w bs seq_len h d", w=cp_size + ).contiguous() + else: + for i in range(len(a2a_inputs) + 2): + if 0 < i < len(a2a_inputs) + 1: + a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) + a2a_reqs[i - 1] = torch.distributed.all_to_all_single( + a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True + ) + a2a_post_fns[i - 1] = post_all2all(local_seq_2_local_head, cp_size) + if i < len(a2a_inputs): + a2a_inputs[i] = rearrange( + a2a_inputs[i], "bs (w s) h d -> w bs s h d", w=cp_size + ).contiguous() + if i > 1: + with torch.get_device_module().stream(cp_stream): + a2a_reqs[i - 2].wait() + a2a_outputs[i - 2] = a2a_post_fns[i - 2](a2a_outputs[i - 2]) + torch.get_device_module().current_stream().wait_stream(cp_stream) + return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs + + +class _SeqAllToAll(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, group: dist.ProcessGroup, input: Tensor, local_seq_2_local_head: bool + ) -> Tensor: + ctx.group = group + res = single_all_to_all(input, local_seq_2_local_head, group, False) + ctx.local_seq_2_local_head = local_seq_2_local_head + return res + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None]: + return ( + None, + _SeqAllToAll.apply(ctx.group, *grad_output, not ctx.local_seq_2_local_head), + None, + ) + + +class _SeqAllToAllQKV(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + q: Tensor, + k: Tensor, + v: Tensor, + cp_size: int, + cp_stream: torch.get_device_module().Stream, + local_seq_2_local_head: bool, + ) -> Tuple[Tensor, Tensor, Tensor]: + ctx.group = group + ctx.cp_size = cp_size + ctx.cp_stream = cp_stream + ctx.local_seq_2_local_head = local_seq_2_local_head + q, k, v = async_a2a_communicate( + [q, k, v], cp_size, group, cp_stream, local_seq_2_local_head + ) + return q, k, v + + @staticmethod + def backward( + ctx: Any, *grad_output: Tensor + ) -> Tuple[None, Tensor, Tensor, Tensor, None, None, None]: + q_grad, k_grad, v_grad = _SeqAllToAllQKV.apply( + ctx.group, + *grad_output, + ctx.cp_size, + ctx.cp_stream, + not ctx.local_seq_2_local_head, + ) + return (None, q_grad, k_grad, v_grad, None, None, None) + + +class DistributedAttention(torch.nn.Module): + """Initialization. + + Arguments: + local_attention (Module): local attention with q,k,v + sequence_process_group (ProcessGroup): sequence parallel process group + """ + + def __init__(self, local_attention: Union[Module, Callable]) -> None: + super(DistributedAttention, self).__init__() + self.local_attn = local_attention + self.pg = None + self.stream = None + + def forward( + self, query: Tensor, key: Tensor, value: Tensor, ctx_attn_metadata + ) -> Tensor: + """forward + + Arguments: + query (Tensor): query input to the layer + key (Tensor): key input to the layer + value (Tensor): value input to the layer + + Returns: + * output (Tensor): context output + """ + if self.pg is None: + return self.local_attn(query, key, value, ctx_attn_metadata) + pg_size = dist.get_world_size(self.pg) + if pg_size < 2: + return self.local_attn(query, key, value, ctx_attn_metadata) + + query_layer, key_layer, value_layer = _SeqAllToAllQKV.apply( + self.pg, query, key, value, pg_size, self.stream, True + ) + context_layer = self.local_attn( + query_layer, key_layer, value_layer, ctx_attn_metadata + ) + + output = _SeqAllToAll.apply(self.pg, context_layer, False) + return output + + def set_context_parallel_group(self, group, stream): + self.pg = group + self.stream = stream + + +class MinimalA2AAttnOp(DistributedAttention): + def __init__( + self, + num_heads: int, + head_size: int, + attention_type: str, + topk: float, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + ): + dtype = get_compute_dtype() + attn_backend = get_attn_backend( + head_size, dtype, supported_attention_backends=supported_attention_backends + ) + # Maintained for compatibility purposes; can be removed when CI allows setting Attention_backend or when TurboWan supports FA. + if attn_backend not in ( + SparseLinearAttentionBackend, + SageSparseLinearAttentionBackend, + ): + logger.warning_once( + "TurboWan now only supports `sla_attn` or `sage_sla_attn` and has been automatically set to attention_type. Please set --attention-backend to `sla_attn` or `sage_sla_attn`." + ) + if attention_type == "sagesla": + attn_backend = SageSparseLinearAttentionBackend + else: + attn_backend = SparseLinearAttentionBackend + impl_cls: Type["AttentionImpl"] = attn_backend.get_impl_cls() + local_attn = impl_cls( + num_heads=num_heads, + head_size=head_size, + topk_ratio=topk, + ) + super(MinimalA2AAttnOp, self).__init__(local_attn) + + def set_context_parallel_group(self, process_group, ranks, stream): + del ranks + super().set_context_parallel_group(process_group, stream) + + def forward( + self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs + ) -> Tensor: + forward_context: ForwardContext = get_forward_context() + ctx_attn_metadata = forward_context.attn_metadata + results = super().forward(query, key, value, ctx_attn_metadata) + return rearrange(results, "b ... h l -> b ... (h l)") diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/custom_op.py b/sglang/python/sglang/multimodal_gen/runtime/layers/custom_op.py new file mode 100644 index 0000000000000000000000000000000000000000..c4abecb446acf8ea2f23f91d1f678247cad02b57 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/custom_op.py @@ -0,0 +1,114 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/custom_op.py + +from collections.abc import Callable +from typing import Any + +import torch.nn as nn + +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) +_is_cuda = current_platform.is_cuda() + + +class CustomOp(nn.Module): + """ + Base class for custom ops. + Dispatches the forward method to the appropriate backend. + """ + + def __init__(self) -> None: + super().__init__() + self._forward_method = self.dispatch_forward() + + def forward(self, *args, **kwargs) -> Any: + return self._forward_method(*args, **kwargs) + + def forward_native(self, *args, **kwargs) -> Any: + """PyTorch-native implementation of the forward method. + This method is optional. If implemented, it can be used with compilers + such as torch.compile or PyTorch XLA. Also, it can be used for testing + purposes. + """ + raise NotImplementedError + + def forward_cuda(self, *args, **kwargs) -> Any: + raise NotImplementedError + + def forward_hip(self, *args, **kwargs) -> Any: + # ROCm kernels follow the CUDA path by default. + return self.forward_cuda(*args, **kwargs) + + def forward_cpu(self, *args, **kwargs) -> Any: + # By default, we assume that CPU ops are compatible with CUDA ops. + return self.forward_cuda(*args, **kwargs) + + def forward_tpu(self, *args, **kwargs) -> Any: + # By default, we assume that TPU ops are compatible with the + # PyTorch-native implementation. + # NOTE(woosuk): This is a placeholder for future extensions. + return self.forward_native(*args, **kwargs) + + def forward_musa(self, *args, **kwargs) -> Any: + # MUSA kernels follow the CUDA path by default. + return self.forward_cuda(*args, **kwargs) + + def forward_oot(self, *args, **kwargs) -> Any: + # By default, we assume that OOT ops are compatible with the + # PyTorch-native implementation. + return self.forward_native(*args, **kwargs) + + def forward_npu(self, *args, **kwargs) -> Any: + # By default, we assume that NPU ops are compatible with the + # PyTorch-native implementation. + return self.forward_native(*args, **kwargs) + + def dispatch_forward(self) -> Callable: + if _is_cuda: + return self.forward_cuda + elif current_platform.is_hip(): + return self.forward_hip + elif current_platform.is_npu(): + return self.forward_npu + elif current_platform.is_xpu(): + return self.forward_xpu + elif current_platform.is_musa(): + return self.forward_musa + else: + return self.forward_native + + @classmethod + def enabled(cls) -> bool: + # since we are not using Inductor, we always return True + return True + + @staticmethod + def default_on() -> bool: + """ + On by default if level < CompilationLevel.PIECEWISE + Specifying 'all' or 'none' in custom_op takes precedence. + """ + raise NotImplementedError + + # Dictionary of all custom ops (classes, indexed by registered name). + # To check if an op with a name is enabled, call .enabled() on the class. + # Examples: + # - MyOp.enabled() + # - op_registry["my_op"].enabled() + op_registry: dict[str, type["CustomOp"]] = {} + + # Decorator to register custom ops. + @classmethod + def register(cls, name: str) -> Callable: + + def decorator(op_cls): + assert name not in cls.op_registry, f"Duplicate op name: {name}" + op_cls.name = name + cls.op_registry[name] = op_cls + return op_cls + + return decorator diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/elementwise.py b/sglang/python/sglang/multimodal_gen/runtime/layers/elementwise.py new file mode 100644 index 0000000000000000000000000000000000000000..a9cafb70e75cdf0b23d810ea6bda6fdf72c115cb --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/elementwise.py @@ -0,0 +1,35 @@ +import torch + +from sglang.jit_kernel.diffusion.triton.scale_shift import fuse_scale_shift_kernel +from sglang.multimodal_gen.runtime.layers.custom_op import CustomOp + + +class MulAdd(CustomOp): + """ + Fuse elementwise mul and add + Input: a, b, c, OptionalInt[k] + Output: a * (k + b) + c + """ + + def __init__(self, prefix: str = ""): + super().__init__() + + def forward_native( + self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, k: int = 0 + ) -> torch.Tensor: + # a.shape: [batch_size, seq_len, inner_dim] + if b.dim() == 4: + # b.shape: [batch_size, num_frames, 1, inner_dim] + num_frames = b.shape[1] + frame_seqlen = a.shape[1] // num_frames + return c + ( + a.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * (k + b) + ).flatten(1, 2) + else: + # b.shape: [batch_size, 1, inner_dim] + return c + a * (k + b) + + def forward_cuda( + self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, k: int = 0 + ): + return fuse_scale_shift_kernel(a, b, c, scale_constant=k) diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/layernorm.py b/sglang/python/sglang/multimodal_gen/runtime/layers/layernorm.py new file mode 100644 index 0000000000000000000000000000000000000000..120730556574c05b26a8e57ac4f853bb776b9ed1 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/layernorm.py @@ -0,0 +1,593 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/layernorm.py +"""Custom normalization layers.""" + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sglang.multimodal_gen.runtime.platforms import current_platform + +_is_cuda = current_platform.is_cuda() +_is_npu = current_platform.is_npu() +_is_musa = current_platform.is_musa() +if _is_cuda: + from sgl_kernel import fused_add_rmsnorm, rmsnorm + +if _is_npu: + import torch_npu + +if _is_musa: + from sgl_kernel import fused_add_rmsnorm + +from sglang.jit_kernel.diffusion.triton.norm import norm_infer, rms_norm_fn +from sglang.jit_kernel.diffusion.triton.rmsnorm_onepass import triton_one_pass_rms_norm +from sglang.jit_kernel.diffusion.triton.scale_shift import fuse_scale_shift_kernel +from sglang.jit_kernel.norm import can_use_fused_inplace_qknorm, fused_inplace_qknorm +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, +) +from sglang.multimodal_gen.runtime.layers.custom_op import CustomOp +from sglang.multimodal_gen.runtime.utils.common import get_bool_env_var + + +# Copied and adapted from sglang +@CustomOp.register("rms_norm") +class RMSNorm(CustomOp): + """Root mean square normalization. + + Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. + Refer to https://arxiv.org/abs/1910.07467 + """ + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + dtype: torch.dtype = torch.float32, + var_hidden_size: Optional[int] = None, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.hidden_size = hidden_size + self.variance_size_override = ( + None if var_hidden_size == hidden_size else var_hidden_size + ) + if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE"): + self._forward_method = self.forward_native + + def forward_triton(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None): + return rms_norm_fn( + x, self.weight, bias=None, residual=residual, eps=self.variance_epsilon + ) + + def forward_cuda( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + shape = x.shape + device = x.device + x = x.reshape(-1, shape[-1]) + if residual is not None: + residual_shape = residual.shape + residual = residual.view(-1, shape[-1]) + + if x.dtype == torch.float: + # fp32 + out = self.forward_triton(x, residual) + if residual is not None: + return out[0].view(shape), out[1].view(residual_shape) + out = out.view(shape) + return out + elif self.variance_size_override is not None: + return self.forward_native(x, residual) + elif residual is not None: + fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) + return x.view(shape), residual.view(residual_shape) + else: + if x.shape[-1] <= 128: + out = triton_one_pass_rms_norm( + x, self.weight.data, self.variance_epsilon + ) + else: + out = rmsnorm(x, self.weight.data, self.variance_epsilon) + out = out.view(shape) + + return out + + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if not x.is_contiguous(): + x = x.contiguous() + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + hidden_size = x.shape[-1] + if hidden_size != self.hidden_size: + raise ValueError( + "Expected hidden_size to be " + f"{self.hidden_size}, but found: {hidden_size}" + ) + + if self.variance_size_override is None: + x_var = x + else: + if hidden_size < self.variance_size_override: + raise ValueError( + "Expected hidden_size to be at least " + f"{self.variance_size_override}, but found: {hidden_size}" + ) + + x_var = x[..., : self.variance_size_override] + + variance = x_var.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = (x * self.weight).to(orig_dtype) + if residual is None: + return x + else: + return x, residual + + def forward_cpu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + return self.forward_native(x, residual) + + def forward_npu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + out, _, residual_out = torch_npu.npu_add_rms_norm( + residual, x, self.weight.data, self.variance_epsilon + ) + return out, residual_out + return torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0] + + def forward_hip( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + # ROCm builds of sgl-kernel do not expose rmsnorm custom ops yet. + return self.forward_native(x, residual) + + def _get_weight(self, dtype: torch.dtype) -> torch.Tensor: + """Return weight matched to *dtype*. + + MUSA kernels require input and weight to share the same dtype, + unlike CUDA kernels which may handle mixed dtypes internally. + """ + weight = self.weight.data + if weight.dtype != dtype: + weight = weight.to(dtype=dtype) + return weight + + def forward_musa( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + shape = x.shape + x = x.reshape(-1, shape[-1]) + if residual is not None: + residual_shape = residual.shape + residual = residual.view(-1, shape[-1]) + + if self.variance_size_override is not None: + return self.forward_native(x, residual) + elif residual is not None: + # fused_add_rmsnorm requires contiguous inputs. + if not x.is_contiguous(): + x = x.contiguous() + if not residual.is_contiguous(): + residual = residual.contiguous() + weight = self._get_weight(x.dtype) + fused_add_rmsnorm(x, residual, weight, self.variance_epsilon) + return x.view(shape), residual.view(residual_shape) + else: + weight = self._get_weight(x.dtype) + out = F.rms_norm(x, (self.hidden_size,), weight, self.variance_epsilon) + out = out.view(shape) + return out + + def extra_repr(self) -> str: + s = f"hidden_size={self.weight.data.size(0)}" + s += f", eps={self.variance_epsilon}" + return s + + +# Copied and adapted from sglang +@CustomOp.register("layer_norm") +class LayerNorm(CustomOp): + def __init__( + self, + hidden_size: int, + eps=1e-5, + bias: bool = True, + elementwise_affine=True, + device=None, + dtype=None, + ) -> None: + super().__init__() + self.eps = eps + factory_kwargs = {"device": device, "dtype": dtype} + self.hidden_size = hidden_size + if elementwise_affine: + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.bias = ( + torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + if bias + else None + ) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + # Lazy cache for ones vector (not a registered buffer to avoid FSDP/meta issues) + self._weight_fallback_cache = None + + def _get_weight_fallback(self, x: torch.Tensor) -> torch.Tensor: + wf = getattr(self, "_weight_fallback_cache", None) + if ( + wf is None + or wf.device != x.device + or wf.dtype != x.dtype + or wf.numel() != self.hidden_size + ): + wf = torch.ones(self.hidden_size, device=x.device, dtype=x.dtype) + self._weight_fallback_cache = wf + return wf + + def forward_triton(self, x: torch.Tensor): + # Fast inference kernel without residual/dropout branches + return norm_infer( + x.view(-1, self.hidden_size), + self.weight, + self.bias, + eps=self.eps, + is_rms_norm=False, + ).view(x.shape) + + def forward_cuda( + self, + x: torch.Tensor, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + shape = x.shape + x = x.view(-1, self.hidden_size) + return self.forward_triton(x).view(shape) + + @torch.compile(backend="inductor", disable=current_platform.is_npu()) + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + input_dtype = x.dtype + mean = x.mean(-1, keepdim=True) + variance = (x - mean).pow(2).mean(-1, keepdim=True) + x = (x - mean) * torch.rsqrt(variance + self.eps) + if self.weight is not None: + x = self.weight * x + # if no affine, this is a no-op + if self.bias is not None: + x = x + self.bias + return x.to(input_dtype) + + def forward_cpu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + return self.forward_native(x, residual) + + def forward_musa(self, x: torch.Tensor): + return F.layer_norm(x, (self.hidden_size,), self.weight, self.bias, self.eps) + + def extra_repr(self) -> str: + s = f"hidden_size={self.weight.data.size(0)}" + s += f", eps={self.variance_epsilon}" + return s + + +# adapted from Diffusers: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/normalization.py +# NOTE(will): Needed to match behavior of diffusers and wan2.1 even while using +# FSDP's MixedPrecisionPolicy +class FP32LayerNorm(nn.LayerNorm): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + device = inputs.device + return F.layer_norm( + inputs.float(), + self.normalized_shape, + self.weight.float().to(device=device) if self.weight is not None else None, + self.bias.float().to(device=device) if self.bias is not None else None, + self.eps, + ).to(origin_dtype) + + +################################################################################ +# Fused norm kernel +################################################################################ +def _ensure_contiguous(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + return tensor.contiguous() if tensor is not None else None + + +class _ScaleResidualNormScaleShift(CustomOp): + """ + Fused kernel that combines: + 1. residual_out = residual + gate * x + 2. normed = layernorm(residual_out) or rmsnorm(residual_out) + 3. out = normed * (1 + scale) + shift + compute_dtype is always fp32 for higher precision. + """ + + norm_type: str + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + elementwise_affine: bool = False, + dtype: torch.dtype = torch.float32, + prefix: str = "", + ): + super().__init__() + self.eps = eps + self.dtype = dtype + if self.norm_type == "rms": + self.norm = RMSNorm(hidden_size, eps=eps, dtype=dtype) + elif self.norm_type == "layer": + self.norm = FP32LayerNorm( + hidden_size, elementwise_affine=elementwise_affine, eps=eps, dtype=dtype + ) + else: + raise NotImplementedError(f"Norm type {self.norm_type} not implemented") + + def forward_cuda( + self, + residual: torch.Tensor, + x: torch.Tensor, + gate: torch.Tensor | int, + shift: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + if x.shape[-1] % 256 != 0 and x.shape[-1] <= 8192: + import warnings + + warnings.warn( + "FusedScaleResidualNormScaleShift cuda not available, using native fallback", + stacklevel=2, + ) + return self.forward_native(residual, x, gate, shift, scale) + + from sglang.jit_kernel.diffusion.cutedsl.scale_residual_norm_scale_shift import ( + fused_scale_residual_norm_scale_shift, + ) + + if isinstance(gate, int) and gate != 1: + raise ValueError( + f"Only gate value of 1 is supported for int type, but got {gate}" + ) + + return fused_scale_residual_norm_scale_shift( + residual.contiguous(), + x.contiguous(), + gate.contiguous() if isinstance(gate, torch.Tensor) else None, + _ensure_contiguous(getattr(self.norm, "weight", None)), + _ensure_contiguous(getattr(self.norm, "bias", None)), + scale.contiguous(), + shift.contiguous(), + self.norm_type, + self.eps, + ) + + def forward_hip(self, *args, **kwargs): + # ROCm does not support CUDA/CUTLASS-based fused kernels yet, + # so we fall back to the native PyTorch implementation. + return self.forward_native(*args, **kwargs) + + def forward_musa(self, *args, **kwargs): + # MUSA does not support CUDA/CUTLASS-based fused kernels yet, + # so we fall back to the native PyTorch implementation. + return self.forward_native(*args, **kwargs) + + def forward_native( + self, + residual: torch.Tensor, + x: torch.Tensor, + gate: torch.Tensor | int, + shift: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # x.shape: [batch_size, seq_len, inner_dim] + if isinstance(gate, int): + # used by cross-attention, should be 1 + assert gate == 1 + residual_output = residual + x + elif isinstance(gate, torch.Tensor): + if gate.dim() == 4: + # gate.shape: [batch_size, num_frames, 1, inner_dim] + num_frames = gate.shape[1] + frame_seqlen = x.shape[1] // num_frames + residual_output = residual + ( + x.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * gate + ).flatten(1, 2) + else: + # gate.shape: [batch_size, 1, inner_dim] + residual_output = residual + x * gate + else: + raise ValueError(f"Gate type {type(gate)} not supported") + normalized = self.norm(residual_output) + modulated = fuse_scale_shift_kernel(normalized, scale, shift) + return modulated, residual_output + + +class ScaleResidualLayerNormScaleShift(_ScaleResidualNormScaleShift): + norm_type = "layer" + + +class ScaleResidualRMSNormScaleShift(_ScaleResidualNormScaleShift): + norm_type = "rms" + + +class _NormScaleShift(CustomOp): + """ + Fused kernel that combines: + 1. normed = layernorm(x) or rmsnorm(x) + 2. out = normed * (1 + scale) + shift + compute_dtype is always fp32 for higher precision. + """ + + norm_type: str + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + elementwise_affine: bool = False, + dtype: torch.dtype = torch.float32, + prefix: str = "", + ): + super().__init__() + self.eps = eps + if self.norm_type == "rms": + self.norm = RMSNorm(hidden_size, eps=eps, dtype=dtype) + elif self.norm_type == "layer": + self.norm = FP32LayerNorm( + hidden_size, elementwise_affine=elementwise_affine, eps=eps, dtype=dtype + ) + else: + raise NotImplementedError(f"Norm type {self.norm_type} not implemented") + + def forward_cuda( + self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor + ) -> torch.Tensor: + if x.shape[-1] % 256 != 0 and x.shape[-1] <= 8192: + import warnings + + warnings.warn( + "FusedNormScaleShift cuda not available, using native fallback", + stacklevel=2, + ) + return self.forward_native(x, shift, scale) + + from sglang.jit_kernel.diffusion.cutedsl.scale_residual_norm_scale_shift import ( + fused_norm_scale_shift, + ) + + return fused_norm_scale_shift( + x.contiguous(), + _ensure_contiguous(getattr(self.norm, "weight", None)), + _ensure_contiguous(getattr(self.norm, "bias", None)), + scale.contiguous(), + shift.contiguous(), + self.norm_type, + self.eps, + ) + + def forward_hip(self, *args, **kwargs): + # ROCm does not support CUDA/CUTLASS-based fused kernels yet, + # so we fall back to the native PyTorch implementation. + return self.forward_native(*args, **kwargs) + + def forward_musa(self, *args, **kwargs): + # MUSA does not support CUDA/CUTLASS-based fused kernels yet, + # so we fall back to the native PyTorch implementation. + return self.forward_native(*args, **kwargs) + + def forward_native( + self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor + ) -> torch.Tensor: + normalized = self.norm(x) + modulated = fuse_scale_shift_kernel(normalized, scale, shift) + return modulated.to(x.dtype) + + +class LayerNormScaleShift(_NormScaleShift): + norm_type = "layer" + + +class RMSNormScaleShift(_NormScaleShift): + norm_type = "rms" + + +def apply_qk_norm( + q: torch.Tensor, + k: torch.Tensor, + q_norm: "RMSNorm", + k_norm: "RMSNorm", + head_dim: int, + allow_inplace: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply QK normalization for query and key tensors. + + Uses JIT fused inplace kernel when available, falls back to standard RMSNorm. + """ + + batch_size = q.size(0) + q_eps = q_norm.variance_epsilon + k_eps = k_norm.variance_epsilon + # Only try fused path on CUDA and when it won't introduce implicit copies. + if ( + _is_cuda + and allow_inplace + and (q_eps == k_eps) + and can_use_fused_inplace_qknorm(head_dim, q.dtype) + ): + fused_inplace_qknorm( + q=q.view(batch_size, -1, head_dim), + k=k.view(batch_size, -1, head_dim), + q_weight=q_norm.weight, + k_weight=k_norm.weight, + head_dim=head_dim, + eps=q_eps, + ) + return q, k + + q_shape = q.shape + k_shape = k.shape + q_out = q_norm(q.view(-1, head_dim)).view(q_shape) + k_out = k_norm(k.view(-1, head_dim)).view(k_shape) + return q_out, k_out + + +def tensor_parallel_rms_norm(x: torch.Tensor, norm: "RMSNorm") -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + src_dtype = x.dtype + weight = norm.weight.tensor_split(tp_size)[tp_rank].float() + x_fp32 = x.float() + variance = x_fp32.pow(2).mean(dim=-1, keepdim=True) + variance = get_tp_group().all_reduce( + variance, op=torch._C._distributed_c10d.ReduceOp.AVG + ) + output = x_fp32 * torch.rsqrt(variance + norm.variance_epsilon) * weight + return output.to(dtype=src_dtype) + + +# TODO: Workaround, fuse norm with new select01 kernel +def apply_layernorm_only(x: torch.Tensor, layernorm_scale_shift: LayerNormScaleShift): + return norm_infer( + x.view(-1, x.shape[-1]), + layernorm_scale_shift.norm.weight, + layernorm_scale_shift.norm.bias, + eps=layernorm_scale_shift.eps, + is_rms_norm=False, + ).view(x.shape) diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/linear.py b/sglang/python/sglang/multimodal_gen/runtime/layers/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..f74309a53405c84f6187d100e05fc17ccd42e6a6 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/linear.py @@ -0,0 +1,1071 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/linear.py + +from abc import abstractmethod + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from sglang.multimodal_gen.runtime.distributed import ( + divide, + get_tp_group, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.multimodal_gen.runtime.layers.utils import get_group_rank, get_group_size + +# yapf: disable +from sglang.multimodal_gen.runtime.models.parameter import ( + BasevLLMParameter, + BlockQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + PerTensorScaleParameter, + RowvLLMParameter, +) + +# yapf: enable +from sglang.multimodal_gen.runtime.models.utils import set_weight_attrs +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +WEIGHT_LOADER_V2_SUPPORTED = [ + "CompressedTensorsLinearMethod", + "AWQMarlinLinearMethod", + "AWQLinearMethod", + "GPTQMarlinLinearMethod", + "Fp8LinearMethod", + "MarlinLinearMethod", + "QQQLinearMethod", + "GPTQMarlin24LinearMethod", + "TPUInt8LinearMethod", + "GPTQLinearMethod", + "FBGEMMFp8LinearMethod", + "ModelOptFp8LinearMethod", + "IPEXAWQLinearMethod", + "IPEXGPTQLinearMethod", + "HQQMarlinMethod", + "QuarkLinearMethod", +] + + +def adjust_scalar_to_fused_array( + param: torch.Tensor, loaded_weight: torch.Tensor, shard_id: str | int +) -> tuple[torch.Tensor, torch.Tensor]: + """For fused modules (QKV and MLP) we have an array of length + N that holds 1 scale for each "logical" matrix. So the param + is an array of length N. The loaded_weight corresponds to + one of the shards on disk. Here, we slice the param based on + the shard_id for loading. + """ + qkv_idxs = {"q": 0, "k": 1, "v": 2} + + if isinstance(shard_id, str): + shard_id = qkv_idxs[shard_id] + elif not isinstance(shard_id, int): + raise ValueError(f"Unknown Shard Id {shard_id}") + + # AutoFP8 scales do not have a shape + # compressed-tensors scales do have a shape + if len(loaded_weight.shape) != 0: + assert loaded_weight.shape[0] == 1 + loaded_weight = loaded_weight[0] + + return param[shard_id], loaded_weight + + +class LinearMethodBase(QuantizeMethodBase): + """Base class for different (maybe quantized) linear methods.""" + + @abstractmethod + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + """Create weights for a linear layer. + The weights will be set as attributes of the layer. + + Args: + layer: The layer that is using the LinearMethodBase factory. + input_size_per_partition: Size of the weight input dim on rank X. + output_partition_sizes: Sizes of the output dim of each logical + weight on rank X. E.g., output_partition_sizes for QKVLinear + is a list contains the width of Wq, Wk, Wv on rank X. + input_size: Size of the input dim of the weight across all ranks. + output_size: Size of the output dim of the weight across all ranks. + params_dtype: Datatype of the parameters. + """ + raise NotImplementedError + + @abstractmethod + def apply( + self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None + ) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + +class UnquantizedLinearMethod(LinearMethodBase): + """Linear method without quantization.""" + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + weight = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply( + self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None + ) -> torch.Tensor: + output = ( + F.linear(x, layer.weight, bias) + if current_platform.is_amp_supported() or bias is None + else F.linear(x, layer.weight, bias.to(x.dtype)) + ) # NOTE: this line assumes that we are using amp when using cuda and is needed to account for the fact that amp isn't supported in mps + return output + + +class LinearBase(torch.nn.Module): + """Base linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_size: int, + skip_bias_add: bool = False, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + self.quant_config = quant_config + self.prefix = prefix + if quant_config is None: + self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod() + else: + self.quant_method = quant_config.get_quant_method(self, prefix=prefix) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, Parameter | None]: + raise NotImplementedError + + +class ReplicatedLinear(LinearBase): + """Replicated linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__( + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix=prefix, + ) + + # All the linear layer supports quant method. + assert self.quant_method is not None + self.quant_method.create_weights( + self, + self.input_size, + [self.output_size], + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader, + ) + + if bias: + self.bias = Parameter( + torch.empty( + self.output_size, + dtype=self.params_dtype, + ) + ) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor) -> None: + # If the weight on disk does not have a shape, give it one + # (such scales for AutoFp8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param.size() == loaded_weight.size(), ( + f"Tried to load weights of size {loaded_weight.size()}" + f"to a parameter of size {param.size()}" + ) + param.data.copy_(loaded_weight) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, Parameter | None]: + bias = self.bias if not self.skip_bias_add else None + assert self.quant_method is not None + output = self.quant_method.apply(self, x, bias) + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + return s + + +class ColumnParallelLinear(LinearBase): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Args: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + output_sizes: list of output sizes packed into one output, like for QKV + the list would be size 3. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + output_sizes: list[int] | None = None, + prefix: str = "", + tp_group: dist.ProcessGroup = None, + ): + # Divide the weight matrix along the last dimension. + self.tp_group = tp_group or get_tp_group() + self.tp_size = get_group_size(self.tp_group) + self.tp_rank = get_group_rank(self.tp_group) + self.input_size_per_partition = input_size + self.output_size_per_partition = divide(output_size, self.tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, self.tp_size) for output_size in self.output_sizes + ] + + super().__init__( + input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix + ) + + self.gather_output = gather_output + + if output_sizes is None: + output_sizes = [output_size] + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 + if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED + else self.weight_loader + ), + ) + if bias: + self.bias = Parameter( + torch.empty( + self.output_size_per_partition, + dtype=params_dtype, + ) + ) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor) -> None: + tp_rank = self.tp_rank + output_dim = getattr(param, "output_dim", None) + + is_sharded_weight = getattr(param, "is_sharded_weight", False) + is_sharded_weight = is_sharded_weight + + param_data = param.data + if output_dim is not None and not is_sharded_weight: + shard_size = param_data.shape[output_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor) -> None: + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + assert loaded_weight.numel() == 1 + loaded_weight = loaded_weight.reshape(1) + param.load_column_parallel_weight(loaded_weight=loaded_weight) + + def forward(self, input_: torch.Tensor) -> tuple[torch.Tensor, Parameter | None]: + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_, bias) + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather( + output_parallel, tp_group=self.tp_group + ) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size_per_partition}" + s += f", bias={self.bias is not None}" + s += f", tp_size={self.tp_size}" + s += f", gather_output={self.gather_output}" + return s + + +class MergedColumnParallelLinear(ColumnParallelLinear): + """Packed linear layers with column parallelism. + + Similar to ColumnParallelLinear, but the weight matrix is concatenated + along the output dimension. When the weight matrix is loaded, the + different partitions are sharded separately. + + Args: + input_size: input dimension of the linear layer. + output_sizes: list of output dimensions of the linear layer. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make the output + available to all GPUs, otherwise, every GPU will have + its own output. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + tp_group: dist.ProcessGroup = None, + ): + super().__init__( + input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + tp_group=tp_group, + ) + self.output_sizes = output_sizes + assert all(output_size % self.tp_size == 0 for output_size in output_sizes) + + def weight_loader( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: int | None = None, + ) -> None: + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + # Special case for per-tensor scale to load scalar into fused array. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) + + if loaded_shard_id is None: + # Loaded weight is already fused on disk (mlp). + # (e.g., Phi-3's gate_up_proj). + if output_dim is None: + if needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0 + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + current_shard_offset = 0 + shard_offsets: list[tuple[int, int, int]] = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + for shard_id, shard_offset, shard_size in shard_offsets: + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size + ) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + assert loaded_shard_id < len(self.output_sizes) + tp_rank = self.tp_rank + tp_size = self.tp_size + if output_dim is not None: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + + is_sharded_weight = getattr(param, "is_sharded_weight", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight + + param_data = param_data.narrow(output_dim, shard_offset, shard_size) + start_idx = tp_rank * shard_size + if not is_sharded_weight: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_offset = loaded_shard_id * shard_size + param_data = param_data.narrow(0, shard_offset, shard_size) + + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, loaded_shard_id + ) + + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions." + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def _load_fused_module_from_checkpoint( + self, param: BasevLLMParameter, loaded_weight: torch.Tensor + ) -> None: + """ + Handle special case for models where MLP layers are already + fused on disk. In this case, we have no shard id. This function + determmines the shard id by splitting these layers and then calls + the weight loader using the shard id. + + An example of a model with these fused layers: + https://huggingface.co/microsoft/Phi-3-mini-4k-instruct + """ + + current_shard_offset = 0 + shard_offsets: list[tuple[int, int, int]] = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if ( + isinstance(param, PackedColumnParameter | PackedvLLMParameter) + and param.packed_dim == param.output_dim + ): + shard_size, shard_offset = param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset + ) + + loaded_weight_shard = loaded_weight.narrow( + param.output_dim, shard_offset, shard_size + ) + self.weight_loader_v2(param, loaded_weight_shard, shard_id) + + def weight_loader_v2( + self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: int | None = None, + ) -> None: + if loaded_shard_id is None: + if isinstance(param, PerTensorScaleParameter): + param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) + return + elif type(param) in (RowvLLMParameter, BasevLLMParameter): + param.load_merged_column_weight(loaded_weight=loaded_weight) + return + # TODO: @dsikka - move to parameter.py + self._load_fused_module_from_checkpoint(param, loaded_weight) + return + + assert loaded_shard_id < len(self.output_sizes) + + tp_size = self.tp_size + + if isinstance(param, BlockQuantScaleParameter): + raise NotImplementedError("FP8 is not implemented yet") + # FIXME(will): add fp8 support + # from vllm.model_executor.layers.quantization.fp8 import ( + # Fp8LinearMethod, Fp8MoEMethod) + # assert self.quant_method is not None + # assert isinstance(self.quant_method, + # (Fp8LinearMethod, Fp8MoEMethod)) + # weight_block_size = self.quant_method.quant_config.weight_block_size + # assert weight_block_size is not None + # block_n, _ = weight_block_size[0], weight_block_size[1] + # shard_offset = ( + # (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // + # block_n) // tp_size + # shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // + # block_n // tp_size) + else: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + + param.load_merged_column_weight( + loaded_weight=loaded_weight, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size, + ) + + +class QKVParallelLinear(ColumnParallelLinear): + """Linear layers for the attention's QKV transformation. + + Linear layers for the linear transformation of the query, key, and value + vectors in the attention layer. The weight matrix is concatenated along + the output dimension. The layer is parallelized along the head dimension. + When the number of key/value heads is smaller than the number of query + heads (e.g., multi-query/grouped-query attention), the key/value head may + be replicated while the query heads are partitioned. + + Args: + hidden_size: input hidden state size of the transformer. + head_size: size of each attention head. + total_num_heads: total number of attention query heads. + total_num_kv_heads: total number of attention key/value heads. If + None, assume total_num_kv_heads = total_num_heads. + bias: If true, add bias. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: int | None = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + tp_group: dist.ProcessGroup = None, + ): + self.hidden_size = hidden_size + self.head_size = head_size + self.total_num_heads = total_num_heads + if total_num_kv_heads is None: + total_num_kv_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + # Divide the weight matrix along the last dimension. + tp_group = tp_group or get_tp_group() + tp_size = get_group_size(tp_group) + self.num_heads = divide(self.total_num_heads, tp_size) + if tp_size >= self.total_num_kv_heads: + self.num_kv_heads = 1 + self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads) + else: + self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) + self.num_kv_head_replicas = 1 + input_size = self.hidden_size + output_size = ( + (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size + ) + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj + ] + + super().__init__( + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + tp_group=tp_group, + ) + + def _get_shard_offset_mapping(self, loaded_shard_id: str) -> int | None: + shard_offset_mapping = { + "q": 0, + "k": self.num_heads * self.head_size, + "v": (self.num_heads + self.num_kv_heads) * self.head_size, + "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size, + } + return shard_offset_mapping.get(loaded_shard_id) + + def _get_shard_size_mapping(self, loaded_shard_id: str) -> int | None: + shard_size_mapping = { + "q": self.num_heads * self.head_size, + "k": self.num_kv_heads * self.head_size, + "v": self.num_kv_heads * self.head_size, + } + return shard_size_mapping.get(loaded_shard_id) + + def _load_fused_module_from_checkpoint( + self, param: BasevLLMParameter, loaded_weight: torch.Tensor + ): + """ + Handle special case for models where QKV layers are already + fused on disk. In this case, we have no shard id. This function + determmines the shard id by splitting these layers and then calls + the weight loader using the shard id. + + An example of a model with these fused layers: + https://huggingface.co/microsoft/Phi-3-mini-4k-instruct + """ + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ( + "k", + self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ( + "v", + (self.total_num_heads + self.total_num_kv_heads) * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ] + + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if ( + isinstance(param, PackedColumnParameter | PackedvLLMParameter) + and param.packed_dim == param.output_dim + ): + shard_size, shard_offset = param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset + ) + + loaded_weight_shard = loaded_weight.narrow( + param.output_dim, shard_offset, shard_size + ) + self.weight_loader_v2(param, loaded_weight_shard, shard_id) + + def weight_loader_v2( + self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: str | None = None, + ): + if loaded_shard_id is None: # special case for certain models + if isinstance(param, PerTensorScaleParameter): + param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0) + return + elif type(param) in (RowvLLMParameter, BasevLLMParameter): + param.load_qkv_weight(loaded_weight=loaded_weight) + return + # TODO: @dsikka - move to parameter.py + self._load_fused_module_from_checkpoint(param, loaded_weight) + return + + assert loaded_shard_id in ["q", "k", "v"] + + shard_offset = self._get_shard_offset_mapping(loaded_shard_id) + shard_size = self._get_shard_size_mapping(loaded_shard_id) + + param.load_qkv_weight( + loaded_weight=loaded_weight, + num_heads=self.num_kv_head_replicas, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size, + ) + + def weight_loader( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: str | None = None, + ): + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + + # Special case for per-tensor scales in fused case. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) + + if loaded_shard_id is None: + # Loaded weight is already fused on disk (qkv). + # (e.g., Phi-3's qkv_proj). + if output_dim is None: + if needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0 + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ( + "k", + self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ( + "v", + (self.total_num_heads + self.total_num_kv_heads) * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ] + + for shard_id, shard_offset, shard_size in shard_offsets: + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size + ) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + tp_rank = self.tp_rank + assert loaded_shard_id in ["q", "k", "v"] + + # If output dim is defined, use the default loading process. + if output_dim is not None: + if loaded_shard_id == "q": + shard_offset = 0 + shard_size = self.num_heads * self.head_size + elif loaded_shard_id == "k": + shard_offset = self.num_heads * self.head_size + shard_size = self.num_kv_heads * self.head_size + elif loaded_shard_id == "v": + shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size + shard_size = self.num_kv_heads * self.head_size + + is_sharded_weight = getattr(param, "is_sharded_weight", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight + + shard_idx = 0 + param_data = param_data.narrow(output_dim, shard_offset, shard_size) + if loaded_shard_id == "q": + shard_idx = tp_rank + else: + shard_idx = tp_rank // self.num_kv_head_replicas + start_idx = shard_idx * shard_size + + if not is_sharded_weight: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + + # Special case for for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_index = ["q", "k", "v"].index(loaded_shard_id) + param_data = param_data.narrow(0, shard_index * shard_size, shard_size) + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, loaded_shard_id + ) + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "QKVParallelLinear, assume the weight is the same " + "for all partitions." + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class RowParallelLinear(LinearBase): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + skip_bias_add: This was added to enable performance optimization where + bias can be fused with other element-wise operations. + We skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: torch.dtype | None = None, + reduce_results: bool = True, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + tp_group: dist.ProcessGroup = None, + ): + # Divide the weight matrix along the first dimension. + self.tp_group = tp_group or get_tp_group() + self.tp_rank = get_group_rank(self.tp_group) + self.tp_size = get_group_size(self.tp_group) + self.input_size_per_partition = divide(input_size, self.tp_size) + self.output_size_per_partition = output_size + self.output_partition_sizes = [output_size] + + super().__init__( + input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix + ) + + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 + if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED + else self.weight_loader + ), + ) + if not reduce_results and (bias and not skip_bias_add): + raise ValueError( + "When not reduce the results, adding bias to the " + "results can lead to incorrect results" + ) + + if bias: + self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = self.tp_rank + input_dim = getattr(param, "input_dim", None) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight + + param_data = param.data + if input_dim is not None and not is_sharded_weight: + shard_size = param_data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor): + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + assert loaded_weight.numel() == 1 + loaded_weight = loaded_weight.reshape(1) + + param.load_row_parallel_weight(loaded_weight=loaded_weight) + + def forward(self, input_) -> tuple[torch.Tensor, Parameter | None]: + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = self.tp_rank + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size + ) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) + if self.reduce_results and self.tp_size > 1: + output = tensor_model_parallel_all_reduce( + output_parallel, tp_group=self.tp_group + ) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + + return output, output_bias + + def extra_repr(self) -> str: + s = f"input_features={self.input_size_per_partition}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + s += f", tp_size={self.tp_size}" + s += f", reduce_results={self.reduce_results}" + return s diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/lora/linear.py b/sglang/python/sglang/multimodal_gen/runtime/layers/lora/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..e83f75a49f2376692c56a37e1cb70180a16c7593 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/lora/linear.py @@ -0,0 +1,527 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Code adapted from SGLang https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/lora/layers.py + + +import torch +from torch import nn +from torch.distributed._composable.fsdp import ( + CPUOffloadPolicy, + OffloadPolicy, + fully_shard, +) +from torch.distributed.tensor import DTensor + +from sglang.multimodal_gen.runtime.distributed import ( + get_local_torch_device, + get_tp_rank, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) +from sglang.multimodal_gen.runtime.layers.linear import ( + ColumnParallelLinear, + LinearBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, +) +from sglang.multimodal_gen.utils import get_mixed_precision_state + +torch._dynamo.config.recompile_limit = 16 + + +class BaseLayerWithLoRA(nn.Module): + + def __init__( + self, + base_layer: nn.Module, + lora_rank: int | None = None, + lora_alpha: int | None = None, + ): + super().__init__() + self.base_layer: nn.Module = base_layer + + self.merged: bool = False + # Immutable base-weight snapshot; `to("cpu")` may alias CPU storage. + # Use `clone()` so merge updates cannot mutate this backup tensor. + self.cpu_weight = base_layer.weight.detach().to("cpu").clone() + # indicates adapter weights don't contain this layer + # (which shouldn't normally happen, but we want to separate it from the case of erroneous merging) + # Default to True to prevent using uninitialized weights; set to False when weights are loaded + self.disable_lora: bool = True + self.lora_rank = lora_rank + self.lora_alpha = lora_alpha + self.lora_weights_list: list[ + tuple[torch.nn.Parameter, torch.nn.Parameter, str | None, float] + ] = [] + self.lora_path: str | None = None + self.strength: float = 1.0 + + self.lora_A = None + self.lora_B = None + + @property + def weight(self): + return self.base_layer.weight + + @property + def bias(self): + return getattr(self.base_layer, "bias", None) + + @torch.compile() + def forward(self, x: torch.Tensor) -> torch.Tensor: + lora_A = self.lora_A + lora_B = self.lora_B + if isinstance(self.lora_B, DTensor): + lora_B = self.lora_B.to_local() + lora_A = self.lora_A.to_local() + + # TODO: Support multiple LoRA adapters when use not merged mode + if not self.merged and not self.disable_lora: + lora_A_sliced = self.slice_lora_a_weights(lora_A.to(x, non_blocking=True)) + lora_B_sliced = self.slice_lora_b_weights(lora_B.to(x, non_blocking=True)) + delta = x @ lora_A_sliced.T @ lora_B_sliced.T + if self.lora_alpha != self.lora_rank: + delta = delta * ( + self.lora_alpha / self.lora_rank # type: ignore + ) # type: ignore + delta = delta * self.strength + if delta.dim() > 2: + delta = delta.reshape(-1, delta.shape[-1]) + out, output_bias = self.base_layer(x) + return out + delta, output_bias + else: + out, output_bias = self.base_layer(x) + return out, output_bias + + def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor: + return A + + def slice_lora_b_weights(self, B: torch.Tensor) -> torch.Tensor: + return B + + def set_lora_weights( + self, + A: torch.Tensor, + B: torch.Tensor, + lora_path: str | None = None, + strength: float = 1.0, + clear_existing: bool = False, + ) -> None: + """ + Set LoRA weights. Supports multiple LoRA adapters. + + Args: + A: LoRA A weight tensor + B: LoRA B weight tensor + lora_path: Path to the LoRA adapter (for logging) + strength: LoRA strength + clear_existing: If True, clear existing LoRA weights before adding new one. + If False, append to existing list (for multi-LoRA support). + """ + lora_A_param = torch.nn.Parameter( + A + ) # share storage with weights in the pipeline + lora_B_param = torch.nn.Parameter(B) + + if clear_existing: + self.lora_weights_list.clear() + # Also clear backward compatibility attributes + self.lora_A = None + self.lora_B = None + self.lora_path = None + self.strength = 1.0 + + # Add to list for multi-LoRA support + self.lora_weights_list.append((lora_A_param, lora_B_param, lora_path, strength)) + + # Set backward compatibility attributes to point to the last LoRA (for single LoRA case) + # This ensures backward compatibility while supporting multiple LoRA + self.lora_A = lora_A_param + self.lora_B = lora_B_param + self.lora_path = lora_path + self.strength = strength + + self.disable_lora = False + self.merge_lora_weights() + + @torch.no_grad() + def _merge_lora_into_data( + self, + data: torch.Tensor, + lora_list: list[ + tuple[torch.nn.Parameter, torch.nn.Parameter, str | None, float] + ], + ) -> None: + """ + Merge all LoRA adapters into the data tensor in-place. + + Args: + data: The base weight tensor to merge LoRA into (modified in-place) + lora_list: List of (lora_A, lora_B, lora_path, lora_strength) tuples + """ + # Merge all LoRA adapters in order + for lora_A, lora_B, _, lora_strength in lora_list: + lora_delta = self.slice_lora_b_weights( + lora_B.to(data) + ) @ self.slice_lora_a_weights(lora_A.to(data)) + # Apply lora_alpha / lora_rank scaling for consistency with forward() + if self.lora_alpha is not None and self.lora_rank is not None: + if self.lora_alpha != self.lora_rank: + lora_delta = lora_delta * (self.lora_alpha / self.lora_rank) + if lora_delta.dim() > 2: + lora_delta = lora_delta.reshape(-1, lora_delta.shape[-1]) + data += lora_strength * lora_delta + + @torch.no_grad() + def merge_lora_weights(self, strength: float | None = None) -> None: + if strength is not None: + self.strength = strength + + if self.disable_lora: + return + + if self.merged: + self.unmerge_lora_weights() + + # Use lora_weights_list if available, otherwise fall back to single LoRA for backward compatibility + lora_list = self.lora_weights_list if self.lora_weights_list else [] + if not lora_list and self.lora_A is not None and self.lora_B is not None: + lora_list = [(self.lora_A, self.lora_B, self.lora_path, self.strength)] + + if not lora_list: + raise ValueError("LoRA weights not set. Please set them first.") + + if isinstance(self.base_layer.weight, DTensor): + mesh = self.base_layer.weight.data.device_mesh + unsharded_base_layer = ReplicatedLinear( + input_size=self.base_layer.input_size, + output_size=self.base_layer.output_size, + bias=getattr(self.base_layer, "bias", None) is not None, + skip_bias_add=self.base_layer.skip_bias_add, + params_dtype=self.base_layer.params_dtype, + quant_config=self.base_layer.quant_config, + prefix=self.base_layer.prefix, + ) + # Using offload param is on CPU, so current_device is for "CPU -> GPU -> merge -> CPU" + current_device = self.base_layer.weight.data.device + data = self.base_layer.weight.data.to( + get_local_torch_device() + ).full_tensor() + + self._merge_lora_into_data(data, lora_list) + + unsharded_base_layer.weight = nn.Parameter(data.to(current_device)) + if isinstance(getattr(self.base_layer, "bias", None), DTensor): + unsharded_base_layer.bias = nn.Parameter( + self.base_layer.bias.to(get_local_torch_device(), non_blocking=True) + .full_tensor() + .to(current_device) + ) + + offload_policy = ( + CPUOffloadPolicy() if "cpu" in str(current_device) else OffloadPolicy() + ) + mp_policy = get_mixed_precision_state().mp_policy + + self.base_layer = fully_shard( + unsharded_base_layer, + mesh=mesh, + mp_policy=mp_policy, + offload_policy=offload_policy, + ) + else: + current_device = self.base_layer.weight.data.device + data = self.base_layer.weight.data.to(get_local_torch_device()) + + self._merge_lora_into_data(data, lora_list) + + self.base_layer.weight.data = data.to(current_device, non_blocking=True) + + self.merged = True + + @torch.no_grad() + # @torch.compile(dynamic=True) + def unmerge_lora_weights(self) -> None: + if self.disable_lora: + return + + if not self.merged: + raise ValueError( + "LoRA weights not merged. Please merge them first before unmerging." + ) + + # avoid precision loss + if isinstance(self.base_layer.weight, DTensor): + device = self.base_layer.weight.data.device + old_weight = self.base_layer.weight + new_weight_data = self.cpu_weight.to(device, non_blocking=True) + self.base_layer.weight = nn.Parameter(new_weight_data) + del old_weight + else: + current_device = self.base_layer.weight.data.device + cpu_weight_on_device = self.cpu_weight.to(current_device, non_blocking=True) + self.base_layer.weight.data.copy_(cpu_weight_on_device) + if ( + cpu_weight_on_device.data_ptr() + != self.base_layer.weight.data.data_ptr() + ): + del cpu_weight_on_device + + self.merged = False + + +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): + """ + Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation). + + Note: The current version does not yet implement the LoRA functionality. + This class behaves exactly the same as the base VocabParallelEmbedding. + Future versions will integrate LoRA functionality to support efficient parameter fine-tuning. + """ + + def __init__( + self, + base_layer: VocabParallelEmbedding, + ) -> None: + super().__init__(base_layer) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + raise NotImplementedError( + "We don't support VocabParallelEmbeddingWithLoRA yet." + ) + + +class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): + + def __init__( + self, + base_layer: ColumnParallelLinear, + lora_rank: int | None = None, + lora_alpha: int | None = None, + ) -> None: + super().__init__(base_layer, lora_rank, lora_alpha) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + # duplicate the logic in ColumnParallelLinear + bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None + output_parallel = self.base_layer.quant_method.apply( + self.base_layer, input_, bias + ) + if self.base_layer.gather_output: + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None + return output, output_bias + + def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor: + return A + + def slice_lora_b_weights(self, B: torch.Tensor) -> torch.Tensor: + tp_rank = get_tp_rank() + shard_size = self.base_layer.output_partition_sizes[0] + start_idx = tp_rank * shard_size + end_idx = (tp_rank + 1) * shard_size + B = B[start_idx:end_idx, :] + return B + + +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + + def __init__( + self, + base_layer: MergedColumnParallelLinear, + lora_rank: int | None = None, + lora_alpha: int | None = None, + ) -> None: + super().__init__(base_layer, lora_rank, lora_alpha) + + def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor: + return A + + def slice_lora_b_weights(self, B: torch.Tensor) -> torch.Tensor: + tp_rank = get_tp_rank() + # Since the outputs for both gate and up are identical, we use a random one. + shard_size = self.base_layer.output_partition_sizes[0] + start_idx = tp_rank * shard_size + end_idx = (tp_rank + 1) * shard_size + return B[:, start_idx:end_idx, :] + + +class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + + def __init__( + self, + base_layer: QKVParallelLinear, + lora_rank: int | None = None, + lora_alpha: int | None = None, + ) -> None: + super().__init__(base_layer, lora_rank, lora_alpha) + + def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor: + return A + + def slice_lora_b_weights( + self, B: list[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor]: + tp_rank = get_tp_rank() + B_q, B_kv = B + base_layer = self.base_layer + q_proj_shard_size = base_layer.q_proj_shard_size + kv_proj_shard_size = base_layer.kv_proj_shard_size + num_kv_head_replicas = base_layer.num_kv_head_replicas + + q_start_idx = q_proj_shard_size * tp_rank + q_end_idx = q_start_idx + q_proj_shard_size + + kv_shard_id = tp_rank // num_kv_head_replicas + kv_start_idx = kv_proj_shard_size * kv_shard_id + kv_end_idx = kv_start_idx + kv_proj_shard_size + + return B_q[q_start_idx:q_end_idx, :], B_kv[:, kv_start_idx:kv_end_idx, :] + + +class RowParallelLinearWithLoRA(BaseLayerWithLoRA): + + def __init__( + self, + base_layer: RowParallelLinear, + lora_rank: int | None = None, + lora_alpha: int | None = None, + ) -> None: + super().__init__(base_layer, lora_rank, lora_alpha) + + def forward(self, input_: torch.Tensor): + # duplicate the logic in RowParallelLinear + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tp_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size + ) + input_parallel = splitted_input[tp_rank].contiguous() + output_parallel = self.base_layer.quant_method.apply( + self.base_layer, input_parallel + ) + + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = ( + output_ + self.base_layer.bias + if self.base_layer.bias is not None + else output_ + ) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + return output, output_bias + + def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor: + tp_rank = get_tp_rank() + shard_size = self.base_layer.input_size_per_partition + start_idx = tp_rank * shard_size + end_idx = (tp_rank + 1) * shard_size + A = A[:, start_idx:end_idx].contiguous() + return A + + def slice_lora_b_weights(self, B: torch.Tensor) -> torch.Tensor: + return B + + +class LinearWithLoRA(BaseLayerWithLoRA): + """ + Wrapper for standard torch.nn.Linear to support LoRA. + Unlike custom LinearBase classes, nn.Linear.forward() returns a single tensor, + not a tuple of (output, bias). + """ + + def __init__( + self, + base_layer: nn.Linear, + lora_rank: int | None = None, + lora_alpha: int | None = None, + ) -> None: + super().__init__(base_layer, lora_rank, lora_alpha) + + @torch.compile() + def forward(self, x: torch.Tensor) -> torch.Tensor: + lora_A = self.lora_A + lora_B = self.lora_B + if isinstance(self.lora_B, DTensor): + lora_B = self.lora_B.to_local() + lora_A = self.lora_A.to_local() + + # TODO: Support multiple LoRA adapters when use not merged mode + if not self.merged and not self.disable_lora: + lora_A_sliced = self.slice_lora_a_weights(lora_A.to(x, non_blocking=True)) + lora_B_sliced = self.slice_lora_b_weights(lora_B.to(x, non_blocking=True)) + delta = x @ lora_A_sliced.T @ lora_B_sliced.T + if self.lora_alpha != self.lora_rank: + delta = delta * ( + self.lora_alpha / self.lora_rank # type: ignore + ) # type: ignore + delta = delta * self.strength + if delta.dim() > 2: + delta = delta.reshape(-1, delta.shape[-1]) + # nn.Linear.forward() returns a single tensor, not a tuple + out = self.base_layer(x) + return out + delta + else: + # nn.Linear.forward() returns a single tensor + out = self.base_layer(x) + return out + + +def wrap_with_lora_layer( + layer: nn.Module, + lora_rank: int | None = None, + lora_alpha: int | None = None, +) -> BaseLayerWithLoRA | None: + """ + transform the given layer to its corresponding LoRA layer + """ + supported_layer_types: dict[ + type[LinearBase] | type[nn.Linear], type[BaseLayerWithLoRA] + ] = { + # the order matters + # VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, + QKVParallelLinear: QKVParallelLinearWithLoRA, + MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, + ColumnParallelLinear: ColumnParallelLinearWithLoRA, + RowParallelLinear: RowParallelLinearWithLoRA, + ReplicatedLinear: BaseLayerWithLoRA, + nn.Linear: LinearWithLoRA, + } + for src_layer_type, lora_layer_type in supported_layer_types.items(): + if isinstance(layer, src_layer_type): # type: ignore[arg-type] + ret = lora_layer_type( + layer, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + ) + return ret + return None + + +# source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9 +def replace_submodule( + model: nn.Module, module_name: str, new_module: nn.Module +) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/mlp.py b/sglang/python/sglang/multimodal_gen/runtime/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..4c1ed0dce0db963196bcaa2f66f9fec4256b02e8 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/mlp.py @@ -0,0 +1,118 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.nn as nn +from diffusers.models.activations import ( + GEGLU, + GELU, + ApproximateGELU, + LinearActivation, + SwiGLU, +) + +from sglang.multimodal_gen.runtime.layers.activation import get_act_fn +from sglang.multimodal_gen.runtime.layers.linear import ( + ColumnParallelLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig + + +class MLP(nn.Module): + """ + MLP for DiT blocks, NO gated linear units + """ + + def __init__( + self, + input_dim: int, + mlp_hidden_dim: int, + output_dim: int | None = None, + bias: bool = True, + act_type: str = "gelu_pytorch_tanh", + dtype: torch.dtype | None = None, + prefix: str = "", + quant_config: QuantizationConfig = None, + ): + super().__init__() + self.fc_in = ColumnParallelLinear( + input_dim, + mlp_hidden_dim, + bias=True, + gather_output=False, + quant_config=quant_config, + ) + + self.act = get_act_fn(act_type) + if output_dim is None: + output_dim = input_dim + self.fc_out = RowParallelLinear( + mlp_hidden_dim, + output_dim, + bias=True, + input_is_parallel=True, + quant_config=quant_config, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.fc_in(x) + x = self.act(x) + x, _ = self.fc_out(x) + return x + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + activation_fn: str = "geglu", + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + elif activation_fn == "swiglu": + act_fn = SwiGLU(dim, inner_dim, bias=bias) + elif activation_fn == "linear-silu": + act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu") + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # dummy dropout layer to match with checkpoints compatible with diffusers + self.net.append(nn.Dropout(0.0)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py b/sglang/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c83f3b17fc6fb4a4acec193367ce400b66a54c29 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py @@ -0,0 +1,64 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from typing import Literal, get_args + +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, +) +from sglang.multimodal_gen.runtime.layers.quantization.fp8 import Fp8Config + +QuantizationMethods = Literal["fp8"] + +QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) + +# The customized quantization methods which will be added to this dict. +_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = { + "fp8": Fp8Config, +} + + +def register_quantization_config(quantization: str): + """Register a customized vllm quantization config. + + When a quantization method is not supported by vllm, you can register a customized + quantization config to support it. + + Args: + quantization (str): The quantization method name. + + + """ # noqa: E501 + + def _wrapper(quant_config_cls): + if quantization in QUANTIZATION_METHODS: + raise ValueError( + f"The quantization method `{quantization}` is already exists." + ) + if not issubclass(quant_config_cls, QuantizationConfig): + raise ValueError( + "The quantization config must be a subclass of " "`QuantizationConfig`." + ) + _CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls + QUANTIZATION_METHODS.append(quantization) + return quant_config_cls + + return _wrapper + + +def get_quantization_config(quantization: str) -> type[QuantizationConfig]: + if quantization not in QUANTIZATION_METHODS: + raise ValueError(f"Invalid quantization method: {quantization}") + + method_to_config: dict[str, type[QuantizationConfig]] = {} + # Update the `method_to_config` with customized quantization methods. + method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) + + return method_to_config[quantization] + + +__all__ = [ + "QuantizationMethods", + "QuantizationConfig", + "get_quantization_config", + "QUANTIZATION_METHODS", +] diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/quantization/configs/base_config.py b/sglang/python/sglang/multimodal_gen/runtime/layers/quantization/configs/base_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f4afe7d015f5736f9c7e877fde21d1ae62cf4daa --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/quantization/configs/base_config.py @@ -0,0 +1,155 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/quantization/base_config.py + +import inspect +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +import torch +from torch import nn + +if TYPE_CHECKING: + from sglang.multimodal_gen.runtime.layers.quantization import QuantizationMethods +else: + QuantizationMethods = str + + +class QuantizeMethodBase(ABC): + """Base class for different quantized methods.""" + + @abstractmethod + def create_weights( + self, layer: torch.nn.Module, *weight_args, **extra_weight_attrs + ): + """Create weights for a layer. + + The weights will be set as attributes of the layer.""" + raise NotImplementedError + + @abstractmethod + def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + # Not required functions + def embedding(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: + """Gather embeddings in the layer based on indices in the input tensor. + + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + def process_weights_after_loading(self, layer: nn.Module) -> None: + """Process the weight after loading. + + This can be used for example, to transpose weights for computation. + """ + return + + +def method_has_implemented_embedding(method_class: type[QuantizeMethodBase]) -> bool: + """ + Not all quant methods have embedding implemented, so we need to check that + it exists for our given method. We check this by making sure the function + has been changed from the base implementation. + """ + base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", None) + class_embedding = inspect.getattr_static(method_class, "embedding", None) + + return class_embedding is not None and class_embedding is not base_embedding + + +class QuantizationConfig(ABC): + """Base class for quantization configs.""" + + # for quantization frameworks with a separate quantized model provided, e.g. Nunchaku + quantized_model_path: str | None = None + + def __init__(self): + super().__init__() + # mapping is updated by models as they initialize + self.packed_modules_mapping: dict[str, list[str]] = dict() + + @abstractmethod + def get_name(self) -> QuantizationMethods: + """Name of the quantization method.""" + raise NotImplementedError + + @abstractmethod + def get_supported_act_dtypes(self) -> list[torch.dtype]: + """List of supported activation dtypes.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + """Minimum GPU capability to support the quantization method. + + E.g., 70 for Volta, 75 for Turing, 80 for Ampere. + This requirement is due to the custom CUDA kernels used by the + quantization method. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_config_filenames() -> list[str]: + """List of filenames to search for in the model directory.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig": + """Create a config class from the model's quantization config.""" + raise NotImplementedError + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant + ) -> QuantizationMethods | None: + """ + Detects if this quantization method can support a given checkpoint + format by overriding the user specified quantization method -- + this method should only be overwritten by subclasses in exceptional + circumstances + """ + return None + + @staticmethod + def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any: + """Get a value from the model's quantization config.""" + for key in keys: + if key in config: + return config[key] + raise ValueError( + f"Cannot find any of {keys} in the model's " "quantization config." + ) + + @staticmethod + def get_from_keys_or(config: dict[str, Any], keys: list[str], default: Any) -> Any: + """Get a optional value from the model's quantization config.""" + try: + return QuantizationConfig.get_from_keys(config, keys) + except ValueError: + return default + + @abstractmethod + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> QuantizeMethodBase | None: + """Get the quantize method to use for the quantized layer. + + Args: + layer: The layer for the quant method. + prefix: The full name of the layer in the state dict + Returns: + The quantize method. None if the given layer doesn't support quant + method. + """ + raise NotImplementedError + + def get_cache_scale(self, name: str) -> str | None: + return None diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/quantization/configs/nunchaku_config.py b/sglang/python/sglang/multimodal_gen/runtime/layers/quantization/configs/nunchaku_config.py new file mode 100644 index 0000000000000000000000000000000000000000..a8c22b40709419e2280838c8a15c6ff329d05ef1 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/quantization/configs/nunchaku_config.py @@ -0,0 +1,283 @@ +# SPDX-License-Identifier: Apache-2.0 +import json +import os +from dataclasses import dataclass +from functools import lru_cache +from typing import Any, Optional + +import torch +from safetensors.torch import load_file as safetensors_load_file +from torch import nn + +from sglang.multimodal_gen.runtime.layers.linear import LinearBase +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +from .base_config import QuantizationConfig, QuantizeMethodBase + +logger = init_logger(__name__) + + +@lru_cache(maxsize=1) +def is_nunchaku_available() -> bool: + try: + import nunchaku # noqa + + logger.debug("Nunchaku package detected") + return True + except Exception: + return False + + +@dataclass +class NunchakuConfig(QuantizationConfig): + """ + Configuration for Nunchaku (SVDQuant) W4A4-style quantization. + + Attributes: + precision: Quantization precision type. Options: + - "int4": Standard INT4 quantization + - "nvfp4": FP4 quantization + rank: SVD low-rank dimension for absorbing outliers + group_size: Quantization group size (automatically set based on precision) + act_unsigned: Use unsigned activation quantization + transformer_weights_path: Path to pre-quantized transformer weights (.safetensors) + model_cls: DiT model class that provides quantization rules via get_nunchaku_quant_rules() + """ + + precision: str = "int4" + rank: int = 32 + group_size: Optional[int] = None + act_unsigned: bool = False + transformer_weights_path: Optional[str] = None + model_cls: Optional[type] = None + + @classmethod + def get_name(cls) -> str: + return "svdquant" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @staticmethod + def get_config_filenames() -> list[str]: + return ["quantization_config.json", "quant_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "NunchakuConfig": + + return cls( + precision=config.get("precision", "int4"), + rank=int(config.get("rank", 32)), + group_size=config.get("group_size"), + act_unsigned=bool(config.get("act_unsigned", False)), + transformer_weights_path=config.get("transformer_weights_path"), + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[QuantizeMethodBase]: + if not isinstance(layer, LinearBase): + return None + + # get quantization rules from model class + quant_rules = self._get_quant_rules() + + # priority: skip > awq_w4a16 > svdq_w4a4 > default + skip_patterns = quant_rules.get("skip", []) + for pattern in skip_patterns: + if pattern in prefix.lower(): + return None + + awq_patterns = quant_rules.get("awq_w4a16", []) + for pattern in awq_patterns: + if pattern in prefix: + from ..nunchaku_linear import NunchakuAWQLinearMethod + + return NunchakuAWQLinearMethod(group_size=64) + + svdq_patterns = quant_rules.get("svdq_w4a4", []) + for pattern in svdq_patterns: + if pattern in prefix: + from ..nunchaku_linear import NunchakuSVDQLinearMethod + + return NunchakuSVDQLinearMethod( + precision=self.precision, + rank=self.rank, + act_unsigned=self.act_unsigned, + ) + + # default: apply svdq_w4a4 to all remaining linear layers + from ..nunchaku_linear import NunchakuSVDQLinearMethod + + return NunchakuSVDQLinearMethod( + precision=self.precision, + rank=self.rank, + act_unsigned=self.act_unsigned, + ) + + def _get_quant_rules(self) -> dict[str, list[str]]: + if self.model_cls is not None and hasattr( + self.model_cls, "get_nunchaku_quant_rules" + ): + return self.model_cls.get_nunchaku_quant_rules() + return {} + + def __post_init__(self): + if self.group_size is None: + if self.precision == "nvfp4": + self.group_size = 16 + elif self.precision == "int4": + self.group_size = 64 + else: + raise ValueError( + f"Invalid precision: {self.precision}. Must be 'int4' or 'nvfp4'" + ) + + if self.precision not in ["int4", "nvfp4"]: + raise ValueError( + f"Invalid precision: {self.precision}. Must be 'int4' or 'nvfp4'" + ) + + if self.rank <= 0: + raise ValueError(f"Rank must be positive, got {self.rank}") + + @classmethod + def from_dict(cls, config_dict: dict) -> "NunchakuConfig": + """Create configuration from dictionary.""" + return cls(**config_dict) + + def to_dict(self) -> dict: + """Convert configuration to dictionary.""" + return { + "precision": self.precision, + "rank": self.rank, + "group_size": self.group_size, + "act_unsigned": self.act_unsigned, + "transformer_weights_path": self.transformer_weights_path, + } + + @classmethod + def from_pretrained(cls, model_path: str) -> Optional["NunchakuConfig"]: + for filename in cls.get_config_filenames(): + config_path = os.path.join(model_path, filename) + if os.path.exists(config_path): + with open(config_path, "r") as f: + config_dict = json.load(f) + if config_dict.get("quant_method") == cls.get_name(): + return cls.from_config(config_dict) + return None + + +def _patch_native_svdq_linear( + module: nn.Module, tensor: Any, svdq_linear_cls: type +) -> bool: + if ( + isinstance(module, svdq_linear_cls) + and getattr(module, "wtscale", None) is not None + ): + module.wtscale = tensor + return True + return False + + +def _patch_sglang_svdq_linear( + module: nn.Module, tensor: Any, svdq_method_cls: type +) -> bool: + quant_method = getattr(module, "quant_method", None) + if not isinstance(quant_method, svdq_method_cls): + return False + + existing = getattr(module, "wtscale", None) + if isinstance(existing, nn.Parameter): + with torch.no_grad(): + existing.data.copy_(tensor.to(existing.data.dtype)) + else: + module.wtscale = tensor + + # Keep alpha in sync (kernel reads `layer._nunchaku_alpha`) + try: + module._nunchaku_alpha = float(tensor.detach().cpu().item()) + except Exception: + module._nunchaku_alpha = None + return True + + +def _patch_sglang_svdq_wcscales( + module: nn.Module, tensor: Any, svdq_method_cls: type +) -> bool: + quant_method = getattr(module, "quant_method", None) + if not isinstance(quant_method, svdq_method_cls): + return False + + existing = getattr(module, "wcscales", None) + if isinstance(existing, nn.Parameter): + with torch.no_grad(): + existing.data.copy_(tensor.to(existing.data.dtype)) + else: + module.wcscales = tensor + return True + + +def _patch_nunchaku_scales( + model: nn.Module, + safetensors_list: list[str], +) -> None: + """Patch transformer module with Nunchaku scale tensors from safetensors weights. + + For NVFP4 checkpoints, correctness depends on `wtscale` and attention + `wcscales`. The FSDP loader may skip some of these metadata tensors. + """ + + if not safetensors_list: + return + + if len(safetensors_list) != 1: + logger.warning( + "Nunchaku scale patch expects a single safetensors file, " + "but got %d files. Skipping.", + len(safetensors_list), + ) + return + + from nunchaku.models.linear import SVDQW4A4Linear # type: ignore[import] + + state_dict = safetensors_load_file(safetensors_list[0]) + if state_dict is None: + return + + num_wtscale = 0 + num_wcscales = 0 + + from ..nunchaku_linear import NunchakuSVDQLinearMethod + + for name, module in model.named_modules(): + wt = state_dict.get(f"{name}.wtscale") + if wt is not None: + if _patch_native_svdq_linear(module, wt, SVDQW4A4Linear): + num_wtscale += 1 + elif _patch_sglang_svdq_linear(module, wt, NunchakuSVDQLinearMethod): + num_wtscale += 1 + + wc = state_dict.get(f"{name}.wcscales") + if wc is not None: + # Some modules may have wcscales as a direct attribute/Parameter. + existing = getattr(module, "wcscales", None) + if isinstance(existing, nn.Parameter): + with torch.no_grad(): + existing.data.copy_(wc.to(existing.data.dtype)) + num_wcscales += 1 + elif existing is not None: + setattr(module, "wcscales", wc) + num_wcscales += 1 + elif _patch_sglang_svdq_wcscales(module, wc, NunchakuSVDQLinearMethod): + num_wcscales += 1 + + if num_wtscale > 0: + logger.info("Patched wtscale for %d layers", num_wtscale) + if num_wcscales > 0: + logger.info("Patched wcscales for %d layers", num_wcscales) diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/quantization/fp8.py b/sglang/python/sglang/multimodal_gen/runtime/layers/quantization/fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..fcde0ab8821a94f05ce82e9a60a9f14f95b1470d --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/quantization/fp8.py @@ -0,0 +1,498 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_tensor_model_parallel_world_size, +) +from sglang.multimodal_gen.runtime.layers.linear import ( + LinearMethodBase, + UnquantizedLinearMethod, +) +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.multimodal_gen.runtime.models.parameter import ( + BlockQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.utils.common import ( + cpu_has_amx_support, + get_bool_env_var, + use_intel_amx_backend, +) +from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading +from sglang.srt.layers.quantization.fp8_kernel import ( + is_fp8_fnuz, + per_token_group_quant_fp8, +) +from sglang.srt.layers.quantization.fp8_utils import ( + apply_fp8_linear, + can_auto_enable_marlin_fp8, + cutlass_fp8_supported, + dispatch_w8a8_block_fp8_linear, + input_to_float8, + normalize_e4m3fn_to_e4m3fnuz, + requant_weight_ue8m0_inplace, +) +from sglang.srt.layers.quantization.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, + prepare_fp8_layer_for_marlin, +) +from sglang.srt.layers.quantization.utils import ( + convert_to_channelwise, + is_layer_skipped, + requantize_with_max_scale, +) + +if TYPE_CHECKING: + from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config + +_is_hip = current_platform.is_hip() +_is_cuda = current_platform.is_cuda() +_is_npu = current_platform.is_npu() +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = current_platform.is_cpu() +_is_fp8_fnuz = is_fp8_fnuz() +_use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT") and _is_hip +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip + +if _use_aiter or _use_hip_int4: + pass + + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = logging.getLogger(__name__) + + +class Fp8Config(QuantizationConfig): + """Config class for FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + activation_scheme: str = "dynamic", + ignored_layers: Optional[List[str]] = None, + weight_block_size: List[int] = None, + ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.info("Detected fp8 checkpoint.") + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError(f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + self.ignored_layers = ignored_layers or [] + if weight_block_size is not None: + if not is_checkpoint_fp8_serialized: + raise ValueError( + f"The block-wise quantization only supports fp8-serialized checkpoint for now." + ) + if len(weight_block_size) != 2: + raise ValueError( + f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions." + ) + if activation_scheme != "dynamic": + raise ValueError( + f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme." + ) + self.weight_block_size = weight_block_size + + @classmethod + def get_name(cls) -> str: + return "fp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> Fp8Config: + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = "fp8" in quant_method + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + ignored_layers = cls.get_from_keys_or( + config, ["ignored_layers", "modules_to_not_convert"], None + ) + if ignored_layers: + # hacking ministral + ignored_layers = [layer.replace("model.", "") for layer in ignored_layers] + weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) + return cls( + is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + weight_block_size=weight_block_size, + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[QuantizeMethodBase]: + from sglang.multimodal_gen.runtime.layers.linear import LinearBase + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.ignored_layers): + return UnquantizedLinearMethod() + return Fp8LinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class Fp8LinearMethod(LinearMethodBase): + """Linear method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn data type due to the limitation of + torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Union[Fp8Config, W4AFp8Config]): + self.quant_config = quant_config + self.cutlass_fp8_supported = cutlass_fp8_supported() + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + self.use_marlin = False + if _is_cuda: + force_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") + auto_enable = can_auto_enable_marlin_fp8() + self.use_marlin = force_marlin or auto_enable + + self.block_quant = self.quant_config.weight_block_size is not None + + self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + tp_size = get_tensor_model_parallel_world_size() + if self.block_quant: + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # Required by row parallel + if tp_size > 1 and input_size // input_size_per_partition == tp_size: + if input_size_per_partition % block_k != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}." + ) + # Required by column parallel or enabling merged weights + if ( + tp_size > 1 and output_size // output_size_per_partition == tp_size + ) or len(output_partition_sizes) > 1: + for output_partition_size in output_partition_sizes: + if output_partition_size % block_n != 0: + raise ValueError( + f"Weight output_partition_size = " + f"{output_partition_size} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # WEIGHT + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized + else params_dtype + ) + + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, input_size_per_partition, dtype=weight_dtype + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # If checkpoint is serialized fp8, load them. + # Otherwise, wait until process_weights_after_loading. + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALE + if self.block_quant: + if hasattr(self.quant_config, "activation_scheme"): + assert self.quant_config.activation_scheme == "dynamic" + elif hasattr(self.quant_config, "linear_activation_scheme"): + assert self.quant_config.linear_activation_scheme == "dynamic" + scale = BlockQuantScaleParameter( + data=torch.empty( + (output_size_per_partition + block_n - 1) // block_n, + (input_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + scale.format_ue8m0 = False + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale_inv", scale) + else: + scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", scale) + + # INPUT ACTIVATION SCALE + if ( + hasattr(self.quant_config, "activation_scheme") + and self.quant_config.activation_scheme == "static" + ) or ( + hasattr(self.quant_config, "linear_activation_scheme") + and self.quant_config.linear_activation_scheme == "static" + ): + scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", scale) + else: + layer.register_parameter("input_scale", None) + + def process_weights_after_loading(self, layer: Module) -> None: + if self.block_quant: + # If ROCm, normalize the weights and scales to e4m3fnuz + if _is_fp8_fnuz: + # activation_scheme: dynamic + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=layer.weight, + weight_scale=layer.weight_scale_inv, + input_scale=None, + ) + layer.input_scale = None + elif _is_cpu: + assert ( + _is_cpu_amx_available + ), "Fp8LinearMethod on CPU requires that CPU has AMX support" + _amx_process_weight_after_loading(layer, ["weight"]) + layer.weight_scale_inv = torch.nn.Parameter( + layer.weight_scale_inv.data, requires_grad=False + ) + return + else: + # For fp8 linear weights run with deepgemm, the weights and scales need be requantized to ue8m0 + from sglang.srt.layers.quantization.fp8_utils import ( + deepgemm_w8a8_block_fp8_linear_with_fallback, + ) + from sglang.srt.model_loader.utils import ( + should_deepgemm_weight_requant_ue8m0, + ) + + if ( + should_deepgemm_weight_requant_ue8m0( + weight_block_size=getattr( + self.quant_config, "weight_block_size", None + ), + ) + and ( + self.w8a8_block_fp8_linear + is deepgemm_w8a8_block_fp8_linear_with_fallback + ) + and (not layer.weight_scale_inv.format_ue8m0) + ): + requant_weight_ue8m0_inplace( + layer.weight, + layer.weight_scale_inv, + self.quant_config.weight_block_size, + ) + layer.weight_scale_inv.format_ue8m0 = True + weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data + + layer.weight.data = weight.data + layer.weight_scale_inv.data = weight_scale.data + else: + layer.weight = Parameter(layer.weight.data, requires_grad=False) + + # If checkpoint not serialized fp8, quantize the weights. + if not self.quant_config.is_checkpoint_fp8_serialized: + if self.cutlass_fp8_supported or self.use_marlin: + # apply per-channel quantization default as + # cutlass sgl-kernel and marlin only support per-channel scale + qweight, weight_scale = per_token_group_quant_fp8( + layer.weight, layer.weight.shape[-1] + ) + weight_scale = weight_scale.t().contiguous() + else: + # per-tensor quantization + qweight, weight_scale = input_to_float8(layer.weight) + + # Update the layer with the new values. + layer.weight = Parameter(qweight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.input_scale = None + + # If checkpoint is fp8, handle that there are N scales for N + # shards in a fused module + else: + layer.weight_scale = Parameter( + layer.weight_scale.data, requires_grad=False + ) + if ( + hasattr(self.quant_config, "activation_scheme") + and self.quant_config.activation_scheme == "static" + ) or ( + hasattr(self.quant_config, "linear_activation_scheme") + and self.quant_config.linear_activation_scheme == "static" + ): + layer.input_scale = Parameter( + layer.input_scale.data, requires_grad=False + ) + + # cutlass sgl-kernel and marlin only support per-channel scale + if self.cutlass_fp8_supported or self.use_marlin: + weight = layer.weight + weight_scale = convert_to_channelwise( + layer.weight_scale, layer.logical_widths + ) + else: + # Dequant -> Quant with max scale so we can run per tensor. + weight = layer.weight + weight_scale = layer.weight_scale + # If ROCm, normalize the weights and scales to e4m3fnuz + if _is_fp8_fnuz: + weight, weight_scale, input_scale = ( + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=weight_scale, + input_scale=layer.input_scale, + ) + ) + if input_scale is not None: + layer.input_scale = Parameter( + input_scale, requires_grad=False + ) + + weight_scale, weight = requantize_with_max_scale( + weight=weight, + weight_scale=weight_scale, + logical_widths=layer.logical_widths, + ) + + # Update layer with new values. + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + if ( + hasattr(self.quant_config, "activation_scheme") + and self.quant_config.activation_scheme == "static" + ) or ( + hasattr(self.quant_config, "linear_activation_scheme") + and self.quant_config.linear_activation_scheme == "static" + ): + layer.input_scale = Parameter( + layer.input_scale.max(), requires_grad=False + ) + + if self.use_marlin: + if self.block_quant: + layer.weight_block_size = self.quant_config.weight_block_size + prepare_fp8_layer_for_marlin(layer, not self.block_quant) + # Activations not quantized for marlin. + del layer.input_scale + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.use_marlin: + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) + + if self.block_quant: + if use_intel_amx_backend(layer): + return torch.ops.sgl_kernel.fp8_scaled_mm_cpu( + x, + layer.weight, + layer.weight_scale_inv, + self.quant_config.weight_block_size, + bias, + x.dtype, + True, # is_vnni + ) + + if isinstance(x, tuple): + return self.w8a8_block_fp8_linear( + input=x[0], + weight=layer.weight, + block_size=self.quant_config.weight_block_size, + weight_scale=layer.weight_scale_inv, + input_scale=x[1], + bias=bias, + ) + + return self.w8a8_block_fp8_linear( + input=x, + weight=layer.weight, + block_size=self.quant_config.weight_block_size, + weight_scale=layer.weight_scale_inv, + input_scale=None, + bias=bias, + ) + + return apply_fp8_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + cutlass_fp8_supported=self.cutlass_fp8_supported, + use_per_token_if_dynamic=False, + ) diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/quantization/nunchaku_linear.py b/sglang/python/sglang/multimodal_gen/runtime/layers/quantization/nunchaku_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..516f7669992f65726a987b147aeb7011ec1538b7 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/quantization/nunchaku_linear.py @@ -0,0 +1,291 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import List, Optional + +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter + +from sglang.multimodal_gen.runtime.layers.linear import LinearMethodBase +from sglang.multimodal_gen.runtime.models.utils import set_weight_attrs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +try: + from nunchaku.ops.gemm import svdq_gemm_w4a4_cuda + from nunchaku.ops.gemv import awq_gemv_w4a16_cuda + from nunchaku.ops.quantize import svdq_quantize_w4a4_act_fuse_lora_cuda +except ImportError: + svdq_gemm_w4a4_cuda = None + awq_gemv_w4a16_cuda = None + svdq_quantize_w4a4_act_fuse_lora_cuda = None + + +class NunchakuSVDQLinearMethod(LinearMethodBase): + def __init__( + self, + precision: str = "int4", + rank: int = 32, + act_unsigned: bool = False, + ): + self.precision = precision + self.rank = rank + self.act_unsigned = act_unsigned + + if precision == "nvfp4": + self.group_size = 16 + else: + self.group_size = 64 + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + + qweight = Parameter( + torch.empty( + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.int8, + ), + requires_grad=False, + ) + set_weight_attrs(qweight, {"input_dim": 1, "output_dim": 0}) + + num_groups = input_size_per_partition // self.group_size + if self.precision == "nvfp4": + scale_dtype = torch.float8_e4m3fn + else: + scale_dtype = params_dtype + wscales = Parameter( + torch.empty(num_groups, output_size_per_partition, dtype=scale_dtype), + requires_grad=False, + ) + + smooth_factor = Parameter( + torch.empty(input_size_per_partition, dtype=params_dtype), + requires_grad=False, + ) + + smooth_factor_orig = Parameter( + torch.empty(input_size_per_partition, dtype=params_dtype), + requires_grad=False, + ) + + proj_down = Parameter( + torch.empty(input_size_per_partition, self.rank, dtype=params_dtype), + requires_grad=False, + ) + proj_up = Parameter( + torch.empty(output_size_per_partition, self.rank, dtype=params_dtype), + requires_grad=False, + ) + + if self.precision == "nvfp4": + wcscales = Parameter( + torch.empty( + output_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + wtscale = Parameter( + torch.empty(1, dtype=params_dtype), + requires_grad=False, + ) + else: + wcscales = None + wtscale = None + + layer.register_parameter("qweight", qweight) + layer.register_parameter("wscales", wscales) + layer.register_parameter("smooth_factor", smooth_factor) + layer.register_parameter("smooth_factor_orig", smooth_factor_orig) + layer.register_parameter("proj_down", proj_down) + layer.register_parameter("proj_up", proj_up) + if wcscales is not None: + layer.register_parameter("wcscales", wcscales) + if wtscale is not None: + layer.register_parameter("wtscale", wtscale) + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.precision = self.precision + layer.rank = self.rank + layer.group_size = self.group_size + layer.act_unsigned = self.act_unsigned + + weight_loader = extra_weight_attrs.get("weight_loader") + if weight_loader is not None: + set_weight_attrs(qweight, {"weight_loader": weight_loader}) + set_weight_attrs(wscales, {"weight_loader": weight_loader}) + set_weight_attrs(smooth_factor, {"weight_loader": weight_loader}) + set_weight_attrs(smooth_factor_orig, {"weight_loader": weight_loader}) + set_weight_attrs(proj_down, {"weight_loader": weight_loader}) + set_weight_attrs(proj_up, {"weight_loader": weight_loader}) + if wcscales is not None: + set_weight_attrs(wcscales, {"weight_loader": weight_loader}) + if wtscale is not None: + set_weight_attrs(wtscale, {"weight_loader": weight_loader}) + + def process_weights_after_loading(self, layer: nn.Module) -> None: + layer.qweight = Parameter(layer.qweight.data, requires_grad=False) + layer.wscales = Parameter(layer.wscales.data, requires_grad=False) + layer.smooth_factor = Parameter(layer.smooth_factor.data, requires_grad=False) + layer.smooth_factor_orig = Parameter( + layer.smooth_factor_orig.data, requires_grad=False + ) + layer.proj_down = Parameter(layer.proj_down.data, requires_grad=False) + layer.proj_up = Parameter(layer.proj_up.data, requires_grad=False) + if hasattr(layer, "wcscales") and layer.wcscales is not None: + layer.wcscales = Parameter(layer.wcscales.data, requires_grad=False) + if hasattr(layer, "wtscale") and layer.wtscale is not None: + layer.wtscale = Parameter(layer.wtscale.data, requires_grad=False) + + alpha: float | None = None + wtscale = getattr(layer, "wtscale", None) + if wtscale is not None: + if isinstance(wtscale, Parameter): + wtscale = wtscale.data + if isinstance(wtscale, torch.Tensor): + alpha = float(wtscale.detach().cpu().item()) + else: + alpha = float(wtscale) + layer._nunchaku_alpha = alpha + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + orig_shape = x.shape + x_2d = x.reshape(-1, orig_shape[-1]) + quantized_x, ascales, lora_act_out = svdq_quantize_w4a4_act_fuse_lora_cuda( + x_2d, + lora_down=layer.proj_down, + smooth=layer.smooth_factor, + fp4=layer.precision == "nvfp4", + pad_size=256, + ) + out_2d = torch.empty( + x_2d.shape[0], + layer.output_size_per_partition, + dtype=x_2d.dtype, + device=x_2d.device, + ) + alpha: float | None = getattr(layer, "_nunchaku_alpha", None) + wcscales = getattr(layer, "wcscales", None) + + svdq_gemm_w4a4_cuda( + act=quantized_x, + wgt=layer.qweight, + out=out_2d, + ascales=ascales, + wscales=layer.wscales, + lora_act_in=lora_act_out, + lora_up=layer.proj_up, + bias=bias, + fp4=layer.precision == "nvfp4", + alpha=alpha, + wcscales=wcscales, + act_unsigned=getattr(layer, "act_unsigned", False), + ) + out = out_2d.reshape(*orig_shape[:-1], layer.output_size_per_partition) + return out + + +class NunchakuAWQLinearMethod(LinearMethodBase): + def __init__(self, group_size: int = 64): + self.group_size = group_size + self.pack_factor = 8 + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + + qweight = Parameter( + torch.empty( + output_size_per_partition // 4, + input_size_per_partition // 2, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs(qweight, {"input_dim": 1, "output_dim": 0}) + + num_groups = input_size_per_partition // self.group_size + wscales = Parameter( + torch.empty(num_groups, output_size_per_partition, dtype=params_dtype), + requires_grad=False, + ) + + wzeros = Parameter( + torch.empty(num_groups, output_size_per_partition, dtype=params_dtype), + requires_grad=False, + ) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("wscales", wscales) + layer.register_parameter("wzeros", wzeros) + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.group_size = self.group_size + layer.pack_factor = self.pack_factor + + weight_loader = extra_weight_attrs.get("weight_loader") + if weight_loader is not None: + set_weight_attrs(qweight, {"weight_loader": weight_loader}) + set_weight_attrs(wscales, {"weight_loader": weight_loader}) + set_weight_attrs(wzeros, {"weight_loader": weight_loader}) + + def process_weights_after_loading(self, layer: nn.Module) -> None: + layer.qweight = Parameter(layer.qweight.data, requires_grad=False) + layer.wscales = Parameter(layer.wscales.data, requires_grad=False) + layer.wzeros = Parameter(layer.wzeros.data, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + orig_shape = x.shape + x_2d = x.reshape(-1, orig_shape[-1]) + + in_features = layer.input_size_per_partition + out_features = layer.output_size_per_partition + out_2d = awq_gemv_w4a16_cuda( + in_feats=x_2d, + kernel=layer.qweight, + scaling_factors=layer.wscales, + zeros=layer.wzeros, + m=x_2d.shape[0], + n=out_features, + k=in_features, + group_size=layer.group_size, + ) + if bias is not None: + view_shape = [1] * (out_2d.ndim - 1) + [-1] + out_2d.add_(bias.view(view_shape)) + + out = out_2d.reshape(*orig_shape[:-1], out_features) + return out diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/__init__.py b/sglang/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..977c34cc348b3cc5ebe1be82101f07e62b5b9da7 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/__init__.py @@ -0,0 +1,48 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/rotary_embedding.py + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rotary Positional Embeddings — unified public API (drop-in replacement).""" + +from .base import RotaryEmbedding +from .factory import get_rope, get_rotary_pos_embed +from .mrope import NDRotaryEmbedding +from .utils import ( + _apply_rotary_emb, + apply_flashinfer_rope_qk_inplace, +) + +__all__ = [ + # _utils + "_apply_rotary_emb", + "apply_flashinfer_rope_qk_inplace", + # _base + "RotaryEmbedding", + # _mrope + "NDRotaryEmbedding", + # _factory + "get_rope", + "get_rotary_pos_embed", +] diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/base.py b/sglang/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/base.py new file mode 100644 index 0000000000000000000000000000000000000000..6c42ca5d05bc13c9654ef1cbc9b0a6d96f76cf38 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/base.py @@ -0,0 +1,130 @@ +"""RotaryEmbedding base class and LinearScalingRotaryEmbedding variant.""" + +import torch + +from sglang.multimodal_gen.runtime.layers.custom_op import CustomOp + +from .utils import _apply_rotary_emb + + +@CustomOp.register("rotary_embedding") +class RotaryEmbedding(CustomOp): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int | float, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: int | float) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_cuda(self, *args, **kwargs): + return self.forward_native(*args, **kwargs) + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + + +class LinearScalingRotaryEmbedding(RotaryEmbedding): + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int | float, + is_neox_style: bool, + dtype: torch.dtype, + scaling_factor: float, + ) -> None: + self.scaling_factor = float(scaling_factor) + super().__init__( + head_size=head_size, + rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, + base=base, + is_neox_style=is_neox_style, + dtype=dtype, + ) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + t = t / self.scaling_factor + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/factory.py b/sglang/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..807660ea0285ffc8021d6dc4b8829827d4ffe0a4 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/factory.py @@ -0,0 +1,171 @@ +"""get_rope / get_rotary_pos_embed factory functions and module-level caches.""" + +from collections import OrderedDict +from typing import Any + +import torch + +from .base import LinearScalingRotaryEmbedding, RotaryEmbedding +from .mrope import NDRotaryEmbedding, _to_tuple + +_ROPE_DICT: dict[tuple, RotaryEmbedding] = {} +_ND_ROPE_CACHE: "OrderedDict[tuple, NDRotaryEmbedding]" = OrderedDict() +_ROPE_3D_CACHE: "OrderedDict[tuple, tuple[torch.Tensor, torch.Tensor]]" = OrderedDict() + + +def get_rope( + head_size: int, + rotary_dim: int, + max_position: int, + base: int | float, + is_neox_style: bool = True, + rope_scaling: dict[str, Any] | None = None, + dtype: torch.dtype | None = None, + partial_rotary_factor: float = 1.0, +) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() + if rope_scaling is not None: + # Transforms every value that is a list into a tuple for caching calls + rope_scaling_tuple = { + k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items() + } + rope_scaling_args = tuple(rope_scaling_tuple.items()) + else: + rope_scaling_args = None + if partial_rotary_factor < 1.0: + rotary_dim = int(rotary_dim * partial_rotary_factor) + max_position_embeddings = max_position + rope_type = None + if rope_scaling is not None: + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) + if rope_type in (None, "default"): + rope_scaling = None + elif rope_type == "linear": + factor = float(rope_scaling.get("factor", 1.0)) + original_max = rope_scaling.get("original_max_position_embeddings", None) + if original_max is not None: + max_position_embeddings = max( + max_position_embeddings, int(float(original_max) * factor) + ) + key = ( + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + rope_scaling_args, + dtype, + ) + if key in _ROPE_DICT: + return _ROPE_DICT[key] + + if rope_scaling is None: + rotary_emb = RotaryEmbedding( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + else: + if rope_type == "linear": + factor = float(rope_scaling.get("factor", 1.0)) + rotary_emb = LinearScalingRotaryEmbedding( + head_size=head_size, + rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, + base=base, + is_neox_style=is_neox_style, + dtype=dtype, + scaling_factor=factor, + ) + else: + raise ValueError(f"Unknown RoPE scaling {rope_scaling}") + _ROPE_DICT[key] = rotary_emb + return rotary_emb + + +def get_rotary_pos_embed( + rope_sizes, + hidden_size, + heads_num, + rope_dim_list, + rope_theta, + theta_rescale_factor=1.0, + interpolation_factor=1.0, + shard_dim: int = 0, + dtype: torch.dtype = torch.float32, + start_frame: int = 0, + device: torch.device | str | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Generate rotary positional embeddings for the given sizes. + + Args: + rope_sizes: Tuple of dimensions (t, h, w) + hidden_size: Hidden dimension size + heads_num: Number of attention heads + rope_dim_list: List of dimensions for each axis, or None + rope_theta: Base for frequency calculations + theta_rescale_factor: Rescale factor for theta. Defaults to 1.0 + interpolation_factor: Factor to scale positions. Defaults to 1.0 + shard_dim: Which dimension to shard for sequence parallelism. Defaults to 0. + + Returns: + Tuple of (cos, sin) tensors for rotary embeddings + """ + + target_ndim = 3 + head_dim = hidden_size // heads_num + + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + + assert ( + sum(rope_dim_list) == head_dim + ), "sum(rope_dim_list) should equal to head_dim of attention layer" + + # Get SP info - now handled within NDRotaryEmbedding + # sp_group = get_sp_group() + # sp_rank = sp_group.rank_in_group + # sp_world_size = sp_group.world_size + + # Simple LRU cache keyed by parameters + global _ND_ROPE_CACHE + key = ( + tuple(rope_dim_list), + float(rope_theta), + ( + tuple(theta_rescale_factor) + if isinstance(theta_rescale_factor, list) + else float(theta_rescale_factor) + ), + ( + tuple(interpolation_factor) + if isinstance(interpolation_factor, list) + else float(interpolation_factor) + ), + dtype, + ) + + cache_hit = key in _ND_ROPE_CACHE + if cache_hit: + rope_emb = _ND_ROPE_CACHE.pop(key) + _ND_ROPE_CACHE[key] = rope_emb # move to end (most-recent) + else: + rope_emb = NDRotaryEmbedding( + rope_dim_list=rope_dim_list, + rope_theta=rope_theta, + theta_rescale_factor=theta_rescale_factor, + interpolation_factor=interpolation_factor, + dtype=dtype, + ) + _ND_ROPE_CACHE[key] = rope_emb + if len(_ND_ROPE_CACHE) > 16: + # pop least-recently-used + _ND_ROPE_CACHE.pop(next(iter(_ND_ROPE_CACHE))) + + freqs_cos, freqs_sin = rope_emb.forward_from_grid( + grid_size=_to_tuple(rope_sizes, dim=3), + shard_dim=shard_dim, + start_frame=start_frame, + device=device, + ) + return freqs_cos, freqs_sin diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/mrope.py b/sglang/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/mrope.py new file mode 100644 index 0000000000000000000000000000000000000000..27211332bf75ad3cdd120324064258e07a13141f --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/mrope.py @@ -0,0 +1,392 @@ +"""MRotaryEmbedding, YaRNScalingMRotaryEmbedding, NDRotaryEmbedding, OneDRotaryEmbedding.""" + +import functools + +import torch + +from sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_group + + +def _to_tuple(x: int | tuple[int, ...], dim: int = 2) -> tuple[int, ...]: + if isinstance(x, int): + return (x,) * dim + elif len(x) == dim: + return x + else: + raise ValueError(f"Expected length {dim} or int, but got {x}") + + +def get_1d_rotary_pos_embed( + dim: int, + pos: torch.FloatTensor | int, + theta: float = 10000.0, + theta_rescale_factor: float = 1.0, + interpolation_factor: float = 1.0, + dtype: torch.dtype = torch.float32, + device: torch.device | str | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Precompute the frequency tensor for complex exponential (cis) with given dimensions. + (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) + + This function calculates a frequency tensor with complex exponential using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + + Args: + dim (int): Dimension of the frequency tensor. + pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. + interpolation_factor (float, optional): Factor to scale positions. Defaults to 1.0. + + Returns: + freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] + """ + if isinstance(pos, int): + pos = torch.arange(pos, dtype=dtype, device=device) + elif ( + isinstance(pos, torch.Tensor) + and device is not None + and pos.device != torch.device(device) + ): + # Ensure positions are on the requested device to avoid implicit CPU ops. + pos = pos.to(device) + + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + if theta_rescale_factor != 1.0: + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + freqs = 1.0 / ( + theta + ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].to(dtype) / dim).to( + device=device + ) + ) # [D/2] + freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] + freqs_cos = freqs.cos() # [S, D/2] + freqs_sin = freqs.sin() # [S, D/2] + return freqs_cos, freqs_sin + + +class OneDRotaryEmbedding(torch.nn.Module): + """1D rotary positional embedding with caching.""" + + def __init__( + self, + dim: int, + theta: float = 10000.0, + theta_rescale_factor: float = 1.0, + interpolation_factor: float = 1.0, + dtype: torch.dtype = torch.float32, + use_real: bool = False, + repeat_interleave_real: bool = False, + ): + super().__init__() + assert dim % 2 == 0 + self.dim = dim + self.theta = theta + self.theta_rescale_factor = theta_rescale_factor + self.interpolation_factor = interpolation_factor + # dtype of freqs + self.dtype = dtype + self.use_real = use_real + self.repeat_interleave_real = repeat_interleave_real + + def build_freqs(self, device): + freqs = 1.0 / ( + self.theta + ** ( + torch.arange(0, self.dim, 2, dtype=self.dtype, device=device)[ + : (self.dim // 2) + ] + / self.dim + ).to(device=device) + ) + return freqs + + def build_freqs_outer(self, pos: torch.Tensor, device): + theta = self.theta + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + if self.theta_rescale_factor != 1.0: + theta *= self.theta_rescale_factor ** (self.dim / (self.dim - 2)) + + freqs = self.build_freqs(device) + + freqs = torch.outer(pos * self.interpolation_factor, freqs) + freqs_cos = freqs.cos() + freqs_sin = freqs.sin() + + if self.use_real and self.repeat_interleave_real: + freqs_cos = freqs_cos.repeat_interleave(2, dim=1) + freqs_sin = freqs_sin.repeat_interleave(2, dim=1) + + return freqs_cos.float(), freqs_sin.float() + + @functools.lru_cache(maxsize=16) + def forward_from_grid( + self, seq_len: int, start_pos: int, device_str: str + ) -> tuple[torch.Tensor, torch.Tensor]: + device = torch.device(device_str) + pos = torch.arange( + start_pos, start_pos + seq_len, dtype=self.dtype, device=device + ) + + freqs_cos, freqs_sin = self.build_freqs_outer(pos, device) + return freqs_cos, freqs_sin + + def forward(self, pos: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculates 1D rotary embeddings for the given positions. + + This method converts the input tensor to a hashable representation + and calls a cached helper method to perform the computation. + """ + pos_tuple = tuple(pos.tolist()) + device_str = str(pos.device) + return self._forward_cached(pos_tuple, device_str) + + @functools.lru_cache(maxsize=16) + def _forward_cached( + self, pos_tuple: tuple, device_str: str + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + The core implementation that computes 1D rotary embeddings. + This method is wrapped by an LRU cache. + """ + device = torch.device(device_str) + pos = torch.as_tensor(pos_tuple, dtype=self.dtype, device=device) + freqs_cos, freqs_sin = self.build_freqs_outer(pos, device) + return freqs_cos, freqs_sin + + +class NDRotaryEmbedding(torch.nn.Module): + """N-dimensional rotary positional embedding.""" + + def __init__( + self, + rope_dim_list: list[int], + rope_theta: float, + theta_rescale_factor: float | list[float] = 1.0, + interpolation_factor: float | list[float] = 1.0, + use_real: bool = False, + repeat_interleave_real: bool = False, + dtype: torch.dtype = torch.float32, + ): + super().__init__() + self.rope_dim_list = rope_dim_list + self.ndim = len(rope_dim_list) + self.rope_theta = rope_theta + # dtype of freqs + # does not control the output dtype + self.dtype = dtype + + if isinstance(theta_rescale_factor, (int, float)): + self.theta_rescale_factor = [theta_rescale_factor] * self.ndim + elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: + self.theta_rescale_factor = [theta_rescale_factor[0]] * self.ndim + else: + self.theta_rescale_factor = theta_rescale_factor + assert ( + len(self.theta_rescale_factor) == self.ndim + ), "len(theta_rescale_factor) should equal to len(rope_dim_list)" + + if isinstance(interpolation_factor, (int, float)): + self.interpolation_factor = [interpolation_factor] * self.ndim + elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: + self.interpolation_factor = [interpolation_factor[0]] * self.ndim + else: + self.interpolation_factor = interpolation_factor + assert ( + len(self.interpolation_factor) == self.ndim + ), "len(interpolation_factor) should equal to len(rope_dim_list)" + + self.rope_generators: list[OneDRotaryEmbedding] = torch.nn.ModuleList() + _config_to_gen_idx: dict[tuple, int] = {} + self.dim_idx_to_gen_idx: list[int] = [] + + for i in range(self.ndim): + dim = self.rope_dim_list[i] + rescale = self.theta_rescale_factor[i] + interp = self.interpolation_factor[i] + + config_key = (dim, rescale, interp, use_real, repeat_interleave_real) + if config_key not in _config_to_gen_idx: + generator = OneDRotaryEmbedding( + dim=dim, + theta=self.rope_theta, + theta_rescale_factor=rescale, + interpolation_factor=interp, + dtype=self.dtype, + use_real=use_real, + repeat_interleave_real=repeat_interleave_real, + ) + _config_to_gen_idx[config_key] = len(self.rope_generators) + self.rope_generators.append(generator) + + gen_idx = _config_to_gen_idx[config_key] + self.dim_idx_to_gen_idx.append(gen_idx) + + def forward(self, positions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculates n-d rotary embeddings for given absolute positions. + + Args: + positions (torch.Tensor): A tensor of shape `[num_tokens, ndim]` + containing the integer coordinates for each token. + + Returns: + A tuple of (cos, sin) tensors. + """ + # Caching wrapper: convert tensor to a hashable tuple of tuples. + pos_tuple = tuple(map(tuple, positions.tolist())) + device_str = str(positions.device) + return self._forward_cached(pos_tuple, device_str) + + @functools.lru_cache(maxsize=16) + def _forward_cached( + self, pos_tuple: tuple[tuple[int, ...], ...], device_str: str + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + The core implementation that computes embeddings from a position tensor. + This method is wrapped by an LRU cache. + """ + device = torch.device(device_str) + positions = torch.tensor(pos_tuple, dtype=torch.long, device=device) + return self.forward_uncached(pos=positions) + + def forward_uncached(self, pos: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + The core implementation that computes embeddings from a position tensor. + This method is wrapped by an LRU cache. + """ + device = pos.device + + # Pre-allocate the final tensors for efficiency. + num_tokens = pos.shape[0] + first_generator = self.rope_generators[0] + if first_generator.use_real and first_generator.repeat_interleave_real: + head_dim = sum(self.rope_dim_list) + else: + head_dim = sum(self.rope_dim_list) // 2 + + cos = torch.empty((num_tokens, head_dim), device=device, dtype=self.dtype) + sin = torch.empty((num_tokens, head_dim), device=device, dtype=self.dtype) + + col_offset = 0 + for i in range(self.ndim): + # Extract position coordinates for the current dimension for all tokens. + pos_i = pos[:, i].to(self.dtype) + + # Get the appropriate 1D generator. + gen_idx = self.dim_idx_to_gen_idx[i] + generator = self.rope_generators[gen_idx] + + # Calculate 1D embeddings. + cos_1d, sin_1d = generator(pos_i) + + slice_width = cos_1d.shape[1] + cos[:, col_offset : col_offset + slice_width] = cos_1d + sin[:, col_offset : col_offset + slice_width] = sin_1d + col_offset += slice_width + + return cos.float(), sin.float() + + def forward_from_grid( + self, + grid_size: tuple[int, ...], + shard_dim: int = 0, + start_frame: int = 0, + device: torch.device | str | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Handles sp internally + """ + # Caching wrapper: use grid parameters directly as the key. + # grid_tuple = _to_tuple(grid_size, dim=self.ndim) + device_str = str(device) if device is not None else "cpu" + return self._forward_cached_from_grid( + grid_size, shard_dim, start_frame, device_str + ) + + @functools.lru_cache(maxsize=16) + def _forward_cached_from_grid( + self, + grid_size: tuple[int, ...], + shard_dim: int, + start_frame: int, + device_str: str, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Computes embeddings for a structured grid, using a highly efficient + implementation that avoids materializing the full position tensor. + This method is wrapped by an LRU cache. + """ + device = torch.device(device_str) + sp_group = get_sp_group() + sp_rank = sp_group.rank_in_group + sp_world_size = sp_group.world_size + + sizes = _to_tuple(grid_size, dim=self.ndim) + starts = (0,) * self.ndim + + # Apply sequence parallel sharding to the sizes and compute shard offset + shard_sizes = list(sizes) + shard_offsets = [0] * self.ndim + if sp_world_size > 1: + assert sizes[shard_dim] % sp_world_size == 0, ( + f"Dimension {shard_dim} with size {sizes[shard_dim]} is not divisible " + f"by sequence parallel world size {sp_world_size}" + ) + shard_size = sizes[shard_dim] // sp_world_size + shard_offsets[shard_dim] = sp_rank * shard_size + shard_sizes[shard_dim] = shard_size + + # Pre-allocate outputs on the requested device to avoid CPU ops and extra cats + num_tokens = 1 + for s in shard_sizes: + num_tokens *= int(s) + head_dim_half = sum(self.rope_dim_list) // 2 + cos = torch.empty((num_tokens, head_dim_half), device=device, dtype=self.dtype) + sin = torch.empty((num_tokens, head_dim_half), device=device, dtype=self.dtype) + + # Compute per-axis 1D embeddings once and expand via repeats to [N, d_i/2] + col_offset = 0 + for i in range(self.ndim): + dim_i = self.rope_dim_list[i] + dim_i_half = dim_i // 2 + size_i = int(shard_sizes[i]) + + # Starting position for this axis, with optional frame offset for time axis (i==0) + base_offset = starts[i] + if i == 0 and start_frame > 0: + base_offset += start_frame + if sp_world_size > 1 and i == shard_dim: + base_offset += shard_offsets[i] + + gen_idx = self.dim_idx_to_gen_idx[i] + generator = self.rope_generators[gen_idx] + cos_1d, sin_1d = generator.forward_from_grid( + size_i, base_offset, device_str + ) + + # Expand to [num_tokens, dim_i/2] matching flatten order (last dims vary fastest) + repeats_per_entry = 1 + for j in range(i + 1, self.ndim): + repeats_per_entry *= int(shard_sizes[j]) + tile_count = 1 + for j in range(0, i): + tile_count *= int(shard_sizes[j]) + + cos_expanded = cos_1d.repeat_interleave(repeats_per_entry, dim=0) + sin_expanded = sin_1d.repeat_interleave(repeats_per_entry, dim=0) + if tile_count > 1: + cos_expanded = cos_expanded.repeat(tile_count, 1) + sin_expanded = sin_expanded.repeat(tile_count, 1) + + cos[:, col_offset : col_offset + dim_i_half] = cos_expanded + sin[:, col_offset : col_offset + dim_i_half] = sin_expanded + col_offset += dim_i_half + + return cos.float(), sin.float() diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/utils.py b/sglang/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4eaa106e8742e4f00146553d91c54e47f40c16ca --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/utils.py @@ -0,0 +1,121 @@ +"""Primitive RoPE ops: rotate helpers and apply_rotary_emb utilities.""" + +from typing import Optional, Tuple + +import torch + +from sglang.jit_kernel.diffusion.triton.rotary import apply_rotary_embedding + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, + interleaved: bool = False, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] or [num_tokens, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + # cos = cos.unsqueeze(-2).to(x.dtype) + # sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + cos = cos.unsqueeze(-2) + sin = sin.unsqueeze(-2) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = (x1.float() * cos - x2.float() * sin).type_as(x) + o2 = (x2.float() * cos + x1.float() * sin).type_as(x) + return torch.cat((o1, o2), dim=-1) + else: + return apply_rotary_embedding(x, cos, sin, interleaved) + + +def apply_flashinfer_rope_qk_inplace( + q: torch.Tensor, + k: torch.Tensor, + cos_sin_cache: torch.Tensor, + *, + head_size: Optional[int] = None, + is_neox: bool = False, + positions: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if q.dim() != 4 or k.dim() != 4: + raise ValueError( + f"Expected q/k to be 4D [bsz, seqlen, nheads, head_size], " + f"got q:{tuple(q.shape)} k:{tuple(k.shape)}" + ) + if q.shape != k.shape: + raise ValueError( + f"q and k must have the same shape, got {q.shape} vs {k.shape}" + ) + + if not (isinstance(cos_sin_cache, torch.Tensor) and cos_sin_cache.dim() == 2): + raise ValueError("cos_sin_cache must be a 2D torch.Tensor") + + bsz, seqlen, nheads, d = q.shape + if head_size is None: + head_size = d + if head_size != d: + raise ValueError(f"head_size mismatch: inferred {d}, but head_size={head_size}") + + try: + from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace + except ImportError: + # Triton fallback for AMD/ROCm where FlashInfer is not available + import warnings + + warnings.warn( + "FlashInfer not available, using Triton fallback for RoPE", + stacklevel=2, + ) + half_size = cos_sin_cache.shape[-1] // 2 + if positions is None: + cos = cos_sin_cache[:seqlen, :half_size].to(q.dtype) + sin = cos_sin_cache[:seqlen, half_size:].to(q.dtype) + cos = cos.unsqueeze(0).expand(bsz, -1, -1).reshape(bsz * seqlen, -1) + sin = sin.unsqueeze(0).expand(bsz, -1, -1).reshape(bsz * seqlen, -1) + else: + positions = positions.to(cos_sin_cache.device).view(-1) + cos = cos_sin_cache[positions, :half_size].to(q.dtype) + sin = cos_sin_cache[positions, half_size:].to(q.dtype) + q_flat = q.reshape(bsz * seqlen, nheads, d) + k_flat = k.reshape(bsz * seqlen, nheads, d) + q_rot = apply_rotary_embedding(q_flat, cos, sin, interleaved=not is_neox) + k_rot = apply_rotary_embedding(k_flat, cos, sin, interleaved=not is_neox) + return q_rot.view(bsz, seqlen, nheads, d), k_rot.view(bsz, seqlen, nheads, d) + + if positions is None: + pos_1d = torch.arange(seqlen, device=q.device, dtype=torch.long) + positions = pos_1d if bsz == 1 else pos_1d.repeat(bsz) + else: + if not ( + isinstance(positions, torch.Tensor) + and positions.dtype == torch.long + and positions.dim() == 1 + ): + raise ValueError("positions must be a 1D torch.long Tensor") + if positions.numel() != bsz * seqlen: + raise ValueError( + f"positions length must be bsz*seqlen={bsz*seqlen}, got {positions.numel()}" + ) + + q_flat = q.reshape(bsz * seqlen, nheads * d).contiguous() + k_flat = k.reshape(bsz * seqlen, nheads * d).contiguous() + apply_rope_with_cos_sin_cache_inplace( + positions=positions, + query=q_flat, + key=k_flat, + head_size=d, + cos_sin_cache=cos_sin_cache, + is_neox=is_neox, + ) + return q_flat.view(bsz, seqlen, nheads, d), k_flat.view(bsz, seqlen, nheads, d) diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/usp.py b/sglang/python/sglang/multimodal_gen/runtime/layers/usp.py new file mode 100644 index 0000000000000000000000000000000000000000..e822350091ae30a284ab89e89c1a822649f3b995 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/usp.py @@ -0,0 +1,252 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import logging +from typing import TYPE_CHECKING + +import torch +import torch.distributed._functional_collectives as ft_c +from torch.distributed.tensor.experimental._attention import _cp_options + +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_sp_group, + get_ulysses_parallel_world_size, +) +from sglang.srt.utils.common import torch_release + +_cp_options.enable_load_balance = False + +if TYPE_CHECKING: + from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionImpl, + ) + +logger = logging.getLogger(__name__) + + +def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor: + """ + When tracing the code, the result tensor is not an AsyncCollectiveTensor, + so we cannot call ``wait()``. + """ + if isinstance(tensor, ft_c.AsyncCollectiveTensor): + return tensor.wait() + return tensor + + +def _usp_all_to_all_single(x: torch.Tensor) -> torch.Tensor: + ulysses_pg = get_sp_group().ulysses_group + assert ulysses_pg is not None, "Ulysses process group is not initialized." + x_shape = x.shape + x = x.flatten() + x = ft_c.all_to_all_single( + x, output_split_sizes=None, input_split_sizes=None, group=ulysses_pg + ) + x = _maybe_wait(x) + x = x.reshape(x_shape) + return x + + +def _usp_input_all_to_all(x: torch.Tensor, head_dim: int = 1) -> torch.Tensor: + """ + Perform Ulysses-style input all-to-all over the head dimension. + + Default layout expects heads at dim=1 and sequence at dim=2: + [b, h, s_local, d] -> [b, h_local, s_global, d] + + If heads are at dim=2 (input is [b, s_local, h, d]), set head_dim=2, and the + function returns [b, s_global, h_local, d], preserving the original + head/sequence dim ordering. + + Args: + x: A 4D tensor with layout [b, *, *, d] where '*' are sequence and heads + head_dim: Which dimension index corresponds to heads (1 or 2) + + Returns: + Tensor with the same dim order as input, with heads sharded and sequence gathered. + """ + world_size = get_ulysses_parallel_world_size() + if world_size <= 1: + return x + + assert x.ndim == 4, f"x must have 4 dimensions, got {x.ndim}" + assert head_dim in (1, 2), f"head_dim must be 1 or 2, got {head_dim}" + + # Move the dimension to be split (h_global) to dim 0 for all_to_all_single + if head_dim == 1: + b, h_global, s_local, d = x.shape + # Shape transition: [b, h_global, s_local, d] -> [h_global, b, s_local, d] + permute_order = (1, 0, 2, 3) + else: # head_dim == 2 + b, s_local, h_global, d = x.shape + # Shape transition: [b, s_local, h_global, d] -> [h_global, b, s_local, d] + permute_order = (2, 0, 1, 3) + + assert ( + h_global % world_size == 0 + ), f"h_global ({h_global}) must be divisible by world_size ({world_size})" + + h_local, s_global = h_global // world_size, s_local * world_size + + x = x.permute(permute_order).contiguous() + x = _usp_all_to_all_single(x) + x = x.reshape(world_size, h_local, b, s_local, d) + + # Reorder dims to place 'world_size' adjacent to 's_local' to merge them into 's_global' + if head_dim == 1: + # Shape transition: [world_size, h_local, b, s_local, d] -> [b, h_local, world_size, s_local, d] + x = x.permute(2, 1, 0, 3, 4).contiguous().reshape(b, h_local, s_global, d) + else: # head_dim == 2 + # Shape transition: [world_size, h_local, b, s_local, d] -> [b, world_size, s_local, h_local, d] + x = x.permute(2, 0, 3, 1, 4).contiguous().reshape(b, s_global, h_local, d) + + return x + + +def _usp_output_all_to_all(x: torch.Tensor, head_dim: int = 1) -> torch.Tensor: + """ + Perform Ulysses-style output all-to-all over the head dimension (inverse of input). + + Default layout expects heads at dim=1 and sequence at dim=2: + [b, h_local, s, d] -> [b, h, s_local, d] + + If heads are at dim=2 (input is [b, s_global, h // world_size, d]), set head_dim=2, + and the function returns [b, s_local, h, d], preserving the original head/sequence + dim ordering. + + Args: + x: A 4D tensor with layout [b, *, *, d] where '*' are sequence and heads + head_dim: Which dimension index corresponds to heads (1 or 2) + + Returns: + Tensor with the same dim order as input, with heads gathered and sequence sharded. + """ + world_size = get_ulysses_parallel_world_size() + if world_size <= 1: + return x + + assert x.ndim == 4, f"x must have 4 dimensions, got {x.ndim}" + assert head_dim in (1, 2), f"head_dim must be 1 or 2, got {head_dim}" + + # Move the dimension to be split (s_global) to dim 0 for all_to_all_single + if head_dim == 1: + b, h_local, s_global, d = x.shape + # Shape transition: [b, h_local, s_global, d] -> [s_global, b, h_local, d] + permute_order = (2, 0, 1, 3) + else: # head_dim == 2 + b, s_global, h_local, d = x.shape + # Shape transition: [b, s_global, h_local, d] -> [s_global, b, h_local, d] + permute_order = (1, 0, 2, 3) + + assert ( + s_global % world_size == 0 + ), f"s_global ({s_global}) must be divisible by world_size ({world_size})" + + s_local, h_global = s_global // world_size, h_local * world_size + + x = x.permute(permute_order).contiguous() + x = _usp_all_to_all_single(x) + x = x.reshape(world_size, s_local, b, h_local, d) + + # Reorder dims to place 'world_size' adjacent to 'h_local' to merge them into 'h_global' + if head_dim == 1: + # Shape transition: [world_size, s_local, b, h_local, d] -> [b, world_size, h_local, s_local, d] + x = x.permute(2, 0, 3, 1, 4).contiguous().reshape(b, h_global, s_local, d) + else: # head_dim == 2 + # Shape transition: [world_size, s_local, b, h_local, d] -> [b, s_local, world_size, h_local, d] + x = x.permute(2, 1, 0, 3, 4).contiguous().reshape(b, s_local, h_global, d) + + return x + + +def ring_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_impl: "AttentionImpl", + is_causal: bool = False, + dropout_p: float = 0.0, +): + """ + Ring Attention implementation. + + This function implements Ring Attention, a strategy for distributed attention + computation that reduces peak memory usage. It accepts a generic attention + implementation (`attn_impl`) which is called by the underlying PyTorch + distributed attention primitive. + + Args: + query, key, value: The input tensors for attention. + attn_impl: An instance of an attention implementation backend + (e.g., FlashAttentionImpl) whose `forward` method will be + used as the computational kernel. + is_causal: Whether to apply causal masking. + dropout_p: Dropout probability. + """ + # torch.distributed.tensor.experimental._attention is not a public API, + from torch.distributed.tensor.experimental._attention import ( + _templated_ring_attention, + ) + + ring_pg = get_sp_group().ring_group + assert ring_pg is not None, "Ring process group is not initialized." + + # Ring attention primitives expect tensors in [B, H, S, D] layout. + # We permute the inputs here. + query = torch.permute(query, [0, 2, 1, 3]).contiguous() + key = torch.permute(key, [0, 2, 1, 3]).contiguous() + value = torch.permute(value, [0, 2, 1, 3]).contiguous() + + # Create an adapter function that matches the signature expected by + # _templated_ring_attention. The `attn_impl` already has dropout and + # causal settings configured during its initialization. + + # Note: Please be aware that Attention Backend and Ring Attention may require different QKV tensor shapes. + # For example, FlashAttention expects the format to be BSHD. + def attn_callable_adapter(q, k, v, *args, **kwargs): + # We ignore the dropout_p and is_causal passed by _templated_ring_attention + # and rely on the pre-configured attn_impl. + # The `attn_metadata` is not available here, so we pass None. + # This is a limitation we must accept when using this experimental API. + q = torch.permute(q, [0, 2, 1, 3]) + k = torch.permute(k, [0, 2, 1, 3]) + v = torch.permute(v, [0, 2, 1, 3]) + # logger.warning(f"Warning: return_s·oftmax_lse is only supported for FlashAttentionImpl") + output, softmax_lse, *rest = attn_impl.forward( + q, + k, + v, + attn_metadata=None, + return_softmax_lse=True, + ) + output = torch.permute(output, [0, 2, 1, 3]) + return output, softmax_lse, *rest + + # Starting from torch 2.6.0, _templated_ring_attention expects an integer + # segment_id for the attention function. + use_segment_id = torch_release >= (2, 6) + + attn_kwargs = dict( + op=attn_callable_adapter, + dropout_p=dropout_p, + is_causal=is_causal, + query=query, + key=key, + value=value, + group=ring_pg, # https://github.com/pytorch/pytorch/blob/c907c778f42ba2fdaf25b733dd25baf9779c6a12/torch/distributed/tensor/experimental/_context_parallel/_attention.py#L309 + ) + + if use_segment_id: + # For torch >= 2.6, segment_id is required. The value '1' is a placeholder + # as we are not using complex segmentation features. + out, *_ = _templated_ring_attention( + seq_dim=1, # segment_id + **attn_kwargs, + ) + else: + out, *_ = _templated_ring_attention( + **attn_kwargs, + ) + + # Permute the output back to [B, S, H, D] layout. + output = torch.permute(out, [0, 2, 1, 3]) + return output diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/utils.py b/sglang/python/sglang/multimodal_gen/runtime/layers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2777e9209fad0144ead2da5dc28af70000282f5e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/utils.py @@ -0,0 +1,266 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/utils.py +"""Utility methods for model layers.""" + +import inspect +from typing import Any, Callable, List, Optional + +import torch +from torch.library import Library + +from sglang.multimodal_gen.runtime.platforms import current_platform + + +def get_group_size(group) -> int: + if hasattr(group, "world_size"): + return group.world_size # GroupCoordinator + elif hasattr(group, "size") and callable(getattr(group, "size", None)): + return group.size() # ProcessGroup + else: + raise ValueError(f"Unsupported group type: {type(group)}") + + +def get_group_rank(group) -> int: + if hasattr(group, "rank_in_group"): + return group.rank_in_group # GroupCoordinator + elif hasattr(group, "rank") and callable(getattr(group, "rank", None)): + return group.rank() # ProcessGroup + else: + raise ValueError(f"Unsupported group type: {type(group)}") + + +def get_token_bin_counts_and_mask( + tokens: torch.Tensor, + vocab_size: int, + num_seqs: int, +) -> tuple[torch.Tensor, torch.Tensor]: + # Compute the bin counts for the tokens. + # vocab_size + 1 for padding. + bin_counts = torch.zeros( + (num_seqs, vocab_size + 1), dtype=torch.long, device=tokens.device + ) + bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) + bin_counts = bin_counts[:, :vocab_size] + mask = bin_counts > 0 + + return bin_counts, mask + + +sglang_lib = Library("sglang", "FRAGMENT") # noqa + + +def direct_register_custom_op( + op_name: str, + op_func: Callable, + mutates_args: List[str], + fake_impl: Optional[Callable] = None, + target_lib: Optional[Library] = None, +): + """ + `torch.library.custom_op` can have significant overhead because it + needs to consider complicated dispatching logic. This function + directly registers a custom op and dispatches it to the CUDA backend. + See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 + for more details. + + By default, the custom op is registered to the vLLM library. If you + want to register it to a different library, you can pass the library + object to the `target_lib` argument. + + IMPORTANT: the lifetime of the operator is tied to the lifetime of the + library object. If you want to bind the operator to a different library, + make sure the library object is alive when the operator is used. + + Note: This function will silently skip registration if the operator + with the same name is already registered to avoid RuntimeError in + multi-engine scenarios (e.g., VERL framework). + """ + import torch.library + + my_lib = target_lib or sglang_lib + + # Check if operator is already registered to avoid duplicate registration + # This is important for scenarios where multiple SGLang engines run in the same process + try: + # Try to access the operator to see if it's already registered + lib_name = my_lib.m.name if hasattr(my_lib.m, "name") else "sglang" + if hasattr(torch.ops, lib_name) and hasattr( + getattr(torch.ops, lib_name), op_name + ): + # Operator already exists, skip registration + return + except (AttributeError, RuntimeError): + # Operator doesn't exist, proceed with registration + pass + + if hasattr(torch.library, "infer_schema"): + schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) + else: + # for pytorch 2.4 + import torch._custom_op.impl + + schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) + + try: + my_lib.define(op_name + schema_str) + my_lib.impl( + op_name, op_func, "CUDA" if not current_platform.is_npu() else "PrivateUse1" + ) + if fake_impl is not None: + my_lib._register_fake(op_name, fake_impl) + except RuntimeError as error: + if "Tried to register an operator" in str(error) and "multiple times" in str( + error + ): + # Silently ignore duplicate registration errors + # This can happen in multi-engine scenarios + pass + else: + # Re-raise other RuntimeErrors + raise error + except AttributeError as error: + # Always re-raise AttributeError as it indicates missing dependencies + raise error + + +class CustomOpWrapper: + def __init__( + self, + op_name: str, + op_func: Callable, + mutates_args: List[str], + **extra_kwargs, + ): + self.op_name = op_name + self.op_func = op_func + self.mutates_args = mutates_args + self.extra_kwargs = extra_kwargs + self._impl: Optional[Callable] = None + + def __call__(self, *args, **kwargs): + return self.real_impl(*args, **kwargs) + + @property + def real_impl(self) -> Callable: + if self._impl is None: + if not hasattr(torch.ops.sglang, self.op_name): + + # NOTE(dark): if torch compile fail here, mark the decorator as eager + # lazy registration does not work with torch compile + direct_register_custom_op( + op_name=self.op_name, + op_func=self.op_func, + mutates_args=self.mutates_args, + fake_impl=self.fake_impl, + ) + self._impl = getattr(torch.ops.sglang, self.op_name) + assert self._impl is not None + return self._impl + + @property + def fake_impl(self) -> Callable: + if "fake_impl" in self.extra_kwargs: + return self.extra_kwargs["fake_impl"] + assert "out_shape" in self.extra_kwargs + signature = inspect.signature(self.op_func) + out_shape = self.extra_kwargs["out_shape"] + + # check out_shape in signature + + def fake_impl(*args, **kwargs): + if out_shape is None: + return None + bound = signature.bind(*args, **kwargs) + bound.apply_defaults() + try: + return torch.empty_like( + bound.args[out_shape] + if isinstance(out_shape, int) + else bound.arguments[out_shape] + ) + except (IndexError, KeyError): + raise RuntimeError( + f"Cannot find output argument at position `{out_shape}` for " + f"custom operator `{self.op_name}` with signature `{signature}`." + ) + + return fake_impl + + +# Real implementation +def register_custom_op( + fn: Optional[Callable] = None, + *, + op_name: Optional[str] = None, + mutates_args: Optional[List[str]] = None, + eager: bool = True, + **extra_kwargs, +) -> Any: + """ + A decorator to register a custom operator. + + Example usage: + ```python + # inplace operator, out_shape is None by default + @register_custom_op(mutates_args=["x"]) + def add_1_(x: torch.Tensor) -> None: + x.add_(1) + + # operator with output, out_shape indicates the position of output + @register_custom_op(mutates_args=["x"], out_shape=0) + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x.add_(y) + ``` + + :param fn: The function to be registered as a custom operator. + If None, return a decorator. + :type fn: Callable + :param op_name: The name of the operator. If None, use the function name + :type op_name: Optional[str] + :param mutates_args: A list of argument names that are mutated in-place. + :type mutates_args: List[str] + :param out_shape: The position (int for positional, str for keyword) of the output-shape tensor. + It is used to generate a fake implementation for torch.compile compatibility. + If the operator is inplace and has no output, set to None. + :type out_shape: Optional[List[Union[int, str]]] + :param fake_impl: A fake implementation for the operator. + Only one of `out_shape` or `fake_impl` should be provided. + :type fake_impl: Optional[Callable] + :param eager: Whether to register the operator eagerly. + If False, the registration will be deferred until the first call. + If you met any issue with torch.compile, try to set eager=True. + Currently, to avoid misuse, we set eager=True by default. + :type eager: bool + :return: The registered JIT custom operator, or a decorator. + NOTE: the real register will occur at the first call of the function. + :rtype: Callable + """ + extra_kwarg_keys = set(extra_kwargs.keys()) + expected_kwarg_keys = set({"out_shape", "fake_impl"}) + assert ( + expected_kwarg_keys >= extra_kwarg_keys + ), f"Unexpected extra kwargs: {extra_kwarg_keys - expected_kwarg_keys}" + + has_out_shape = "out_shape" in extra_kwargs + has_fake_impl = "fake_impl" in extra_kwargs + assert not ( + has_out_shape and has_fake_impl + ), "Only one of `out_shape` or `fake_impl` should be provided." + # Assume inplace if neither out_shape nor fake_impl is provided + if not (has_out_shape or has_fake_impl): + extra_kwargs["out_shape"] = None + + def decorator(op_func: Callable) -> Callable: + wrapper = CustomOpWrapper( + op_name=op_name or op_func.__name__, + op_func=op_func, + mutates_args=mutates_args or [], + **extra_kwargs, + ) + return wrapper.real_impl if eager else wrapper + + if fn is not None: + return decorator(fn) + return decorator diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/visual_embedding.py b/sglang/python/sglang/multimodal_gen/runtime/layers/visual_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..c8eff7d6f0d8d77ab28b6107cb697aaa73edef04 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/visual_embedding.py @@ -0,0 +1,273 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import math + +import torch +import torch.nn as nn +from diffusers.models.embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings as _CombinedTimestepGuidanceTextProjEmbeddings, +) +from diffusers.models.embeddings import ( + CombinedTimestepTextProjEmbeddings as _CombinedTimestepTextProjEmbeddings, +) +from diffusers.models.embeddings import ( + PixArtAlphaTextProjection, + TimestepEmbedding, +) +from diffusers.models.embeddings import Timesteps as _Timesteps +from diffusers.models.embeddings import ( + get_timestep_embedding as timestep_embedding_diffusers, +) + +from sglang.jit_kernel.timestep_embedding import ( + timestep_embedding as timestep_embedding_cuda, +) +from sglang.multimodal_gen.runtime.layers.activation import get_act_fn +from sglang.multimodal_gen.runtime.layers.linear import ColumnParallelLinear +from sglang.multimodal_gen.runtime.layers.mlp import MLP +from sglang.multimodal_gen.runtime.platforms import current_platform + +_is_cuda = current_platform.is_cuda() + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding + + Image to Patch Embedding using Conv2d + + A convolution based approach to patchifying a 2D image w/ embedding projection. + + Based on the impl in https://github.com/google-research/vision_transformer + + Hacked together by / Copyright 2020 Ross Wightman + + Remove the _assert function in forward function to be compatible with multi-resolution images. + """ + + def __init__( + self, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + dtype=None, + prefix: str = "", + ): + super().__init__() + # Convert patch_size to 2-tuple + if isinstance(patch_size, list | tuple): + if len(patch_size) == 1: + patch_size = (patch_size[0], patch_size[0]) + else: + patch_size = (patch_size, patch_size) + + self.patch_size = patch_size + self.flatten = flatten + + self.proj = nn.Conv3d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=bias, + dtype=dtype, + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +class Timesteps(_Timesteps): + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + if _is_cuda: + return timestep_embedding_cuda( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + else: + return timestep_embedding_diffusers( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + + +class CombinedTimestepGuidanceTextProjEmbeddings( + _CombinedTimestepGuidanceTextProjEmbeddings +): + def __init__(self, embedding_dim, pooled_projection_dim): + nn.Module.__init__(self) + + # use sgld op + self.time_proj = Timesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0 + ) + # use diffusers op + self.timestep_embedder = TimestepEmbedding( + in_channels=256, time_embed_dim=embedding_dim + ) + self.guidance_embedder = TimestepEmbedding( + in_channels=256, time_embed_dim=embedding_dim + ) + self.text_embedder = PixArtAlphaTextProjection( + pooled_projection_dim, embedding_dim, act_fn="silu" + ) + + +class CombinedTimestepTextProjEmbeddings(_CombinedTimestepTextProjEmbeddings): + def __init__(self, embedding_dim, pooled_projection_dim): + nn.Module.__init__(self) + + # use sgld op + self.time_proj = Timesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0 + ) + # use diffusers op + self.timestep_embedder = TimestepEmbedding( + in_channels=256, time_embed_dim=embedding_dim + ) + self.text_embedder = PixArtAlphaTextProjection( + pooled_projection_dim, embedding_dim, act_fn="silu" + ) + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__( + self, + hidden_size, + act_layer="silu", + frequency_embedding_size=256, + max_period=10000, + dtype=None, + freq_dtype=torch.float32, + prefix: str = "", + ): + super().__init__() + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + + self.mlp = MLP( + frequency_embedding_size, + hidden_size, + hidden_size, + act_type=act_layer, + dtype=dtype, + ) + self.freq_dtype = freq_dtype + + def forward( + self, t: torch.Tensor, timestep_seq_len: int | None = None + ) -> torch.Tensor: + t_freq = timestep_embedding( + t, self.frequency_embedding_size, self.max_period, dtype=self.freq_dtype + ).to(self.mlp.fc_in.weight.dtype) + if timestep_seq_len is not None: + assert ( + t_freq.shape[0] % timestep_seq_len == 0 + ), "timestep length is not divisible by timestep_seq_len" + batch_size = t_freq.shape[0] // timestep_seq_len + t_freq = t_freq.unflatten(0, (batch_size, timestep_seq_len)) + # t_freq = t_freq.to(self.mlp.fc_in.weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +def timestep_embedding( + t: torch.Tensor, + dim: int, + max_period: int = 10000, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Create sinusoidal timestep embeddings. + + Args: + t: Tensor of shape [B] with timesteps + dim: Embedding dimension + max_period: Controls the minimum frequency of the embeddings + + Returns: + Tensor of shape [B, dim] with embeddings + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=dtype, device=t.device) + / half + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class ModulateProjection(nn.Module): + """Modulation layer for DiT blocks.""" + + def __init__( + self, + hidden_size: int, + factor: int = 2, + act_layer: str = "silu", + dtype: torch.dtype | None = None, + prefix: str = "", + ): + super().__init__() + self.factor = factor + self.hidden_size = hidden_size + self.linear = ColumnParallelLinear( + hidden_size, + hidden_size * factor, + bias=True, + gather_output=True, + params_dtype=dtype, + ) + self.act = get_act_fn(act_layer) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.act(x) + x, _ = self.linear(x) + return x + + +def unpatchify(x, t, h, w, patch_size, channels) -> torch.Tensor: + """ + Convert patched representation back to image space. + + Args: + x: Tensor of shape [B, T*H*W, C*P_t*P_h*P_w] + t, h, w: Temporal and spatial dimensions + + Returns: + Unpatchified tensor of shape [B, C, T*P_t, H*P_h, W*P_w] + """ + assert x.ndim == 3, f"x.ndim: {x.ndim}" + assert len(patch_size) == 3, f"patch_size: {patch_size}" + assert t * h * w == x.shape[1], f"t * h * w: {t * h * w}, x.shape[1]: {x.shape[1]}" + c = channels + pt, ph, pw = patch_size + + x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw)) + x = torch.einsum("nthwcopq->nctohpwq", x) + imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) + + return imgs diff --git a/sglang/python/sglang/multimodal_gen/runtime/layers/vocab_parallel_embedding.py b/sglang/python/sglang/multimodal_gen/runtime/layers/vocab_parallel_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..fecb4245fd024d009aa143abb8dee092b786b4b8 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/layers/vocab_parallel_embedding.py @@ -0,0 +1,490 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Sequence +from dataclasses import dataclass + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn.parameter import Parameter, UninitializedParameter + +from sglang.multimodal_gen.runtime.distributed import ( + divide, + get_tp_group, + tensor_model_parallel_all_reduce, +) +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, + QuantizeMethodBase, + method_has_implemented_embedding, +) +from sglang.multimodal_gen.runtime.layers.utils import get_group_rank, get_group_size +from sglang.multimodal_gen.runtime.models.parameter import BasevLLMParameter +from sglang.multimodal_gen.runtime.models.utils import set_weight_attrs +from sglang.multimodal_gen.runtime.platforms import current_platform + +DEFAULT_VOCAB_PADDING_SIZE = 64 + + +class UnquantizedEmbeddingMethod(QuantizeMethodBase): + """Unquantized method for embeddings.""" + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """Create weights for embedding layer.""" + + weight = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply( + self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None + ) -> torch.Tensor: + return F.linear(x, layer.weight, bias) + + def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: + return F.embedding(input_, layer.weight) + + +def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: + """Pad the vocab size to the given value.""" + return ((vocab_size + pad_to - 1) // pad_to) * pad_to + + +def vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size: int, rank: int, offset: int = 0 +) -> Sequence[int]: + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f + offset, index_l + offset + + +def vocab_range_from_global_vocab_size( + global_vocab_size: int, rank: int, world_size: int, offset: int = 0 +) -> Sequence[int]: + per_partition_vocab_size = divide(global_vocab_size, world_size) + return vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size, rank, offset=offset + ) + + +@dataclass +class VocabParallelEmbeddingShardIndices: + """Indices for a shard of a vocab parallel embedding.""" + + padded_org_vocab_start_index: int + padded_org_vocab_end_index: int + padded_added_vocab_start_index: int + padded_added_vocab_end_index: int + + org_vocab_start_index: int + org_vocab_end_index: int + added_vocab_start_index: int + added_vocab_end_index: int + + @property + def num_org_elements(self) -> int: + return self.org_vocab_end_index - self.org_vocab_start_index + + @property + def num_added_elements(self) -> int: + return self.added_vocab_end_index - self.added_vocab_start_index + + @property + def num_org_elements_padded(self) -> int: + return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index + + @property + def num_added_elements_padded(self) -> int: + return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index + + @property + def num_org_vocab_padding(self) -> int: + return self.num_org_elements_padded - self.num_org_elements + + @property + def num_added_vocab_padding(self) -> int: + return self.num_added_elements_padded - self.num_added_elements + + @property + def num_elements_padded(self) -> int: + return self.num_org_elements_padded + self.num_added_elements_padded + + def __post_init__(self): + # sanity checks + assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index + assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index + + assert self.org_vocab_start_index <= self.org_vocab_end_index + assert self.added_vocab_start_index <= self.added_vocab_end_index + + assert self.org_vocab_start_index <= self.padded_org_vocab_start_index + assert self.added_vocab_start_index <= self.padded_added_vocab_start_index + assert self.org_vocab_end_index <= self.padded_org_vocab_end_index + assert self.added_vocab_end_index <= self.padded_added_vocab_end_index + + assert self.num_org_elements <= self.num_org_elements_padded + assert self.num_added_elements <= self.num_added_elements_padded + + +@torch.compile( + dynamic=True, + backend=current_platform.simple_compile_backend, + disable=current_platform.is_npu(), +) +def get_masked_input_and_mask( + input_: torch.Tensor, + org_vocab_start_index: int, + org_vocab_end_index: int, + num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int, +) -> tuple[torch.Tensor, torch.Tensor]: + # torch.compile will fuse all of the pointwise ops below + # into a single kernel, making it very fast + org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index) + added_vocab_mask = (input_ >= added_vocab_start_index) & ( + input_ < added_vocab_end_index + ) + added_offset = ( + added_vocab_start_index + - (org_vocab_end_index - org_vocab_start_index) + - num_org_vocab_padding + ) + valid_offset = (org_vocab_start_index * org_vocab_mask) + ( + added_offset * added_vocab_mask + ) + vocab_mask = org_vocab_mask | added_vocab_mask + input_ = vocab_mask * (input_ - valid_offset) + return input_, ~vocab_mask + + +class VocabParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + Adapted from torch.nn.Embedding, note that we pad the vocabulary size to + make sure it is divisible by the number of model parallel GPUs. + + In order to support various loading methods, we ensure that LoRA-added + embeddings are always at the end of TP-sharded tensors. In other words, + we shard base embeddings and LoRA embeddings separately (both padded), + and place them in the same tensor. + In this example, we will have the original vocab size = 1010, + added vocab size = 16 and padding to 64. Therefore, the total + vocab size with padding will be 1088 (because we first pad 1010 to + 1024, add 16, and then pad to 1088). + Therefore, the tensor format looks like the following: + TP1, rank 0 (no sharding): + |< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >| + corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 | + index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 | + + TP2, rank 0: + |< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >| + corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 | + index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 | + TP2, rank 1: + |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >| + corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 | + index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 | + + Args: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. + quant_config: quant config for the layer + prefix: full name of the layer in the state dict + """ # noqa: E501 + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + params_dtype: torch.dtype | None = None, + org_num_embeddings: int | None = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + tp_group: dist.ProcessGroup = None, + ): + super().__init__() + + # Keep the input dimensions. + tp_group = tp_group or get_tp_group() + tp_rank = get_group_rank(tp_group) + self.tp_size = get_group_size(tp_group) + self.tp_group = tp_group + self.num_embeddings = num_embeddings + self.padding_size = padding_size + self.org_vocab_size = org_num_embeddings or num_embeddings + num_added_embeddings = num_embeddings - self.org_vocab_size + self.org_vocab_size_padded = pad_vocab_size( + self.org_vocab_size, self.padding_size + ) + self.num_embeddings_padded = pad_vocab_size( + self.org_vocab_size_padded + num_added_embeddings, self.padding_size + ) + assert self.org_vocab_size_padded <= self.num_embeddings_padded + + self.shard_indices = self._get_indices( + self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, + tp_rank, + self.tp_size, + ) + self.embedding_dim = embedding_dim + + quant_method = None + if quant_config is not None: + quant_method = quant_config.get_quant_method(self, prefix=prefix) + if quant_method is None: + quant_method = UnquantizedEmbeddingMethod() + + # If we are making an embedding layer, then our quantization linear + # method must implement the embedding operation. If we are another + # layer type like ParallelLMHead, this is not important. + is_embedding_layer = type(self.__class__) is VocabParallelEmbedding + quant_method_implements_embedding = method_has_implemented_embedding( + type(quant_method) + ) + if is_embedding_layer and not quant_method_implements_embedding: + raise NotImplementedError( + f"The class {type(quant_method).__name__} must implement " + "the 'embedding' method, see UnquantizedEmbeddingMethod." + ) + + self.quant_method: QuantizeMethodBase = quant_method + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + # Divide the weight matrix along the vocaburaly dimension. + self.num_added_embeddings = self.num_embeddings - self.org_vocab_size + self.num_embeddings_per_partition = divide( + self.num_embeddings_padded, self.tp_size + ) + assert ( + self.shard_indices.num_elements_padded == self.num_embeddings_per_partition + ) + self.num_org_embeddings_per_partition = ( + self.shard_indices.org_vocab_end_index + - self.shard_indices.org_vocab_start_index + ) + self.num_added_embeddings_per_partition = ( + self.shard_indices.added_vocab_end_index + - self.shard_indices.added_vocab_start_index + ) + + self.quant_method.create_weights( + self, + self.embedding_dim, + [self.num_embeddings_per_partition], + self.embedding_dim, + self.num_embeddings_padded, + params_dtype=params_dtype, + weight_loader=self.weight_loader, + ) + + @classmethod + def _get_indices( + cls, + vocab_size_padded: int, + org_vocab_size_padded: int, + vocab_size: int, + org_vocab_size: int, + tp_rank: int, + tp_size: int, + ) -> VocabParallelEmbeddingShardIndices: + """Get start and end indices for vocab parallel embedding, following the + layout outlined in the class docstring, based on the given tp_rank and + tp_size.""" + num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded + padded_org_vocab_start_index, padded_org_vocab_end_index = ( + vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, tp_size) + ) + padded_added_vocab_start_index, padded_added_vocab_end_index = ( + vocab_range_from_global_vocab_size( + num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size + ) + ) + # remove padding + org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size) + org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size) + added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size) + added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size) + return VocabParallelEmbeddingShardIndices( + padded_org_vocab_start_index, + padded_org_vocab_end_index, + padded_added_vocab_start_index, + padded_added_vocab_end_index, + org_vocab_start_index, + org_vocab_end_index, + added_vocab_start_index, + added_vocab_end_index, + ) + + def get_sharded_to_full_mapping(self) -> list[int] | None: + """Get a mapping that can be used to reindex the gathered + logits for sampling. + + During sampling, we gather logits from all ranks. The relationship + of index->token_id will follow the same format as outlined in the class + docstring. However, after the gather, we want to reindex the final + logits tensor to map index->token_id one-to-one (the index is always + equal the token_id it corresponds to). The indices returned by this + method allow us to do that. + """ + if self.tp_size < 2: + return None + + base_embeddings: list[int] = [] + added_embeddings: list[int] = [] + padding: list[int] = [] + for tp_rank in range(self.tp_size): + shard_indices = self._get_indices( + self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, + tp_rank, + self.tp_size, + ) + range_start = self.num_embeddings_per_partition * tp_rank + range_end = self.num_embeddings_per_partition * (tp_rank + 1) + base_embeddings.extend( + range(range_start, range_start + shard_indices.num_org_elements) + ) + padding.extend( + range( + range_start + shard_indices.num_org_elements, + range_start + shard_indices.num_org_elements_padded, + ) + ) + added_embeddings.extend( + range( + range_start + shard_indices.num_org_elements_padded, + range_start + + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements, + ) + ) + padding.extend( + range( + range_start + + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements, + range_start + + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements_padded, + ) + ) + assert ( + range_start + + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements_padded + == range_end + ) + ret = base_embeddings + added_embeddings + padding + assert len(ret) == self.num_embeddings_padded + return ret + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + output_dim = getattr(param, "output_dim", None) + packed_dim = getattr(param, "packed_dim", None) + + # If the parameter is a gguf weight, then load it directly. + if getattr(param, "is_gguf_weight_type", None): + param.data.copy_(loaded_weight) + param.weight_type = loaded_weight.item() + return + elif isinstance(param, UninitializedParameter): + shape = list(loaded_weight.shape) + if output_dim is not None: + shape[output_dim] = self.num_embeddings_per_partition + param.materialize(tuple(shape), dtype=loaded_weight.dtype) + + # If parameter does not have output dim, then it should + # be copied onto all gpus (e.g. g_idx for act_order gptq). + if output_dim is None: + assert param.data.shape == loaded_weight.shape + param.data.copy_(loaded_weight) + return + + # Shard indexes for loading the weight + start_idx = self.shard_indices.org_vocab_start_index + shard_size = self.shard_indices.org_vocab_end_index - start_idx + + # If param packed on the same dim we are sharding on, then + # need to adjust offsets of loaded weight by pack_factor. + if packed_dim is not None and packed_dim == output_dim: + packed_factor = ( + param.packed_factor + if isinstance(param, BasevLLMParameter) + else param.pack_factor + ) + assert loaded_weight.shape[output_dim] == ( + self.org_vocab_size // param.packed_factor + ) + start_idx = start_idx // packed_factor + shard_size = shard_size // packed_factor + else: + assert loaded_weight.shape[output_dim] == self.org_vocab_size + + # Copy the data. Select chunk corresponding to current shard. + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + + param[: loaded_weight.shape[0]].data.copy_(loaded_weight) + param[loaded_weight.shape[0] :].data.fill_(0) + + def forward(self, input_): + if self.tp_size > 1: + # Build the mask. + masked_input, input_mask = get_masked_input_and_mask( + input_, + self.shard_indices.org_vocab_start_index, + self.shard_indices.org_vocab_end_index, + self.shard_indices.num_org_vocab_padding, + self.shard_indices.added_vocab_start_index, + self.shard_indices.added_vocab_end_index, + ) + else: + masked_input = input_ + # Get the embeddings. + output_parallel = self.quant_method.embedding(self, masked_input.long()) + # Mask the output embedding. + if self.tp_size > 1: + output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) + # Reduce across all the model parallel GPUs. + output = tensor_model_parallel_all_reduce( + output_parallel, tp_group=self.tp_group + ) + return output + + def extra_repr(self) -> str: + s = f"num_embeddings={self.num_embeddings_per_partition}" + s += f", embedding_dim={self.embedding_dim}" + s += f", org_vocab_size={self.org_vocab_size}" + s += f", num_embeddings_padded={self.num_embeddings_padded}" + s += f", tp_size={self.tp_size}" + return s diff --git a/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/adapter_loader.py b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/adapter_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..7d073d5a014d5d71d4ea50941c5badd848b9b424 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/adapter_loader.py @@ -0,0 +1,72 @@ +from safetensors.torch import load_file as safetensors_load_file + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import ( + ComponentLoader, +) +from sglang.multimodal_gen.runtime.loader.utils import ( + _list_safetensors_files, + set_default_torch_dtype, + skip_init_modules, +) +from sglang.multimodal_gen.runtime.models.registry import ModelRegistry +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( + get_diffusers_component_config, +) +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE + + +class AdapterLoader(ComponentLoader): + """Loader for small adapter-style modules (e.g., LTX-2 connectors). + + This loader intentionally avoids FSDP sharding and just: + 1) Instantiates the module from `config.json`. + 2) Loads a single safetensors state_dict. + """ + + component_names = ["connectors"] + expected_library = "diffusers" + + def load_customized( + self, component_model_path: str, server_args: ServerArgs, *args + ): + config = get_diffusers_component_config(component_path=component_model_path) + + cls_name = config.pop("_class_name", None) + if cls_name is None: + raise ValueError( + "Model config does not contain a _class_name attribute. " + "Only diffusers format is supported." + ) + + config.pop("_diffusers_version", None) + config.pop("_name_or_path", None) + + server_args.model_paths["connectors"] = component_model_path + + model_cls, _ = ModelRegistry.resolve_model_cls(cls_name) + + target_device = get_local_torch_device() + default_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.dit_precision] + + from types import SimpleNamespace + + with set_default_torch_dtype(default_dtype), skip_init_modules(): + connector_cfg = SimpleNamespace(**config) + model = model_cls(connector_cfg).to( + device=target_device, dtype=default_dtype + ) + + safetensors_list = _list_safetensors_files(component_model_path) + if not safetensors_list: + raise ValueError(f"No safetensors files found in {component_model_path}") + if len(safetensors_list) != 1: + raise ValueError( + f"Found {len(safetensors_list)} safetensors files in {component_model_path}, expected 1" + ) + + loaded = safetensors_load_file(safetensors_list[0]) + model.load_state_dict(loaded, strict=False) + + return model diff --git a/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/bridge_loader.py b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/bridge_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..646c232837338a933452009e8b70149e32bafbff --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/bridge_loader.py @@ -0,0 +1,106 @@ +from copy import deepcopy + +import torch + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import ( + ComponentLoader, +) +from sglang.multimodal_gen.runtime.loader.fsdp_load import maybe_load_fsdp_model +from sglang.multimodal_gen.runtime.loader.utils import _list_safetensors_files +from sglang.multimodal_gen.runtime.models.registry import ModelRegistry +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( + get_diffusers_component_config, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE + +logger = init_logger(__name__) + + +class BridgeLoader(ComponentLoader): + """Loader for MOVA dual tower bridge with FSDP support.""" + + pipeline_bridge_config_attr: str = "bridge_config" + + component_names = ["dual_tower_bridge"] + expected_library = "diffusers" + + def load_customized( + self, component_model_path: str, server_args: ServerArgs, component_name: str + ): + config = get_diffusers_component_config(component_path=component_model_path) + hf_config = deepcopy(config) + class_name = config.pop("_class_name", None) + if class_name is None: + raise ValueError( + "Model config does not contain a _class_name attribute. " + "Only diffusers format is supported." + ) + server_args.model_paths[component_name] = component_model_path + + # Try to get bridge config from pipeline config, fallback to creating one + bridge_config = getattr( + server_args.pipeline_config, self.pipeline_bridge_config_attr, None + ) + if bridge_config is not None: + bridge_config.update_model_arch(config) + else: + # Create a minimal config from hf_config + from sglang.multimodal_gen.configs.models.bridges.mova_dual_tower import ( + MOVADualTowerConfig, + ) + + bridge_config = MOVADualTowerConfig() + bridge_config.update_model_arch(config) + + model_cls, _ = ModelRegistry.resolve_model_cls(class_name) + + # Find all safetensors files + safetensors_list = _list_safetensors_files(component_model_path) + if not safetensors_list: + raise ValueError(f"No safetensors files found in {component_model_path}") + + default_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.dit_precision] + + logger.info( + "Loading %s from %s safetensors files, default_dtype: %s", + class_name, + len(safetensors_list), + default_dtype, + ) + + # Check if FSDP loading is available + if ( + server_args.hsdp_shard_dim is not None + and hasattr(model_cls, "_fsdp_shard_conditions") + and model_cls._fsdp_shard_conditions + ): + # Load with FSDP support + model = maybe_load_fsdp_model( + model_cls=model_cls, + init_params={"config": bridge_config, "hf_config": hf_config}, + weight_dir_list=safetensors_list, + device=get_local_torch_device(), + hsdp_replicate_dim=server_args.hsdp_replicate_dim, + hsdp_shard_dim=server_args.hsdp_shard_dim, + cpu_offload=server_args.dit_cpu_offload, + pin_cpu_memory=server_args.pin_cpu_memory, + fsdp_inference=server_args.use_fsdp_inference, + param_dtype=default_dtype, + reduce_dtype=torch.float32, + output_dtype=None, + strict=False, + ) + else: + # Fallback to simple loading (for non-FSDP or legacy models) + model = model_cls.from_pretrained( + component_model_path, torch_dtype=default_dtype + ) + model = model.to(device=get_local_torch_device(), dtype=default_dtype) + + total_params = sum(p.numel() for p in model.parameters()) + logger.info("Loaded bridge model with %.2fM parameters", total_params / 1e6) + + return model diff --git a/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/component_loader.py b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/component_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..b416e6bbdaa6f953fe2ffdbe726dbf2385bcb854 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/component_loader.py @@ -0,0 +1,350 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import importlib +import os +import pkgutil +import traceback +from abc import ABC +from typing import Any, Type + +import torch +from diffusers import AutoModel +from torch import nn +from transformers import AutoImageProcessor, AutoProcessor, AutoTokenizer + +from sglang.multimodal_gen.configs.models import ModelConfig +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.loader.utils import ( + _normalize_component_type, + component_name_to_loader_cls, + get_memory_usage_of_component, +) +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import get_hf_config +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class ComponentLoader(ABC): + """Base class for loading a specific type of model component.""" + + # the list of possible name of the component in model_index.json, e.g., scheduler + component_names: list[str] = [] + + # diffusers or transformers + expected_library: str = "" + + _loaders_registered = False + + def __init_subclass__(cls, **kwargs): + """ + register loaders, called when subclass is imported + """ + super().__init_subclass__(**kwargs) + for component_name in cls.component_names: + component_name_to_loader_cls[component_name] = cls + + def __init__(self, device=None) -> None: + self.device = device + + def should_offload( + self, server_args: ServerArgs, model_config: ModelConfig | None = None + ): + # not offload by default + return False + + def target_device(self, should_offload): + if should_offload: + return ( + torch.device("mps") + if current_platform.is_mps() + else torch.device("cpu") + ) + else: + return get_local_torch_device() + + def load( + self, + component_model_path: str, + server_args: ServerArgs, + component_name: str, + transformers_or_diffusers: str, + ) -> tuple[AutoModel, float]: + """ + Template method that standardizes logging around the core load implementation. + The priority of loading method is: + 1. load customized component + 2. load native diffusers/transformers component + If all of the above methods failed, an error will be thrown + + """ + gpu_mem_before_loading = current_platform.get_available_gpu_memory() + logger.info( + "Loading %s from %s. avail mem: %.2f GB", + component_name, + component_model_path, + gpu_mem_before_loading, + ) + try: + component = self.load_customized( + component_model_path, server_args, component_name + ) + source = "sgl-diffusion" + except Exception as e: + if "Unsupported model architecture" in str(e): + logger.info( + f"Component: {component_name} doesn't have a customized version yet, using native version" + ) + else: + traceback.print_exc() + logger.error( + f"Error while loading customized {component_name}, falling back to native version" + ) + # fallback to native version + component = self.load_native( + component_model_path, server_args, transformers_or_diffusers + ) + should_offload = self.should_offload(server_args) + target_device = self.target_device(should_offload) + component = component.to(device=target_device) + source = "native" + logger.warning( + "Native component %s: %s is loaded, performance may be sub-optimal", + component_name, + component.__class__.__name__, + ) + + if component is None: + logger.error("Load %s failed", component_name) + consumed = 0.0 + else: + if isinstance(component, nn.Module): + component = component.eval() + current_gpu_mem = current_platform.get_available_gpu_memory() + model_size = get_memory_usage_of_component(component) or "NA" + consumed = gpu_mem_before_loading - current_gpu_mem + logger.info( + f"Loaded %s: %s ({source} version). model size: %s GB, consumed GPU mem: %.2f GB, avail GPU mem: %.2f GB", + component_name, + component.__class__.__name__, + model_size, + consumed, + current_gpu_mem, + ) + return component, consumed + + def load_native( + self, + component_model_path: str, + server_args: ServerArgs, + transformers_or_diffusers: str, + ) -> AutoModel: + """ + Load the component using the native library (transformers/diffusers). + """ + if transformers_or_diffusers == "transformers": + from transformers import AutoModel + + config = get_hf_config( + component_model_path, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + ) + return AutoModel.from_pretrained( + component_model_path, + config=config, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + ) + elif transformers_or_diffusers == "diffusers": + from diffusers import AutoModel + + return AutoModel.from_pretrained( + component_model_path, + revision=server_args.revision, + trust_remote_code=server_args.trust_remote_code, + ) + else: + raise ValueError(f"Unsupported library: {transformers_or_diffusers}") + + def load_customized( + self, component_model_path: str, server_args: ServerArgs, component_name: str + ): + """ + Load the customized version component, implemented and optimized in SGL-diffusion + """ + raise NotImplementedError( + f"load_customized not implemented for {self.__class__.__name__}" + ) + + @classmethod + def _ensure_loaders_registered(cls): + """ + avoid multiple registration + """ + if cls._loaders_registered: + return + + package_dir = os.path.dirname(__file__) + package_name = ( + __package__ + or "sglang.multimodal_gen.runtime.loader.component_loaders.component_loaders" + ) + + for _, name, _ in pkgutil.iter_modules([package_dir]): + # skip importing self to avoid circular dependency issues + if name == "component_loader": + continue + try: + importlib.import_module(f".{name}", package=package_name) + except ImportError as e: + logger.warning(f"Failed to import loader component {name}: {e}") + + cls._loaders_registered = True + + @classmethod + def for_component_type( + cls, component_name: str, transformers_or_diffusers: str + ) -> "ComponentLoader": + """ + Factory method to create a component loader for a specific component type. + + Args: + component_name: Type of component (e.g., "vae", "text_encoder", "transformer", "scheduler") + transformers_or_diffusers: Whether the component is from transformers or diffusers + """ + cls._ensure_loaders_registered() + + # Map of component types to their loader classes and expected library + component_name = _normalize_component_type(component_name) + + # NOTE(FlamingoPg): special for LTX-2 models + if component_name == "vocoder" or component_name == "connectors": + transformers_or_diffusers = "diffusers" + + # NOTE(CloudRipple): special for MOVA models + # TODO(CloudRipple): remove most of these special cases after unifying the loading logic + if component_name in [ + "audio_vae", + "audio_dit", + "dual_tower_bridge", + "video_dit", + ]: + transformers_or_diffusers = "diffusers" + + if ( + component_name == "scheduler" + and transformers_or_diffusers == "mova.diffusion.schedulers.flow_match_pair" + ): + transformers_or_diffusers = "diffusers" + + if component_name in component_name_to_loader_cls: + loader_cls: Type[ComponentLoader] = component_name_to_loader_cls[ + component_name + ] + expected_library = loader_cls.expected_library + # Assert that the library matches what's expected for this component type + assert ( + transformers_or_diffusers == expected_library + ), f"{component_name} must be loaded from {expected_library}, got {transformers_or_diffusers}" + return loader_cls() + + # For unknown component types, use a generic loader + logger.warning( + "No specific loader found for component type: %s. Using generic loader.", + component_name, + ) + return GenericComponentLoader(transformers_or_diffusers) + + +class ImageProcessorLoader(ComponentLoader): + """Loader for image processor.""" + + component_names = ["image_processor"] + expected_library = "transformers" + + def load_customized( + self, component_model_path: str, server_args: ServerArgs, component_name: str + ) -> Any: + return AutoImageProcessor.from_pretrained(component_model_path, use_fast=True) + + +class AutoProcessorLoader(ComponentLoader): + """Loader for auto processor.""" + + component_names = ["processor"] + expected_library = "transformers" + + def load_customized( + self, component_model_path: str, server_args: ServerArgs, component_name: str + ) -> Any: + return AutoProcessor.from_pretrained(component_model_path) + + +class TokenizerLoader(ComponentLoader): + """Loader for tokenizers.""" + + component_names = ["tokenizer"] + expected_library = "transformers" + + def load_customized( + self, component_model_path: str, server_args: ServerArgs, component_name: str + ) -> Any: + return AutoTokenizer.from_pretrained( + component_model_path, + padding_size="right", + ) + + +class GenericComponentLoader(ComponentLoader): + """Generic loader for components that don't have a specific loader.""" + + def __init__(self, library="transformers") -> None: + super().__init__() + self.library = library + + +class PipelineComponentLoader: + """ + Utility class for loading the components in a pipeline. + """ + + @staticmethod + def load_component( + component_name: str, + component_model_path: str, + transformers_or_diffusers: str, + server_args: ServerArgs, + ): + """ + Load a pipeline component. + + Args: + component_name: Name of the component (e.g., "vae", "text_encoder", "transformer", "scheduler") + component_model_path: Path to the component model + transformers_or_diffusers: Whether the component is from transformers or diffusers + + """ + + # Get the appropriate loader for this component type + loader = ComponentLoader.for_component_type( + component_name, transformers_or_diffusers + ) + + try: + # Load the component + return loader.load( + component_model_path, + server_args, + component_name, + transformers_or_diffusers, + ) + except Exception as e: + logger.error( + f"Error while loading component: {component_name}, {component_model_path=}" + ) + raise e diff --git a/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/image_encoder_loader.py b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/image_encoder_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..18a33b3bbb54217679ae679a25528f6d0802b20f --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/image_encoder_loader.py @@ -0,0 +1,57 @@ +from sglang.multimodal_gen.configs.models import ModelConfig +from sglang.multimodal_gen.runtime.loader.component_loaders.text_encoder_loader import ( + TextEncoderLoader, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( + get_diffusers_component_config, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class ImageEncoderLoader(TextEncoderLoader): + component_names = ["image_encoder"] + expected_library = "transformers" + + def should_offload(self, server_args, model_config: ModelConfig | None = None): + should_offload = server_args.image_encoder_cpu_offload + if not should_offload: + return False + # _fsdp_shard_conditions is in arch_config, not directly on model_config + arch_config = ( + getattr(model_config, "arch_config", model_config) if model_config else None + ) + fsdp_shard_conditions = ( + getattr(arch_config, "_fsdp_shard_conditions", []) if arch_config else [] + ) + use_cpu_offload = should_offload and len(fsdp_shard_conditions) > 0 + return use_cpu_offload + + def load_customized( + self, component_model_path: str, server_args: ServerArgs, *args + ): + """Load the text encoders based on the model path, and inference args.""" + # model_config: PretrainedConfig = get_hf_config( + # model=model_path, + # trust_remote_code=server_args.trust_remote_code, + # revision=server_args.revision, + # model_override_args=None, + # ) + model_config = get_diffusers_component_config( + component_path=component_model_path + ) + + encoder_config = server_args.pipeline_config.image_encoder_config + encoder_config.update_model_arch(model_config) + + # Always start with local device; load_model will adjust for offload if needed + # TODO(will): add support for other dtypes + return self.load_model( + component_model_path, + encoder_config, + server_args, + server_args.pipeline_config.image_encoder_precision, + cpu_offload_flag=server_args.image_encoder_cpu_offload, + ) diff --git a/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/scheduler_loader.py b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/scheduler_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..a82c1047c2e667264ff2d8874b16ab849e262559 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/scheduler_loader.py @@ -0,0 +1,37 @@ +from sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import ( + ComponentLoader, +) +from sglang.multimodal_gen.runtime.models.registry import ModelRegistry +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( + get_diffusers_component_config, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class SchedulerLoader(ComponentLoader): + """Loader for scheduler.""" + + component_names = ["scheduler"] + expected_library = "diffusers" + + def load_customized( + self, component_model_path: str, server_args: ServerArgs, *args + ): + """Load the scheduler based on the model path, and inference args.""" + config = get_diffusers_component_config(component_path=component_model_path) + + class_name = config.pop("_class_name") + assert ( + class_name is not None + ), "Model config does not contain a _class_name attribute. Only diffusers format is supported." + + scheduler_cls, _ = ModelRegistry.resolve_model_cls(class_name) + + scheduler = scheduler_cls(**config) + if server_args.pipeline_config.flow_shift is not None: + scheduler.set_shift(server_args.pipeline_config.flow_shift) + + return scheduler diff --git a/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/text_encoder_loader.py b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/text_encoder_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..35c4102a70051a2a9966246feac33945345f6830 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/text_encoder_loader.py @@ -0,0 +1,305 @@ +import dataclasses +import glob +import os +from collections.abc import Generator, Iterable +from typing import Generator, Iterable, cast + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch import nn +from torch.distributed import init_device_mesh +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from sglang.multimodal_gen.configs.models import EncoderConfig, ModelConfig +from sglang.multimodal_gen.configs.pipeline_configs.qwen_image import ( + QwenImageEditPipelineConfig, +) +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import ( + ComponentLoader, +) +from sglang.multimodal_gen.runtime.loader.fsdp_load import shard_model +from sglang.multimodal_gen.runtime.loader.utils import ( + set_default_torch_dtype, + skip_init_modules, +) +from sglang.multimodal_gen.runtime.loader.weight_utils import ( + filter_duplicate_safetensors_files, + filter_files_not_needed_for_inference, + pt_weights_iterator, + safetensors_weights_iterator, +) +from sglang.multimodal_gen.runtime.models.registry import ModelRegistry +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( + get_config, + get_diffusers_component_config, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE + +logger = init_logger(__name__) + + +class TextEncoderLoader(ComponentLoader): + """Loader for text encoders.""" + + component_names = ["text_encoder"] + expected_library = "transformers" + + @dataclasses.dataclass + class Source: + """A source for weights.""" + + model_or_path: str + """The model ID or path.""" + + prefix: str = "" + """A prefix to prepend to all weights.""" + + fall_back_to_pt: bool = True + """Whether .pt weights can be used.""" + + allow_patterns_overrides: list[str] | None = None + """If defined, weights will load exclusively using these patterns.""" + + def should_offload(self, server_args, model_config: ModelConfig | None = None): + should_offload = server_args.text_encoder_cpu_offload + if not should_offload: + return False + # _fsdp_shard_conditions is in arch_config, not directly on model_config + arch_config = ( + getattr(model_config, "arch_config", model_config) if model_config else None + ) + fsdp_shard_conditions = ( + getattr(arch_config, "_fsdp_shard_conditions", []) if arch_config else [] + ) + use_cpu_offload = should_offload and len(fsdp_shard_conditions) > 0 + return use_cpu_offload + + def _prepare_weights( + self, + model_name_or_path: str, + fall_back_to_pt: bool, + allow_patterns_overrides: list[str] | None, + ) -> tuple[str, list[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + # model_name_or_path = (self._maybe_download_from_modelscope( + # model_name_or_path, revision) or model_name_or_path) + + is_local = os.path.isdir(model_name_or_path) + assert is_local, "Model path must be a local directory" + + use_safetensors = False + index_file = SAFE_WEIGHTS_INDEX_NAME + allow_patterns = ["*.safetensors", "*.bin"] + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if allow_patterns_overrides is not None: + allow_patterns = allow_patterns_overrides + + hf_folder = model_name_or_path + + hf_weights_files: list[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if use_safetensors: + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file + ) + else: + hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`" + ) + + return hf_folder, hf_weights_files, use_safetensors + + def _get_weights_iterator( + self, source: "Source", to_cpu: bool + ) -> Generator[tuple[str, torch.Tensor], None, None]: + """get an iterator for the model weights based on the load format.""" + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + source.model_or_path, + source.fall_back_to_pt, + source.allow_patterns_overrides, + ) + if use_safetensors: + weights_iterator = safetensors_weights_iterator( + hf_weights_files, to_cpu=to_cpu + ) + else: + weights_iterator = pt_weights_iterator(hf_weights_files, to_cpu=to_cpu) + + # apply the prefix. + return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) + + def _get_all_weights( + self, + model: nn.Module, + model_path: str, + to_cpu: bool, + ) -> Generator[tuple[str, torch.Tensor], None, None]: + primary_weights = TextEncoderLoader.Source( + model_path, + prefix="", + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), + allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None), + ) + yield from self._get_weights_iterator(primary_weights, to_cpu) + + secondary_weights = cast( + Iterable[TextEncoderLoader.Source], + getattr(model, "secondary_weights", ()), + ) + for source in secondary_weights: + yield from self._get_weights_iterator(source, to_cpu) + + def load_customized( + self, component_model_path: str, server_args: ServerArgs, component_name: str + ): + """Load the text encoders based on the model path, and inference args.""" + diffusers_pretrained_config = get_config( + component_model_path, trust_remote_code=True + ) + model_config = get_diffusers_component_config( + component_path=component_model_path + ) + + def is_not_first_encoder(module_name): + return "2" in module_name + + # TODO(mick): had to throw an exception for different text-encoder arch + if not is_not_first_encoder(component_name): + encoder_config = server_args.pipeline_config.text_encoder_configs[0] + encoder_config.update_model_arch(model_config) + for key, value in diffusers_pretrained_config.__dict__.items(): + setattr(encoder_config.arch_config, key, value) + encoder_dtype = server_args.pipeline_config.text_encoder_precisions[0] + else: + assert len(server_args.pipeline_config.text_encoder_configs) == 2 + encoder_config = server_args.pipeline_config.text_encoder_configs[1] + encoder_config.update_model_arch(model_config) + encoder_dtype = server_args.pipeline_config.text_encoder_precisions[1] + # TODO(will): add support for other dtypes + return self.load_model( + component_model_path, + encoder_config, + server_args, + encoder_dtype, + ) + + def load_model( + self, + model_path: str, + model_config: EncoderConfig, + server_args: ServerArgs, + dtype: str = "fp16", + cpu_offload_flag: bool | None = None, + ): + # Determine CPU offload behavior and target device + + local_torch_device = get_local_torch_device() + should_offload = self.should_offload(server_args, model_config) + + if should_offload and not current_platform.is_mps(): + model_device = torch.device("cpu") + else: + model_device = local_torch_device + + with set_default_torch_dtype(PRECISION_TO_TYPE[dtype]): + with model_device, skip_init_modules(): + architectures = getattr(model_config, "architectures", []) + model_cls, _ = ModelRegistry.resolve_model_cls(architectures) + enable_image_understanding = ( + True + if isinstance( + server_args.pipeline_config, QwenImageEditPipelineConfig + ) + else False + ) + model_config.enable_image_understanding = enable_image_understanding + model = model_cls(model_config) + + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights( + self._get_all_weights(model, model_path, to_cpu=should_offload) + ) + + # Explicitly move model to target device after loading weights + if not should_offload: + model = model.to(local_torch_device) + + if should_offload: + # Disable FSDP for MPS as it's not compatible + if current_platform.is_mps(): + logger.info( + "Disabling FSDP sharding for MPS platform as it's not compatible" + ) + model = model.to(local_torch_device) + else: + mesh = init_device_mesh( + current_platform.device_type, + mesh_shape=(1, dist.get_world_size()), + mesh_dim_names=("offload", "replicate"), + ) + shard_model( + model, + cpu_offload=True, + reshard_after_forward=True, + mesh=mesh["offload"], + fsdp_shard_conditions=model_config.arch_config._fsdp_shard_conditions + or getattr(model, "_fsdp_shard_conditions", None), + pin_cpu_memory=server_args.pin_cpu_memory, + ) + else: + model = model.to(local_torch_device) + # We only enable strict check for non-quantized models + # that have loaded weights tracking currently. + # if loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + # NOTE: + # If we silently continue with uninitialized weights, the text encoder can + # produce NaNs/garbage embeddings that later fail stage verification in a + # hard-to-debug way (e.g., `prompt_embeds` fails the NaN check). + # + # We allow a small set of known-optional parameters to be missing, but + # default to strict behavior for the rest. + allowed_missing_patterns = ( + getattr(model, "_allowed_missing_weights_patterns", []) or [] + ) + unexpected_missing = { + n + for n in weights_not_loaded + if not any(pat in n for pat in allowed_missing_patterns) + } + if unexpected_missing: + raise ValueError( + "Following text encoder weights were not initialized from checkpoint: " + f"{sorted(unexpected_missing)}. " + "This usually indicates a checkpoint/model-arch mismatch or a broken " + "weight-name mapping. If these are truly optional, set " + "`model._allowed_missing_weights_patterns` to whitelist patterns." + ) + logger.warning( + "Following (allowed) text encoder weights were not initialized from " + "checkpoint: %s (allowed patterns: %s)", + sorted(weights_not_loaded), + allowed_missing_patterns, + ) + + return model diff --git a/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..3c18e16c9d8d48bbda00c15df74b7df4b53b97ee --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py @@ -0,0 +1,209 @@ +import json +import logging +import os +from typing import Any, Dict, List, Optional + +import torch + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.layers.quantization.configs.nunchaku_config import ( + NunchakuConfig, + _patch_nunchaku_scales, +) +from sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import ( + ComponentLoader, +) +from sglang.multimodal_gen.runtime.loader.fsdp_load import maybe_load_fsdp_model +from sglang.multimodal_gen.runtime.loader.utils import ( + _list_safetensors_files, + _normalize_component_type, +) +from sglang.multimodal_gen.runtime.models.registry import ModelRegistry +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( + get_diffusers_component_config, + get_metadata_from_safetensors_file, + get_quant_config, + get_quant_config_from_safetensors_metadata, + maybe_download_model, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import get_log_level, init_logger +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE + +logger = init_logger(__name__) + + +class TransformerLoader(ComponentLoader): + """Shared loader for (video/audio) DiT transformers.""" + + component_names = ["transformer", "audio_dit", "video_dit"] + expected_library = "diffusers" + + def get_list_of_safetensors_to_load( + self, server_args: ServerArgs, component_model_path: str + ) -> list[str]: + """ + get list of safetensors to load. + + If --transformer-weights-path is provided, load weights from that path + instead of the base model's component directory. + """ + quantized_path = server_args.transformer_weights_path + + if quantized_path: + quantized_path = maybe_download_model(quantized_path) + logger.info("using quantized transformer weights from: %s", quantized_path) + if os.path.isfile(quantized_path) and quantized_path.endswith( + ".safetensors" + ): + safetensors_list = [quantized_path] + else: + safetensors_list = _list_safetensors_files(quantized_path) + else: + safetensors_list = _list_safetensors_files(component_model_path) + + if not safetensors_list: + raise ValueError( + f"no safetensors files found in " + f"{quantized_path or component_model_path}" + ) + + return safetensors_list + + def _resolve_quant_config( + self, + hf_config: Dict[str, List[str]], + server_args: ServerArgs, + safetensors_list: list[str], + ) -> Optional[dict]: + # priority: model config.json → safetensors metadata → nunchaku config + quant_config = get_quant_config(hf_config) + if quant_config is None and server_args.transformer_weights_path: + # try to read quantization_config from the safetensors metadata header + for safetensors_file in safetensors_list: + quant_config = get_quant_config_from_safetensors_metadata( + safetensors_file + ) + if quant_config: + break + return quant_config + + def _resolve_target_param_dtype( + self, + quant_config: Optional[dict], + nunchaku_config: Optional[NunchakuConfig], + model_cls, + server_args: ServerArgs, + ) -> Optional[torch.dtype]: + if quant_config is not None or nunchaku_config is not None: + # TODO: improve the condition + # respect dtype from checkpoint + param_dtype = None + else: + param_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.dit_precision] + + if nunchaku_config is not None: + nunchaku_config.model_cls = model_cls + # verify that the nunchaku checkpoint matches the selected model class + original_dit_cls_name = json.loads( + get_metadata_from_safetensors_file( + nunchaku_config.transformer_weights_path + ).get("config") + )["_class_name"] + specified_dit_cls_name = str(model_cls.__name__) + if original_dit_cls_name != specified_dit_cls_name: + raise Exception( + f"Class name of DiT specified in nunchaku transformer_weights_path: {original_dit_cls_name} does not match that of specified DiT name: {specified_dit_cls_name}" + ) + + return param_dtype + + def load_customized( + self, component_model_path: str, server_args: ServerArgs, component_name: str + ): + """Load the transformer based on the model path, and inference args.""" + # 1. hf config + config = get_diffusers_component_config(component_path=component_model_path) + + # 2. quant config + safetensors_list = self.get_list_of_safetensors_to_load( + server_args, component_model_path + ) + quant_config = self._resolve_quant_config(config, server_args, safetensors_list) + + # 3. dit config + # Config from Diffusers supersedes sgl_diffusion's model config + component_name = _normalize_component_type(component_name) + server_args.model_paths[component_name] = component_model_path + if component_name in ("transformer", "video_dit"): + pipeline_dit_config_attr = "dit_config" + elif component_name in ("audio_dit",): + pipeline_dit_config_attr = "audio_dit_config" + else: + raise ValueError(f"Invalid module name: {component_name}") + dit_config = getattr(server_args.pipeline_config, pipeline_dit_config_attr) + dit_config.update_model_arch(config) + + cls_name = config.pop("_class_name") + model_cls, _ = ModelRegistry.resolve_model_cls(cls_name) + + nunchaku_config = server_args.nunchaku_config + param_dtype = self._resolve_target_param_dtype( + quant_config, nunchaku_config, model_cls, server_args + ) + + logger.info( + "Loading %s from %s safetensors file(s) %s, param_dtype: %s", + cls_name, + len(safetensors_list), + f": {safetensors_list}" if get_log_level() == logging.DEBUG else "", + param_dtype, + ) + + # prepare init_param + init_params: dict[str, Any] = { + "config": dit_config, + "hf_config": config, + "quant_config": (quant_config if quant_config else nunchaku_config), + } + if ( + init_params["quant_config"] is None + and server_args.transformer_weights_path is not None + ): + logger.warning( + f"transformer_weights_path provided, but quantization config not resolved, which is unexpected and likely to cause errors" + ) + else: + logger.debug("quantization config: %s", init_params["quant_config"]) + + # Load the model using FSDP loader + model = maybe_load_fsdp_model( + model_cls=model_cls, + init_params=init_params, + weight_dir_list=safetensors_list, + device=get_local_torch_device(), + hsdp_replicate_dim=server_args.hsdp_replicate_dim, + hsdp_shard_dim=server_args.hsdp_shard_dim, + cpu_offload=server_args.dit_cpu_offload, + pin_cpu_memory=server_args.pin_cpu_memory, + fsdp_inference=server_args.use_fsdp_inference, + # TODO(will): make these configurable + param_dtype=param_dtype, + reduce_dtype=torch.float32, + output_dtype=None, + strict=False, + ) + + if nunchaku_config is not None: + _patch_nunchaku_scales(model, safetensors_list) + + total_params = sum(p.numel() for p in model.parameters()) + logger.info("Loaded model with %.2fB parameters", total_params / 1e9) + + # considering the existent of mixed-precision models (e.g., nunchaku) + if next(model.parameters()).dtype != param_dtype and param_dtype: + logger.warning( + f"Model dtype does not match expected param dtype, {next(model.parameters()).dtype} vs {param_dtype}" + ) + + return model diff --git a/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/vae_loader.py b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/vae_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..e3a9dadf42d1dcb53ca3f2c5c82955f002381eb1 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/vae_loader.py @@ -0,0 +1,154 @@ +import importlib.util +import os + +import torch +import torch.nn as nn +from safetensors.torch import load_file as safetensors_load_file + +from sglang.multimodal_gen import envs +from sglang.multimodal_gen.configs.models import ModelConfig +from sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import ( + ComponentLoader, +) +from sglang.multimodal_gen.runtime.loader.utils import ( + _list_safetensors_files, + set_default_torch_dtype, + skip_init_modules, +) +from sglang.multimodal_gen.runtime.models.registry import ModelRegistry +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( + get_diffusers_component_config, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE + +logger = init_logger(__name__) + + +def _convert_conv3d_weights_to_channels_last_3d(module: nn.Module) -> int: + """ + Convert Conv3d weights to channels_last_3d (NDHWC) memory format. + Returns the number of Conv3d modules converted. + """ + if not hasattr(torch, "channels_last_3d"): + return 0 + num_converted = 0 + for m in module.modules(): + if isinstance(m, nn.Conv3d): + try: + m.weight.data = m.weight.data.to(memory_format=torch.channels_last_3d) + num_converted += 1 + except Exception: + # Best-effort; skip unsupported cases. + continue + return num_converted + + +class VAELoader(ComponentLoader): + """Shared loader for (video/audio) VAE modules.""" + + component_names = ["vae", "audio_vae", "video_vae"] + expected_library = "diffusers" + + def should_offload( + self, server_args: ServerArgs, model_config: ModelConfig | None = None + ): + return server_args.vae_cpu_offload + + def load_customized( + self, component_model_path: str, server_args: ServerArgs, component_name: str + ): + """Load the VAE based on the model path, and inference args.""" + config = get_diffusers_component_config(component_path=component_model_path) + class_name = config.pop("_class_name", None) + assert ( + class_name is not None + ), "Model config does not contain a _class_name attribute. Only diffusers format is supported." + + server_args.model_paths[component_name] = component_model_path + + if component_name in ("vae", "video_vae"): + pipeline_vae_config_attr = "vae_config" + pipeline_vae_precision = "vae_precision" + elif component_name in ("audio_vae",): + pipeline_vae_config_attr = "audio_vae_config" + pipeline_vae_precision = "audio_vae_precision" + else: + raise ValueError( + f"Unsupported module name for VAE loader: {component_name}" + ) + vae_config = getattr(server_args.pipeline_config, pipeline_vae_config_attr) + vae_precision = getattr(server_args.pipeline_config, pipeline_vae_precision) + vae_config.update_model_arch(config) + if hasattr(vae_config, "post_init"): + # NOTE: some post init logics are only available after updated with config + vae_config.post_init() + + should_offload = self.should_offload(server_args) + target_device = self.target_device(should_offload) + + # Check for auto_map first (custom VAE classes) + auto_map = config.get("auto_map", {}) + auto_model_map = auto_map.get("AutoModel") + if auto_model_map: + module_path, cls_name = auto_model_map.rsplit(".", 1) + custom_module_file = os.path.join(component_model_path, f"{module_path}.py") + spec = importlib.util.spec_from_file_location("_custom", custom_module_file) + custom_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(custom_module) + vae_cls = getattr(custom_module, cls_name) + vae_dtype = PRECISION_TO_TYPE[vae_precision] + with set_default_torch_dtype(vae_dtype): + vae = vae_cls.from_pretrained( + component_model_path, + revision=server_args.revision, + trust_remote_code=server_args.trust_remote_code, + ) + vae = vae.to(device=target_device, dtype=vae_dtype) + if ( + component_name in ("vae", "video_vae") + and torch.cuda.is_available() + and getattr(envs, "SGLANG_DIFFUSION_VAE_CHANNELS_LAST_3D", False) + ): + n = _convert_conv3d_weights_to_channels_last_3d(vae) + if n > 0: + logger.info( + "VAE: converted %d Conv3d weights to channels_last_3d", n + ) + return vae + + # Load from ModelRegistry (standard VAE classes) + with ( + set_default_torch_dtype(PRECISION_TO_TYPE[vae_precision]), + skip_init_modules(), + ): + vae_cls, _ = ModelRegistry.resolve_model_cls(class_name) + vae = vae_cls(vae_config).to(target_device) + + safetensors_list = _list_safetensors_files(component_model_path) + assert ( + len(safetensors_list) == 1 + ), f"Found {len(safetensors_list)} safetensors files in {component_model_path}" + loaded = safetensors_load_file(safetensors_list[0]) + vae.load_state_dict(loaded, strict=False) + + state_keys = set(vae.state_dict().keys()) + loaded_keys = set(loaded.keys()) + missing_keys = sorted(state_keys - loaded_keys) + unexpected_keys = sorted(loaded_keys - state_keys) + if missing_keys: + logger.warning("VAE missing keys: %s", missing_keys) + if unexpected_keys: + logger.warning("VAE unexpected keys: %s", unexpected_keys) + + if ( + component_name in ("vae", "video_vae") + and torch.cuda.is_available() + and getattr(envs, "SGLANG_DIFFUSION_VAE_CHANNELS_LAST_3D", False) + ): + n = _convert_conv3d_weights_to_channels_last_3d(vae) + if n > 0: + logger.info("VAE: converted %d Conv3d weights to channels_last_3d", n) + + return vae diff --git a/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/vl_encoder_loader.py b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/vl_encoder_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..8d60bfdaec567d38e425e2e2803a1275c8c010f5 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/vl_encoder_loader.py @@ -0,0 +1,41 @@ +from typing import Any + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import ( + ComponentLoader, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import get_hf_config + + +class VisionLanguageEncoderLoader(ComponentLoader): + """Loader for vision language encoder (typically Causal LM or Vision2Seq).""" + + component_names = ["vision_language_encoder"] + expected_library = "transformers" + + def load_customized( + self, + component_model_path: str, + server_args: ServerArgs, + transformers_or_diffusers: str = "vision_language_encoder", + ) -> Any: + if transformers_or_diffusers == "vision_language_encoder": + from transformers import GlmImageForConditionalGeneration + + config = get_hf_config( + component_model_path, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + ) + model = GlmImageForConditionalGeneration.from_pretrained( + component_model_path, + config=config, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + ).to(get_local_torch_device()) + return model + else: + raise ValueError( + f"Unsupported library for VisionLanguageEncoder: {transformers_or_diffusers}" + ) diff --git a/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/vocoder_loader.py b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/vocoder_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..8e8d6abc9925c1a6ebe4cfce1a66b7bdf8ffdfca --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/loader/component_loaders/vocoder_loader.py @@ -0,0 +1,88 @@ +from safetensors.torch import load_file as safetensors_load_file + +from sglang.multimodal_gen.configs.models import ModelConfig +from sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import ( + ComponentLoader, +) +from sglang.multimodal_gen.runtime.loader.utils import ( + _list_safetensors_files, + set_default_torch_dtype, + skip_init_modules, +) +from sglang.multimodal_gen.runtime.models.registry import ModelRegistry +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( + get_diffusers_component_config, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE + +logger = init_logger(__name__) + + +class VocoderLoader(ComponentLoader): + component_names = ["vocoder"] + expected_library = "diffusers" + + def should_offload( + self, server_args: ServerArgs, model_config: ModelConfig | None = None + ): + return server_args.vae_cpu_offload + + def load_customized( + self, component_model_path: str, server_args: ServerArgs, component_name: str + ): + config = get_diffusers_component_config(component_path=component_model_path) + class_name = config.pop("_class_name", None) + assert ( + class_name is not None + ), "Model config does not contain a _class_name attribute. Only diffusers format is supported." + + server_args.model_paths[component_name] = component_model_path + + from sglang.multimodal_gen.configs.models.vocoder.ltx_vocoder import ( + LTXVocoderConfig, + ) + + vocoder_config = LTXVocoderConfig() + vocoder_config.update_model_arch(config) + + try: + vocoder_precision = server_args.pipeline_config.audio_vae_precision + except AttributeError: + vocoder_precision = "fp32" + vocoder_dtype = PRECISION_TO_TYPE[vocoder_precision] + + should_offload = self.should_offload(server_args) + target_device = self.target_device(should_offload) + + with set_default_torch_dtype(vocoder_dtype), skip_init_modules(): + vocoder_cls, _ = ModelRegistry.resolve_model_cls(class_name) + vocoder = vocoder_cls(vocoder_config).to(target_device) + + safetensors_list = _list_safetensors_files(component_model_path) + assert ( + len(safetensors_list) == 1 + ), f"Found {len(safetensors_list)} safetensors files in {component_model_path}" + loaded = safetensors_load_file(safetensors_list[0]) + incompatible = vocoder.load_state_dict(loaded, strict=False) + missing_keys = [] + unexpected_keys = [] + try: + missing_keys = incompatible.missing_keys + unexpected_keys = incompatible.unexpected_keys + except AttributeError: + # Best-effort fallback in case older torch returns a tuple-like. + try: + missing_keys = incompatible[0] + unexpected_keys = incompatible[1] + except Exception: + pass + + if missing_keys or unexpected_keys: + logger.warning( + "Loaded vocoder with missing_keys=%d unexpected_keys=%d", + len(missing_keys), + len(unexpected_keys), + ) + return vocoder diff --git a/sglang/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py b/sglang/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py new file mode 100644 index 0000000000000000000000000000000000000000..00c20138d886382a03d7735fd2df3dcc492642a9 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py @@ -0,0 +1,411 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from torchtune +# Copyright 2024 The TorchTune Authors. +# Copyright 2025 The sglang-diffusion Authors. + +from collections.abc import Callable, Generator +from itertools import chain +from typing import Any + +import torch +from torch import nn +from torch.distributed import DeviceMesh, init_device_mesh +from torch.distributed._tensor import distribute_tensor +from torch.distributed.fsdp import ( + CPUOffloadPolicy, + FSDPModule, + MixedPrecisionPolicy, + fully_shard, +) +from torch.nn.modules.module import _IncompatibleKeys + +from sglang.multimodal_gen.runtime.loader.utils import ( + get_param_names_mapping, + hf_to_custom_state_dict, + set_default_torch_dtype, +) +from sglang.multimodal_gen.runtime.loader.weight_utils import ( + safetensors_weights_iterator, +) +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import set_mixed_precision_policy + +logger = init_logger(__name__) + + +def _make_param_like( + actual_param: torch.nn.Parameter, tensor: torch.Tensor +) -> torch.nn.Parameter: + cls = actual_param.__class__ + # nn.Parameter defaults to requires_grad=True, which is illegal for non-floating/complex dtypes (e.g., int8/FP8 + # quantized weights). + try: + new_param = cls.__new__(cls, tensor, requires_grad=False) + except TypeError: + new_param = cls.__new__(cls, tensor) + new_param.__dict__.update(actual_param.__dict__) + new_param.requires_grad = False + return new_param + + +# TODO(PY): add compile option +def maybe_load_fsdp_model( + model_cls: type[nn.Module], + init_params: dict[str, Any], + weight_dir_list: list[str], + device: torch.device, + hsdp_replicate_dim: int, + hsdp_shard_dim: int, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + cpu_offload: bool = False, + fsdp_inference: bool = False, + output_dtype: torch.dtype | None = None, + pin_cpu_memory: bool = True, + strict: bool = True, +) -> torch.nn.Module: + """Load a model with optional FSDP (Fully Sharded Data Parallel) support. + + Args: + param_dtype: Data type for model parameters, also used for: + - Model initialization context (set_default_torch_dtype) + - FSDP mixed precision policy + - Weight loading and casting + reduce_dtype: Data type for gradient reduction in FSDP mixed precision. + strict: If True, enforce strict state dict loading (all keys must match). + """ + # NOTE(will): cast_forward_inputs=True shouldn't be needed as we are + # manually casting the inputs to the model + default_torch_dtype = param_dtype if param_dtype else torch.bfloat16 + mp_policy = MixedPrecisionPolicy( + default_torch_dtype, reduce_dtype, output_dtype, cast_forward_inputs=False + ) + + set_mixed_precision_policy( + param_dtype=default_torch_dtype, + reduce_dtype=reduce_dtype, + output_dtype=output_dtype, + mp_policy=mp_policy, + ) + + with set_default_torch_dtype(default_torch_dtype), torch.device("meta"): + model = model_cls(**init_params) + + # Check if we should use FSDP + use_fsdp = fsdp_inference + + # Disable FSDP for MPS as it's not compatible + if current_platform.is_mps(): + use_fsdp = False + logger.info("Disabling FSDP for MPS platform as it's not compatible") + + if use_fsdp: + world_size = hsdp_replicate_dim * hsdp_shard_dim + if not fsdp_inference: + hsdp_replicate_dim = world_size + hsdp_shard_dim = 1 + + device_mesh = init_device_mesh( + current_platform.device_type, + # (Replicate(), Shard(dim=0)) + mesh_shape=(hsdp_replicate_dim, hsdp_shard_dim), + mesh_dim_names=("replicate", "shard"), + ) + shard_model( + model, + cpu_offload=cpu_offload, + reshard_after_forward=True, + mp_policy=mp_policy, + mesh=device_mesh, + fsdp_shard_conditions=model._fsdp_shard_conditions, + pin_cpu_memory=pin_cpu_memory, + ) + + weight_iterator = safetensors_weights_iterator(weight_dir_list) + param_names_mapping_fn = get_param_names_mapping(model.param_names_mapping) + load_model_from_full_model_state_dict( + model, + weight_iterator, + device, + param_dtype, + strict=strict, + cpu_offload=cpu_offload, + param_names_mapping=param_names_mapping_fn, + ) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None and hasattr( + quant_method, "process_weights_after_loading" + ): + quant_method.process_weights_after_loading(module) + + for n, p in chain(model.named_parameters(), model.named_buffers()): + if p.is_meta: + raise RuntimeError(f"Unexpected param or buffer {n} on meta device.") + # Avoid unintended computation graph accumulation during inference + if isinstance(p, torch.nn.Parameter): + p.requires_grad = False + return model + + +def shard_model( + model, + *, + cpu_offload: bool, + reshard_after_forward: bool = True, + mp_policy: MixedPrecisionPolicy | None = MixedPrecisionPolicy(), # noqa + mesh: DeviceMesh | None = None, + fsdp_shard_conditions: list[Callable[[str, nn.Module], bool]] = [], # noqa + pin_cpu_memory: bool = True, +) -> None: + """ + Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API. + + This method will over the model's named modules from the bottom-up and apply shard modules + based on whether they meet any of the criteria from shard_conditions. + + Args: + model (TransformerDecoder): Model to shard with FSDP. + cpu_offload (bool): If set to True, FSDP will offload parameters, gradients, and optimizer + states to CPU. + reshard_after_forward (bool): Whether to reshard parameters and buffers after + the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy + from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy. + mesh (Optional[DeviceMesh]): Device mesh to use for FSDP sharding under multiple parallelism. + Default to None. + fsdp_shard_conditions (List[Callable[[str, nn.Module], bool]]): A list of functions to determine + which modules to shard with FSDP. + pin_cpu_memory (bool): If set to True, FSDP will pin the CPU memory of the offloaded parameters. + + """ + if fsdp_shard_conditions is None or len(fsdp_shard_conditions) == 0: + logger.warning( + "The FSDP shard condition list is empty or None. No modules will be sharded in %s", + type(model).__name__, + ) + return + + fsdp_kwargs = { + "reshard_after_forward": reshard_after_forward, + "mesh": mesh, + "mp_policy": mp_policy, + } + if cpu_offload: + fsdp_kwargs["offload_policy"] = CPUOffloadPolicy(pin_memory=pin_cpu_memory) + + # iterating in reverse to start with + # lowest-level modules first + num_layers_sharded = 0 + # TODO(will): don't reshard after forward for the last layer to save on the + # all-gather that will immediately happen Shard the model with FSDP, + for n, m in reversed(list(model.named_modules())): + if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]): # type: ignore + fully_shard(m, **fsdp_kwargs) + num_layers_sharded += 1 + + if num_layers_sharded == 0: + raise ValueError( + "No layer modules were sharded. Please check if shard conditions are working as expected." + ) + + # Finally shard the entire model to account for any stragglers + fully_shard(model, **fsdp_kwargs) + + +# TODO(PY): device mesh for cfg parallel +def load_model_from_full_model_state_dict( + model: FSDPModule | torch.nn.Module, + full_sd_iterator: Generator[tuple[str, torch.Tensor], None, None], + device: torch.device, + param_dtype: torch.dtype | None, + strict: bool = False, + cpu_offload: bool = False, + param_names_mapping: Callable[[str], tuple[str, Any, Any]] | None = None, +) -> _IncompatibleKeys: + """ + Converting full state dict into a sharded state dict + and loading it into FSDP model (if training) or normal huggingface model + Args: + model (Union[FSDPModule, torch.nn.Module]): Model to generate fully qualified names for cpu_state_dict + full_sd_iterator (Generator): an iterator yielding (param_name, tensor) pairs + device (torch.device): device used to move full state dict tensors + param_dtype (torch.dtype): dtype used to move full state dict tensors. If none, respect original dtype from checkpoint + strict (bool): flag to check if to load the model in strict mode + cpu_offload (bool): flag to check if FSDP offload is enabled + param_names_mapping (Optional[Callable[[str], str]]): a function that maps full param name to sharded param name + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + """ + meta_sd = model.state_dict() + param_dict = dict(model.named_parameters()) + + # map names from checkpoint to customized names + custom_param_sd, reverse_param_names_mapping = hf_to_custom_state_dict( + full_sd_iterator, param_names_mapping + ) # type: ignore + + is_fsdp_model = isinstance(model, FSDPModule) or any( + hasattr(p, "device_mesh") for p in meta_sd.values() + ) + + # sort parameter names to ensure all ranks process parameters in the same order + sorted_param_names = sorted(custom_param_sd.keys()) + + sharded_sd = {} + skipped_checkpoint_keys: list[str] = [] + + # shard from loaded state_dict, custom_param_sd -> sharded_sd + for target_param_name in sorted_param_names: + full_tensor = custom_param_sd[target_param_name] + meta_sharded_param = meta_sd.get(target_param_name) + + if meta_sharded_param is None: + # For FSDP models, ensure all ranks process parameters consistently + if strict or is_fsdp_model: + raise ValueError( + f"Parameter {target_param_name} not found in custom model state dict. The hf to custom mapping may be incorrect." + ) + else: + skipped_checkpoint_keys.append(target_param_name) + continue + + # use meta param dtype so quantized params (e.g. FP8) keep their dtype; + # for non-quantized models meta dtype equals param_dtype anyway + if meta_sharded_param is None: + # for nunchaku, some scales are patched later + target_dtype = full_tensor.dtype + else: + target_dtype = meta_sharded_param.dtype + + if not hasattr(meta_sharded_param, "device_mesh"): + full_tensor = full_tensor.to(device=device, dtype=target_dtype) + actual_param = param_dict.get(target_param_name) + weight_loader = ( + getattr(actual_param, "weight_loader", None) + if actual_param is not None + else None + ) + if weight_loader is not None: + assert actual_param is not None + sharded_tensor = torch.empty_like( + meta_sharded_param, device=device, dtype=target_dtype + ) + # Preserve requires_grad flag to avoid errors with non-floating dtypes + requires_grad = getattr(meta_sharded_param, "requires_grad", False) + temp_param = _make_param_like(actual_param, sharded_tensor) + if not ( + sharded_tensor.is_floating_point() or sharded_tensor.is_complex() + ): + requires_grad = False + temp_param.requires_grad = requires_grad + weight_loader(temp_param, full_tensor) + sharded_tensor = temp_param.data + else: + # In cases where parts of the model aren't sharded, some parameters will be plain tensors + sharded_tensor = full_tensor + + # Important: `cpu_offload` is intended for FSDP-managed parameter movement. + # If a parameter is not sharded into a DTensor (i.e., no `device_mesh`), FSDP + # will NOT manage it. Offloading it here would leave CPU parameters that + # later participate in GPU kernels (e.g., conv/embedding), causing device/dtype + # mismatches like "Input type (CUDABFloat16Type) and weight type (CPUBFloat16Type)". + # + # Therefore: + # - For non-FSDP models, keep the historical behavior (allow CPU offload). + # - For FSDP models, do NOT offload non-sharded parameters here. + if cpu_offload and not is_fsdp_model: + sharded_tensor = sharded_tensor.cpu() + else: + full_tensor = full_tensor.to(device=device, dtype=target_dtype) + sharded_tensor = distribute_tensor( + full_tensor, + meta_sharded_param.device_mesh, + meta_sharded_param.placements, + ) + if cpu_offload: + sharded_tensor = sharded_tensor.to("cpu") + + requires_grad = False + sharded_sd[target_param_name] = nn.Parameter( + sharded_tensor, requires_grad=requires_grad + ) + + model.reverse_param_names_mapping = reverse_param_names_mapping + + if skipped_checkpoint_keys: + logger.warning( + "Checkpoint keys not loaded (no matching model parameter) %s", + ( + skipped_checkpoint_keys[:20] + if len(skipped_checkpoint_keys) > 20 + else skipped_checkpoint_keys + ), + ) + if len(skipped_checkpoint_keys) > 20: + logger.warning( + "... and %d more skipped keys.", + len(skipped_checkpoint_keys) - 20, + ) + + # parameters in nn.Module that doesn't exist in safetensor files + unused_keys = set(meta_sd.keys()) - set(sharded_sd.keys()) + if unused_keys: + logger.warning("Found unloaded parameters in meta state dict: %s", unused_keys) + + # for nunchaku + ALLOWED_NEW_PARAM_PATTERNS = [ + "gate_compress", + "wcscales", + "wtscale", + "bias", + ] + for new_param_name in unused_keys: + # check unallowed missing params + if not any(pattern in new_param_name for pattern in ALLOWED_NEW_PARAM_PATTERNS): + logger.error( + "Unsupported new parameter: %s. Allowed patterns: %s", + new_param_name, + ALLOWED_NEW_PARAM_PATTERNS, + ) + raise ValueError( + f"New parameter '{new_param_name}' is not supported. " + f"Currently only parameters containing {ALLOWED_NEW_PARAM_PATTERNS} are allowed." + ) + + meta_sharded_param = meta_sd.get(new_param_name) + meta_sharded_param_dtype = meta_sharded_param.dtype + + if "wcscales" in new_param_name or "wtscale" in new_param_name: + init_like = torch.ones_like + else: + init_like = torch.zeros_like + + if not hasattr(meta_sharded_param, "device_mesh"): + sharded_tensor = init_like( + meta_sharded_param, device=device, dtype=meta_sharded_param_dtype + ) + if cpu_offload and not is_fsdp_model: + sharded_tensor = sharded_tensor.cpu() + else: + full_tensor = init_like( + meta_sharded_param, device=device, dtype=meta_sharded_param_dtype + ) + sharded_tensor = distribute_tensor( + full_tensor, + meta_sharded_param.device_mesh, + meta_sharded_param.placements, + ) + if cpu_offload: + sharded_tensor = sharded_tensor.cpu() + sharded_sd[new_param_name] = nn.Parameter(sharded_tensor) + + # choose `assign=True` since we cannot call `copy_` on meta tensor + return model.load_state_dict(sharded_sd, strict=strict, assign=True) diff --git a/sglang/python/sglang/multimodal_gen/runtime/loader/utils.py b/sglang/python/sglang/multimodal_gen/runtime/loader/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b18603b57ff5af25452397241c620956fa9b717c --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/loader/utils.py @@ -0,0 +1,203 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +"""Utilities for selecting and loading models.""" + +import contextlib +import glob +import os +import re +from collections import defaultdict +from collections.abc import Callable, Iterator +from typing import Any, Dict, Type + +import torch +from torch import nn + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +@contextlib.contextmanager +def set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + try: + yield + finally: + torch.set_default_dtype(old_dtype) + + +def get_param_names_mapping( + mapping_dict: dict[str, str | tuple[str, int, int]], +) -> Callable[[str], tuple[str, Any, Any]]: + """ + Creates a mapping function that transforms parameter names using regex patterns. + + Args: + mapping_dict (Dict[str, str]): Dictionary mapping regex patterns to replacement patterns + + Returns: + Callable[[str], str]: A function that maps parameter names from source to target format + """ + + def mapping_fn(name: str) -> tuple[str, Any, Any]: + # support chained conversions, e.g.: + # transformer.xxx.lora_down -> xxx.lora_down -> xxx.proj_down + merge_index = None + total_split_params = None + max_steps = max(8, len(mapping_dict) * 2) + applied_patterns: set[str] = set() + visited_names: set[str] = {name} + + for _ in range(max_steps): + transformed = False + for pattern, replacement in mapping_dict.items(): + # avoid re-applying the same rule on its own output + if pattern in applied_patterns: + continue + if re.match(pattern, name) is None: + continue + + curr_merge_index = None + curr_total_split_params = None + if isinstance(replacement, tuple): + curr_merge_index = replacement[1] + curr_total_split_params = replacement[2] + replacement = replacement[0] + + new_name = re.sub(pattern, replacement, name) + + if new_name != name: + if curr_merge_index is not None: + merge_index = curr_merge_index + total_split_params = curr_total_split_params + + name = new_name + applied_patterns.add(pattern) + if name in visited_names: + transformed = False + break + visited_names.add(name) + transformed = True + break + + if not transformed: + break + + return name, merge_index, total_split_params + + return mapping_fn + + +def hf_to_custom_state_dict( + hf_param_sd: dict[str, torch.Tensor] | Iterator[tuple[str, torch.Tensor]], + param_names_mapping: Callable[[str], tuple[str, Any, Any]], +) -> tuple[dict[str, torch.Tensor], dict[str, tuple[str, Any, Any]]]: + """ + Converts a Hugging Face parameter state dictionary to a custom parameter state dictionary. + + Args: + hf_param_sd (Dict[str, torch.Tensor]): The Hugging Face parameter state dictionary + param_names_mapping (Callable[[str], tuple[str, Any, Any]]): A function that maps parameter names from source to target format + + Returns: + custom_param_sd (Dict[str, torch.Tensor]): The custom formatted parameter state dict + reverse_param_names_mapping (Dict[str, Tuple[str, Any, Any]]): Maps back from custom to hf + """ + custom_param_sd = {} + to_merge_params = defaultdict(dict) # type: ignore + reverse_param_names_mapping = {} + if isinstance(hf_param_sd, dict): + hf_param_sd = hf_param_sd.items() # type: ignore + for source_param_name, full_tensor in hf_param_sd: # type: ignore + target_param_name, merge_index, num_params_to_merge = param_names_mapping( + source_param_name + ) + if target_param_name == "" or target_param_name is None: # type: ignore[comparison-overlap] + continue + reverse_param_names_mapping[target_param_name] = ( + source_param_name, + merge_index, + num_params_to_merge, + ) + if merge_index is not None: + to_merge_params[target_param_name][merge_index] = full_tensor + if len(to_merge_params[target_param_name]) == num_params_to_merge: + # cat at output dim according to the merge_index order + sorted_tensors = [ + to_merge_params[target_param_name][i] + for i in range(num_params_to_merge) + ] + full_tensor = torch.cat(sorted_tensors, dim=0) + del to_merge_params[target_param_name] + else: + continue + custom_param_sd[target_param_name] = full_tensor + return custom_param_sd, reverse_param_names_mapping + + +class skip_init_modules: + def __enter__(self): + # Save originals + self._orig_reset = {} + for cls in (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d): + self._orig_reset[cls] = cls.reset_parameters + cls.reset_parameters = lambda self: None # skip init + + def __exit__(self, exc_type, exc_value, traceback): + # restore originals + for cls, orig in self._orig_reset.items(): + cls.reset_parameters = orig + + +def _normalize_component_type(module_type: str) -> str: + """Normalize module types like 'text_encoder_2' -> 'text_encoder'.""" + if module_type.endswith("_2"): + return module_type[:-2] + return module_type + + +def _clean_hf_config_inplace(model_config: dict) -> None: + """Remove common extraneous HF fields if present.""" + for key in ( + "_name_or_path", + "transformers_version", + "model_type", + "tokenizer_class", + "torch_dtype", + ): + model_config.pop(key, None) + + +def _list_safetensors_files(model_path: str) -> list[str]: + """List all .safetensors files under a directory.""" + return sorted(glob.glob(os.path.join(str(model_path), "*.safetensors"))) + + +BYTES_PER_GB = 1024**3 + + +def get_memory_usage_of_component(module) -> float | None: + """ + returned value is in GB, rounded to 2 decimal digits + """ + if not isinstance(module, nn.Module): + return None + if hasattr(module, "get_memory_footprint"): + usage = module.get_memory_footprint() / BYTES_PER_GB + else: + # manually + param_size = sum(p.numel() * p.element_size() for p in module.parameters()) + buffer_size = sum(b.numel() * b.element_size() for b in module.buffers()) + + total_size_bytes = param_size + buffer_size + usage = total_size_bytes / (1024**3) + + return round(usage, 2) + + +# component name -> ComponentLoader class +component_name_to_loader_cls: Dict[str, Type[Any]] = {} diff --git a/sglang/python/sglang/multimodal_gen/runtime/loader/weight_utils.py b/sglang/python/sglang/multimodal_gen/runtime/loader/weight_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7507dc10833d666d45e1a3672b41292ec4e2979f --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/loader/weight_utils.py @@ -0,0 +1,359 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/model_loader/weight_utils.py +"""Utilities for downloading, loading, initializing and verifying model weights.""" + +import hashlib +import json +import os +import tempfile +from collections.abc import Generator, Iterable +from pathlib import Path + +import filelock +import huggingface_hub.constants +import torch +from safetensors.torch import safe_open +from torch.distributed.tensor import DTensor +from tqdm.auto import tqdm + +try: + from runai_model_streamer import SafetensorsStreamer + + HAS_RUNAI_MODEL_STREAMER = True +except ImportError: + HAS_RUNAI_MODEL_STREAMER = False + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +# use system-level temp directory for file locks, so that multiple users +# can share the same lock without error. +# lock files in the temp directory will be automatically deleted when the +# system reboots, so users will not complain about annoying lock files +temp_dir = tempfile.gettempdir() + + +def enable_hf_transfer() -> None: + """automatically activates hf_transfer""" + if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: + try: + # enable hf hub transfer if available + import hf_transfer # type: ignore # noqa + + huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True + except ImportError: + pass + + +enable_hf_transfer() + + +class DisabledTqdm(tqdm): + + def __init__(self, *args, **kwargs): + kwargs["disable"] = True + super().__init__(*args, **kwargs) + + +def get_lock(model_name_or_path: str | Path, cache_dir: str | None = None): + lock_dir = cache_dir or temp_dir + model_name_or_path = str(model_name_or_path) + os.makedirs(os.path.dirname(lock_dir), exist_ok=True) + model_name = model_name_or_path.replace("/", "-") + hash_name = hashlib.sha256(model_name.encode()).hexdigest() + # add hash to avoid conflict with old users' lock files + lock_file_name = hash_name + model_name + ".lock" + # mode 0o666 is required for the filelock to be shared across users + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666) + return lock + + +# For models like Mistral-7B-v0.3, there are both sharded +# safetensors files and a consolidated safetensors file. +# Passing both of these to the weight loader functionality breaks. +# So, we use the index_file to +# look up which safetensors files should be used. +def filter_duplicate_safetensors_files( + hf_weights_files: list[str], hf_folder: str, index_file: str +) -> list[str]: + # model.safetensors.index.json is a mapping from keys in the + # torch state_dict to safetensors file holding that weight. + index_file_name = os.path.join(hf_folder, index_file) + if not os.path.isfile(index_file_name): + return hf_weights_files + + # Iterate through the weight_map (weight_name: safetensors files) + # to identify weights that we should use. + with open(index_file_name) as f: + weight_map = json.load(f)["weight_map"] + weight_files_in_index = set() + for weight_name in weight_map: + weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name])) + # Filter out any fields that are not found in the index file. + hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index] + return hf_weights_files + + +def filter_files_not_needed_for_inference(hf_weights_files: list[str]) -> list[str]: + """ + Exclude files that are not needed for inference. + + See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 + """ + blacklist = [ + "training_args.bin", + "optimizer.bin", + "optimizer.pt", + "scheduler.pt", + "scaler.pt", + ] + hf_weights_files = [ + f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist) + ] + return hf_weights_files + + +# explicitly use pure text format, with a newline at the end +# this makes it impossible to see the animation in the progress bar +# but will avoid messing up with ray or multiprocessing, which wraps +# each line of output with some prefix. +_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 + + +def _validate_safetensors_file(file_path: str) -> bool: + """ + Validate that a safetensors file is readable and not corrupted. + + Args: + file_path: Path to the safetensors file + + Returns: + True if file is valid, False if corrupted + """ + try: + with safe_open(file_path, framework="pt", device="cpu") as f: + _ = list(f.keys()) + return True + except Exception as e: + logger.error( + "Corrupted safetensors file detected: %s - %s: %s", + file_path, + type(e).__name__, + str(e), + ) + return False + + +def safetensors_weights_iterator( + hf_weights_files: list[str], + to_cpu: bool = True, + use_runai_model_streamer: bool = HAS_RUNAI_MODEL_STREAMER, +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files.""" + enable_tqdm = ( + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + ) + device = "cpu" if to_cpu else str(get_local_torch_device()) + + # Validate files before loading + corrupted_files = [ + st_file + for st_file in hf_weights_files + if not _validate_safetensors_file(st_file) + ] + + if corrupted_files: + # Delete corrupted files (both symlink and blob if applicable) + for file_path in corrupted_files: + try: + if os.path.islink(file_path): + blob_path = os.path.realpath(file_path) + os.remove(file_path) + logger.info( + "Removed corrupted symlink: %s", os.path.basename(file_path) + ) + if os.path.exists(blob_path): + os.remove(blob_path) + logger.info( + "Removed corrupted blob: %s", os.path.basename(blob_path) + ) + elif os.path.isfile(file_path): + os.remove(file_path) + logger.info( + "Removed corrupted file: %s", os.path.basename(file_path) + ) + except Exception as e: + logger.warning("Failed to remove corrupted file %s: %s", file_path, e) + + raise RuntimeError( + f"Found {len(corrupted_files)} corrupted safetensors file(s). " + f"Files have been removed: {[os.path.basename(f) for f in corrupted_files]}. " + "Please retry - the files will be re-downloaded automatically." + ) + + if use_runai_model_streamer: + with SafetensorsStreamer() as streamer: + streamer.stream_files(hf_weights_files) + for name, tensor in streamer.get_tensors(): + if to_cpu: + yield name, tensor.clone().detach() + else: + yield name, tensor.to(device) + else: + for st_file in tqdm( + hf_weights_files, + desc="Loading safetensors checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ): + with safe_open(st_file, framework="pt", device=device) as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param + + +def _load_pt_file(bin_file: str, device: str) -> dict: + """Load a PyTorch checkpoint file, handling legacy tar format. + + PyTorch 2.6 changed the default of weights_only from False to True. + Legacy tar format files cannot be loaded with weights_only=True. + This function tries weights_only=True first, then falls back to False + for legacy tar format files from trusted sources (HuggingFace Hub). + """ + try: + return torch.load(bin_file, map_location=device, weights_only=True) + except RuntimeError as e: + if "legacy .tar format" in str(e): + logger.warning( + "Loading %s with weights_only=False (legacy tar format)", + os.path.basename(bin_file), + ) + return torch.load(bin_file, map_location=device, weights_only=False) + raise + + +def pt_weights_iterator( + hf_weights_files: list[str], + to_cpu: bool = True, +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model bin/pt files.""" + device = "cpu" if to_cpu else str(get_local_torch_device()) + enable_tqdm = ( + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + ) + for bin_file in tqdm( + hf_weights_files, + desc="Loading pt checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ): + state = _load_pt_file(bin_file, device) + yield from state.items() + del state + + +def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + """Default weight loader.""" + try: + if param.numel() == 1 and loaded_weight.numel() == 1: + # Sometimes scalar values aren't considered tensors with shapes + # so if both param and loaded_weight are a scalar, + # "broadcast" instead of copy + param.data.fill_(loaded_weight.item()) + else: + assert param.size() == loaded_weight.size(), ( + f"Attempted to load weight ({loaded_weight.size()}) " + f"into parameter ({param.size()})" + ) + + param.data.copy_(loaded_weight) + except Exception: + # NOTE: This exception is added for the purpose of setting breakpoint to + # debug weight loading issues. + raise + + +def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: + """Remap the name of FP8 k/v_scale parameters. + + This function handles the remapping of FP8 k/v_scale parameter names. + It detects if the given name ends with a suffix and attempts to remap + it to the expected name format in the model. If the remapped name is not + found in the params_dict, a warning is printed and None is returned. + + Args: + name (str): The original loaded checkpoint parameter name. + params_dict (dict): Dictionary containing the model's named parameters. + + Returns: + str: The remapped parameter name if successful, or the original name + if no remapping is needed. + None: If the remapped name is not found in params_dict. + """ + if name.endswith(".kv_scale"): + logger.warning_once( + "DEPRECATED. Found kv_scale in the checkpoint. " + "This format is deprecated in favor of separate k_scale and " + "v_scale tensors and will be removed in a future release. " + "Functionally, we will remap kv_scale to k_scale and duplicate " + "k_scale to v_scale" + ) + # NOTE: we remap the deprecated kv_scale to k_scale + remapped_name = name.replace(".kv_scale", ".attn.k_scale") + if remapped_name not in params_dict: + logger.warning_once( + f"Found kv_scale in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_name}). kv_scale is " + "not loaded." + ) + return None + return remapped_name + + possible_scale_names = [".k_scale", ".v_scale"] + modelopt_scale_names = [".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"] + for scale_name in possible_scale_names: + if name.endswith(scale_name): + if any(mo_scale_name in name for mo_scale_name in modelopt_scale_names): + remapped_name = name.replace( + f".self_attn.{scale_name[1]}_proj{scale_name}", + f".self_attn.attn{scale_name}", + ) + else: + remapped_name = name.replace(scale_name, f".attn{scale_name}") + if remapped_name not in params_dict: + logger.warning_once( + f"Found {scale_name} in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_name}). {scale_name} is " + "not loaded." + ) + return None + return remapped_name + + # If there were no matches, return the untouched param name + return name + + +def compute_weights_checksum( + named_params: Iterable[tuple[str, torch.Tensor]], +) -> str: + """Compute a SHA-256 checksum for a set of (name, tensor) pairs. + + Used to verify the correctness of weight refitting. After a refit, + compare the checksum of the in-GPU model weights against the checksum + of the on-disk tensors or the tensors in the training engine. + """ + hasher = hashlib.sha256() + for name, tensor in sorted(named_params, key=lambda x: x[0]): + hasher.update(name.encode()) + t = tensor.detach() + # DTensor doesn't support .numpy(); extract the local tensor. + if isinstance(t, DTensor): + t = t._local_tensor + hasher.update(t.cpu().contiguous().reshape(-1).view(torch.uint8).numpy().data) + return hasher.hexdigest() diff --git a/sglang/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/sglang/python/sglang/multimodal_gen/runtime/loader/weights_updater.py new file mode 100644 index 0000000000000000000000000000000000000000..f170809a738e32cf753ec28384a85ab06e290ea3 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -0,0 +1,293 @@ +""" +In-place weight updates for diffusion pipeline modules. + +This module provides WeightsUpdater, which swaps model weights at runtime +without restarting the server. It is the diffusion-engine counterpart of the +LLM engine's ModelRunner.update_weights_from_disk. + +Detailed usage of higher level API can be found in + +/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py + +Key design decisions: + +- All-or-nothing with rollback: modules are updated sequentially. If + any module fails (shape mismatch, corrupted file, etc.), every module + that was already updated is rolled back by reloading its weights from + pipeline.model_path (the last successfully-loaded checkpoint). On + success, pipeline.model_path is updated to the new model_path so + that future rollbacks target the latest good checkpoint, not the + originally-launched model. + +- Rollback failures propagate: if rollback itself fails, the exception is + not caught so the caller knows the model is in an inconsistent state. + This matches the LLM engine behaviour. + +- Offload-aware: the diffusion LayerwiseOffloadManager replaces GPU + parameters with torch.empty((1,)) placeholders while real weights live + in consolidated pinned CPU buffers. A naive param.data.copy_() would + fail with a shape mismatch. Instead, the updater dynamically detects + active offload managers and writes new weights directly into their CPU + buffers via update_cpu_weights(), bypassing the placeholders entirely. + For any layer that happens to be prefetched on GPU at update time, the + live GPU tensor is also updated so the change takes effect immediately. + This requires no extra GPU memory and does not disturb the offload state. + +- DTensor-aware: parameters that have been distributed via + torch.distributed.tensor are updated through distribute_tensor + so that each shard is correctly placed on the right device mesh. +""" + +from __future__ import annotations + +import gc +from pathlib import Path + +import torch +from torch.distributed.tensor import DTensor, distribute_tensor + +from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheMixin +from sglang.multimodal_gen.runtime.loader.utils import ( + _list_safetensors_files, +) +from sglang.multimodal_gen.runtime.loader.weight_utils import ( + safetensors_weights_iterator, +) +from sglang.multimodal_gen.runtime.pipelines.diffusers_pipeline import DiffusersPipeline +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model +from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def get_updatable_modules(pipeline) -> dict[str, torch.nn.Module]: + """Return updatable nn.Module components for the given pipeline. + + Works with both the native ComposedPipelineBase backend and the + DiffusersPipeline wrapper. + """ + if isinstance(pipeline, DiffusersPipeline): + diffusers_pipe = pipeline.get_module("diffusers_pipeline") + if diffusers_pipe is not None and diffusers_pipe.components is not None: + raw = diffusers_pipe.components + else: + raw = {} + else: + raw = pipeline.modules + return {n: m for n, m in raw.items() if isinstance(m, torch.nn.Module)} + + +def _get_weights_iter(weights_dir: str): + """Return a (name, tensor) iterator over safetensors in weights_dir.""" + safetensors_files = _list_safetensors_files(weights_dir) + if not safetensors_files: + raise FileNotFoundError(f"No safetensors files found in {weights_dir}") + return safetensors_weights_iterator(safetensors_files) + + +def _validate_weight_files( + local_model_path: str, + modules_to_update: list[tuple[str, torch.nn.Module]], +) -> tuple[dict[str, str], list[str]]: + """Check that every module has a weights directory with safetensors files. + + Returns: + (weights_map, missing) where weights_map maps module name to its + weights directory and missing lists modules without weight files. + """ + weights_map: dict[str, str] = {} + missing: list[str] = [] + for module_name, _ in modules_to_update: + weights_dir = Path(local_model_path) / module_name + if weights_dir.exists() and _list_safetensors_files(str(weights_dir)): + weights_map[module_name] = str(weights_dir) + else: + missing.append(module_name) + return weights_map, missing + + +def _load_weights_into_module(module: torch.nn.Module, weights_iter) -> None: + """Load weights into a module, handling offload-managed parameters. + + For offloaded modules, updates CPU buffers directly via + update_cpu_weights(); non-offloaded parameters use in-place copy. + """ + offload_managers: list = [] + if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers: + offload_managers = [m for m in module.layerwise_offload_managers if m.enabled] + + if offload_managers: + weight_dict = dict(weights_iter) + offloaded_names: set[str] = set() + for manager in offload_managers: + offloaded_names.update(manager.update_cpu_weights(weight_dict)) + remaining = ((n, w) for n, w in weight_dict.items() if n not in offloaded_names) + load_weights_into_model(remaining, dict(module.named_parameters())) + else: + load_weights_into_model(weights_iter, dict(module.named_parameters())) + + +def load_weights_into_model(weights_iter, model_params: dict) -> None: + """Copy weights from weights_iter into model_params in-place.""" + for name, loaded_weight in weights_iter: + if name not in model_params: + continue + param = model_params[name] + if param.shape != loaded_weight.shape: + raise ValueError( + f"Shape mismatch for {name}: model={param.shape}, loaded={loaded_weight.shape}" + ) + if isinstance(param, DTensor): + distributed_weight = distribute_tensor( + loaded_weight.to(param.dtype), + param.device_mesh, + param.placements, + ) + param._local_tensor.copy_(distributed_weight._local_tensor) + else: + param.data.copy_(loaded_weight.to(param.dtype)) + + +class WeightsUpdater: + """In-place weight updates for diffusion pipeline modules. + + Args: + pipeline: A ComposedPipelineBase (or DiffusersPipeline) instance + whose modules will be updated. The pipeline's model_path + attribute is used for rollback on failure. + """ + + def __init__(self, pipeline): + self.pipeline = pipeline + + def update_weights_from_disk( + self, + model_path: str, + flush_cache: bool = True, + target_modules: list[str] | None = None, + ) -> tuple[bool, str]: + """Update model weights from disk without restarting the server.""" + logger.info(f"Updating weights from disk: {model_path}") + + try: + modules_to_update = self._collect_modules(target_modules) + except ValueError as e: + logger.error(str(e)) + return False, str(e) + + if not modules_to_update: + error_msg = ( + f"No matching modules found for update. " + f"Requested: {target_modules}. " + f"Available nn.Module(s): {list(get_updatable_modules(self.pipeline).keys())}" + ) + logger.error(error_msg) + return False, error_msg + + try: + local_model_path = maybe_download_model(model_path) + except Exception as e: + return False, f"Failed to download model: {e}" + + weights_map, missing = _validate_weight_files( + local_model_path, modules_to_update + ) + if missing: + error_msg = ( + f"Cannot update weights: missing weight files for modules: {missing}. " + f"No partial updates allowed." + ) + logger.error(error_msg) + return False, error_msg + + logger.info( + f"Updating {len(weights_map)} modules: " + + ", ".join(f"{n} <- {p}" for n, p in weights_map.items()) + ) + + success, message = self._apply_weights(modules_to_update, weights_map) + + gc.collect() + torch.cuda.empty_cache() + + if success and flush_cache: + for _, module in modules_to_update: + if isinstance(module, TeaCacheMixin): + module.reset_teacache_state() + + logger.info(message) + return success, message + + def _collect_modules( + self, target_modules: list[str] | None + ) -> list[tuple[str, torch.nn.Module]]: + """Resolve target_modules to (name, module) pairs. + + Raises: + ValueError: If target_modules contains names not found in the pipeline. + """ + components = get_updatable_modules(self.pipeline) + + if target_modules is None: + names = list(components.keys()) + else: + unknown = [n for n in target_modules if n not in components] + if unknown: + raise ValueError( + f"Module(s) requested for update not found in pipeline: {unknown}. " + f"Available Module(s): {list(components.keys())}" + ) + names = target_modules + + return [(name, components[name]) for name in names] + + def _apply_weights( + self, + modules_to_update: list[tuple[str, torch.nn.Module]], + weights_map: dict[str, str], + ) -> tuple[bool, str]: + """Load weights into each module; rollback on first failure.""" + updated_modules: list[str] = [] + + for module_name, module in modules_to_update: + try: + weights_iter = _get_weights_iter(weights_map[module_name]) + _load_weights_into_module(module, weights_iter) + updated_modules.append(module_name) + except Exception as e: + rollback_list = updated_modules + [module_name] + logger.error( + f"Weight update failed for module '{module_name}': {e}. " + f"Rolling back {len(rollback_list)} module(s) " + f"(including partially-loaded '{module_name}'): " + f"{rollback_list}.", + exc_info=True, + ) + self._rollback(rollback_list) + return False, ( + f"Failed to update module '{module_name}': {e}. " + f"All modules rolled back to original weights." + ) + + names = ", ".join(updated_modules) + return True, f"Updated {len(updated_modules)} modules ({names})." + + def _rollback(self, updated_modules: list[str]) -> None: + """Restore updated_modules to original weights. + + If rollback itself fails the exception propagates so the caller + knows the model is in an inconsistent state. + """ + if not updated_modules: + return + original_path = maybe_download_model(self.pipeline.model_path) + for name in updated_modules: + module = self.pipeline.get_module(name) + if module is None: + continue + weights_dir = Path(original_path) / name + if not weights_dir.exists(): + continue + weights_iter = _get_weights_iter(str(weights_dir)) + _load_weights_into_module(module, weights_iter) diff --git a/sglang/python/sglang/multimodal_gen/runtime/managers/forward_context.py b/sglang/python/sglang/multimodal_gen/runtime/managers/forward_context.py new file mode 100644 index 0000000000000000000000000000000000000000..e506929c6fdbd7c2ebed9ea73c136080478308bb --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/managers/forward_context.py @@ -0,0 +1,120 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/forward_context.py +import time +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Type + +import torch + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +if TYPE_CHECKING: + from sglang.multimodal_gen.runtime.layers.attention import AttentionMetadata + from sglang.multimodal_gen.runtime.pipelines_core import Req + +logger = init_logger(__name__) + +# TODO(will): check if this is needed +# track_batchsize: bool = envs.SGLANG_DIFFUSION_LOG_BATCHSIZE_INTERVAL >= 0 +track_batchsize: bool = False +last_logging_time: float = 0 +forward_start_time: float = 0 +# batchsize_logging_interval: float = envs.SGLANG_DIFFUSION_LOG_BATCHSIZE_INTERVAL +batchsize_logging_interval: float = 1000 +batchsize_forward_time: defaultdict = defaultdict(list) + + +@dataclass +class ForwardContext: + current_timestep: int + # TODO(will): check this arg + # copy from vllm_config.compilation_config.static_forward_context + # attn_layers: Dict[str, Any] + # TODO: extend to support per-layer dynamic forward context + attn_metadata: "AttentionMetadata" # set dynamically for each forward pass + forward_batch: Optional["Req"] = None + attention_backend_cls: Optional[Type] = None + + def set_attn_backend_cls(self, attention_backend_cls: Type): + if self.attention_backend_cls: + if self.attention_backend_cls != attention_backend_cls: + raise RuntimeError( + f"Different types of attention backend in a same context detected, previous: {self.attention_backend_cls}, new: {attention_backend_cls}" + ) + else: + self.attention_backend_cls = attention_backend_cls + + +_forward_context: Optional["ForwardContext"] = None + + +def get_forward_context() -> "ForwardContext": + """Get the current forward context.""" + assert _forward_context is not None, ( + "Forward context is not set. " + "Please use `set_forward_context` to set the forward context." + ) + return _forward_context + + +# TODO(will): finalize the interface +@contextmanager +def set_forward_context( + current_timestep, attn_metadata, forward_batch: Optional["Req"] = None +): + """A context manager that stores the current forward context, + can be attention metadata, etc. + Here we can inject common logic for every model forward pass. + """ + global forward_start_time + need_to_track_batchsize = track_batchsize and attn_metadata is not None + if need_to_track_batchsize: + forward_start_time = time.perf_counter() + global _forward_context + prev_context = _forward_context + _forward_context = ForwardContext( + current_timestep=current_timestep, + attn_metadata=attn_metadata, + forward_batch=forward_batch, + ) + + try: + yield + finally: + global last_logging_time, batchsize_logging_interval + if need_to_track_batchsize: + if hasattr(attn_metadata, "num_prefill_tokens"): + # for v0 attention backends + batchsize = ( + attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens + ) + else: + # for v1 attention backends + batchsize = attn_metadata.num_input_tokens + now = time.perf_counter() + # time measurement is in milliseconds + batchsize_forward_time[batchsize].append((now - forward_start_time) * 1000) + if now - last_logging_time > batchsize_logging_interval: + last_logging_time = now + forward_stats = [] + for bs, times in batchsize_forward_time.items(): + if len(times) <= 1: + # can be cudagraph / profiling run + continue + medium = torch.quantile(torch.tensor(times), q=0.5).item() + medium = round(medium, 2) + forward_stats.append((bs, len(times), medium)) + forward_stats.sort(key=lambda x: x[1], reverse=True) + if forward_stats: + logger.info( + ( + "Batchsize forward time stats " + "(batchsize, count, median_time(ms)): %s" + ), + forward_stats, + ) + _forward_context = prev_context diff --git a/sglang/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py b/sglang/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..7b357330fc4a01a9e25f6f5f257334b57ba08f9c --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py @@ -0,0 +1,514 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import gc +import multiprocessing as mp +import os +import time +from typing import List, Union + +import torch +from setproctitle import setproctitle + +from sglang.multimodal_gen import envs +from sglang.multimodal_gen.runtime.distributed import ( + get_sp_group, + get_tp_rank, + get_tp_world_size, + maybe_init_distributed_environment_and_model_parallel, + model_parallel_is_initialized, +) +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, + get_ring_parallel_rank, + get_ring_parallel_world_size, + get_tp_group, + get_ulysses_parallel_rank, + get_ulysses_parallel_world_size, +) +from sglang.multimodal_gen.runtime.entrypoints.utils import save_outputs +from sglang.multimodal_gen.runtime.loader.weight_utils import compute_weights_checksum +from sglang.multimodal_gen.runtime.loader.weights_updater import ( + WeightsUpdater, + get_updatable_modules, +) +from sglang.multimodal_gen.runtime.pipelines_core import ( + ComposedPipelineBase, + LoRAPipeline, + Req, + build_pipeline, +) +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.server_args import PortArgs, ServerArgs +from sglang.multimodal_gen.runtime.utils.common import set_cuda_arch, set_musa_arch +from sglang.multimodal_gen.runtime.utils.layerwise_offload import ( + OffloadableDiTMixin, + iter_materialized_weights, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import ( + configure_logger, + globally_suppress_loggers, + init_logger, +) +from sglang.multimodal_gen.runtime.utils.perf_logger import ( + PerformanceLogger, + capture_memory_snapshot, +) + +logger = init_logger(__name__) + + +class GPUWorker: + """ + A worker that executes the model on a single GPU. + """ + + def __init__( + self, + local_rank: int, + rank: int, + master_port: int, + server_args: ServerArgs, + ): + self.local_rank = local_rank + self.rank = rank + self.master_port = master_port + # FIXME: should we use tcp as distribute init method? + self.server_args = server_args + self.pipeline: ComposedPipelineBase = None + + self.init_device_and_model() + self.sp_group = get_sp_group() + self.sp_cpu_group = self.sp_group.cpu_group + self.tp_group = get_tp_group() + self.tp_cpu_group = self.tp_group.cpu_group + + self.cfg_group = get_cfg_group() + self.cfg_cpu_group = self.cfg_group.cpu_group + + def init_device_and_model(self) -> None: + """Initialize the device and load the model.""" + torch.get_device_module().set_device(self.local_rank) + # Set environment variables for distributed initialization + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(self.master_port) + os.environ["LOCAL_RANK"] = str(self.local_rank) + os.environ["RANK"] = str(self.rank) + os.environ["WORLD_SIZE"] = str(self.server_args.num_gpus) + # initialize the distributed environment + maybe_init_distributed_environment_and_model_parallel( + tp_size=self.server_args.tp_size, + enable_cfg_parallel=self.server_args.enable_cfg_parallel, + ulysses_degree=self.server_args.ulysses_degree, + ring_degree=self.server_args.ring_degree, + sp_size=self.server_args.sp_degree, + dp_size=self.server_args.dp_size, + distributed_init_method=f"tcp://127.0.0.1:{self.master_port}", + dist_timeout=self.server_args.dist_timeout, + ) + + # set proc title + if model_parallel_is_initialized(): + suffix = "" + if get_tp_world_size() != 1: + tp_rank = get_tp_rank() + suffix += f"_TP{tp_rank}" + if get_ulysses_parallel_world_size() != 1: + u_rank = get_ulysses_parallel_rank() + suffix += f"_U{u_rank}" + if get_ring_parallel_world_size() != 1: + r_rank = get_ring_parallel_rank() + suffix += f"_R{r_rank}" + if get_classifier_free_guidance_world_size() != 1: + c_rank = get_classifier_free_guidance_rank() + suffix += f"_C{c_rank}" + setproctitle(f"sgl_diffusion::scheduler{suffix}") + else: + setproctitle(f"sgl_diffusion::scheduler_{self.local_rank}") + + self.pipeline = build_pipeline(self.server_args) + + # apply layerwise offload after lora is applied while building LoRAPipeline + # otherwise empty offloaded weights could fail lora converting + if self.server_args.dit_layerwise_offload: + # enable layerwise offload if possible + for module_name in [ + "transformer", + "transformer_2", + "video_dit", + "video_dit_2", + "audio_dit", + ]: + dit = self.pipeline.get_module(module_name) + if dit: + if isinstance(dit, OffloadableDiTMixin): + dit.configure_layerwise_offload(self.server_args) + else: + logger.info( + f"Module {type(dit).__name__} does not support layerwise offload. Skipping." + ) + + logger.info( + f"Worker {self.rank}: Initialized device, model, and distributed environment." + ) + + def do_mem_analysis(self, output_batch: OutputBatch): + final_snapshot = capture_memory_snapshot() + if output_batch.metrics: + output_batch.metrics.record_memory_snapshot("mem_analysis", final_snapshot) + + # for details on max_memory_reserved: https://docs.pytorch.org/docs/stable/generated/torch.cuda.memory.max_memory_reserved.html + peak_reserved_bytes = torch.get_device_module().max_memory_reserved() + peak_allocated_bytes = torch.get_device_module().max_memory_allocated() + + output_batch.peak_memory_mb = peak_reserved_bytes / (1024**2) + peak_reserved_gb = peak_reserved_bytes / (1024**3) + peak_allocated_gb = peak_allocated_bytes / (1024**3) + + remaining_gpu_mem_gb = ( + current_platform.get_device_total_memory() / (1024**3) - peak_reserved_gb + ) + can_stay_resident = self.get_can_stay_resident_components(remaining_gpu_mem_gb) + suggested_args = set() + component_to_arg = { + "vae": "--vae-cpu-offload", + "text_encoder": "--text-encoder-cpu-offload", + "text_encoder_2": "--text-encoder-cpu-offload", + "image_encoder": "--image-encoder-cpu-offload", + } + + for component in can_stay_resident: + if component == "transformer": + if self.server_args.dit_layerwise_offload: + suggested_args.add("--dit-layerwise-offload") + elif self.server_args.dit_cpu_offload: + suggested_args.add("--dit-cpu-offload") + elif component in component_to_arg: + suggested_args.add(component_to_arg[component]) + + suggested_args_str = ( + ", ".join(sorted(suggested_args)) if suggested_args else "None" + ) + + pool_overhead_gb = peak_reserved_gb - peak_allocated_gb + + logger.info( + f"Peak GPU memory: {peak_reserved_gb:.2f} GB, " + f"Peak allocated: {peak_allocated_gb:.2f} GB, " + f"Memory pool overhead: {pool_overhead_gb:.2f} GB ({pool_overhead_gb / peak_reserved_gb * 100:.1f}%), " + f"Remaining GPU memory at peak: {remaining_gpu_mem_gb:.2f} GB. " + f"Components that could stay resident (based on the last request workload): {can_stay_resident}. " + f"Related offload server args to disable: {suggested_args_str}" + ) + + def execute_forward(self, batch: List[Req]) -> OutputBatch: + """ + Execute a forward pass. + """ + assert self.pipeline is not None + req = batch[0] + output_batch = None + try: + if self.rank == 0: + torch.get_device_module().reset_peak_memory_stats() + + start_time = time.monotonic() + + # capture memory baseline before forward + if self.rank == 0 and req.metrics: + baseline_snapshot = capture_memory_snapshot() + req.metrics.record_memory_snapshot("before_forward", baseline_snapshot) + + req.log(server_args=self.server_args) + result = self.pipeline.forward(req, self.server_args) + + if isinstance(result, Req): + output_batch = OutputBatch( + output=result.output, + audio=getattr(result, "audio", None), + audio_sample_rate=getattr(result, "audio_sample_rate", None), + metrics=result.metrics, + trajectory_timesteps=getattr(result, "trajectory_timesteps", None), + trajectory_latents=getattr(result, "trajectory_latents", None), + noise_pred=getattr(result, "noise_pred", None), + trajectory_decoded=getattr(result, "trajectory_decoded", None), + ) + else: + output_batch = result + + # capture memory after forward (peak) + if self.rank == 0 and output_batch.metrics: + peak_snapshot = capture_memory_snapshot() + output_batch.metrics.record_memory_snapshot( + "after_forward", peak_snapshot + ) + + if self.rank == 0 and not req.suppress_logs: + self.do_mem_analysis(output_batch) + + duration_ms = (time.monotonic() - start_time) * 1000 + output_batch.metrics.total_duration_ms = duration_ms + + # Save output to file and return file path only if requested. Avoid the serialization + # and deserialization overhead between scheduler_client and gpu_worker. + if req.save_output and req.return_file_paths_only and self.rank == 0: + if output_batch.output is not None: + output_paths = save_outputs( + output_batch.output, + req.data_type, + req.fps, + True, + lambda idx: req.output_file_path(len(output_batch.output), idx), + audio=output_batch.audio, + audio_sample_rate=output_batch.audio_sample_rate, + output_compression=req.output_compression, + enable_frame_interpolation=req.enable_frame_interpolation, + frame_interpolation_exp=req.frame_interpolation_exp, + frame_interpolation_scale=req.frame_interpolation_scale, + frame_interpolation_model_path=req.frame_interpolation_model_path, + ) + output_batch.output_file_paths = output_paths + output_batch.output = None + + # TODO: extract to avoid duplication + if req.perf_dump_path is not None or envs.SGLANG_DIFFUSION_STAGE_LOGGING: + # Avoid logging warmup perf records that share the same request_id. + if not req.is_warmup: + PerformanceLogger.log_request_summary(metrics=output_batch.metrics) + except Exception as e: + logger.error( + f"Error executing request {req.request_id}: {e}", exc_info=True + ) + if isinstance(e, _oom_exceptions()): + logger.warning(OOM_MSG) + if output_batch is None: + output_batch = OutputBatch() + output_batch.error = f"Error executing request {req.request_id}: {e}" + return output_batch + + def get_can_stay_resident_components( + self, remaining_gpu_mem_gb: float + ) -> List[str]: + """ + Calculate which components can stay resident on GPU without being offloaded. + """ + can_stay_resident = [] + if not self.pipeline: + return can_stay_resident + + # Map memory_usage keys to server_args offload flags + # If the flag is False, the component is ALREADY resident, so we don't suggest it. + # If the flag is True, it is currently offloaded, so it's a candidate to "stay resident". + offload_flags = { + "transformer": self.server_args.dit_cpu_offload + or self.server_args.dit_layerwise_offload, + "vae": self.server_args.vae_cpu_offload, + "text_encoder": self.server_args.text_encoder_cpu_offload, + "text_encoder_2": self.server_args.text_encoder_cpu_offload, + "image_encoder": self.server_args.image_encoder_cpu_offload, + } + + for name, usage in self.pipeline.memory_usages.items(): + # Only consider components that are currently configured to be offloaded + is_offload_configured = offload_flags.get(name, False) + if not is_offload_configured: + continue + + if usage <= remaining_gpu_mem_gb: + can_stay_resident.append(name) + remaining_gpu_mem_gb -= usage + + return can_stay_resident + + def set_lora( + self, + lora_nickname: Union[str, List[str]], + lora_path: Union[str, None, List[Union[str, None]]] = None, + target: Union[str, List[str]] = "all", + strength: Union[float, List[float]] = 1.0, + ) -> OutputBatch: + """ + Set the LoRA adapter(s) for the pipeline. + Supports both single LoRA (backward compatible) and multiple LoRA adapters. + + Args: + lora_nickname: The nickname(s) of the adapter(s). Can be a string or a list of strings. + lora_path: Path(s) to the LoRA adapter(s). Can be a string, None, or a list of strings/None. + target: Which transformer(s) to apply the LoRA to. Can be a string or a list of strings. + strength: LoRA strength(s) for merge, default 1.0. Can be a float or a list of floats. + """ + if not isinstance(self.pipeline, LoRAPipeline): + return OutputBatch(error="Lora is not enabled") + self.pipeline.set_lora(lora_nickname, lora_path, target, strength) + return OutputBatch() + + def merge_lora_weights( + self, target: str = "all", strength: float = 1.0 + ) -> OutputBatch: + """ + Merge LoRA weights. + + Args: + target: Which transformer(s) to merge. + strength: LoRA strength for merge, default 1.0. + """ + if not isinstance(self.pipeline, LoRAPipeline): + return OutputBatch(error="Lora is not enabled") + self.pipeline.merge_lora_weights(target, strength) + return OutputBatch() + + def unmerge_lora_weights(self, target: str = "all") -> OutputBatch: + """ + Unmerge LoRA weights. + + Args: + target: Which transformer(s) to unmerge. + """ + if not isinstance(self.pipeline, LoRAPipeline): + return OutputBatch(error="Lora is not enabled") + self.pipeline.unmerge_lora_weights(target) + return OutputBatch() + + def list_loras(self) -> OutputBatch: + """ + List loaded LoRA adapters and current application status per module. + """ + from sglang.multimodal_gen.runtime.pipelines_core.lora_pipeline import ( + LoRAPipeline, + ) + + if not isinstance(self.pipeline, LoRAPipeline): + return OutputBatch(error="Lora is not enabled") + status = self.pipeline.get_lora_status() + return OutputBatch(output=status) + + def update_weights_from_disk( + self, + model_path: str, + flush_cache: bool = True, + target_modules: list[str] | None = None, + ) -> tuple[bool, str]: + """Update model weights from disk inplace without restarting the server.""" + if not self.pipeline: + return False, "Pipeline is not initialized" + + updater = WeightsUpdater(self.pipeline) + success, message = updater.update_weights_from_disk( + model_path, + flush_cache=flush_cache, + target_modules=target_modules, + ) + if success: + self.server_args.model_path = model_path + self.pipeline.model_path = model_path + return success, message + + def get_weights_checksum( + self, module_names: list[str] | None = None + ) -> dict[str, str]: + """Compute SHA-256 checksum of each module's weights.""" + if not self.pipeline: + return {"error": "Pipeline is not initialized"} + + all_modules = get_updatable_modules(self.pipeline) + names = module_names if module_names is not None else list(all_modules.keys()) + + checksums: dict[str, str] = {} + for name in names: + module = all_modules.get(name) + if module is None: + checksums[name] = "not_found" + continue + checksums[name] = compute_weights_checksum( + iter_materialized_weights(module) + ) + return checksums + + +OOM_MSG = f""" +OOM detected. Possible solutions: + - If the OOM occurs during loading: + 1. Enable CPU offload for memory-intensive components, or use `--dit-layerwise-offload` for DiT + - If the OOM occurs during runtime: + 1. Enable SP and/or TP (in a multi-GPU setup) + 2. Reduce the number of output tokens by lowering resolution or decreasing `--num-frames` + 3. Opt for a sparse-attention backend + 4. Enable FSDP by `--use-fsdp-inference` (in a multi-GPU setup) + 5. Enable quantization (e.g. nunchaku) + Or, open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose +""" + + +def _oom_exceptions(): + # torch.OutOfMemoryError exists only in some PyTorch builds + types = [torch.cuda.OutOfMemoryError] + if hasattr(torch, "OutOfMemoryError"): + types.append(torch.OutOfMemoryError) + return tuple(types) + + +def run_scheduler_process( + local_rank: int, + rank: int, + master_port: int, + server_args: ServerArgs, + pipe_writer: mp.connection.Connection, + # For all workers: pipe to receive tasks from rank 0 + task_pipe_r: mp.connection.Connection, + # For slave workers: pipe to send results back to rank 0 + result_pipe_w: mp.connection.Connection | None, + # For rank 0 worker only: pipes to send tasks to slaves + task_pipes_to_slaves: list[mp.connection.Connection] | None = None, + # For rank 0 worker only: pipes to receive results from slaves + result_pipes_from_slaves: list[mp.connection.Connection] | None = None, +) -> None: + """ + The entry point for the worker process. + Rank 0 acts as the master, handling ZMQ requests and coordinating slaves. + Ranks > 0 act as slaves, waiting for tasks from the master. + """ + configure_logger(server_args) + globally_suppress_loggers() + if current_platform.is_cuda(): + set_cuda_arch() + elif current_platform.is_musa(): + set_musa_arch() + + port_args = PortArgs.from_server_args(server_args) + + # start the scheduler event loop + assert task_pipes_to_slaves is not None + assert result_pipes_from_slaves is not None + from sglang.multimodal_gen.runtime.managers.scheduler import Scheduler + + try: + scheduler = Scheduler( + server_args, + gpu_id=rank, + port_args=port_args, + task_pipes_to_slaves=task_pipes_to_slaves, + result_pipes_from_slaves=result_pipes_from_slaves, + ) + logger.info(f"Worker {rank}: Scheduler loop started.") + pipe_writer.send( + { + "status": "ready", + } + ) + scheduler.event_loop() + except _oom_exceptions() as _e: + logger.warning(OOM_MSG) + raise + finally: + # Clean up resources to speed up shutdown + if "scheduler" in locals(): + del scheduler + gc.collect() + if torch.cuda.is_initialized(): + torch.cuda.empty_cache() + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + logger.info(f"Worker {rank}: Shutdown complete.") diff --git a/sglang/python/sglang/multimodal_gen/runtime/managers/scheduler.py b/sglang/python/sglang/multimodal_gen/runtime/managers/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..c11c2c8502243d97d2390c367f1e3d57826e2675 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/managers/scheduler.py @@ -0,0 +1,435 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import os +import pickle +from collections import deque +from typing import Any, List + +import zmq + +from sglang.multimodal_gen.configs.pipeline_configs.base import ModelTaskType +from sglang.multimodal_gen.runtime.entrypoints.openai.utils import ( + _parse_size, + save_image_to_path, +) +from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( + GetWeightsChecksumReqInput, + UpdateWeightFromDiskReqInput, +) +from sglang.multimodal_gen.runtime.entrypoints.utils import ( + ListLorasReq, + MergeLoraWeightsReq, + SetLoraReq, + ShutdownReq, + UnmergeLoraWeightsReq, +) +from sglang.multimodal_gen.runtime.managers.gpu_worker import GPUWorker +from sglang.multimodal_gen.runtime.pipelines_core import Req +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch +from sglang.multimodal_gen.runtime.server_args import ( + PortArgs, + ServerArgs, + set_global_server_args, +) +from sglang.multimodal_gen.runtime.utils.common import get_zmq_socket +from sglang.multimodal_gen.runtime.utils.distributed import broadcast_pyobj +from sglang.multimodal_gen.runtime.utils.logging_utils import GREEN, RESET, init_logger + +logger = init_logger(__name__) + +MINIMUM_PICTURE_BASE64_FOR_WARMUP = "data:image/jpg;base64,iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAYAAABzenr0AAAACXBIWXMAAA7EAAAOxAGVKw4bAAAAbUlEQVRYhe3VsQ2AMAxE0Y/lIgNQULD/OqyCMgCihCKSG4yRuKuiNH6JLsoEbMACOGBcua9HOR7Y6w6swBwMy0qLTpkeI77qdEBpBFAHBBDAGH8WrwJKI4AAegUCfAKgEgpQDvh3CR3oQCuav58qlAw73kKCSgAAAABJRU5ErkJggg==" + + +class Scheduler: + """ + Runs the main event loop for the rank 0 worker. + It listens for external requests via ZMQ and coordinates with other workers. + This class does NOT manage worker processes. + """ + + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + port_args: PortArgs, + task_pipes_to_slaves: list = None, + result_pipes_from_slaves: list = None, + ): + self.server_args = server_args + self.port_args = port_args + + set_global_server_args(server_args=server_args) + + # Inter-process Communication + self.context = zmq.Context(io_threads=2) + endpoint = server_args.scheduler_endpoint + if gpu_id == 0: + # router allocates identify (envelope) for each connection + self.receiver, actual_endpoint = get_zmq_socket( + self.context, zmq.ROUTER, endpoint, True + ) + logger.info(f"Scheduler bind at endpoint: {actual_endpoint}") + else: + self.receiver = None + + worker = GPUWorker( + local_rank=gpu_id, + master_port=port_args.master_port, + rank=gpu_id, + server_args=server_args, + ) + self.worker = worker + self.task_pipes_to_slaves = task_pipes_to_slaves + self.result_pipes_from_slaves = result_pipes_from_slaves + self.gpu_id = gpu_id + self._running = True + + self.request_handlers = { + SetLoraReq: self._handle_set_lora, + MergeLoraWeightsReq: self._handle_merge_lora, + UnmergeLoraWeightsReq: self._handle_unmerge_lora, + Req: self._handle_generation, + List[Req]: self._handle_generation, + ListLorasReq: self._handle_list_loras, + ShutdownReq: self._handle_shutdown, + UpdateWeightFromDiskReqInput: self._handle_update_weights_from_disk, + GetWeightsChecksumReqInput: self._handle_get_weights_checksum, + } + + # FIFO, new reqs are appended + self.waiting_queue: deque[tuple[bytes, Req]] = deque() + + # whether we've send the necessary warmup reqs + self.warmed_up = False + # warmup progress tracking + self._warmup_total = 0 + self._warmup_processed = 0 + + self.prepare_server_warmup_reqs() + + # Maximum consecutive errors before terminating the event loop + self._max_consecutive_errors = 3 + self._consecutive_error_count = 0 + + def _handle_set_lora(self, reqs: List[Any]) -> OutputBatch: + # TODO: return set status + # TODO: return with SetLoRAResponse or something more appropriate + req = reqs[0] + return self.worker.set_lora( + req.lora_nickname, req.lora_path, req.target, req.strength + ) + + def _handle_merge_lora(self, reqs: List[Any]): + req = reqs[0] + return self.worker.merge_lora_weights(req.target, req.strength) + + def _handle_unmerge_lora(self, reqs: List[Any]) -> OutputBatch: + req = reqs[0] + return self.worker.unmerge_lora_weights(req.target) + + def _handle_list_loras(self, _reqs: List[Any]) -> OutputBatch: + return self.worker.list_loras() + + def _handle_shutdown(self, _reqs: List[Any]) -> OutputBatch: + self._running = False + return OutputBatch() + + def _handle_update_weights_from_disk(self, reqs: List[Any]) -> OutputBatch: + """Handle update_weights_from_disk request for RL workflows.""" + req = reqs[0] + success, message = self.worker.update_weights_from_disk( + model_path=req.model_path, + flush_cache=req.flush_cache, + target_modules=req.target_modules, + ) + return OutputBatch( + output={"success": success, "message": message}, + error=None if success else message, + ) + + def _handle_get_weights_checksum(self, reqs: List[Any]) -> OutputBatch: + """Handle get_weights_checksum request.""" + req = reqs[0] + checksums = self.worker.get_weights_checksum(module_names=req.module_names) + return OutputBatch(output=checksums) + + def _handle_generation(self, reqs: List[Req]): + warmup_reqs = [req for req in reqs if req.is_warmup] + if warmup_reqs: + self._warmup_processed += len(warmup_reqs) + if self._warmup_total > 0: + logger.info( + f"Processing warmup req... ({self._warmup_processed}/{self._warmup_total})" + ) + else: + logger.info("Processing warmup req...") + return self.worker.execute_forward(reqs) + + def return_result( + self, + output_batch: OutputBatch, + identity: bytes | None = None, + is_warmup: bool = False, + ): + """ + replies to client, only on rank 0 + """ + if not is_warmup and self.receiver is not None and identity is not None: + self.receiver.send_multipart([identity, b"", pickle.dumps(output_batch)]) + + def get_next_batch_to_run(self) -> list[tuple[bytes, Req]] | None: + """pull a req from waiting_queue""" + if not self.waiting_queue: + return None + + # pop the first (earliest) + item = self.waiting_queue.popleft() + + return [item] + + def prepare_server_warmup_reqs(self): + if ( + self.server_args.warmup + and not self.warmed_up + and self.server_args.warmup_resolutions is not None + ): + # insert warmup reqs constructed with each warmup-resolution + self._warmup_total = len(self.server_args.warmup_resolutions) + self._warmup_processed = 0 + + for resolution in self.server_args.warmup_resolutions: + width, height = _parse_size(resolution) + task_type = self.server_args.pipeline_config.task_type + + if task_type in ( + ModelTaskType.I2I, + ModelTaskType.TI2I, + ModelTaskType.I2V, + ModelTaskType.TI2V, + ): + uploads_dir = os.path.join("outputs", "uploads") + os.makedirs(uploads_dir, exist_ok=True) + input_path = asyncio.run( + save_image_to_path( + MINIMUM_PICTURE_BASE64_FOR_WARMUP, + os.path.join(uploads_dir, "warmup_image.jpg"), + ) + ) + req = Req( + data_type=task_type.data_type(), + width=width, + height=height, + prompt="", + negative_prompt="", + image_path=[input_path], + ) + else: + req = Req( + data_type=task_type.data_type(), + width=width, + height=height, + prompt="", + ) + req.set_as_warmup(self.server_args.warmup_steps) + self.waiting_queue.append((None, req)) + # if server is warmed-up, set this flag to avoid req-based warmup + self.warmed_up = True + + def process_received_reqs_with_req_based_warmup( + self, recv_reqs: List[tuple[bytes, Any]] + ) -> List[tuple[bytes, Any]]: + if ( + self.warmed_up + or not self.server_args.warmup + or not recv_reqs + or self.server_args.warmup_resolutions is not None + ): + return recv_reqs + + # handle server req-based warmup by inserting an identical req to the beginning of the waiting queue + # only the very first req through server's lifetime will be warmed up + identity, req = recv_reqs[0] + if isinstance(req, Req): + warmup_req = req.copy_as_warmup(self.server_args.warmup_steps) + recv_reqs.insert(0, (identity, warmup_req)) + self._warmup_total = 1 + self._warmup_processed = 0 + self.warmed_up = True + return recv_reqs + + def recv_reqs(self) -> List[tuple[bytes, Any]]: + """ + For non-main schedulers, reqs are broadcasted from main using broadcast_pyobj + """ + if self.receiver is not None: + try: + try: + # Accept valid REQ envelopes only, ignore malformed/probe frames. + parts = self.receiver.recv_multipart(zmq.NOBLOCK) + identity, payload = parts[0], parts[-1] + + # Ignore malformed probes or non-pickle data + recv_reqs = pickle.loads(payload) if len(parts) > 2 else [] + except (zmq.Again, pickle.UnpicklingError, IndexError, EOFError): + recv_reqs = [] + except zmq.ZMQError: + # re-raise or handle appropriately to let the outer loop continue + raise + + if recv_reqs: + # Ensure recv_reqs is a list + if not isinstance(recv_reqs, list): + recv_reqs = [recv_reqs] + + # Pack with identity for rank 0 + recv_reqs = [(identity, req) for req in recv_reqs] + else: + recv_reqs = None + + # TODO: fix this condition + if self.server_args.sp_degree != 1: + recv_reqs = broadcast_pyobj( + recv_reqs, + self.worker.sp_group.rank, + self.worker.sp_cpu_group, + src=self.worker.sp_group.ranks[0], + ) + + if self.server_args.enable_cfg_parallel: + recv_reqs = broadcast_pyobj( + recv_reqs, + self.worker.cfg_group.rank, + self.worker.cfg_cpu_group, + src=self.worker.cfg_group.ranks[0], + ) + + if self.server_args.tp_size > 1: + recv_reqs = broadcast_pyobj( + recv_reqs, + self.worker.tp_group.rank, + self.worker.tp_cpu_group, + src=self.worker.tp_group.ranks[0], + ) + + assert recv_reqs is not None + + return recv_reqs + + def event_loop(self) -> None: + """ + The main event loop that listens for ZMQ requests. + Handles abortion + """ + + logger.debug( + f"Rank 0 scheduler listening on tcp://*:{self.server_args.scheduler_port}" + ) + + while self._running: + # 1: receive requests + try: + new_reqs = self.recv_reqs() + new_reqs = self.process_received_reqs_with_req_based_warmup(new_reqs) + self.waiting_queue.extend(new_reqs) + # Reset error count on success + self._consecutive_error_count = 0 + except Exception as e: + self._consecutive_error_count += 1 + logger.error( + f"Error receiving requests in scheduler event loop " + f"(attempt {self._consecutive_error_count}/{self._max_consecutive_errors}): {e}", + exc_info=True, + ) + if self._consecutive_error_count >= self._max_consecutive_errors: + logger.error( + f"Maximum consecutive errors ({self._max_consecutive_errors}) reached. " + "Terminating scheduler event loop." + ) + raise RuntimeError( + f"Scheduler terminated after {self._max_consecutive_errors} " + f"consecutive errors. Last error: {e}" + ) from e + continue + + # 2: execute, make sure a reply is always sent + items = self.get_next_batch_to_run() + if not items: + continue + + identities = [item[0] for item in items] + reqs = [item[1] for item in items] + + try: + processed_req = reqs[0] + handler = self.request_handlers.get(type(processed_req)) + if handler: + output_batch = handler(reqs) + else: + output_batch = OutputBatch( + error=f"Unknown request type: {type(processed_req)}" + ) + except Exception as e: + logger.error( + f"Error executing request in scheduler event loop: {e}", + exc_info=True, + ) + # Determine appropriate error response format + output_batch = ( + OutputBatch(error=str(e)) + if reqs and isinstance(reqs[0], Req) + else OutputBatch(error=str(e)) + ) + + # 3. return results + try: + # log warmup info + is_warmup = ( + processed_req.is_warmup if isinstance(processed_req, Req) else False + ) + if is_warmup: + if output_batch.error is None: + if self._warmup_total > 0: + logger.info( + f"Warmup req ({self._warmup_processed}/{self._warmup_total}) processed in {GREEN}%.2f{RESET} seconds", + output_batch.metrics.total_duration_s, + ) + else: + logger.info( + f"Warmup req processed in {GREEN}%.2f{RESET} seconds", + output_batch.metrics.total_duration_s, + ) + else: + if self._warmup_total > 0: + logger.info( + f"Warmup req ({self._warmup_processed}/{self._warmup_total}) processing failed" + ) + else: + logger.info(f"Warmup req processing failed") + + # TODO: Support sending back to multiple identities if batched + self.return_result(output_batch, identities[0], is_warmup=is_warmup) + except zmq.ZMQError as e: + # Reply failed; log and keep loop alive to accept future requests + logger.error(f"ZMQ error sending reply: {e}") + continue + + if self.receiver is not None: + self.receiver.close() + self.context.destroy(linger=0) + + def _broadcast_task(self, payload: dict[str, Any]) -> None: + """Broadcast a task to all slave worker processes.""" + method = payload["method"] + kwargs = {k: v for k, v in payload.items() if k != "method"} + task = {"method": method, "kwargs": kwargs} + for pipe in self.task_pipes_to_slaves: + pipe.send(task) + + def _collect_slave_results(self) -> List[dict[str, Any]]: + """Collect results from all slave worker processes.""" + results = [] + for pipe in self.result_pipes_from_slaves: + results.append(pipe.recv()) + return results diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/__init__.py b/sglang/python/sglang/multimodal_gen/runtime/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..af2eb7d103a81378d30056cf536353bd24621a5e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/adapter/ltx_2_connector.py b/sglang/python/sglang/multimodal_gen/runtime/models/adapter/ltx_2_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..1ec86167363621a2f198f59428db4d8d52fd14b7 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/adapter/ltx_2_connector.py @@ -0,0 +1,600 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.attention import FeedForward + +from sglang.multimodal_gen.configs.models.adapter.ltx_2_connector import ( + LTX2ConnectorConfig, +) +from sglang.multimodal_gen.runtime.layers.attention import USPAttention +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum + + +def apply_interleaved_rotary_emb( + x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor] +) -> torch.Tensor: + cos, sin = freqs + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + + +def apply_split_rotary_emb( + x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor] +) -> torch.Tensor: + cos, sin = freqs + + x_dtype = x.dtype + needs_reshape = False + if x.ndim != 4 and cos.ndim == 4: + # cos is (#b, h, t, r) -> reshape x to (b, h, t, dim_per_head) + # The cos/sin batch dim may only be broadcastable, so take batch size from x + b = x.shape[0] + _, h, t, _ = cos.shape + x = x.reshape(b, t, h, -1).transpose(1, 2) + needs_reshape = True + + # Split last dim (2*r) into (d=2, r) + last = x.shape[-1] + if last % 2 != 0: + raise ValueError( + f"Expected x.shape[-1] to be even for split rotary, got {last}." + ) + r = last // 2 + + # (..., 2, r) + split_x = x.reshape(*x.shape[:-1], 2, r) + first_x = split_x[..., :1, :] # (..., 1, r) + second_x = split_x[..., 1:, :] # (..., 1, r) + + cos_u = cos.unsqueeze(-2) # broadcast to (..., 1, r) against (..., 2, r) + sin_u = sin.unsqueeze(-2) + + out = split_x * cos_u + first_out = out[..., :1, :] + second_out = out[..., 1:, :] + + first_out.addcmul_(-sin_u, second_x) + second_out.addcmul_(sin_u, first_x) + + out = out.reshape(*out.shape[:-2], last) + + if needs_reshape: + out = out.transpose(1, 2).reshape(b, t, -1) + + out = out.to(dtype=x_dtype) + return out + + +class LTX2Attention(torch.nn.Module): + r""" + Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key + RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention. + """ + + def __init__( + self, + query_dim: int, + heads: int = 8, + kv_heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = True, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + qk_norm: str = "rms_norm_across_heads", + norm_eps: float = 1e-6, + norm_elementwise_affine: bool = True, + rope_type: str = "interleaved", + processor=None, + ): + super().__init__() + if qk_norm != "rms_norm_across_heads": + raise NotImplementedError( + "Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`." + ) + + self.head_dim = dim_head + self.inner_dim = dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.cross_attention_dim = ( + cross_attention_dim if cross_attention_dim is not None else query_dim + ) + self.use_bias = bias + self.dropout = dropout + self.out_dim = query_dim + self.heads = heads + self.rope_type = rope_type + + self.norm_q = torch.nn.RMSNorm( + dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine + ) + self.norm_k = torch.nn.RMSNorm( + dim_head * kv_heads, + eps=norm_eps, + elementwise_affine=norm_elementwise_affine, + ) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear( + self.cross_attention_dim, self.inner_kv_dim, bias=bias + ) + self.to_v = torch.nn.Linear( + self.cross_attention_dim, self.inner_kv_dim, bias=bias + ) + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + # Scaled dot product attention + self.attn = USPAttention( + num_heads=heads, + head_size=self.head_dim, + dropout_rate=0, + softmax_scale=None, + causal=False, + supported_attention_backends={ + AttentionBackendEnum.FA, + AttentionBackendEnum.AITER, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.SAGE_ATTN, + AttentionBackendEnum.SAGE_ATTN_3, + }, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = self.to_q(hidden_states) + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + query = self.norm_q(query) + key = self.norm_k(key) + + if query_rotary_emb is not None: + if self.rope_type == "interleaved": + query = apply_interleaved_rotary_emb(query, query_rotary_emb) + key = apply_interleaved_rotary_emb( + key, + key_rotary_emb if key_rotary_emb is not None else query_rotary_emb, + ) + elif self.rope_type == "split": + query = apply_split_rotary_emb(query, query_rotary_emb) + key = apply_split_rotary_emb( + key, + key_rotary_emb if key_rotary_emb is not None else query_rotary_emb, + ) + + query = query.unflatten(2, (self.heads, -1)) + key = key.unflatten(2, (self.heads, -1)) + value = value.unflatten(2, (self.heads, -1)) + + hidden_states = self.attn( + query, + key, + value, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + +class LTX2RotaryPosEmbed1d(nn.Module): + """ + 1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors. + """ + + def __init__( + self, + dim: int, + base_seq_len: int = 4096, + theta: float = 10000.0, + double_precision: bool = True, + rope_type: str = "interleaved", + num_attention_heads: int = 32, + ): + super().__init__() + if rope_type not in ["interleaved", "split"]: + raise ValueError( + f"{rope_type=} not supported. Choose between 'interleaved' and 'split'." + ) + + self.dim = dim + self.base_seq_len = base_seq_len + self.theta = theta + self.double_precision = double_precision + self.rope_type = rope_type + self.num_attention_heads = num_attention_heads + + def forward( + self, + batch_size: int, + pos: int, + device: Union[str, torch.device], + dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Get 1D position ids + grid_1d = torch.arange(pos, dtype=torch.float32, device=device) + # Get fractional indices relative to self.base_seq_len + grid_1d = grid_1d / self.base_seq_len + grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len] + + # 2. Calculate 1D RoPE frequencies + num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2 + freqs_dtype = torch.float64 if self.double_precision else torch.float32 + pow_indices = torch.pow( + self.theta, + torch.linspace( + start=0.0, + end=1.0, + steps=self.dim // num_rope_elems, + dtype=freqs_dtype, + device=device, + ), + ) + freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) + + # 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape + # (self.dim // 2,). + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2] + + # 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim + if self.rope_type == "interleaved": + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like( + cos_freqs[:, :, : self.dim % num_rope_elems] + ) + sin_padding = torch.zeros_like( + sin_freqs[:, :, : self.dim % num_rope_elems] + ) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + elif self.rope_type == "split": + expected_freqs = self.dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + # Reshape freqs to be compatible with multi-head attention + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1) + + cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + + if dtype is not None: + cos_freqs = cos_freqs.to(dtype) + sin_freqs = sin_freqs.to(dtype) + return cos_freqs, sin_freqs + + +class LTX2TransformerBlock1d(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + activation_fn: str = "gelu-approximate", + eps: float = 1e-6, + rope_type: str = "interleaved", + ): + super().__init__() + + self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) + self.attn1 = LTX2Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + rope_type=rope_type, + ) + + self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) + self.ff = FeedForward(dim, activation_fn=activation_fn) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + attn_hidden_states = self.attn1( + norm_hidden_states, + attention_mask=attention_mask, + query_rotary_emb=rotary_emb, + ) + hidden_states = hidden_states + attn_hidden_states + + norm_hidden_states = self.norm2(hidden_states) + ff_hidden_states = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_hidden_states + + return hidden_states + + +class LTX2ConnectorTransformer1d(nn.Module): + """ + A 1D sequence transformer for modalities such as text. + In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 128, + num_layers: int = 2, + num_learnable_registers: int | None = 128, + rope_base_seq_len: int = 4096, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + eps: float = 1e-6, + causal_temporal_positioning: bool = False, + rope_type: str = "interleaved", + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + + self.num_learnable_registers = num_learnable_registers + self.learnable_registers = None + if num_learnable_registers is not None: + init_registers = ( + torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0 + ) + self.learnable_registers = torch.nn.Parameter(init_registers) + + self.rope = LTX2RotaryPosEmbed1d( + self.inner_dim, + base_seq_len=rope_base_seq_len, + theta=rope_theta, + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, + ) + + self.transformer_blocks = torch.nn.ModuleList( + [ + LTX2TransformerBlock1d( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + rope_type=rope_type, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = torch.nn.RMSNorm( + self.inner_dim, eps=eps, elementwise_affine=False + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attn_mask_binarize_threshold: float = -9000.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # hidden_states shape: [batch_size, seq_len, hidden_dim] + # attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len] + batch_size, seq_len, _ = hidden_states.shape + + # 1. Replace padding with learned registers, if using + if self.learnable_registers is not None: + if seq_len % self.num_learnable_registers != 0: + raise ValueError( + f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number" + f" of learnable registers {self.num_learnable_registers}" + ) + + num_register_repeats = seq_len // self.num_learnable_registers + registers = torch.tile( + self.learnable_registers, (num_register_repeats, 1) + ) # [seq_len, inner_dim] + + binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int() + if binary_attn_mask.ndim == 4: + binary_attn_mask = binary_attn_mask.squeeze(1).squeeze( + 1 + ) # [B, 1, 1, L] --> [B, L] + + hidden_states_non_padded = [ + hidden_states[i, binary_attn_mask[i].bool(), :] + for i in range(batch_size) + ] + valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded] + pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens] + padded_hidden_states = [ + F.pad(x, pad=(0, 0, 0, p), value=0) + for x, p in zip(hidden_states_non_padded, pad_lengths) + ] + padded_hidden_states = torch.cat( + [x.unsqueeze(0) for x in padded_hidden_states], dim=0 + ) # [B, L, D] + + flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze( + -1 + ) # [B, L, 1] + hidden_states = ( + flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers + ) + + # Overwrite attention_mask with an all-zeros mask if using registers. + attention_mask = torch.zeros_like(attention_mask) + + # 2. Calculate 1D RoPE positional embeddings + rotary_emb = self.rope( + batch_size, seq_len, device=hidden_states.device, dtype=hidden_states.dtype + ) + + # 3. Run 1D transformer blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, attention_mask, rotary_emb + ) + else: + hidden_states = block( + hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb + ) + + hidden_states = self.norm_out(hidden_states) + + return hidden_states, attention_mask + + +class LTX2TextConnectors(nn.Module): + """ + Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio + streams. + """ + + def __init__( + self, + config: LTX2ConnectorConfig, + ): + super().__init__() + caption_channels = config.caption_channels + text_proj_in_factor = config.text_proj_in_factor + video_connector_num_attention_heads = config.video_connector_num_attention_heads + video_connector_attention_head_dim = config.video_connector_attention_head_dim + video_connector_num_layers = config.video_connector_num_layers + video_connector_num_learnable_registers = ( + config.video_connector_num_learnable_registers + ) + audio_connector_num_attention_heads = config.audio_connector_num_attention_heads + audio_connector_attention_head_dim = config.audio_connector_attention_head_dim + audio_connector_num_layers = config.audio_connector_num_layers + audio_connector_num_learnable_registers = ( + config.audio_connector_num_learnable_registers + ) + connector_rope_base_seq_len = config.connector_rope_base_seq_len + rope_theta = config.rope_theta + rope_double_precision = config.rope_double_precision + causal_temporal_positioning = config.causal_temporal_positioning + rope_type = config.rope_type + + self.text_proj_in = nn.Linear( + caption_channels * text_proj_in_factor, caption_channels, bias=False + ) + self.video_connector = LTX2ConnectorTransformer1d( + num_attention_heads=video_connector_num_attention_heads, + attention_head_dim=video_connector_attention_head_dim, + num_layers=video_connector_num_layers, + num_learnable_registers=video_connector_num_learnable_registers, + rope_base_seq_len=connector_rope_base_seq_len, + rope_theta=rope_theta, + rope_double_precision=rope_double_precision, + causal_temporal_positioning=causal_temporal_positioning, + rope_type=rope_type, + ) + self.audio_connector = LTX2ConnectorTransformer1d( + num_attention_heads=audio_connector_num_attention_heads, + attention_head_dim=audio_connector_attention_head_dim, + num_layers=audio_connector_num_layers, + num_learnable_registers=audio_connector_num_learnable_registers, + rope_base_seq_len=connector_rope_base_seq_len, + rope_theta=rope_theta, + rope_double_precision=rope_double_precision, + causal_temporal_positioning=causal_temporal_positioning, + rope_type=rope_type, + ) + + def forward( + self, + text_encoder_hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + additive_mask: bool = False, + ): + # Convert to additive attention mask, if necessary + if not additive_mask: + text_dtype = text_encoder_hidden_states.dtype + attention_mask = (attention_mask - 1).reshape( + attention_mask.shape[0], 1, -1, attention_mask.shape[-1] + ) + attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max + + # Ensure input dtype matches the layer's weight dtype + if text_encoder_hidden_states.dtype != self.text_proj_in.weight.dtype: + text_encoder_hidden_states = text_encoder_hidden_states.to( + self.text_proj_in.weight.dtype + ) + + # Ensure sequence length is divisible by num_learnable_registers (128) + seq_len = text_encoder_hidden_states.shape[1] + num_learnable_registers = self.video_connector.num_learnable_registers + if ( + num_learnable_registers is not None + and seq_len % num_learnable_registers != 0 + ): + pad_len = num_learnable_registers - (seq_len % num_learnable_registers) + text_encoder_hidden_states = F.pad( + text_encoder_hidden_states, (0, 0, 0, pad_len), value=0.0 + ) + + if attention_mask.shape[-1] == seq_len: + # Pad with a large negative value to mask out the new tokens + attention_mask = F.pad(attention_mask, (0, pad_len), value=-1000000.0) + + text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states) + + video_text_embedding, new_attn_mask = self.video_connector( + text_encoder_hidden_states, attention_mask + ) + + attn_mask = (new_attn_mask < 1e-6).to(torch.int64) + attn_mask = attn_mask.reshape( + video_text_embedding.shape[0], video_text_embedding.shape[1], 1 + ) + video_text_embedding = video_text_embedding * attn_mask + new_attn_mask = attn_mask.squeeze(-1) + + audio_text_embedding, _ = self.audio_connector( + text_encoder_hidden_states, attention_mask + ) + + return video_text_embedding, audio_text_embedding, new_attn_mask + + +EntryClass = LTX2TextConnectors diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/bridges/__init__.py b/sglang/python/sglang/multimodal_gen/runtime/models/bridges/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bee1e907a532583a75fbd26cb869f0a7bec4d8f7 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/bridges/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 + +from sglang.multimodal_gen.runtime.models.bridges.mova_dual_tower import ( + DualTowerConditionalBridge, +) + +__all__ = ["DualTowerConditionalBridge"] diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/bridges/mova_dual_tower.py b/sglang/python/sglang/multimodal_gen/runtime/models/bridges/mova_dual_tower.py new file mode 100644 index 0000000000000000000000000000000000000000..afc2e0e53ed47d27ddfcc41c97f4cdff22ff8c73 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/bridges/mova_dual_tower.py @@ -0,0 +1,672 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copied and adapted from: mossVG/mova/diffusion/models/interactionv2.py + + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from sglang.multimodal_gen.configs.models.bridges.mova_dual_tower import ( + MOVADualTowerConfig, +) +from sglang.multimodal_gen.runtime.distributed import get_tp_world_size +from sglang.multimodal_gen.runtime.layers.attention import USPAttention +from sglang.multimodal_gen.runtime.layers.layernorm import ( + RMSNorm, + tensor_parallel_rms_norm, +) +from sglang.multimodal_gen.runtime.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( + apply_flashinfer_rope_qk_inplace, +) +from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +@torch.no_grad() +def compute_rope_cos_sin( + position_ids: torch.Tensor, + head_dim: int, + base: float = 10000.0, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute RoPE cos/sin embeddings for given position IDs. + + This is a functional implementation that doesn't require storing buffers, + making it compatible with FSDP meta device initialization. + + Args: + position_ids: Position IDs tensor [B, L] or [1, L] + head_dim: Dimension of each attention head + base: RoPE base frequency (default: 10000.0) + device: Target device + dtype: Output dtype + + Returns: + (cos, sin): Each with shape [B, L, head_dim] + """ + device = device or position_ids.device + dtype = dtype or torch.float32 + + # Compute inverse frequencies + inv_freq = 1.0 / ( + base + ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) / head_dim) + ) + + # Expand for batch computation: [B, L] -> [B, 1, L] @ [1, head_dim/2, 1] -> [B, head_dim/2, L] + inv_freq_expanded = inv_freq[None, :, None].expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Compute frequencies: [B, head_dim/2, L] -> [B, L, head_dim/2] + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + + # Double the frequencies for full head_dim: [B, L, head_dim] + emb = torch.cat((freqs, freqs), dim=-1) + + cos = emb.cos().to(dtype=dtype) + sin = emb.sin().to(dtype=dtype) + + return cos, sin + + +class PerFrameAttentionPooling(nn.Module): + """Per-frame multi-head attention pooling. + + Flattens the input sequence [B, L, D] and grid size (T, H, W). + Performs single-query attention pooling on the H*W tokens for each time frame. + Output shape: [B, T, D]. + """ + + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + assert dim % num_heads == 0, "dim must be divisible by num_heads" + self.dim = dim + self.num_heads = num_heads + + self.probe = nn.Parameter(torch.randn(1, 1, dim)) + nn.init.normal_(self.probe, std=0.02) + + self.attention = nn.MultiheadAttention( + embed_dim=dim, num_heads=num_heads, batch_first=True + ) + self.layernorm = nn.LayerNorm(dim, eps=eps) + + def forward(self, x: torch.Tensor, grid_size: Tuple[int, int, int]) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor of shape [B, L, D], where L = T * H * W. + grid_size: Tuple of (T, H, W). + + Returns: + Pooled tensor of shape [B, T, D]. + """ + B, L, D = x.shape + T, H, W = grid_size + assert ( + D == self.dim + ), f"Input dimension D={D} does not match module dim={self.dim}" + assert L == T * H * W, f"Flattened length L={L} does not match T*H*W={T*H*W}" + + S = H * W + x_bt_s_d = x.view(B, T, S, D).contiguous().view(B * T, S, D) # [B*T, S, D] + probe = self.probe.expand(B * T, -1, -1) # [B*T, 1, D] + + pooled_bt_1_d = self.attention(probe, x_bt_s_d, x_bt_s_d, need_weights=False)[0] + pooled_bt_d = pooled_bt_1_d.squeeze(1) # [B*T, D] + + pooled = pooled_bt_d.view(B, T, D) + pooled = self.layernorm(pooled) + return pooled + + +class CrossModalInteractionController: + """Strategy class to control dual-tower interaction. + + Manages the interaction mapping between Visual DiT (e.g., 30 layers) + and Audio DiT (e.g., 30 layers). + """ + + def __init__(self, visual_layers: int = 30, audio_layers: int = 30): + self.visual_layers = visual_layers + self.audio_layers = audio_layers + self.min_layers = min(visual_layers, audio_layers) + + def get_interaction_layers( + self, strategy: str = "shallow_focus" + ) -> Dict[str, List[Tuple[int, int]]]: + """Gets the mapping relationship of interaction layers.""" + if strategy == "shallow_focus": + num_interact = min(10, self.min_layers // 3) + interact_layers = list(range(0, num_interact)) + elif strategy == "distributed": + step = 3 + interact_layers = list(range(0, self.min_layers, step)) + elif strategy == "progressive": + shallow = list(range(0, min(8, self.min_layers))) + if self.min_layers > 8: + deep = list(range(8, self.min_layers, 3)) + interact_layers = shallow + deep + else: + interact_layers = shallow + elif strategy == "custom": + interact_layers = [0, 2, 4, 6, 8, 12, 16, 20] + interact_layers = [i for i in interact_layers if i < self.min_layers] + elif strategy == "full": + interact_layers = list(range(0, self.min_layers)) + else: + raise ValueError(f"Unknown interaction strategy: {strategy}") + + mapping = { + "v2a": [(i, i) for i in interact_layers], + "a2v": [(i, i) for i in interact_layers], + } + return mapping + + def should_interact( + self, layer_idx: int, direction: str, interaction_mapping: Dict + ) -> bool: + """Determines if the specified layer needs to interact.""" + if direction not in interaction_mapping: + return False + return any(src == layer_idx for src, _ in interaction_mapping[direction]) + + +class ConditionalCrossAttention(nn.Module): + """ + Cross-modal attention for dual-tower bridge with Tensor Parallel support. + + This module handles attention between video and audio hidden states, + which have different sequence lengths. + """ + + def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + self.q_dim = dim + self.kv_dim = kv_dim + self.num_heads = num_heads + self.head_dim = self.q_dim // num_heads + + self.tp_size = get_tp_world_size() + if self.num_heads % self.tp_size != 0: + raise ValueError( + f"num_heads ({self.num_heads}) must be divisible by tp_size ({self.tp_size})." + ) + self.num_heads_per_rank = self.num_heads // self.tp_size + + # TP strategy: shard Q/K/V over heads (column-parallel), then row-parallel output. + self.q = ColumnParallelLinear(dim, dim, bias=True, gather_output=False) + self.k = ColumnParallelLinear(kv_dim, dim, bias=True, gather_output=False) + self.v = ColumnParallelLinear(kv_dim, dim, bias=True, gather_output=False) + self.o = RowParallelLinear(dim, dim, bias=True, input_is_parallel=True) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + + self.attn = USPAttention( + num_heads=self.num_heads_per_rank, + head_size=self.head_dim, + causal=False, + softmax_scale=None, + # is_cross_attention=True, + ) + + def forward( + self, + x: torch.Tensor, + y: torch.Tensor, + x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ): + ctx = y + q, _ = self.q(x) + k, _ = self.k(ctx) + v, _ = self.v(ctx) + + # RMSNorm over sharded hidden dimension + if self.tp_size > 1: + q = tensor_parallel_rms_norm(q, self.norm_q) + k = tensor_parallel_rms_norm(k, self.norm_k) + else: + q = self.norm_q(q) + k = self.norm_k(k) + + if x_freqs is not None: + x_cos, x_sin = x_freqs + q_view = rearrange(q, "b l (h d) -> b l h d", d=self.head_dim) + x_cos = x_cos.to(q_view.dtype).to(q_view.device).squeeze(0) + x_sin = x_sin.to(q_view.dtype).to(q_view.device).squeeze(0) + # FlashInfer expects cos_sin_cache with shape [seqlen, head_dim], + # where the first half is cos and the second half is sin, each with + # head_dim//2 elements. Since compute_rope_cos_sin duplicates the + # frequencies (cat((freqs, freqs))), we only take the first half. + half_dim = self.head_dim // 2 + cos_sin_cache = torch.cat( + [ + x_cos[:, :half_dim].to(dtype=torch.float32).contiguous(), + x_sin[:, :half_dim].to(dtype=torch.float32).contiguous(), + ], + dim=-1, + ) + q_view, _ = apply_flashinfer_rope_qk_inplace( + q_view, q_view.clone(), cos_sin_cache, is_neox=True + ) + q = rearrange(q_view, "b l h d -> b l (h d)") + + if y_freqs is not None: + y_cos, y_sin = y_freqs + k_view = rearrange(k, "b l (h d) -> b l h d", d=self.head_dim) + y_cos = y_cos.to(k_view.dtype).to(k_view.device).squeeze(0) + y_sin = y_sin.to(k_view.dtype).to(k_view.device).squeeze(0) + # FlashInfer expects cos_sin_cache with shape [seqlen, head_dim], + # where the first half is cos and the second half is sin, each with + # head_dim//2 elements. Since compute_rope_cos_sin duplicates the + # frequencies (cat((freqs, freqs))), we only take the first half. + half_dim = self.head_dim // 2 + cos_sin_cache = torch.cat( + [ + y_cos[:, :half_dim].to(dtype=torch.float32).contiguous(), + y_sin[:, :half_dim].to(dtype=torch.float32).contiguous(), + ], + dim=-1, + ) + k_view, _ = apply_flashinfer_rope_qk_inplace( + k_view, k_view.clone(), cos_sin_cache, is_neox=True + ) + k = rearrange(k_view, "b l h d -> b l (h d)") + + q = rearrange(q, "b l (h d) -> b l h d", h=self.num_heads_per_rank) + k = rearrange(k, "b l (h d) -> b l h d", h=self.num_heads_per_rank) + v = rearrange(v, "b l (h d) -> b l h d", h=self.num_heads_per_rank) + + x = self.attn(q, k, v) + x = rearrange(x, "b l h d -> b l (h d)") + x, _ = self.o(x) + return x + + +class AdaLayerNorm(nn.Module): + """ + Norm layer modified to incorporate timestep embeddings. + """ + + def __init__( + self, + embedding_dim: int, + num_embeddings: Optional[int] = None, + output_dim: Optional[int] = None, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + chunk_dim: int = 0, + ): + super().__init__() + + self.chunk_dim = chunk_dim + output_dim = output_dim or embedding_dim * 2 + + if num_embeddings is not None: + self.emb = nn.Embedding(num_embeddings, embedding_dim) + else: + self.emb = None + + self.silu = nn.SiLU() + self.linear = ReplicatedLinear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) + + def forward( + self, + x: torch.Tensor, + timestep: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.emb is not None: + temb = self.emb(timestep) + + temb, _ = self.linear(self.silu(temb)) + + if self.chunk_dim == 2: + scale, shift = temb.chunk(2, dim=2) + elif self.chunk_dim == 1: + shift, scale = temb.chunk(2, dim=1) + shift = shift[:, None, :] + scale = scale[:, None, :] + else: + scale, shift = temb.chunk(2, dim=0) + + x = self.norm(x) * (1 + scale) + shift + return x + + +class ConditionalCrossAttentionBlock(nn.Module): + """A wrapper block for ConditionalCrossAttention that applies LayerNorm to the condition input y.""" + + def __init__( + self, + dim: int, + kv_dim: int, + num_heads: int, + eps: float = 1e-6, + pooled_adaln: bool = False, + ): + super().__init__() + self.y_norm = nn.LayerNorm(kv_dim, eps=eps) + self.inner = ConditionalCrossAttention( + dim=dim, kv_dim=kv_dim, num_heads=num_heads, eps=eps + ) + self.pooled_adaln = pooled_adaln + if pooled_adaln: + self.per_frame_pooling = PerFrameAttentionPooling( + kv_dim, num_heads=num_heads, eps=eps + ) + self.adaln = AdaLayerNorm(kv_dim, output_dim=dim * 2, chunk_dim=2) + + def forward( + self, + x: torch.Tensor, + y: torch.Tensor, + x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + video_grid_size: Optional[Tuple[int, int, int]] = None, + ) -> torch.Tensor: + if self.pooled_adaln: + assert video_grid_size is not None, "video_grid_size cannot be None" + pooled_y = self.per_frame_pooling(y, video_grid_size) + if pooled_y.shape[1] != x.shape[1]: + pooled_y = F.interpolate( + pooled_y.permute(0, 2, 1), + size=x.shape[1], + mode="linear", + align_corners=False, + ).permute(0, 2, 1) + x = self.adaln(x, temb=pooled_y) + y = self.y_norm(y) + return self.inner(x=x, y=y, x_freqs=x_freqs, y_freqs=y_freqs) + + +class DualTowerConditionalBridge( + CachableDiT, + OffloadableDiTMixin, +): + """Dual-tower conditional bridge module v2 (SGLang optimized version). + + Implements the correct architecture: + 1. Audio latents -> Audio DiT -> Audio hidden states [B, L, 1536]. + 2. Visual latents -> Visual DiT -> Visual hidden states [B, L, 5120]. + 3. Cross-attention interaction between the hidden states of the two DiTs. + """ + + _fsdp_shard_conditions = MOVADualTowerConfig()._fsdp_shard_conditions + _compile_conditions = MOVADualTowerConfig()._compile_conditions + _supported_attention_backends = MOVADualTowerConfig()._supported_attention_backends + param_names_mapping = MOVADualTowerConfig().param_names_mapping + reverse_param_names_mapping = MOVADualTowerConfig().reverse_param_names_mapping + lora_param_names_mapping = MOVADualTowerConfig().lora_param_names_mapping + + def __init__( + self, + config: MOVADualTowerConfig | None = None, + hf_config: dict[str, Any] | None = None, + # Fallback parameters for from_pretrained compatibility + visual_layers: int = 40, + audio_layers: int = 30, + visual_hidden_dim: int = 5120, + audio_hidden_dim: int = 1536, + audio_fps: float = 50.0, + head_dim: int = 128, + interaction_strategy: str = "full", + apply_cross_rope: bool = True, + apply_first_frame_bias_in_rope: bool = False, + trainable_condition_scale: bool = False, + pooled_adaln: bool = False, + ): + super().__init__(config=config, hf_config=hf_config) + + # Use config if provided, otherwise use individual parameters + if config is not None: + visual_layers = config.visual_layers + audio_layers = config.audio_layers + visual_hidden_dim = config.visual_hidden_dim + audio_hidden_dim = config.audio_hidden_dim + audio_fps = config.audio_fps + head_dim = config.head_dim + interaction_strategy = config.interaction_strategy + apply_cross_rope = config.apply_cross_rope + apply_first_frame_bias_in_rope = config.apply_first_frame_bias_in_rope + trainable_condition_scale = config.trainable_condition_scale + pooled_adaln = config.pooled_adaln + + self.visual_hidden_dim = visual_hidden_dim + self.audio_hidden_dim = audio_hidden_dim + self.audio_fps = audio_fps + self.head_dim = head_dim + self.apply_cross_rope = apply_cross_rope + self.apply_first_frame_bias_in_rope = apply_first_frame_bias_in_rope + self.trainable_condition_scale = trainable_condition_scale + self.pooled_adaln = pooled_adaln + + if self.trainable_condition_scale: + self.condition_scale = nn.Parameter( + torch.tensor([1.0], dtype=torch.float32) + ) + else: + self.condition_scale = 1.0 + + self.controller = CrossModalInteractionController(visual_layers, audio_layers) + self.interaction_mapping = self.controller.get_interaction_layers( + interaction_strategy + ) + + # Cross-modal attention modules - interaction at DiT hidden states level + self.audio_to_video_conditioners = nn.ModuleDict() + self.video_to_audio_conditioners = nn.ModuleDict() + + self.rope_base = 10000.0 # RoPE base frequency hardcode. adapted from original mova implementation. + + # Audio DiT hidden states conditioning Video DiT + for v_layer, _ in self.interaction_mapping["a2v"]: + self.audio_to_video_conditioners[str(v_layer)] = ( + ConditionalCrossAttentionBlock( + dim=visual_hidden_dim, + kv_dim=audio_hidden_dim, + num_heads=visual_hidden_dim // head_dim, + pooled_adaln=False, + ) + ) + + # Visual DiT hidden states conditioning Audio DiT + for a_layer, _ in self.interaction_mapping["v2a"]: + self.video_to_audio_conditioners[str(a_layer)] = ( + ConditionalCrossAttentionBlock( + dim=audio_hidden_dim, + kv_dim=visual_hidden_dim, + num_heads=audio_hidden_dim // head_dim, + pooled_adaln=self.pooled_adaln, + ) + ) + + # Required attributes for CachableDiT/BaseDiT + self.hidden_size = visual_hidden_dim + self.num_attention_heads = visual_hidden_dim // head_dim + self.num_channels_latents = ( + visual_hidden_dim # Bridge doesn't output latents, but required by BaseDiT + ) + self.layer_names = [ + "audio_to_video_conditioners", + "video_to_audio_conditioners", + ] + self.__post_init__() + + @torch.no_grad() + def build_aligned_freqs( + self, + video_fps: float, + grid_size: Tuple[int, int, int], + audio_steps: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + """Generates aligned RoPE (cos, sin) based on video FPS, grid size, and audio length. + + Uses functional RoPE computation to avoid FSDP meta device issues. + + Args: + video_fps: FPS of the video. + grid_size: Tuple of (f_v, h, w). + audio_steps: Length of the audio sequence. + device: Target device. + dtype: Output dtype. + + Returns: + A tuple of ((cos_v, sin_v), (cos_a, sin_a)). + """ + f_v, h, w = grid_size + L_v = f_v * h * w + L_a = int(audio_steps) + + device = device or next(self.parameters()).device + dtype = dtype or torch.float32 + + # Audio positions: 0, 1, 2, ..., L_a-1 + audio_pos = torch.arange(L_a, device=device, dtype=torch.float32).unsqueeze(0) + + # Video positions: Align video frames to audio step units + if self.apply_first_frame_bias_in_rope: + video_effective_fps = float(video_fps) / 4.0 + if f_v > 0: + t_starts = torch.zeros((f_v,), device=device, dtype=torch.float32) + if f_v > 1: + t_starts[1:] = (1.0 / float(video_fps)) + torch.arange( + f_v - 1, device=device, dtype=torch.float32 + ) * (1.0 / video_effective_fps) + else: + t_starts = torch.zeros((0,), device=device, dtype=torch.float32) + video_pos_per_frame = t_starts * float(self.audio_fps) + else: + scale = float(self.audio_fps) / float(video_fps / 4.0) + video_pos_per_frame = ( + torch.arange(f_v, device=device, dtype=torch.float32) * scale + ) + + video_pos = video_pos_per_frame.repeat_interleave(h * w).unsqueeze(0) + + # Use functional RoPE to compute cos/sin + cos_v, sin_v = compute_rope_cos_sin( + video_pos, self.head_dim, base=self.rope_base, device=device, dtype=dtype + ) + cos_a, sin_a = compute_rope_cos_sin( + audio_pos, self.head_dim, base=self.rope_base, device=device, dtype=dtype + ) + + return (cos_v, sin_v), (cos_a, sin_a) + + def should_interact(self, layer_idx: int, direction: str) -> bool: + return self.controller.should_interact( + layer_idx, direction, self.interaction_mapping + ) + + def apply_conditional_control( + self, + layer_idx: int, + direction: str, + primary_hidden_states: torch.Tensor, + condition_hidden_states: torch.Tensor, + x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + condition_scale: Optional[float] = None, + video_grid_size: Optional[Tuple[int, int, int]] = None, + ) -> torch.Tensor: + """Applies conditional control at the DiT hidden states level.""" + if not self.controller.should_interact( + layer_idx, direction, self.interaction_mapping + ): + return primary_hidden_states + + if direction == "a2v": + conditioner = self.audio_to_video_conditioners[str(layer_idx)] + elif direction == "v2a": + conditioner = self.video_to_audio_conditioners[str(layer_idx)] + else: + raise ValueError(f"Invalid direction: {direction}") + + conditioned_features = conditioner( + x=primary_hidden_states, + y=condition_hidden_states, + x_freqs=x_freqs, + y_freqs=y_freqs, + video_grid_size=video_grid_size, + ) + + if self.trainable_condition_scale and condition_scale is not None: + logger.warning( + "The current model has a trainable condition_scale, but condition_scale " + "was passed externally. Ignoring the trainable condition_scale and " + "using the external condition_scale=%s.", + condition_scale, + ) + + scale = condition_scale if condition_scale is not None else self.condition_scale + + primary_hidden_states = primary_hidden_states + conditioned_features * scale + + return primary_hidden_states + + def forward( + self, + layer_idx: int, + visual_hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + *, + x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + a2v_condition_scale: Optional[float] = None, + v2a_condition_scale: Optional[float] = None, + condition_scale: Optional[float] = None, + video_grid_size: Optional[Tuple[int, int, int]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Performs bidirectional conditional control for both visual and audio towers.""" + visual_conditioned = self.apply_conditional_control( + layer_idx=layer_idx, + direction="a2v", + primary_hidden_states=visual_hidden_states, + condition_hidden_states=audio_hidden_states, + x_freqs=x_freqs, + y_freqs=y_freqs, + condition_scale=( + a2v_condition_scale + if a2v_condition_scale is not None + else condition_scale + ), + video_grid_size=video_grid_size, + ) + + audio_conditioned = self.apply_conditional_control( + layer_idx=layer_idx, + direction="v2a", + primary_hidden_states=audio_hidden_states, + condition_hidden_states=visual_hidden_states, + x_freqs=y_freqs, + y_freqs=x_freqs, + condition_scale=( + v2a_condition_scale + if v2a_condition_scale is not None + else condition_scale + ), + video_grid_size=video_grid_size, + ) + + return visual_conditioned, audio_conditioned + + +EntryClass = DualTowerConditionalBridge diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/dits/base.py b/sglang/python/sglang/multimodal_gen/runtime/models/dits/base.py new file mode 100644 index 0000000000000000000000000000000000000000..1048a5196630d0292f42ee0de1fdf37c793b6c88 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/dits/base.py @@ -0,0 +1,123 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from typing import Any + +import torch +from torch import nn + +from sglang.multimodal_gen.configs.models import DiTConfig + +# NOTE: TeaCacheContext and TeaCacheMixin have been moved to +# sglang.multimodal_gen.runtime.cache.teacache +# For backwards compatibility, re-export from the new location +from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheContext # noqa: F401 +from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheMixin +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum + + +# TODO +class BaseDiT(nn.Module, ABC): + _fsdp_shard_conditions: list = [] + _compile_conditions: list = [] + param_names_mapping: dict + reverse_param_names_mapping: dict + hidden_size: int + num_attention_heads: int + num_channels_latents: int + # always supports torch_sdpa + _supported_attention_backends: set[AttentionBackendEnum] = ( + DiTConfig()._supported_attention_backends + ) + + def __init_subclass__(cls) -> None: + required_class_attrs = [ + "_fsdp_shard_conditions", + "param_names_mapping", + "_compile_conditions", + ] + super().__init_subclass__() + for attr in required_class_attrs: + if not hasattr(cls, attr): + raise AttributeError( + f"Subclasses of BaseDiT must define '{attr}' class variable" + ) + + def __init__(self, config: DiTConfig, hf_config: dict[str, Any], **kwargs) -> None: + super().__init__() + self.config = config + self.hf_config = hf_config + if not self.supported_attention_backends: + raise ValueError( + f"Subclass {self.__class__.__name__} must define _supported_attention_backends" + ) + + @abstractmethod + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + timestep: torch.LongTensor, + encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None, + guidance=None, + **kwargs, + ) -> torch.Tensor: + pass + + def __post_init__(self) -> None: + required_attrs = ["hidden_size", "num_attention_heads", "num_channels_latents"] + for attr in required_attrs: + if not hasattr(self, attr): + raise AttributeError( + f"Subclasses of BaseDiT must define '{attr}' instance variable" + ) + + @property + def supported_attention_backends(self) -> set[AttentionBackendEnum]: + return self._supported_attention_backends + + @property + def device(self) -> torch.device: + """Get the device of the model.""" + return next(self.parameters()).device + + +class CachableDiT(TeaCacheMixin, BaseDiT): + """ + An intermediate base class that adds TeaCache optimization functionality to DiT models. + + Inherits TeaCacheMixin for cache logic and BaseDiT for core DiT functionality. + """ + + # These are required class attributes that should be overridden by concrete implementations + _fsdp_shard_conditions = [] + param_names_mapping = {} + reverse_param_names_mapping = {} + lora_param_names_mapping: dict = {} + # Ensure these instance attributes are properly defined in subclasses + hidden_size: int + num_attention_heads: int + num_channels_latents: int + # always supports torch_sdpa + _supported_attention_backends: set[AttentionBackendEnum] = ( + DiTConfig()._supported_attention_backends + ) + + def __init__(self, config: DiTConfig, **kwargs) -> None: + super().__init__(config, **kwargs) + self._init_teacache_state() + + @classmethod + def get_nunchaku_quant_rules(cls) -> dict[str, dict[str, Any]]: + """ + Get quantization rules for Nunchaku quantization. + + Returns a dict mapping layer name patterns to quantization configs: + { + "skip": [list of patterns to skip quantization], + "svdq_w4a4": [list of patterns for SVDQ W4A4], + "awq_w4a16": [list of patterns for AWQ W4A16], + } + """ + return {} diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py b/sglang/python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py new file mode 100644 index 0000000000000000000000000000000000000000..279aca9f28f7d9b20ab4e17e0eaaf825e30ae95c --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py @@ -0,0 +1,876 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import math +from typing import Any + +import torch +import torch.nn as nn +from torch.nn.attention.flex_attention import ( + BlockMask, + create_block_mask, + flex_attention, +) + +from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin + +# wan 1.3B model has a weird channel / head configurations and require max-autotune to work with flexattention +# see https://github.com/pytorch/pytorch/issues/133254 +# change to default for other models +flex_attention = torch.compile( + flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs" +) +import torch.distributed as dist + +from sglang.multimodal_gen.configs.models.dits import WanVideoConfig +from sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_world_size +from sglang.multimodal_gen.runtime.layers.attention import LocalAttention +from sglang.multimodal_gen.runtime.layers.elementwise import MulAdd +from sglang.multimodal_gen.runtime.layers.layernorm import ( + FP32LayerNorm, + LayerNormScaleShift, + RMSNorm, + ScaleResidualLayerNormScaleShift, +) +from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear +from sglang.multimodal_gen.runtime.layers.mlp import MLP +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, +) +from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( + _apply_rotary_emb, + get_rotary_pos_embed, +) +from sglang.multimodal_gen.runtime.layers.visual_embedding import PatchEmbed +from sglang.multimodal_gen.runtime.models.dits.base import BaseDiT +from sglang.multimodal_gen.runtime.models.dits.wanvideo import ( + WanT2VCrossAttention, + WanTimeTextImageEmbedding, +) +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class CausalWanSelfAttention(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + local_attn_size: int = -1, + sink_size: int = 0, + qk_norm=True, + eps=1e-6, + parallel_attention=False, + ) -> None: + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.local_attn_size = local_attn_size + self.sink_size = sink_size + self.qk_norm = qk_norm + self.eps = eps + self.parallel_attention = parallel_attention + self.max_attention_size = ( + 32760 if local_attn_size == -1 else local_attn_size * 1560 + ) + + # Scaled dot product attention + self.attn = LocalAttention( + num_heads=num_heads, + head_size=self.head_dim, + dropout_rate=0, + softmax_scale=None, + causal=False, + supported_attention_backends=( + AttentionBackendEnum.FA, + AttentionBackendEnum.AITER, + AttentionBackendEnum.TORCH_SDPA, + ), + ) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + freqs_cis: tuple[torch.Tensor, torch.Tensor], + block_mask: BlockMask, + kv_cache: dict | None = None, + current_start: int = 0, + cache_start: int | None = None, + ): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + if cache_start is None: + cache_start = current_start + + cos, sin = freqs_cis + roped_query = _apply_rotary_emb(q, cos, sin, is_neox_style=False).type_as(v) + roped_key = _apply_rotary_emb(k, cos, sin, is_neox_style=False).type_as(v) + + if kv_cache is None: + # Padding for flex attention + padded_length = math.ceil(q.shape[1] / 128) * 128 - q.shape[1] + padded_roped_query = torch.cat( + [ + roped_query, + torch.zeros( + [q.shape[0], padded_length, q.shape[2], q.shape[3]], + device=q.device, + dtype=v.dtype, + ), + ], + dim=1, + ) + + padded_roped_key = torch.cat( + [ + roped_key, + torch.zeros( + [k.shape[0], padded_length, k.shape[2], k.shape[3]], + device=k.device, + dtype=v.dtype, + ), + ], + dim=1, + ) + + padded_v = torch.cat( + [ + v, + torch.zeros( + [v.shape[0], padded_length, v.shape[2], v.shape[3]], + device=v.device, + dtype=v.dtype, + ), + ], + dim=1, + ) + + x = flex_attention( + query=padded_roped_query.transpose(2, 1), + key=padded_roped_key.transpose(2, 1), + value=padded_v.transpose(2, 1), + block_mask=block_mask, + )[:, :, :-padded_length].transpose(2, 1) + else: + frame_seqlen = q.shape[1] + current_end = current_start + roped_query.shape[1] + sink_tokens = self.sink_size * frame_seqlen + # If we are using local attention and the current KV cache size is larger than the local attention size, we need to truncate the KV cache + kv_cache_size = kv_cache["k"].shape[1] + num_new_tokens = roped_query.shape[1] + if ( + self.local_attn_size != -1 + and (current_end > kv_cache["global_end_index"].item()) + and ( + num_new_tokens + kv_cache["local_end_index"].item() > kv_cache_size + ) + ): + # Calculate the number of new tokens added in this step + # Shift existing cache content left to discard oldest tokens + # Clone the source slice to avoid overlapping memory error + num_evicted_tokens = ( + num_new_tokens + kv_cache["local_end_index"].item() - kv_cache_size + ) + num_rolled_tokens = ( + kv_cache["local_end_index"].item() + - num_evicted_tokens + - sink_tokens + ) + kv_cache["k"][ + :, sink_tokens : sink_tokens + num_rolled_tokens + ] = kv_cache["k"][ + :, + sink_tokens + + num_evicted_tokens : sink_tokens + + num_evicted_tokens + + num_rolled_tokens, + ].clone() + kv_cache["v"][ + :, sink_tokens : sink_tokens + num_rolled_tokens + ] = kv_cache["v"][ + :, + sink_tokens + + num_evicted_tokens : sink_tokens + + num_evicted_tokens + + num_rolled_tokens, + ].clone() + # Insert the new keys/values at the end + local_end_index = ( + kv_cache["local_end_index"].item() + + current_end + - kv_cache["global_end_index"].item() + - num_evicted_tokens + ) + local_start_index = local_end_index - num_new_tokens + kv_cache["k"][:, local_start_index:local_end_index] = roped_key + kv_cache["v"][:, local_start_index:local_end_index] = v + else: + # Assign new keys/values directly up to current_end + local_end_index = ( + kv_cache["local_end_index"].item() + + current_end + - kv_cache["global_end_index"].item() + ) + local_start_index = local_end_index - num_new_tokens + kv_cache["k"] = kv_cache["k"].detach() + kv_cache["v"] = kv_cache["v"].detach() + # logger.info("kv_cache['k'] is in comp graph: %s", kv_cache["k"].requires_grad or kv_cache["k"].grad_fn is not None) + kv_cache["k"][:, local_start_index:local_end_index] = roped_key + kv_cache["v"][:, local_start_index:local_end_index] = v + x = self.attn( + roped_query, + kv_cache["k"][ + :, + max(0, local_end_index - self.max_attention_size) : local_end_index, + ], + kv_cache["v"][ + :, + max(0, local_end_index - self.max_attention_size) : local_end_index, + ], + ) + kv_cache["global_end_index"].fill_(current_end) + kv_cache["local_end_index"].fill_(local_end_index) + + return x + + +class CausalWanTransformerBlock(nn.Module): + + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + local_attn_size: int = -1, + sink_size: int = 0, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: int | None = None, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + prefix: str = "", + quant_config: QuantizationConfig | None = None, + ): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.to_q = ReplicatedLinear(dim, dim, bias=True, quant_config=quant_config) + self.to_k = ReplicatedLinear(dim, dim, bias=True, quant_config=quant_config) + self.to_v = ReplicatedLinear(dim, dim, bias=True, quant_config=quant_config) + + self.to_out = ReplicatedLinear(dim, dim, bias=True, quant_config=quant_config) + self.attn1 = CausalWanSelfAttention( + dim, + num_heads, + local_attn_size=local_attn_size, + sink_size=sink_size, + qk_norm=qk_norm, + eps=eps, + ) + self.hidden_dim = dim + self.num_attention_heads = num_heads + self.local_attn_size = local_attn_size + dim_head = dim // num_heads + if qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "rms_norm_across_heads": + # LTX applies qk norm across all heads + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + else: + print("QK Norm type not supported") + raise Exception + assert cross_attn_norm is True + self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift( + dim, eps=eps, elementwise_affine=True, dtype=torch.float32 + ) + + # 2. Cross-attention + # Only T2V for now + cross_attn_backends = { + b for b in supported_attention_backends if not b.is_sparse + } + self.attn2 = WanT2VCrossAttention( + dim, + num_heads, + qk_norm=qk_norm, + eps=eps, + supported_attention_backends=cross_attn_backends, + quant_config=quant_config, + ) + self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift( + dim, eps=eps, elementwise_affine=False, dtype=torch.float32 + ) + + # 3. Feed-forward + self.ffn = MLP( + dim, ffn_dim, act_type="gelu_pytorch_tanh", quant_config=quant_config + ) + self.mlp_residual = MulAdd() + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + freqs_cis: tuple[torch.Tensor, torch.Tensor], + block_mask: BlockMask, + kv_cache: dict | None = None, + crossattn_cache: dict | None = None, + current_start: int = 0, + cache_start: int | None = None, + ) -> torch.Tensor: + # hidden_states.shape: [batch_size, seq_length, inner_dim] + # temb.shape: [batch_size, num_frames, 6, inner_dim] + if hidden_states.dim() == 4: + hidden_states = hidden_states.squeeze(1) + num_frames = temb.shape[1] + frame_seqlen = hidden_states.shape[1] // num_frames + bs, seq_length, _ = hidden_states.shape + orig_dtype = hidden_states.dtype + # assert orig_dtype != torch.float32 + e = self.scale_shift_table + temb.float() + # e.shape: [batch_size, num_frames, 6, inner_dim] + assert e.shape == (bs, num_frames, 6, self.hidden_dim) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = e.chunk( + 6, dim=2 + ) + # *_msa.shape: [batch_size, num_frames, 1, inner_dim] + assert shift_msa.dtype == torch.float32 + + # 1. Self-attention + norm_hidden_states = ( + ( + self.norm1(hidden_states.float()).unflatten( + dim=1, sizes=(num_frames, frame_seqlen) + ) + * (1 + scale_msa) + + shift_msa + ) + .flatten(1, 2) + .to(orig_dtype) + ) + query, _ = self.to_q(norm_hidden_states) + key, _ = self.to_k(norm_hidden_states) + value, _ = self.to_v(norm_hidden_states) + + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + + attn_output = self.attn1( + query, + key, + value, + freqs_cis, + block_mask, + kv_cache, + current_start, + cache_start, + ) + attn_output = attn_output.flatten(2) + attn_output, _ = self.to_out(attn_output) + attn_output = attn_output.squeeze(1) + + null_shift = null_scale = torch.zeroes( + (1,), device=hidden_states.device, dtype=hidden_states.dtype + ) + norm_hidden_states, hidden_states = self.self_attn_residual_norm( + hidden_states, attn_output, gate_msa, null_shift, null_scale + ) + norm_hidden_states, hidden_states = norm_hidden_states.to( + orig_dtype + ), hidden_states.to(orig_dtype) + + # 2. Cross-attention + attn_output = self.attn2( + norm_hidden_states, + context=encoder_hidden_states, + context_lens=None, + crossattn_cache=crossattn_cache, + ) + norm_hidden_states, hidden_states = self.cross_attn_residual_norm( + hidden_states, attn_output, 1, c_shift_msa, c_scale_msa + ) + norm_hidden_states, hidden_states = norm_hidden_states.to( + orig_dtype + ), hidden_states.to(orig_dtype) + + # 3. Feed-forward + ff_output = self.ffn(norm_hidden_states) + hidden_states = self.mlp_residual(ff_output, c_gate_msa, hidden_states) + hidden_states = hidden_states.to(orig_dtype) + + return hidden_states + + +class CausalWanTransformer3DModel(BaseDiT, OffloadableDiTMixin): + _fsdp_shard_conditions = WanVideoConfig()._fsdp_shard_conditions + _compile_conditions = WanVideoConfig()._compile_conditions + _supported_attention_backends = WanVideoConfig()._supported_attention_backends + param_names_mapping = WanVideoConfig().param_names_mapping + reverse_param_names_mapping = WanVideoConfig().reverse_param_names_mapping + lora_param_names_mapping = WanVideoConfig().lora_param_names_mapping + + def __init__( + self, + config: WanVideoConfig, + hf_config: dict[str, Any], + quant_config: QuantizationConfig | None = None, + ) -> None: + super().__init__(config=config, hf_config=hf_config) + + inner_dim = config.num_attention_heads * config.attention_head_dim + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.attention_head_dim = config.attention_head_dim + self.in_channels = config.in_channels + self.out_channels = config.out_channels + self.num_channels_latents = config.num_channels_latents + self.patch_size = config.patch_size + self.text_len = config.text_len + self.local_attn_size = config.local_attn_size + + # 1. Patch & position embedding + self.patch_embedding = PatchEmbed( + in_chans=config.in_channels, + embed_dim=inner_dim, + patch_size=config.patch_size, + flatten=False, + ) + + # 2. Condition embeddings + self.condition_embedder = WanTimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=config.freq_dim, + text_embed_dim=config.text_dim, + image_embed_dim=config.image_dim, + ) + + # 3. Transformer blocks + self.blocks = nn.ModuleList( + [ + CausalWanTransformerBlock( + inner_dim, + config.ffn_dim, + config.num_attention_heads, + config.local_attn_size, + config.sink_size, + config.qk_norm, + config.cross_attn_norm, + config.eps, + config.added_kv_proj_dim, + self._supported_attention_backends, + prefix=f"{config.prefix}.blocks.{i}", + quant_config=quant_config, + ) + for i in range(config.num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = LayerNormScaleShift( + inner_dim, + eps=config.eps, + elementwise_affine=False, + dtype=torch.float32, + ) + self.proj_out = nn.Linear( + inner_dim, config.out_channels * math.prod(config.patch_size) + ) + self.scale_shift_table = nn.Parameter( + torch.randn(1, 2, inner_dim) / inner_dim**0.5 + ) + + self.gradient_checkpointing = False + + # Causal-specific + self.block_mask = None + self.num_frame_per_block = config.arch_config.num_frames_per_block + assert self.num_frame_per_block <= 3 + self.independent_first_frame = False + + self.__post_init__() + + self.layer_names = [ + "blocks", + ] + + @staticmethod + def _prepare_blockwise_causal_attn_mask( + device: torch.device | str, + num_frames: int = 21, + frame_seqlen: int = 1560, + num_frame_per_block=1, + local_attn_size=-1, + ) -> BlockMask: + """ + we will divide the token sequence into the following format + [1 latent frame] [1 latent frame] ... [1 latent frame] + We use flexattention to construct the attention mask + """ + total_length = num_frames * frame_seqlen + + # we do right padding to get to a multiple of 128 + padded_length = math.ceil(total_length / 128) * 128 - total_length + + ends = torch.zeros( + total_length + padded_length, device=device, dtype=torch.long + ) + + # Block-wise causal mask will attend to all elements that are before the end of the current chunk + frame_indices = torch.arange( + start=0, + end=total_length, + step=frame_seqlen * num_frame_per_block, + device=device, + ) + + for tmp in frame_indices: + ends[tmp : tmp + frame_seqlen * num_frame_per_block] = ( + tmp + frame_seqlen * num_frame_per_block + ) + + def attention_mask(b, h, q_idx, kv_idx): + if local_attn_size == -1: + return (kv_idx < ends[q_idx]) | (q_idx == kv_idx) + else: + return ( + (kv_idx < ends[q_idx]) + & (kv_idx >= (ends[q_idx] - local_attn_size * frame_seqlen)) + ) | (q_idx == kv_idx) + # return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask + + block_mask = create_block_mask( + attention_mask, + B=None, + H=None, + Q_LEN=total_length + padded_length, + KV_LEN=total_length + padded_length, + _compile=False, + device=device, + ) + + if not dist.is_initialized() or dist.get_rank() == 0: + print( + f" cache a block wise causal mask with block size of {num_frame_per_block} frames" + ) + print(block_mask) + + # import imageio + # import numpy as np + # from torch.nn.attention.flex_attention import create_mask + + # mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length + + # padded_length, KV_LEN=total_length + padded_length, device=device) + # import cv2 + # mask = cv2.resize(mask[0, 0].cpu().float().numpy(), (1024, 1024)) + # imageio.imwrite("mask_%d.jpg" % (0), np.uint8(255. * mask)) + + return block_mask + + def _forward_inference( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + timestep: torch.LongTensor, + encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None, + kv_cache: dict = None, + crossattn_cache: dict = None, + current_start: int = 0, + cache_start: int = 0, + start_frame: int = 0, + **kwargs, + ) -> torch.Tensor: + r""" + Run the diffusion model with kv caching. + See Algorithm 2 of CausVid paper https://arxiv.org/abs/2412.07772 for details. + This function will be run for num_frame times. + Process the latent frames one by one (1560 tokens each) + """ + + orig_dtype = hidden_states.dtype + if not isinstance(encoder_hidden_states, torch.Tensor): + encoder_hidden_states = encoder_hidden_states[0] + if ( + isinstance(encoder_hidden_states_image, list) + and len(encoder_hidden_states_image) > 0 + ): + encoder_hidden_states_image = encoder_hidden_states_image[0] + else: + encoder_hidden_states_image = None + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # Get rotary embeddings + d = self.hidden_size // self.num_attention_heads + rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)] + freqs_cos, freqs_sin = get_rotary_pos_embed( + ( + post_patch_num_frames * get_sp_world_size(), + post_patch_height, + post_patch_width, + ), + self.hidden_size, + self.num_attention_heads, + rope_dim_list, + dtype=( + torch.float32 + if current_platform.is_mps() or current_platform.is_musa() + else torch.float64 + ), + rope_theta=10000, + start_frame=start_frame, # Assume that start_frame is 0 when kv_cache is None + ) + freqs_cos = freqs_cos.to(hidden_states.device) + freqs_sin = freqs_sin.to(hidden_states.device) + freqs_cis = ( + (freqs_cos.float(), freqs_sin.float()) if freqs_cos is not None else None + ) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = ( + self.condition_embedder( + timestep.flatten(), encoder_hidden_states, encoder_hidden_states_image + ) + ) + timestep_proj = timestep_proj.unflatten(1, (6, self.hidden_size)).unflatten( + dim=0, sizes=timestep.shape + ) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat( + [encoder_hidden_states_image, encoder_hidden_states], dim=1 + ) + + encoder_hidden_states = ( + encoder_hidden_states.to(orig_dtype) + if current_platform.is_mps() + else encoder_hidden_states + ) # cast to orig_dtype for MPS + + assert encoder_hidden_states.dtype == orig_dtype + + # 4. Transformer blocks + for block_index, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + causal_kwargs = { + "kv_cache": kv_cache[block_index], + "current_start": current_start, + "cache_start": cache_start, + "block_mask": self.block_mask, + } + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + timestep_proj, + freqs_cis, + **causal_kwargs, + ) + else: + causal_kwargs = { + "kv_cache": kv_cache[block_index], + "crossattn_cache": crossattn_cache[block_index], + "current_start": current_start, + "cache_start": cache_start, + "block_mask": self.block_mask, + } + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + freqs_cis, + **causal_kwargs, + ) + + # 5. Output norm, projection & unpatchify + temb = temb.unflatten(dim=0, sizes=timestep.shape).unsqueeze(2) + shift, scale = (self.scale_shift_table.unsqueeze(1) + temb).chunk(2, dim=2) + hidden_states = self.norm_out(hidden_states, shift, scale) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, + post_patch_num_frames, + post_patch_height, + post_patch_width, + p_t, + p_h, + p_w, + -1, + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return output + + def _forward_train( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + timestep: torch.LongTensor, + encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None, + start_frame: int = 0, + **kwargs, + ) -> torch.Tensor: + + orig_dtype = hidden_states.dtype + if not isinstance(encoder_hidden_states, torch.Tensor): + encoder_hidden_states = encoder_hidden_states[0] + if ( + isinstance(encoder_hidden_states_image, list) + and len(encoder_hidden_states_image) > 0 + ): + encoder_hidden_states_image = encoder_hidden_states_image[0] + else: + encoder_hidden_states_image = None + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # Get rotary embeddings + d = self.hidden_size // self.num_attention_heads + rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)] + freqs_cos, freqs_sin = get_rotary_pos_embed( + ( + post_patch_num_frames * get_sp_world_size(), + post_patch_height, + post_patch_width, + ), + self.hidden_size, + self.num_attention_heads, + rope_dim_list, + dtype=( + torch.float32 + if current_platform.is_mps() or current_platform.is_musa() + else torch.float64 + ), + rope_theta=10000, + start_frame=start_frame, + ) + freqs_cos = freqs_cos.to(hidden_states.device) + freqs_sin = freqs_sin.to(hidden_states.device) + freqs_cis = ( + (freqs_cos.float(), freqs_sin.float()) if freqs_cos is not None else None + ) + + # Construct blockwise causal attn mask + if self.block_mask is None: + self.block_mask = self._prepare_blockwise_causal_attn_mask( + device=hidden_states.device, + num_frames=num_frames, + frame_seqlen=post_patch_height * post_patch_width, + num_frame_per_block=self.num_frame_per_block, + local_attn_size=self.local_attn_size, + ) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = ( + self.condition_embedder( + timestep.flatten(), encoder_hidden_states, encoder_hidden_states_image + ) + ) + timestep_proj = timestep_proj.unflatten(1, (6, self.hidden_size)).unflatten( + dim=0, sizes=timestep.shape + ) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat( + [encoder_hidden_states_image, encoder_hidden_states], dim=1 + ) + + encoder_hidden_states = ( + encoder_hidden_states.to(orig_dtype) + if current_platform.is_mps() + else encoder_hidden_states + ) # cast to orig_dtype for MPS + + assert encoder_hidden_states.dtype == orig_dtype + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + timestep_proj, + freqs_cis, + block_mask=self.block_mask, + ) + else: + for block in self.blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + freqs_cis, + block_mask=self.block_mask, + ) + + # 5. Output norm, projection & unpatchify + temb = temb.unflatten(dim=0, sizes=timestep.shape).unsqueeze(2) + shift, scale = (self.scale_shift_table.unsqueeze(1) + temb).chunk(2, dim=2) + hidden_states = self.norm_out(hidden_states, shift, scale) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, + post_patch_num_frames, + post_patch_height, + post_patch_width, + p_t, + p_h, + p_w, + -1, + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return output + + def forward(self, *args, **kwargs): + if kwargs.get("kv_cache") is not None: + return self._forward_inference(*args, **kwargs) + else: + return self._forward_train(*args, **kwargs) + + +EntryClass = CausalWanTransformer3DModel diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/dits/flux.py b/sglang/python/sglang/multimodal_gen/runtime/models/dits/flux.py new file mode 100644 index 0000000000000000000000000000000000000000..1f2c4b806ba718e6e4ce077940388c722386cac2 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/dits/flux.py @@ -0,0 +1,919 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from diffusers.models.attention import AttentionModuleMixin +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import ( + AdaLayerNormContinuous, + AdaLayerNormZero, + AdaLayerNormZeroSingle, +) +from torch.nn import LayerNorm as LayerNorm + +from sglang.multimodal_gen.configs.models.dits.flux import FluxConfig +from sglang.multimodal_gen.runtime.layers.attention import USPAttention +from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm, apply_qk_norm +from sglang.multimodal_gen.runtime.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.mlp import FeedForward +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, +) +from sglang.multimodal_gen.runtime.layers.quantization.configs.nunchaku_config import ( + NunchakuConfig, + is_nunchaku_available, +) +from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( + NDRotaryEmbedding, + apply_flashinfer_rope_qk_inplace, +) +from sglang.multimodal_gen.runtime.layers.visual_embedding import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, +) +from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) # pylint: disable=invalid-name + +try: + from nunchaku.models.attention import NunchakuFeedForward # type: ignore[import] + from nunchaku.models.normalization import ( # type: ignore[import] + NunchakuAdaLayerNormZero, + NunchakuAdaLayerNormZeroSingle, + ) + from nunchaku.ops.gemm import ( + svdq_gemm_w4a4_cuda as _svdq_gemm_w4a4, # type: ignore[import] + ) + from nunchaku.ops.quantize import ( + svdq_quantize_w4a4_act_fuse_lora_cuda as _svdq_quantize_w4a4, # type: ignore[import] + ) + + _nunchaku_fused_ops_available = True +except Exception: + NunchakuFeedForward = None + NunchakuAdaLayerNormZero = None + NunchakuAdaLayerNormZeroSingle = None + _svdq_gemm_w4a4 = None + _svdq_quantize_w4a4 = None + _nunchaku_fused_ops_available = False + + +def _fused_gelu_mlp( + x: torch.Tensor, + fc1, + fc2, + pad_size: int = 256, +) -> torch.Tensor: + """ + Fused GELU MLP matching nunchaku's fused_gelu_mlp kernel path. + + nunchaku's single-block MLP checkpoint is calibrated for the fused path where: + 1. fc1 GEMM + GELU + 0.171875 shift + unsigned re-quantization + fc2.lora_down + are all done in a single fused kernel call + 2. fc2 GEMM then receives unsigned INT4 activations (act_unsigned=True) + + Using the sequential path (fc1 → GELU → fc2 with symmetric quantization) is + fundamentally incompatible with these wscales, causing visually wrong outputs. + """ + batch_size, seq_len, channels = x.shape + x_2d = x.view(batch_size * seq_len, channels) + + quantized_x, ascales, lora_act = _svdq_quantize_w4a4( + x_2d, + lora_down=fc1.proj_down, + smooth=fc1.smooth_factor, + fp4=fc1.precision == "nvfp4", + pad_size=pad_size, + ) + + batch_size_pad = (batch_size * seq_len + pad_size - 1) // pad_size * pad_size + is_fp4 = fc2.precision == "nvfp4" + + qout_act = torch.empty( + batch_size_pad, + fc1.output_size_per_partition // 2, + dtype=torch.uint8, + device=x_2d.device, + ) + if is_fp4: + qout_ascales = torch.empty( + fc1.output_size_per_partition // 16, + batch_size_pad, + dtype=torch.float8_e4m3fn, + device=x_2d.device, + ) + else: + qout_ascales = torch.empty( + fc1.output_size_per_partition // 64, + batch_size_pad, + dtype=x_2d.dtype, + device=x_2d.device, + ) + qout_lora_act = torch.empty( + batch_size_pad, fc2.proj_down.shape[1], dtype=torch.float32, device=x_2d.device + ) + + # fused: fc1 GEMM + GELU + shift + unsigned quantize + fc2.lora_down + _svdq_gemm_w4a4( + act=quantized_x, + wgt=fc1.qweight, + qout=qout_act, + ascales=ascales, + wscales=fc1.wscales, + oscales=qout_ascales, + lora_act_in=lora_act, + lora_up=fc1.proj_up, + lora_down=fc2.proj_down, + lora_act_out=qout_lora_act, + bias=fc1.bias, + smooth_factor=fc2.smooth_factor, + fp4=is_fp4, + alpha=getattr(fc1, "_nunchaku_alpha", None), + wcscales=getattr(fc1, "wcscales", None), + ) + + output = torch.empty( + batch_size * seq_len, + fc2.output_size_per_partition, + dtype=x_2d.dtype, + device=x_2d.device, + ) + # fc2 GEMM with unsigned INT4 activations (fused kernel shifted by 0.171875) + _svdq_gemm_w4a4( + act=qout_act, + wgt=fc2.qweight, + out=output, + ascales=qout_ascales, + wscales=fc2.wscales, + lora_act_in=qout_lora_act, + lora_up=fc2.proj_up, + bias=fc2.bias, + fp4=is_fp4, + alpha=getattr(fc2, "_nunchaku_alpha", None), + wcscales=getattr(fc2, "wcscales", None), + act_unsigned=True, + ) + + return output.view(batch_size, seq_len, -1) + + +def _get_qkv_projections( + attn: "FluxAttention", hidden_states, encoder_hidden_states=None +): + if getattr(attn, "use_fused_qkv", False): + qkv, _ = attn.to_qkv(hidden_states) + query, key, value = [x.contiguous() for x in qkv.chunk(3, dim=-1)] + else: + query, _ = attn.to_q(hidden_states) + key, _ = attn.to_k(hidden_states) + value, _ = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + if getattr(attn, "use_fused_added_qkv", False): + added_qkv, _ = attn.to_added_qkv(encoder_hidden_states) + encoder_query, encoder_key, encoder_value = [ + x.contiguous() for x in added_qkv.chunk(3, dim=-1) + ] + else: + encoder_query, _ = attn.add_q_proj(encoder_hidden_states) + encoder_key, _ = attn.add_k_proj(encoder_hidden_states) + encoder_value, _ = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +class FluxAttention(torch.nn.Module, AttentionModuleMixin): + def __init__( + self, + query_dim: int, + num_heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + context_pre_only: Optional[bool] = None, + pre_only: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * num_heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else num_heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.use_fused_qkv = isinstance(quant_config, NunchakuConfig) + self.use_fused_added_qkv = isinstance(quant_config, NunchakuConfig) + + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + + if self.use_fused_qkv: + self.to_qkv = MergedColumnParallelLinear( + query_dim, + [self.inner_dim] * 3, + bias=bias, + gather_output=True, + quant_config=quant_config, + prefix=f"{prefix}.to_qkv" if prefix else "to_qkv", + ) + else: + self.to_q = ColumnParallelLinear( + query_dim, + self.inner_dim, + bias=bias, + gather_output=True, + quant_config=quant_config, + ) + self.to_k = ColumnParallelLinear( + query_dim, + self.inner_dim, + bias=bias, + gather_output=True, + quant_config=quant_config, + ) + self.to_v = ColumnParallelLinear( + query_dim, + self.inner_dim, + bias=bias, + gather_output=True, + quant_config=quant_config, + ) + if not self.pre_only: + self.to_out = torch.nn.ModuleList([]) + self.to_out.append( + ColumnParallelLinear( + self.inner_dim, + self.out_dim, + bias=out_bias, + gather_output=True, + quant_config=quant_config, + prefix=f"{prefix}.to_out.0" if prefix else "", + ) + ) + if dropout != 0.0: + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + if self.use_fused_added_qkv: + self.to_added_qkv = MergedColumnParallelLinear( + added_kv_proj_dim, + [self.inner_dim] * 3, + bias=added_proj_bias, + gather_output=True, + quant_config=quant_config, + prefix=f"{prefix}.to_added_qkv" if prefix else "to_added_qkv", + ) + else: + self.add_q_proj = ColumnParallelLinear( + added_kv_proj_dim, + self.inner_dim, + bias=added_proj_bias, + gather_output=True, + quant_config=quant_config, + ) + self.add_k_proj = ColumnParallelLinear( + added_kv_proj_dim, + self.inner_dim, + bias=added_proj_bias, + gather_output=True, + quant_config=quant_config, + ) + self.add_v_proj = ColumnParallelLinear( + added_kv_proj_dim, + self.inner_dim, + bias=added_proj_bias, + gather_output=True, + quant_config=quant_config, + ) + self.to_add_out = ColumnParallelLinear( + self.inner_dim, + query_dim, + bias=out_bias, + gather_output=True, + quant_config=quant_config, + prefix=f"{prefix}.to_add_out" if prefix else "", + ) + + self.attn = USPAttention( + num_heads=num_heads, + head_size=self.head_dim, + dropout_rate=0, + softmax_scale=None, + causal=False, + ) + + def forward( + self, + x: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + freqs_cis=None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + query, key, value, encoder_query, encoder_key, encoder_value = ( + _get_qkv_projections(self, x, encoder_hidden_states) + ) + + query = query.unflatten(-1, (self.heads, -1)) + key = key.unflatten(-1, (self.heads, -1)) + value = value.unflatten(-1, (self.heads, -1)) + query, key = apply_qk_norm( + q=query, + k=key, + q_norm=self.norm_q, + k_norm=self.norm_k, + head_dim=self.head_dim, + allow_inplace=True, + ) + + if self.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (self.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (self.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (self.heads, -1)) + + encoder_query, encoder_key = apply_qk_norm( + q=encoder_query, + k=encoder_key, + q_norm=self.norm_added_q, + k_norm=self.norm_added_k, + head_dim=self.head_dim, + allow_inplace=True, + ) + + bsz, seq_len, _, _ = query.shape + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if freqs_cis is not None: + cos, sin = freqs_cis + cos_sin_cache = torch.cat( + [ + cos.to(dtype=torch.float32).contiguous(), + sin.to(dtype=torch.float32).contiguous(), + ], + dim=-1, + ) + query, key = apply_flashinfer_rope_qk_inplace( + query, key, cos_sin_cache, is_neox=False + ) + + x = self.attn(query, key, value) + x = x.flatten(2, 3) + x = x.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, x = x.split_with_sizes( + [ + encoder_hidden_states.shape[1], + x.shape[1] - encoder_hidden_states.shape[1], + ], + dim=1, + ) + if not self.pre_only: + x, _ = self.to_out[0](x) + if len(self.to_out) == 2: + x = self.to_out[1](x) + encoder_hidden_states, _ = self.to_add_out(encoder_hidden_states) + + return x, encoder_hidden_states + else: + if not self.pre_only: + x, _ = self.to_out[0](x) + if len(self.to_out) == 2: + x = self.to_out[1](x) + return x + + +class FluxSingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + self.use_nunchaku_structure = isinstance(quant_config, NunchakuConfig) + + self.norm = AdaLayerNormZeroSingle(dim) + + if self.use_nunchaku_structure: + self.mlp_fc1 = ColumnParallelLinear( + dim, + self.mlp_hidden_dim, + bias=True, + gather_output=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp_fc1" if prefix else "mlp_fc1", + ) + self.act_mlp = nn.GELU(approximate="tanh") + self.mlp_fc2 = ColumnParallelLinear( + self.mlp_hidden_dim, + dim, + bias=True, + gather_output=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp_fc2" if prefix else "mlp_fc2", + ) + + self.attn = FluxAttention( + query_dim=dim, + dim_head=attention_head_dim, + num_heads=num_attention_heads, + out_dim=dim, + bias=True, + eps=1e-6, + pre_only=False, + quant_config=quant_config, + prefix=f"{prefix}.attn" if prefix else "attn", + ) + if is_nunchaku_available(): + self.norm = NunchakuAdaLayerNormZeroSingle(self.norm, scale_shift=0) + else: + self.proj_mlp = ColumnParallelLinear( + dim, + self.mlp_hidden_dim, + bias=True, + gather_output=True, + quant_config=quant_config, + ) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = ColumnParallelLinear( + dim + self.mlp_hidden_dim, + dim, + bias=True, + gather_output=True, + quant_config=quant_config, + ) + self.attn = FluxAttention( + query_dim=dim, + dim_head=attention_head_dim, + num_heads=num_attention_heads, + out_dim=dim, + bias=True, + eps=1e-6, + pre_only=True, + quant_config=quant_config, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + joint_attention_kwargs = joint_attention_kwargs or {} + + if self.use_nunchaku_structure: + if _nunchaku_fused_ops_available: + mlp_hidden_states = _fused_gelu_mlp( + norm_hidden_states, self.mlp_fc1, self.mlp_fc2 + ) + else: + mlp_out, _ = self.mlp_fc1(norm_hidden_states) + mlp_hidden_states = self.act_mlp(mlp_out) + mlp_hidden_states, _ = self.mlp_fc2(mlp_hidden_states) + + attn_output = self.attn( + x=norm_hidden_states, + freqs_cis=freqs_cis, + **joint_attention_kwargs, + ) + if isinstance(attn_output, tuple): + attn_output = attn_output[0] + + hidden_states = attn_output + mlp_hidden_states + gate = gate.unsqueeze(1) + hidden_states = gate * hidden_states + hidden_states = residual + hidden_states + else: + proj_hidden_states, _ = self.proj_mlp(norm_hidden_states) + mlp_hidden_states = self.act_mlp(proj_hidden_states) + + attn_output = self.attn( + x=norm_hidden_states, + freqs_cis=freqs_cis, + **joint_attention_kwargs, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + proj_out, _ = self.proj_out(hidden_states) + hidden_states = gate * proj_out + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + encoder_hidden_states, hidden_states = ( + hidden_states[:, :text_seq_len], + hidden_states[:, text_seq_len:], + ) + return encoder_hidden_states, hidden_states + + +class FluxTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + self.norm1_context = AdaLayerNormZero(dim) + + self.attn = FluxAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + num_heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + eps=eps, + quant_config=quant_config, + prefix=f"{prefix}.attn" if prefix else "attn", + ) + + self.norm2 = LayerNorm(dim, eps=1e-6, elementwise_affine=False) + self.norm2_context = LayerNorm(dim, eps=1e-6, elementwise_affine=False) + + nunchaku_enabled = ( + quant_config is not None + and hasattr(quant_config, "get_name") + and quant_config.get_name() == "svdquant" + and is_nunchaku_available() + ) + self.use_nunchaku_structure = nunchaku_enabled + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + self.ff_context = FeedForward( + dim=dim, + dim_out=dim, + activation_fn="gelu-approximate", + ) + if nunchaku_enabled: + nunchaku_kwargs = { + "precision": quant_config.precision, + "rank": quant_config.rank, + "act_unsigned": quant_config.act_unsigned, + } + self.ff = NunchakuFeedForward(self.ff, **nunchaku_kwargs) + self.ff_context = NunchakuFeedForward(self.ff_context, **nunchaku_kwargs) + self.norm1 = NunchakuAdaLayerNormZero(self.norm1, scale_shift=0) + self.norm1_context = NunchakuAdaLayerNormZero( + self.norm1_context, scale_shift=0 + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, emb=temb + ) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( + self.norm1_context(encoder_hidden_states, emb=temb) + ) + + joint_attention_kwargs = joint_attention_kwargs or {} + # Attention. + attention_outputs = self.attn( + x=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + freqs_cis=freqs_cis, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + norm_hidden_states = self.norm2(hidden_states) + if self.use_nunchaku_structure: + norm_hidden_states = ( + norm_hidden_states * scale_mlp[:, None] + shift_mlp[:, None] + ) + else: + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + if self.use_nunchaku_structure: + norm_encoder_hidden_states = ( + norm_encoder_hidden_states * c_scale_mlp[:, None] + c_shift_mlp[:, None] + ) + else: + norm_encoder_hidden_states = ( + norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + + c_shift_mlp[:, None] + ) + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = ( + encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + ) + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class FluxPosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.rope = NDRotaryEmbedding( + rope_dim_list=axes_dim, + rope_theta=theta, + use_real=False, + repeat_interleave_real=False, + dtype=( + torch.float32 + if current_platform.is_mps() or current_platform.is_musa() + else torch.float64 + ), + ) + + def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + pos = ids.float() + # TODO: potential error: flux use n_axes = ids.shape[-1] + # see: https://github.com/huggingface/diffusers/blob/17c0e79dbdf53fb6705e9c09cc1a854b84c39249/src/diffusers/models/transformers/transformer_flux.py#L509 + freqs_cos, freqs_sin = self.rope.forward_uncached(pos=pos) + return freqs_cos.contiguous().float(), freqs_sin.contiguous().float() + + +class FluxTransformer2DModel(CachableDiT, OffloadableDiTMixin): + """ + The Transformer model introduced in Flux. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + """ + + param_names_mapping = FluxConfig().arch_config.param_names_mapping + + @classmethod + def get_nunchaku_quant_rules(cls) -> dict[str, list[str]]: + return { + "skip": [ + "norm", + "embed", + "rotary", + "pos_embed", + ], + "svdq_w4a4": [ + "attn.to_qkv", + "attn.to_out", + "attn.add_qkv_proj", + "attn.to_added_qkv", + "attn.to_add_out", + "img_mlp", + "txt_mlp", + "attention.to_qkv", + "attention.to_out", + "proj_mlp", + "proj_out", + "mlp_fc1", + "mlp_fc2", + "ff.net", + "ff_context.net", + ], + "awq_w4a16": [ + "img_mod", + "txt_mod", + ], + } + + def __init__( + self, + config: FluxConfig, + hf_config: dict[str, Any], + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__(config=config, hf_config=hf_config) + self.config = config.arch_config + + self.out_channels = ( + getattr(self.config, "out_channels", None) or self.config.in_channels + ) + self.inner_dim = ( + self.config.num_attention_heads * self.config.attention_head_dim + ) + + self.rotary_emb = FluxPosEmbed(theta=10000, axes_dim=self.config.axes_dims_rope) + + text_time_guidance_cls = ( + CombinedTimestepGuidanceTextProjEmbeddings + if self.config.guidance_embeds + else CombinedTimestepTextProjEmbeddings + ) + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, + pooled_projection_dim=self.config.pooled_projection_dim, + ) + + self.context_embedder = ColumnParallelLinear( + self.config.joint_attention_dim, + self.inner_dim, + bias=True, + gather_output=True, + ) + self.x_embedder = ColumnParallelLinear( + self.config.in_channels, self.inner_dim, bias=True, gather_output=True + ) + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + quant_config=quant_config, + prefix=f"transformer_blocks.{i}", + ) + for i in range(self.config.num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + quant_config=quant_config, + prefix=f"single_transformer_blocks.{i}", + ) + for i in range(self.config.num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 + ) + self.proj_out = ColumnParallelLinear( + self.inner_dim, + self.config.patch_size * self.config.patch_size * self.out_channels, + bias=True, + gather_output=True, + ) + + self.layer_names = [ + "transformer_blocks", + "single_transformer_blocks", + ] + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + guidance: torch.Tensor = None, + freqs_cis: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + guidance (`torch.Tensor`): + Guidance embeddings. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + """ + if ( + joint_attention_kwargs is not None + and joint_attention_kwargs.get("scale", None) is not None + ): + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + hidden_states, _ = self.x_embedder(hidden_states) + + # Only pass guidance to time_text_embed if the model supports it + if self.config.guidance_embeds and guidance is not None: + temb = self.time_text_embed(timestep, guidance, pooled_projections) + else: + temb = self.time_text_embed(timestep, pooled_projections) + + encoder_hidden_states, _ = self.context_embedder(encoder_hidden_states) + + if ( + joint_attention_kwargs is not None + and "ip_adapter_image_embeds" in joint_attention_kwargs + ): + ip_adapter_image_embeds = joint_attention_kwargs.pop( + "ip_adapter_image_embeds" + ) + ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) + joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + freqs_cis=freqs_cis, + joint_attention_kwargs=joint_attention_kwargs, + ) + for block in self.single_transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + freqs_cis=freqs_cis, + joint_attention_kwargs=joint_attention_kwargs, + ) + + hidden_states = self.norm_out(hidden_states, temb) + + output, _ = self.proj_out(hidden_states) + + return output + + +EntryClass = FluxTransformer2DModel diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py b/sglang/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py new file mode 100644 index 0000000000000000000000000000000000000000..5cd447f35635d54aa32555c0092703f825a95023 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py @@ -0,0 +1,901 @@ +# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from diffusers.models.attention import AttentionModuleMixin +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.normalization import AdaLayerNormContinuous + +from sglang.multimodal_gen.configs.models.dits.flux import FluxConfig +from sglang.multimodal_gen.runtime.layers.attention import USPAttention +from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm, apply_qk_norm +from sglang.multimodal_gen.runtime.layers.linear import ColumnParallelLinear +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, +) +from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( + NDRotaryEmbedding, + apply_flashinfer_rope_qk_inplace, +) +from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) # pylint: disable=invalid-name + + +def _get_qkv_projections( + attn: "Flux2Attention", hidden_states, encoder_hidden_states=None +): + query, _ = attn.to_q(hidden_states) + key, _ = attn.to_k(hidden_states) + value, _ = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query, _ = attn.add_q_proj(encoder_hidden_states) + encoder_key, _ = attn.add_k_proj(encoder_hidden_states) + encoder_value, _ = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +class Flux2SwiGLU(nn.Module): + """ + Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection + layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters. + """ + + def __init__(self): + super().__init__() + self.gate_fn = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(2, dim=-1) + x = self.gate_fn(x1) * x2 + return x + + +class Flux2FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: float = 3.0, + inner_dim: Optional[int] = None, + bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out or dim + + # Flux2SwiGLU will reduce the dimension by half + self.linear_in = ColumnParallelLinear( + dim, inner_dim * 2, bias=bias, gather_output=True, quant_config=quant_config + ) + self.act_fn = Flux2SwiGLU() + self.linear_out = ColumnParallelLinear( + inner_dim, dim_out, bias=bias, gather_output=True, quant_config=quant_config + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.linear_in(x) + x = self.act_fn(x) + x, _ = self.linear_out(x) + return x + + +class Flux2Attention(torch.nn.Module, AttentionModuleMixin): + def __init__( + self, + query_dim: int, + num_heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * num_heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else num_heads + + self.use_bias = bias + self.dropout = dropout + + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.to_q = ColumnParallelLinear( + query_dim, + self.inner_dim, + bias=bias, + gather_output=True, + quant_config=quant_config, + ) + self.to_k = ColumnParallelLinear( + query_dim, + self.inner_dim, + bias=bias, + gather_output=True, + quant_config=quant_config, + ) + self.to_v = ColumnParallelLinear( + query_dim, + self.inner_dim, + bias=bias, + gather_output=True, + quant_config=quant_config, + ) + + # QK Norm + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + + self.to_out = torch.nn.ModuleList([]) + self.to_out.append( + ColumnParallelLinear( + self.inner_dim, + self.out_dim, + bias=out_bias, + gather_output=True, + quant_config=quant_config, + ) + ) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + self.add_q_proj = ColumnParallelLinear( + added_kv_proj_dim, + self.inner_dim, + bias=added_proj_bias, + gather_output=True, + quant_config=quant_config, + ) + self.add_k_proj = ColumnParallelLinear( + added_kv_proj_dim, + self.inner_dim, + bias=added_proj_bias, + gather_output=True, + quant_config=quant_config, + ) + self.add_v_proj = ColumnParallelLinear( + added_kv_proj_dim, + self.inner_dim, + bias=added_proj_bias, + gather_output=True, + quant_config=quant_config, + ) + self.to_add_out = ColumnParallelLinear( + self.inner_dim, + query_dim, + bias=out_bias, + gather_output=True, + quant_config=quant_config, + ) + + self.attn = USPAttention( + num_heads=num_heads, + head_size=self.head_dim, + dropout_rate=0, + softmax_scale=None, + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = ( + _get_qkv_projections(self, hidden_states, encoder_hidden_states) + ) + + query = query.unflatten(-1, (self.heads, -1)) + key = key.unflatten(-1, (self.heads, -1)) + value = value.unflatten(-1, (self.heads, -1)) + + query, key = apply_qk_norm( + q=query, + k=key, + q_norm=self.norm_q, + k_norm=self.norm_k, + head_dim=self.head_dim, + allow_inplace=True, + ) + + if self.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (self.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (self.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (self.heads, -1)) + + encoder_query, encoder_key = apply_qk_norm( + q=encoder_query, + k=encoder_key, + q_norm=self.norm_added_q, + k_norm=self.norm_added_k, + head_dim=self.head_dim, + allow_inplace=True, + ) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if freqs_cis is not None: + cos, sin = freqs_cis + cos_sin_cache = torch.cat( + [ + cos.to(dtype=torch.float32).contiguous(), + sin.to(dtype=torch.float32).contiguous(), + ], + dim=-1, + ) + query, key = apply_flashinfer_rope_qk_inplace( + query, key, cos_sin_cache, is_neox=False + ) + + num_rep = ( + encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0 + ) + hidden_states = self.attn(query, key, value, num_replicated_prefix=num_rep) + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [ + encoder_hidden_states.shape[1], + hidden_states.shape[1] - encoder_hidden_states.shape[1], + ], + dim=1, + ) + encoder_hidden_states, _ = self.to_add_out(encoder_hidden_states) + + hidden_states, _ = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin): + """ + Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks. + + This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF) + input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B + paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block. + """ + + # Does not support QKV fusion as the QKV projections are always fused + _supports_qkv_fusion = False + + def __init__( + self, + query_dim: int, + num_heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + mlp_ratio: float = 4.0, + mlp_mult_factor: int = 2, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * num_heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else num_heads + + self.use_bias = bias + self.dropout = dropout + + self.mlp_ratio = mlp_ratio + self.mlp_hidden_dim = int(query_dim * self.mlp_ratio) + self.mlp_mult_factor = mlp_mult_factor + + # Fused QKV projections + MLP input projection + self.to_qkv_mlp_proj = ColumnParallelLinear( + self.query_dim, + self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, + bias=bias, + gather_output=True, + quant_config=quant_config, + ) + self.mlp_act_fn = Flux2SwiGLU() + + # QK Norm + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + + # Fused attention output projection + MLP output projection + self.to_out = ColumnParallelLinear( + self.inner_dim + self.mlp_hidden_dim, + self.out_dim, + bias=out_bias, + gather_output=True, + quant_config=quant_config, + ) + + self.attn = USPAttention( + num_heads=num_heads, + head_size=self.head_dim, + dropout_rate=0, + softmax_scale=None, + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + num_replicated_prefix: int = 0, + **kwargs, + ) -> torch.Tensor: + # Parallel in (QKV + MLP in) projection + hidden_states, _ = self.to_qkv_mlp_proj(hidden_states) + qkv, mlp_hidden_states = torch.split( + hidden_states, + [3 * self.inner_dim, self.mlp_hidden_dim * self.mlp_mult_factor], + dim=-1, + ) + + # Handle the attention logic + query, key, value = qkv.chunk(3, dim=-1) + + query = query.unflatten(-1, (self.heads, -1)) + key = key.unflatten(-1, (self.heads, -1)) + value = value.unflatten(-1, (self.heads, -1)) + + query = self.norm_q(query) + key = self.norm_k(key) + + if freqs_cis is not None: + cos, sin = freqs_cis + cos_sin_cache = torch.cat( + [ + cos.to(dtype=torch.float32).contiguous(), + sin.to(dtype=torch.float32).contiguous(), + ], + dim=-1, + ) + query, key = apply_flashinfer_rope_qk_inplace( + query, key, cos_sin_cache, is_neox=False + ) + hidden_states = self.attn( + query, key, value, num_replicated_prefix=num_replicated_prefix + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # Handle the feedforward (FF) logic + mlp_hidden_states = self.mlp_act_fn(mlp_hidden_states) + + # Concatenate and parallel output projection + hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1) + hidden_states, _ = self.to_out(hidden_states) + + return hidden_states + + +class Flux2SingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + # Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this + # is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442) + # for a visual depiction of this type of transformer block. + self.attn = Flux2ParallelSelfAttention( + query_dim=dim, + dim_head=attention_head_dim, + num_heads=num_attention_heads, + out_dim=dim, + bias=bias, + out_bias=bias, + eps=eps, + mlp_ratio=mlp_ratio, + mlp_mult_factor=2, + quant_config=quant_config, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + split_hidden_states: bool = False, + text_seq_len: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already + # concatenated + if encoder_hidden_states is not None: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + mod_shift, mod_scale, mod_gate = temb_mod_params + + norm_hidden_states = self.norm(hidden_states) + norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift + + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + freqs_cis=freqs_cis, + num_replicated_prefix=text_seq_len or 0, + **joint_attention_kwargs, + ) + + hidden_states = hidden_states + mod_gate * attn_output + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + if split_hidden_states: + encoder_hidden_states, hidden_states = ( + hidden_states[:, :text_seq_len], + hidden_states[:, text_seq_len:], + ) + return encoder_hidden_states, hidden_states + else: + return hidden_states + + +class Flux2TransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + self.attn = Flux2Attention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + num_heads=num_attention_heads, + out_dim=dim, + bias=bias, + added_proj_bias=bias, + out_bias=bias, + eps=eps, + quant_config=quant_config, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff = Flux2FeedForward( + dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias, quant_config=quant_config + ) + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff_context = Flux2FeedForward( + dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias, quant_config=quant_config + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb_mod_params_img: Tuple[ + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ... + ], + temb_mod_params_txt: Tuple[ + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ... + ], + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + joint_attention_kwargs = joint_attention_kwargs or {} + + # Modulation parameters shape: [1, 1, self.dim] + (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = ( + temb_mod_params_img + ) + (c_shift_msa, c_scale_msa, c_gate_msa), ( + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = temb_mod_params_txt + + # Img stream + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa + + # Conditioning txt stream + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) + norm_encoder_hidden_states = ( + 1 + c_scale_msa + ) * norm_encoder_hidden_states + c_shift_msa + + # Attention on concatenated img + txt stream + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + freqs_cis=freqs_cis, + **joint_attention_kwargs, + ) + + attn_output, context_attn_output = attention_outputs + + # Process attention outputs for the image stream (`hidden_states`). + attn_output = gate_msa * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + gate_mlp * ff_output + + # Process attention outputs for the text stream (`encoder_hidden_states`). + context_attn_output = c_gate_msa * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = ( + norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp + ) + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class Flux2TimestepGuidanceEmbeddings(nn.Module): + def __init__( + self, + in_channels: int = 256, + embedding_dim: int = 6144, + bias: bool = False, + guidance_embeds: bool = True, + ): + super().__init__() + + self.time_proj = Timesteps( + num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0 + ) + self.timestep_embedder = TimestepEmbedding( + in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias + ) + + if guidance_embeds: + self.guidance_embedder = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=embedding_dim, + sample_proj_bias=bias, + ) + else: + self.guidance_embedder = None + + def forward( + self, timestep: torch.Tensor, guidance: Optional[torch.Tensor] = None + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.to(timestep.dtype) + ) # (N, D) + + if guidance is not None and self.guidance_embedder is not None: + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder( + guidance_proj.to(guidance.dtype) + ) # (N, D) + time_guidance_emb = timesteps_emb + guidance_emb + return time_guidance_emb + else: + return timesteps_emb + + +class Flux2Modulation(nn.Module): + def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False): + super().__init__() + self.mod_param_sets = mod_param_sets + + self.linear = ColumnParallelLinear( + dim, dim * 3 * self.mod_param_sets, bias=bias, gather_output=True + ) + self.act_fn = nn.SiLU() + + def forward( + self, temb: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]: + mod = self.act_fn(temb) + mod, _ = self.linear(mod) + + if mod.ndim == 2: + mod = mod.unsqueeze(1) + mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1) + # Return tuple of 3-tuples of modulation params shift/scale/gate + return tuple( + mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets) + ) + + +class Flux2PosEmbed(nn.Module): + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.rope = NDRotaryEmbedding( + rope_dim_list=axes_dim, + rope_theta=theta, + use_real=False, + repeat_interleave_real=False, + dtype=( + torch.float32 + if current_platform.is_mps() or current_platform.is_musa() + else torch.float64 + ), + ) + + def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + pos = ids.float() + # TODO: potential error: flux use n_axes = ids.shape[-1] + # see: https://github.com/huggingface/diffusers/blob/17c0e79dbdf53fb6705e9c09cc1a854b84c39249/src/diffusers/models/transformers/transformer_flux.py#L509 + freqs_cos, freqs_sin = self.rope.forward_uncached(pos=pos) + return freqs_cos.contiguous().float(), freqs_sin.contiguous().float() + + +class Flux2Transformer2DModel(CachableDiT, OffloadableDiTMixin): + """ + The Transformer model introduced in Flux 2. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + """ + + param_names_mapping = FluxConfig().arch_config.param_names_mapping + + def __init__( + self, + config: FluxConfig, + hf_config: dict[str, Any], + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(config=config, hf_config=hf_config) + patch_size: int = config.patch_size + in_channels: int = config.in_channels + out_channels: Optional[int] = config.out_channels + num_layers: int = config.num_layers + num_single_layers: int = config.num_single_layers + attention_head_dim: int = config.attention_head_dim + num_attention_heads: int = config.num_attention_heads + joint_attention_dim: int = config.joint_attention_dim + timestep_guidance_channels: int = config.timestep_guidance_channels + mlp_ratio: float = config.mlp_ratio + axes_dims_rope: Tuple[int, ...] = config.axes_dims_rope + rope_theta: int = config.rope_theta + eps: float = config.eps + guidance_embeds: bool = getattr(config, "guidance_embeds", True) + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.guidance_embeds = guidance_embeds + + # 1. Sinusoidal positional embedding for RoPE on image and text tokens + self.rotary_emb = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope) + + # 2. Combined timestep + guidance embedding + self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( + in_channels=timestep_guidance_channels, + embedding_dim=self.inner_dim, + bias=False, + guidance_embeds=guidance_embeds, + ) + + # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.) + # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks + self.double_stream_modulation_img = Flux2Modulation( + self.inner_dim, mod_param_sets=2, bias=False + ) + self.double_stream_modulation_txt = Flux2Modulation( + self.inner_dim, mod_param_sets=2, bias=False + ) + # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream + self.single_stream_modulation = Flux2Modulation( + self.inner_dim, mod_param_sets=1, bias=False + ) + + # 4. Input projections + self.x_embedder = ColumnParallelLinear( + in_channels, self.inner_dim, bias=False, gather_output=True + ) + self.context_embedder = ColumnParallelLinear( + joint_attention_dim, self.inner_dim, bias=False, gather_output=True + ) + + # 5. Double Stream Transformer Blocks + self.transformer_blocks = nn.ModuleList( + [ + Flux2TransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + quant_config=quant_config, + ) + for _ in range(num_layers) + ] + ) + + # 6. Single Stream Transformer Blocks + self.single_transformer_blocks = nn.ModuleList( + [ + Flux2SingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + quant_config=quant_config, + ) + for _ in range(num_single_layers) + ] + ) + + # 7. Output layers + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, + self.inner_dim, + elementwise_affine=False, + eps=eps, + bias=False, + ) + self.proj_out = ColumnParallelLinear( + self.inner_dim, + patch_size * patch_size * self.out_channels, + bias=False, + gather_output=True, + ) + + self.layer_names = ["transformer_blocks", "single_transformer_blocks"] + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + timestep: torch.LongTensor = None, + guidance: torch.Tensor = None, + freqs_cis: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + """ + # 0. Handle input arguments + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + num_txt_tokens = encoder_hidden_states.shape[1] + + # 1. Calculate timestep embedding and modulation parameters + timestep = timestep.to(hidden_states.dtype) + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) + + temb = self.time_guidance_embed(timestep, guidance) + + double_stream_mod_img = self.double_stream_modulation_img(temb) + double_stream_mod_txt = self.double_stream_modulation_txt(temb) + single_stream_mod = self.single_stream_modulation(temb)[0] + + # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states) + hidden_states, _ = self.x_embedder(hidden_states) + encoder_hidden_states, _ = self.context_embedder(encoder_hidden_states) + + # 3. Calculate RoPE embeddings from image and text tokens + # NOTE: the below logic means that we can't support batched inference with images of different resolutions or + # text prompts of different lengths. Is this a use case we want to support? + # 4. Double Stream Transformer Blocks + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_mod_params_img=double_stream_mod_img, + temb_mod_params_txt=double_stream_mod_txt, + freqs_cis=freqs_cis, + joint_attention_kwargs=joint_attention_kwargs, + ) + # Concatenate text and image streams for single-block inference + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # 5. Single Stream Transformer Blocks + for index_block, block in enumerate(self.single_transformer_blocks): + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=None, + temb_mod_params=single_stream_mod, + freqs_cis=freqs_cis, + joint_attention_kwargs=joint_attention_kwargs, + text_seq_len=num_txt_tokens, + ) + # Remove text tokens from concatenated stream + hidden_states = hidden_states[:, num_txt_tokens:, ...] + + # 6. Output layers + hidden_states = self.norm_out(hidden_states, temb) + output, _ = self.proj_out(hidden_states) + + return output + + +EntryClass = Flux2Transformer2DModel diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/dits/glm_image.py b/sglang/python/sglang/multimodal_gen/runtime/models/dits/glm_image.py new file mode 100644 index 0000000000000000000000000000000000000000..f7f223dec196dbeb89cfe24bf2c3b4ddc94c6445 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/dits/glm_image.py @@ -0,0 +1,877 @@ +# Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sglang.multimodal_gen.configs.models.dits.glmimage import GlmImageDitConfig +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_sp_parallel_rank, + get_sp_world_size, +) +from sglang.multimodal_gen.runtime.layers.attention import USPAttention +from sglang.multimodal_gen.runtime.layers.layernorm import ( + ScaleResidualLayerNormScaleShift, +) +from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear +from sglang.multimodal_gen.runtime.layers.mlp import FeedForward +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, +) +from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( + _apply_rotary_emb, + apply_flashinfer_rope_qk_inplace, +) +from sglang.multimodal_gen.runtime.layers.visual_embedding import Timesteps +from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) +from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +_is_cuda = current_platform.is_cuda() + + +class GlmImageLayerKVCache: + """KV cache for GlmImage model.""" + + def __init__(self): + self.k_cache = None + self.v_cache = None + self.mode: Optional[str] = None # "write", "read", "skip" + + def store(self, k: torch.Tensor, v: torch.Tensor): + if self.k_cache is None: + self.k_cache = k + self.v_cache = v + else: + self.k_cache = torch.cat([self.k_cache, k], dim=2) + self.v_cache = torch.cat([self.v_cache, v], dim=2) + + def get(self): + return self.k_cache, self.v_cache + + def clear(self): + self.k_cache = None + self.v_cache = None + self.mode = None + + +class GlmImageKVCache: + """Container for all layers' KV caches.""" + + def __init__(self, num_layers: int): + self.num_layers = num_layers + self.caches = [GlmImageLayerKVCache() for _ in range(num_layers)] + + def __getitem__(self, layer_idx: int) -> GlmImageLayerKVCache: + return self.caches[layer_idx] + + def set_mode(self, mode: Optional[str]): + if mode is not None and mode not in ["write", "read", "skip"]: + raise ValueError( + f"Invalid mode: {mode}, must be one of 'write', 'read', 'skip'" + ) + for cache in self.caches: + cache.mode = mode + + def clear(self): + for cache in self.caches: + cache.clear() + + +class GlmImageTimestepEmbedding(nn.Module): + """ + Replacement for diffusers TimestepEmbedding using ReplicatedLinear. + Structure: linear_1 -> act(silu) -> linear_2 + """ + + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + ): + super().__init__() + if out_dim is None: + out_dim = time_embed_dim + self.linear_1 = ReplicatedLinear(in_channels, time_embed_dim, bias=True) + if act_fn == "silu": + self.act = nn.SiLU() + elif act_fn == "gelu": + self.act = nn.GELU(approximate="tanh") + else: + self.act = nn.SiLU() + self.linear_2 = ReplicatedLinear(time_embed_dim, out_dim, bias=True) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + sample, _ = self.linear_1(sample) + sample = self.act(sample) + sample, _ = self.linear_2(sample) + return sample + + +class GlmImageTextProjection(nn.Module): + """ + Replacement for diffusers PixArtAlphaTextProjection using ReplicatedLinear. + Structure: linear_1 -> act_1 -> linear_2 + """ + + def __init__( + self, + in_features: int, + hidden_size: int, + out_features: int = None, + act_fn: str = "silu", + ): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = ReplicatedLinear(in_features, hidden_size, bias=True) + if act_fn == "silu": + self.act_1 = nn.SiLU() + elif act_fn == "gelu_tanh": + self.act_1 = nn.GELU(approximate="tanh") + else: + self.act_1 = nn.SiLU() + self.linear_2 = ReplicatedLinear(hidden_size, out_features, bias=True) + + def forward(self, caption: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states, _ = self.linear_2(hidden_states) + return hidden_states + + +class GlmImageCombinedTimestepSizeEmbeddings(nn.Module): + def __init__( + self, + embedding_dim: int, + condition_dim: int, + pooled_projection_dim: int, + timesteps_dim: int = 256, + ): + super().__init__() + + self.time_proj = Timesteps( + num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0 + ) + self.condition_proj = Timesteps( + num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0 + ) + self.timestep_embedder = GlmImageTimestepEmbedding( + in_channels=timesteps_dim, time_embed_dim=embedding_dim + ) + self.condition_embedder = GlmImageTextProjection( + pooled_projection_dim, embedding_dim, act_fn="silu" + ) + + def forward( + self, + timestep: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + hidden_dtype: torch.dtype, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + + crop_coords_proj = self.condition_proj(crop_coords.flatten()).view( + crop_coords.size(0), -1 + ) + target_size_proj = self.condition_proj(target_size.flatten()).view( + target_size.size(0), -1 + ) + + # (B, 2 * condition_dim) + condition_proj = torch.cat([crop_coords_proj, target_size_proj], dim=1) + + timesteps_emb = self.timestep_embedder( + timesteps_proj.to(dtype=hidden_dtype) + ) # (B, embedding_dim) + condition_emb = self.condition_embedder( + condition_proj.to(dtype=hidden_dtype) + ) # (B, embedding_dim) + + conditioning = timesteps_emb + condition_emb + return conditioning + + +class GlmImageImageProjector(nn.Module): + def __init__( + self, + in_channels: int = 16, + hidden_size: int = 2560, + patch_size: int = 2, + ): + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Linear(in_channels * patch_size**2, hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, channel, height, width = hidden_states.shape + post_patch_height = height // self.patch_size + post_patch_width = width // self.patch_size + + hidden_states = hidden_states.reshape( + batch_size, + channel, + post_patch_height, + self.patch_size, + post_patch_width, + self.patch_size, + ) + hidden_states = ( + hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + ) + hidden_states = self.proj(hidden_states) + + return hidden_states + + +class GlmImageAdaLayerNormZero(nn.Module): + def __init__(self, embedding_dim: int, dim: int) -> None: + super().__init__() + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.linear = ReplicatedLinear(embedding_dim, 12 * dim, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + dtype = hidden_states.dtype + norm_hidden_states = self.norm(hidden_states).to(dtype=dtype) + norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to( + dtype=dtype + ) + + emb, _ = self.linear(temb) + ( + shift_msa, + c_shift_msa, + scale_msa, + c_scale_msa, + gate_msa, + c_gate_msa, + shift_mlp, + c_shift_mlp, + scale_mlp, + c_scale_mlp, + gate_mlp, + c_gate_mlp, + ) = emb.chunk(12, dim=1) + + hidden_states = norm_hidden_states * ( + 1 + scale_msa.unsqueeze(1) + ) + shift_msa.unsqueeze(1) + encoder_hidden_states = norm_encoder_hidden_states * ( + 1 + c_scale_msa.unsqueeze(1) + ) + c_shift_msa.unsqueeze(1) + + return ( + hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) + + +class GlmImageAttention(torch.nn.Module): + def __init__( + self, + query_dim, + heads, + dim_head, + out_dim, + bias, + qk_norm, + elementwise_affine, + eps, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + prefix: str = "", + quant_config: QuantizationConfig | None = None, + ): + super().__init__() + + self.k_cache = None + self.v_cache = None + + self.heads = out_dim // dim_head if out_dim is not None else heads + self.dim_head = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim + self.out_dim = out_dim if out_dim is not None else query_dim + + self.num_kv_heads = self.dim_head // self.inner_kv_dim + + self.to_q = ReplicatedLinear( + query_dim, self.inner_dim, bias=bias, quant_config=quant_config + ) + self.to_k = ReplicatedLinear( + query_dim, self.inner_kv_dim, bias=bias, quant_config=quant_config + ) + self.to_v = ReplicatedLinear( + query_dim, self.inner_kv_dim, bias=bias, quant_config=quant_config + ) + + # (dropout omitted) + self.to_out = nn.ModuleList( + [ + ReplicatedLinear( + self.inner_dim, self.out_dim, bias=True, quant_config=quant_config + ) + ] + ) + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm( + dim_head, eps=eps, elementwise_affine=elementwise_affine + ) + self.norm_k = nn.LayerNorm( + dim_head, eps=eps, elementwise_affine=elementwise_affine + ) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'." + ) + + self.attn = USPAttention( + num_heads=self.heads, + head_size=dim_head, + num_kv_heads=self.num_kv_heads, + dropout_rate=0, + softmax_scale=None, + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + kv_cache: Optional[GlmImageLayerKVCache] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + dtype = encoder_hidden_states.dtype + + batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape + batch_size, image_seq_length, embed_dim = hidden_states.shape + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # 1. QKV projections + query, _ = self.to_q(hidden_states) + key, _ = self.to_k(hidden_states) + value, _ = self.to_v(hidden_states) + + query = query.unflatten(2, (self.heads, -1)) + key = key.unflatten(2, (self.heads, -1)) + value = value.unflatten(2, (self.heads, -1)) + + # 2. QK normalization + if self.norm_q is not None: + query = self.norm_q(query).to(dtype=dtype) + if self.norm_k is not None: + key = self.norm_k(key).to(dtype=dtype) + + # 3. Rotational positional embeddings applied to latent stream + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + + if _is_cuda and cos.dim() == 2: + q_img = query[:, text_seq_length:, :, :] + k_img = key[:, text_seq_length:, :, :] + cos_sin_cache = torch.cat( + [ + cos.to(dtype=torch.float32).contiguous(), + sin.to(dtype=torch.float32).contiguous(), + ], + dim=-1, + ) + # apply_flashinfer_rope_qk_inplace is inplace kernel and q_img/k_img are views of query/key, so we need not copy back + q_out, k_out = apply_flashinfer_rope_qk_inplace( + q_img, k_img, cos_sin_cache, is_neox=True + ) + else: + query[:, text_seq_length:, :, :] = _apply_rotary_emb( + query[:, text_seq_length:, :, :], cos, sin, is_neox_style=True + ) + key[:, text_seq_length:, :, :] = _apply_rotary_emb( + key[:, text_seq_length:, :, :], cos, sin, is_neox_style=True + ) + + if kv_cache is not None: + if kv_cache.mode == "write": + kv_cache.store(key, value) + elif kv_cache.mode == "read": + k_cache, v_cache = kv_cache.get() + key = torch.cat([k_cache, key], dim=1) if k_cache is not None else key + value = ( + torch.cat([v_cache, value], dim=1) if v_cache is not None else value + ) + elif kv_cache.mode == "skip": + pass + + # 4. Attention + if attention_mask is not None: + text_attn_mask = attention_mask + assert ( + text_attn_mask.dim() == 2 + ), "the shape of text_attn_mask should be (batch_size, text_seq_length)" + text_attn_mask = text_attn_mask.float().to(query.device) + mix_attn_mask = torch.ones( + (batch_size, text_seq_length + image_seq_length), device=query.device + ) + mix_attn_mask[:, :text_seq_length] = text_attn_mask + mix_attn_mask = mix_attn_mask.unsqueeze(2) + attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2) + attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype) + hidden_states = self.attn(query, key, value) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # 5. Output projection + hidden_states, _ = self.to_out[0](hidden_states) + # hidden_states = self.to_out[1](hidden_states) # (dropout omitted) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + + +class GlmImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int = 2560, + num_attention_heads: int = 64, + attention_head_dim: int = 40, + time_embed_dim: int = 512, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + prefix: str = "", + quant_config: QuantizationConfig | None = None, + ) -> None: + super().__init__() + + # 1. Attention + self.norm1 = GlmImageAdaLayerNormZero(time_embed_dim, dim) + + self.attn1 = GlmImageAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + out_dim=dim, + bias=True, + qk_norm="layer_norm", + elementwise_affine=False, + eps=1e-5, + supported_attention_backends=supported_attention_backends, + prefix=f"{prefix}.attn1", + quant_config=quant_config, + ) + + # 2. Feedforward + self.norm2 = ScaleResidualLayerNormScaleShift( + dim, eps=1e-5, elementwise_affine=False + ) + self.norm2_context = ScaleResidualLayerNormScaleShift( + dim, eps=1e-5, elementwise_affine=False + ) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[ + Union[ + Tuple[torch.Tensor, torch.Tensor], + List[Tuple[torch.Tensor, torch.Tensor]], + ] + ] = None, + attention_mask: Optional[Dict[str, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + kv_cache: Optional[GlmImageLayerKVCache] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Timestep conditioning + ( + norm_hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + norm_encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = self.norm1(hidden_states, encoder_hidden_states, temb) + + # 2. Attention + if attention_kwargs is None: + attention_kwargs = {} + + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + kv_cache=kv_cache, + **attention_kwargs, + ) + + # 3. Feedforward (fused residual + norm + scale/shift) + norm_hidden_states, hidden_states = self.norm2( + hidden_states, + attn_hidden_states, + gate_msa.unsqueeze(1), + shift_mlp.unsqueeze(1), + scale_mlp.unsqueeze(1), + ) + norm_encoder_hidden_states, encoder_hidden_states = self.norm2_context( + encoder_hidden_states, + attn_encoder_hidden_states, + c_gate_msa.unsqueeze(1), + c_shift_mlp.unsqueeze(1), + c_scale_mlp.unsqueeze(1), + ) + + ff_output = self.ff(norm_hidden_states) + ff_output_context = self.ff(norm_encoder_hidden_states) + hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1) + encoder_hidden_states = ( + encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1) + ) + + return hidden_states, encoder_hidden_states + + +class GlmImageRotaryPosEmbed(nn.Module): + def __init__(self, dim: int, patch_size: int, theta: float = 10000.0) -> None: + super().__init__() + + self.dim = dim + self.patch_size = patch_size + self.theta = theta + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, num_channels, height, width = hidden_states.shape + height, width = height // self.patch_size, width // self.patch_size + device = hidden_states.device + + dim_h, dim_w = self.dim // 2, self.dim // 2 + h_inv_freq = 1.0 / ( + self.theta + ** ( + torch.arange(0, dim_h, 2, dtype=torch.float32, device=device)[ + : (dim_h // 2) + ].float() + / dim_h + ) + ) + w_inv_freq = 1.0 / ( + self.theta + ** ( + torch.arange(0, dim_w, 2, dtype=torch.float32, device=device)[ + : (dim_w // 2) + ].float() + / dim_w + ) + ) + h_seq = torch.arange(height, device=device) + w_seq = torch.arange(width, device=device) + freqs_h = torch.outer(h_seq, h_inv_freq) + freqs_w = torch.outer(w_seq, w_inv_freq) + + # Create position matrices for height and width + # [height, 1, dim//4] and [1, width, dim//4] + freqs_h = freqs_h.unsqueeze(1) + freqs_w = freqs_w.unsqueeze(0) + # Broadcast freqs_h and freqs_w to [height, width, dim//4] + freqs_h = freqs_h.expand(height, width, -1) + freqs_w = freqs_w.expand(height, width, -1) + + # Concatenate along last dimension to get [height, width, dim//2] + freqs = torch.cat([freqs_h, freqs_w], dim=-1) + freqs = freqs.reshape(height * width, -1) # [height * width, dim//2] + return (freqs.cos(), freqs.sin()) + + +class GlmImageAdaLayerNormContinuous(nn.Module): + """ + GlmImage-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the + Linear on conditioning embedding. + """ + + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + bias: bool = True, + norm_type: str = "layer_norm", + ): + super().__init__() + self.linear = nn.Linear( + conditioning_embedding_dim, embedding_dim * 2, bias=bias + ) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) + # For now, don’t replace this with sglang’s LayerNorm + # because the model doesn’t have this parameter and it will break model loading + elif norm_type == "rms_norm": + self.norm = nn.RMSNorm(embedding_dim, eps, elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward( + self, x: torch.Tensor, conditioning_embedding: torch.Tensor + ) -> torch.Tensor: + # *** NO SiLU here *** + emb = self.linear(conditioning_embedding.to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +class GlmImageTransformer2DModel(CachableDiT, OffloadableDiTMixin): + r""" + Args: + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + num_layers (`int`, defaults to `30`): + The number of layers of Transformer blocks to use. + attention_head_dim (`int`, defaults to `40`): + The number of channels in each head. + num_attention_heads (`int`, defaults to `64`): + The number of heads to use for multi-head attention. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_embed_dim (`int`, defaults to `1472`): + Input dimension of text embeddings from the text encoder. + time_embed_dim (`int`, defaults to `512`): + Output dimension of timestep embeddings. + condition_dim (`int`, defaults to `256`): + The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size, + crop_coords). + pos_embed_max_size (`int`, defaults to `128`): + The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added + to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128 + means that the maximum supported height and width for image generation is `128 * vae_scale_factor * + patch_size => 128 * 8 * 2 => 2048`. + sample_size (`int`, defaults to `128`): + The base resolution of input latents. If height/width is not provided during generation, this value is used + to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024` + """ + + def __init__( + self, + config: GlmImageDitConfig, + hf_config: dict[str, Any], + quant_config: QuantizationConfig | None = None, + ): + super().__init__(config=config, hf_config=hf_config) + + self.config_data = config # Store config + arch_config = config.arch_config + + self.in_channels = arch_config.in_channels + self.out_channels = arch_config.out_channels + self.patch_size = arch_config.patch_size + self.num_layers = arch_config.num_layers + self.attention_head_dim = arch_config.attention_head_dim + self.num_attention_heads = arch_config.num_attention_heads + self.text_embed_dim = arch_config.text_embed_dim + self.time_embed_dim = arch_config.time_embed_dim + + # GlmImage uses 2 additional SDXL-like conditions - target_size, crop_coords + # Each of these are sincos embeddings of shape 2 * condition_dim + pooled_projection_dim = 2 * 2 * arch_config.condition_dim + inner_dim = arch_config.num_attention_heads * arch_config.attention_head_dim + + # 1. RoPE + self.rotary_emb = GlmImageRotaryPosEmbed( + arch_config.attention_head_dim, arch_config.patch_size, theta=10000.0 + ) + + # 2. Patch & Text-timestep embedding + self.image_projector = GlmImageImageProjector( + arch_config.in_channels, inner_dim, arch_config.patch_size + ) + self.glyph_projector = FeedForward( + arch_config.text_embed_dim, + inner_dim, + inner_dim=inner_dim, + activation_fn="gelu", + ) + self.prior_token_embedding = nn.Embedding( + arch_config.prior_vq_quantizer_codebook_size, inner_dim + ) + self.prior_projector = FeedForward( + inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu" + ) + + self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings( + embedding_dim=arch_config.time_embed_dim, + condition_dim=arch_config.condition_dim, + pooled_projection_dim=pooled_projection_dim, + timesteps_dim=arch_config.time_embed_dim, + ) + + # 3. Transformer blocks + self._supported_attention_backends = arch_config._supported_attention_backends + self.transformer_blocks = nn.ModuleList( + [ + GlmImageTransformerBlock( + inner_dim, + arch_config.num_attention_heads, + arch_config.attention_head_dim, + arch_config.time_embed_dim, + supported_attention_backends=self._supported_attention_backends, + prefix=f"transformer_blocks.{i}", + quant_config=quant_config, + ) + for i in range(arch_config.num_layers) + ] + ) + + # 4. Output projection + self.norm_out = GlmImageAdaLayerNormContinuous( + inner_dim, arch_config.time_embed_dim, elementwise_affine=False + ) + self.proj_out = nn.Linear( + inner_dim, + arch_config.patch_size * arch_config.patch_size * arch_config.out_channels, + bias=True, + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + prior_token_id: torch.Tensor, + prior_token_drop: torch.Tensor, + timestep: torch.LongTensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + attention_kwargs: Optional[Dict[str, Any]] = None, + attention_mask: Optional[torch.Tensor] = None, + kv_caches: Optional[GlmImageKVCache] = None, + kv_caches_mode: Optional[str] = None, + freqs_cis: Optional[ + Union[ + Tuple[torch.Tensor, torch.Tensor], + List[Tuple[torch.Tensor, torch.Tensor]], + ] + ] = None, + ### + guidance: torch.Tensor = None, # TODO: this should probably be removed + ) -> Tuple[torch.Tensor]: + if kv_caches is not None: + kv_caches.set_mode(kv_caches_mode) + + batch_size, num_channels, height, width = hidden_states.shape + + timestep -= 1.0 + + if isinstance(encoder_hidden_states, list): + encoder_hidden_states = encoder_hidden_states[0] + + # 1. RoPE + image_rotary_emb = freqs_cis + if image_rotary_emb is None: + image_rotary_emb = self.rotary_emb(hidden_states) + # 2. Patch & Timestep embeddings + p = self.config.patch_size + post_patch_height = height // p + post_patch_width = width // p + + hidden_states = self.image_projector(hidden_states) + encoder_hidden_states = self.glyph_projector(encoder_hidden_states) + prior_embedding = self.prior_token_embedding(prior_token_id) + prior_embedding[prior_token_drop] *= 0.0 + prior_hidden_states = self.prior_projector(prior_embedding) + # SP: when latents are H-sharded, hidden_states has fewer patches than prior_hidden_states. + # Shard prior_hidden_states along seq dim to match (prior is row-major, same as latent patches). + if ( + get_sp_world_size() > 1 + and prior_hidden_states.shape[1] != hidden_states.shape[1] + ): + rank = get_sp_parallel_rank() + sp_world_size = get_sp_world_size() + chunk = prior_hidden_states.shape[1] // sp_world_size + prior_hidden_states = prior_hidden_states[ + :, rank * chunk : (rank + 1) * chunk, : + ] + hidden_states = hidden_states + prior_hidden_states + + temb = self.time_condition_embed( + timestep, target_size, crop_coords, hidden_states.dtype + ) + temb = F.silu(temb) + + # 3. Transformer blocks + for idx, block in enumerate(self.transformer_blocks): + hidden_states, encoder_hidden_states = block( + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + attention_mask, + attention_kwargs, + kv_cache=kv_caches[idx] if kv_caches is not None else None, + ) + + # 4. Output norm & projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + hidden_states = hidden_states.reshape( + batch_size, post_patch_height, post_patch_width, -1, p, p + ) + output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + + return output.float() + # float() + # reference: https://github.com/zRzRzRzRzRzRzR/diffusers/blob/6cfc83b4abc5b083fef56a18ec4700f48ba3aaba/src/diffusers/pipelines/glm_image/pipeline_glm_image.py#L737 + + +EntryClass = GlmImageTransformer2DModel diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/dits/helios.py b/sglang/python/sglang/multimodal_gen/runtime/models/dits/helios.py new file mode 100644 index 0000000000000000000000000000000000000000..f4eecd42e070ac23f770d801cf573c5a7e913c55 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/dits/helios.py @@ -0,0 +1,825 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from Helios diffusers transformer: +# https://github.com/BestWishYsh/Helios +""" +Helios Transformer 3D model for video generation. + +Implements the HeliosTransformer3DModel with multi-term memory patches, +3D rotary position embeddings, and per-block scale-shift modulation. +""" + +import math +from functools import lru_cache +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sglang.multimodal_gen.configs.models.dits.helios import HeliosConfig +from sglang.multimodal_gen.runtime.distributed import ( + divide, + get_tp_world_size, +) +from sglang.multimodal_gen.runtime.layers.attention import USPAttention +from sglang.multimodal_gen.runtime.layers.layernorm import ( + FP32LayerNorm, + RMSNorm, + tensor_parallel_rms_norm, +) +from sglang.multimodal_gen.runtime.layers.linear import ( + ColumnParallelLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.mlp import MLP +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, +) +from sglang.multimodal_gen.runtime.layers.visual_embedding import ( + ModulateProjection, + PatchEmbed, + TimestepEmbedder, +) +from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +# --------------------------------------------------------------------------- +# Utility functions +# --------------------------------------------------------------------------- + + +def pad_for_3d_conv(x, kernel_size): + """Pad input to make it divisible by kernel_size using replicate mode.""" + b, c, t, h, w = x.shape + pt, ph, pw = kernel_size + pad_t = (pt - (t % pt)) % pt + pad_h = (ph - (h % ph)) % ph + pad_w = (pw - (w % pw)) % pw + return F.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate") + + +def center_down_sample_3d(x, kernel_size): + """Average pooling for 3D downsampling.""" + return F.avg_pool3d(x, kernel_size, stride=kernel_size) + + +def apply_rotary_emb_transposed(hidden_states, freqs_cis): + """Apply rotary positional embeddings with transposed cos/sin format.""" + x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1) + out = torch.empty_like(hidden_states) + out[..., 0::2] = x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2] + out[..., 1::2] = x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2] + return out.type_as(hidden_states) + + +# --------------------------------------------------------------------------- +# Output norm +# --------------------------------------------------------------------------- + + +class HeliosOutputNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + self.norm = FP32LayerNorm(dim, eps, elementwise_affine=False) + + def forward(self, hidden_states, temb, original_context_length): + temb = temb[:, -original_context_length:, :] + shift, scale = ( + self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2) + ).chunk(2, dim=2) + shift = shift.squeeze(2).to(hidden_states.device) + scale = scale.squeeze(2).to(hidden_states.device) + hidden_states = hidden_states[:, -original_context_length:, :] + hidden_states = ( + self.norm(hidden_states.float()) * (1 + scale) + shift + ).type_as(hidden_states) + return hidden_states + + +# --------------------------------------------------------------------------- +# Rotary Positional Embedding (3D) +# --------------------------------------------------------------------------- + + +class HeliosRotaryPosEmbed(nn.Module): + """3D rotary position embeddings for (time, height, width).""" + + def __init__(self, rope_dim, theta): + super().__init__() + self.DT, self.DY, self.DX = rope_dim + self.theta = theta + # Store as plain attributes (not buffers) to avoid meta-device issues + # during FSDP loading. They'll be re-created on the correct device in forward. + self._freqs_base_t = None + self._freqs_base_y = None + self._freqs_base_x = None + + def _get_freqs_base(self, dim): + return 1.0 / ( + self.theta + ** (torch.arange(0, dim, 2, dtype=torch.float32)[: (dim // 2)] / dim) + ) + + def _ensure_freqs_base(self, device): + """Lazily create frequency bases on the correct device.""" + if self._freqs_base_t is None or self._freqs_base_t.device != device: + self._freqs_base_t = self._get_freqs_base(self.DT).to(device) + self._freqs_base_y = self._get_freqs_base(self.DY).to(device) + self._freqs_base_x = self._get_freqs_base(self.DX).to(device) + + @torch.no_grad() + def get_frequency_batched(self, freqs_base, pos): + freqs = torch.einsum("d,bthw->dbthw", freqs_base, pos) + freqs = freqs.repeat_interleave(2, dim=0) + return freqs.cos(), freqs.sin() + + @torch.no_grad() + @lru_cache(maxsize=32) + def _get_spatial_meshgrid(self, height, width, device_str): + device = torch.device(device_str) + grid_y_coords = torch.arange(height, device=device, dtype=torch.float32) + grid_x_coords = torch.arange(width, device=device, dtype=torch.float32) + grid_y, grid_x = torch.meshgrid(grid_y_coords, grid_x_coords, indexing="ij") + return grid_y, grid_x + + @torch.no_grad() + def forward(self, frame_indices, height, width, device): + self._ensure_freqs_base(device) + batch_size = frame_indices.shape[0] + num_frames = frame_indices.shape[1] + + frame_indices = frame_indices.to(device=device, dtype=torch.float32) + grid_y, grid_x = self._get_spatial_meshgrid(height, width, str(device)) + + grid_t = frame_indices[:, :, None, None].expand( + batch_size, num_frames, height, width + ) + grid_y_batch = grid_y[None, None, :, :].expand(batch_size, num_frames, -1, -1) + grid_x_batch = grid_x[None, None, :, :].expand(batch_size, num_frames, -1, -1) + + freqs_cos_t, freqs_sin_t = self.get_frequency_batched( + self._freqs_base_t, grid_t + ) + freqs_cos_y, freqs_sin_y = self.get_frequency_batched( + self._freqs_base_y, grid_y_batch + ) + freqs_cos_x, freqs_sin_x = self.get_frequency_batched( + self._freqs_base_x, grid_x_batch + ) + + result = torch.cat( + [ + freqs_cos_t, + freqs_cos_y, + freqs_cos_x, + freqs_sin_t, + freqs_sin_y, + freqs_sin_x, + ], + dim=0, + ) + return result.permute(1, 0, 2, 3, 4) + + +# --------------------------------------------------------------------------- +# Condition Embedder +# --------------------------------------------------------------------------- + + +class HeliosTimeTextEmbedding(nn.Module): + """Condition embedder combining timestep and text embeddings.""" + + def __init__(self, dim, time_freq_dim, time_proj_dim, text_embed_dim): + super().__init__() + self.time_embedder = TimestepEmbedder( + dim, frequency_embedding_size=time_freq_dim, act_layer="silu" + ) + self.time_modulation = ModulateProjection(dim, factor=6, act_layer="silu") + self.text_embedder = MLP( + text_embed_dim, dim, dim, bias=True, act_type="gelu_pytorch_tanh" + ) + + def forward( + self, timestep, encoder_hidden_states, is_return_encoder_hidden_states=True + ): + temb = self.time_embedder(timestep) + timestep_proj = self.time_modulation(temb) + + if encoder_hidden_states is not None and is_return_encoder_hidden_states: + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + + return temb, timestep_proj, encoder_hidden_states + + +# --------------------------------------------------------------------------- +# Self-Attention for Helios +# --------------------------------------------------------------------------- + + +class HeliosSelfAttention(nn.Module): + """Self-attention with RMSNorm Q/K, optional history key amplification.""" + + def __init__( + self, + dim: int, + num_heads: int, + eps: float = 1e-6, + is_amplify_history: bool = False, + history_scale_mode: str = "per_head", + quant_config: QuantizationConfig | None = None, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + tp_size = get_tp_world_size() + self.local_num_heads = divide(num_heads, tp_size) + + self.to_q = ColumnParallelLinear( + dim, dim, bias=True, gather_output=False, quant_config=quant_config + ) + self.to_k = ColumnParallelLinear( + dim, dim, bias=True, gather_output=False, quant_config=quant_config + ) + self.to_v = ColumnParallelLinear( + dim, dim, bias=True, gather_output=False, quant_config=quant_config + ) + self.to_out = RowParallelLinear( + dim, dim, bias=True, reduce_results=True, quant_config=quant_config + ) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + self.tp_rmsnorm = tp_size > 1 + + self.attn = USPAttention( + num_heads=self.local_num_heads, + head_size=self.head_dim, + causal=False, + is_cross_attention=False, + ) + + self.is_amplify_history = is_amplify_history + if is_amplify_history: + if history_scale_mode == "scalar": + self.history_key_scale = nn.Parameter(torch.ones(1)) + elif history_scale_mode == "per_head": + self.history_key_scale = nn.Parameter(torch.ones(num_heads)) + else: + raise ValueError(f"Unknown history_scale_mode: {history_scale_mode}") + self.history_scale_mode = history_scale_mode + self.max_scale = 10.0 + + def forward(self, hidden_states, rotary_emb=None, original_context_length=None): + q, _ = self.to_q(hidden_states) + k, _ = self.to_k(hidden_states) + v, _ = self.to_v(hidden_states) + + if self.tp_rmsnorm: + q = tensor_parallel_rms_norm(q, self.norm_q) + k = tensor_parallel_rms_norm(k, self.norm_k) + else: + q = self.norm_q(q) + k = self.norm_k(k) + + q = q.unflatten(2, (self.local_num_heads, self.head_dim)) + k = k.unflatten(2, (self.local_num_heads, self.head_dim)) + v = v.unflatten(2, (self.local_num_heads, self.head_dim)) + + if rotary_emb is not None: + q = apply_rotary_emb_transposed(q, rotary_emb) + k = apply_rotary_emb_transposed(k, rotary_emb) + + if self.is_amplify_history and original_context_length is not None: + history_seq_len = hidden_states.shape[1] - original_context_length + if history_seq_len > 0: + scale_key = 1.0 + torch.sigmoid(self.history_key_scale) * ( + self.max_scale - 1.0 + ) + if self.history_scale_mode == "per_head": + scale_key = scale_key.view(1, 1, -1, 1) + k = torch.cat( + [k[:, :history_seq_len] * scale_key, k[:, history_seq_len:]], + dim=1, + ) + + x = self.attn(q, k, v) + x = x.flatten(2) + x, _ = self.to_out(x) + return x + + +# --------------------------------------------------------------------------- +# Cross-Attention for Helios +# --------------------------------------------------------------------------- + + +class HeliosCrossAttention(nn.Module): + """Cross-attention with RMSNorm Q/K normalization.""" + + def __init__( + self, + dim: int, + num_heads: int, + eps: float = 1e-6, + quant_config: QuantizationConfig | None = None, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + tp_size = get_tp_world_size() + self.local_num_heads = divide(num_heads, tp_size) + + self.to_q = ColumnParallelLinear( + dim, dim, bias=True, gather_output=False, quant_config=quant_config + ) + self.to_k = ColumnParallelLinear( + dim, dim, bias=True, gather_output=False, quant_config=quant_config + ) + self.to_v = ColumnParallelLinear( + dim, dim, bias=True, gather_output=False, quant_config=quant_config + ) + self.to_out = RowParallelLinear( + dim, dim, bias=True, reduce_results=True, quant_config=quant_config + ) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + self.tp_rmsnorm = tp_size > 1 + + self.attn = USPAttention( + num_heads=self.local_num_heads, + head_size=self.head_dim, + causal=False, + is_cross_attention=True, + ) + + def forward(self, hidden_states, encoder_hidden_states): + q, _ = self.to_q(hidden_states) + k, _ = self.to_k(encoder_hidden_states) + v, _ = self.to_v(encoder_hidden_states) + + if self.tp_rmsnorm: + q = tensor_parallel_rms_norm(q, self.norm_q) + k = tensor_parallel_rms_norm(k, self.norm_k) + else: + q = self.norm_q(q) + k = self.norm_k(k) + + q = q.unflatten(2, (self.local_num_heads, self.head_dim)) + k = k.unflatten(2, (self.local_num_heads, self.head_dim)) + v = v.unflatten(2, (self.local_num_heads, self.head_dim)) + + x = self.attn(q, k, v) + x = x.flatten(2) + x, _ = self.to_out(x) + return x + + +# --------------------------------------------------------------------------- +# Transformer Block +# --------------------------------------------------------------------------- + + +class HeliosTransformerBlock(nn.Module): + """ + Single transformer block with self-attention, cross-attention, FFN, + and scale-shift modulation from timestep embeddings. + """ + + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + cross_attn_norm: bool = True, + eps: float = 1e-6, + guidance_cross_attn: bool = True, + is_amplify_history: bool = False, + history_scale_mode: str = "per_head", + quant_config: QuantizationConfig | None = None, + ): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = HeliosSelfAttention( + dim=dim, + num_heads=num_heads, + eps=eps, + is_amplify_history=is_amplify_history, + history_scale_mode=history_scale_mode, + quant_config=quant_config, + ) + + # 2. Cross-attention + self.attn2 = HeliosCrossAttention( + dim=dim, + num_heads=num_heads, + eps=eps, + quant_config=quant_config, + ) + self.self_attn_residual_norm = ( + FP32LayerNorm(dim, eps, elementwise_affine=True) + if cross_attn_norm + else nn.Identity() + ) + + # 3. Feed-forward + self.ffn = MLP( + dim, ffn_dim, act_type="gelu_pytorch_tanh", quant_config=quant_config + ) + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + # 4. Guidance cross-attention flag + self.guidance_cross_attn = guidance_cross_attn + + def forward( + self, + hidden_states, + encoder_hidden_states, + temb, + rotary_emb, + original_context_length=None, + ): + if temb.ndim == 4: + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.unsqueeze(0) + temb.float() + ).chunk(6, dim=2) + shift_msa = shift_msa.squeeze(2) + scale_msa = scale_msa.squeeze(2) + gate_msa = gate_msa.squeeze(2) + c_shift_msa = c_shift_msa.squeeze(2) + c_scale_msa = c_scale_msa.squeeze(2) + c_gate_msa = c_gate_msa.squeeze(2) + else: + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=1) + + # 1. Self-attention + norm_hidden_states = ( + self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa + ).type_as(hidden_states) + attn_output = self.attn1( + norm_hidden_states, rotary_emb, original_context_length + ) + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as( + hidden_states + ) + + # 2. Cross-attention + if self.guidance_cross_attn: + history_seq_len = hidden_states.shape[1] - original_context_length + history_hidden_states, current_hidden_states = torch.split( + hidden_states, [history_seq_len, original_context_length], dim=1 + ) + norm_hidden_states = self.self_attn_residual_norm( + current_hidden_states.float() + ).type_as(current_hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states) + current_hidden_states = current_hidden_states + attn_output + hidden_states = torch.cat( + [history_hidden_states, current_hidden_states], dim=1 + ) + else: + norm_hidden_states = self.self_attn_residual_norm( + hidden_states.float() + ).type_as(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = ( + self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa + ).type_as(hidden_states) + ff_output = self.ffn(norm_hidden_states) + hidden_states = ( + hidden_states.float() + ff_output.float() * c_gate_msa + ).type_as(hidden_states) + + return hidden_states + + +# --------------------------------------------------------------------------- +# Main model +# --------------------------------------------------------------------------- + + +class HeliosTransformer3DModel(CachableDiT, OffloadableDiTMixin): + """ + Helios Transformer 3D model for video generation. + + Implements multi-scale history patches, 3D RoPE, and chunked denoising + with zero_history_timestep and guidance_cross_attn. + """ + + _fsdp_shard_conditions = HeliosConfig()._fsdp_shard_conditions + _compile_conditions = HeliosConfig()._compile_conditions + _supported_attention_backends = HeliosConfig()._supported_attention_backends + param_names_mapping = HeliosConfig().param_names_mapping + reverse_param_names_mapping = HeliosConfig().reverse_param_names_mapping + lora_param_names_mapping = HeliosConfig().lora_param_names_mapping + + def __init__( + self, + config: HeliosConfig, + hf_config: dict[str, Any], + quant_config: QuantizationConfig | None = None, + ) -> None: + super().__init__(config=config, hf_config=hf_config) + + inner_dim = config.num_attention_heads * config.attention_head_dim + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.in_channels = config.in_channels + self.out_channels = config.out_channels + self.num_channels_latents = config.num_channels_latents + self.patch_size = config.patch_size + self.text_len = config.text_len + self.inner_dim = inner_dim + + # Helios-specific config + self.zero_history_timestep = config.zero_history_timestep + self.has_multi_term_memory_patch = config.has_multi_term_memory_patch + self.guidance_cross_attn = config.guidance_cross_attn + + # 1. Patch & position embedding + self.patch_embedding = PatchEmbed( + in_chans=config.in_channels, + embed_dim=inner_dim, + patch_size=config.patch_size, + flatten=False, + ) + + # 2. Rotary position embeddings + self.rope = HeliosRotaryPosEmbed( + rope_dim=config.rope_dim, theta=config.rope_theta + ) + + # 3. Multi-term memory patches + if self.has_multi_term_memory_patch: + self.patch_short = nn.Conv3d( + config.in_channels, + inner_dim, + kernel_size=config.patch_size, + stride=config.patch_size, + ) + self.patch_mid = nn.Conv3d( + config.in_channels, + inner_dim, + kernel_size=tuple(2 * p for p in config.patch_size), + stride=tuple(2 * p for p in config.patch_size), + ) + self.patch_long = nn.Conv3d( + config.in_channels, + inner_dim, + kernel_size=tuple(4 * p for p in config.patch_size), + stride=tuple(4 * p for p in config.patch_size), + ) + + # 4. Condition embeddings + self.condition_embedder = HeliosTimeTextEmbedding( + dim=inner_dim, + time_freq_dim=config.freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=config.text_dim, + ) + + # 5. Transformer blocks + self.blocks = nn.ModuleList( + [ + HeliosTransformerBlock( + dim=inner_dim, + ffn_dim=config.ffn_dim, + num_heads=config.num_attention_heads, + cross_attn_norm=config.cross_attn_norm, + eps=config.eps, + guidance_cross_attn=config.guidance_cross_attn, + is_amplify_history=config.is_amplify_history, + history_scale_mode=config.history_scale_mode, + quant_config=quant_config, + ) + for _ in range(config.num_layers) + ] + ) + + # 6. Output norm & projection + self.norm_out = HeliosOutputNorm(inner_dim, config.eps) + self.proj_out = ColumnParallelLinear( + inner_dim, + config.out_channels * math.prod(config.patch_size), + bias=True, + gather_output=True, + quant_config=quant_config, + ) + + self.cnt = 0 + self.__post_init__() + self.layer_names = ["blocks"] + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + timestep: torch.LongTensor, + # Stage 1 history inputs + indices_hidden_states=None, + indices_latents_history_short=None, + indices_latents_history_mid=None, + indices_latents_history_long=None, + latents_history_short=None, + latents_history_mid=None, + latents_history_long=None, + **kwargs, + ) -> torch.Tensor: + orig_dtype = hidden_states.dtype + if not isinstance(encoder_hidden_states, torch.Tensor): + encoder_hidden_states = encoder_hidden_states[0] + + batch_size = hidden_states.shape[0] + p_t, p_h, p_w = self.patch_size + + # 1. Patch embed the noisy latents + hidden_states = self.patch_embedding(hidden_states) + _, _, post_patch_num_frames, post_patch_height, post_patch_width = ( + hidden_states.shape + ) + + if indices_hidden_states is None: + indices_hidden_states = ( + torch.arange(0, post_patch_num_frames) + .unsqueeze(0) + .expand(batch_size, -1) + ) + + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + # 2. Compute rotary embeddings + rotary_emb = self.rope( + frame_indices=indices_hidden_states, + height=post_patch_height, + width=post_patch_width, + device=hidden_states.device, + ) + rotary_emb = rotary_emb.flatten(2).transpose(1, 2) + original_context_length = hidden_states.shape[1] + + # 3. Process short history + if ( + latents_history_short is not None + and indices_latents_history_short is not None + ): + latents_history_short = latents_history_short.to(hidden_states) + latents_history_short = self.patch_short(latents_history_short) + _, _, _, H1, W1 = latents_history_short.shape + latents_history_short = latents_history_short.flatten(2).transpose(1, 2) + + rotary_emb_history_short = self.rope( + frame_indices=indices_latents_history_short, + height=H1, + width=W1, + device=latents_history_short.device, + ) + rotary_emb_history_short = rotary_emb_history_short.flatten(2).transpose( + 1, 2 + ) + hidden_states = torch.cat([latents_history_short, hidden_states], dim=1) + rotary_emb = torch.cat([rotary_emb_history_short, rotary_emb], dim=1) + + # 4. Process mid history + if latents_history_mid is not None and indices_latents_history_mid is not None: + latents_history_mid = latents_history_mid.to(hidden_states) + latents_history_mid = pad_for_3d_conv(latents_history_mid, (2, 4, 4)) + latents_history_mid = self.patch_mid(latents_history_mid) + latents_history_mid = latents_history_mid.flatten(2).transpose(1, 2) + + rotary_emb_history_mid = self.rope( + frame_indices=indices_latents_history_mid, + height=H1, + width=W1, + device=latents_history_mid.device, + ) + rotary_emb_history_mid = pad_for_3d_conv(rotary_emb_history_mid, (2, 2, 2)) + rotary_emb_history_mid = center_down_sample_3d( + rotary_emb_history_mid, (2, 2, 2) + ) + rotary_emb_history_mid = rotary_emb_history_mid.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([latents_history_mid, hidden_states], dim=1) + rotary_emb = torch.cat([rotary_emb_history_mid, rotary_emb], dim=1) + + # 5. Process long history + if ( + latents_history_long is not None + and indices_latents_history_long is not None + ): + latents_history_long = latents_history_long.to(hidden_states) + latents_history_long = pad_for_3d_conv(latents_history_long, (4, 8, 8)) + latents_history_long = self.patch_long(latents_history_long) + latents_history_long = latents_history_long.flatten(2).transpose(1, 2) + + rotary_emb_history_long = self.rope( + frame_indices=indices_latents_history_long, + height=H1, + width=W1, + device=latents_history_long.device, + ) + rotary_emb_history_long = pad_for_3d_conv( + rotary_emb_history_long, (4, 4, 4) + ) + rotary_emb_history_long = center_down_sample_3d( + rotary_emb_history_long, (4, 4, 4) + ) + rotary_emb_history_long = rotary_emb_history_long.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([latents_history_long, hidden_states], dim=1) + rotary_emb = torch.cat([rotary_emb_history_long, rotary_emb], dim=1) + + history_context_length = hidden_states.shape[1] - original_context_length + + # 6. Compute condition embeddings + if indices_hidden_states is not None and self.zero_history_timestep: + timestep_t0 = torch.zeros( + (1,), dtype=timestep.dtype, device=timestep.device + ) + temb_t0, timestep_proj_t0, _ = self.condition_embedder( + timestep_t0, + encoder_hidden_states, + is_return_encoder_hidden_states=False, + ) + temb_t0 = temb_t0.unsqueeze(1).expand( + batch_size, history_context_length, -1 + ) + timestep_proj_t0 = ( + timestep_proj_t0.unflatten(-1, (6, -1)) + .view(1, 6, 1, -1) + .expand(batch_size, -1, history_context_length, -1) + ) + + temb, timestep_proj, encoder_hidden_states = self.condition_embedder( + timestep, encoder_hidden_states + ) + timestep_proj = timestep_proj.unflatten(-1, (6, -1)) + + if indices_hidden_states is not None and not self.zero_history_timestep: + main_repeat_size = hidden_states.shape[1] + else: + main_repeat_size = original_context_length + temb = temb.view(batch_size, 1, -1).expand(batch_size, main_repeat_size, -1) + timestep_proj = timestep_proj.view(batch_size, 6, 1, -1).expand( + batch_size, 6, main_repeat_size, -1 + ) + + if indices_hidden_states is not None and self.zero_history_timestep: + temb = torch.cat([temb_t0, temb], dim=1) + timestep_proj = torch.cat([timestep_proj_t0, timestep_proj], dim=2) + + if timestep_proj.ndim == 4: + timestep_proj = timestep_proj.permute(0, 2, 1, 3) + + # 7. Transformer blocks + hidden_states = hidden_states.contiguous() + encoder_hidden_states = encoder_hidden_states.contiguous() + rotary_emb = rotary_emb.contiguous() + + for block in self.blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + original_context_length, + ) + + self.cnt += 1 + + # 8. Output norm & projection + hidden_states = self.norm_out(hidden_states, temb, original_context_length) + hidden_states, _ = self.proj_out(hidden_states) + + # 9. Unpatchify + hidden_states = hidden_states.reshape( + batch_size, + post_patch_num_frames, + post_patch_height, + post_patch_width, + p_t, + p_h, + p_w, + -1, + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return output + + +EntryClass = HeliosTransformer3DModel diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/dits/hunyuan3d.py b/sglang/python/sglang/multimodal_gen/runtime/models/dits/hunyuan3d.py new file mode 100644 index 0000000000000000000000000000000000000000..034b7718b3c40f12551cb9986e79f28d3255acdf --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/dits/hunyuan3d.py @@ -0,0 +1,1401 @@ +# Copied and adapted from: https://github.com/Tencent-Hunyuan/Hunyuan3D-2 +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from sglang.multimodal_gen.configs.models.dits.hunyuan3d import ( + Hunyuan3DDiTArchConfig, + Hunyuan3DDiTConfig, +) +from sglang.multimodal_gen.runtime.distributed import divide +from sglang.multimodal_gen.runtime.distributed.parallel_state import get_tp_world_size +from sglang.multimodal_gen.runtime.layers.attention import LocalAttention +from sglang.multimodal_gen.runtime.layers.linear import ( + MergedColumnParallelLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.mlp import MLP +from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class MixedRowParallelLinear(RowParallelLinear): + """RowParallel for inputs concatenated from multiple separately-sharded sources.""" + + def __init__(self, input_sizes: list[int], output_size: int, **kwargs): + self.input_sizes = input_sizes + super().__init__(sum(input_sizes), output_size, **kwargs) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + input_dim = getattr(param, "input_dim", None) + if input_dim is not None: + shards = [] + offset = 0 + for sz in self.input_sizes: + part = loaded_weight.narrow(input_dim, offset, sz) + per_rank = sz // self.tp_size + shard = part.narrow(input_dim, self.tp_rank * per_rank, per_rank) + shards.append(shard) + offset += sz + param.data.copy_(torch.cat(shards, dim=input_dim)) + else: + param.data.copy_(loaded_weight) + + +def _flux_timestep_embedding( + t: torch.Tensor, dim, max_period=10000, time_factor: float = 1000.0 +): + """Create sinusoidal timestep embeddings for Flux-style model.""" + t = time_factor * t + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(t.device) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class _FluxGELU(nn.Module): + def __init__(self, approximate="tanh"): + super().__init__() + self.approximate = approximate + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.gelu(x, approximate=self.approximate) + + +class _FluxMLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class _FluxRMSNorm(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + +class _FluxQKNorm(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = _FluxRMSNorm(dim) + self.key_norm = _FluxRMSNorm(dim) + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class _FluxSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + ): + super().__init__() + tp_size = get_tp_world_size() + self.num_heads = num_heads + self.local_num_heads = divide(num_heads, tp_size) + self.head_dim = dim // num_heads + + self.qkv = MergedColumnParallelLinear( + dim, [dim, dim, dim], bias=qkv_bias, gather_output=False + ) + self.norm = _FluxQKNorm(self.head_dim) + self.proj = RowParallelLinear(dim, dim, bias=True, input_is_parallel=True) + + if supported_attention_backends is None: + supported_attention_backends = { + AttentionBackendEnum.FA, + AttentionBackendEnum.TORCH_SDPA, + } + self.local_attn = LocalAttention( + num_heads=self.local_num_heads, + head_size=self.head_dim, + causal=False, + supported_attention_backends=supported_attention_backends, + ) + + def forward(self, x: torch.Tensor, pe: torch.Tensor) -> torch.Tensor: + qkv, _ = self.qkv(x) + B, L, _ = qkv.shape + qkv = qkv.view(B, L, 3, self.local_num_heads, self.head_dim) + q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v_for_norm = v.transpose(1, 2) + q, k = self.norm(q, k, v_for_norm) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + x = self.local_attn(q, k, v) + x = x.flatten(2) + x, _ = self.proj(x) + return x + + +@dataclass +class _FluxModulationOut: + shift: torch.Tensor + scale: torch.Tensor + gate: torch.Tensor + + +class _FluxModulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward( + self, vec: torch.Tensor + ) -> Tuple[_FluxModulationOut, Optional[_FluxModulationOut]]: + out = self.lin(F.silu(vec))[:, None, :] + out = out.chunk(self.multiplier, dim=-1) + + return ( + _FluxModulationOut(*out[:3]), + _FluxModulationOut(*out[3:]) if self.is_double else None, + ) + + +class _FluxDoubleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float, + qkv_bias: bool = False, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + ): + super().__init__() + mlp_hidden_dim = int(hidden_size * mlp_ratio) + tp_size = get_tp_world_size() + self.num_heads = num_heads + self.local_num_heads = divide(num_heads, tp_size) + self.hidden_size = hidden_size + self.head_dim = hidden_size // num_heads + self.img_mod = _FluxModulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = _FluxSelfAttention( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + supported_attention_backends=supported_attention_backends, + ) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = MLP(hidden_size, mlp_hidden_dim, act_type="gelu_pytorch_tanh") + + self.txt_mod = _FluxModulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = _FluxSelfAttention( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + supported_attention_backends=supported_attention_backends, + ) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = MLP(hidden_size, mlp_hidden_dim, act_type="gelu_pytorch_tanh") + + if supported_attention_backends is None: + supported_attention_backends = { + AttentionBackendEnum.FA, + AttentionBackendEnum.TORCH_SDPA, + } + self.local_attn_joint = LocalAttention( + num_heads=self.local_num_heads, + head_size=self.head_dim, + causal=False, + supported_attention_backends=supported_attention_backends, + ) + + def forward( + self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + + B, img_L, _ = img_modulated.shape + img_qkv, _ = self.img_attn.qkv(img_modulated) + img_qkv = img_qkv.view(B, img_L, 3, self.local_num_heads, self.head_dim) + img_q, img_k, img_v = img_qkv[:, :, 0], img_qkv[:, :, 1], img_qkv[:, :, 2] + img_q_t = img_q.transpose(1, 2) + img_k_t = img_k.transpose(1, 2) + img_v_t = img_v.transpose(1, 2) + img_q_t, img_k_t = self.img_attn.norm(img_q_t, img_k_t, img_v_t) + img_q = img_q_t.transpose(1, 2) + img_k = img_k_t.transpose(1, 2) + + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_L = txt_modulated.shape[1] + txt_qkv, _ = self.txt_attn.qkv(txt_modulated) + txt_qkv = txt_qkv.view(B, txt_L, 3, self.local_num_heads, self.head_dim) + txt_q, txt_k, txt_v = txt_qkv[:, :, 0], txt_qkv[:, :, 1], txt_qkv[:, :, 2] + txt_q_t = txt_q.transpose(1, 2) + txt_k_t = txt_k.transpose(1, 2) + txt_v_t = txt_v.transpose(1, 2) + txt_q_t, txt_k_t = self.txt_attn.norm(txt_q_t, txt_k_t, txt_v_t) + txt_q = txt_q_t.transpose(1, 2) + txt_k = txt_k_t.transpose(1, 2) + + q = torch.cat((txt_q, img_q), dim=1) + k = torch.cat((txt_k, img_k), dim=1) + v = torch.cat((txt_v, img_v), dim=1) + + attn = self.local_attn_joint(q, k, v) + attn = attn.flatten(2) + + txt_attn, img_attn = attn[:, :txt_L], attn[:, txt_L:] + + img_proj, _ = self.img_attn.proj(img_attn) + img = img + img_mod1.gate * img_proj + img = img + img_mod2.gate * self.img_mlp( + (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift + ) + + txt_proj, _ = self.txt_attn.proj(txt_attn) + txt = txt + txt_mod1.gate * txt_proj + txt = txt + txt_mod2.gate * self.txt_mlp( + (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift + ) + return img, txt + + +class _FluxSingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: Optional[float] = None, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + ): + super().__init__() + + tp_size = get_tp_world_size() + self.hidden_dim = hidden_size + self.num_heads = num_heads + self.local_num_heads = divide(num_heads, tp_size) + self.head_dim = hidden_size // num_heads + self.tp_size = tp_size + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.linear1 = MergedColumnParallelLinear( + hidden_size, + [hidden_size, hidden_size, hidden_size, self.mlp_hidden_dim], + bias=True, + gather_output=False, + ) + self.linear2 = MixedRowParallelLinear( + [hidden_size, self.mlp_hidden_dim], + hidden_size, + bias=True, + input_is_parallel=True, + ) + + self.norm = _FluxQKNorm(self.head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = _FluxGELU(approximate="tanh") + self.modulation = _FluxModulation(hidden_size, double=False) + + if supported_attention_backends is None: + supported_attention_backends = { + AttentionBackendEnum.FA, + AttentionBackendEnum.TORCH_SDPA, + } + self.local_attn = LocalAttention( + num_heads=self.local_num_heads, + head_size=self.head_dim, + causal=False, + supported_attention_backends=supported_attention_backends, + ) + + def forward( + self, x: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor + ) -> torch.Tensor: + mod, _ = self.modulation(vec) + + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + linear1_out, _ = self.linear1(x_mod) + local_qkv_dim = 3 * self.head_dim * self.local_num_heads + local_mlp_dim = self.mlp_hidden_dim // self.tp_size + qkv, mlp = torch.split(linear1_out, [local_qkv_dim, local_mlp_dim], dim=-1) + + B, L, _ = qkv.shape + qkv = qkv.view(B, L, 3, self.local_num_heads, self.head_dim) + q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + q_t = q.transpose(1, 2) + k_t = k.transpose(1, 2) + v_t = v.transpose(1, 2) + q_t, k_t = self.norm(q_t, k_t, v_t) + q = q_t.transpose(1, 2) + k = k_t.transpose(1, 2) + + attn = self.local_attn(q, k, v) + attn = attn.flatten(2) + + output, _ = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + mod.gate * output + + +class _FluxLastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, bias=True + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + + +class Hunyuan3D2DiT(CachableDiT, OffloadableDiTMixin): + """Hunyuan3D DiT model (Flux-style architecture for Hunyuan3D-2.0).""" + + _aliases = ["hy3dgen.shapegen.models.Hunyuan3DDiT"] + + param_names_mapping = Hunyuan3DDiTConfig().param_names_mapping + + @classmethod + def build_config_from_params(cls, params: dict) -> Hunyuan3DDiTConfig: + """Build a DiTConfig from YAML-style parameter dict.""" + field_mapping = { + "num_heads": "num_attention_heads", + "depth": "num_layers", + "depth_single_blocks": "num_single_layers", + } + arch_kwargs = {} + for k, v in params.items(): + if k in ("ckpt_path", "supported_attention_backends"): + continue + mapped = field_mapping.get(k, k) + if k == "axes_dim" and isinstance(v, list): + v = tuple(v) + arch_kwargs[mapped] = v + return Hunyuan3DDiTConfig(arch_config=Hunyuan3DDiTArchConfig(**arch_kwargs)) + + def __init__( + self, + config: Hunyuan3DDiTConfig, + hf_config: dict | None = None, + **kwargs, + ): + super().__init__(config=config, hf_config=hf_config or {}, **kwargs) + arch = config.arch_config + + in_channels = arch.in_channels + context_in_dim = arch.context_in_dim + hidden_size = arch.hidden_size + mlp_ratio = arch.mlp_ratio + num_heads = arch.num_attention_heads + depth = arch.num_layers + depth_single_blocks = arch.num_single_layers + axes_dim = list(arch.axes_dim) + theta = arch.theta + qkv_bias = arch.qkv_bias + time_factor = arch.time_factor + guidance_embed = arch.guidance_embed + supported_attention_backends = arch._supported_attention_backends + + self.in_channels = in_channels + self.context_in_dim = context_in_dim + self.hidden_size = hidden_size + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.num_attention_heads = num_heads + self.depth = depth + self.depth_single_blocks = depth_single_blocks + self.axes_dim = axes_dim + self.theta = theta + self.qkv_bias = qkv_bias + self.time_factor = time_factor + self.out_channels = self.in_channels + self.num_channels_latents = self.in_channels + self.guidance_embed = guidance_embed + + if hidden_size % num_heads != 0: + raise ValueError( + f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}" + ) + pe_dim = hidden_size // num_heads + if sum(axes_dim) != pe_dim: + raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}") + self.latent_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = _FluxMLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.cond_in = nn.Linear(context_in_dim, self.hidden_size) + self.guidance_in = ( + _FluxMLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + if guidance_embed + else nn.Identity() + ) + + self.double_blocks = nn.ModuleList( + [ + _FluxDoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + supported_attention_backends=supported_attention_backends, + ) + for _ in range(depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + _FluxSingleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=mlp_ratio, + supported_attention_backends=supported_attention_backends, + ) + for _ in range(depth_single_blocks) + ] + ) + + self.final_layer = _FluxLastLayer(self.hidden_size, 1, self.out_channels) + + # OffloadableDiTMixin + self.layer_names = ["double_blocks", "single_blocks"] + + def forward( + self, + x, + t, + contexts, + **kwargs, + ) -> torch.Tensor: + """Forward pass for denoising.""" + + cond = contexts["main"] + + latent = self.latent_in(x) + + t_emb = _flux_timestep_embedding(t, 256, self.time_factor).to( + dtype=latent.dtype + ) + + vec = self.time_in(t_emb) + + if self.guidance_embed: + guidance = kwargs.get("guidance", None) + if guidance is None: + raise ValueError( + "Didn't get guidance strength for guidance distilled model." + ) + vec = vec + self.guidance_in( + _flux_timestep_embedding(guidance, 256, self.time_factor) + ) + + cond = self.cond_in(cond) + + pe = None + + # Double blocks + for i, block in enumerate(self.double_blocks): + latent, cond = block(img=latent, txt=cond, vec=vec, pe=pe) + latent = torch.cat((cond, latent), 1) + + # Single blocks + for i, block in enumerate(self.single_blocks): + latent = block(latent, vec=vec, pe=pe) + + latent = latent[:, cond.shape[1] :, ...] + latent = self.final_layer(latent, vec) + return latent + + +import copy +import json +import os as _os + +from diffusers.models import UNet2DConditionModel +from diffusers.models.attention_processor import Attention as DiffusersAttention +from diffusers.models.transformers.transformer_2d import BasicTransformerBlock + + +def _chunked_feed_forward( + ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int +): + """Feed forward with chunking to save memory.""" + if hidden_states.shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]}" + f"has to be divisible by chunk size: {chunk_size}." + f" Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = hidden_states.shape[chunk_dim] // chunk_size + ff_output = torch.cat( + [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + return ff_output + + +class SGLangAttentionWrapper(torch.nn.Module): + """Drop-in replacement for DiffusersAttention that uses sglang's attention backend.""" + + _SUPPORTED_BACKENDS = {AttentionBackendEnum.FA, AttentionBackendEnum.TORCH_SDPA} + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + cross_attention_dim: int | None = None, + out_bias: bool = True, + ) -> None: + super().__init__() + self.inner_dim = dim_head * heads + self.heads = heads + self.dim_head = dim_head + self.query_dim = query_dim + cross_attention_dim = cross_attention_dim or query_dim + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, self.inner_dim, bias=bias) + self.to_out = nn.ModuleList( + [nn.Linear(self.inner_dim, query_dim, bias=out_bias), nn.Dropout(dropout)] + ) + + from sglang.multimodal_gen.runtime.layers.attention.selector import ( + get_attn_backend, + ) + + attn_backend = get_attn_backend( + dim_head, torch.float16, self._SUPPORTED_BACKENDS + ) + impl_cls = attn_backend.get_impl_cls() + self.attn_impl = impl_cls( + num_heads=heads, + head_size=dim_head, + softmax_scale=dim_head**-0.5, + num_kv_heads=heads, + causal=False, + ) + self._attn_backend_name = attn_backend.get_enum().name + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + B, N_q, _ = hidden_states.shape + _, N_kv, _ = encoder_hidden_states.shape + + q = self.to_q(hidden_states).view(B, N_q, self.heads, self.dim_head) + k = self.to_k(encoder_hidden_states).view(B, N_kv, self.heads, self.dim_head) + v = self.to_v(encoder_hidden_states).view(B, N_kv, self.heads, self.dim_head) + + from sglang.multimodal_gen.runtime.managers.forward_context import ( + get_forward_context, + ) + + ctx = get_forward_context() + out = self.attn_impl.forward(q, k, v, attn_metadata=ctx.attn_metadata) + out = out.reshape(B, N_q, self.inner_dim) + + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + +class Basic2p5DTransformerBlock(torch.nn.Module): + """2.5D Transformer block with Multiview Attention (MVA) and Reference View Attention (RVA).""" + + def __init__( + self, + transformer: BasicTransformerBlock, + layer_name: str, + use_ma: bool = True, + use_ra: bool = True, + is_turbo: bool = False, + use_sglang_attn: bool = True, + ) -> None: + super().__init__() + self.transformer = transformer + self.layer_name = layer_name + self.use_ma = use_ma + self.use_ra = use_ra + self.is_turbo = is_turbo + self.use_sglang_attn = use_sglang_attn and not is_turbo + + attn_cls = ( + SGLangAttentionWrapper if self.use_sglang_attn else DiffusersAttention + ) + attn_kwargs = dict( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + cross_attention_dim=None, + upcast_attention=self.attn1.upcast_attention, + out_bias=True, + ) + if self.use_sglang_attn: + attn_kwargs.pop("upcast_attention") + + if self.use_ma: + self.attn_multiview = attn_cls(**attn_kwargs) + + if self.use_ra: + self.attn_refview = attn_cls(**attn_kwargs) + + if self.is_turbo: + self._initialize_attn_weights() + + def _initialize_attn_weights(self): + """Initialize attention weights for turbo mode.""" + if self.use_ma: + self.attn_multiview.load_state_dict(self.attn1.state_dict()) + with torch.no_grad(): + for layer in self.attn_multiview.to_out: + for param in layer.parameters(): + param.zero_() + if self.use_ra: + self.attn_refview.load_state_dict(self.attn1.state_dict()) + with torch.no_grad(): + for layer in self.attn_refview.to_out: + for param in layer.parameters(): + param.zero_() + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.transformer, name) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: dict = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[dict] = None, + ) -> torch.Tensor: + """Forward pass with MVA and RVA support.""" + batch_size = hidden_states.shape[0] + + cross_attention_kwargs = ( + cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + ) + num_in_batch = cross_attention_kwargs.pop("num_in_batch", 1) + mode = cross_attention_kwargs.pop("mode", None) + + if not self.is_turbo: + mva_scale = cross_attention_kwargs.pop("mva_scale", 1.0) + ref_scale = cross_attention_kwargs.pop("ref_scale", 1.0) + else: + position_attn_mask = cross_attention_kwargs.pop("position_attn_mask", None) + position_voxel_indices = cross_attention_kwargs.pop( + "position_voxel_indices", None + ) + mva_scale = 1.0 + ref_scale = 1.0 + + condition_embed_dict = cross_attention_kwargs.pop("condition_embed_dict", None) + + # Normalization + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1( + hidden_states, added_cond_kwargs["pooled_text_emb"] + ) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # Prepare GLIGEN inputs + cross_attention_kwargs = ( + cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + ) + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + # Self-attention + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=( + encoder_hidden_states if self.only_cross_attention else None + ), + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # Reference Attention - Write mode + if mode is not None and "w" in mode: + condition_embed_dict[self.layer_name] = rearrange( + norm_hidden_states, "(b n) l c -> b (n l) c", n=num_in_batch + ) + + # Reference Attention - Read mode + if mode is not None and "r" in mode and self.use_ra: + condition_embed = ( + condition_embed_dict[self.layer_name] + .unsqueeze(1) + .repeat(1, num_in_batch, 1, 1) + ) + condition_embed = rearrange(condition_embed, "b n l c -> (b n) l c") + + attn_output = self.attn_refview( + norm_hidden_states, + encoder_hidden_states=condition_embed, + attention_mask=None, + **cross_attention_kwargs, + ) + + if not self.is_turbo: + ref_scale_timing = ref_scale + if isinstance(ref_scale, torch.Tensor): + ref_scale_timing = ( + ref_scale.unsqueeze(1).repeat(1, num_in_batch).view(-1) + ) + for _ in range(attn_output.ndim - 1): + ref_scale_timing = ref_scale_timing.unsqueeze(-1) + + hidden_states = ref_scale_timing * attn_output + hidden_states + + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # Multiview Attention + if num_in_batch > 1 and self.use_ma: + multivew_hidden_states = rearrange( + norm_hidden_states, "(b n) l c -> b (n l) c", n=num_in_batch + ) + + if self.is_turbo: + position_mask = None + if position_attn_mask is not None: + if multivew_hidden_states.shape[1] in position_attn_mask: + position_mask = position_attn_mask[ + multivew_hidden_states.shape[1] + ] + position_indices = None + if position_voxel_indices is not None: + if multivew_hidden_states.shape[1] in position_voxel_indices: + position_indices = position_voxel_indices[ + multivew_hidden_states.shape[1] + ] + attn_output = self.attn_multiview( + multivew_hidden_states, + encoder_hidden_states=multivew_hidden_states, + attention_mask=position_mask, + position_indices=position_indices, + **cross_attention_kwargs, + ) + else: + attn_output = self.attn_multiview( + multivew_hidden_states, + encoder_hidden_states=multivew_hidden_states, + **cross_attention_kwargs, + ) + + attn_output = rearrange( + attn_output, "b (n l) c -> (b n) l c", n=num_in_batch + ) + + hidden_states = mva_scale * attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2( + hidden_states, added_cond_kwargs["pooled_text_emb"] + ) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + # Feed-forward + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3( + hidden_states, added_cond_kwargs["pooled_text_emb"] + ) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + ff_output = _chunked_feed_forward( + self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +@torch.no_grad() +def compute_voxel_grid_mask(position: torch.Tensor, grid_resolution: int = 8): + """Compute voxel grid mask for position-aware attention.""" + position = position.half() + B, N, _, H, W = position.shape + assert H % grid_resolution == 0 and W % grid_resolution == 0 + + valid_mask = (position != 1).all(dim=2, keepdim=True) + valid_mask = valid_mask.expand_as(position) + position[valid_mask == False] = 0 + + position = rearrange( + position, + "b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w", + num_h=grid_resolution, + num_w=grid_resolution, + ) + valid_mask = rearrange( + valid_mask, + "b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w", + num_h=grid_resolution, + num_w=grid_resolution, + ) + + grid_position = position.sum(dim=(-2, -1)) + count_masked = valid_mask.sum(dim=(-2, -1)) + + grid_position = grid_position / count_masked.clamp(min=1) + grid_position[count_masked < 5] = 0 + + grid_position = grid_position.permute(0, 1, 4, 2, 3) + grid_position = rearrange(grid_position, "b n c h w -> b n (h w) c") + + grid_position_expanded_1 = grid_position.unsqueeze(2).unsqueeze(4) + grid_position_expanded_2 = grid_position.unsqueeze(1).unsqueeze(3) + + distances = torch.norm(grid_position_expanded_1 - grid_position_expanded_2, dim=-1) + + weights = distances + grid_distance = 1.73 / grid_resolution + + weights = weights < grid_distance + + return weights + + +def compute_multi_resolution_mask( + position_maps: torch.Tensor, grid_resolutions: List[int] = [32, 16, 8] +) -> dict: + """Compute multi-resolution position attention masks.""" + position_attn_mask = {} + with torch.no_grad(): + for grid_resolution in grid_resolutions: + position_mask = compute_voxel_grid_mask(position_maps, grid_resolution) + position_mask = rearrange( + position_mask, "b ni nj li lj -> b (ni li) (nj lj)" + ) + position_attn_mask[position_mask.shape[1]] = position_mask + return position_attn_mask + + +@torch.no_grad() +def compute_discrete_voxel_indice( + position: torch.Tensor, grid_resolution: int = 8, voxel_resolution: int = 128 +): + """Compute discrete voxel indices for position encoding.""" + position = position.half() + B, N, _, H, W = position.shape + assert H % grid_resolution == 0 and W % grid_resolution == 0 + + valid_mask = (position != 1).all(dim=2, keepdim=True) + valid_mask = valid_mask.expand_as(position) + position[valid_mask == False] = 0 + + position = rearrange( + position, + "b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w", + num_h=grid_resolution, + num_w=grid_resolution, + ) + valid_mask = rearrange( + valid_mask, + "b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w", + num_h=grid_resolution, + num_w=grid_resolution, + ) + + grid_position = position.sum(dim=(-2, -1)) + count_masked = valid_mask.sum(dim=(-2, -1)) + + grid_position = grid_position / count_masked.clamp(min=1) + grid_position[count_masked < 5] = 0 + + grid_position = grid_position.permute(0, 1, 4, 2, 3).clamp(0, 1) + voxel_indices = grid_position * (voxel_resolution - 1) + voxel_indices = torch.round(voxel_indices).long() + return voxel_indices + + +def compute_multi_resolution_discrete_voxel_indice( + position_maps: torch.Tensor, + grid_resolutions: List[int] = [64, 32, 16, 8], + voxel_resolutions: List[int] = [512, 256, 128, 64], +) -> dict: + """Compute multi-resolution discrete voxel indices.""" + voxel_indices = {} + with torch.no_grad(): + for grid_resolution, voxel_resolution in zip( + grid_resolutions, voxel_resolutions + ): + voxel_indice = compute_discrete_voxel_indice( + position_maps, grid_resolution, voxel_resolution + ) + voxel_indice = rearrange(voxel_indice, "b n c h w -> b (n h w) c") + voxel_indices[voxel_indice.shape[1]] = { + "voxel_indices": voxel_indice, + "voxel_resolution": voxel_resolution, + } + return voxel_indices + + +class UNet2p5DConditionModel(torch.nn.Module): + """2.5D UNet for multi-view texture generation.""" + + def __init__(self, unet: UNet2DConditionModel) -> None: + super().__init__() + self.unet = unet + + self.use_ma = True + self.use_ra = True + self.use_camera_embedding = True + self.use_dual_stream = True + self.is_turbo = False + + if self.use_dual_stream: + self.unet_dual = copy.deepcopy(unet) + self.init_attention(self.unet_dual) + self.init_attention( + self.unet, use_ma=self.use_ma, use_ra=self.use_ra, is_turbo=self.is_turbo + ) + self.init_condition() + self.init_camera_embedding() + + @staticmethod + def from_pretrained(pretrained_model_name_or_path: str, **kwargs): + """Load a pretrained UNet2p5DConditionModel.""" + torch_dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", torch.float32)) + config_path = _os.path.join(pretrained_model_name_or_path, "config.json") + unet_ckpt_path = _os.path.join( + pretrained_model_name_or_path, "diffusion_pytorch_model.bin" + ) + + with open(config_path, "r", encoding="utf-8") as file: + config = json.load(file) + + unet = UNet2DConditionModel(**config) + unet = UNet2p5DConditionModel(unet) + unet_ckpt = torch.load(unet_ckpt_path, map_location="cpu", weights_only=True) + unet.load_state_dict(unet_ckpt, strict=True) + unet = unet.to(torch_dtype) + return unet + + def init_condition(self): + """Initialize condition-related modules.""" + self.unet.conv_in = torch.nn.Conv2d( + 12, # 4 (latent) + 4 (normal) + 4 (position) + self.unet.conv_in.out_channels, + kernel_size=self.unet.conv_in.kernel_size, + stride=self.unet.conv_in.stride, + padding=self.unet.conv_in.padding, + dilation=self.unet.conv_in.dilation, + groups=self.unet.conv_in.groups, + bias=self.unet.conv_in.bias is not None, + ) + + self.unet.learned_text_clip_gen = nn.Parameter(torch.randn(1, 77, 1024)) + self.unet.learned_text_clip_ref = nn.Parameter(torch.randn(1, 77, 1024)) + + def init_camera_embedding(self): + """Initialize camera embedding module.""" + if self.use_camera_embedding: + time_embed_dim = 1280 + self.max_num_ref_image = 5 + self.max_num_gen_image = 12 * 3 + 4 * 2 + self.unet.class_embedding = nn.Embedding( + self.max_num_ref_image + self.max_num_gen_image, time_embed_dim + ) + + def init_attention( + self, + unet: UNet2DConditionModel, + use_ma: bool = False, + use_ra: bool = False, + is_turbo: bool = False, + use_sglang_attn: bool = True, + ): + """Initialize attention blocks with MVA and RVA support.""" + block_kwargs = dict( + use_ma=use_ma, + use_ra=use_ra, + is_turbo=is_turbo, + use_sglang_attn=use_sglang_attn, + ) + + # Down blocks + for down_block_i, down_block in enumerate(unet.down_blocks): + if ( + hasattr(down_block, "has_cross_attention") + and down_block.has_cross_attention + ): + for attn_i, attn in enumerate(down_block.attentions): + for transformer_i, transformer in enumerate( + attn.transformer_blocks + ): + if isinstance(transformer, BasicTransformerBlock): + attn.transformer_blocks[transformer_i] = ( + Basic2p5DTransformerBlock( + transformer, + f"down_{down_block_i}_{attn_i}_{transformer_i}", + **block_kwargs, + ) + ) + + # Mid block + if ( + hasattr(unet.mid_block, "has_cross_attention") + and unet.mid_block.has_cross_attention + ): + for attn_i, attn in enumerate(unet.mid_block.attentions): + for transformer_i, transformer in enumerate(attn.transformer_blocks): + if isinstance(transformer, BasicTransformerBlock): + attn.transformer_blocks[transformer_i] = ( + Basic2p5DTransformerBlock( + transformer, + f"mid_{attn_i}_{transformer_i}", + **block_kwargs, + ) + ) + + # Up blocks + for up_block_i, up_block in enumerate(unet.up_blocks): + if ( + hasattr(up_block, "has_cross_attention") + and up_block.has_cross_attention + ): + for attn_i, attn in enumerate(up_block.attentions): + for transformer_i, transformer in enumerate( + attn.transformer_blocks + ): + if isinstance(transformer, BasicTransformerBlock): + attn.transformer_blocks[transformer_i] = ( + Basic2p5DTransformerBlock( + transformer, + f"up_{up_block_i}_{attn_i}_{transformer_i}", + **block_kwargs, + ) + ) + + if use_sglang_attn and (use_ma or use_ra): + backend = "unknown" + for block in self._iter_2p5d_blocks(unet): + for attr in ("attn_multiview", "attn_refview"): + wrapper = getattr(block, attr, None) + if isinstance(wrapper, SGLangAttentionWrapper): + backend = wrapper._attn_backend_name + break + if backend != "unknown": + break + count = sum(1 for _ in self._iter_2p5d_blocks(unet)) + logger.info( + "Initialized %d Basic2p5DTransformerBlocks with sglang %s attention", + count, + backend, + ) + + @staticmethod + def _iter_2p5d_blocks(unet): + """Yield all Basic2p5DTransformerBlock instances in a UNet.""" + for block_group in (unet.down_blocks, [unet.mid_block], unet.up_blocks): + for block in block_group: + if not hasattr(block, "attentions"): + continue + for attn in block.attentions: + for tb in attn.transformer_blocks: + if isinstance(tb, Basic2p5DTransformerBlock): + yield tb + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.unet, name) + + def forward( + self, + sample: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + *args, + down_intrablock_additional_residuals=None, + down_block_res_samples=None, + mid_block_res_sample=None, + **cached_condition, + ): + """Forward pass for multi-view texture generation.""" + B, N_gen, _, H, W = sample.shape + assert H == W + + if self.use_camera_embedding: + camera_info_gen = ( + cached_condition["camera_info_gen"] + self.max_num_ref_image + ) + camera_info_gen = rearrange(camera_info_gen, "b n -> (b n)") + else: + camera_info_gen = None + + # Concatenate latents with normal and position maps + sample = [sample] + if "normal_imgs" in cached_condition: + sample.append(cached_condition["normal_imgs"]) + if "position_imgs" in cached_condition: + sample.append(cached_condition["position_imgs"]) + sample = torch.cat(sample, dim=2) + + sample = rearrange(sample, "b n c h w -> (b n) c h w") + + encoder_hidden_states_gen = encoder_hidden_states.unsqueeze(1).repeat( + 1, N_gen, 1, 1 + ) + encoder_hidden_states_gen = rearrange( + encoder_hidden_states_gen, "b n l c -> (b n) l c" + ) + + # Process reference images for RVA + if self.use_ra: + if "condition_embed_dict" in cached_condition: + condition_embed_dict = cached_condition["condition_embed_dict"] + else: + condition_embed_dict = {} + ref_latents = cached_condition["ref_latents"] + N_ref = ref_latents.shape[1] + + if self.use_camera_embedding: + camera_info_ref = cached_condition["camera_info_ref"] + camera_info_ref = rearrange(camera_info_ref, "b n -> (b n)") + else: + camera_info_ref = None + + ref_latents = rearrange(ref_latents, "b n c h w -> (b n) c h w") + + encoder_hidden_states_ref = self.unet.learned_text_clip_ref.unsqueeze( + 1 + ).repeat(B, N_ref, 1, 1) + encoder_hidden_states_ref = rearrange( + encoder_hidden_states_ref, "b n l c -> (b n) l c" + ) + + noisy_ref_latents = ref_latents + timestep_ref = 0 + + if self.use_dual_stream: + unet_ref = self.unet_dual + else: + unet_ref = self.unet + + unet_ref( + noisy_ref_latents, + timestep_ref, + encoder_hidden_states=encoder_hidden_states_ref, + class_labels=camera_info_ref, + return_dict=False, + cross_attention_kwargs={ + "mode": "w", + "num_in_batch": N_ref, + "condition_embed_dict": condition_embed_dict, + }, + ) + cached_condition["condition_embed_dict"] = condition_embed_dict + else: + condition_embed_dict = None + + mva_scale = cached_condition.get("mva_scale", 1.0) + ref_scale = cached_condition.get("ref_scale", 1.0) + + if self.is_turbo: + position_attn_mask = cached_condition.get("position_attn_mask", None) + position_voxel_indices = cached_condition.get( + "position_voxel_indices", None + ) + cross_attention_kwargs_ = { + "mode": "r", + "num_in_batch": N_gen, + "condition_embed_dict": condition_embed_dict, + "position_attn_mask": position_attn_mask, + "position_voxel_indices": position_voxel_indices, + "mva_scale": mva_scale, + "ref_scale": ref_scale, + } + else: + cross_attention_kwargs_ = { + "mode": "r", + "num_in_batch": N_gen, + "condition_embed_dict": condition_embed_dict, + "mva_scale": mva_scale, + "ref_scale": ref_scale, + } + + return self.unet( + sample, + timestep, + encoder_hidden_states_gen, + *args, + class_labels=camera_info_gen, + down_intrablock_additional_residuals=( + [ + s.to(dtype=self.unet.dtype) + for s in down_intrablock_additional_residuals + ] + if down_intrablock_additional_residuals is not None + else None + ), + down_block_additional_residuals=( + [s.to(dtype=self.unet.dtype) for s in down_block_res_samples] + if down_block_res_samples is not None + else None + ), + mid_block_additional_residual=( + mid_block_res_sample.to(dtype=self.unet.dtype) + if mid_block_res_sample is not None + else None + ), + return_dict=False, + cross_attention_kwargs=cross_attention_kwargs_, + ) + + +# Entry class for model registry +EntryClass = [Hunyuan3D2DiT, UNet2p5DConditionModel] diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py b/sglang/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py new file mode 100644 index 0000000000000000000000000000000000000000..09a233ec917686dc592d85eb2d33238c1f753595 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py @@ -0,0 +1,1002 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import numpy as np +import torch +import torch.nn as nn + +from sglang.multimodal_gen.configs.models.dits import HunyuanVideoConfig +from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams +from sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_world_size +from sglang.multimodal_gen.runtime.layers.attention import ( + LocalAttention, + UlyssesAttention, +) +from sglang.multimodal_gen.runtime.layers.elementwise import MulAdd +from sglang.multimodal_gen.runtime.layers.layernorm import ( + LayerNormScaleShift, + RMSNorm, + ScaleResidualLayerNormScaleShift, +) +from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear +from sglang.multimodal_gen.runtime.layers.mlp import MLP +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, +) +from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( + _apply_rotary_emb, + get_rotary_pos_embed, +) +from sglang.multimodal_gen.runtime.layers.visual_embedding import ( + ModulateProjection, + PatchEmbed, + TimestepEmbedder, + unpatchify, +) +from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context +from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.models.utils import modulate +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) +from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin + + +class MMDoubleStreamBlock(nn.Module): + """ + A multimodal DiT block with separate modulation for text and image/video, + using distributed attention and linear layers. + """ + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + mlp_ratio: float, + dtype: torch.dtype | None = None, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + prefix: str = "", + quant_config: QuantizationConfig | None = None, + ): + super().__init__() + + self.deterministic = False + self.num_attention_heads = num_attention_heads + head_dim = hidden_size // num_attention_heads + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + # Image modulation components + self.img_mod = ModulateProjection( + hidden_size, + factor=6, + act_layer="silu", + dtype=dtype, + prefix=f"{prefix}.img_mod", + ) + + # Fused operations for image stream + self.img_attn_norm = LayerNormScaleShift( + hidden_size, elementwise_affine=False, dtype=dtype + ) + self.img_attn_residual_mlp_norm = ScaleResidualLayerNormScaleShift( + hidden_size, elementwise_affine=False, dtype=dtype + ) + self.img_mlp_residual = MulAdd() + + # Image attention components + self.img_attn_qkv = ReplicatedLinear( + hidden_size, + hidden_size * 3, + bias=True, + params_dtype=dtype, + prefix=f"{prefix}.img_attn_qkv", + quant_config=quant_config, + ) + + self.img_attn_q_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype) + self.img_attn_k_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype) + + self.img_attn_proj = ReplicatedLinear( + hidden_size, + hidden_size, + bias=True, + params_dtype=dtype, + prefix=f"{prefix}.img_attn_proj", + quant_config=quant_config, + ) + + self.img_mlp = MLP( + hidden_size, + mlp_hidden_dim, + bias=True, + dtype=dtype, + prefix=f"{prefix}.img_mlp", + quant_config=quant_config, + ) + + # Text modulation components + self.txt_mod = ModulateProjection( + hidden_size, + factor=6, + act_layer="silu", + dtype=dtype, + prefix=f"{prefix}.txt_mod", + ) + + # Fused operations for text stream + self.txt_attn_norm = LayerNormScaleShift( + hidden_size, elementwise_affine=False, dtype=dtype + ) + self.txt_attn_residual_mlp_norm = ScaleResidualLayerNormScaleShift( + hidden_size, elementwise_affine=False, dtype=dtype + ) + self.txt_mlp_residual = MulAdd() + + # Text attention components + self.txt_attn_qkv = ReplicatedLinear( + hidden_size, + hidden_size * 3, + bias=True, + params_dtype=dtype, + quant_config=quant_config, + ) + + # QK norm layers for text + self.txt_attn_q_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype) + self.txt_attn_k_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype) + + self.txt_attn_proj = ReplicatedLinear( + hidden_size, + hidden_size, + bias=True, + params_dtype=dtype, + quant_config=quant_config, + ) + + self.txt_mlp = MLP( + hidden_size, + mlp_hidden_dim, + bias=True, + dtype=dtype, + quant_config=quant_config, + ) + + # Use UlyssesAttention to replace Distributed attention + self.attn = UlyssesAttention( + num_heads=num_attention_heads, + head_size=head_dim, + causal=False, + supported_attention_backends=supported_attention_backends, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + img: torch.Tensor, + txt: torch.Tensor, + vec: torch.Tensor, + freqs_cis: tuple, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Process modulation vectors + img_mod_outputs = self.img_mod(vec) + ( + img_attn_shift, + img_attn_scale, + img_attn_gate, + img_mlp_shift, + img_mlp_scale, + img_mlp_gate, + ) = torch.chunk(img_mod_outputs, 6, dim=-1) + + txt_mod_outputs = self.txt_mod(vec) + ( + txt_attn_shift, + txt_attn_scale, + txt_attn_gate, + txt_mlp_shift, + txt_mlp_scale, + txt_mlp_gate, + ) = torch.chunk(txt_mod_outputs, 6, dim=-1) + + # Prepare image for attention using fused operation + img_attn_input = self.img_attn_norm(img, img_attn_shift, img_attn_scale) + # Get QKV for image + img_qkv, _ = self.img_attn_qkv(img_attn_input) + batch_size, image_seq_len = img_qkv.shape[0], img_qkv.shape[1] + + # Split QKV + img_qkv = img_qkv.view( + batch_size, image_seq_len, 3, self.num_attention_heads, -1 + ) + img_q, img_k, img_v = img_qkv[:, :, 0], img_qkv[:, :, 1], img_qkv[:, :, 2] + + # Apply QK-Norm if needed + + img_q = self.img_attn_q_norm(img_q.contiguous()).to(img_v) + img_k = self.img_attn_k_norm(img_k.contiguous()).to(img_v) + # Apply rotary embeddings + cos, sin = freqs_cis + img_q, img_k = _apply_rotary_emb( + img_q, cos, sin, is_neox_style=False + ), _apply_rotary_emb(img_k, cos, sin, is_neox_style=False) + # Prepare text for attention using fused operation + txt_attn_input = self.txt_attn_norm(txt, txt_attn_shift, txt_attn_scale) + + # Get QKV for text + txt_qkv, _ = self.txt_attn_qkv(txt_attn_input) + batch_size, text_seq_len = txt_qkv.shape[0], txt_qkv.shape[1] + + # Split QKV + txt_qkv = txt_qkv.view( + batch_size, text_seq_len, 3, self.num_attention_heads, -1 + ) + txt_q, txt_k, txt_v = txt_qkv[:, :, 0], txt_qkv[:, :, 1], txt_qkv[:, :, 2] + + # Apply QK-Norm if needed + txt_q = self.txt_attn_q_norm(txt_q.contiguous()).to(txt_q.dtype) + txt_k = self.txt_attn_k_norm(txt_k.contiguous()).to(txt_k.dtype) + + # Run distributed attention + img_attn, txt_attn = self.attn(img_q, img_k, img_v, txt_q, txt_k, txt_v) + img_attn_out, _ = self.img_attn_proj( + img_attn.view(batch_size, image_seq_len, -1) + ) + # Use fused operation for residual connection, normalization, and modulation + img_mlp_input, img_residual = self.img_attn_residual_mlp_norm( + img, img_attn_out, img_attn_gate, img_mlp_shift, img_mlp_scale + ) + + # Process image MLP + img_mlp_out = self.img_mlp(img_mlp_input) + img = self.img_mlp_residual(img_mlp_out, img_mlp_gate, img_residual) + + # Process text attention output + txt_attn_out, _ = self.txt_attn_proj( + txt_attn.reshape(batch_size, text_seq_len, -1) + ) + + # Use fused operation for residual connection, normalization, and modulation + txt_mlp_input, txt_residual = self.txt_attn_residual_mlp_norm( + txt, txt_attn_out, txt_attn_gate, txt_mlp_shift, txt_mlp_scale + ) + + # Process text MLP + txt_mlp_out = self.txt_mlp(txt_mlp_input) + txt = self.txt_mlp_residual(txt_mlp_out, txt_mlp_gate, txt_residual) + + return img, txt + + +class MMSingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers using distributed attention + and tensor parallelism. + """ + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + mlp_ratio: float = 4.0, + dtype: torch.dtype | None = None, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + prefix: str = "", + quant_config: QuantizationConfig | None = None, + ): + super().__init__() + + self.deterministic = False + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + head_dim = hidden_size // num_attention_heads + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp_hidden_dim = mlp_hidden_dim + + # Combined QKV and MLP input projection + self.linear1 = ReplicatedLinear( + hidden_size, + hidden_size * 3 + mlp_hidden_dim, + bias=True, + params_dtype=dtype, + prefix=f"{prefix}.linear1", + quant_config=quant_config, + ) + + # Combined projection and MLP output + self.linear2 = ReplicatedLinear( + hidden_size + mlp_hidden_dim, + hidden_size, + bias=True, + params_dtype=dtype, + prefix=f"{prefix}.linear2", + quant_config=quant_config, + ) + + # QK norm layers + self.q_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype) + self.k_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype) + + # Fused operations with better naming + self.input_norm_scale_shift = LayerNormScaleShift( + hidden_size, + eps=1e-6, + elementwise_affine=False, + dtype=dtype, + ) + self.output_residual = MulAdd() + + # Activation function + self.mlp_act = nn.GELU(approximate="tanh") + + # Modulation + self.modulation = ModulateProjection( + hidden_size, + factor=3, + act_layer="silu", + dtype=dtype, + prefix=f"{prefix}.modulation", + ) + + # Use UlyssesAttention to replace Distributed attention + self.attn = UlyssesAttention( + num_heads=num_attention_heads, + head_size=head_dim, + causal=False, + supported_attention_backends=supported_attention_backends, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + x: torch.Tensor, + vec: torch.Tensor, + txt_len: int, + freqs_cis: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + # Process modulation + mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1) + + # Apply pre-norm and modulation using fused operation + x_mod = self.input_norm_scale_shift(x, mod_shift, mod_scale) + + # Get combined projections + linear1_out, _ = self.linear1(x_mod) + + # Split into QKV and MLP parts + qkv, mlp = torch.split( + linear1_out, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1 + ) + + # Process QKV + batch_size, seq_len = qkv.shape[0], qkv.shape[1] + qkv = qkv.view(batch_size, seq_len, 3, self.num_attention_heads, -1) + q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + + # Apply QK-Norm + q = self.q_norm(q.contiguous()).to(v.dtype) + k = self.k_norm(k.contiguous()).to(v.dtype) + + # Split into image and text parts + img_q, txt_q = q[:, :-txt_len], q[:, -txt_len:] + img_k, txt_k = k[:, :-txt_len], k[:, -txt_len:] + img_v, txt_v = v[:, :-txt_len], v[:, -txt_len:] + # Apply rotary embeddings to image parts + cos, sin = freqs_cis + img_q, img_k = _apply_rotary_emb( + img_q, cos, sin, is_neox_style=False + ), _apply_rotary_emb(img_k, cos, sin, is_neox_style=False) + + # Run distributed attention + img_attn_output, txt_attn_output = self.attn( + img_q, img_k, img_v, txt_q, txt_k, txt_v + ) + attn_output = torch.cat((img_attn_output, txt_attn_output), dim=1).view( + batch_size, seq_len, -1 + ) + # Process MLP activation + mlp_output = self.mlp_act(mlp) + + # Combine attention and MLP outputs + combined = torch.cat((attn_output, mlp_output), dim=-1) + + # Final projection + output, _ = self.linear2(combined) + + # Apply residual connection with gating using fused operation + return self.output_residual(output, mod_gate, x) + + +class HunyuanVideoTransformer3DModel(CachableDiT, OffloadableDiTMixin): + """ + HunyuanVideo Transformer backbone adapted for distributed training. + + This implementation uses distributed attention and linear layers for efficient + parallel processing across multiple GPUs. + + Based on the architecture from: + - Flux.1: https://github.com/black-forest-labs/flux + - MMDiT: http://arxiv.org/abs/2403.03206 + """ + + # PY: we make the input args the same as HF config + + # shard single stream, double stream blocks, and refiner_blocks + _fsdp_shard_conditions = HunyuanVideoConfig()._fsdp_shard_conditions + _compile_conditions = HunyuanVideoConfig()._compile_conditions + _supported_attention_backends = HunyuanVideoConfig()._supported_attention_backends + param_names_mapping = HunyuanVideoConfig().param_names_mapping + reverse_param_names_mapping = HunyuanVideoConfig().reverse_param_names_mapping + lora_param_names_mapping = HunyuanVideoConfig().lora_param_names_mapping + + def __init__( + self, + config: HunyuanVideoConfig, + hf_config: dict[str, Any], + quant_config: QuantizationConfig | None = None, + ): + super().__init__(config=config, hf_config=hf_config) + + self.patch_size = [config.patch_size_t, config.patch_size, config.patch_size] + self.in_channels = config.in_channels + self.num_channels_latents = config.num_channels_latents + self.out_channels = ( + config.in_channels if config.out_channels is None else config.out_channels + ) + self.unpatchify_channels = self.out_channels + self.guidance_embeds = config.guidance_embeds + self.rope_dim_list = list(config.rope_axes_dim) + self.rope_theta = config.rope_theta + self.text_states_dim = config.text_embed_dim + self.text_states_dim_2 = config.pooled_projection_dim + # TODO(will): hack? + self.dtype = config.dtype + + pe_dim = config.hidden_size // config.num_attention_heads + if sum(config.rope_axes_dim) != pe_dim: + raise ValueError( + f"Got {config.rope_axes_dim} but expected positional dim {pe_dim}" + ) + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_channels_latents = config.num_channels_latents + + # Image projection + self.img_in = PatchEmbed( + self.patch_size, + self.in_channels, + self.hidden_size, + dtype=config.dtype, + prefix=f"{config.prefix}.img_in", + ) + + self.txt_in = SingleTokenRefiner( + self.text_states_dim, + config.hidden_size, + config.num_attention_heads, + depth=config.num_refiner_layers, + dtype=config.dtype, + prefix=f"{config.prefix}.txt_in", + ) + + # Time modulation + self.time_in = TimestepEmbedder( + self.hidden_size, + act_layer="silu", + dtype=config.dtype, + prefix=f"{config.prefix}.time_in", + ) + + # Text modulation + self.vector_in = MLP( + self.text_states_dim_2, + self.hidden_size, + self.hidden_size, + act_type="silu", + dtype=config.dtype, + prefix=f"{config.prefix}.vector_in", + ) + + # Guidance modulation + self.guidance_in = ( + TimestepEmbedder( + self.hidden_size, + act_layer="silu", + dtype=config.dtype, + prefix=f"{config.prefix}.guidance_in", + ) + if self.guidance_embeds + else None + ) + + # Double blocks + self.double_blocks = nn.ModuleList( + [ + MMDoubleStreamBlock( + config.hidden_size, + config.num_attention_heads, + mlp_ratio=config.mlp_ratio, + dtype=config.dtype, + supported_attention_backends=self._supported_attention_backends, + prefix=f"{config.prefix}.double_blocks.{i}", + quant_config=quant_config, + ) + for i in range(config.num_layers) + ] + ) + + # Single blocks + self.single_blocks = nn.ModuleList( + [ + MMSingleStreamBlock( + config.hidden_size, + config.num_attention_heads, + mlp_ratio=config.mlp_ratio, + dtype=config.dtype, + supported_attention_backends=self._supported_attention_backends, + prefix=f"{config.prefix}.single_blocks.{i + config.num_layers}", + quant_config=quant_config, + ) + for i in range(config.num_single_layers) + ] + ) + + self.final_layer = FinalLayer( + config.hidden_size, + self.patch_size, + self.out_channels, + dtype=config.dtype, + prefix=f"{config.prefix}.final_layer", + ) + + self.__post_init__() + + self.layer_names = ["double_blocks", "single_blocks"] + + # TODO: change the input the FORWARD_BATCH Dict + # TODO: change output to a dict + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + timestep: torch.LongTensor, + encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None, + guidance=None, + **kwargs, + ): + """ + Forward pass of the HunyuanDiT model. + + Args: + hidden_states: Input image/video latents [B, C, T, H, W] + encoder_hidden_states: Text embeddings [B, L, D] + timestep: Diffusion timestep + guidance: Guidance scale for CFG + + Returns: + Tuple of (output) + """ + forward_context = get_forward_context() + forward_batch = forward_context.forward_batch + enable_teacache = forward_batch is not None and forward_batch.enable_teacache + + if guidance is None: + guidance = torch.tensor( + [6016.0], device=hidden_states.device, dtype=hidden_states.dtype + ) + + img = x = hidden_states + t = timestep + + # Split text embeddings - first token is global, rest are per-token + if isinstance(encoder_hidden_states, torch.Tensor): + txt = encoder_hidden_states[:, 1:] + text_states_2 = encoder_hidden_states[:, 0, : self.text_states_dim_2] + else: + txt = encoder_hidden_states[0] + text_states_2 = encoder_hidden_states[1] + + # Get spatial dimensions + _, _, ot, oh, ow = x.shape # codespell:ignore + tt, th, tw = ( + ot // self.patch_size[0], # codespell:ignore + oh // self.patch_size[1], + ow // self.patch_size[2], + ) + + # Get rotary embeddings + freqs_cos, freqs_sin = get_rotary_pos_embed( + (tt * get_sp_world_size(), th, tw), + self.hidden_size, + self.num_attention_heads, + self.rope_dim_list, + self.rope_theta, + ) + freqs_cos = freqs_cos.to(x.device) + freqs_sin = freqs_sin.to(x.device) + # Prepare modulation vectors + vec = self.time_in(t) + + # Add text modulation + vec = vec + self.vector_in(text_states_2) + + # Add guidance modulation if needed + if self.guidance_in and guidance is not None: + vec = vec + self.guidance_in(guidance) + # Embed image and text + img = self.img_in(img) + txt = self.txt_in(txt, t) + txt_seq_len = txt.shape[1] + img_seq_len = img.shape[1] + + freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None + + should_skip_forward = self.should_skip_forward_for_cached_states( + img=img, vec=vec + ) + + if should_skip_forward: + img = self.retrieve_cached_states(img) + else: + if enable_teacache: + original_img = img.clone() + + # Process through double stream blocks + for index, block in enumerate(self.double_blocks): + double_block_args = [img, txt, vec, freqs_cis] + img, txt = block(*double_block_args) + # Merge txt and img to pass through single stream blocks + x = torch.cat((img, txt), 1) + + # Process through single stream blocks + if len(self.single_blocks) > 0: + for index, block in enumerate(self.single_blocks): + single_block_args = [ + x, + vec, + txt_seq_len, + freqs_cis, + ] + x = block(*single_block_args) + + # Extract image features + img = x[:, :img_seq_len, ...] + + if enable_teacache: + self.maybe_cache_states(img, original_img) + + # Final layer processing + img = self.final_layer(img, vec) + # Unpatchify to get original shape + img = unpatchify(img, tt, th, tw, self.patch_size, self.out_channels) + + return img + + def maybe_cache_states( + self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor + ) -> None: + self.previous_residual = hidden_states - original_hidden_states + + def should_skip_forward_for_cached_states(self, **kwargs) -> bool: + + forward_context = get_forward_context() + forward_batch = forward_context.forward_batch + if forward_batch is None: + return False + current_timestep = forward_context.current_timestep + enable_teacache = forward_batch.enable_teacache + + if not enable_teacache: + return False + raise NotImplementedError("teacache is not supported yet for HunyuanVideo") + + teacache_params = forward_batch.teacache_params + assert teacache_params is not None, "teacache_params is not initialized" + assert isinstance( + teacache_params, TeaCacheParams + ), "teacache_params is not a TeaCacheParams" + num_inference_steps = forward_batch.num_inference_steps + teache_thresh = teacache_params.teacache_thresh + + coefficients = teacache_params.coefficients + + if current_timestep == 0: + self.cnt = 0 + + inp = kwargs["img"].clone() + vec_ = kwargs["vec"].clone() + # convert to DTensor + vec_ = torch.distributed.tensor.DTensor.from_local( + vec_, + torch.distributed.DeviceMesh( + current_platform.device_type, + list(range(get_sp_world_size())), + mesh_dim_names=("dp",), + ), + [torch.distributed.tensor.Replicate()], + ) + + inp = torch.distributed.tensor.DTensor.from_local( + inp, + torch.distributed.DeviceMesh( + current_platform.device_type, + list(range(get_sp_world_size())), + mesh_dim_names=("dp",), + ), + [torch.distributed.tensor.Replicate()], + ) + + # txt_ = kwargs["txt"].clone() + + # inp = img.clone() + # vec_ = vec.clone() + # txt_ = txt.clone() + ( + img_mod1_shift, + img_mod1_scale, + img_mod1_gate, + img_mod2_shift, + img_mod2_scale, + img_mod2_gate, + ) = ( + self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1) + ) + normed_inp = self.double_blocks[0].img_attn_norm.norm(inp) + modulated_inp = modulate(normed_inp, shift=img_mod1_shift, scale=img_mod1_scale) + if self.cnt == 0 or self.cnt == num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = [ + 7.33226126e02, + -4.01131952e02, + 6.75869174e01, + -3.14987800e00, + 9.61237896e-02, + ] + rescale_func = np.poly1d(coefficients) + assert ( + self.previous_modulated_input is not None + ), "previous_modulated_input is not initialized" + self.accumulated_rel_l1_distance += rescale_func( + ( + (modulated_inp - self.previous_modulated_input).abs().mean() + / self.previous_modulated_input.abs().mean() + ) + .cpu() + .item() + ) + if self.accumulated_rel_l1_distance < teache_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.cnt += 1 + + return not should_calc + + def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor: + return hidden_states + self.previous_residual + + +class SingleTokenRefiner(nn.Module): + """ + A token refiner that processes text embeddings with attention to improve + their representation for cross-attention with image features. + """ + + def __init__( + self, + in_channels, + hidden_size, + num_attention_heads, + depth=2, + qkv_bias=True, + dtype=None, + prefix: str = "", + ) -> None: + super().__init__() + + # Input projection + self.input_embedder = ReplicatedLinear( + in_channels, + hidden_size, + bias=True, + params_dtype=dtype, + prefix=f"{prefix}.input_embedder", + ) + + # Timestep embedding + self.t_embedder = TimestepEmbedder( + hidden_size, act_layer="silu", dtype=dtype, prefix=f"{prefix}.t_embedder" + ) + + # Context embedding + self.c_embedder = MLP( + in_channels, + hidden_size, + hidden_size, + act_type="silu", + dtype=dtype, + prefix=f"{prefix}.c_embedder", + ) + + # Refiner blocks + self.refiner_blocks = nn.ModuleList( + [ + IndividualTokenRefinerBlock( + hidden_size, + num_attention_heads, + qkv_bias=qkv_bias, + dtype=dtype, + prefix=f"{prefix}.refiner_blocks.{i}", + ) + for i in range(depth) + ] + ) + + def forward(self, x, t): + # Get timestep embeddings + timestep_aware_representations = self.t_embedder(t) + + # Get context-aware representations + + context_aware_representations = torch.mean(x, dim=1) + + context_aware_representations = self.c_embedder(context_aware_representations) + c = timestep_aware_representations + context_aware_representations + # Project input + x, _ = self.input_embedder(x) + # Process through refiner blocks + for block in self.refiner_blocks: + x = block(x, c) + return x + + +class IndividualTokenRefinerBlock(nn.Module): + """ + A transformer block for refining individual tokens with self-attention. + """ + + def __init__( + self, + hidden_size, + num_attention_heads, + mlp_ratio=4.0, + qkv_bias=True, + dtype=None, + prefix: str = "", + ) -> None: + super().__init__() + self.num_attention_heads = num_attention_heads + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + # Normalization and attention + self.norm1 = nn.LayerNorm( + hidden_size, eps=1e-6, elementwise_affine=True, dtype=dtype + ) + + self.self_attn_qkv = ReplicatedLinear( + hidden_size, + hidden_size * 3, + bias=qkv_bias, + params_dtype=dtype, + prefix=f"{prefix}.self_attn_qkv", + ) + + self.self_attn_proj = ReplicatedLinear( + hidden_size, + hidden_size, + bias=qkv_bias, + params_dtype=dtype, + prefix=f"{prefix}.self_attn_proj", + ) + + # MLP + self.norm2 = nn.LayerNorm( + hidden_size, eps=1e-6, elementwise_affine=True, dtype=dtype + ) + self.mlp = MLP( + hidden_size, + mlp_hidden_dim, + bias=True, + act_type="silu", + dtype=dtype, + prefix=f"{prefix}.mlp", + ) + + # Modulation + self.adaLN_modulation = ModulateProjection( + hidden_size, + factor=2, + act_layer="silu", + dtype=dtype, + prefix=f"{prefix}.adaLN_modulation", + ) + + # Scaled dot product attention + self.attn = LocalAttention( + num_heads=num_attention_heads, + head_size=hidden_size // num_attention_heads, + # TODO: remove hardcode; remove STA + supported_attention_backends=( + AttentionBackendEnum.FA, + AttentionBackendEnum.AITER, + AttentionBackendEnum.TORCH_SDPA, + ), + ) + + def forward(self, x, c): + # Get modulation parameters + gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=-1) + # Self-attention + norm_x = self.norm1(x) + qkv, _ = self.self_attn_qkv(norm_x) + + batch_size, seq_len = qkv.shape[0], qkv.shape[1] + qkv = qkv.view(batch_size, seq_len, 3, self.num_attention_heads, -1) + q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + + # Run scaled dot product attention + attn_output = self.attn(q, k, v) # [B, L, H, D] + attn_output = attn_output.reshape(batch_size, seq_len, -1) # [B, L, H*D] + + # Project and apply residual connection with gating + attn_out, _ = self.self_attn_proj(attn_output) + x = x + attn_out * gate_msa.unsqueeze(1) + + # MLP + mlp_out = self.mlp(self.norm2(x)) + x = x + mlp_out * gate_mlp.unsqueeze(1) + + return x + + +class FinalLayer(nn.Module): + """ + The final layer of DiT that projects features to pixel space. + """ + + def __init__( + self, hidden_size, patch_size, out_channels, dtype=None, prefix: str = "" + ) -> None: + super().__init__() + + # Normalization + self.norm_final = nn.LayerNorm( + hidden_size, eps=1e-6, elementwise_affine=False, dtype=dtype + ) + + output_dim = patch_size[0] * patch_size[1] * patch_size[2] * out_channels + + self.linear = ReplicatedLinear( + hidden_size, + output_dim, + bias=True, + params_dtype=dtype, + prefix=f"{prefix}.linear", + ) + + # Modulation + self.adaLN_modulation = ModulateProjection( + hidden_size, + factor=2, + act_layer="silu", + dtype=dtype, + prefix=f"{prefix}.adaLN_modulation", + ) + + def forward(self, x, c): + # What the heck HF? Why you change the scale and shift order here??? + scale, shift = self.adaLN_modulation(c).chunk(2, dim=-1) + x = self.norm_final(x) * (1.0 + scale.unsqueeze(1)) + shift.unsqueeze(1) + x, _ = self.linear(x) + return x + + +EntryClass = HunyuanVideoTransformer3DModel diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py b/sglang/python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py new file mode 100644 index 0000000000000000000000000000000000000000..f8c4935780dc9baed49e93805efe7f8431708f5d --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py @@ -0,0 +1,1473 @@ +# Copied and adapted from LTX-2 and WanVideo implementations. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Any, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sglang.multimodal_gen.configs.models.dits.ltx_2 import LTX2ArchConfig, LTX2Config +from sglang.multimodal_gen.runtime.distributed import ( + get_sp_parallel_rank, + get_sp_world_size, + get_tp_rank, + get_tp_world_size, + model_parallel_is_initialized, +) +from sglang.multimodal_gen.runtime.distributed.communication_op import ( + tensor_model_parallel_all_reduce, +) +from sglang.multimodal_gen.runtime.layers.attention import USPAttention +from sglang.multimodal_gen.runtime.layers.linear import ( + ColumnParallelLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, +) +from sglang.multimodal_gen.runtime.layers.visual_embedding import timestep_embedding +from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def apply_interleaved_rotary_emb( + x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor] +) -> torch.Tensor: + cos, sin = freqs + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) + return (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + +def apply_split_rotary_emb( + x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor] +) -> torch.Tensor: + cos, sin = freqs + x_dtype = x.dtype + needs_reshape = False + if x.ndim != 4 and cos.ndim == 4: + b = x.shape[0] + _, h, t, _ = cos.shape + x = x.reshape(b, t, h, -1).swapaxes(1, 2) + needs_reshape = True + + last = x.shape[-1] + if last % 2 != 0: + raise ValueError( + f"Expected x.shape[-1] to be even for split rotary, got {last}." + ) + r = last // 2 + + split_x = x.reshape(*x.shape[:-1], 2, r) + first_x = split_x[..., :1, :] + second_x = split_x[..., 1:, :] + + cos_u = cos.unsqueeze(-2) + sin_u = sin.unsqueeze(-2) + + out = split_x * cos_u + first_out = out[..., :1, :] + second_out = out[..., 1:, :] + first_out.addcmul_(-sin_u, second_x) + second_out.addcmul_(sin_u, first_x) + + out = out.reshape(*out.shape[:-2], last) + if needs_reshape: + out = out.swapaxes(1, 2).reshape(b, t, -1) + return out.to(dtype=x_dtype) + + +# ============================================================================== +# Layers and Embeddings +# ============================================================================== + + +class LTX2AudioVideoRotaryPosEmbed(nn.Module): + def __init__( + self, + dim: int, + patch_size: int = 1, + patch_size_t: int = 1, + base_num_frames: int = 20, + base_height: int = 2048, + base_width: int = 2048, + sampling_rate: int = 16000, + hop_length: int = 160, + scale_factors: Tuple[int, ...] = (8, 32, 32), + theta: float = 10000.0, + causal_offset: int = 1, + modality: str = "video", + double_precision: bool = True, + rope_type: str = "interleaved", + num_attention_heads: int = 32, + ) -> None: + super().__init__() + self.dim = int(dim) + self.patch_size = int(patch_size) + self.patch_size_t = int(patch_size_t) + + if rope_type not in ["interleaved", "split"]: + raise ValueError( + f"{rope_type=} not supported. Choose between 'interleaved' and 'split'." + ) + self.rope_type = rope_type + + self.base_num_frames = int(base_num_frames) + self.num_attention_heads = int(num_attention_heads) + + self.base_height = int(base_height) + self.base_width = int(base_width) + + self.sampling_rate = int(sampling_rate) + self.hop_length = int(hop_length) + self.audio_latents_per_second = ( + float(self.sampling_rate) / float(self.hop_length) / float(scale_factors[0]) + ) + + self.scale_factors = tuple(int(x) for x in scale_factors) + self.theta = float(theta) + self.causal_offset = int(causal_offset) + + self.modality = modality + self.coords_dtype = torch.bfloat16 if modality == "video" else torch.float32 + if self.modality not in ["video", "audio"]: + raise ValueError( + f"Modality {modality} is not supported. Supported modalities are `video` and `audio`." + ) + self.double_precision = bool(double_precision) + + def prepare_video_coords( + self, + batch_size: int, + num_frames: int, + height: int, + width: int, + device: torch.device, + fps: float = 24.0, + *, + start_frame: int = 0, + ) -> torch.Tensor: + grid_f = torch.arange( + start=int(start_frame), + end=int(num_frames) + int(start_frame), + step=self.patch_size_t, + dtype=torch.float32, + device=device, + ) + grid_h = torch.arange( + start=0, + end=height, + step=self.patch_size, + dtype=torch.float32, + device=device, + ) + grid_w = torch.arange( + start=0, + end=width, + step=self.patch_size, + dtype=torch.float32, + device=device, + ) + grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") + grid = torch.stack(grid, dim=0) + + patch_size = (self.patch_size_t, self.patch_size, self.patch_size) + patch_size_delta = torch.tensor( + patch_size, dtype=grid.dtype, device=grid.device + ) + patch_ends = grid + patch_size_delta.view(3, 1, 1, 1) + + latent_coords = torch.stack([grid, patch_ends], dim=-1) + latent_coords = latent_coords.flatten(1, 3) + latent_coords = latent_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1) + + scale_tensor = torch.tensor(self.scale_factors, device=latent_coords.device) + broadcast_shape = [1] * latent_coords.ndim + broadcast_shape[1] = -1 + pixel_coords = latent_coords * scale_tensor.view(*broadcast_shape) + pixel_coords[:, 0, ...] = ( + pixel_coords[:, 0, ...] + self.causal_offset - self.scale_factors[0] + ).clamp(min=0) + pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps + return pixel_coords + + def prepare_audio_coords( + self, + batch_size: int, + num_frames: int, + device: torch.device, + *, + start_frame: int = 0, + ) -> torch.Tensor: + grid_f = torch.arange( + start=int(start_frame), + end=int(num_frames) + int(start_frame), + step=self.patch_size_t, + dtype=torch.float32, + device=device, + ) + + audio_scale_factor = self.scale_factors[0] + grid_start_mel = grid_f * audio_scale_factor + grid_start_mel = ( + grid_start_mel + self.causal_offset - audio_scale_factor + ).clip(min=0) + grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate + + grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor + grid_end_mel = (grid_end_mel + self.causal_offset - audio_scale_factor).clip( + min=0 + ) + grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate + + audio_coords = torch.stack([grid_start_s, grid_end_s], dim=-1) + audio_coords = audio_coords.unsqueeze(0).expand(batch_size, -1, -1) + audio_coords = audio_coords.unsqueeze(1) + return audio_coords + + def prepare_coords(self, *args, **kwargs): + if self.modality == "video": + return self.prepare_video_coords(*args, **kwargs) + return self.prepare_audio_coords(*args, **kwargs) + + def forward( + self, coords: torch.Tensor, device: Optional[Union[str, torch.device]] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = device or coords.device + num_pos_dims = coords.shape[1] + + coords = coords.to(self.coords_dtype) + if coords.ndim == 4: + coords_start, coords_end = coords.chunk(2, dim=-1) + coords = (coords_start + coords_end) / 2.0 + coords = coords.squeeze(-1) + + if self.modality == "video": + max_positions = (self.base_num_frames, self.base_height, self.base_width) + else: + max_positions = (self.base_num_frames,) + + grid = torch.stack( + [coords[:, i] / max_positions[i] for i in range(num_pos_dims)], dim=-1 + ).to(device) + + num_rope_elems = num_pos_dims * 2 + freqs_dtype = torch.float64 if self.double_precision else torch.float32 + pow_indices = torch.pow( + self.theta, + torch.linspace( + start=0.0, + end=1.0, + steps=self.dim // num_rope_elems, + dtype=freqs_dtype, + device=device, + ), + ) + freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) + + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs + freqs = freqs.transpose(-1, -2).flatten(2) + + if self.rope_type == "interleaved": + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like( + cos_freqs[:, :, : self.dim % num_rope_elems] + ) + sin_padding = torch.zeros_like( + cos_freqs[:, :, : self.dim % num_rope_elems] + ) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + else: + expected_freqs = self.dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + + b = cos_freq.shape[0] + t = cos_freq.shape[1] + cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1) + cos_freqs = torch.swapaxes(cos_freq, 1, 2) + sin_freqs = torch.swapaxes(sin_freq, 1, 2) + + # Cast to bf16 to match model weights dtype. coords_dtype controls + # intermediate coordinate precision (fp32 for audio) and differs. + return cos_freqs.to(torch.bfloat16), sin_freqs.to(torch.bfloat16) + + +def rms_norm(x: torch.Tensor, eps: float) -> torch.Tensor: + return F.rms_norm(x, normalized_shape=(x.shape[-1],), eps=eps) + + +class LTX2TextProjection(nn.Module): + def __init__( + self, + in_features: int, + hidden_size: int, + out_features: int | None = None, + act_fn: str = "gelu_tanh", + ) -> None: + super().__init__() + if out_features is None: + out_features = hidden_size + + self.linear_1 = ColumnParallelLinear( + in_features, hidden_size, bias=True, gather_output=True + ) + if act_fn == "gelu_tanh": + self.act_1 = nn.GELU(approximate="tanh") + elif act_fn == "silu": + self.act_1 = nn.SiLU() + else: + raise ValueError(f"Unknown activation function: {act_fn}") + + self.linear_2 = ColumnParallelLinear( + hidden_size, out_features, bias=True, gather_output=True + ) + + def forward(self, caption: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states, _ = self.linear_2(hidden_states) + return hidden_states + + +class LTX2TimestepEmbedder(nn.Module): + def __init__(self, embedding_dim: int, in_channels: int = 256) -> None: + super().__init__() + self.linear_1 = ColumnParallelLinear( + in_channels, embedding_dim, bias=True, gather_output=True + ) + self.linear_2 = ColumnParallelLinear( + embedding_dim, embedding_dim, bias=True, gather_output=True + ) + + def forward(self, t_emb: torch.Tensor) -> torch.Tensor: + x, _ = self.linear_1(t_emb) + x = F.silu(x) + x, _ = self.linear_2(x) + return x + + +class LTX2PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): + def __init__(self, embedding_dim: int) -> None: + super().__init__() + self.timestep_embedder = LTX2TimestepEmbedder(embedding_dim, in_channels=256) + + def forward( + self, timestep: torch.Tensor, hidden_dtype: torch.dtype | None = None + ) -> torch.Tensor: + t = timestep.reshape(-1).to(dtype=torch.float32) + t_emb = timestep_embedding(t, dim=256, max_period=10000, dtype=torch.float32) + if hidden_dtype is not None: + t_emb = t_emb.to(dtype=hidden_dtype) + return self.timestep_embedder(t_emb) + + +class LTX2AdaLayerNormSingle(nn.Module): + def __init__(self, embedding_dim: int, embedding_coefficient: int = 6) -> None: + super().__init__() + self.emb = LTX2PixArtAlphaCombinedTimestepSizeEmbeddings(embedding_dim) + self.silu = nn.SiLU() + self.linear = ColumnParallelLinear( + embedding_dim, + embedding_coefficient * embedding_dim, + bias=True, + gather_output=True, + ) + + def forward( + self, timestep: torch.Tensor, hidden_dtype: torch.dtype | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype).to( + dtype=self.linear.weight.dtype + ) + out, _ = self.linear(self.silu(embedded_timestep)) + return out, embedded_timestep + + +class LTX2TPRMSNormAcrossHeads(nn.Module): + def __init__( + self, full_hidden_size: int, local_hidden_size: int, eps: float + ) -> None: + super().__init__() + self.full_hidden_size = full_hidden_size + self.local_hidden_size = local_hidden_size + self.eps = eps + self.weight = nn.Parameter(torch.ones(local_hidden_size)) + + tp_rank = get_tp_rank() + + def _weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + shard = loaded_weight.narrow( + 0, tp_rank * local_hidden_size, local_hidden_size + ) + param.data.copy_(shard.to(dtype=param.dtype, device=param.device)) + + setattr(self.weight, "weight_loader", _weight_loader) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Keep track of the original dtype. We do the statistics in fp32 for + # numerical stability, but cast the output back to the input dtype to + orig_dtype = x.dtype + if get_tp_world_size() == 1: + var = x.float().pow(2).mean(dim=-1, keepdim=True) + else: + local_sumsq = x.float().pow(2).sum(dim=-1, keepdim=True) + global_sumsq = tensor_model_parallel_all_reduce(local_sumsq) + var = global_sumsq / float(self.full_hidden_size) + + inv_rms_fp32 = torch.rsqrt(var + self.eps) + y = (x.float() * inv_rms_fp32).to(dtype=orig_dtype) + return y * self.weight.to(dtype=orig_dtype) + + +class LTX2Attention(nn.Module): + def __init__( + self, + query_dim: int, + context_dim: int | None = None, + heads: int = 8, + dim_head: int = 64, + norm_eps: float = 1e-6, + qk_norm: bool = True, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + prefix: str = "", + quant_config: QuantizationConfig | None = None, + ) -> None: + super().__init__() + + self.query_dim = int(query_dim) + self.context_dim = int(query_dim if context_dim is None else context_dim) + self.heads = int(heads) + self.dim_head = int(dim_head) + self.inner_dim = self.heads * self.dim_head + self.norm_eps = float(norm_eps) + self.qk_norm = bool(qk_norm) + + tp_size = get_tp_world_size() + if tp_size <= 0: + raise ValueError(f"Invalid {tp_size=}. Expected tp_size >= 1.") + if self.heads % tp_size != 0: + raise ValueError( + f"LTX2Attention requires heads divisible by tp_size, got " + f"{self.heads=} {tp_size=}." + ) + if self.inner_dim % tp_size != 0: + # This should follow from heads % tp_size, but keep explicit for clarity. + raise ValueError( + f"LTX2Attention requires inner_dim divisible by tp_size, got " + f"{self.inner_dim=} {tp_size=}." + ) + self.local_heads = self.heads // tp_size + + self.to_q = ColumnParallelLinear( + self.query_dim, + self.inner_dim, + bias=True, + gather_output=False, + quant_config=quant_config, + ) + self.to_k = ColumnParallelLinear( + self.context_dim, + self.inner_dim, + bias=True, + gather_output=False, + quant_config=quant_config, + ) + self.to_v = ColumnParallelLinear( + self.context_dim, + self.inner_dim, + bias=True, + gather_output=False, + quant_config=quant_config, + ) + + self.q_norm: nn.Module | None = None + self.k_norm: nn.Module | None = None + if self.qk_norm: + if tp_size == 1: + self.q_norm = torch.nn.RMSNorm(self.inner_dim, eps=self.norm_eps) + self.k_norm = torch.nn.RMSNorm(self.inner_dim, eps=self.norm_eps) + else: + self.q_norm = LTX2TPRMSNormAcrossHeads( + full_hidden_size=self.inner_dim, + local_hidden_size=self.inner_dim // tp_size, + eps=self.norm_eps, + ) + self.k_norm = LTX2TPRMSNormAcrossHeads( + full_hidden_size=self.inner_dim, + local_hidden_size=self.inner_dim // tp_size, + eps=self.norm_eps, + ) + + self.to_out = nn.Sequential( + RowParallelLinear( + self.inner_dim, + self.query_dim, + bias=True, + input_is_parallel=True, + quant_config=quant_config, + ), + nn.Identity(), + ) + + self.attn = USPAttention( + num_heads=self.local_heads, + head_size=self.dim_head, + num_kv_heads=self.local_heads, + dropout_rate=0, + softmax_scale=None, + causal=False, + supported_attention_backends=supported_attention_backends, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + x: torch.Tensor, + context: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + pe: tuple[torch.Tensor, torch.Tensor] | None = None, + k_pe: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + q, _ = self.to_q(x) + context_ = x if context is None else context + k, _ = self.to_k(context_) + v, _ = self.to_v(context_) + + if self.qk_norm: + assert self.q_norm is not None and self.k_norm is not None + q = self.q_norm(q) + k = self.k_norm(k) + + if pe is not None: + cos, sin = pe + k_cos, k_sin = pe if k_pe is None else k_pe + tp_size = get_tp_world_size() + if tp_size > 1: + tp_rank = get_tp_rank() + cos, sin = self._slice_rope_for_tp( + cos, sin, tp_rank=tp_rank, tp_size=tp_size + ) + k_cos, k_sin = self._slice_rope_for_tp( + k_cos, k_sin, tp_rank=tp_rank, tp_size=tp_size + ) + if cos.dim() == 3: + q = apply_interleaved_rotary_emb(q, (cos, sin)) + k = apply_interleaved_rotary_emb(k, (k_cos, k_sin)) + else: + q = apply_split_rotary_emb(q, (cos, sin)) + k = apply_split_rotary_emb(k, (k_cos, k_sin)) + + q = q.view(*q.shape[:-1], self.local_heads, self.dim_head) + k = k.view(*k.shape[:-1], self.local_heads, self.dim_head) + v = v.view(*v.shape[:-1], self.local_heads, self.dim_head) + + if mask is not None: + # Fallback to SDPA for masked attention + q_ = q.transpose(1, 2) + k_ = k.transpose(1, 2) + v_ = v.transpose(1, 2) + + if torch.is_floating_point(mask): + m = mask + if m.dim() == 2: + m = m[:, None, None, :] + elif m.dim() == 3: + m = m[:, None, :, :] + sdpa_mask = m.to(dtype=q_.dtype, device=q_.device) + else: + m = mask.to(dtype=q_.dtype, device=q_.device) + if m.dim() == 2: + m = m[:, None, None, :] + elif m.dim() == 3: + m = m[:, None, :, :] + sdpa_mask = (m - 1.0) * torch.finfo(q_.dtype).max + + out = torch.nn.functional.scaled_dot_product_attention( + q_, k_, v_, attn_mask=sdpa_mask, dropout_p=0.0, is_causal=False + ).transpose(1, 2) + else: + out = self.attn(q, k, v) + + out = out.flatten(2) + out, _ = self.to_out[0](out) + return out + + def _slice_rope_for_tp( + self, + cos: torch.Tensor, + sin: torch.Tensor, + *, + tp_rank: int, + tp_size: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Slice RoPE tensors to the local TP shard. + + - split-rope: cos/sin are shaped [B, H, T, R] (head-major), slice by heads. + - interleaved-rope: cos/sin are shaped [B, T, D], where D matches the projected + feature dimension and is sharded by TP. + """ + if cos.ndim == 4: + # [B, H, T, R] + start = tp_rank * self.local_heads + end = start + self.local_heads + return cos[:, start:end, :, :], sin[:, start:end, :, :] + elif cos.ndim == 3: + # [B, T, D] + d = cos.shape[-1] + if d % tp_size != 0: + raise ValueError( + f"RoPE dim must be divisible by tp_size, got {d=} {tp_size=}." + ) + local_d = d // tp_size + start = tp_rank * local_d + end = start + local_d + return cos[:, :, start:end], sin[:, :, start:end] + raise ValueError(f"Unexpected RoPE tensor rank: {cos.ndim}. Expected 3 or 4.") + + +class LTX2FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: int = 4, + quant_config: QuantizationConfig | None = None, + ) -> None: + super().__init__() + if dim_out is None: + dim_out = dim + inner_dim = int(dim * mult) + + self.proj_in = ColumnParallelLinear( + dim, inner_dim, bias=True, gather_output=True, quant_config=quant_config + ) + self.act = nn.GELU(approximate="tanh") + self.proj_out = ColumnParallelLinear( + inner_dim, dim_out, bias=True, gather_output=True, quant_config=quant_config + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.proj_in(x) + x = self.act(x) + x, _ = self.proj_out(x) + return x + + +class LTX2TransformerBlock(nn.Module): + def __init__( + self, + idx: int, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: int, + audio_dim: int, + audio_num_attention_heads: int, + audio_attention_head_dim: int, + audio_cross_attention_dim: int, + qk_norm: bool = True, + norm_eps: float = 1e-6, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + prefix: str = "", + quant_config: QuantizationConfig | None = None, + ): + super().__init__() + self.idx = idx + self.norm_eps = norm_eps + + # 1. Self-Attention (video and audio) + self.attn1 = LTX2Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + norm_eps=norm_eps, + qk_norm=qk_norm, + supported_attention_backends=supported_attention_backends, + prefix=f"{prefix}.attn1", + quant_config=quant_config, + ) + self.audio_attn1 = LTX2Attention( + query_dim=audio_dim, + heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + norm_eps=norm_eps, + qk_norm=qk_norm, + supported_attention_backends=supported_attention_backends, + prefix=f"{prefix}.audio_attn1", + quant_config=quant_config, + ) + + # 2. Prompt Cross-Attention + self.attn2 = LTX2Attention( + query_dim=dim, + context_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + norm_eps=norm_eps, + qk_norm=qk_norm, + supported_attention_backends=supported_attention_backends, + prefix=f"{prefix}.attn2", + quant_config=quant_config, + ) + self.audio_attn2 = LTX2Attention( + query_dim=audio_dim, + context_dim=audio_cross_attention_dim, + heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + norm_eps=norm_eps, + qk_norm=qk_norm, + supported_attention_backends=supported_attention_backends, + prefix=f"{prefix}.audio_attn2", + quant_config=quant_config, + ) + + # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention + self.audio_to_video_attn = LTX2Attention( + query_dim=dim, + context_dim=audio_dim, + heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + norm_eps=norm_eps, + qk_norm=qk_norm, + supported_attention_backends=supported_attention_backends, + prefix=f"{prefix}.audio_to_video_attn", + quant_config=quant_config, + ) + self.video_to_audio_attn = LTX2Attention( + query_dim=audio_dim, + context_dim=dim, + heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + norm_eps=norm_eps, + qk_norm=qk_norm, + supported_attention_backends=supported_attention_backends, + prefix=f"{prefix}.video_to_audio_attn", + quant_config=quant_config, + ) + + # 4. Feedforward layers + self.ff = LTX2FeedForward(dim, dim_out=dim, quant_config=quant_config) + self.audio_ff = LTX2FeedForward( + audio_dim, dim_out=audio_dim, quant_config=quant_config + ) + + # 5. Modulation Parameters + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + self.audio_scale_shift_table = nn.Parameter( + torch.randn(6, audio_dim) / audio_dim**0.5 + ) + self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim)) + self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter( + torch.randn(5, audio_dim) + ) + + def get_ada_values( + self, + scale_shift_table: torch.Tensor, + batch_size: int, + timestep: torch.Tensor, + indices: slice, + ) -> tuple[torch.Tensor, ...]: + num_ada_params = int(scale_shift_table.shape[0]) + ada_values = ( + scale_shift_table[indices] + .unsqueeze(0) + .unsqueeze(0) + .to(device=timestep.device, dtype=timestep.dtype) + + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[ + :, :, indices, : + ] + ).unbind(dim=2) + return [t.squeeze(2) for t in ada_values] + + def forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + audio_encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + temb_audio: torch.Tensor, + temb_ca_scale_shift: torch.Tensor, + temb_ca_audio_scale_shift: torch.Tensor, + temb_ca_gate: torch.Tensor, + temb_ca_audio_gate: torch.Tensor, + video_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + audio_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ca_video_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ca_audio_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + audio_encoder_attention_mask: Optional[torch.Tensor] = None, + a2v_cross_attention_mask: Optional[torch.Tensor] = None, + v2a_cross_attention_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + + batch_size = hidden_states.size(0) + + # 1. Video and Audio Self-Attention + vshift_msa, vscale_msa, vgate_msa = self.get_ada_values( + self.scale_shift_table, batch_size, temb, slice(0, 3) + ) + norm_hidden_states = ( + rms_norm(hidden_states, self.norm_eps) * (1 + vscale_msa) + vshift_msa + ) + attn_hidden_states = self.attn1(norm_hidden_states, pe=video_rotary_emb) + hidden_states = hidden_states + attn_hidden_states * vgate_msa + + ashift_msa, ascale_msa, agate_msa = self.get_ada_values( + self.audio_scale_shift_table, batch_size, temb_audio, slice(0, 3) + ) + norm_audio_hidden_states = ( + rms_norm(audio_hidden_states, self.norm_eps) * (1 + ascale_msa) + ashift_msa + ) + attn_audio_hidden_states = self.audio_attn1( + norm_audio_hidden_states, pe=audio_rotary_emb + ) + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * agate_msa + + # 2. Prompt Cross-Attention + norm_hidden_states = rms_norm(hidden_states, self.norm_eps) + attn_hidden_states = self.attn2( + norm_hidden_states, + context=encoder_hidden_states, + mask=encoder_attention_mask, + ) + hidden_states = hidden_states + attn_hidden_states + + norm_audio_hidden_states = rms_norm(audio_hidden_states, self.norm_eps) + attn_audio_hidden_states = self.audio_attn2( + norm_audio_hidden_states, + context=audio_encoder_hidden_states, + mask=audio_encoder_attention_mask, + ) + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states + + # 3. Audio-to-Video and Video-to-Audio Cross-Attention + norm_hidden_states = rms_norm(hidden_states, self.norm_eps) + norm_audio_hidden_states = rms_norm(audio_hidden_states, self.norm_eps) + + # Compute combined ada params + video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[ + :4, : + ] + video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :] + + video_ca_scale_shift_table = ( + video_per_layer_ca_scale_shift[None, None, :, :].to( + dtype=temb_ca_scale_shift.dtype, device=temb_ca_scale_shift.device + ) + + temb_ca_scale_shift.reshape( + batch_size, temb_ca_scale_shift.shape[1], 4, -1 + ) + ).unbind(dim=2) + video_ca_gate = ( + video_per_layer_ca_gate[None, None, :, :].to( + dtype=temb_ca_gate.dtype, device=temb_ca_gate.device + ) + + temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1) + ).unbind(dim=2) + + ( + video_a2v_ca_scale, + video_a2v_ca_shift, + video_v2a_ca_scale, + video_v2a_ca_shift, + ) = [t.squeeze(2) for t in video_ca_scale_shift_table] + a2v_gate = video_ca_gate[0].squeeze(2) + + audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[ + :4, : + ] + audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :] + + audio_ca_scale_shift_table = ( + audio_per_layer_ca_scale_shift[None, None, :, :].to( + dtype=temb_ca_audio_scale_shift.dtype, + device=temb_ca_audio_scale_shift.device, + ) + + temb_ca_audio_scale_shift.reshape( + batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1 + ) + ).unbind(dim=2) + audio_ca_gate = ( + audio_per_layer_ca_gate[None, None, :, :].to( + dtype=temb_ca_audio_gate.dtype, device=temb_ca_audio_gate.device + ) + + temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1) + ).unbind(dim=2) + + ( + audio_a2v_ca_scale, + audio_a2v_ca_shift, + audio_v2a_ca_scale, + audio_v2a_ca_shift, + ) = [t.squeeze(2) for t in audio_ca_scale_shift_table] + v2a_gate = audio_ca_gate[0].squeeze(2) + + # A2V + mod_norm_hidden_states = ( + norm_hidden_states * (1 + video_a2v_ca_scale) + video_a2v_ca_shift + ) + mod_norm_audio_hidden_states = ( + norm_audio_hidden_states * (1 + audio_a2v_ca_scale) + audio_a2v_ca_shift + ) + + a2v_attn_hidden_states = self.audio_to_video_attn( + mod_norm_hidden_states, + context=mod_norm_audio_hidden_states, + pe=ca_video_rotary_emb, + k_pe=ca_audio_rotary_emb, + mask=a2v_cross_attention_mask, + ) + hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states + + # V2A + mod_norm_hidden_states = ( + norm_hidden_states * (1 + video_v2a_ca_scale) + video_v2a_ca_shift + ) + mod_norm_audio_hidden_states = ( + norm_audio_hidden_states * (1 + audio_v2a_ca_scale) + audio_v2a_ca_shift + ) + + v2a_attn_hidden_states = self.video_to_audio_attn( + mod_norm_audio_hidden_states, + context=mod_norm_hidden_states, + pe=ca_audio_rotary_emb, + k_pe=ca_video_rotary_emb, + mask=v2a_cross_attention_mask, + ) + audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states + + # 4. Feedforward + vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values( + self.scale_shift_table, batch_size, temb, slice(3, None) + ) + norm_hidden_states = ( + rms_norm(hidden_states, self.norm_eps) * (1 + vscale_mlp) + vshift_mlp + ) + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_output * vgate_mlp + + ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values( + self.audio_scale_shift_table, batch_size, temb_audio, slice(3, None) + ) + norm_audio_hidden_states = ( + rms_norm(audio_hidden_states, self.norm_eps) * (1 + ascale_mlp) + ashift_mlp + ) + audio_ff_output = self.audio_ff(norm_audio_hidden_states) + audio_hidden_states = audio_hidden_states + audio_ff_output * agate_mlp + + return hidden_states, audio_hidden_states + + +class LTX2VideoTransformer3DModel(CachableDiT, OffloadableDiTMixin): + _fsdp_shard_conditions = LTX2ArchConfig()._fsdp_shard_conditions + _compile_conditions = LTX2ArchConfig()._compile_conditions + _supported_attention_backends = LTX2ArchConfig()._supported_attention_backends + param_names_mapping = LTX2ArchConfig().param_names_mapping + reverse_param_names_mapping = LTX2ArchConfig().reverse_param_names_mapping + lora_param_names_mapping = LTX2ArchConfig().lora_param_names_mapping + + def _validate_tp_config(self, *, arch: LTX2ArchConfig, tp_size: int) -> None: + """Validate TP-related dimension constraints (fail-fast).""" + if tp_size < 1: + raise ValueError(f"Invalid tp_size={tp_size}. Expected tp_size >= 1.") + + if self.hidden_size % self.num_attention_heads != 0: + raise ValueError( + "video hidden_size must be divisible by num_attention_heads, got " + f"{self.hidden_size=} {self.num_attention_heads=}." + ) + if self.audio_hidden_size % self.audio_num_attention_heads != 0: + raise ValueError( + "audio_hidden_size must be divisible by audio_num_attention_heads, got " + f"{self.audio_hidden_size=} {self.audio_num_attention_heads=}." + ) + + if tp_size == 1: + return + + if self.num_attention_heads % tp_size != 0: + raise ValueError( + "num_attention_heads must be divisible by tp_size, got " + f"{self.num_attention_heads=} {tp_size=}." + ) + if self.audio_num_attention_heads % tp_size != 0: + raise ValueError( + "audio_num_attention_heads must be divisible by tp_size, got " + f"{self.audio_num_attention_heads=} {tp_size=}." + ) + if self.hidden_size % tp_size != 0: + raise ValueError( + "hidden_size must be divisible by tp_size for TP-sharded projections, got " + f"{self.hidden_size=} {tp_size=}." + ) + if self.audio_hidden_size % tp_size != 0: + raise ValueError( + "audio_hidden_size must be divisible by tp_size for TP-sharded projections, got " + f"{self.audio_hidden_size=} {tp_size=}." + ) + if int(arch.out_channels) % tp_size != 0: + raise ValueError( + "out_channels must be divisible by tp_size for TP-sharded output projection, got " + f"{arch.out_channels=} {tp_size=}." + ) + if int(arch.audio_out_channels) % tp_size != 0: + raise ValueError( + "audio_out_channels must be divisible by tp_size for TP-sharded output projection, got " + f"{arch.audio_out_channels=} {tp_size=}." + ) + + def __init__( + self, + config: LTX2Config, + hf_config: dict[str, Any], + quant_config: QuantizationConfig | None = None, + ) -> None: + super().__init__(config=config, hf_config=hf_config) + + arch = config.arch_config + self.hidden_size = arch.hidden_size + self.num_attention_heads = arch.num_attention_heads + self.audio_hidden_size = arch.audio_hidden_size + self.audio_num_attention_heads = arch.audio_num_attention_heads + self.norm_eps = arch.norm_eps + + tp_size = get_tp_world_size() + self._validate_tp_config(arch=arch, tp_size=tp_size) + + # 1. Patchification input projections + # Matches LTX2Config().param_names_mapping + self.patchify_proj = ColumnParallelLinear( + arch.in_channels, + self.hidden_size, + bias=True, + gather_output=True, + quant_config=quant_config, + ) + self.audio_patchify_proj = ColumnParallelLinear( + arch.audio_in_channels, + self.audio_hidden_size, + bias=True, + gather_output=True, + quant_config=quant_config, + ) + + # 2. Prompt embeddings + self.caption_projection = LTX2TextProjection( + in_features=arch.caption_channels, hidden_size=self.hidden_size + ) + self.audio_caption_projection = LTX2TextProjection( + in_features=arch.caption_channels, hidden_size=self.audio_hidden_size + ) + + # 3. Timestep Modulation Params and Embedding + self.adaln_single = LTX2AdaLayerNormSingle( + self.hidden_size, embedding_coefficient=6 + ) + self.audio_adaln_single = LTX2AdaLayerNormSingle( + self.audio_hidden_size, embedding_coefficient=6 + ) + + # Global Cross Attention Modulation Parameters + self.av_ca_video_scale_shift_adaln_single = LTX2AdaLayerNormSingle( + self.hidden_size, embedding_coefficient=4 + ) + self.av_ca_a2v_gate_adaln_single = LTX2AdaLayerNormSingle( + self.hidden_size, embedding_coefficient=1 + ) + self.av_ca_audio_scale_shift_adaln_single = LTX2AdaLayerNormSingle( + self.audio_hidden_size, embedding_coefficient=4 + ) + self.av_ca_v2a_gate_adaln_single = LTX2AdaLayerNormSingle( + self.audio_hidden_size, embedding_coefficient=1 + ) + + # Output Layer Scale/Shift Modulation parameters + self.scale_shift_table = nn.Parameter( + torch.randn(2, self.hidden_size) / self.hidden_size**0.5 + ) + self.audio_scale_shift_table = nn.Parameter( + torch.randn(2, self.audio_hidden_size) / self.audio_hidden_size**0.5 + ) + + hf_patch_size = int(hf_config.get("patch_size", 1)) + hf_patch_size_t = int(hf_config.get("patch_size_t", 1)) + self.patch_size = (hf_patch_size_t, hf_patch_size, hf_patch_size) + + hf_audio_patch_size = int(hf_config.get("audio_patch_size", 1)) + hf_audio_patch_size_t = int(hf_config.get("audio_patch_size_t", 1)) + + rope_type = ( + arch.rope_type.value + if hasattr(arch.rope_type, "value") + else str(arch.rope_type) + ) + rope_double_precision = bool( + hf_config.get("rope_double_precision", arch.double_precision_rope) + ) + causal_offset = int(hf_config.get("causal_offset", 1)) + + pos_embed_max_pos = int(arch.positional_embedding_max_pos[0]) + base_height = int(arch.positional_embedding_max_pos[1]) + base_width = int(arch.positional_embedding_max_pos[2]) + + audio_pos_embed_max_pos = int(arch.audio_positional_embedding_max_pos[0]) + + self.video_scale_factors = (8, 32, 32) + self.audio_scale_factors = (4,) + + self.rope = LTX2AudioVideoRotaryPosEmbed( + dim=self.hidden_size, + patch_size=hf_patch_size, + patch_size_t=hf_patch_size_t, + base_num_frames=pos_embed_max_pos, + base_height=base_height, + base_width=base_width, + scale_factors=self.video_scale_factors, + theta=float(arch.positional_embedding_theta), + causal_offset=causal_offset, + modality="video", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=self.num_attention_heads, + ) + self.audio_rope = LTX2AudioVideoRotaryPosEmbed( + dim=self.audio_hidden_size, + patch_size=hf_audio_patch_size, + patch_size_t=hf_audio_patch_size_t, + base_num_frames=audio_pos_embed_max_pos, + sampling_rate=16000, + hop_length=160, + scale_factors=self.audio_scale_factors, + theta=float(arch.positional_embedding_theta), + causal_offset=causal_offset, + modality="audio", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=self.audio_num_attention_heads, + ) + + cross_attn_pos_embed_max_pos = max(pos_embed_max_pos, audio_pos_embed_max_pos) + self.cross_attn_rope = LTX2AudioVideoRotaryPosEmbed( + dim=int(arch.audio_cross_attention_dim), + patch_size=hf_patch_size, + patch_size_t=hf_patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + base_height=base_height, + base_width=base_width, + theta=float(arch.positional_embedding_theta), + causal_offset=causal_offset, + modality="video", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=self.num_attention_heads, + ) + self.cross_attn_audio_rope = LTX2AudioVideoRotaryPosEmbed( + dim=int(arch.audio_cross_attention_dim), + patch_size=hf_audio_patch_size, + patch_size_t=hf_audio_patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + sampling_rate=16000, + hop_length=160, + theta=float(arch.positional_embedding_theta), + causal_offset=causal_offset, + modality="audio", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=self.audio_num_attention_heads, + ) + + self.cross_pe_max_pos = cross_attn_pos_embed_max_pos + + # 5. Transformer Blocks + self.transformer_blocks = nn.ModuleList( + [ + LTX2TransformerBlock( + idx=idx, + dim=self.hidden_size, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.hidden_size // self.num_attention_heads, + cross_attention_dim=arch.cross_attention_dim, + audio_dim=self.audio_hidden_size, + audio_num_attention_heads=self.audio_num_attention_heads, + audio_attention_head_dim=self.audio_hidden_size + // self.audio_num_attention_heads, + audio_cross_attention_dim=arch.audio_cross_attention_dim, + norm_eps=self.norm_eps, + qk_norm=True, # Always True in LTX2 + supported_attention_backends=self._supported_attention_backends, + prefix=config.prefix, + quant_config=quant_config, + ) + for idx in range(arch.num_layers) + ] + ) + + # 6. Output layers + self.norm_out = nn.LayerNorm( + self.hidden_size, eps=self.norm_eps, elementwise_affine=False + ) + self.proj_out = ColumnParallelLinear( + self.hidden_size, + arch.out_channels, + bias=True, + gather_output=True, + quant_config=quant_config, + ) + + self.audio_norm_out = nn.LayerNorm( + self.audio_hidden_size, eps=self.norm_eps, elementwise_affine=False + ) + self.audio_proj_out = ColumnParallelLinear( + self.audio_hidden_size, + arch.audio_out_channels, + bias=True, + gather_output=True, + quant_config=quant_config, + ) + + self.out_channels_raw = arch.out_channels // ( + self.patch_size[0] * self.patch_size[1] * self.patch_size[2] + ) + self.audio_out_channels = arch.audio_out_channels + self.timestep_scale_multiplier = arch.timestep_scale_multiplier + self.av_ca_timestep_scale_multiplier = arch.av_ca_timestep_scale_multiplier + + self.layer_names = ["transformer_blocks"] + + def forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + audio_encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + audio_timestep: Optional[torch.LongTensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + audio_encoder_attention_mask: Optional[torch.Tensor] = None, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + fps: float = 24.0, + audio_num_frames: Optional[int] = None, + video_coords: Optional[torch.Tensor] = None, + audio_coords: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + + batch_size = hidden_states.size(0) + audio_timestep = audio_timestep if audio_timestep is not None else timestep + + if num_frames is None or height is None or width is None: + raise ValueError( + "num_frames/height/width must be provided for RoPE coordinate generation." + ) + if audio_num_frames is None: + raise ValueError( + "audio_num_frames must be provided for RoPE coordinate generation." + ) + + if video_coords is None: + # Wan-style SP-RoPE: when SP is enabled, each rank runs on its local + # time shard but RoPE positions must be offset to global time. + # + # We assume equal time sharding across SP ranks. + if model_parallel_is_initialized(): + sp_world_size = get_sp_world_size() + sp_rank = get_sp_parallel_rank() + else: + sp_world_size = 1 + sp_rank = 0 + + video_shift = int(sp_rank) * int(num_frames) if sp_world_size > 1 else 0 + video_coords = self.rope.prepare_video_coords( + batch_size=batch_size, + num_frames=num_frames, + height=height, + width=width, + device=hidden_states.device, + fps=fps, + start_frame=video_shift, + ) + if audio_coords is None: + audio_coords = self.audio_rope.prepare_audio_coords( + batch_size=batch_size, + num_frames=audio_num_frames, + device=audio_hidden_states.device, + ) + + video_rotary_emb = self.rope(video_coords, device=hidden_states.device) + audio_rotary_emb = self.audio_rope( + audio_coords, device=audio_hidden_states.device + ) + ca_video_rotary_emb = self.cross_attn_rope( + video_coords[:, 0:1, :], device=hidden_states.device + ) + ca_audio_rotary_emb = self.cross_attn_audio_rope( + audio_coords[:, 0:1, :], device=audio_hidden_states.device + ) + + # 2. Patchify input projections + hidden_states, _ = self.patchify_proj(hidden_states) + audio_hidden_states, _ = self.audio_patchify_proj(audio_hidden_states) + + # 3. Prepare timestep embeddings + # 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters + temb, embedded_timestep = self.adaln_single( + timestep.flatten(), + ) + temb = temb.view(batch_size, -1, temb.size(-1)) + embedded_timestep = embedded_timestep.view( + batch_size, -1, embedded_timestep.size(-1) + ) + + temb_audio, audio_embedded_timestep = self.audio_adaln_single( + audio_timestep.flatten() + ) + temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1)) + audio_embedded_timestep = audio_embedded_timestep.view( + batch_size, -1, audio_embedded_timestep.size(-1) + ) + + # 3.2. Prepare global modality cross attention modulation parameters + ts_ca_mult = ( + self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier + ) + + hidden_dtype = hidden_states.dtype + temb_ca_scale_shift, _ = self.av_ca_video_scale_shift_adaln_single( + timestep.flatten(), hidden_dtype=hidden_dtype + ) + temb_ca_scale_shift = temb_ca_scale_shift.view( + batch_size, -1, temb_ca_scale_shift.shape[-1] + ) + + temb_ca_gate, _ = self.av_ca_a2v_gate_adaln_single( + timestep.flatten() * self.av_ca_timestep_scale_multiplier, + hidden_dtype=hidden_dtype, + ) + temb_ca_gate = temb_ca_gate.view(batch_size, -1, temb_ca_gate.shape[-1]) + + temb_ca_audio_scale_shift, _ = self.av_ca_audio_scale_shift_adaln_single( + audio_timestep.flatten(), hidden_dtype=audio_hidden_states.dtype + ) + temb_ca_audio_scale_shift = temb_ca_audio_scale_shift.view( + batch_size, -1, temb_ca_audio_scale_shift.shape[-1] + ) + + temb_ca_audio_gate, _ = self.av_ca_v2a_gate_adaln_single( + audio_timestep.flatten() * self.av_ca_timestep_scale_multiplier, + hidden_dtype=audio_hidden_states.dtype, + ) + temb_ca_audio_gate = temb_ca_audio_gate.view( + batch_size, -1, temb_ca_audio_gate.shape[-1] + ) + + # 4. Prepare prompt embeddings + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + audio_encoder_hidden_states = self.audio_caption_projection( + audio_encoder_hidden_states + ) + + # 5. Run blocks + for block in self.transformer_blocks: + hidden_states, audio_hidden_states = block( + hidden_states, + audio_hidden_states, + encoder_hidden_states, + audio_encoder_hidden_states, + # Keep the first 4 args positional to stay compatible with cache-dit's + # LTX2 adapter, which treats `audio_hidden_states` as `encoder_hidden_states` + # under ForwardPattern.Pattern_0. + temb=temb, + temb_audio=temb_audio, + temb_ca_scale_shift=temb_ca_scale_shift, + temb_ca_audio_scale_shift=temb_ca_audio_scale_shift, + temb_ca_gate=temb_ca_gate, + temb_ca_audio_gate=temb_ca_audio_gate, + video_rotary_emb=video_rotary_emb, + audio_rotary_emb=audio_rotary_emb, + ca_video_rotary_emb=ca_video_rotary_emb, + ca_audio_rotary_emb=ca_audio_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=audio_encoder_attention_mask, + ) + + # 6. Output layers + # Video + scale_shift_values = self.scale_shift_table[None, None].to( + device=hidden_states.device, dtype=hidden_states.dtype + ) + embedded_timestep[:, :, None].to(dtype=hidden_states.dtype) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + with torch.autocast(device_type=hidden_states.device.type, enabled=False): + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states, _ = self.proj_out(hidden_states) + + # Audio + audio_scale_shift_values = self.audio_scale_shift_table[None, None].to( + device=audio_hidden_states.device, dtype=audio_hidden_states.dtype + ) + audio_embedded_timestep[:, :, None].to(dtype=audio_hidden_states.dtype) + audio_shift, audio_scale = ( + audio_scale_shift_values[:, :, 0], + audio_scale_shift_values[:, :, 1], + ) + with torch.autocast(device_type=audio_hidden_states.device.type, enabled=False): + audio_hidden_states = self.audio_norm_out(audio_hidden_states) + audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift + audio_hidden_states, _ = self.audio_proj_out(audio_hidden_states) + + # Unpatchify if requested (default True for pipeline compatibility) + return_latents = kwargs.get("return_latents", True) + + if return_latents: + # Unpatchify Video + # [B, N, C_out_raw*patch_vol] -> [B, C_out_raw, T, H, W] + # Requires num_frames, height, width to be known + if num_frames is not None and height is not None and width is not None: + p_t, p_h, p_w = self.patch_size + post_t, post_h, post_w = num_frames // p_t, height // p_h, width // p_w + b = batch_size + hidden_states = hidden_states.reshape( + b, post_t, post_h, post_w, self.out_channels_raw, p_t, p_h, p_w + ) + hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7).reshape( + b, self.out_channels_raw, num_frames, height, width + ) + + # Unpatchify Audio + # [B, N, C_out] -> [B, C_out, T] (or 4D/5D) + if audio_num_frames is not None: + b = batch_size + # simple reshape for 1D patch + audio_hidden_states = audio_hidden_states.permute(0, 2, 1) # [B, C, T] + + return hidden_states, audio_hidden_states + + +# Backward-compatible alias (older internal name). +LTXModel = LTX2VideoTransformer3DModel +EntryClass = LTX2VideoTransformer3DModel diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py b/sglang/python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..f3deaf649e9a5f9ad79750f2bb6da190a58a5a4c --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py @@ -0,0 +1,267 @@ +# Copied and adapted from: mossVG/mova/diffusion/models/wan_audio_dit.py +# SPDX-License-Identifier: Apache-2.0 +# +# NOTE: This module reuses common functions from mova_video_dit.py to reduce code duplication. +# Audio-specific functions (precompute_freqs_cis_1d, legacy_precompute_freqs_cis_1d) are kept here. + +import math +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange +from torch.distributed.tensor import DTensor + +from sglang.multimodal_gen.configs.models.dits.mova_audio import MOVAAudioConfig +from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear +from sglang.multimodal_gen.runtime.layers.mlp import MLP +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, +) +from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin + +# Reuse common functions and classes from mova_video_dit +from .mova_video_dit import DiTBlock, precompute_freqs_cis, sinusoidal_embedding_1d + + +# Audio-specific positional encoding functions +def legacy_precompute_freqs_cis_1d( + dim: int, + end: int = 16384, + theta: float = 10000.0, + base_tps=4.0, + target_tps=44100 / 2048, +): + s = float(base_tps) / float(target_tps) + # 1d rope precompute + f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta, s) + # No positional encoding is applied to the remaining dimensions + no_freqs_cis = precompute_freqs_cis(dim // 3, end, theta, s) + no_freqs_cis = torch.ones_like(no_freqs_cis) + return f_freqs_cis, no_freqs_cis, no_freqs_cis + + +def precompute_freqs_cis_1d(dim: int, end: int = 16384, theta: float = 10000.0): + f_freqs_cis = precompute_freqs_cis(dim, end, theta) + return f_freqs_cis.chunk(3, dim=-1) + + +class Head(nn.Module): + def __init__( + self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float + ): + super().__init__() + self.dim = dim + self.patch_size = patch_size + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.head = ReplicatedLinear(dim, out_dim * math.prod(patch_size)) + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, t_mod): + if len(t_mod.shape) == 3: + shift, scale = ( + self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + + t_mod.unsqueeze(2) + ).chunk(2, dim=2) + x, _ = self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2)) + else: + # NOTE: t_mod was originally [B, C]. This works correctly with broadcasting when B=1, but it won't match [1, 2, C] when B > 1. + shift, scale = ( + self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + + t_mod.unsqueeze(1) + ).chunk(2, dim=1) + x, _ = self.head(self.norm(x) * (1 + scale) + shift) + return x + + +class Conv1dLocalIsland(nn.Conv1d): + """Inherits from Conv1d and overrides forward. + + - Parameters remain as DTensors (optimizer consistency is maintained). + - In the forward pass, x, weight, and bias are aggregated as Replicate, + and then local convolution is performed via to_local. + - The output is then redistributed as a DTensor (default is Replicate, + placements can be customized). + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + if isinstance(input, DTensor): + x_local = input.to_local() # type: ignore[attr-defined] + w_local = self.weight.to_local() # type: ignore[attr-defined] + b_local = ( + self.bias.to_local() if self.bias is not None else None # type: ignore[attr-defined] + ) + + return self._conv_forward(x_local, w_local, b_local) + else: + return super().forward(input) + + +class WanAudioModel(CachableDiT, OffloadableDiTMixin): + _fsdp_shard_conditions = MOVAAudioConfig()._fsdp_shard_conditions + _compile_conditions = MOVAAudioConfig()._compile_conditions + _supported_attention_backends = MOVAAudioConfig()._supported_attention_backends + param_names_mapping = MOVAAudioConfig().param_names_mapping + reverse_param_names_mapping = MOVAAudioConfig().reverse_param_names_mapping + lora_param_names_mapping = MOVAAudioConfig().lora_param_names_mapping + + def __init__( + self, + config: MOVAAudioConfig, + hf_config: dict[str, Any], + quant_config: QuantizationConfig | None = None, + ) -> None: + super().__init__(config=config, hf_config=hf_config) + + # Extract parameters from config + dim = config.dim + in_dim = config.in_dim + ffn_dim = config.ffn_dim + out_dim = config.out_dim + text_dim = config.text_dim + freq_dim = config.freq_dim + eps = config.eps + patch_size = config.patch_size + num_heads = config.num_heads + num_layers = config.num_layers + has_image_pos_emb = config.has_image_pos_emb + has_ref_conv = config.has_ref_conv + separated_timestep = config.separated_timestep + require_vae_embedding = config.require_vae_embedding + require_clip_embedding = config.require_clip_embedding + fuse_vae_embedding_in_latents = config.fuse_vae_embedding_in_latents + vae_type = config.vae_type + + self.dim = dim + self.freq_dim = freq_dim + self.patch_size = patch_size + self.separated_timestep = separated_timestep + self.require_vae_embedding = require_vae_embedding + self.require_clip_embedding = require_clip_embedding + self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents + self.vae_type = vae_type + # self.patch_embedding = nn.Conv3d( + # in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.patch_embedding = Conv1dLocalIsland( + in_dim, dim, kernel_size=patch_size, stride=patch_size + ) + self.text_embedding = MLP( + text_dim, + dim, + output_dim=dim, + act_type="gelu_pytorch_tanh", + quant_config=quant_config, + ) + self.time_embedding = MLP( + freq_dim, dim, output_dim=dim, act_type="silu", quant_config=quant_config + ) + # Preserve state_dict keys (time_projection.1.weight/bias). + self.time_projection = nn.Sequential( + nn.SiLU(), ReplicatedLinear(dim, dim * 6, quant_config=quant_config) + ) + self.blocks = nn.ModuleList( + [ + DiTBlock(dim, num_heads, ffn_dim, eps, quant_config=quant_config) + for _ in range(num_layers) + ] + ) + self.head = Head(dim, out_dim, patch_size, eps) + self.num_heads = num_heads + self.freqs = None + self.img_pos_emb = None + if has_ref_conv: + self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2)) + self.has_image_pos_emb = has_image_pos_emb + self.has_ref_conv = has_ref_conv + self.hidden_size = dim + self.num_attention_heads = num_heads + self.num_channels_latents = out_dim + self.layer_names = ["blocks"] + self.cnt = 0 + self.teacache_thresh = 0 + self.coefficients = [] + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.previous_resiual = None + self.previous_e0_even = None + self.previous_e0_odd = None + self.previous_residual_even = None + self.previous_residual_odd = None + self.is_even = False + self.should_calc_even = True + self.should_calc_odd = True + self.accumulated_rel_l1_distance_even = 0 + self.accumulated_rel_l1_distance_odd = 0 + self.__post_init__() + + def _init_freqs(self): + if self.freqs is not None: + return + head_dim = self.dim // self.num_heads + if self.vae_type == "dac": + self.freqs = precompute_freqs_cis_1d(head_dim) + else: + raise ValueError(f"Invalid VAE type: {self.vae_type}") + + def patchify( + self, + x: torch.Tensor, + control_camera_latents_input: Optional[torch.Tensor] = None, + ): + x = self.patch_embedding(x) + grid_size = x.shape[2:] + x = rearrange(x, "b c f -> b f c").contiguous() + return x, grid_size # x, grid_size: (f) + + def unpatchify(self, x: torch.Tensor, grid_size: tuple[int]): + return rearrange( + x, "b f (p c) -> b c (f p)", f=grid_size[0], p=self.patch_size[0] + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + timestep: torch.LongTensor, + ) -> torch.Tensor: + # MOVA audio uses x/context naming historically. + x = hidden_states + context = ( + encoder_hidden_states[0] + if isinstance(encoder_hidden_states, list) + else encoder_hidden_states + ) + + t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_proj, _ = self.time_projection(t) + t_mod = t_proj.unflatten(1, (6, self.dim)) + context = self.text_embedding(context) + + x, (f,) = self.patchify(x) + + freqs = ( + torch.cat( + [ + self.freqs[0][:f].view(f, -1).expand(f, -1), + self.freqs[1][:f].view(f, -1).expand(f, -1), + self.freqs[2][:f].view(f, -1).expand(f, -1), + ], + dim=-1, + ) + .reshape(f, 1, -1) + .to(x.device) + ) + + for block in self.blocks: + x = block(x, context, t_mod, freqs) + + x = self.head(x, t) + x = self.unpatchify(x, (f,)) + return x + + +EntryClass = WanAudioModel diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py b/sglang/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..21de5b9b37bb7601a698e14907894f6476a8e264 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py @@ -0,0 +1,584 @@ +# Copied and adapted from: mossVG/mova/diffusion/models/wan_video_dit.py +# SPDX-License-Identifier: Apache-2.0 +# +# NOTE: This module shares common functions (sinusoidal_embedding_1d, precompute_freqs_cis, etc.) +# with wanvideo.py. These functions are kept here for MOVA-specific model architecture, +# but could be refactored to a common module in the future. + +import math +from typing import Any, Tuple + +import torch +import torch.nn as nn +from einops import rearrange +from torch.distributed.tensor import DTensor + +from sglang.multimodal_gen.configs.models.dits.mova_video import MOVAVideoConfig +from sglang.multimodal_gen.runtime.distributed import get_tp_world_size +from sglang.multimodal_gen.runtime.layers.attention import LocalAttention, USPAttention + +# Reuse SGLang's optimized RMSNorm instead of torch.nn.RMSNorm or custom SlowRMSNorm +from sglang.multimodal_gen.runtime.layers.layernorm import ( + LayerNormScaleShift, + RMSNorm, + ScaleResidualLayerNormScaleShift, + tensor_parallel_rms_norm, +) +from sglang.multimodal_gen.runtime.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.mlp import MLP +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, +) +from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +# @torch.compile(fullgraph=True) +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): + return x * (1 + scale) + shift + + +def sinusoidal_embedding_1d(dim, position): + sinusoid = torch.outer( + position.type(torch.float64), + torch.pow( + 10000, + -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div( + dim // 2 + ), + ), + ) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x.to(position.dtype) + + +def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): + # 3d rope precompute + f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) + h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + return f_freqs_cis, h_freqs_cis, w_freqs_cis + + +def precompute_freqs_cis( + dim: int, end: int = 1024, theta: float = 10000.0, s: float = 1.0 +): + # 1d rope precompute + # Note: s parameter is used for audio-specific scaling (e.g., tps adjustment) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].double() / dim)) + pos = torch.arange(end, dtype=torch.float64, device=freqs.device) * s + freqs = torch.outer(pos, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def rope_apply(x, freqs, num_heads): + x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + x_out = torch.view_as_complex( + x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2) + ) + x_out = torch.view_as_real(x_out * freqs).flatten(2) + return x_out.to(x.dtype) + + +def rope_apply_head_dim(x, freqs, head_dim): + x = rearrange(x, "b s (n d) -> b s n d", d=head_dim) + x_out = torch.view_as_complex( + x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2) + ) + # print(f"{x_out.shape = }, {freqs.shape = }") + x_out = torch.view_as_real(x_out * freqs).flatten(2) + return x_out.to(x.dtype) + + +class SelfAttention(nn.Module): + """ + Self-Attention module for MOVA DiT with Sequence Parallelism support. + + SP is handled at the pipeline level (latents are pre-sharded before DiT forward). + USPAttention internally handles the all-to-all communication for distributed attention. + Input x should already be the local shard [B, S_local, D] when SP is enabled. + """ + + def __init__( + self, + dim: int, + num_heads: int, + eps: float = 1e-6, + quant_config: QuantizationConfig | None = None, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.tp_size = get_tp_world_size() + if self.num_heads % self.tp_size != 0: + raise ValueError( + f"num_heads ({self.num_heads}) must be divisible by tp_size ({self.tp_size})." + ) + self.num_heads_per_rank = self.num_heads // self.tp_size + + # TP strategy: shard Q/K/V over heads (column-parallel), then row-parallel output. + self.q = ColumnParallelLinear( + dim, dim, bias=True, gather_output=False, quant_config=quant_config + ) + self.k = ColumnParallelLinear( + dim, dim, bias=True, gather_output=False, quant_config=quant_config + ) + self.v = ColumnParallelLinear( + dim, dim, bias=True, gather_output=False, quant_config=quant_config + ) + self.o = RowParallelLinear( + dim, dim, bias=True, input_is_parallel=True, quant_config=quant_config + ) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + + self.attn = USPAttention( + # Local heads per TP rank. + num_heads=self.num_heads_per_rank, + head_size=self.head_dim, + causal=False, + softmax_scale=None, + ) + + def forward(self, x, freqs): + """ + Forward pass for self-attention. + + Args: + x: Input tensor [B, S_local, D] - already sharded by SP when SP > 1 + freqs: RoPE frequencies [S_local, 1, head_dim] - should match x's sequence length + + Returns: + Output tensor [B, S_local, D] + """ + if isinstance(freqs, DTensor): + freqs = freqs.to_local() + + # Compute Q, K, V on local sequence + q, _ = self.q(x) + k, _ = self.k(x) + v, _ = self.v(x) + + # RMSNorm over sharded hidden dimension. + if self.tp_size > 1: + q = tensor_parallel_rms_norm(q, self.norm_q) + k = tensor_parallel_rms_norm(k, self.norm_k) + else: + q = self.norm_q(q) + k = self.norm_k(k) + + # Apply RoPE + q = rope_apply_head_dim(q, freqs, self.head_dim) + k = rope_apply_head_dim(k, freqs, self.head_dim) + + # USPAttention expects [B, S_local, H, D] format + q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads_per_rank) + k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads_per_rank) + v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads_per_rank) + + # USPAttention handles SP communication internally + out = self.attn(q, k, v) + out = rearrange(out, "b s n d -> b s (n d)") + + out, _ = self.o(out) + return out + + +class CrossAttention(nn.Module): + """ + Cross-Attention module for MOVA DiT. + + Cross-attention does NOT require SP communication because: + - Query comes from the main sequence (already sharded by SP) + - Key/Value come from context (text embeddings, which are replicated across all ranks) + + Uses LocalAttention instead of USPAttention for efficiency. + """ + + def __init__( + self, + dim: int, + num_heads: int, + eps: float = 1e-6, + quant_config: QuantizationConfig | None = None, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.tp_size = get_tp_world_size() + if self.num_heads % self.tp_size != 0: + raise ValueError( + f"num_heads ({self.num_heads}) must be divisible by tp_size ({self.tp_size})." + ) + self.num_heads_per_rank = self.num_heads // self.tp_size + + self.q = ColumnParallelLinear( + dim, dim, bias=True, gather_output=False, quant_config=quant_config + ) + self.k = ColumnParallelLinear( + dim, dim, bias=True, gather_output=False, quant_config=quant_config + ) + self.v = ColumnParallelLinear( + dim, dim, bias=True, gather_output=False, quant_config=quant_config + ) + self.o = RowParallelLinear( + dim, dim, bias=True, input_is_parallel=True, quant_config=quant_config + ) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + + # Use LocalAttention for cross-attention (no SP communication needed) + self.attn = LocalAttention( + num_heads=self.num_heads_per_rank, + head_size=self.head_dim, + causal=False, + softmax_scale=None, + ) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + """ + Forward pass for cross-attention. + + Args: + x: Query tensor [B, S_local, D] - the main sequence (sharded by SP) + y: Context tensor [B, S_ctx, D] - text/image embeddings (replicated) + + Returns: + Output tensor [B, S_local, D] + """ + ctx = y + + q, _ = self.q(x) + k, _ = self.k(ctx) + v, _ = self.v(ctx) + + if self.tp_size > 1: + q = tensor_parallel_rms_norm(q, self.norm_q) + k = tensor_parallel_rms_norm(k, self.norm_k) + else: + q = self.norm_q(q) + k = self.norm_k(k) + + q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads_per_rank) + k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads_per_rank) + v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads_per_rank) + x = self.attn(q, k, v) + x = rearrange(x, "b s n d -> b s (n d)") + x, _ = self.o(x) + return x + + +class MulAdd(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, gate, residual): + return residual + gate * x + + +class DiTBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + ffn_dim: int, + eps: float = 1e-6, + quant_config: QuantizationConfig | None = None, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.ffn_dim = ffn_dim + + self.self_attn = SelfAttention(dim, num_heads, eps, quant_config=quant_config) + self.cross_attn = CrossAttention(dim, num_heads, eps, quant_config=quant_config) + self.norm1 = LayerNormScaleShift( + dim, eps=eps, elementwise_affine=False, dtype=torch.float32 + ) + self.self_attn_norm = nn.LayerNorm(dim, eps=eps) + # Fused: residual + 1 * cross_attn_out → layernorm + scale/shift + # Replaces the old norm2 (LayerNormScaleShift) + residual add for cross-attention + self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift( + dim, eps=eps, elementwise_affine=False, dtype=torch.float32 + ) + self.ffn = MLP( + dim, + ffn_dim, + output_dim=dim, + act_type="gelu_pytorch_tanh", + quant_config=quant_config, + ) + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + self.mlp_residual = MulAdd() + + def forward(self, x, context, t_mod, freqs): + has_seq = len(t_mod.shape) == 4 + chunk_dim = 2 if has_seq else 1 + # msa: multi-head self-attention mlp: multi-layer perceptron + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod + ).chunk(6, dim=chunk_dim) + if has_seq: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + shift_msa.squeeze(2), + scale_msa.squeeze(2), + gate_msa.squeeze(2), + shift_mlp.squeeze(2), + scale_mlp.squeeze(2), + gate_mlp.squeeze(2), + ) + orig_dtype = x.dtype + # 1. Self-attention, fuse: + # - layernorm(x) * (1 + scale_msa) + shift_msa + input_x = self.norm1(x, shift_msa, scale_msa) + # 2. torch.compile may fuse mlp_residual and self_attn_norm + x = self.mlp_residual(self.self_attn(input_x, freqs), gate_msa, x) + norm_x = self.self_attn_norm(x) + # 3. Cross-attention, fuse: + # - x = x + 1 * cross_output + # - input_x = layernorm(x) * (1 + scale_mlp) + shift_mlp + cross_output = self.cross_attn(norm_x, context) + input_x, x = self.cross_attn_residual_norm( + x, cross_output, 1, shift_mlp, scale_mlp + ) + # 4. Feed-forward + x = self.mlp_residual(self.ffn(input_x), gate_mlp, x) + x = x.to(orig_dtype) + return x + + +class Head(nn.Module): + def __init__( + self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float + ): + super().__init__() + self.dim = dim + self.patch_size = patch_size + self.norm = LayerNormScaleShift( + dim, eps=eps, elementwise_affine=False, dtype=torch.float32 + ) + # Output dim is small for MOVA; replicate to avoid TP shape coupling. + self.head = ReplicatedLinear(dim, out_dim * math.prod(patch_size)) + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, t_mod): + if len(t_mod.shape) == 3: + shift, scale = ( + self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + + t_mod.unsqueeze(2) + ).chunk(2, dim=2) + x, _ = self.head(self.norm(x, shift.squeeze(2), scale.squeeze(2))) + else: + shift, scale = ( + self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod + ).chunk(2, dim=1) + x, _ = self.head(self.norm(x, shift, scale)) + return x + + +class Conv3dLocalIsland(nn.Conv3d): + """ + Inherits from Conv3d and overrides the forward method. + + Key behaviors: + - Parameters are kept as DTensor to maintain optimizer consistency. + - The forward pass aggregates input, weight, and bias into a Replicate state, + then performs the convolution locally using to_local(). + - The output is then redistributed as a DTensor (defaults to Replicate, + but placements can be customized). + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + if isinstance(input, DTensor): + # NOTE: DTensor typing stubs are incomplete; at runtime DTensor has + # to_local() and parameters may also be DTensor. + x_local = input.to_local() # type: ignore[attr-defined] + w_local = self.weight.to_local() # type: ignore[attr-defined] + b_local = ( + self.bias.to_local() if self.bias is not None else None # type: ignore[attr-defined] + ) + + return self._conv_forward(x_local, w_local, b_local) + else: + return super().forward(input) + + +class WanModel(CachableDiT, OffloadableDiTMixin): + _fsdp_shard_conditions = MOVAVideoConfig()._fsdp_shard_conditions + _compile_conditions = MOVAVideoConfig()._compile_conditions + _supported_attention_backends = MOVAVideoConfig()._supported_attention_backends + param_names_mapping = MOVAVideoConfig().param_names_mapping + reverse_param_names_mapping = MOVAVideoConfig().reverse_param_names_mapping + lora_param_names_mapping = MOVAVideoConfig().lora_param_names_mapping + + def __init__( + self, + config: MOVAVideoConfig, + hf_config: dict[str, Any], + quant_config: QuantizationConfig | None = None, + ) -> None: + super().__init__(config=config, hf_config=hf_config) + + # Extract parameters from config + dim = config.dim + in_dim = config.in_dim + ffn_dim = config.ffn_dim + out_dim = config.out_dim + text_dim = config.text_dim + freq_dim = config.freq_dim + eps = config.eps + patch_size = config.patch_size + num_heads = config.num_heads + num_layers = config.num_layers + has_image_pos_emb = config.has_image_pos_emb + has_ref_conv = config.has_ref_conv + separated_timestep = config.separated_timestep + require_vae_embedding = config.require_vae_embedding + require_clip_embedding = config.require_clip_embedding + fuse_vae_embedding_in_latents = config.fuse_vae_embedding_in_latents + + self.dim = dim + self.freq_dim = freq_dim + self.patch_size = patch_size + self.separated_timestep = separated_timestep + self.require_vae_embedding = require_vae_embedding + self.require_clip_embedding = require_clip_embedding + self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents + + self.patch_embedding = Conv3dLocalIsland( + in_dim, dim, kernel_size=patch_size, stride=patch_size + ) + self.text_embedding = MLP( + text_dim, + dim, + output_dim=dim, + act_type="gelu_pytorch_tanh", + quant_config=quant_config, + ) + self.time_embedding = MLP( + freq_dim, dim, output_dim=dim, act_type="silu", quant_config=quant_config + ) + # Preserve state_dict keys (time_projection.1.weight/bias). + self.time_projection = nn.Sequential( + nn.SiLU(), ReplicatedLinear(dim, dim * 6, quant_config=quant_config) + ) + self.blocks = nn.ModuleList( + [ + DiTBlock(dim, num_heads, ffn_dim, eps, quant_config=quant_config) + for _ in range(num_layers) + ] + ) + self.head = Head(dim, out_dim, patch_size, eps) + self.num_heads = num_heads + self.freqs = None + + if has_ref_conv: + self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2)) + self.has_image_pos_emb = has_image_pos_emb + self.has_ref_conv = has_ref_conv + self.hidden_size = dim + self.num_attention_heads = num_heads + self.num_channels_latents = out_dim + self.layer_names = ["blocks"] + self.cnt = 0 + self.teacache_thresh = 0 + self.coefficients = [] + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.previous_resiual = None + self.previous_e0_even = None + self.previous_e0_odd = None + self.previous_residual_even = None + self.previous_residual_odd = None + self.is_even = False + self.should_calc_even = True + self.should_calc_odd = True + self.accumulated_rel_l1_distance_even = 0 + self.accumulated_rel_l1_distance_odd = 0 + self.__post_init__() + + def _init_freqs(self): + if self.freqs is not None: + return + head_dim = self.dim // self.num_heads + self.freqs = precompute_freqs_cis_3d(head_dim) + + def patchify( + self, x: torch.Tensor, control_camera_latents_input: torch.Tensor | None = None + ): + # NOTE(dhyu): avoid slow_conv + x = x.contiguous(memory_format=torch.channels_last_3d) + x = self.patch_embedding(x) + grid_size = x.shape[2:] + x = rearrange(x, "b c f h w -> b (f h w) c").contiguous() + return x, grid_size # x, grid_size: (f, h, w) + + def unpatchify(self, x: torch.Tensor, grid_size: tuple[int, int, int]): + return rearrange( + x, + "b (f h w) (x y z c) -> b c (f x) (h y) (w z)", + f=grid_size[0], + h=grid_size[1], + w=grid_size[2], + x=self.patch_size[0], + y=self.patch_size[1], + z=self.patch_size[2], + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + timestep: torch.LongTensor, + ) -> torch.Tensor: + # MOVA code historically uses x/context/y/clip_feature naming. + x = hidden_states + context = ( + encoder_hidden_states[0] + if isinstance(encoder_hidden_states, list) + else encoder_hidden_states + ) + t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_proj, _ = self.time_projection(t) + t_mod = t_proj.unflatten(1, (6, self.dim)) + context = self.text_embedding(context) + + x, (f, h, w) = self.patchify(x) + + freqs = ( + torch.cat( + [ + self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ) + .reshape(f * h * w, 1, -1) + .to(x.device) + ) + + for block in self.blocks: + x = block(x, context, t_mod, freqs) + + x = self.head(x, t) + x = self.unpatchify(x, (f, h, w)) + return x + + +EntryClass = WanModel diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/sglang/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py new file mode 100644 index 0000000000000000000000000000000000000000..37afc24588a11b95239cbaa72e827e18a88fa981 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -0,0 +1,1211 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import functools +from typing import Any, Dict, List, Optional, Tuple, Union + +import diffusers +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.attention import FeedForward +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import AdaLayerNormContinuous + +from sglang.jit_kernel.diffusion.triton.scale_shift import ( + fuse_scale_shift_gate_select01_kernel, +) +from sglang.multimodal_gen.configs.models.dits.qwenimage import QwenImageDitConfig +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.layers.attention import USPAttention +from sglang.multimodal_gen.runtime.layers.elementwise import MulAdd +from sglang.multimodal_gen.runtime.layers.layernorm import ( + LayerNormScaleShift, + RMSNorm, + ScaleResidualLayerNormScaleShift, + apply_layernorm_only, + apply_qk_norm, +) +from sglang.multimodal_gen.runtime.layers.linear import ( + MergedColumnParallelLinear, + ReplicatedLinear, +) +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, +) +from sglang.multimodal_gen.runtime.layers.quantization.configs.nunchaku_config import ( + NunchakuConfig, + is_nunchaku_available, +) +from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( + apply_flashinfer_rope_qk_inplace, +) +from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) +from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) # pylint: disable=invalid-name +_is_cuda = current_platform.is_cuda() + +try: + from nunchaku.models.attention import NunchakuFeedForward # type: ignore[import] +except Exception: + NunchakuFeedForward = None + + +def _get_qkv_projections( + attn: "QwenImageCrossAttention", hidden_states, encoder_hidden_states=None +): + if attn.use_fused_qkv: + img_qkv, _ = attn.to_qkv(hidden_states) + img_query, img_key, img_value = [ + x.contiguous() for x in img_qkv.chunk(3, dim=-1) + ] + else: + img_query, _ = attn.to_q(hidden_states) + img_key, _ = attn.to_k(hidden_states) + img_value, _ = attn.to_v(hidden_states) + + txt_query = txt_key = txt_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + if attn.use_fused_added_qkv: + txt_qkv, _ = attn.to_added_qkv(encoder_hidden_states) + txt_query, txt_key, txt_value = [ + x.contiguous() for x in txt_qkv.chunk(3, dim=-1) + ] + else: + txt_query, _ = attn.add_q_proj(encoder_hidden_states) + txt_key, _ = attn.add_k_proj(encoder_hidden_states) + txt_value, _ = attn.add_v_proj(encoder_hidden_states) + + return img_query, img_key, img_value, txt_query, txt_key, txt_value + + +class QwenTimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, use_additional_t_cond=False): + super().__init__() + + self.time_proj = Timesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000 + ) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, time_embed_dim=embedding_dim + ) + self.use_additional_t_cond = use_additional_t_cond + if use_additional_t_cond: + self.addition_t_embedding = nn.Embedding(2, embedding_dim) + + def forward(self, timestep, hidden_states, addition_t_cond=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.to(dtype=hidden_states.dtype) + ) # (N, D) + + conditioning = timesteps_emb + if self.use_additional_t_cond: + if addition_t_cond is None: + raise ValueError( + "When additional_t_cond is True, addition_t_cond must be provided." + ) + addition_t_emb = self.addition_t_embedding(addition_t_cond) + addition_t_emb = addition_t_emb.to(dtype=hidden_states.dtype) + conditioning = conditioning + addition_t_emb + + return conditioning + + +class QwenEmbedRope(nn.Module): + def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + + # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + device = index.device + assert dim % 2 == 0 + freqs = torch.outer( + index, + ( + 1.0 + / torch.pow( + theta, + torch.arange(0, dim, 2, device=device).to(torch.float32).div(dim), + ) + ).to(device=device), + ) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def forward( + self, + video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], + txt_seq_lens: List[int], + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): + A list of 3 integers [frame, height, width] representing the shape of the video. + txt_seq_lens (`List[int]`): + A list of integers of length batch_size representing the length of each text prompt. + device: (`torch.device`): + The device on which to perform the RoPE computation. + """ + # When models are initialized under a "meta" device context (e.g. init_empty_weights), + # tensors created during __init__ become meta tensors. Calling .to(...) on a meta tensor + # raises "Cannot copy out of meta tensor". Rebuild the frequencies on the target device + # in that case; otherwise move them if just on a different device. + if getattr(self.pos_freqs, "device", torch.device("meta")).type == "meta": + pos_index = torch.arange(4096, device=device) + neg_index = torch.arange(4096, device=device).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ).to(device=device) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ).to(device=device) + elif self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + # RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs + video_freq = self._compute_video_freqs(frame, height, width, idx) + video_freq = video_freq.to(device) + vid_freqs.append(video_freq) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0).to(device=device) + return vid_freqs, txt_freqs + + @functools.lru_cache(maxsize=128) + def _compute_video_freqs( + self, frame: int, height: int, width: int, idx: int = 0 + ) -> torch.Tensor: + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = ( + freqs_pos[0][idx : idx + frame] + .view(frame, 1, 1, -1) + .expand(frame, height, width, -1) + ) + if self.scale_rope: + freqs_height = torch.cat( + [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], + dim=0, + ) + freqs_height = freqs_height.view(1, height, 1, -1).expand( + frame, height, width, -1 + ) + freqs_width = torch.cat( + [freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], + dim=0, + ) + freqs_width = freqs_width.view(1, 1, width, -1).expand( + frame, height, width, -1 + ) + else: + freqs_height = ( + freqs_pos[1][:height] + .view(1, height, 1, -1) + .expand(frame, height, width, -1) + ) + freqs_width = ( + freqs_pos[2][:width] + .view(1, 1, width, -1) + .expand(frame, height, width, -1) + ) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape( + seq_lens, -1 + ) + return freqs.clone().contiguous() + + +class QwenEmbedLayer3DRope(nn.Module): + def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + device = index.device + assert dim % 2 == 0 + freqs = torch.outer( + index, + ( + 1.0 + / torch.pow( + theta, + torch.arange(0, dim, 2, device=device).to(torch.float32).div(dim), + ) + ).to(device=device), + ) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def forward(self, video_fhw, txt_seq_lens, device): + """ + Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: + txt_length: [bs] a list of 1 integers representing the length of the text + """ + + # When models are initialized under a "meta" device context (e.g. init_empty_weights), + # tensors created during __init__ become meta tensors. Calling .to(...) on a meta tensor + # raises "Cannot copy out of meta tensor". Rebuild the frequencies on the target device + # in that case; otherwise move them if just on a different device. + if getattr(self.pos_freqs, "device", torch.device("meta")).type == "meta": + pos_index = torch.arange(4096, device=device) + neg_index = torch.arange(4096, device=device).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ).to(device=device) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ).to(device=device) + elif self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + layer_num = len(video_fhw) - 1 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + if idx != layer_num: + video_freq = self._compute_video_freqs(frame, height, width, idx) + else: + # For the condition image, we set the layer index to -1 + video_freq = self._compute_condition_freqs(frame, height, width) + video_freq = video_freq.to(device) + vid_freqs.append(video_freq) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_vid_index = max(max_vid_index, layer_num) + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + @functools.lru_cache(maxsize=None) + def _compute_video_freqs(self, frame, height, width, idx=0): + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = ( + freqs_pos[0][idx : idx + frame] + .view(frame, 1, 1, -1) + .expand(frame, height, width, -1) + ) + if self.scale_rope: + freqs_height = torch.cat( + [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], + dim=0, + ) + freqs_height = freqs_height.view(1, height, 1, -1).expand( + frame, height, width, -1 + ) + freqs_width = torch.cat( + [freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], + dim=0, + ) + freqs_width = freqs_width.view(1, 1, width, -1).expand( + frame, height, width, -1 + ) + else: + freqs_height = ( + freqs_pos[1][:height] + .view(1, height, 1, -1) + .expand(frame, height, width, -1) + ) + freqs_width = ( + freqs_pos[2][:width] + .view(1, 1, width, -1) + .expand(frame, height, width, -1) + ) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape( + seq_lens, -1 + ) + return freqs.clone().contiguous() + + @functools.lru_cache(maxsize=None) + def _compute_condition_freqs(self, frame, height, width): + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = ( + freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1) + ) + if self.scale_rope: + freqs_height = torch.cat( + [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], + dim=0, + ) + freqs_height = freqs_height.view(1, height, 1, -1).expand( + frame, height, width, -1 + ) + freqs_width = torch.cat( + [freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], + dim=0, + ) + freqs_width = freqs_width.view(1, 1, width, -1).expand( + frame, height, width, -1 + ) + else: + freqs_height = ( + freqs_pos[1][:height] + .view(1, height, 1, -1) + .expand(frame, height, width, -1) + ) + freqs_width = ( + freqs_pos[2][:width] + .view(1, 1, width, -1) + .expand(frame, height, width, -1) + ) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape( + seq_lens, -1 + ) + return freqs.clone().contiguous() + + +class QwenImageCrossAttention(nn.Module): + def __init__( + self, + dim: int, # query_dim + num_heads: int, + head_dim: int, + window_size=(-1, -1), + added_kv_proj_dim: int = None, + out_bias: bool = True, + qk_norm=True, # rmsnorm + eps=1e-6, + pre_only=False, + context_pre_only: bool = False, + parallel_attention=False, + out_dim: int = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + self.parallel_attention = parallel_attention + self.added_kv_proj_dim = added_kv_proj_dim + self.prefix = prefix + + self.use_fused_qkv = isinstance(quant_config, NunchakuConfig) + + self.inner_dim = out_dim if out_dim is not None else head_dim * num_heads + self.inner_kv_dim = self.inner_dim + + if self.use_fused_qkv: + # Use fused QKV projection for nunchaku quantization + self.to_qkv = MergedColumnParallelLinear( + dim, + [self.inner_dim] * 3, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.to_qkv", + ) + else: + # Use separate Q/K/V projections for non-quantized models + self.to_q = ReplicatedLinear(dim, self.inner_dim, bias=True) + self.to_k = ReplicatedLinear(dim, self.inner_dim, bias=True) + self.to_v = ReplicatedLinear(dim, self.inner_dim, bias=True) + + if self.qk_norm: + self.norm_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + + if added_kv_proj_dim is not None: + self.use_fused_added_qkv = isinstance(quant_config, NunchakuConfig) + if self.use_fused_added_qkv: + self.to_added_qkv = MergedColumnParallelLinear( + added_kv_proj_dim, + [self.inner_dim] * 3, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.to_added_qkv", + ) + else: + # Use separate Q/K/V projections for non-quantized models + self.add_q_proj = ReplicatedLinear( + added_kv_proj_dim, self.inner_dim, bias=True + ) + self.add_k_proj = ReplicatedLinear( + added_kv_proj_dim, self.inner_dim, bias=True + ) + self.add_v_proj = ReplicatedLinear( + added_kv_proj_dim, self.inner_dim, bias=True + ) + + if context_pre_only is not None and not context_pre_only: + self.to_add_out = ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias) + else: + self.to_add_out = None + + if not pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append( + ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias) + ) + else: + self.to_out = None + + self.norm_added_q = RMSNorm(head_dim, eps=eps) + self.norm_added_k = RMSNorm(head_dim, eps=eps) + + # Scaled dot product attention + self.attn = USPAttention( + num_heads=num_heads, + head_size=self.head_dim, + dropout_rate=0, + softmax_scale=None, + causal=False, + supported_attention_backends={ + AttentionBackendEnum.FA, + AttentionBackendEnum.AITER, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.SAGE_ATTN, + AttentionBackendEnum.SAGE_ATTN_3, + }, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor], + **cross_attention_kwargs, + ): + seq_len_txt = encoder_hidden_states.shape[1] + + img_query, img_key, img_value, txt_query, txt_key, txt_value = ( + _get_qkv_projections(self, hidden_states, encoder_hidden_states) + ) + + # Reshape for multi-head attention + img_query = img_query.unflatten(-1, (self.num_heads, -1)) + img_key = img_key.unflatten(-1, (self.num_heads, -1)) + img_value = img_value.unflatten(-1, (self.num_heads, -1)) + + txt_query = txt_query.unflatten(-1, (self.num_heads, -1)) + txt_key = txt_key.unflatten(-1, (self.num_heads, -1)) + txt_value = txt_value.unflatten(-1, (self.num_heads, -1)) + + # Apply QK normalization + if self.qk_norm: + img_query, img_key = apply_qk_norm( + q=img_query, + k=img_key, + q_norm=self.norm_q, + k_norm=self.norm_k, + head_dim=img_query.shape[-1], + allow_inplace=True, + ) + txt_query, txt_key = apply_qk_norm( + q=txt_query, + k=txt_key, + q_norm=self.norm_added_q, + k_norm=self.norm_added_k, + head_dim=txt_query.shape[-1], + allow_inplace=True, + ) + + # Apply RoPE + if image_rotary_emb is not None: + if not ( + isinstance(image_rotary_emb[0], torch.Tensor) + and image_rotary_emb[0].dim() == 2 + ): + raise RuntimeError("image_rotary_emb must be cos_sin_cache tensors") + + img_cache, txt_cache = image_rotary_emb + + img_query, img_key = apply_flashinfer_rope_qk_inplace( + img_query, img_key, img_cache, is_neox=False + ) + txt_query, txt_key = apply_flashinfer_rope_qk_inplace( + txt_query, txt_key, txt_cache, is_neox=False + ) + + # Concatenate for joint attention + # Order: [text, image] + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) + + # Compute joint attention + joint_hidden_states = self.attn( + joint_query, + joint_key, + joint_value, + ) + + # Reshape back + joint_hidden_states = joint_hidden_states.flatten(2, 3) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + # Split attention outputs back + txt_attn_output = joint_hidden_states[:, :seq_len_txt, :] # Text part + img_attn_output = joint_hidden_states[:, seq_len_txt:, :] # Image part + + # Apply output projections + img_attn_output, _ = self.to_out[0](img_attn_output) + if len(self.to_out) > 1: + (img_attn_output,) = self.to_out[1](img_attn_output) # dropout + + txt_attn_output, _ = self.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +class QwenImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] | NunchakuConfig = None, + prefix: str = "", + zero_cond_t: bool = False, + ): + super().__init__() + self.prefix = prefix + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.quant_config = quant_config + self.zero_cond_t = zero_cond_t + + mod_quant_config = ( + quant_config + if (quant_config is not None and quant_config.get_name() == "svdquant") + else None + ) + # Image processing modules + self.img_mod = nn.Sequential( + nn.SiLU(), + nn.Linear( + dim, 6 * dim, bias=True + ), # For scale, shift, gate for norm1 and norm2 + ) + self.img_norm1 = LayerNormScaleShift( + hidden_size=dim, eps=eps, elementwise_affine=False + ) + + self.attn = QwenImageCrossAttention( + dim=dim, + num_heads=num_attention_heads, + added_kv_proj_dim=dim, + context_pre_only=False, + head_dim=attention_head_dim, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + self.img_norm2 = ScaleResidualLayerNormScaleShift( + dim, eps=eps, elementwise_affine=False + ) + + # Text processing modules + self.txt_mod = nn.Sequential( + nn.SiLU(), + nn.Linear( + dim, 6 * dim, bias=True + ), # For scale, shift, gate for norm1 and norm2 + ) + self.txt_norm1 = LayerNormScaleShift( + hidden_size=dim, eps=eps, elementwise_affine=False + ) + # Text doesn't need separate attention - it's handled by img_attn joint computation + self.txt_norm2 = ScaleResidualLayerNormScaleShift( + hidden_size=dim, eps=eps, elementwise_affine=False + ) + # Utils + self.fuse_mul_add = MulAdd() + + nunchaku_enabled = ( + quant_config is not None + and hasattr(quant_config, "get_name") + and quant_config.get_name() == "svdquant" + and is_nunchaku_available() + ) + if nunchaku_enabled: + ff_class = diffusers.models.attention.FeedForward + self.img_mlp = ff_class( + dim=dim, + dim_out=dim, + activation_fn="gelu-approximate", + ) + self.txt_mlp = ff_class( + dim=dim, + dim_out=dim, + activation_fn="gelu-approximate", + ) + else: + self.img_mlp = FeedForward( + dim=dim, + dim_out=dim, + activation_fn="gelu-approximate", + ) + self.txt_mlp = FeedForward( + dim=dim, + dim_out=dim, + activation_fn="gelu-approximate", + ) + + if nunchaku_enabled: + nunchaku_kwargs = { + "precision": quant_config.precision, + "rank": quant_config.rank, + "act_unsigned": quant_config.act_unsigned, + } + self.img_mlp = NunchakuFeedForward(self.img_mlp, **nunchaku_kwargs) + self.txt_mlp = NunchakuFeedForward(self.txt_mlp, **nunchaku_kwargs) + + def _modulate( + self, + x: torch.Tensor, + mod_params: torch.Tensor, + norm_module: Union[LayerNormScaleShift, ScaleResidualLayerNormScaleShift], + index: Optional[torch.Tensor] = None, + gate_x: Optional[torch.Tensor] = None, + residual_x: Optional[torch.Tensor] = None, + ) -> Union[ + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + ]: + # Apply attention gates and add residual (like in Megatron) + # - residual_out = gate_x * x + residual_x + # - x = norm(residual_out) * (1 + scale) + shift + # TODO: clean code here + is_scale_residual = isinstance(norm_module, ScaleResidualLayerNormScaleShift) + + shift, scale, gate = mod_params.chunk(3, dim=-1) + if index is not None: + actual_batch = x.shape[0] + shift0, shift1 = ( + shift[:actual_batch], + shift[actual_batch : 2 * actual_batch], + ) + scale0, scale1 = ( + scale[:actual_batch], + scale[actual_batch : 2 * actual_batch], + ) + gate0, gate1 = gate[:actual_batch], gate[actual_batch : 2 * actual_batch] + if _is_cuda: + if is_scale_residual: + x = gate_x * x + residual_x + residual_out = x + if not x.is_contiguous(): + x = x.contiguous() + if not index.is_contiguous(): + index = index.contiguous() + # TODO: fuse norm with above select01 kernel, workaround now + x = apply_layernorm_only(x, norm_module) + x, gate_result = fuse_scale_shift_gate_select01_kernel( + x, + scale0=scale0.contiguous(), + shift0=shift0.contiguous(), + gate0=gate0.contiguous(), + scale1=scale1.contiguous(), + shift1=shift1.contiguous(), + gate1=gate1.contiguous(), + index=index, + ) + if is_scale_residual: + return x, residual_out, gate_result + else: + return x, gate_result + else: + mask = (index == 0).unsqueeze(-1) + shift_result = torch.where( + mask, shift0.unsqueeze(1), shift1.unsqueeze(1) + ) + scale_result = torch.where( + mask, scale0.unsqueeze(1), scale1.unsqueeze(1) + ) + gate_result = torch.where(mask, gate0.unsqueeze(1), gate1.unsqueeze(1)) + if is_scale_residual: + modulated, residual_out = norm_module( + residual=residual_x, + x=x, + gate=gate_x, + shift=shift_result, + scale=scale_result, + ) + return modulated, residual_out, gate_result + else: + modulated = norm_module(x=x, shift=shift_result, scale=scale_result) + return modulated, gate_result + else: + shift_result = shift.unsqueeze(1) + scale_result = scale.unsqueeze(1) + gate_result = gate.unsqueeze(1) + if is_scale_residual: + modulated, residual_out = norm_module( + residual=residual_x, + x=x, + gate=gate_x, + shift=shift_result, + scale=scale_result, + ) + return modulated, residual_out, gate_result + else: + modulated = norm_module(x=x, shift=shift_result, scale=scale_result) + return modulated, gate_result + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_mask: torch.Tensor, + temb_img_silu: torch.Tensor, + temb_txt_silu: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + modulate_index: Optional[List[int]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Get modulation parameters for both streams + img_mod_params = self.img_mod[1](temb_img_silu) # [B, 6*dim] + txt_mod_params = self.txt_mod[1](temb_txt_silu) # [B, 6*dim] + + if ( + self.quant_config is not None + and hasattr(self.quant_config, "get_name") + and self.quant_config.get_name() == "svdquant" + ): + # When NOT using nunchaku, reshape mod_params from [B, 6*dim] to [B, dim*6] + # When using nunchaku (svdquant), keep original format + img_mod_params = ( + img_mod_params.view(img_mod_params.shape[0], -1, 6) + .transpose(1, 2) + .reshape(img_mod_params.shape[0], -1) + ) + txt_mod_params = ( + txt_mod_params.view(txt_mod_params.shape[0], -1, 6) + .transpose(1, 2) + .reshape(txt_mod_params.shape[0], -1) + ) + + # Split modulation parameters for norm1 and norm2 + img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + + # Process image stream - norm1 + modulation + img_modulated, img_gate1 = self._modulate( + hidden_states, img_mod1, self.img_norm1, modulate_index + ) + # Process text stream - norm1 + modulation + txt_shift1, txt_scale1, txt_gate1_raw = txt_mod1.chunk(3, dim=-1) + txt_modulated = self.txt_norm1( + encoder_hidden_states, shift=txt_shift1, scale=txt_scale1 + ) + txt_gate1 = txt_gate1_raw.unsqueeze(1) + + # Use QwenAttnProcessor2_0 for joint attention computation + # This directly implements the DoubleStreamLayerMegatron logic: + # 1. Computes QKV for both streams + # 2. Applies QK normalization and RoPE + # 3. Concatenates and runs joint attention + # 4. Splits results back to separate streams + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + # Image stream (will be processed as "sample") + hidden_states=img_modulated, + # Text stream (will be processed as "context") + encoder_hidden_states=txt_modulated, + encoder_hidden_states_mask=encoder_hidden_states_mask, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided + img_attn_output, txt_attn_output = attn_output + # Process image stream - norm2 + MLP + img_modulated2, hidden_states, img_gate2 = self._modulate( + img_attn_output, + img_mod2, + self.img_norm2, + modulate_index, + gate_x=img_gate1, + residual_x=hidden_states, + ) + img_mlp_output = self.img_mlp(img_modulated2) + + if img_mlp_output.dim() == 2: + img_mlp_output = img_mlp_output.unsqueeze(0) + hidden_states = self.fuse_mul_add(img_mlp_output, img_gate2, hidden_states) + + # Process text stream - norm2 + MLP + txt_shift2, txt_scale2, txt_gate2_raw = txt_mod2.chunk(3, dim=-1) + txt_modulated2, encoder_hidden_states = self.txt_norm2( + residual=encoder_hidden_states, + x=txt_attn_output, + gate=txt_gate1, + shift=txt_shift2, + scale=txt_scale2, + ) + txt_gate2 = txt_gate2_raw.unsqueeze(1) + txt_mlp_output = self.txt_mlp(txt_modulated2) + + if txt_mlp_output.dim() == 2: + txt_mlp_output = txt_mlp_output.unsqueeze(0) + encoder_hidden_states = self.fuse_mul_add( + txt_mlp_output, txt_gate2, encoder_hidden_states + ) + + # Clip to prevent overflow for fp16 + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +def to_hashable(obj): + if isinstance(obj, list): + return tuple(to_hashable(x) for x in obj) + return obj + + +class QwenImageTransformer2DModel(CachableDiT, OffloadableDiTMixin): + """ + The Transformer model introduced in Qwen. + + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["QwenImageTransformerBlock"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + _repeated_blocks = ["QwenImageTransformerBlock"] + + param_names_mapping = QwenImageDitConfig().arch_config.param_names_mapping + + @classmethod + def get_nunchaku_quant_rules(cls) -> dict[str, list[str]]: + return { + "skip": [ + "norm", + "embed", + "rotary", + "pos_embed", + ], + "svdq_w4a4": [ + "attn.to_qkv", + "attn.to_out", + "attn.add_qkv_proj", + "attn.to_add_out", + "img_mlp", + "txt_mlp", + ], + "awq_w4a16": [ + "img_mod", + "txt_mod", + ], + } + + def __init__( + self, + config: QwenImageDitConfig, + hf_config: dict[str, Any], + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(config=config, hf_config=hf_config) + patch_size = config.arch_config.patch_size + in_channels = config.arch_config.in_channels + out_channels = config.arch_config.out_channels + num_layers = config.arch_config.num_layers + attention_head_dim = config.arch_config.attention_head_dim + num_attention_heads = config.arch_config.num_attention_heads + joint_attention_dim = config.arch_config.joint_attention_dim + axes_dims_rope = config.arch_config.axes_dims_rope + self.zero_cond_t = getattr(config.arch_config, "zero_cond_t", False) + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.use_additional_t_cond: bool = getattr( + config.arch_config, "use_additional_t_cond", False + ) # For qwen-image-layered now + self.use_layer3d_rope: bool = getattr( + config.arch_config, "use_layer3d_rope", False + ) # For qwen-image-layered now + + if not self.use_layer3d_rope: + self.rotary_emb = QwenEmbedRope( + theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True + ) + else: + self.rotary_emb = QwenEmbedLayer3DRope( + theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True + ) + + self.time_text_embed = QwenTimestepProjEmbeddings( + embedding_dim=self.inner_dim, + use_additional_t_cond=self.use_additional_t_cond, + ) + + self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) + + self.img_in = nn.Linear(in_channels, self.inner_dim) + self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + QwenImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + quant_config=quant_config, + prefix=f"transformer_blocks.{layer_idx}", + zero_cond_t=self.zero_cond_t, + ) + for layer_idx in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 + ) + self.proj_out = nn.Linear( + self.inner_dim, patch_size * patch_size * self.out_channels, bias=True + ) + + self.timestep_zero = torch.zeros( + (1,), dtype=torch.int, device=get_local_torch_device() + ) + + self.layer_names = ["transformer_blocks"] + + @functools.lru_cache(maxsize=50) + def build_modulate_index(self, img_shapes: tuple[int, int, int], device): + modulate_index_list = [] + for sample in img_shapes: + first_size = sample[0][0] * sample[0][1] * sample[0][2] + total_size = sum(s[0] * s[1] * s[2] for s in sample) + idx = (torch.arange(total_size, device=device) >= first_size).int() + modulate_index_list.append(idx) + + modulate_index = torch.stack(modulate_index_list) + return modulate_index + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_mask: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_shapes: Optional[List[Tuple[int, int, int]]] = None, + txt_seq_lens: Optional[List[int]] = None, + freqs_cis: tuple[torch.Tensor, torch.Tensor] = None, + additional_t_cond: Optional[torch.Tensor] = None, + guidance: torch.Tensor = None, # TODO: this should probably be removed + attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`QwenTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`): + Mask of the input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if ( + attention_kwargs is not None + and attention_kwargs.get("scale", None) is not None + ): + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + if isinstance(encoder_hidden_states, list): + encoder_hidden_states = encoder_hidden_states[0] + + hidden_states = self.img_in(hidden_states) + + timestep = (timestep / 1000).to(hidden_states.dtype) + + if self.zero_cond_t: + timestep = torch.cat([timestep, self.timestep_zero], dim=0) + device = timestep.device + modulate_index = self.build_modulate_index(to_hashable(img_shapes), device) + else: + modulate_index = None + + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + temb = self.time_text_embed(timestep, hidden_states, additional_t_cond) + + temb_img_silu = F.silu(temb) + if self.zero_cond_t: + temb_txt = temb.chunk(2, dim=0)[0] + temb_txt_silu = temb_img_silu.chunk(2, dim=0)[0] + else: + temb_txt = temb + temb_txt_silu = temb_img_silu + + image_rotary_emb = freqs_cis + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb_img_silu=temb_img_silu, + temb_txt_silu=temb_txt_silu, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + modulate_index=modulate_index, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len( + controlnet_block_samples + ) + interval_control = int(np.ceil(interval_control)) + hidden_states = ( + hidden_states + + controlnet_block_samples[index_block // interval_control] + ) + # Use only the image part (hidden_states) from the dual-stream blocks + hidden_states = self.norm_out(hidden_states, temb_txt) + + output = self.proj_out(hidden_states) + return output + + +EntryClass = QwenImageTransformer2DModel diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/sglang/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py new file mode 100644 index 0000000000000000000000000000000000000000..c8331f9a69c2e17ae53f637a836892ec8d35e175 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -0,0 +1,1128 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import math +from functools import lru_cache +from typing import Any + +import torch +import torch.nn as nn + +from sglang.multimodal_gen.configs.models.dits import WanVideoConfig +from sglang.multimodal_gen.configs.sample.wan import WanTeaCacheParams +from sglang.multimodal_gen.runtime.distributed import ( + divide, + get_sp_group, + get_sp_world_size, + get_tp_world_size, + sequence_model_parallel_all_gather, +) +from sglang.multimodal_gen.runtime.layers.attention import ( + MinimalA2AAttnOp, + UlyssesAttention_VSA, + USPAttention, +) +from sglang.multimodal_gen.runtime.layers.elementwise import MulAdd +from sglang.multimodal_gen.runtime.layers.layernorm import ( + FP32LayerNorm, + LayerNormScaleShift, + RMSNorm, + ScaleResidualLayerNormScaleShift, + tensor_parallel_rms_norm, +) +from sglang.multimodal_gen.runtime.layers.linear import ( + ColumnParallelLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.mlp import MLP +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, +) +from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( + NDRotaryEmbedding, + _apply_rotary_emb, + apply_flashinfer_rope_qk_inplace, +) +from sglang.multimodal_gen.runtime.layers.visual_embedding import ( + ModulateProjection, + PatchEmbed, + TimestepEmbedder, +) +from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context +from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) +from sglang.multimodal_gen.runtime.server_args import get_global_server_args +from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) +_is_cuda = current_platform.is_cuda() + + +class WanImageEmbedding(torch.nn.Module): + + def __init__(self, in_features: int, out_features: int): + super().__init__() + + self.norm1 = FP32LayerNorm(in_features) + self.ff = MLP(in_features, in_features, out_features, act_type="gelu") + self.norm2 = FP32LayerNorm(out_features) + + def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: + dtype = encoder_hidden_states_image.dtype + hidden_states = self.norm1(encoder_hidden_states_image) + hidden_states = self.ff(hidden_states) + hidden_states = self.norm2(hidden_states).to(dtype) + return hidden_states + + +class WanTimeTextImageEmbedding(nn.Module): + + def __init__( + self, + dim: int, + time_freq_dim: int, + text_embed_dim: int, + image_embed_dim: int | None = None, + ): + super().__init__() + + self.time_embedder = TimestepEmbedder( + dim, frequency_embedding_size=time_freq_dim, act_layer="silu" + ) + self.time_modulation = ModulateProjection(dim, factor=6, act_layer="silu") + self.text_embedder = MLP( + text_embed_dim, dim, dim, bias=True, act_type="gelu_pytorch_tanh" + ) + + self.image_embedder = None + if image_embed_dim is not None: + self.image_embedder = WanImageEmbedding(image_embed_dim, dim) + + def forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: torch.Tensor | None = None, + timestep_seq_len: int | None = None, + ): + temb = self.time_embedder(timestep, timestep_seq_len) + timestep_proj = self.time_modulation(temb) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + assert self.image_embedder is not None + encoder_hidden_states_image = self.image_embedder( + encoder_hidden_states_image + ) + + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + + +class WanSelfAttention(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6, + parallel_attention=False, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + is_cross_attention: bool = False, + quant_config: QuantizationConfig | None = None, + ) -> None: + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + self.parallel_attention = parallel_attention + tp_size = get_tp_world_size() + + # layers + self.to_q = ColumnParallelLinear( + dim, dim, gather_output=False, quant_config=quant_config + ) + self.to_k = ColumnParallelLinear( + dim, dim, gather_output=False, quant_config=quant_config + ) + self.to_v = ColumnParallelLinear( + dim, dim, gather_output=False, quant_config=quant_config + ) + self.to_out = RowParallelLinear( + dim, dim, input_is_parallel=True, quant_config=quant_config + ) + self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.tp_rmsnorm = tp_size > 1 and qk_norm + self.local_num_heads = divide(num_heads, tp_size) + + # Scaled dot product attention + self.attn = USPAttention( + num_heads=self.local_num_heads, + head_size=self.head_dim, + dropout_rate=0, + softmax_scale=None, + causal=False, + supported_attention_backends=supported_attention_backends, + skip_sequence_parallel=is_cross_attention, + ) + + def forward(self, x: torch.Tensor, context: torch.Tensor, context_lens: int): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + """ + pass + + +class WanT2VCrossAttention(WanSelfAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, is_cross_attention=True) + + def forward(self, x, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + q, _ = self.to_q(x) + if self.tp_rmsnorm: + q = tensor_parallel_rms_norm(q, self.norm_q) + else: + q = self.norm_q(q) + q = q.unflatten(2, (self.local_num_heads, self.head_dim)) + + k, _ = self.to_k(context) + if self.tp_rmsnorm: + k = tensor_parallel_rms_norm(k, self.norm_k) + else: + k = self.norm_k(k) + k = k.unflatten(2, (self.local_num_heads, self.head_dim)) + + v, _ = self.to_v(context) + v = v.unflatten(2, (self.local_num_heads, self.head_dim)) + + # compute attention + x = self.attn(q, k, v) + + # output + x = x.flatten(2) + x, _ = self.to_out(x) + return x + + +class WanI2VCrossAttention(WanSelfAttention): + + def __init__( + self, + dim: int, + num_heads: int, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + quant_config: QuantizationConfig | None = None, + ) -> None: + super().__init__( + dim, + num_heads, + window_size, + qk_norm, + eps, + supported_attention_backends=supported_attention_backends, + is_cross_attention=True, + quant_config=quant_config, + ) + + self.add_k_proj = ColumnParallelLinear( + dim, dim, gather_output=False, quant_config=quant_config + ) + self.add_v_proj = ColumnParallelLinear( + dim, dim, gather_output=False, quant_config=quant_config + ) + self.norm_added_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + context_img = context[:, :257] + context = context[:, 257:] + + q, _ = self.to_q(x) + if self.tp_rmsnorm: + q = tensor_parallel_rms_norm(q, self.norm_q) + else: + q = self.norm_q(q) + q = q.unflatten(2, (self.local_num_heads, self.head_dim)) + + k, _ = self.to_k(context) + if self.tp_rmsnorm: + k = tensor_parallel_rms_norm(k, self.norm_k) + else: + k = self.norm_k(k) + k = k.unflatten(2, (self.local_num_heads, self.head_dim)) + + v, _ = self.to_v(context) + v = v.unflatten(2, (self.local_num_heads, self.head_dim)) + + k_img, _ = self.add_k_proj(context_img) + if self.tp_rmsnorm: + k_img = tensor_parallel_rms_norm(k_img, self.norm_added_k) + else: + k_img = self.norm_added_k(k_img) + k_img = k_img.unflatten(2, (self.local_num_heads, self.head_dim)) + + v_img, _ = self.add_v_proj(context_img) + v_img = v_img.unflatten(2, (self.local_num_heads, self.head_dim)) + + img_x = self.attn(q, k_img, v_img) + x = self.attn(q, k, v) + + # output + x = x.flatten(2) + img_x = img_x.flatten(2) + x = x + img_x + x, _ = self.to_out(x) + return x + + +class WanTransformerBlock(nn.Module): + + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: int | None = None, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + prefix: str = "", + attention_type: str = "original", + sla_topk: float = 0.1, + quant_config: QuantizationConfig | None = None, + ): + super().__init__() + + # 1. Self-attention + self.norm1 = LayerNormScaleShift( + dim, + eps=eps, + elementwise_affine=False, + dtype=torch.float32, + ) + self.to_q = ColumnParallelLinear( + dim, dim, bias=True, gather_output=False, quant_config=quant_config + ) + self.to_k = ColumnParallelLinear( + dim, dim, bias=True, gather_output=False, quant_config=quant_config + ) + self.to_v = ColumnParallelLinear( + dim, dim, bias=True, gather_output=False, quant_config=quant_config + ) + + self.to_out = RowParallelLinear( + dim, dim, bias=True, reduce_results=True, quant_config=quant_config + ) + tp_size = get_tp_world_size() + self.local_num_heads = divide(num_heads, tp_size) + self_attn_backends = supported_attention_backends + + if attention_type in ("sla", "sagesla"): + self.attn1 = MinimalA2AAttnOp( + num_heads=self.local_num_heads, + head_size=dim // num_heads, + attention_type=attention_type, + topk=sla_topk, + supported_attention_backends={ + AttentionBackendEnum.SLA_ATTN, + AttentionBackendEnum.SAGE_SLA_ATTN, + }, + ) + else: + self.attn1 = USPAttention( + num_heads=self.local_num_heads, + head_size=dim // num_heads, + causal=False, + supported_attention_backends=self_attn_backends, + is_cross_attention=False, + prefix=f"{prefix}.attn1", + ) + + self.hidden_dim = dim + self.num_attention_heads = num_heads + self.dim_head = dim // num_heads + if qk_norm == "rms_norm": + self.norm_q = RMSNorm(self.dim_head, eps=eps) + self.norm_k = RMSNorm(self.dim_head, eps=eps) + elif qk_norm == "rms_norm_across_heads": + # LTX applies qk norm across all heads + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + else: + logger.error("QK Norm type not supported") + raise Exception + assert cross_attn_norm is True + self.qk_norm = qk_norm + self.tp_rmsnorm = qk_norm == "rms_norm_across_heads" and tp_size > 1 + self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift( + dim, + eps=eps, + elementwise_affine=True, + dtype=torch.float32, + ) + + # 2. Cross-attention + cross_attn_backends = { + b for b in supported_attention_backends if not b.is_sparse + } + if added_kv_proj_dim is not None: + # I2V + self.attn2 = WanI2VCrossAttention( + dim, + num_heads, + qk_norm=qk_norm, + eps=eps, + supported_attention_backends=cross_attn_backends, + quant_config=quant_config, + ) + else: + # T2V + self.attn2 = WanT2VCrossAttention( + dim, + num_heads, + qk_norm=qk_norm, + eps=eps, + supported_attention_backends=cross_attn_backends, + quant_config=quant_config, + ) + self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift( + dim, + eps=eps, + elementwise_affine=False, + dtype=torch.float32, + ) + + # 3. Feed-forward + self.ffn = MLP( + dim, ffn_dim, act_type="gelu_pytorch_tanh", quant_config=quant_config + ) + self.mlp_residual = MulAdd() + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + freqs_cis: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + if hidden_states.dim() == 4: + hidden_states = hidden_states.squeeze(1) + bs, seq_length, _ = hidden_states.shape + orig_dtype = hidden_states.dtype + if temb.dim() == 4: + # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.unsqueeze(0) + temb.float() + ).chunk(6, dim=2) + # batch_size, seq_len, 1, inner_dim + shift_msa = shift_msa.squeeze(2) + scale_msa = scale_msa.squeeze(2) + gate_msa = gate_msa.squeeze(2) + c_shift_msa = c_shift_msa.squeeze(2) + c_scale_msa = c_scale_msa.squeeze(2) + c_gate_msa = c_gate_msa.squeeze(2) + else: + # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) + e = self.scale_shift_table + temb.float() + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + e.chunk(6, dim=1) + ) + + assert shift_msa.dtype == torch.float32 + + # 1. Self-attention + norm_hidden_states = self.norm1(hidden_states, shift_msa, scale_msa) + query, _ = self.to_q(norm_hidden_states) + key, _ = self.to_k(norm_hidden_states) + value, _ = self.to_v(norm_hidden_states) + + if self.norm_q is not None: + if self.tp_rmsnorm: + query = tensor_parallel_rms_norm(query, self.norm_q) + else: + query = self.norm_q(query) + if self.norm_k is not None: + if self.tp_rmsnorm: + key = tensor_parallel_rms_norm(key, self.norm_k) + else: + key = self.norm_k(key) + query = query.squeeze(1).unflatten(2, (self.local_num_heads, self.dim_head)) + key = key.squeeze(1).unflatten(2, (self.local_num_heads, self.dim_head)) + value = value.squeeze(1).unflatten(2, (self.local_num_heads, self.dim_head)) + + # Apply rotary embeddings + cos, sin = freqs_cis + if _is_cuda and query.shape == key.shape: + cos_sin_cache = torch.cat( + [ + cos.to(dtype=torch.float32).contiguous(), + sin.to(dtype=torch.float32).contiguous(), + ], + dim=-1, + ) + query, key = apply_flashinfer_rope_qk_inplace( + query, key, cos_sin_cache, is_neox=False + ) + else: + query, key = _apply_rotary_emb( + query, cos, sin, is_neox_style=False + ), _apply_rotary_emb(key, cos, sin, is_neox_style=False) + attn_output = self.attn1(query, key, value) + attn_output = attn_output.flatten(2) + attn_output, _ = self.to_out(attn_output) + attn_output = attn_output.squeeze(1) + + null_shift = null_scale = torch.zeros( + (1,), device=hidden_states.device, dtype=hidden_states.dtype + ) + norm_hidden_states, hidden_states = self.self_attn_residual_norm( + hidden_states, attn_output, gate_msa, null_shift, null_scale + ) + norm_hidden_states, hidden_states = norm_hidden_states.to( + orig_dtype + ), hidden_states.to(orig_dtype) + + # 2. Cross-attention + attn_output = self.attn2( + norm_hidden_states, context=encoder_hidden_states, context_lens=None + ) + norm_hidden_states, hidden_states = self.cross_attn_residual_norm( + hidden_states, attn_output, 1, c_shift_msa, c_scale_msa + ) + norm_hidden_states, hidden_states = norm_hidden_states.to( + orig_dtype + ), hidden_states.to(orig_dtype) + + # 3. Feed-forward + ff_output = self.ffn(norm_hidden_states) + hidden_states = self.mlp_residual(ff_output, c_gate_msa, hidden_states) + hidden_states = hidden_states.to(orig_dtype) + + return hidden_states + + +class WanTransformerBlock_VSA(nn.Module): + + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: int | None = None, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + prefix: str = "", + quant_config: QuantizationConfig | None = None, + ): + super().__init__() + + # 1. Self-attention + self.norm1 = LayerNormScaleShift( + dim, + eps=eps, + elementwise_affine=False, + dtype=torch.float32, + ) + self.to_q = ColumnParallelLinear( + dim, dim, bias=True, gather_output=True, quant_config=quant_config + ) + self.to_k = ColumnParallelLinear( + dim, dim, bias=True, gather_output=True, quant_config=quant_config + ) + self.to_v = ColumnParallelLinear( + dim, dim, bias=True, gather_output=True, quant_config=quant_config + ) + self.to_gate_compress = ColumnParallelLinear( + dim, dim, bias=True, gather_output=True, quant_config=quant_config + ) + + self.to_out = ColumnParallelLinear( + dim, dim, bias=True, gather_output=True, quant_config=quant_config + ) + self.attn1 = UlyssesAttention_VSA( + num_heads=num_heads, + head_size=dim // num_heads, + causal=False, + supported_attention_backends=supported_attention_backends, + prefix=f"{prefix}.attn1", + ) + self.hidden_dim = dim + self.num_attention_heads = num_heads + dim_head = dim // num_heads + if qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "rms_norm_across_heads": + # LTX applies qk norm across all heads + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + else: + logger.error("QK Norm type not supported") + raise Exception + assert cross_attn_norm is True + self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift( + dim, + eps=eps, + elementwise_affine=True, + dtype=torch.float32, + ) + + # 2. Cross-attention + cross_attn_backends = { + b for b in supported_attention_backends if not b.is_sparse + } + if added_kv_proj_dim is not None: + # I2V + self.attn2 = WanI2VCrossAttention( + dim, + num_heads, + qk_norm=qk_norm, + eps=eps, + supported_attention_backends=cross_attn_backends, + quant_config=quant_config, + ) + else: + # T2V + self.attn2 = WanT2VCrossAttention( + dim, + num_heads, + qk_norm=qk_norm, + eps=eps, + supported_attention_backends=cross_attn_backends, + quant_config=quant_config, + ) + self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift( + dim, + eps=eps, + elementwise_affine=False, + dtype=torch.float32, + ) + + # 3. Feed-forward + self.ffn = MLP( + dim, ffn_dim, act_type="gelu_pytorch_tanh", quant_config=quant_config + ) + self.mlp_residual = MulAdd() + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + freqs_cis: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + if hidden_states.dim() == 4: + hidden_states = hidden_states.squeeze(1) + bs, seq_length, _ = hidden_states.shape + orig_dtype = hidden_states.dtype + # assert orig_dtype != torch.float32 + e = self.scale_shift_table + temb.float() + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = e.chunk( + 6, dim=1 + ) + assert shift_msa.dtype == torch.float32 + + # 1. Self-attention + norm_hidden_states = self.norm1(hidden_states, shift_msa, scale_msa) + query, _ = self.to_q(norm_hidden_states) + key, _ = self.to_k(norm_hidden_states) + value, _ = self.to_v(norm_hidden_states) + gate_compress, _ = self.to_gate_compress(norm_hidden_states) + + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + gate_compress = gate_compress.squeeze(1).unflatten( + 2, (self.num_attention_heads, -1) + ) + + # Apply rotary embeddings + cos, sin = freqs_cis + if _is_cuda and query.shape == key.shape: + cos_sin_cache = torch.cat( + [ + cos.to(dtype=torch.float32).contiguous(), + sin.to(dtype=torch.float32).contiguous(), + ], + dim=-1, + ) + query, key = apply_flashinfer_rope_qk_inplace( + query, key, cos_sin_cache, is_neox=False + ) + else: + query, key = _apply_rotary_emb( + query, cos, sin, is_neox_style=False + ), _apply_rotary_emb(key, cos, sin, is_neox_style=False) + + attn_output = self.attn1(query, key, value, gate_compress=gate_compress) + attn_output = attn_output.flatten(2) + attn_output, _ = self.to_out(attn_output) + attn_output = attn_output.squeeze(1) + + null_shift = null_scale = torch.zeros((1,), device=hidden_states.device) + norm_hidden_states, hidden_states = self.self_attn_residual_norm( + hidden_states, attn_output, gate_msa, null_shift, null_scale + ) + norm_hidden_states, hidden_states = norm_hidden_states.to( + orig_dtype + ), hidden_states.to(orig_dtype) + + # 2. Cross-attention + attn_output = self.attn2( + norm_hidden_states, context=encoder_hidden_states, context_lens=None + ) + norm_hidden_states, hidden_states = self.cross_attn_residual_norm( + hidden_states, attn_output, 1, c_shift_msa, c_scale_msa + ) + norm_hidden_states, hidden_states = norm_hidden_states.to( + orig_dtype + ), hidden_states.to(orig_dtype) + + # 3. Feed-forward + ff_output = self.ffn(norm_hidden_states) + hidden_states = self.mlp_residual(ff_output, c_gate_msa, hidden_states) + hidden_states = hidden_states.to(orig_dtype) + + return hidden_states + + +class WanTransformer3DModel(CachableDiT, OffloadableDiTMixin): + _fsdp_shard_conditions = WanVideoConfig()._fsdp_shard_conditions + _compile_conditions = WanVideoConfig()._compile_conditions + _supported_attention_backends = WanVideoConfig()._supported_attention_backends + param_names_mapping = WanVideoConfig().param_names_mapping + reverse_param_names_mapping = WanVideoConfig().reverse_param_names_mapping + lora_param_names_mapping = WanVideoConfig().lora_param_names_mapping + + def __init__( + self, + config: WanVideoConfig, + hf_config: dict[str, Any], + quant_config: QuantizationConfig | None = None, + ) -> None: + super().__init__(config=config, hf_config=hf_config) + + inner_dim = config.num_attention_heads * config.attention_head_dim + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.in_channels = config.in_channels + self.out_channels = config.out_channels + self.num_channels_latents = config.num_channels_latents + self.patch_size = config.patch_size + self.text_len = config.text_len + + # 1. Patch & position embedding + self.patch_embedding = PatchEmbed( + in_chans=config.in_channels, + embed_dim=inner_dim, + patch_size=config.patch_size, + flatten=False, + ) + + # 2. Condition embeddings + self.condition_embedder = WanTimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=config.freq_dim, + text_embed_dim=config.text_dim, + image_embed_dim=config.image_dim, + ) + + # 3. Transformer blocks + attn_backend = get_global_server_args().attention_backend + transformer_block = ( + WanTransformerBlock_VSA + if (attn_backend and attn_backend.lower() == "video_sparse_attn") + else WanTransformerBlock + ) + self.blocks = nn.ModuleList( + [ + transformer_block( + inner_dim, + config.ffn_dim, + config.num_attention_heads, + config.qk_norm, + config.cross_attn_norm, + config.eps, + config.added_kv_proj_dim, + self._supported_attention_backends + | {AttentionBackendEnum.VIDEO_SPARSE_ATTN}, + prefix=f"{config.prefix}.blocks.{i}", + attention_type=config.attention_type, + sla_topk=config.sla_topk, + quant_config=quant_config, + ) + for i in range(config.num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = LayerNormScaleShift( + inner_dim, + eps=config.eps, + elementwise_affine=False, + dtype=torch.float32, + ) + self.proj_out = ColumnParallelLinear( + inner_dim, + config.out_channels * math.prod(config.patch_size), + bias=True, + gather_output=True, + quant_config=quant_config, + ) + self.scale_shift_table = nn.Parameter( + torch.randn(1, 2, inner_dim) / inner_dim**0.5 + ) + + # For type checking + + self.cnt = 0 + self.__post_init__() + + # misc + self.sp_size = get_sp_world_size() + + # Get rotary embeddings + d = self.hidden_size // self.num_attention_heads + self.rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)] + + self.rotary_emb = NDRotaryEmbedding( + rope_dim_list=self.rope_dim_list, + rope_theta=10000, + dtype=( + torch.float32 + if current_platform.is_mps() or current_platform.is_musa() + else torch.float64 + ), + ) + + self.layer_names = ["blocks"] + + @lru_cache(maxsize=1) + def _compute_rope_for_sequence_shard( + self, + local_len: int, + rank: int, + frame_stride_local: int, + width_local: int, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + token_start = rank * local_len + token_indices = torch.arange( + token_start, + token_start + local_len, + device=device, + dtype=torch.long, + ) + t_idx = token_indices // frame_stride_local + rem = token_indices % frame_stride_local + h_idx = rem // width_local + w_idx = rem % width_local + positions = torch.stack((t_idx, h_idx, w_idx), dim=1) + return self.rotary_emb.forward_uncached(positions) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + timestep: torch.LongTensor, + encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None, + guidance=None, + **kwargs, + ) -> torch.Tensor: + forward_batch = get_forward_context().forward_batch + if forward_batch is not None: + sequence_shard_enabled = ( + forward_batch.enable_sequence_shard and self.sp_size > 1 + ) + else: + sequence_shard_enabled = False + self.enable_teacache = ( + forward_batch is not None and forward_batch.enable_teacache + ) + + orig_dtype = hidden_states.dtype + if not isinstance(encoder_hidden_states, torch.Tensor): + encoder_hidden_states = encoder_hidden_states[0] + if ( + isinstance(encoder_hidden_states_image, list) + and len(encoder_hidden_states_image) > 0 + ): + encoder_hidden_states_image = encoder_hidden_states_image[0] + else: + encoder_hidden_states_image = None + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + p_t, p_h, p_w = self.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + if not sequence_shard_enabled: + # The rotary embedding layer correctly handles SP offsets internally. + freqs_cos, freqs_sin = self.rotary_emb.forward_from_grid( + ( + post_patch_num_frames * self.sp_size, + post_patch_height, + post_patch_width, + ), + shard_dim=0, + start_frame=0, + device=hidden_states.device, + ) + assert freqs_cos.dtype == torch.float32 + assert freqs_cos.device == hidden_states.device + freqs_cis = ( + (freqs_cos.float(), freqs_sin.float()) + if freqs_cos is not None + else None + ) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + # shape is [B, T' * H' * W', C] + seq_len_orig = hidden_states.shape[1] + seq_shard_pad = 0 + if sequence_shard_enabled: + if seq_len_orig % self.sp_size != 0: + seq_shard_pad = self.sp_size - (seq_len_orig % self.sp_size) + pad = torch.zeros( + (batch_size, seq_shard_pad, hidden_states.shape[2]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + hidden_states = torch.cat([hidden_states, pad], dim=1) + sp_rank = get_sp_group().rank_in_group + local_seq_len = hidden_states.shape[1] // self.sp_size + hidden_states = hidden_states.view( + batch_size, self.sp_size, local_seq_len, hidden_states.shape[2] + ) + hidden_states = hidden_states[:, sp_rank, :, :] + + frame_stride = post_patch_height * post_patch_width + freqs_cos, freqs_sin = self._compute_rope_for_sequence_shard( + local_seq_len, + sp_rank, + frame_stride, + post_patch_width, + hidden_states.device, + ) + freqs_cis = (freqs_cos.float(), freqs_sin.float()) + + # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v) + if timestep.dim() == 2: + # ti2v + ts_seq_len = timestep.shape[1] + timestep = timestep.flatten() # batch_size * seq_len + else: + ts_seq_len = None + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = ( + self.condition_embedder( + timestep, + encoder_hidden_states, + encoder_hidden_states_image, + timestep_seq_len=ts_seq_len, + ) + ) + if ts_seq_len is not None: + # batch_size, seq_len, 6, inner_dim + timestep_proj = timestep_proj.unflatten(2, (6, -1)) + else: + # batch_size, 6, inner_dim + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + if sequence_shard_enabled and ts_seq_len is not None: + if seq_shard_pad > 0: + pad = torch.zeros( + ( + batch_size, + seq_shard_pad, + timestep_proj.shape[2], + timestep_proj.shape[3], + ), + dtype=timestep_proj.dtype, + device=timestep_proj.device, + ) + timestep_proj = torch.cat([timestep_proj, pad], dim=1) + timestep_proj = timestep_proj.view( + batch_size, + self.sp_size, + local_seq_len, + timestep_proj.shape[2], + timestep_proj.shape[3], + ) + timestep_proj = timestep_proj[:, sp_rank, :, :, :] + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat( + [encoder_hidden_states_image, encoder_hidden_states], dim=1 + ) + + encoder_hidden_states = ( + encoder_hidden_states.to(orig_dtype) + if not current_platform.is_amp_supported() + else encoder_hidden_states + ) # cast to orig_dtype for MPS + + assert encoder_hidden_states.dtype == orig_dtype + + # 4. Transformer blocks + # if caching is enabled, we might be able to skip the forward pass + should_skip_forward = self.should_skip_forward_for_cached_states( + timestep_proj=timestep_proj, temb=temb + ) + + if should_skip_forward: + hidden_states = self.retrieve_cached_states(hidden_states) + else: + # if teacache is enabled, we need to cache the original hidden states + if self.enable_teacache: + original_hidden_states = hidden_states.clone() + + for block in self.blocks: + hidden_states = block( + hidden_states, encoder_hidden_states, timestep_proj, freqs_cis + ) + # if teacache is enabled, we need to cache the original hidden states + if self.enable_teacache: + self.maybe_cache_states(hidden_states, original_hidden_states) + self.cnt += 1 + + if sequence_shard_enabled: + hidden_states = hidden_states.contiguous() + hidden_states = sequence_model_parallel_all_gather(hidden_states, dim=1) + if seq_shard_pad > 0: + hidden_states = hidden_states[:, :seq_len_orig, :] + + # 5. Output norm, projection & unpatchify + if temb.dim() == 3: + # batch_size, seq_len, inner_dim (wan 2.2 ti2v) + shift, scale = ( + self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2) + ).chunk(2, dim=2) + shift = shift.squeeze(2) + scale = scale.squeeze(2) + else: + # batch_size, inner_dim + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + + hidden_states = self.norm_out(hidden_states, shift, scale) + hidden_states, _ = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, + post_patch_num_frames, + post_patch_height, + post_patch_width, + p_t, + p_h, + p_w, + -1, + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return output + + def maybe_cache_states( + self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor + ) -> None: + """Cache residual with CFG positive/negative separation.""" + residual = hidden_states.squeeze(0) - original_hidden_states + if not self.is_cfg_negative: + self.previous_residual = residual + else: + self.previous_residual_negative = residual + + def should_skip_forward_for_cached_states(self, **kwargs) -> bool: + if not self.enable_teacache: + return False + ctx = self._get_teacache_context() + if ctx is None: + return False + + # Wan uses WanTeaCacheParams with additional fields + teacache_params = ctx.teacache_params + assert isinstance( + teacache_params, WanTeaCacheParams + ), "teacache_params is not a WanTeaCacheParams" + + # Initialize Wan-specific parameters + use_ret_steps = teacache_params.use_ret_steps + cutoff_steps = teacache_params.get_cutoff_steps(ctx.num_inference_steps) + ret_steps = teacache_params.ret_steps + + # Adjust ret_steps and cutoff_steps for non-CFG mode + # (WanTeaCacheParams uses *2 factor assuming CFG) + if not ctx.do_cfg: + ret_steps = ret_steps // 2 + cutoff_steps = cutoff_steps // 2 + + timestep_proj = kwargs["timestep_proj"] + temb = kwargs["temb"] + modulated_inp = timestep_proj if use_ret_steps else temb + + self.is_cfg_negative = ctx.is_cfg_negative + + # Wan uses ret_steps/cutoff_steps for boundary detection + is_boundary_step = self.cnt < ret_steps or self.cnt >= cutoff_steps + + # Use shared helper to compute cache decision + should_calc = self._compute_teacache_decision( + modulated_inp=modulated_inp, + is_boundary_step=is_boundary_step, + coefficients=ctx.coefficients, + teacache_thresh=ctx.teacache_thresh, + ) + + return not should_calc + + def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Retrieve cached residual with CFG positive/negative separation.""" + if not self.is_cfg_negative: + return hidden_states + self.previous_residual + else: + return hidden_states + self.previous_residual_negative + + +EntryClass = WanTransformer3DModel diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/dits/zimage.py b/sglang/python/sglang/multimodal_gen/runtime/models/dits/zimage.py new file mode 100644 index 0000000000000000000000000000000000000000..ae0e421b6bd4c0aaaa048145c26d8a0ced77f176 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/dits/zimage.py @@ -0,0 +1,776 @@ +import math +from typing import Any, List, Optional, Tuple + +import torch +import torch.nn as nn + +from sglang.multimodal_gen.configs.models.dits.zimage import ZImageDitConfig +from sglang.multimodal_gen.runtime.distributed import get_tp_world_size +from sglang.multimodal_gen.runtime.layers.activation import SiluAndMul +from sglang.multimodal_gen.runtime.layers.attention import USPAttention +from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm, apply_qk_norm +from sglang.multimodal_gen.runtime.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, +) +from sglang.multimodal_gen.runtime.layers.quantization.configs.nunchaku_config import ( + NunchakuConfig, + is_nunchaku_available, +) +from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( + _apply_rotary_emb, + apply_flashinfer_rope_qk_inplace, +) +from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +try: + from nunchaku.models.attention import NunchakuFeedForward # type: ignore[import] +except Exception: + NunchakuFeedForward = None + +logger = init_logger(__name__) +_is_cuda = current_platform.is_cuda() + +ADALN_EMBED_DIM = 256 +SEQ_MULTI_OF = 32 + + +class SelectFirstElement(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[0] + + +class TimestepEmbedder(nn.Module): + def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): + super().__init__() + if mid_size is None: + mid_size = out_size + + self.mlp = nn.ModuleList( + [ + ColumnParallelLinear( + frequency_embedding_size, mid_size, bias=True, gather_output=False + ), + nn.SiLU(), + RowParallelLinear( + mid_size, out_size, bias=True, input_is_parallel=True + ), + ] + ) + + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + with torch.amp.autocast(current_platform.device_type, enabled=False): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) + / half + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to( + self.mlp[0].weight.dtype + ) + t_emb, _ = self.mlp[0](t_freq) + t_emb = self.mlp[1](t_emb) + t_emb, _ = self.mlp[2](t_emb) + return t_emb + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + # Use MergedColumnParallelLinear for gate and up projection (fused) + self.w13 = MergedColumnParallelLinear( + dim, [hidden_dim, hidden_dim], bias=False, gather_output=False + ) + self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True) + self.act = SiluAndMul() + + def forward(self, x): + x13, _ = self.w13(x) + x = self.act(x13) + out, _ = self.w2(x) + return out + + +class ZImageAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + qk_norm: bool = True, + eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.dim = dim + self.head_dim = dim // num_heads + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.qk_norm = qk_norm + + tp_size = get_tp_world_size() + assert ( + num_heads % tp_size == 0 + ), f"num_heads {num_heads} must be divisible by tp world size {tp_size}" + assert ( + num_kv_heads % tp_size == 0 + ), f"num_kv_heads {num_kv_heads} must be divisible by tp world size {tp_size}" + self.local_num_heads = num_heads // tp_size + self.local_num_kv_heads = num_kv_heads // tp_size + + kv_dim = self.head_dim * num_kv_heads + self.use_fused_qkv = isinstance(quant_config, NunchakuConfig) + + if self.use_fused_qkv: + self.to_qkv = MergedColumnParallelLinear( + dim, + [dim, kv_dim, kv_dim], + bias=False, + gather_output=False, + quant_config=quant_config, + prefix=f"{prefix}.to_qkv", + ) + else: + self.to_q = ColumnParallelLinear( + dim, + dim, + bias=False, + gather_output=False, + quant_config=quant_config, + prefix=f"{prefix}.to_q", + ) + self.to_k = ColumnParallelLinear( + dim, + kv_dim, + bias=False, + gather_output=False, + quant_config=quant_config, + prefix=f"{prefix}.to_k", + ) + self.to_v = ColumnParallelLinear( + dim, + kv_dim, + bias=False, + gather_output=False, + quant_config=quant_config, + prefix=f"{prefix}.to_v", + ) + + if self.qk_norm: + self.norm_q = RMSNorm(self.head_dim, eps=eps) + self.norm_k = RMSNorm(self.head_dim, eps=eps) + else: + self.norm_q = None + self.norm_k = None + + self.to_out = nn.ModuleList( + [ + RowParallelLinear( + dim, + dim, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.to_out.0", + ) + ] + ) + + self.attn = USPAttention( + num_heads=self.local_num_heads, + head_size=self.head_dim, + num_kv_heads=self.local_num_kv_heads, + dropout_rate=0, + softmax_scale=None, + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ): + if self.use_fused_qkv: + qkv, _ = self.to_qkv(hidden_states) + q, k, v = qkv.split( + [ + self.local_num_heads * self.head_dim, + self.local_num_kv_heads * self.head_dim, + self.local_num_kv_heads * self.head_dim, + ], + dim=-1, + ) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + else: + q, _ = self.to_q(hidden_states) + k, _ = self.to_k(hidden_states) + v, _ = self.to_v(hidden_states) + q = q.view(*q.shape[:-1], self.local_num_heads, self.head_dim) + k = k.view(*k.shape[:-1], self.local_num_kv_heads, self.head_dim) + v = v.view(*v.shape[:-1], self.local_num_kv_heads, self.head_dim) + + if self.qk_norm: + q, k = apply_qk_norm( + q=q, + k=k, + q_norm=self.norm_q, + k_norm=self.norm_k, + head_dim=self.head_dim, + allow_inplace=True, + ) + + if freqs_cis is not None: + cos, sin = freqs_cis + if _is_cuda and q.shape == k.shape: + cos_sin_cache = torch.cat( + [ + cos.to(dtype=torch.float32).contiguous(), + sin.to(dtype=torch.float32).contiguous(), + ], + dim=-1, + ) + q, k = apply_flashinfer_rope_qk_inplace( + q, k, cos_sin_cache, is_neox=False + ) + else: + q = _apply_rotary_emb(q, cos, sin, is_neox_style=False) + k = _apply_rotary_emb(k, cos, sin, is_neox_style=False) + + hidden_states = self.attn(q, k, v) + hidden_states = hidden_states.flatten(2) + + hidden_states, _ = self.to_out[0](hidden_states) + + return hidden_states + + +class ZImageTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation=True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + self.layer_id = layer_id + self.modulation = modulation + + self.attention = ZImageAttention( + dim=dim, + num_heads=n_heads, + num_kv_heads=n_kv_heads, + qk_norm=qk_norm, + eps=1e-5, + quant_config=quant_config, + prefix=f"{prefix}.attention", + ) + + hidden_dim = int(dim / 3 * 8) + nunchaku_enabled = ( + isinstance(quant_config, NunchakuConfig) and is_nunchaku_available() + ) + if nunchaku_enabled: + import diffusers + + ff = diffusers.models.attention.FeedForward( + dim=dim, + dim_out=dim, + activation_fn="swiglu", + inner_dim=hidden_dim, + bias=False, + ) + nunchaku_kwargs = { + "precision": quant_config.precision, + "rank": quant_config.rank, + "act_unsigned": quant_config.act_unsigned, + } + self.feed_forward = NunchakuFeedForward(ff, **nunchaku_kwargs) + # NunchakuFeedForward overrides net[2].act_unsigned=True for int4 (GELU-specific + # optimization for non-negative activations). Z-Image uses SwiGLU whose output + # can be negative, so we must restore the original act_unsigned value. + if hasattr(self.feed_forward, "net") and len(self.feed_forward.net) > 2: + self.feed_forward.net[2].act_unsigned = quant_config.act_unsigned + else: + self.feed_forward = FeedForward(dim=dim, hidden_dim=hidden_dim) + + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + if modulation: + self.adaLN_modulation = nn.Sequential( + ReplicatedLinear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True) + ) + + def forward( + self, + x: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], + adaln_input: Optional[torch.Tensor] = None, + ): + if self.modulation: + assert adaln_input is not None + scale_msa_gate, _ = self.adaLN_modulation(adaln_input) + scale_msa, gate_msa, scale_mlp, gate_mlp = scale_msa_gate.unsqueeze( + 1 + ).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + + # Attention block + attn_out = self.attention( + self.attention_norm1(x) * scale_msa, + freqs_cis=freqs_cis, + ) + x = x + gate_msa * self.attention_norm2(attn_out) + + # FFN block + x = x + gate_mlp * self.ffn_norm2( + self.feed_forward( + self.ffn_norm1(x) * scale_mlp, + ) + ) + else: + # Attention block + attn_out = self.attention( + self.attention_norm1(x), + freqs_cis=freqs_cis, + ) + x = x + self.attention_norm2(attn_out) + + # FFN block + x = x + self.ffn_norm2( + self.feed_forward( + self.ffn_norm1(x), + ) + ) + + return x + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = ColumnParallelLinear( + hidden_size, out_channels, bias=True, gather_output=True + ) + + self.act = nn.SiLU() + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + ReplicatedLinear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), + ) + + def forward(self, x, c): + scale, _ = self.adaLN_modulation(c) + scale = 1.0 + scale + x = self.norm_final(x) * scale.unsqueeze(1) + x, _ = self.linear(x) + return x + + +class RopeEmbedder: + def __init__( + self, + theta: float = 256.0, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (64, 128, 128), + ): + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + assert len(axes_dims) == len( + axes_lens + ), "axes_dims and axes_lens must have the same length" + + self.cos_cached = None + self.sin_cached = None + + @staticmethod + def precompute_freqs(dim: List[int], end: List[int], theta: float = 256.0): + with torch.device("cpu"): + cos_list = [] + sin_list = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / ( + theta + ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d) + ) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + + cos_list.append(torch.cos(freqs)) + sin_list.append(torch.sin(freqs)) + + return cos_list, sin_list + + def __call__(self, ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + ids: [batch, len(axes_dims)] or [seq_len, len(axes_dims)] + Returns: + cos: [batch/seq, head_dim // 2] + sin: [batch/seq, head_dim // 2] + """ + assert ids.ndim == 2 + assert ids.shape[-1] == len(self.axes_dims) + device = ids.device + + if self.cos_cached is None: + self.cos_cached, self.sin_cached = self.precompute_freqs( + self.axes_dims, self.axes_lens, theta=self.theta + ) + self.cos_cached = [c.to(device) for c in self.cos_cached] + self.sin_cached = [s.to(device) for s in self.sin_cached] + else: + if self.cos_cached[0].device != device: + self.cos_cached = [c.to(device) for c in self.cos_cached] + self.sin_cached = [s.to(device) for s in self.sin_cached] + + cos_out = [] + sin_out = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + cos_out.append(self.cos_cached[i][index]) + sin_out.append(self.sin_cached[i][index]) + + return torch.cat(cos_out, dim=-1), torch.cat(sin_out, dim=-1) + + +class ZImageTransformer2DModel(CachableDiT, OffloadableDiTMixin): + _supports_gradient_checkpointing = True + _no_split_modules = ["ZImageTransformerBlock"] + _fsdp_shard_conditions = ZImageDitConfig().arch_config._fsdp_shard_conditions + param_names_mapping = ZImageDitConfig().arch_config.param_names_mapping + + param_names_mapping = ZImageDitConfig().arch_config.param_names_mapping + reverse_param_names_mapping = ( + ZImageDitConfig().arch_config.reverse_param_names_mapping + ) + + @classmethod + def get_nunchaku_quant_rules(cls) -> dict[str, list[str]]: + return { + "skip": [ + "norm", + "embed", + "rotary", + "pos_embed", + ], + "svdq_w4a4": [ + "attention.to_qkv", + "attention.to_out", + "img_mlp", + "txt_mlp", + ], + "awq_w4a16": [ + "img_mod", + "txt_mod", + ], + } + + def __init__( + self, + config: ZImageDitConfig, + hf_config: dict[str, Any], + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__(config=config, hf_config=hf_config) + + self.config_data = config # Store config + arch_config = config.arch_config + + self.in_channels = arch_config.in_channels + self.out_channels = arch_config.out_channels + self.all_patch_size = arch_config.all_patch_size + self.all_f_patch_size = arch_config.all_f_patch_size + self.dim = arch_config.dim + self.n_heads = arch_config.num_attention_heads + + self.rope_theta = arch_config.rope_theta + self.t_scale = arch_config.t_scale + self.gradient_checkpointing = False + + assert len(self.all_patch_size) == len(self.all_f_patch_size) + + all_x_embedder = {} + all_final_layer = {} + for patch_idx, (patch_size, f_patch_size) in enumerate( + zip(self.all_patch_size, self.all_f_patch_size) + ): + x_embedder = ColumnParallelLinear( + f_patch_size * patch_size * patch_size * self.in_channels, + self.dim, + bias=True, + gather_output=True, + ) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + final_layer = FinalLayer( + self.dim, patch_size * patch_size * f_patch_size * self.out_channels + ) + all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer + + self.all_x_embedder = nn.ModuleDict(all_x_embedder) + self.all_final_layer = nn.ModuleDict(all_final_layer) + + self.noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + self.dim, + self.n_heads, + arch_config.n_kv_heads, + arch_config.norm_eps, + arch_config.qk_norm, + modulation=True, + quant_config=quant_config, + prefix=f"noise_refiner.{layer_id}", + ) + for layer_id in range(arch_config.n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + layer_id, + self.dim, + self.n_heads, + arch_config.n_kv_heads, + arch_config.norm_eps, + arch_config.qk_norm, + modulation=False, + quant_config=quant_config, + prefix=f"context_refiner.{layer_id}", + ) + for layer_id in range(arch_config.n_refiner_layers) + ] + ) + self.t_embedder = TimestepEmbedder( + min(self.dim, ADALN_EMBED_DIM), mid_size=1024 + ) + + self.cap_embedder = nn.Sequential( + RMSNorm(arch_config.cap_feat_dim, eps=arch_config.norm_eps), + ReplicatedLinear(arch_config.cap_feat_dim, self.dim, bias=True), + ) + + self.x_pad_token = nn.Parameter(torch.empty((1, self.dim))) + self.cap_pad_token = nn.Parameter(torch.empty((1, self.dim))) + + self.layers = nn.ModuleList( + [ + ZImageTransformerBlock( + layer_id, + self.dim, + self.n_heads, + arch_config.n_kv_heads, + arch_config.norm_eps, + arch_config.qk_norm, + quant_config=quant_config, + prefix=f"layers.{layer_id}", + ) + for layer_id in range(arch_config.num_layers) + ] + ) + head_dim = self.dim // self.n_heads + assert head_dim == sum(arch_config.axes_dims) + self.axes_dims = arch_config.axes_dims + self.axes_lens = arch_config.axes_lens + + self.rotary_emb = RopeEmbedder( + theta=self.rope_theta, axes_dims=self.axes_dims, axes_lens=self.axes_lens + ) + self.layer_names = ["layers"] + + def unpatchify( + self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size + ) -> List[torch.Tensor]: + pH = pW = patch_size + pF = f_patch_size + bsz = len(x) + assert len(size) == bsz + for i in range(bsz): + F, H, W = size[i] + ori_len = (F // pF) * (H // pH) * (W // pW) + # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" + x[i] = ( + x[i][:ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + return x + + @staticmethod + def create_coordinate_grid(size, start=None, device=None): + if start is None: + start = (0 for _ in size) + + axes = [ + torch.arange(x0, x0 + span, dtype=torch.int32, device=device) + for x0, span in zip(start, size) + ] + grids = torch.meshgrid(axes, indexing="ij") + return torch.stack(grids, dim=-1) + + def patchify_and_embed( + self, + all_image: List[torch.Tensor], + all_cap_feats: List[torch.Tensor], + patch_size: int, + f_patch_size: int, + ): + assert len(all_image) == len(all_cap_feats) == 1 + + image = all_image[0] # C, F, H, W + cap_feat = all_cap_feats[0] # L, D + pH = pW = patch_size + pF = f_patch_size + device = image.device + + all_image_out = [] + all_image_size = [] + all_cap_feats_out = [] + + # ------------ Process Caption ------------ + cap_ori_len = cap_feat.size(0) + cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF + + # padded feature + cap_padded_feat = torch.cat( + [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], + dim=0, + ) + all_cap_feats_out.append(cap_padded_feat) + + # ------------ Process Image ------------ + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape( + F_tokens * H_tokens * W_tokens, pF * pH * pW * C + ) + image_ori_len = image.size(0) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + # padded feature + image_padded_feat = torch.cat( + [image, image[-1:].repeat(image_padding_len, 1)], + dim=0, + ) + all_image_out.append(image_padded_feat) + + return ( + all_image_out, + all_cap_feats_out, + all_image_size, + ) + + def forward( + self, + hidden_states: List[torch.Tensor], + encoder_hidden_states: List[torch.Tensor], + timestep, + guidance=0, + patch_size=2, + f_patch_size=1, + freqs_cis=None, + **kwargs, + ): + assert patch_size in self.all_patch_size + assert f_patch_size in self.all_f_patch_size + + x = hidden_states + cap_feats = encoder_hidden_states + timestep = 1000.0 - timestep + t = timestep + bsz = 1 + device = x[0].device + t = self.t_embedder(t) + adaln_input = t.type_as(x) + ( + x, + cap_feats, + x_size, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + + x = torch.cat(x, dim=0) + x, _ = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + x_freqs_cis = freqs_cis[1] + + x = x.unsqueeze(0) + x_freqs_cis = x_freqs_cis + for layer in self.noise_refiner: + x = layer(x, x_freqs_cis, adaln_input) + + cap_feats = torch.cat(cap_feats, dim=0) + + cap_feats, _ = self.cap_embedder(cap_feats) + + cap_freqs_cis = freqs_cis[0] + + cap_feats = cap_feats.unsqueeze(0) + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_freqs_cis) + + unified = torch.cat([x, cap_feats], dim=1) + unified_freqs_cis = ( + torch.cat([x_freqs_cis[0], cap_freqs_cis[0]], dim=0), + torch.cat([x_freqs_cis[1], cap_freqs_cis[1]], dim=0), + ) + + for layer in self.layers: + unified = layer(unified, unified_freqs_cis, adaln_input) + + unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"]( + unified, adaln_input + ) + unified = list(unified.unbind(dim=0)) + x = self.unpatchify(unified, x_size, patch_size, f_patch_size) + + return -x[0] + + +EntryClass = ZImageTransformer2DModel diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/encoders/base.py b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/base.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3eece96badf3af741d4ae1e0bdc3d11094a87c --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/base.py @@ -0,0 +1,71 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from dataclasses import field + +import torch +from torch import nn + +from sglang.multimodal_gen.configs.models.encoders import ( + BaseEncoderOutput, + ImageEncoderConfig, + TextEncoderConfig, +) +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum + + +class TextEncoder(nn.Module, ABC): + _fsdp_shard_conditions: list = field(default_factory=lambda: []) + _stacked_params_mapping: list[tuple[str, str, str]] = field(default_factory=list) + _supported_attention_backends: set[AttentionBackendEnum] = ( + TextEncoderConfig()._supported_attention_backends + ) + + def __init__(self, config: TextEncoderConfig) -> None: + super().__init__() + self.config = config + self._fsdp_shard_conditions = config.arch_config._fsdp_shard_conditions + self._stacked_params_mapping = config.arch_config.stacked_params_mapping + if not self.supported_attention_backends: + raise ValueError( + f"Subclass {self.__class__.__name__} must define _supported_attention_backends" + ) + + @abstractmethod + def forward( + self, + input_ids: torch.Tensor | None, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + **kwargs, + ) -> BaseEncoderOutput: + pass + + @property + def supported_attention_backends(self) -> set[AttentionBackendEnum]: + return self._supported_attention_backends + + +class ImageEncoder(nn.Module, ABC): + _supported_attention_backends: set[AttentionBackendEnum] = ( + ImageEncoderConfig()._supported_attention_backends + ) + + def __init__(self, config: ImageEncoderConfig) -> None: + super().__init__() + self.config = config + if not self.supported_attention_backends: + raise ValueError( + f"Subclass {self.__class__.__name__} must define _supported_attention_backends" + ) + + @abstractmethod + def forward(self, pixel_values: torch.Tensor, **kwargs) -> BaseEncoderOutput: + pass + + @property + def supported_attention_backends(self) -> set[AttentionBackendEnum]: + return self._supported_attention_backends diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/encoders/bert.py b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..5a423e51b8965f76c2cbad23abd3bbbc6c08dfbd --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/bert.py @@ -0,0 +1,46 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# type: ignore +import os + +import torch +import torch.nn as nn +from transformers import BertModel, BertTokenizer + + +class HunyuanClip(nn.Module): + """ + Hunyuan clip code copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py + hunyuan's clip used BertModel and BertTokenizer, so we copy it. + """ + + def __init__(self, model_dir, max_length=77): + super().__init__() + + self.max_length = max_length + self.tokenizer = BertTokenizer.from_pretrained( + os.path.join(model_dir, "tokenizer") + ) + self.text_encoder = BertModel.from_pretrained( + os.path.join(model_dir, "clip_text_encoder") + ) + + @torch.no_grad + def forward(self, prompts, with_mask=True): + self.device = next(self.text_encoder.parameters()).device + text_inputs = self.tokenizer( + prompts, + padding="max_length", + max_length=self.max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + prompt_embeds = self.text_encoder( + text_inputs.input_ids.to(self.device), + attention_mask=( + text_inputs.attention_mask.to(self.device) if with_mask else None + ), + ) + return prompt_embeds.last_hidden_state, prompt_embeds.pooler_output diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/encoders/clip.py b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..83fdefd8cb3eaa7b2dfa5a882b5636ec77d4273a --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/clip.py @@ -0,0 +1,758 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/clip.py +# Adapted from transformers: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py +"""Minimal implementation of CLIPVisionModel intended to be only used +within a vision language model.""" + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn + +from sglang.multimodal_gen.configs.models.encoders import ( + BaseEncoderOutput, + CLIPTextConfig, + CLIPVisionConfig, +) +from sglang.multimodal_gen.runtime.distributed import divide, get_tp_world_size +from sglang.multimodal_gen.runtime.layers.activation import get_act_fn +from sglang.multimodal_gen.runtime.layers.attention import LocalAttention +from sglang.multimodal_gen.runtime.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig + +# TODO: support quantization +# from vllm.model_executor.layers.quantization import QuantizationConfig +from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader +from sglang.multimodal_gen.runtime.models.encoders.base import ImageEncoder, TextEncoder +from sglang.multimodal_gen.runtime.models.encoders.vision import ( + resolve_visual_encoder_outputs, +) +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa +class CLIPVisionEmbeddings(nn.Module): + + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + assert self.image_size % self.patch_size == 0 + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions).expand((1, -1)), + persistent=False, + ) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype) + ) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + + return embeddings + + +class CLIPTextEmbeddings(nn.Module): + + def __init__(self, config: CLIPTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding( + config.max_position_embeddings, embed_dim + ) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).expand((1, -1)), + persistent=False, + ) + + def forward( + self, + input_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + ) -> torch.Tensor: + if input_ids is not None: + seq_length = input_ids.shape[-1] + elif inputs_embeds is not None: + seq_length = inputs_embeds.shape[-2] + else: + raise ValueError("Either input_ids or inputs_embeds must be provided.") + + max_position_embedding = self.position_embedding.weight.shape[0] + + if seq_length > max_position_embedding: + raise ValueError( + f"Sequence length must be less than max_position_embeddings (got `sequence length`: " + f"{seq_length} and max_position_embeddings: {max_position_embedding}" + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class CLIPAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: CLIPVisionConfig | CLIPTextConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.out_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + self.tp_size = get_tp_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + self.attn = LocalAttention( + self.num_heads_per_partition, + self.head_dim, + self.num_heads_per_partition, + softmax_scale=self.scale, + causal=True, + supported_attention_backends=config._supported_attention_backends, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ): + """Input shape: Batch x Time x Channel""" + + qkv_states, _ = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + # use flash_attn_func + query_states = query_states.reshape( + query_states.shape[0], + query_states.shape[1], + self.num_heads_per_partition, + self.head_dim, + ) + key_states = key_states.reshape( + key_states.shape[0], + key_states.shape[1], + self.num_heads_per_partition, + self.head_dim, + ) + value_states = value_states.reshape( + value_states.shape[0], + value_states.shape[1], + self.num_heads_per_partition, + self.head_dim, + ) + + if self.attn.backend == AttentionBackendEnum.TORCH_SDPA: + query_states = query_states.transpose(1, 2) # [B, H, S, D] + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if current_platform.is_rocm() or current_platform.is_musa(): + # ROCm: Using both is_causal=True and attn_mask causes NaN. + # Use is_causal=True alone (padding mask not needed for CLIP + # since pooler_output comes from EOS token before padding). + # XXX (MUSA): Torch SDPA on MUSA currently does not support + # using both `attn_mask` and `is_causal=True` simultaneously. + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=None, + is_causal=True, + scale=self.scale, + ) + else: + if attention_mask is not None: + # SDPA requires [B, 1, 1, S] or [B, S, S] format mask + if attention_mask.dim() == 2: + attn_mask = attention_mask[:, None, None, :].to( + dtype=query_states.dtype + ) + attn_mask = (1.0 - attn_mask) * torch.finfo( + query_states.dtype + ).min + else: + attn_mask = attention_mask + else: + attn_mask = None + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attn_mask, + is_causal=attention_mask is None, + scale=self.scale, + ) + attn_output = attn_output.transpose(1, 2) + else: + # Use LocalAttention (doesn't support attention_mask, but maintains compatibility) + attn_output = self.attn(query_states, key_states, value_states) + + attn_output = attn_output.reshape( + attn_output.shape[0], + attn_output.shape[1], + self.num_heads_per_partition * self.head_dim, + ) + attn_output, _ = self.out_proj(attn_output) + + return attn_output, None + + +class CLIPMLP(nn.Module): + + def __init__( + self, + config: CLIPVisionConfig | CLIPTextConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + + return hidden_states + + +class CLIPEncoderLayer(nn.Module): + + def __init__( + self, + config: CLIPTextConfig | CLIPVisionConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.self_attn = CLIPAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = CLIPMLP(config, quant_config=quant_config, prefix=f"{prefix}.mlp") + self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class CLIPEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self + attention layers. Each layer is a [`CLIPEncoderLayer`]. + + Args: + config: CLIPConfig + """ + + def __init__( + self, + config: CLIPVisionConfig | CLIPTextConfig, + quant_config: QuantizationConfig | None = None, + num_hidden_layers_override: int | None = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + self.layers = nn.ModuleList( + [ + CLIPEncoderLayer( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) + for layer_idx in range(num_hidden_layers) + ] + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + return_all_hidden_states: bool, + attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor | list[torch.Tensor]: + hidden_states_pool = [inputs_embeds] + hidden_states = inputs_embeds + + for idx, encoder_layer in enumerate(self.layers): + hidden_states = encoder_layer( + hidden_states, + attention_mask=attention_mask, + ) + if return_all_hidden_states: + hidden_states_pool.append(hidden_states) + # If we have multiple feature sample layers, we return all hidden + # states in order and grab the ones we need by index. + if return_all_hidden_states: + return hidden_states_pool + return [hidden_states] + + +class CLIPTextTransformer(nn.Module): + + def __init__( + self, + config: CLIPTextConfig, + quant_config: QuantizationConfig | None = None, + num_hidden_layers_override: int | None = None, + prefix: str = "", + ): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPTextEmbeddings(config) + + self.encoder = CLIPEncoder( + config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + prefix=prefix, + ) + + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + # For `pooled_output` computation + self.eos_token_id = config.eos_token_id + + def forward( + self, + input_ids: torch.Tensor | None, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + ) -> BaseEncoderOutput: + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + # causal_attention_mask = _create_4d_causal_attention_mask( + # input_shape, hidden_states.dtype, device=hidden_states.device + # ) + + # # expand attention_mask + # if attention_mask is not None and not self._use_flash_attention_2: + # raise NotImplementedError("attention_mask is not supported for CLIPTextTransformer") + # # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + # attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + return_all_hidden_states=output_hidden_states, + attention_mask=attention_mask, + ) + + last_hidden_state = encoder_outputs[-1] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + if self.eos_token_id == 2: + # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. + # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added + # ------------------------------------------------------------ + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange( + last_hidden_state.shape[0], device=last_hidden_state.device + ), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax( + dim=-1 + ), + ] + else: + # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) + pooled_output = last_hidden_state[ + torch.arange( + last_hidden_state.shape[0], device=last_hidden_state.device + ), + # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) + # Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer) + ( + input_ids.to(dtype=torch.int, device=last_hidden_state.device) + == self.eos_token_id + ) + .int() + .argmax(dim=-1), + ] + + return BaseEncoderOutput( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs, + # attentions=encoder_outputs.attentions, + ) + + +class CLIPTextModel(TextEncoder): + + def __init__( + self, + config: CLIPTextConfig, + ) -> None: + super().__init__(config) + self.text_model = CLIPTextTransformer( + config=config, quant_config=config.quant_config, prefix=config.prefix + ) + + def forward( + self, + input_ids: torch.Tensor | None, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + **kwargs, + ) -> BaseEncoderOutput: + + outputs: BaseEncoderOutput = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_hidden_states=output_hidden_states, + ) + return outputs + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + + # Define mapping for stacked parameters + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + # Handle q_proj, k_proj, v_proj -> qkv_proj mapping + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name in name: + # Replace the weight name with the parameter name + model_param_name = name.replace(weight_name, param_name) + + if model_param_name in params_dict: + param = params_dict[model_param_name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(model_param_name) + break + else: + # Use default weight loader for all other parameters + if name in params_dict: + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class CLIPVisionTransformer(nn.Module): + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: QuantizationConfig | None = None, + num_hidden_layers_override: int | None = None, + require_post_norm: bool | None = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPVisionEmbeddings(config) + + # NOTE: This typo of "layrnorm" is not fixed on purpose to match + # the original transformers code and name of the model weights. + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.encoder = CLIPEncoder( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + prefix=f"{prefix}.encoder", + ) + + num_hidden_layers = config.num_hidden_layers + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + + # If possible, skip post_layernorm to conserve memory + if require_post_norm is None: + require_post_norm = len(self.encoder.layers) == num_hidden_layers + + if require_post_norm: + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + else: + self.post_layernorm = None + + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + feature_sample_layers: list[int] | None = None, + ) -> BaseEncoderOutput: + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + return_all_hidden_states = output_hidden_states or ( + feature_sample_layers is not None + ) + + # Produces either the last layer output or all of the hidden states, + # depending on if we have feature_sample_layers or not + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + return_all_hidden_states=return_all_hidden_states, + ) + + if not return_all_hidden_states: + encoder_outputs = encoder_outputs[0] + + # Handle post-norm (if applicable) and stacks feature layers if needed + encoder_outputs = resolve_visual_encoder_outputs( + encoder_outputs, + feature_sample_layers, + self.post_layernorm, + self.config.num_hidden_layers, + ) + + if return_all_hidden_states: + return BaseEncoderOutput(hidden_states=encoder_outputs) + + return BaseEncoderOutput(last_hidden_state=encoder_outputs) + + +class CLIPVisionModel(ImageEncoder): + config_class = CLIPVisionConfig + main_input_name = "pixel_values" + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + + def __init__(self, config: CLIPVisionConfig) -> None: + super().__init__(config) + self.vision_model = CLIPVisionTransformer( + config=config, + quant_config=config.quant_config, + num_hidden_layers_override=config.num_hidden_layers_override, + require_post_norm=config.require_post_norm, + prefix=f"{config.prefix}.vision_model", + ) + + def forward( + self, + pixel_values: torch.Tensor, + feature_sample_layers: list[int] | None = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> BaseEncoderOutput: + base_encoder_output = self.vision_model( + pixel_values, + output_hidden_states=output_hidden_states, + feature_sample_layers=feature_sample_layers, + ) + + return base_encoder_output + + @property + def device(self): + return next(self.parameters()).device + + # (TODO) Add prefix argument for filtering out weights to be loaded + # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + layer_count = len(self.vision_model.encoder.layers) + + for name, loaded_weight in weights: + if name.startswith("visual_projection"): + continue + # post_layernorm is not needed in CLIPVisionModel + if ( + name.startswith("vision_model.post_layernorm") + and self.vision_model.post_layernorm is None + ): + continue + + # omit layers when num_hidden_layers_override is set + if name.startswith("vision_model.encoder.layers"): + layer_idx = int(name.split(".")[3]) + if layer_idx >= layer_count: + continue + + for ( + param_name, + weight_name, + shard_id, + ) in self.config.arch_config.stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class BertModel(CLIPTextModel): + pass + + +EntryClass = [CLIPTextModel, CLIPVisionModel] diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/encoders/gemma_3.py b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/gemma_3.py new file mode 100644 index 0000000000000000000000000000000000000000..0927645fbfe656e8628d196d5ed5abddcdf89d17 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/gemma_3.py @@ -0,0 +1,1187 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from sglang: python/sglang/srt/models/gemma3_causal.py + +import logging +from functools import partial +from typing import Any, Iterable, Optional, Set, Tuple + +import torch +from torch import nn + +from sglang.multimodal_gen.configs.models.encoders.base import BaseEncoderOutput +from sglang.multimodal_gen.configs.models.encoders.gemma_3 import Gemma3Config +from sglang.multimodal_gen.runtime.distributed import get_tp_world_size +from sglang.multimodal_gen.runtime.layers.activation import GeluAndMul +from sglang.multimodal_gen.runtime.layers.attention import LocalAttention +from sglang.multimodal_gen.runtime.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig +from sglang.multimodal_gen.runtime.layers.rotary_embedding import get_rope +from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader +from sglang.multimodal_gen.runtime.utils.common import add_prefix + +logger = logging.getLogger(__name__) + + +def get_attention_sliding_window_size(config): + return config.sliding_window - 1 + + +class Gemma3RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class Gemma3MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "gelu_pytorch_tanh": + raise ValueError( + "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " + "function. Please set `hidden_activation` to " + "`gelu_pytorch_tanh`." + ) + self.act_fn = GeluAndMul(approximate="tanh") + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +class Gemma3Attention(nn.Module): + def __init__( + self, + layer_id: int, + config: Gemma3Config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_id = layer_id + self.hidden_size = hidden_size + tp_size = get_tp_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.head_dim = getattr( + config.text_config, "head_dim", self.hidden_size // self.total_num_heads + ) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.text_config.query_pre_attn_scalar**-0.5 + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=config.text_config.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=config.text_config.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.is_sliding = ( + config.text_config.layer_types[layer_id] == "sliding_attention" + ) + + # Initialize the rotary embedding. + if self.is_sliding: + # Local attention. + self.rope_theta = config.text_config.rope_local_base_freq + rope_scaling = None # Default + # sliding window + self.sliding_window = get_attention_sliding_window_size(config.text_config) + # (left, right) = (window, 0) effectively for causal + self.window_size = (self.sliding_window, 0) + else: + # Global attention. + self.rope_theta = config.text_config.rope_theta + rope_scaling = config.text_config.rope_scaling + self.sliding_window = None + self.window_size = (-1, -1) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=config.text_config.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + + # NOTE(gmixiaojin): The shared RotaryEmbedding above computes inv_freq on + # GPU and uses the x1*cos - x2*sin formula, which causes slight + # numerical differences vs HuggingFace (see the NOTE in + # rotary_embedding.py:_compute_inv_freq). For HF-exact alignment we + # precompute inv_freq on CPU and use rotate_half in self.rotary_emb(). + freq_indices = ( + torch.arange(0, self.head_dim, 2, dtype=torch.int64).float() / self.head_dim + ) + inv_freq = 1.0 / (self.rope_theta**freq_indices) + if rope_scaling and rope_scaling.get("factor"): + inv_freq = inv_freq / float(rope_scaling["factor"]) + self.register_buffer("_hf_inv_freq", inv_freq, persistent=False) + + # Local Attention not support attention mask, we use global attention instead. + # self.attn = LocalAttention( + # self.num_heads, + # self.head_dim, + # self.num_kv_heads, + # softmax_scale=self.scaling, + # causal=True, + # supported_attention_backends=config._supported_attention_backends, + # window_size=self.window_size, + # ) + + # Gemma3 adds normalization for q and k + self.q_norm = Gemma3RMSNorm( + dim=self.head_dim, eps=config.text_config.rms_norm_eps + ) + self.k_norm = Gemma3RMSNorm( + dim=self.head_dim, eps=config.text_config.rms_norm_eps + ) + + def rotary_emb(self, positions, q, k): + """Apply RoPE using HF-exact formula with precomputed inv_freq.""" + positions_flat = positions.flatten().float() + num_tokens = positions_flat.shape[0] + + with torch.autocast(device_type=q.device.type, enabled=False): + freqs = torch.outer(positions_flat, self._hf_inv_freq.float()) + emb = freqs.repeat(1, 2) + cos = emb.cos().to(q.dtype).unsqueeze(1) + sin = emb.sin().to(q.dtype).unsqueeze(1) + + q = q.reshape(num_tokens, -1, self.head_dim) + k = k.reshape(num_tokens, -1, self.head_dim) + q = q * cos + _rotate_half(q) * sin + k = k * cos + _rotate_half(k) * sin + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + batch_size, seq_len, _ = q.shape + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim) + k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) + v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) + + # Apply QK Norm + q = self.q_norm(q) + k = self.k_norm(k) + + # Apply RoPE + q, k = self.rotary_emb(positions, q, k) + q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim) + k = k.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim) + + # TODO(FlamingoPg): Support LocalAttention + query = q.transpose(1, 2) + key = k.transpose(1, 2) + value = v.transpose(1, 2) + + min_val = torch.finfo(query.dtype).min + attn_mask = torch.zeros( + (seq_len, seq_len), + device=hidden_states.device, + dtype=query.dtype, + ) + causal = torch.triu( + torch.ones( + (seq_len, seq_len), device=hidden_states.device, dtype=torch.bool + ), + diagonal=1, + ) + attn_mask = attn_mask.masked_fill(causal, min_val) + if self.is_sliding and self.sliding_window is not None: + idx = torch.arange(seq_len, device=hidden_states.device) + dist = idx[None, :] - idx[:, None] + too_far = dist > self.sliding_window + attn_mask = attn_mask.masked_fill(too_far, min_val) + + key_pad = ~attention_mask.to(torch.bool) + attn_mask = attn_mask[None, None, :, :].expand(batch_size, 1, seq_len, seq_len) + attn_mask = attn_mask.masked_fill( + key_pad[:, None, None, :].expand(batch_size, 1, seq_len, seq_len), + min_val, + ) + + attn_kwargs = { + "attn_mask": attn_mask, + "dropout_p": 0.0, + "is_causal": False, + "scale": self.scaling, + } + if query.shape[1] != key.shape[1]: + attn_kwargs["enable_gqa"] = True + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, key, value, **attn_kwargs + ) + attn_output = attn_output.transpose(1, 2) + + attn_output = attn_output.reshape( + batch_size, seq_len, self.num_heads * self.head_dim + ) + + output, _ = self.o_proj(attn_output) + return output + + +class Gemma3DecoderLayer(nn.Module): + def __init__( + self, + layer_id: int, + config: Gemma3Config, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.text_config.hidden_size + self.self_attn = Gemma3Attention( + layer_id=layer_id, + config=config, + hidden_size=self.hidden_size, + num_heads=config.text_config.num_attention_heads, + num_kv_heads=getattr( + config.text_config, + "num_key_value_heads", + config.text_config.num_attention_heads, + ), + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = Gemma3MLP( + hidden_size=self.hidden_size, + intermediate_size=config.text_config.intermediate_size, + hidden_act=config.text_config.hidden_activation, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = Gemma3RMSNorm( + config.text_config.hidden_size, eps=config.text_config.rms_norm_eps + ) + self.post_attention_layernorm = Gemma3RMSNorm( + config.text_config.hidden_size, eps=config.text_config.rms_norm_eps + ) + self.pre_feedforward_layernorm = Gemma3RMSNorm( + config.text_config.hidden_size, eps=config.text_config.rms_norm_eps + ) + self.post_feedforward_layernorm = Gemma3RMSNorm( + config.text_config.hidden_size, eps=config.text_config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + # Gemma3 uses "sandwich norm": + # x = x + norm(attn(norm(x))) + # So we treat input hidden_states as the residual base. + + if residual is not None: + hidden_states = hidden_states + residual + residual = None + + residual_input = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual_input + hidden_states + + # MLP + residual_mlp = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual_mlp + hidden_states + + return hidden_states, None + + +class Gemma3TextScaledWordEmbedding(nn.Embedding): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int, + embed_scale: Optional[float] = 1.0, + ): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + +# --- Siglip Vision Model Implementation --- + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(1.702 * x) + + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + # Use simple Embedding for position embeddings (usually small enough) + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions).expand((1, -1)), + persistent=False, + ) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype) + ) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + embeddings = embeddings + self.position_embedding(self.position_ids) + + return embeddings + + +class SiglipMLP(nn.Module): + def __init__( + self, + config, + act_layer: type[nn.Module] = QuickGELU, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config, + prefix=add_prefix("fc1", prefix), + ) + self.act = act_layer() + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("fc2", prefix), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_parallel, _ = self.fc1(x) + x_parallel = self.act(x_parallel) + x, _ = self.fc2(x_parallel) + return x + + +class SiglipAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + tp_size = get_tp_world_size() + self.head_dim = hidden_size // num_heads + self.num_heads_per_partition = num_heads // tp_size + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=True, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + + self.out_proj = RowParallelLinear( + input_size=hidden_size, + output_size=hidden_size, + bias=True, + quant_config=quant_config, + prefix=add_prefix("out_proj", prefix), + ) + + self.attn = LocalAttention( + num_heads=self.num_heads_per_partition, + head_size=self.head_dim, + num_kv_heads=self.num_heads_per_partition, + softmax_scale=self.scaling, + causal=False, # Bidirectional for Vision + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.hidden_size // get_tp_world_size()] * 3, dim=-1) + + batch_size, seq_len, _ = q.shape + q = q.view(batch_size, seq_len, self.num_heads_per_partition, self.head_dim) + k = k.view(batch_size, seq_len, self.num_heads_per_partition, self.head_dim) + v = v.view(batch_size, seq_len, self.num_heads_per_partition, self.head_dim) + + attn_output = self.attn(q, k, v) + + attn_output = attn_output.reshape( + batch_size, seq_len, self.hidden_size // get_tp_world_size() + ) + + output, _ = self.out_proj(attn_output) + return output + + +class SiglipEncoderLayer(nn.Module): + def __init__( + self, + config, + act_layer: type[nn.Module] = QuickGELU, + norm_layer: type[nn.Module] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps) + self.layer_norm1 = norm_layer(config.hidden_size) + self.layer_norm2 = norm_layer(config.hidden_size) + self.self_attn = SiglipAttention( + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + ) + self.mlp = SiglipMLP( + config, + act_layer=act_layer, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class SiglipEncoder(nn.Module): + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + num_hidden_layers = config.num_hidden_layers + norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps) + self.layers = nn.ModuleList( + [ + SiglipEncoderLayer( + config=config, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=add_prefix(f"layers.{layer_idx}", prefix), + ) + for layer_idx in range(num_hidden_layers) + ] + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + ) -> torch.Tensor: + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states) + return hidden_states + + +class SiglipVisionTransformer(nn.Module): + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder( + config=config, + quant_config=quant_config, + prefix=add_prefix("encoder", prefix), + ) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @property + def device(self) -> torch.device: + return self.encoder.layers[0].layer_norm1.weight.device + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + hidden_states = self.embeddings(pixel_values.to(self.device)) + last_hidden_state = self.encoder(inputs_embeds=hidden_states) + last_hidden_state = self.post_layernorm(last_hidden_state) + return last_hidden_state + + +class SiglipVisionModel(nn.Module): + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.vision_model = SiglipVisionTransformer( + config, quant_config, prefix=add_prefix("vision_model", prefix) + ) + + @property + def device(self) -> torch.device: + return self.vision_model.device + + def forward(self, pixel_values: torch.Tensor): + return self.vision_model(pixel_values) + + +class Gemma3MultiModalProjector(nn.Module): + """Projector for Gemma3 multimodal.""" + + def __init__(self, config: Gemma3Config): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.zeros( + config.vision_config.hidden_size, config.text_config.hidden_size + ) + ) + + self.mm_soft_emb_norm = Gemma3RMSNorm( + config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps + ) + + self.patches_per_image = int( + config.vision_config.image_size // config.vision_config.patch_size + ) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d( + kernel_size=self.kernel_size, stride=self.kernel_size + ) + + def forward(self, vision_outputs: torch.Tensor) -> torch.Tensor: + batch_size, seq_length, hidden_size = vision_outputs.shape + + # Reshape for pooling + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, hidden_size, self.patches_per_image, self.patches_per_image + ) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + # Apply pooling + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + # Apply normalization + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + # Project to text embedding space + projected_vision_outputs = torch.matmul( + normed_vision_outputs, self.mm_input_projection_weight + ) + + return projected_vision_outputs.type_as(vision_outputs) + + +class Gemma3TextModel(nn.Module): + def __init__(self, config: Gemma3Config): + super().__init__() + self.config = config + # TODO(yinfan.1024) support text encoding model quant later + self.quant_config = None + + # Use VocabParallelEmbedding + from sglang.multimodal_gen.runtime.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, + ) + + self.vocab_size = config.text_config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.text_config.hidden_size, + org_num_embeddings=config.text_config.vocab_size, + quant_config=self.quant_config, + ) + self.embed_scale = config.text_config.hidden_size**0.5 + + self.layers = nn.ModuleList( + [ + Gemma3DecoderLayer( + layer_id=i, + config=config, + quant_config=self.quant_config, + prefix=f"{config.text_config.prefix}.layers.{i}", + ) + for i in range(config.text_config.num_hidden_layers) + ] + ) + + self.norm = Gemma3RMSNorm( + config.text_config.hidden_size, eps=config.text_config.rms_norm_eps + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + out = self.embed_tokens(input_ids) + return out * torch.tensor(self.embed_scale, device=out.device, dtype=out.dtype) + + def forward( + self, + input_ids: torch.Tensor | None, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + **kwargs, + ) -> BaseEncoderOutput: + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + + residual = None + + if position_ids is None: + position_ids = torch.arange( + 0, hidden_states.shape[1], device=hidden_states.device + ).unsqueeze(0) + + all_hidden_states: tuple[Any, ...] | None = () if output_hidden_states else None + + for layer in self.layers: + if all_hidden_states is not None: + all_hidden_states += (hidden_states,) + + hidden_states, residual = layer( + position_ids, + hidden_states, + residual, + attention_mask=attention_mask, + ) + + hidden_states = self.norm(hidden_states) + + if all_hidden_states is not None: + all_hidden_states += (hidden_states,) + + output = BaseEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + ) + + return output + + def load_weights(self, weights: Any) -> set[str]: + # Copied from LlamaModel.load_weights but adapted + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + def _load_with_shard_id( + weight_loader, param, loaded_weight: torch.Tensor, shard_id + ) -> None: + """Call param.weight_loader with best-effort shard_id normalization. + + Different fused-QKV implementations expect different shard_id types: + - Some expect strings: "q"/"k"/"v" + - Some expect integer indices: 0/1/2 + We try the provided shard_id first, then fall back between str/int forms. + """ + try: + weight_loader(param, loaded_weight, shard_id) + return + except (AssertionError, TypeError): + pass + + # Fall back between common representations. + if isinstance(shard_id, str): + mapping = {"q": 0, "k": 1, "v": 2} + if shard_id in mapping: + weight_loader(param, loaded_weight, mapping[shard_id]) + return + if shard_id.isdigit(): + weight_loader(param, loaded_weight, int(shard_id)) + return + elif isinstance(shard_id, int): + mapping = {0: "q", 1: "k", 2: "v"} + if shard_id in mapping: + weight_loader(param, loaded_weight, mapping[shard_id]) + return + + # Re-raise with a clearer message. + raise TypeError( + f"Unsupported shard_id={shard_id!r} for weight_loader={weight_loader} " + f"(param={getattr(param, 'name', '')})." + ) + + stacked_params_mapping = getattr( + getattr(self.config, "arch_config", object()), + "stacked_params_mapping", + None, + ) + if stacked_params_mapping is None: + stacked_params_mapping = [ + # Fused QKV shards; downstream loaders may want "q/k/v" or 0/1/2. + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + # The config has stacked_params_mapping + for ( + param_name, + weight_name, + shard_id, + ) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + _load_with_shard_id(weight_loader, param, loaded_weight, shard_id) + break + else: + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + loaded_params.add(name) + return loaded_params + + +class Gemma3ForConditionalGeneration(nn.Module): + def __init__( + self, + config: Gemma3Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.text_config = config.text_config + + # Vision Tower + self.vision_tower = SiglipVisionModel( + config=config.vision_config, + quant_config=quant_config, + prefix=add_prefix("vision_tower", prefix), + ) + + # Projector + self.multi_modal_projector = Gemma3MultiModalProjector(config) + + # Text Model + self.language_model = Gemma3TextModel(config) + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor, + ) -> torch.Tensor: + image_token_index = int(getattr(self.config, "image_token_index", -1)) + if image_token_index < 0: + image_token_index = int(getattr(self.text_config, "image_token_index", -1)) + special_image_mask = input_ids == image_token_index + n_image_tokens = int(special_image_mask.sum().item()) + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds) + n_image_features = int(image_features.shape[0] * image_features.shape[1]) + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + def forward( + self, + input_ids: torch.LongTensor, + pixel_values: torch.FloatTensor | None = None, + **kwargs, + ): + vocab_size = int(self.language_model.vocab_size) + image_token_index = int(getattr(self.config, "image_token_index", -1)) + if image_token_index < 0: + image_token_index = int(getattr(self.text_config, "image_token_index", -1)) + + if input_ids is not None and image_token_index >= vocab_size: + special_image_mask = input_ids == image_token_index + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids + + inputs_embeds = self.language_model.get_input_embeddings(llm_input_ids) + + if pixel_values is not None: + if pixel_values.dim() == 5: + pixel_values = pixel_values.reshape( + -1, + pixel_values.shape[2], + pixel_values.shape[3], + pixel_values.shape[4], + ) + elif pixel_values.dim() == 3: + pixel_values = pixel_values.unsqueeze(0) + elif pixel_values.dim() != 4: + raise ValueError(f"Unexpected pixel_values shape: {pixel_values.shape}") + + vision_outputs = self.vision_tower(pixel_values) + image_features = self.multi_modal_projector(vision_outputs) + image_features = image_features.to( + device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) + inputs_embeds = inputs_embeds.masked_scatter( + special_image_mask, image_features + ) + + return self.language_model.forward( + llm_input_ids, inputs_embeds=inputs_embeds, **kwargs + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: + loaded_params: Set[str] = set() + params_dict = dict(self.named_parameters()) + + def _load_with_shard_id( + weight_loader, param, loaded_weight: torch.Tensor, shard_id + ) -> None: + """Call param.weight_loader with best-effort shard_id normalization. + + Different fused-QKV implementations expect different shard_id types: + - Some expect strings: "q"/"k"/"v" + - Some expect integer indices: 0/1/2 + We try the provided shard_id first, then fall back between str/int forms. + """ + try: + weight_loader(param, loaded_weight, shard_id) + return + except (AssertionError, TypeError): + pass + + # Fall back between common representations. + if isinstance(shard_id, str): + mapping = {"q": 0, "k": 1, "v": 2} + if shard_id in mapping: + weight_loader(param, loaded_weight, mapping[shard_id]) + return + if shard_id.isdigit(): + weight_loader(param, loaded_weight, int(shard_id)) + return + elif isinstance(shard_id, int): + mapping = {0: "q", 1: "k", 2: "v"} + if shard_id in mapping: + weight_loader(param, loaded_weight, mapping[shard_id]) + return + + raise TypeError( + f"Unsupported shard_id={shard_id!r} for weight_loader={weight_loader} " + f"(param={getattr(param, 'name', '')})." + ) + + # Separate weights + language_model_weights: list[tuple[str, torch.Tensor]] = [] + other_weights: list[tuple[str, torch.Tensor]] = [] + + for name, loaded_weight in weights: + # Handle prefix mapping if needed + # HF weights might be "model.vision_tower...", "model.language_model..." + + if "vision_tower" in name or "vision_model" in name: + # Load vision tower weights + # Map name to local name + local_name = name + if "model.vision_tower" in name: + local_name = name.replace("model.vision_tower", "vision_tower") + elif "vision_tower" in name: + pass # already correct prefix if matching self.vision_tower + elif local_name.startswith("vision_model."): + local_name = ( + "vision_tower.vision_model." + + local_name[len("vision_model.") :] + ) + + # We need to map HF Siglip names to our Siglip implementation + # Our Siglip: vision_tower.vision_model.encoder.layers... + # HF Siglip: vision_model.encoder.layers... + + # If loading from Gemma3 checkpoint, it usually has "model.vision_tower.vision_model..." + + if local_name in params_dict: + param = params_dict[local_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(local_name) + else: + qkv_shard_id = None + fused_name = None + if ".self_attn.q_proj." in local_name: + fused_name = local_name.replace( + ".self_attn.q_proj.", ".self_attn.qkv_proj." + ) + qkv_shard_id = "q" + elif ".self_attn.k_proj." in local_name: + fused_name = local_name.replace( + ".self_attn.k_proj.", ".self_attn.qkv_proj." + ) + qkv_shard_id = "k" + elif ".self_attn.v_proj." in local_name: + fused_name = local_name.replace( + ".self_attn.v_proj.", ".self_attn.qkv_proj." + ) + qkv_shard_id = "v" + + if fused_name is not None and fused_name in params_dict: + param = params_dict[fused_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + _load_with_shard_id( + weight_loader, param, loaded_weight, qkv_shard_id + ) + loaded_params.add(fused_name) + continue + + if ".self_attn.proj." in local_name: + candidate = local_name.replace( + ".self_attn.proj.", ".self_attn.out_proj." + ) + if candidate in params_dict: + param = params_dict[candidate] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(candidate) + continue + if ".self_attn.out_proj." in local_name: + candidate = local_name.replace( + ".self_attn.out_proj.", ".self_attn.proj." + ) + if candidate in params_dict: + param = params_dict[candidate] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(candidate) + continue + + # Try to find match + suffix = local_name.split("vision_tower.")[-1] + # Try adding vision_model + candidate = f"vision_tower.vision_model.{suffix}" + if candidate in params_dict: + param = params_dict[candidate] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(candidate) + + elif "multi_modal_projector" in name: + local_name = name + if "model.multi_modal_projector" in name: + local_name = name.replace( + "model.multi_modal_projector", "multi_modal_projector" + ) + + if local_name in params_dict: + param = params_dict[local_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(local_name) + + elif "language_model" in name or "model.language_model" in name: + # Strip prefix for language model + # If name is "model.language_model.model.layers.0...", we want "model.layers.0..." for Gemma3ForCausalLM + # Gemma3ForCausalLM has .model (Gemma3TextModel) and .lm_head + + # HF: model.language_model.model.layers... + # Ours: language_model.model.layers... + + # We pass (name, weight) to language_model.load_weights + # We should strip "model.language_model." or "language_model." + + suffix = name + if "model.language_model." in name: + suffix = name.replace("model.language_model.", "") + elif "language_model." in name: + suffix = name.replace("language_model.", "") + if suffix.startswith("model."): + suffix = suffix[len("model.") :] + + language_model_weights.append((suffix, loaded_weight)) + + else: + # Fallback for other weights (maybe direct lm_head if not nested?) + other_weights.append((name, loaded_weight)) + + if language_model_weights: + lm_loaded = self.language_model.load_weights(language_model_weights) + loaded_params.update({f"language_model.{n}" for n in lm_loaded}) + + return loaded_params + + def get_attention_sliding_window_size(self): + if self.text_config is not None and hasattr( + self.text_config, "get_attention_sliding_window_size" + ): + return self.text_config.get_attention_sliding_window_size() + sliding_window = getattr(self.text_config, "sliding_window", None) + if sliding_window is None: + sliding_window = getattr(self.config, "sliding_window", None) + if sliding_window is None: + return None + return int(sliding_window) - 1 + + +EntryClass = Gemma3ForConditionalGeneration diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/encoders/hunyuan3d.py b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/hunyuan3d.py new file mode 100644 index 0000000000000000000000000000000000000000..448ffaa8bdc30ea590799f98479566543d60044f --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/hunyuan3d.py @@ -0,0 +1,264 @@ +# Copied and adapted from: https://github.com/Tencent-Hunyuan/Hunyuan3D-2 + +import numpy as np +import torch +import torch.nn as nn +from torchvision import transforms +from transformers import ( + CLIPVisionConfig, + CLIPVisionModelWithProjection, + Dinov2Config, + Dinov2Model, +) + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + return np.concatenate([emb_sin, emb_cos], axis=1) + + +class ImageEncoder(nn.Module): + MODEL_CLASS = None + MODEL_CONFIG_CLASS = None + mean = [] + std = [] + + def __init__( + self, + version=None, + config=None, + use_cls_token=True, + image_size=224, + **kwargs, + ): + super().__init__() + + if config is None: + self.model = self.MODEL_CLASS.from_pretrained(version) + else: + self.model = self.MODEL_CLASS(self.MODEL_CONFIG_CLASS.from_dict(config)) + self.model.eval() + self.model.requires_grad_(False) + self.use_cls_token = use_cls_token + self.size = image_size // 14 + self.num_patches = (image_size // 14) ** 2 + if self.use_cls_token: + self.num_patches += 1 + + self.transform = transforms.Compose( + [ + transforms.Resize( + image_size, transforms.InterpolationMode.BILINEAR, antialias=True + ), + transforms.CenterCrop(image_size), + transforms.Normalize( + mean=self.mean, + std=self.std, + ), + ] + ) + + def forward(self, image, mask=None, value_range=(-1, 1), **kwargs): + if value_range is not None: + low, high = value_range + image = (image - low) / (high - low) + + image = image.to(self.model.device, dtype=self.model.dtype) + inputs = self.transform(image) + outputs = self.model(inputs) + + last_hidden_state = outputs.last_hidden_state + if not self.use_cls_token: + last_hidden_state = last_hidden_state[:, 1:, :] + + return last_hidden_state + + def unconditional_embedding(self, batch_size, **kwargs): + device = next(self.model.parameters()).device + dtype = next(self.model.parameters()).dtype + zero = torch.zeros( + batch_size, + self.num_patches, + self.model.config.hidden_size, + device=device, + dtype=dtype, + ) + + return zero + + +class CLIPImageEncoder(ImageEncoder): + MODEL_CLASS = CLIPVisionModelWithProjection + MODEL_CONFIG_CLASS = CLIPVisionConfig + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + + +class DinoImageEncoder(ImageEncoder): + MODEL_CLASS = Dinov2Model + MODEL_CONFIG_CLASS = Dinov2Config + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + + +class DinoImageEncoderMV(DinoImageEncoder): + _aliases = [ + "hy3dshape.models.conditioner.DinoImageEncoderMV", + ] + + def __init__( + self, + version=None, + config=None, + use_cls_token=True, + image_size=224, + view_num=4, + **kwargs, + ): + super().__init__(version, config, use_cls_token, image_size, **kwargs) + self.view_num = view_num + self.num_patches = self.num_patches + pos = np.arange(self.view_num, dtype=np.float32) + view_embedding = torch.from_numpy( + get_1d_sincos_pos_embed_from_grid(self.model.config.hidden_size, pos) + ).float() + + view_embedding = view_embedding.unsqueeze(1).repeat(1, self.num_patches, 1) + self.view_embed = view_embedding.unsqueeze(0) + + def forward(self, image, mask=None, value_range=(-1, 1), view_idxs=None, **kwargs): + if value_range is not None: + low, high = value_range + image = (image - low) / (high - low) + + image = image.to(self.model.device, dtype=self.model.dtype) + + bs, num_views, c, h, w = image.shape + image = image.view(bs * num_views, c, h, w) + + inputs = self.transform(image) + outputs = self.model(inputs) + + last_hidden_state = outputs.last_hidden_state + last_hidden_state = last_hidden_state.view( + bs, num_views, last_hidden_state.shape[-2], last_hidden_state.shape[-1] + ) + + view_embedding = self.view_embed.to(last_hidden_state.dtype).to( + last_hidden_state.device + ) + if view_idxs is not None: + assert len(view_idxs) == bs + view_embeddings = [] + for i in range(bs): + view_idx = view_idxs[i] + assert num_views == len(view_idx) + view_embeddings.append(self.view_embed[:, view_idx, ...]) + view_embedding = ( + torch.cat(view_embeddings, 0) + .to(last_hidden_state.dtype) + .to(last_hidden_state.device) + ) + + if num_views != self.view_num: + view_embedding = view_embedding[:, :num_views, ...] + last_hidden_state = last_hidden_state + view_embedding + last_hidden_state = last_hidden_state.view( + bs, num_views * last_hidden_state.shape[-2], last_hidden_state.shape[-1] + ) + return last_hidden_state + + def unconditional_embedding(self, batch_size, view_idxs, **kwargs): + device = next(self.model.parameters()).device + dtype = next(self.model.parameters()).dtype + zero = torch.zeros( + batch_size, + self.num_patches * len(view_idxs[0]), + self.model.config.hidden_size, + device=device, + dtype=dtype, + ) + return zero + + +def build_image_encoder(config): + if config["type"] == "CLIPImageEncoder": + return CLIPImageEncoder(**config["kwargs"]) + elif config["type"] == "DinoImageEncoder": + return DinoImageEncoder(**config["kwargs"]) + elif config["type"] == "DinoImageEncoderMV": + return DinoImageEncoderMV(**config["kwargs"]) + else: + raise ValueError(f'Unknown image encoder type: {config["type"]}') + + +class DualImageEncoder(nn.Module): + def __init__( + self, + main_image_encoder, + additional_image_encoder, + ): + super().__init__() + self.main_image_encoder = build_image_encoder(main_image_encoder) + self.additional_image_encoder = build_image_encoder(additional_image_encoder) + + def forward(self, image, mask=None, **kwargs): + outputs = { + "main": self.main_image_encoder(image, mask=mask, **kwargs), + "additional": self.additional_image_encoder(image, mask=mask, **kwargs), + } + return outputs + + def unconditional_embedding(self, batch_size, **kwargs): + outputs = { + "main": self.main_image_encoder.unconditional_embedding( + batch_size, **kwargs + ), + "additional": self.additional_image_encoder.unconditional_embedding( + batch_size, **kwargs + ), + } + return outputs + + +class SingleImageEncoder(nn.Module): + def __init__( + self, + main_image_encoder, + ): + super().__init__() + self.main_image_encoder = build_image_encoder(main_image_encoder) + + def forward(self, image, mask=None, **kwargs): + outputs = { + "main": self.main_image_encoder(image, mask=mask, **kwargs), + } + return outputs + + def unconditional_embedding(self, batch_size, **kwargs): + outputs = { + "main": self.main_image_encoder.unconditional_embedding( + batch_size, **kwargs + ), + } + return outputs + + +# Entry class for model registry +EntryClass = [ + SingleImageEncoder, + DualImageEncoder, + DinoImageEncoder, + DinoImageEncoderMV, +] diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/encoders/llama.py b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..a9d209231fc987c415771ebe2882b4aaf0ff83bb --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/llama.py @@ -0,0 +1,460 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/llama.py + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only LLaMA model compatible with HuggingFace weights.""" + +from collections.abc import Iterable +from typing import Any + +import torch +from torch import nn + +# from ..utils import (extract_layer_index) +from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput, LlamaConfig +from sglang.multimodal_gen.runtime.distributed import get_tp_world_size +from sglang.multimodal_gen.runtime.layers.activation import SiluAndMul + +# from vllm.model_executor.layers.quantization import QuantizationConfig +from sglang.multimodal_gen.runtime.layers.attention import LocalAttention +from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm +from sglang.multimodal_gen.runtime.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig +from sglang.multimodal_gen.runtime.layers.rotary_embedding import get_rope +from sglang.multimodal_gen.runtime.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, +) +from sglang.multimodal_gen.runtime.loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.multimodal_gen.runtime.models.encoders.base import TextEncoder + + +class LlamaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: QuantizationConfig | None = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + # output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class LlamaAttention(nn.Module): + + def __init__( + self, + config: LlamaConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: dict[str, Any] | None = None, + max_position_embeddings: int = 8192, + quant_config: QuantizationConfig | None = None, + bias: bool = False, + bias_o_proj: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + # layer_idx = extract_layer_index(prefix) + self.hidden_size = hidden_size + tp_size = get_tp_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr( + config, "head_dim", self.hidden_size // self.total_num_heads + ) + # Phi models introduced a partial_rotary_factor parameter in the config + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) + self.rotary_dim = int(partial_rotary_factor * self.head_dim) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias_o_proj, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + is_neox_style = True + is_gguf = ( + quant_config + and hasattr(quant_config, "get_name") + and quant_config.get_name() == "gguf" + ) + if is_gguf and config.model_type == "llama": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.rotary_dim, + max_position=max_position_embeddings, + base=int(rope_theta), + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) + + self.attn = LocalAttention( + self.num_heads, + self.head_dim, + self.num_kv_heads, + softmax_scale=self.scaling, + causal=True, + supported_attention_backends=config._supported_attention_backends, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + # attn_output = self.attn(q, k, v) + # use flash_attn_func + # TODO (Attn abstraction and backend) + # reshape q, k, v to (batch_size, seq_len, num_heads, head_dim) + batch_size = q.shape[0] + seq_len = q.shape[1] + q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim) + k = k.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim) + v = v.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim) + # import pdb; pdb.set_trace() + # attn_output = flash_attn_varlen_func(q, k, v, softmax_scale=self.scaling, causal=True) + attn_output = self.attn(q, k, v) + attn_output = attn_output.reshape( + batch_size, seq_len, self.num_heads * self.head_dim + ) + + output, _ = self.o_proj(attn_output) + return output + + +class LlamaDecoderLayer(nn.Module): + + def __init__( + self, + config: LlamaConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False + ) + bias_o_proj = attention_bias + # support internlm/internlm3-8b with qkv_bias + if hasattr(config, "qkv_bias"): + attention_bias = config.qkv_bias + + self.self_attn = LlamaAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + bias_o_proj=bias_o_proj, + prefix=f"{prefix}.self_attn", + ) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class LlamaModel(TextEncoder): + + def __init__( + self, + config: LlamaConfig, + ): + super().__init__(config) + + self.config = config + self.quant_config = self.config.quant_config + if config.lora_config is not None: + max_loras = 1 + lora_vocab_size = 1 + if hasattr(config.lora_config, "max_loras"): + max_loras = config.lora_config.max_loras + if hasattr(config.lora_config, "lora_extra_vocab_size"): + lora_vocab_size = config.lora_config.lora_extra_vocab_size + lora_vocab = lora_vocab_size * max_loras + else: + lora_vocab = 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=config.quant_config, + ) + + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer( + config=config, + quant_config=config.quant_config, + prefix=f"{config.prefix}.layers.{i}", + ) + for i in range(config.num_hidden_layers) + ] + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + **kwargs, + ) -> BaseEncoderOutput: + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + + if position_ids is None: + position_ids = torch.arange( + 0, hidden_states.shape[1], device=hidden_states.device + ).unsqueeze(0) + + all_hidden_states: tuple[Any, ...] | None = () if output_hidden_states else None + for layer in self.layers: + if all_hidden_states is not None: + # TODO + all_hidden_states += ( + (hidden_states,) + if residual is None + else (hidden_states + residual,) + ) + hidden_states, residual = layer(position_ids, hidden_states, residual) + + hidden_states, _ = self.norm(hidden_states, residual) + + # add hidden states from the last decoder layer + if all_hidden_states is not None: + all_hidden_states += (hidden_states,) + + # TODO(will): maybe unify the output format with other models and use + # our own class + output = BaseEncoderOutput( + last_hidden_state=hidden_states, + # past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + # attentions=all_self_attns, + ) + + return output + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # if (self.quant_config is not None and + # (scale_name := self.quant_config.get_cache_scale(name))): + # # Loading kv cache quantization scales + # param = params_dict[scale_name] + # weight_loader = getattr(param, "weight_loader", + # default_weight_loader) + # loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + # loaded_weight[0]) + # weight_loader(param, loaded_weight) + # loaded_params.add(scale_name) + # continue + if "scale" in name: + # Remapping the name of FP8 kv-scale. + kv_scale_name: str | None = maybe_remap_kv_scale_name(name, params_dict) + if kv_scale_name is None: + continue + else: + name = kv_scale_name + for ( + param_name, + weight_name, + shard_id, + ) in self.config.arch_config.stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +EntryClass = LlamaModel diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/encoders/mistral_3.py b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/mistral_3.py new file mode 100644 index 0000000000000000000000000000000000000000..fef6ece6c13f62c7f00810fdd52975b6ac7ce047 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/mistral_3.py @@ -0,0 +1,459 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Iterable, Optional, Union + +import torch +from torch import nn +from transformers import Cache, DynamicCache, LlavaConfig, Mistral3Config, MistralConfig +from transformers.integrations.sdpa_attention import sdpa_attention_forward +from transformers.masking_utils import create_causal_mask +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.mistral3.modeling_mistral3 import ( + Mistral3CausalLMOutputWithPast, + Mistral3ModelOutputWithPast, +) +from transformers.models.mistral.modeling_mistral import ( + MistralMLP, + MistralRMSNorm, + MistralRotaryEmbedding, + apply_rotary_pos_emb, +) + +from sglang.multimodal_gen.runtime.layers.attention import USPAttention +from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to + (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class MistralAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MistralConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + + self.head_dim = ( + getattr(config, "head_dim", None) + or config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=False + ) + self.is_causal = True + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.attn = USPAttention( + num_heads=self.num_heads, + head_size=self.head_dim, + dropout_rate=0, + softmax_scale=None, + causal=False, + supported_attention_backends={ + AttentionBackendEnum.FA, + AttentionBackendEnum.TORCH_SDPA, + }, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface = sdpa_attention_forward + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0, + scaling=self.scaling, + sliding_window=getattr( + self.config, "sliding_window", None + ), # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + +class MistralDecoderLayer(nn.Module): + def __init__(self, config: MistralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = MistralAttention(config=config, layer_idx=layer_idx) + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = MistralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class MistralModel(nn.Module): + def __init__(self, config: MistralConfig): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + MistralDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = MistralRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.config._attn_implementation = "sdpa" + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + mask_function = create_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + hidden_states_pool = [] + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + if output_hidden_states: + hidden_states_pool.append(hidden_states) + + hidden_states = self.norm(hidden_states) + if output_hidden_states: + hidden_states_pool.append(hidden_states) + + return BaseModelOutputWithPast( + hidden_states=hidden_states_pool, + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +class Mistral3Model(nn.Module): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + + def __init__(self, config: Mistral3Config): + super().__init__() + self.language_model = MistralModel(config.text_config) + self.config = config + + def get_input_embeddings(self): + return self.language_model.embed_tokens + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidoutput_hidden_statesden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[tuple, Mistral3ModelOutputWithPast]: + output_attentions = False + output_hidden_states = True + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + outputs: BaseModelOutputWithPast = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + return Mistral3ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Mistral3ForConditionalGeneration(nn.Module): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: LlavaConfig): + super().__init__() + self.model = Mistral3Model(config.arch_config) + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_hidden_states: Optional[bool] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + image_sizes: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[tuple, Mistral3CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + """ + output_hidden_states = True + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + image_sizes=image_sizes, + **kwargs, + ) + + return Mistral3CausalLMOutputWithPast( + hidden_states=outputs.hidden_states, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Define mapping for stacked parameters + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + name_lower = name.lower() + if ( + "vision" in name_lower + or "multi" in name_lower + or "lm_head" in name_lower + ): + continue + final_name = name.replace("language_model.model.", "model.language_model.") + + if final_name in params_dict: + param = params_dict[final_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(final_name) + else: + logger.warning(f"Param {name=} {final_name=} from weight is not loaded") + + return loaded_params + + +EntryClass = Mistral3ForConditionalGeneration diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/encoders/qwen2_5vl.py b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/qwen2_5vl.py new file mode 100644 index 0000000000000000000000000000000000000000..364b72d59fa528ebaad49b85ffe8e4c715190812 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/qwen2_5vl.py @@ -0,0 +1,1151 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from transformers import ( + Cache, + DynamicCache, + PretrainedConfig, + Qwen2_5_VLTextConfig, + Qwen2RMSNorm, +) +from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, +) +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.utils import TransformersKwargs, is_torchdynamo_compiling + +from sglang.multimodal_gen.configs.models.encoders.qwen_image import Qwen2_5VLConfig +from sglang.multimodal_gen.runtime.layers.attention import LocalAttention +from sglang.multimodal_gen.runtime.layers.linear import ( + MergedColumnParallelLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig +from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader +from sglang.multimodal_gen.runtime.models.encoders.base import TextEncoder +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.utils.common import add_prefix + +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" +import logging +from typing import Callable, Iterable, Optional, Tuple, Union + +try: + from typing import Unpack # type: ignore[attr-defined] +except ImportError: + # Python 3.10 and below + from typing_extensions import Unpack + +import torch +import torch.nn as nn +from transformers.activations import ACT2FN +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLAttention, + Qwen2_5_VLCausalLMOutputWithPast, + Qwen2_5_VLModelOutputWithPast, + Qwen2_5_VLRotaryEmbedding, + Qwen2MLP, + apply_multimodal_rotary_pos_emb, + eager_attention_forward, +) + +logger = logging.getLogger(__name__) + + +class Qwen2_5_VLAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warn( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + self.scaling = self.head_dim**-0.5 + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=True + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + self.sliding_window = ( + config.sliding_window + if config.layer_types[layer_idx] == "sliding_attention" + else None + ) + + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + self.attn = LocalAttention( + num_heads=self.num_heads, + head_size=self.head_dim, + num_kv_heads=self.num_key_value_heads, + softmax_scale=self.scaling, + causal=True, + supported_attention_backends=( + AttentionBackendEnum.FA, + AttentionBackendEnum.TORCH_SDPA, + ), + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_values is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + # if self.config._attn_implementation != "eager": + # attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"] + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + attn_output = self.attn(query_states, key_states, value_states) + # + # attn_output, attn_weights = attention_interface( + # self, + # query_states, + # key_states, + # value_states, + # attention_mask, + # dropout=0.0 if not self.training else self.attention_dropout, + # scaling=self.scaling, + # sliding_window=self.sliding_window, + # position_ids=position_ids, # pass positions for FA2 + # **kwargs, + # ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + +class Qwen2_5_VLDecoderLayer(nn.Module): + def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if ( + config.use_sliding_window + and config._attn_implementation != "flash_attention_2" + ): + logger.warning( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = Qwen2_5_VLAttention(config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.attention_type = config.layer_types[layer_idx] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Qwen2_5_VLMLP(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: int = None, + bias: bool = True, + hidden_act="silu", + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=in_features, + output_sizes=[hidden_features] * 2, # [gate_proj, up_proj] + bias=bias, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), + ) + self.down_proj = RowParallelLinear( + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), + ) + self.act = ACT2FN[hidden_act] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + gate, up = gate_up.chunk(2, dim=-1) + x = self.act(gate) * up + x_down, _ = self.down_proj(x) + return x_down + + +class Qwen2_5_VLTextModel(nn.Module): + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + Qwen2_5_VLDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + # self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand( + 3, inputs_embeds.shape[0], -1 + ) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + # NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions + # where each dim indicates visual spatial positions for temporal/height/width grids. + # There are two scenarios when FA2-like packed masking might be activated. + # 1. User specifically passed packed `position_ids` and no attention mask. + # In this case we expect the user to create correct position ids for all 3 grids + # and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len] + # 2. User runs forward with no attention mask and no position ids. In this case, position ids + # are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are + # prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass + # text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation` + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": text_position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = ( + create_sliding_window_causal_mask(**mask_kwargs) + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=text_position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attns, + ] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class Qwen2_5_VLModel(nn.Module): + base_model_prefix = "" + _checkpoint_conversion_mapping = {"^model": "language_model"} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + + def __init__(self, config, enable_image_understanding: bool = False): + super().__init__() + self.language_model = Qwen2_5_VLTextModel(config.text_config) + + if enable_image_understanding: + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config( + config.vision_config + ) + self.visual.to(torch.get_default_dtype()) + self.rope_deltas = None # cache rope_deltas here + self.config = config + # Initialize weights and apply final processing + # self.post_init() + + def get_input_embeddings(self): + return self.language_model.embed_tokens + + def set_input_embeddings(self, value): + self.language_model.embed_tokens = value + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and ( + image_grid_thw is not None or video_grid_thw is not None + ): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere( + input_ids == vision_start_token_id + ).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + ## normalize type, send to device. + second_per_grid_t = torch.as_tensor( + second_per_grid_t, + dtype=range_tensor.dtype, + device=range_tensor.device, + ) + + time_tensor = ( + expanded_range + * second_per_grid_t + * self.config.vision_config.tokens_per_second + ) + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( + position_ids.device + ) + mrope_position_deltas.append( + llm_positions.max() + 1 - len(total_input_ids[i]) + ) + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = ( + position_ids.unsqueeze(0) + .expand(3, -1, -1) + .to(attention_mask.device) + ) + max_position_ids = position_ids.max(0, keepdim=False)[0].max( + -1, keepdim=True + )[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_grid_thw: Optional[torch.LongTensor] = None, + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + split_sizes = ( + video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2 + ).tolist() + video_embeds = torch.split(video_embeds, split_sizes) + return video_embeds + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + split_sizes = ( + image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2 + ).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + return image_embeds + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor = None, + video_features: torch.FloatTensor = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor( + self.config.image_token_id, + dtype=torch.long, + device=inputs_embeds.device, + ) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor( + self.config.video_token_id, + dtype=torch.long, + device=inputs_embeds.device, + ) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = ( + special_image_mask.unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + if ( + image_features is not None + and inputs_embeds[special_image_mask].numel() != image_features.numel() + ): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = ( + special_video_mask.unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + if ( + video_features is not None + and inputs_embeds[special_video_mask].numel() != video_features.numel() + ): + raise ValueError( + f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + return special_image_mask, special_video_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen2_5_VLModelOutputWithPast]: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + """ + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0).to( + inputs_embeds.device, inputs_embeds.dtype + ) + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0).to( + inputs_embeds.device, inputs_embeds.dtype + ) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if position_ids is None: + # Calculate RoPE index once per generation in the pre-fill stage only. + # When compiling, we can't check tensor values thus we check only input length + # It is safe to assume that `length!=1` means we're in pre-fill because compiled + # models currently cannot do asssisted decoding + prefill_compiled_stage = is_torchdynamo_compiling() and ( + (input_ids is not None and input_ids.shape[1] != 1) + or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) + ) + prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( + (cache_position is not None and cache_position[0] == 0) + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ) + if ( + prefill_compiled_stage or prefill_noncompiled_stage + ) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + attention_mask=attention_mask, + ) + self.rope_deltas = rope_deltas + else: + batch_size, seq_length, _ = inputs_embeds.shape + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1) + if cache_position is not None: + delta = (cache_position[0] + self.rope_deltas).to( + inputs_embeds.device + ) + else: + delta = torch.zeros( + (batch_size, seq_length), device=inputs_embeds.device + ) + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1) + position_ids += delta.to(position_ids.device) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + output = Qwen2_5_VLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + return output if return_dict else output.to_tuple() + + +class Qwen2_5_VLForConditionalGeneration(TextEncoder): + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_up_proj.", + ".down_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + config: Qwen2_5VLConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config) + enable_image_understanding = config.enable_image_understanding + config = config.arch_config + self.model = Qwen2_5_VLModel( + config, enable_image_understanding=enable_image_understanding + ) + self.lm_head = nn.Linear( + config.text_config.hidden_size, config.text_config.vocab_size, bias=False + ) + + self.enable_image_understanding = enable_image_understanding + + self.config = config + + def get_input_embeddings(self): + return self.model.embed_tokens + + @torch.no_grad() + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ): + """Run forward pass for Qwen2_5-VL. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Flattened (concatenated) position ids corresponding to a + batch. + **NOTE**: If mrope is enabled (default setting for Qwen2-VL + opensource models), the shape will be `(3, seq_len)`, + otherwise it will be `(seq_len,). + (Use input_metadata.mrope_positions to replace it) + """ + output_attentions = False + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :]) + return Qwen2_5_VLCausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loaded_params: set[str] = set() + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + name = name.replace("model.", "model.language_model.") + if "visual." in name: + if not self.enable_image_understanding: + continue + name = name.replace("visual.", "model.visual.") + try: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + except KeyError: + raise + + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = loaded_weight.to(param.dtype) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + +EntryClass = Qwen2_5_VLForConditionalGeneration diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/encoders/qwen3.py b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/qwen3.py new file mode 100644 index 0000000000000000000000000000000000000000..b8132e4041c1075914fb6a7813c47e85ea8ae4a2 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/qwen3.py @@ -0,0 +1,422 @@ +from collections.abc import Iterable +from typing import Any + +import torch +from torch import nn + +from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput +from sglang.multimodal_gen.configs.models.encoders.qwen3 import Qwen3TextConfig +from sglang.multimodal_gen.runtime.distributed import get_tp_world_size +from sglang.multimodal_gen.runtime.layers.activation import SiluAndMul +from sglang.multimodal_gen.runtime.layers.attention import LocalAttention +from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm +from sglang.multimodal_gen.runtime.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig +from sglang.multimodal_gen.runtime.layers.rotary_embedding import get_rope +from sglang.multimodal_gen.runtime.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, +) +from sglang.multimodal_gen.runtime.loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.multimodal_gen.runtime.models.encoders.base import TextEncoder + + +class Qwen3MLP(nn.Module): + """Qwen3 MLP with SwiGLU activation and tensor parallelism.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: QuantizationConfig | None = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported." + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class Qwen3Attention(nn.Module): + """Qwen3 attention with QK-Norm and tensor parallelism. + + Key difference from LLaMA: RMSNorm is applied to Q and K before attention. + """ + + def __init__( + self, + config: Qwen3TextConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 1000000.0, + rope_scaling: dict[str, Any] | None = None, + max_position_embeddings: int = 40960, + quant_config: QuantizationConfig | None = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tp_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.head_dim = getattr( + config, "head_dim", self.hidden_size // self.total_num_heads + ) + self.rotary_dim = self.head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + # QKV projection with tensor parallelism + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + # Output projection + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + # QK-Norm: Key difference from LLaMA + rms_norm_eps = getattr(config, "rms_norm_eps", 1e-6) + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + # Rotary embeddings + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.rotary_dim, + max_position=max_position_embeddings, + base=int(rope_theta), + rope_scaling=rope_scaling, + is_neox_style=True, + ) + + # Attention with FlashAttention/SageAttn support + self.attn = LocalAttention( + self.num_heads, + self.head_dim, + self.num_kv_heads, + softmax_scale=self.scaling, + causal=True, + supported_attention_backends=config._supported_attention_backends, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # QKV projection + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Reshape for QK-norm + batch_size, seq_len = q.shape[0], q.shape[1] + q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim) + k = k.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim) + v = v.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim) + + # Apply QK-Norm (key difference from LLaMA) + q = self.q_norm(q) + k = self.k_norm(k) + + # Reshape back for rotary embeddings + q = q.reshape(batch_size, seq_len, -1) + k = k.reshape(batch_size, seq_len, -1) + + # Apply rotary embeddings + q, k = self.rotary_emb(positions, q, k) + + # Reshape for attention + q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim) + k = k.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim) + + # Attention + attn_output = self.attn(q, k, v) + attn_output = attn_output.reshape(batch_size, seq_len, -1) + + # Output projection + output, _ = self.o_proj(attn_output) + return output + + +class Qwen3DecoderLayer(nn.Module): + """Qwen3 transformer decoder layer.""" + + def __init__( + self, + config: Qwen3TextConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 1000000.0) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 40960) + attention_bias = getattr(config, "attention_bias", False) + + self.self_attn = Qwen3Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + prefix=f"{prefix}.self_attn", + ) + self.mlp = Qwen3MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) + + # MLP + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Qwen3ForCausalLM(TextEncoder): + """Qwen3 causal language model for text encoding in diffusion models. + + Features: + - Tensor parallelism support + - FlashAttention/SageAttn/SDPA support via LocalAttention + - QK-Norm for better training stability + - FSDP sharding for CPU offload + """ + + def __init__(self, config: Qwen3TextConfig) -> None: + super().__init__(config) + + self.config = config + self.quant_config = config.quant_config + + # Embedding layer with tensor parallelism + if config.lora_config is not None: + max_loras = getattr(config.lora_config, "max_loras", 1) + lora_vocab_size = getattr(config.lora_config, "lora_extra_vocab_size", 1) + lora_vocab = lora_vocab_size * max_loras + else: + lora_vocab = 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=config.quant_config, + ) + + # Transformer layers + self.layers = nn.ModuleList( + [ + Qwen3DecoderLayer( + config=config, + quant_config=config.quant_config, + prefix=f"{config.prefix}.layers.{i}", + ) + for i in range(config.num_hidden_layers) + ] + ) + + # Final layer norm + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + **kwargs, + ) -> BaseEncoderOutput: + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + + residual = None + + if position_ids is None: + position_ids = torch.arange( + 0, hidden_states.shape[1], device=hidden_states.device + ).unsqueeze(0) + + all_hidden_states: tuple[Any, ...] | None = () if output_hidden_states else None + + for layer in self.layers: + if all_hidden_states is not None: + all_hidden_states += ( + (hidden_states,) + if residual is None + else (hidden_states + residual,) + ) + hidden_states, residual = layer(position_ids, hidden_states, residual) + + hidden_states, _ = self.norm(hidden_states, residual) + + # Add hidden states from the last decoder layer + if all_hidden_states is not None: + all_hidden_states += (hidden_states,) + + return BaseEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights with support for tensor parallelism and weight remapping.""" + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + # Strip 'model.' prefix from HuggingFace Qwen3 weights + if name.startswith("model."): + name = name[6:] # len("model.") == 6 + + # Skip rotary embedding weights + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + continue + + # Handle KV scale remapping + if "scale" in name: + kv_scale_name: str | None = maybe_remap_kv_scale_name(name, params_dict) + if kv_scale_name is None: + continue + else: + name = kv_scale_name + + # Handle stacked params mapping (qkv_proj, gate_up_proj) + for ( + param_name, + weight_name, + shard_id, + ) in self.config.arch_config.stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + # Skip loading extra bias for GPTQ models + if name.endswith(".bias") and name not in params_dict: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models + if name.endswith(".bias") and name not in params_dict: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + loaded_params.add(name) + + return loaded_params + + +EntryClass = Qwen3ForCausalLM diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/encoders/t5.py b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/t5.py new file mode 100644 index 0000000000000000000000000000000000000000..5058de389017f092b78032cf4089b25f48581d05 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/t5.py @@ -0,0 +1,747 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from transformers: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/t5/modeling_t5.py + +# Derived from T5 implementation posted on HuggingFace; license below: +# +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch T5 & UMT5 model.""" + +import math +from collections.abc import Iterable +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from torch import nn + +from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput, T5Config +from sglang.multimodal_gen.runtime.distributed import _get_folding_tp_group +from sglang.multimodal_gen.runtime.layers.activation import get_act_fn +from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm +from sglang.multimodal_gen.runtime.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig +from sglang.multimodal_gen.runtime.layers.utils import get_group_rank, get_group_size +from sglang.multimodal_gen.runtime.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, +) +from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader +from sglang.multimodal_gen.runtime.models.encoders.base import TextEncoder +from sglang.multimodal_gen.runtime.platforms import current_platform + + +class AttentionType: + """ + Attention type. + Use string to be compatible with `torch.compile`. + """ + + # Decoder attention between previous layer Q/K/V + DECODER = "decoder" + # Encoder attention between previous layer Q/K/V for encoder-decoder + ENCODER = "encoder" + # Encoder attention between previous layer Q/K/V + ENCODER_ONLY = "encoder_only" + # Attention between dec. Q and enc. K/V for encoder-decoder + ENCODER_DECODER = "encoder_decoder" + + +@dataclass +class AttentionMetadata: + attn_bias: torch.Tensor + + +class T5DenseActDense(nn.Module): + + def __init__( + self, config: T5Config, quant_config: QuantizationConfig | None = None + ): + super().__init__() + tp_group = _get_folding_tp_group(config) + self.wi = MergedColumnParallelLinear( + config.d_model, [config.d_ff], bias=False, tp_group=tp_group + ) + self.wo = RowParallelLinear( + config.d_ff, + config.d_model, + bias=False, + quant_config=quant_config, + tp_group=tp_group, + ) + self.act = get_act_fn(config.dense_act_fn) + + def forward(self, hidden_states) -> torch.Tensor: + hidden_states, _ = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + + def __init__( + self, config: T5Config, quant_config: QuantizationConfig | None = None + ): + super().__init__() + tp_group = _get_folding_tp_group(config) + self.wi_0 = MergedColumnParallelLinear( + config.d_model, + [config.d_ff], + bias=False, + quant_config=quant_config, + tp_group=tp_group, + ) + self.wi_1 = MergedColumnParallelLinear( + config.d_model, + [config.d_ff], + bias=False, + quant_config=quant_config, + tp_group=tp_group, + ) + # Should not run in fp16 unless mixed-precision is used, + # see https://github.com/huggingface/transformers/issues/20287. + self.wo = RowParallelLinear( + config.d_ff, + config.d_model, + bias=False, + quant_config=quant_config, + tp_group=tp_group, + ) + self.act = get_act_fn(config.dense_act_fn) + + def forward(self, hidden_states) -> torch.Tensor: + hidden_gelu = self.act(self.wi_0(hidden_states)[0]) + hidden_linear, _ = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states, _ = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + + def __init__( + self, config: T5Config, quant_config: QuantizationConfig | None = None + ): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense( + config, quant_config=quant_config + ) + else: + self.DenseReluDense = T5DenseActDense(config, quant_config=quant_config) + + self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + + def forward(self, hidden_states) -> torch.Tensor: + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + forwarded_states + return hidden_states + + +# T5 has attn_bias and does not use softmax scaling +class T5MultiHeadAttention(nn.Module): + + def __init__(self) -> None: + super().__init__() + + def forward(self, q, k, v, attn_bias=None): + b, _, n, c = q.shape + attn = torch.einsum("binc,bjnc->bnij", q, k) + if attn_bias is not None: + attn += attn_bias + + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum("bnij,bjnc->binc", attn, v) + x = x.reshape(b, -1, n * c) + return x + + +class T5Attention(nn.Module): + + def __init__( + self, + config: T5Config, + attn_type: str, + has_relative_attention_bias=False, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.attn_type = attn_type + # Cross-attention has no relative pos encoding anyway + self.is_decoder = attn_type == AttentionType.DECODER + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.total_num_heads = self.total_num_kv_heads = config.num_heads + + # Partition heads across multiple tensor parallel GPUs. + self.tp_group = _get_folding_tp_group(config) + self.tp_world_size = get_group_size(self.tp_group) + assert config.num_heads % self.tp_world_size == 0 + self.n_heads = config.num_heads // self.tp_world_size + + self.inner_dim = self.n_heads * self.key_value_proj_dim + # No GQA in t5. + # self.n_kv_heads = self.n_heads + + self.qkv_proj = QKVParallelLinear( + self.d_model, + self.key_value_proj_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + tp_group=self.tp_group, + ) + + self.attn = T5MultiHeadAttention() + + if self.has_relative_attention_bias: + self.relative_attention_bias = VocabParallelEmbedding( + self.relative_attention_num_buckets, + self.total_num_heads, + org_num_embeddings=self.relative_attention_num_buckets, + padding_size=self.relative_attention_num_buckets, + quant_config=quant_config, + tp_group=self.tp_group, + ) + self.o = RowParallelLinear( + self.total_num_heads * self.key_value_proj_dim, + self.d_model, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + tp_group=self.tp_group, + ) + + @staticmethod + def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ) -> torch.Tensor: + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + Translate relative position to a bucket number for relative attention. + The relative position is defined as memory_position - query_position, + i.e. the distance in tokens from the attending position to the + attended-to position. If bidirectional=False, then positive relative + positions are invalid. We use smaller buckets for small absolute + relative_position and larger buckets for larger absolute + relative_positions. All relative positions >=max_distance map to the + same bucket. All relative positions <=-max_distance map to the same + bucket. This should allow for more graceful generalization to longer + sequences than the model has been trained on + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + Returns: + a Tensor with the same shape as relative_position, containing int32 + values in the range [0, num_buckets) + """ # noqa: E501 + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins + # in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), + ) + + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None) -> torch.Tensor: + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[ + :, None + ] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ + None, : + ] + # max_seq_len, nh + relative_position = memory_position - context_position + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + x = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape (1, num_heads, query_length, key_length) + return x + + def forward( + self, + hidden_states: torch.Tensor, # (num_tokens, d_model) + attention_mask: torch.Tensor, + attn_metadata: AttentionMetadata | None = None, + ) -> torch.Tensor: + bs, seq_len, _ = hidden_states.shape + num_seqs = bs + n, c = ( + self.n_heads, + self.key_value_proj_dim, + ) + qkv, _ = self.qkv_proj(hidden_states) + # Projection of 'own' hidden state (self-attention). No GQA here. + q, k, v = qkv.split(self.inner_dim, dim=-1) + q = q.reshape(bs, seq_len, n, c) + k = k.reshape(bs, seq_len, n, c) + v = v.reshape(bs, seq_len, n, c) + + assert attn_metadata is not None + attn_bias = attn_metadata.attn_bias + # Not compatible with CP here (as all encoder-decoder models), + # as it assumes homogeneous batch (prefills or decodes). + if self.has_relative_attention_bias: + # Self-attention. Compute T5 relative positional encoding. + # The bias term is computed on longest sequence in batch. Biases + # for shorter sequences are slices of the longest. + assert self.attn_type == AttentionType.ENCODER + attn_bias = self.compute_bias(seq_len, seq_len).repeat(num_seqs, 1, 1, 1) + attn_metadata.attn_bias = attn_bias + else: + # Encoder/Decoder Self-Attention Layer, attn bias already cached. + assert attn_bias is not None + + if attention_mask is not None: + attention_mask = ( + attention_mask.view(bs, 1, 1, -1) + if attention_mask.ndim == 2 + else attention_mask.unsqueeze(1) + ) + mask_val = -1e4 if current_platform.is_mps() else torch.finfo(q.dtype).min + attn_bias.masked_fill_(attention_mask == 0, mask_val) + + if self.tp_world_size > 1: + rank = get_group_rank(self.tp_group) + attn_bias = attn_bias[ + :, rank * self.n_heads : (rank + 1) * self.n_heads, :, : + ] + + attn_output = self.attn(q, k, v, attn_bias) + output, _ = self.o(attn_output) + return output + + +class T5LayerSelfAttention(nn.Module): + + def __init__( + self, + config, + has_relative_attention_bias=False, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.SelfAttention = T5Attention( + config, + AttentionType.DECODER if "decoder" in prefix else AttentionType.ENCODER, + has_relative_attention_bias=has_relative_attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.SelfAttention", + ) + self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + attn_metadata: AttentionMetadata | None = None, + ) -> torch.Tensor: + normed_hidden_states = self.layer_norm(hidden_states) + + attention_output = self.SelfAttention( + hidden_states=normed_hidden_states, + attention_mask=attention_mask, + attn_metadata=attn_metadata, + ) + + hidden_states = hidden_states + attention_output + + return hidden_states + + +class T5LayerCrossAttention(nn.Module): + + def __init__( + self, config, quant_config: QuantizationConfig | None = None, prefix: str = "" + ): + super().__init__() + self.EncDecAttention = T5Attention( + config, + AttentionType.ENCODER_DECODER, + has_relative_attention_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.EncDecAttention", + ) + self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata | None = None, + ) -> torch.Tensor: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + hidden_states=normed_hidden_states, + attn_metadata=attn_metadata, + ) + hidden_states = hidden_states + attention_output + return hidden_states + + +class T5Block(nn.Module): + + def __init__( + self, + config: T5Config, + is_decoder: bool, + has_relative_attention_bias=False, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.is_decoder = is_decoder + self.layer = nn.ModuleList() + self.layer.append( + T5LayerSelfAttention( + config, + has_relative_attention_bias=has_relative_attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + ) + + if self.is_decoder: + self.layer.append( + T5LayerCrossAttention( + config, quant_config=quant_config, prefix=f"{prefix}.cross_attn" + ) + ) + + self.layer.append(T5LayerFF(config, quant_config=quant_config)) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + attn_metadata: AttentionMetadata | None = None, + ) -> torch.Tensor: + + if attention_mask is None: + attention_mask = torch.ones( + hidden_states.shape[:2], device=hidden_states.device + ) + + hidden_states = self.layer[0]( + hidden_states=hidden_states, + attention_mask=attention_mask, + attn_metadata=attn_metadata, + ) + + if self.is_decoder: + hidden_states = self.layer[1]( + hidden_states=hidden_states, attn_metadata=attn_metadata + ) + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + return hidden_states + + +class T5Stack(nn.Module): + + def __init__( + self, + config: T5Config, + is_decoder: bool, + n_layers: int, + embed_tokens=None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + is_umt5: bool = False, + ): + super().__init__() + self.embed_tokens = embed_tokens + self.is_umt5 = is_umt5 + if is_umt5: + self.block = nn.ModuleList( + [ + T5Block( + config, + is_decoder=is_decoder, + has_relative_attention_bias=True, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{i}", + ) + for i in range(n_layers) + ] + ) + else: + # Only the first block has relative positional encoding. + self.block = nn.ModuleList( + [ + T5Block( + config, + is_decoder=is_decoder, + has_relative_attention_bias=i == 0, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{i}", + ) + for i in range(n_layers) + ] + ) + self.final_layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + for idx, block in enumerate(self.block): + hidden_states = block( + hidden_states=hidden_states, + attention_mask=attention_mask, + attn_metadata=attn_metadata, + ) + + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + +class T5EncoderModel(TextEncoder): + + def __init__(self, config: T5Config, prefix: str = ""): + super().__init__(config) + + quant_config = None + tp_group = _get_folding_tp_group(config) + self.shared = VocabParallelEmbedding( + config.vocab_size, + config.d_model, + org_num_embeddings=config.vocab_size, + tp_group=tp_group, + ) + + self.encoder = T5Stack( + config, + False, + config.num_layers, + self.shared, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + is_umt5=False, + ) + + def get_input_embeddings(self): + return self.shared + + def forward( + self, + input_ids: torch.Tensor | None, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + **kwargs, + ) -> BaseEncoderOutput: + attn_metadata = AttentionMetadata(None) + hidden_states = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + attn_metadata=attn_metadata, + ) + + return BaseEncoderOutput(last_hidden_state=hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q", "q"), + (".qkv_proj", ".k", "k"), + (".qkv_proj", ".v", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + loaded = False + if "decoder" in name or "lm_head" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded = True + break + if not loaded: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class UMT5EncoderModel(TextEncoder): + + def __init__(self, config: T5Config, prefix: str = ""): + super().__init__(config) + + quant_config = None + tp_group = _get_folding_tp_group(config) + self.shared = VocabParallelEmbedding( + config.vocab_size, + config.d_model, + org_num_embeddings=config.vocab_size, + tp_group=tp_group, + ) + + self.encoder = T5Stack( + config, + False, + config.num_layers, + self.shared, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + is_umt5=True, + ) + + def get_input_embeddings(self): + return self.shared + + def forward( + self, + input_ids: torch.Tensor | None, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + **kwargs, + ) -> BaseEncoderOutput: + attn_metadata = AttentionMetadata(None) + hidden_states = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + attn_metadata=attn_metadata, + ) + + return BaseEncoderOutput( + last_hidden_state=hidden_states, + attention_mask=attention_mask, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + loaded = False + if "decoder" in name or "lm_head" in name: + continue + for ( + param_name, + weight_name, + shard_id, + ) in self.config.arch_config.stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded = True + break + if not loaded: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +EntryClass = [T5EncoderModel, UMT5EncoderModel] diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/encoders/vision.py b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/vision.py new file mode 100644 index 0000000000000000000000000000000000000000..3150abf1cb6f3dcb6da687365556264417e2c9c8 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/encoders/vision.py @@ -0,0 +1,96 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/vision.py + +from abc import ABC, abstractmethod +from typing import Generic, TypeVar + +import torch +from transformers import PretrainedConfig + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +_C = TypeVar("_C", bound=PretrainedConfig) + + +class VisionEncoderInfo(ABC, Generic[_C]): + + def __init__(self, vision_config: _C) -> None: + super().__init__() + + self.vision_config = vision_config + + @abstractmethod + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + raise NotImplementedError + + @abstractmethod + def get_max_image_tokens(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_image_size(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_patch_size(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_patch_grid_length(self) -> int: + raise NotImplementedError + + +def resolve_visual_encoder_outputs( + encoder_outputs: torch.Tensor | list[torch.Tensor], + feature_sample_layers: list[int] | None, + post_layer_norm: torch.nn.LayerNorm | None, + max_possible_layers: int, +) -> torch.Tensor: + """Given the outputs a visual encoder module that may correspond to the + output of the last layer, or a list of hidden states to be stacked, + handle post normalization and resolve it into a single output tensor. + + Args: + encoder_outputs: Output of encoder's last layer or all hidden states. + feature_sample_layers: Optional layer indices to grab from the encoder + outputs; if provided, encoder outputs must be a list. + post_layer_norm: Post norm to apply to the output of the encoder. + max_possible_layers: Total layers in the fully loaded visual encoder. + + """ + if feature_sample_layers is None: + if post_layer_norm is not None: + return post_layer_norm(encoder_outputs) + return encoder_outputs + + # Get the hidden states corresponding to the layer indices. + # Negative values are relative to the full visual encoder, + # so offset them depending on how many layers were loaded. + # NOTE: this assumes that encoder_outputs is a list containing + # the inputs to the visual encoder, followed by the hidden states + # of each layer. + num_loaded_layers = len(encoder_outputs) - 1 + offset = max_possible_layers - num_loaded_layers + hs_pool = [ + ( + encoder_outputs[layer_idx] + if layer_idx >= 0 + else encoder_outputs[layer_idx + offset] + ) + for layer_idx in feature_sample_layers + ] + + # Apply post-norm on the final hidden state if we are using it + uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1) + if post_layer_norm is not None and uses_last_layer: + hs_pool[-1] = post_layer_norm(encoder_outputs) + return torch.cat(hs_pool, dim=-1) diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/parameter.py b/sglang/python/sglang/multimodal_gen/runtime/models/parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..ba9b42c664a87e5c35da2e0d272a7d9b72bb9c93 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/parameter.py @@ -0,0 +1,423 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/parameter.py + +from collections.abc import Callable +from fractions import Fraction +from typing import Any + +import torch +from torch.nn import Parameter + +from sglang.multimodal_gen.runtime.distributed import get_tp_rank +from sglang.multimodal_gen.runtime.models.utils import _make_synced_weight_loader +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class BasevLLMParameter(Parameter): + """ + Base parameter for vLLM linear layers. Extends the torch.nn.parameter + by taking in a linear weight loader. Will copy the loaded weight + into the parameter when the provided weight loader is called. + """ + + def __new__(cls, data: torch.Tensor, **kwargs): + + return super().__new__(cls, data=data, requires_grad=False) + + def __init__(self, data: torch.Tensor, weight_loader: Callable): + """ + Initialize the BasevLLMParameter + + :param data: torch tensor with the parameter data + :param weight_loader: weight loader callable + + :returns: a torch.nn.parameter + """ + + # During weight loading, we often do something like: + # narrowed_tensor = param.data.narrow(0, offset, len) + # narrowed_tensor.copy_(real_weight) + # expecting narrowed_tensor and param.data to share the same storage. + # However, on TPUs, narrowed_tensor will lazily propagate to the base + # tensor, which is param.data, leading to the redundant memory usage. + # This sometimes causes OOM errors during model loading. To avoid this, + # we sync the param tensor after its weight loader is called. + from sglang.multimodal_gen.runtime.platforms import current_platform + + if current_platform.is_tpu(): + weight_loader = _make_synced_weight_loader(weight_loader) + + self._weight_loader = weight_loader + + @property + def weight_loader(self): + return self._weight_loader + + def _is_1d_and_scalar(self, loaded_weight: torch.Tensor): + cond1 = self.data.ndim == 1 and self.data.numel() == 1 + cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1 + return cond1 and cond2 + + def _assert_and_load(self, loaded_weight: torch.Tensor) -> None: + assert self.data.shape == loaded_weight.shape or self._is_1d_and_scalar( + loaded_weight + ) + self.data.copy_(loaded_weight) + + def load_column_parallel_weight(self, loaded_weight: torch.Tensor) -> None: + self._assert_and_load(loaded_weight) + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor) -> None: + self._assert_and_load(loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs) -> None: + self._assert_and_load(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs) -> None: + self._assert_and_load(loaded_weight) + + +class _ColumnvLLMParameter(BasevLLMParameter): + """ + Private class defining weight loading functionality + (load_merged_column_weight, load_qkv_weight) + for parameters being loaded into linear layers with column + parallelism. This includes QKV and MLP layers which are + not already fused on disk. Requires an output dimension + to be defined. Called within the weight loader of + each of the column parallel linear layers. + """ + + def __init__(self, output_dim: int, **kwargs): + self._output_dim = output_dim + super().__init__(**kwargs) + + @property + def output_dim(self): + return self._output_dim + + def load_column_parallel_weight(self, loaded_weight: torch.Tensor) -> None: + tp_rank = get_tp_rank() + shard_size = self.data.shape[self.output_dim] + loaded_weight = loaded_weight.narrow( + self.output_dim, tp_rank * shard_size, shard_size + ) + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs) -> None: + + shard_offset = kwargs.get("shard_offset") + shard_size = kwargs.get("shard_size") + if shard_offset is None or shard_size is None: + raise ValueError("shard_offset and shard_size must be provided") + if ( + isinstance(self, PackedColumnParameter | PackedvLLMParameter) + and self.packed_dim == self.output_dim + ): + shard_size, shard_offset = self.adjust_shard_indexes_for_packing( + shard_offset=shard_offset, shard_size=shard_size + ) + + param_data = self.data + + tp_rank = get_tp_rank() + param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) + loaded_weight = loaded_weight.narrow( + self.output_dim, tp_rank * shard_size, shard_size + ) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs) -> None: + + shard_offset = kwargs.get("shard_offset") + shard_size = kwargs.get("shard_size") + shard_id = kwargs.get("shard_id") + num_heads = kwargs.get("num_heads") + + assert shard_offset is not None + assert shard_size is not None + assert shard_id is not None + assert num_heads is not None + + if ( + isinstance(self, PackedColumnParameter | PackedvLLMParameter) + and self.output_dim == self.packed_dim + ): + shard_size, shard_offset = self.adjust_shard_indexes_for_packing( + shard_offset=shard_offset, shard_size=shard_size + ) + + param_data = self.data + tp_rank = get_tp_rank() + shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads + param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) + loaded_weight = loaded_weight.narrow( + self.output_dim, shard_id * shard_size, shard_size + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class RowvLLMParameter(BasevLLMParameter): + """ + Parameter class defining weight_loading functionality + (load_row_parallel_weight) for parameters being loaded + into linear layers with row parallel functionality. + Requires an input_dim to be defined. + """ + + def __init__(self, input_dim: int, **kwargs): + self._input_dim = input_dim + super().__init__(**kwargs) + + @property + def input_dim(self): + return self._input_dim + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor) -> None: + tp_rank = get_tp_rank() + shard_size = self.data.shape[self.input_dim] + loaded_weight = loaded_weight.narrow( + self.input_dim, tp_rank * shard_size, shard_size + ) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + +class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for linear layer weights. Uses both column and + row parallelism. + """ + + pass + + +class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + grouped quantization. Uses both column and row parallelism. + """ + + pass + + +class ChannelQuantScaleParameter(_ColumnvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + channel-wise quantization. Equivalent to _ColumnvLLMParameter. + """ + + pass + + +class PerTensorScaleParameter(BasevLLMParameter): + """ + Parameter class for scales where the number of scales is + equivalent to the number of logical matrices in fused linear + layers (e.g. for QKV, there are 3 scales loaded from disk). + This is relevant to weights with per-tensor quantization. + Adds functionality to map the scalers to a shard during + weight loading. + + Note: additional parameter manipulation may be handled + for each quantization config specifically, within + process_weights_after_loading + """ + + def __init__(self, **kwargs): + self.qkv_idxs = {"q": 0, "k": 1, "v": 2} + super().__init__(**kwargs) + + def _shard_id_as_int(self, shard_id: str | int) -> int: + if isinstance(shard_id, int): + return shard_id + + # if not int, assume shard_id for qkv + # map to int and return + assert isinstance(shard_id, str) + assert shard_id in self.qkv_idxs + return self.qkv_idxs[shard_id] + + # For row parallel layers, no sharding needed + # load weight into parameter as is + def load_row_parallel_weight(self, *args, **kwargs) -> None: + super().load_row_parallel_weight(*args, **kwargs) + + def load_merged_column_weight(self, *args, **kwargs) -> None: + self._load_into_shard_id(*args, **kwargs) + + def load_qkv_weight(self, *args, **kwargs) -> None: + self._load_into_shard_id(*args, **kwargs) + + def load_column_parallel_weight(self, *args, **kwargs) -> None: + super().load_row_parallel_weight(*args, **kwargs) + + def _load_into_shard_id( + self, loaded_weight: torch.Tensor, shard_id: str | int, **kwargs + ): + """ + Slice the parameter data based on the shard id for + loading. + """ + + param_data = self.data + shard_id = self._shard_id_as_int(shard_id) + + # AutoFP8 scales do not have a shape + # compressed-tensors scales do have a shape + if len(loaded_weight.shape) != 0: + assert loaded_weight.shape[0] == 1 + loaded_weight = loaded_weight[0] + + param_data = param_data[shard_id] + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class PackedColumnParameter(_ColumnvLLMParameter): + """ + Parameter for model parameters which are packed on disk + and support column parallelism only. See PackedvLLMParameter + for more details on the packed properties. + """ + + def __init__(self, packed_factor: int | Fraction, packed_dim: int, **kwargs): + self._packed_factor = packed_factor + self._packed_dim = packed_dim + super().__init__(**kwargs) + + @property + def packed_dim(self): + return self._packed_dim + + @property + def packed_factor(self): + return self._packed_factor + + def adjust_shard_indexes_for_packing( + self, shard_size, shard_offset + ) -> tuple[Any, Any]: + return _adjust_shard_indexes_for_packing( + shard_size=shard_size, + shard_offset=shard_offset, + packed_factor=self.packed_factor, + ) + + +class PackedvLLMParameter(ModelWeightParameter): + """ + Parameter for model weights which are packed on disk. + Example: GPTQ Marlin weights are int4 or int8, packed into int32. + Extends the ModelWeightParameter to take in the + packed factor, the packed dimension, and optionally, marlin + tile size for marlin kernels. Adjusts the shard_size and + shard_offset for fused linear layers model weight loading + by accounting for packing and optionally, marlin tile size. + """ + + def __init__(self, packed_factor: int | Fraction, packed_dim: int, **kwargs): + self._packed_factor = packed_factor + self._packed_dim = packed_dim + super().__init__(**kwargs) + + @property + def packed_dim(self): + return self._packed_dim + + @property + def packed_factor(self): + return self._packed_factor + + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): + return _adjust_shard_indexes_for_packing( + shard_size=shard_size, + shard_offset=shard_offset, + packed_factor=self.packed_factor, + ) + + +class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + block-wise quantization. Uses both column and row parallelism. + """ + + pass + + +def permute_param_layout_( + param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs +) -> BasevLLMParameter: + """ + Permute a parameter's layout to the specified input and output dimensions, + useful for forcing the parameter into a known layout, for example, if I need + a packed (quantized) weight matrix to be in the layout + {input_dim = 0, output_dim = 1, packed_dim = 0} + then I can call: + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + to ensure x is in the correct layout (permuting it to the correct layout if + required, asserting if it cannot get it to the correct layout) + """ + + curr_input_dim = getattr(param, "input_dim", None) + curr_output_dim = getattr(param, "output_dim", None) + + if curr_input_dim is None or curr_output_dim is None: + assert param.data.dim() == 2, ( + "permute_param_layout_ only supports 2D parameters when either " + "input_dim or output_dim is not set" + ) + + # if one of the dimensions is not set, set it to the opposite of the other + # we can only do this since we asserted the parameter is 2D above + if curr_input_dim is None: + assert curr_output_dim is not None, "either input or output dim must be set" + curr_input_dim = (curr_output_dim + 1) % 2 + if curr_output_dim is None: + assert curr_input_dim is not None, "either input or output dim must be set" + curr_output_dim = (curr_input_dim + 1) % 2 + + # create permutation from the current layout to the layout with + # self.input_dim at input_dim and self.output_dim at output_dim preserving + # other dimensions + perm = [ + i for i in range(param.data.dim()) if i not in [curr_input_dim, curr_output_dim] + ] + perm.insert(input_dim, curr_input_dim) + perm.insert(output_dim, curr_output_dim) + + if "packed_dim" in kwargs: + assert ( + hasattr(param, "packed_dim") + and param.packed_dim == perm[kwargs["packed_dim"]] + ), "permute_param_layout_ currently doesn't support repacking" + + param.data = param.data.permute(*perm) + if hasattr(param, "_input_dim"): + param._input_dim = input_dim + if hasattr(param, "_output_dim"): + param._output_dim = output_dim + if "packed_dim" in kwargs and hasattr(param, "_packed_dim"): + param._packed_dim = kwargs["packed_dim"] + + return param + + +def _adjust_shard_indexes_for_packing( + shard_size, shard_offset, packed_factor +) -> tuple[Any, Any]: + shard_size = shard_size // packed_factor + shard_offset = shard_offset // packed_factor + return shard_size, shard_offset diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/registry.py b/sglang/python/sglang/multimodal_gen/runtime/models/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb0a0a325ee8d30327f9d6b8488af57732796c2 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/registry.py @@ -0,0 +1,422 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/registry.py + +import ast +import importlib +import os +import pickle +import subprocess +import sys +import tempfile +from abc import ABC, abstractmethod +from collections.abc import Callable, Set +from dataclasses import dataclass, field +from functools import lru_cache +from typing import NoReturn, TypeVar, cast + +import cloudpickle +from torch import nn + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +MODELS_PATH = os.path.dirname(__file__) +COMPONENT_DIRS = [ + d + for d in os.listdir(MODELS_PATH) + if os.path.isdir(os.path.join(MODELS_PATH, d)) + and not d.startswith("__") + and not d.startswith(".") +] + +_IMAGE_ENCODER_MODELS: dict[str, tuple] = { + # "HunyuanVideoTransformer3DModel": ("image_encoder", "hunyuanvideo", "HunyuanVideoImageEncoder"), + "CLIPVisionModelWithProjection": ("encoders", "clip", "CLIPVisionModel"), +} + +# Global alias mapping: external_path -> canonical_class_name +_ALIAS_TO_MODEL: dict[str, str] = {} + + +def _parse_aliases_from_ast(value_node: ast.expr) -> list[str]: + """Parse _aliases list from AST node.""" + aliases = [] + if isinstance(value_node, (ast.List, ast.Tuple)): + for elt in value_node.elts: + if isinstance(elt, ast.Constant) and isinstance(elt.value, str): + aliases.append(elt.value) + return aliases + + +@lru_cache(maxsize=None) +def _discover_and_register_models() -> dict[str, tuple[str, str, str]]: + discovered_models = dict(_IMAGE_ENCODER_MODELS) + + # Collect class definitions with their _aliases + class_aliases: dict[str, list[str]] = {} + + for component in COMPONENT_DIRS: + component_path = os.path.join(MODELS_PATH, component) + for filename in os.listdir(component_path): + if not filename.endswith(".py"): + continue + + mod_relname = filename[:-3] + filepath = os.path.join(component_path, filename) + try: + with open(filepath, "r", encoding="utf-8") as f: + source = f.read() + tree = ast.parse(source, filename=filename) + + entry_class_node = None + first_class_def = None + + # Collect all class definitions and their _aliases + file_class_aliases: dict[str, list[str]] = {} + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + if first_class_def is None: + first_class_def = node + # Look for _aliases in the class body + for class_body_node in node.body: + if isinstance(class_body_node, ast.Assign): + for target in class_body_node.targets: + if ( + isinstance(target, ast.Name) + and target.id == "_aliases" + ): + aliases = _parse_aliases_from_ast( + class_body_node.value + ) + if aliases: + file_class_aliases[node.name] = aliases + if isinstance(node, ast.Assign): + for target in node.targets: + if ( + isinstance(target, ast.Name) + and target.id == "EntryClass" + ): + entry_class_node = node + break + + if entry_class_node and first_class_def: + model_cls_name_list = [] + value_node = entry_class_node.value + + # EntryClass = ClassName + if isinstance(value_node, ast.Name): + model_cls_name_list.append(value_node.id) + # EntryClass = ["...", ClassName, ...] + elif isinstance(value_node, (ast.List, ast.Tuple)): + for elt in value_node.elts: + if isinstance(elt, ast.Constant): + model_cls_name_list.append(elt.value) + elif isinstance(elt, ast.Name): + model_cls_name_list.append(elt.id) + + if model_cls_name_list: + for model_cls_str in model_cls_name_list: + if model_cls_str in discovered_models: + logger.warning( + f"Duplicate architecture found: {model_cls_str}. It will be overwritten." + ) + model_arch = model_cls_str + discovered_models[model_arch] = ( + component, + mod_relname, + model_cls_str, + ) + # Collect aliases for this class + if model_cls_str in file_class_aliases: + class_aliases[model_cls_str] = file_class_aliases[ + model_cls_str + ] + + except Exception as e: + logger.warning(f"Could not parse {filepath} to find models: {e}") + + # Build alias -> canonical class name mapping + for class_name, aliases in class_aliases.items(): + for alias in aliases: + if alias in _ALIAS_TO_MODEL: + logger.warning( + f"Alias '{alias}' already registered for '{_ALIAS_TO_MODEL[alias]}', " + f"will be overwritten by '{class_name}'" + ) + _ALIAS_TO_MODEL[alias] = class_name + + return discovered_models + + +_SGLANG_DIFFUSION_MODELS = _discover_and_register_models() + +_SUBPROCESS_COMMAND = [ + sys.executable, + "-m", + "sglang.multimodal_gen.runtime.models.dits.registry", +] + +_T = TypeVar("_T") + + +@dataclass(frozen=True) +class _ModelInfo: + architecture: str + + @staticmethod + def from_model_cls(model: type[nn.Module]) -> "_ModelInfo": + return _ModelInfo( + architecture=model.__name__, + ) + + +class _BaseRegisteredModel(ABC): + + @abstractmethod + def inspect_model_cls(self) -> _ModelInfo: + raise NotImplementedError + + @abstractmethod + def load_model_cls(self) -> type[nn.Module]: + raise NotImplementedError + + +@dataclass(frozen=True) +class _RegisteredModel(_BaseRegisteredModel): + """ + Represents a model that has already been imported in the main process. + """ + + interfaces: _ModelInfo + model_cls: type[nn.Module] + + @staticmethod + def from_model_cls(model_cls: type[nn.Module]): + return _RegisteredModel( + interfaces=_ModelInfo.from_model_cls(model_cls), + model_cls=model_cls, + ) + + def inspect_model_cls(self) -> _ModelInfo: + return self.interfaces + + def load_model_cls(self) -> type[nn.Module]: + return self.model_cls + + +def _run_in_subprocess(fn: Callable[[], _T]) -> _T: + # NOTE: We use a temporary directory instead of a temporary file to avoid + # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file + with tempfile.TemporaryDirectory() as tempdir: + output_filepath = os.path.join(tempdir, "registry_output.tmp") + + # `cloudpickle` allows pickling lambda functions directly + input_bytes = cloudpickle.dumps((fn, output_filepath)) + + # cannot use `sys.executable __file__` here because the script + # contains relative imports + returned = subprocess.run( + _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True + ) + + # check if the subprocess is successful + try: + returned.check_returncode() + except Exception as e: + # wrap raised exception to provide more information + raise RuntimeError( + f"Error raised in subprocess:\n" f"{returned.stderr.decode()}" + ) from e + + with open(output_filepath, "rb") as f: + return cast(_T, pickle.load(f)) + + +@dataclass(frozen=True) +class _LazyRegisteredModel(_BaseRegisteredModel): + """ + Represents a model that has not been imported in the main process. + """ + + module_name: str + component_name: str + class_name: str + + # Performed in another process to avoid initializing CUDA + def inspect_model_cls(self) -> _ModelInfo: + return _run_in_subprocess( + lambda: _ModelInfo.from_model_cls(self.load_model_cls()) + ) + + def load_model_cls(self) -> type[nn.Module]: + mod = importlib.import_module(self.module_name) + return cast(type[nn.Module], getattr(mod, self.class_name)) + + +@lru_cache(maxsize=128) +def _try_load_model_cls( + model_arch: str, + model: _BaseRegisteredModel, +) -> type[nn.Module] | None: + from sglang.multimodal_gen.runtime.platforms import current_platform + + current_platform.verify_model_arch(model_arch) + try: + return model.load_model_cls() + except Exception: + logger.exception("Ignore import error when loading '%s'", model_arch) + return None + + +@lru_cache(maxsize=128) +def _try_inspect_model_cls( + model_arch: str, + model: _BaseRegisteredModel, +) -> _ModelInfo | None: + try: + return model.inspect_model_cls() + except Exception: + logger.exception("Error in inspecting model architecture '%s'", model_arch) + return None + + +@dataclass +class _ModelRegistry: + # Keyed by model_arch + registered_models: dict[str, _BaseRegisteredModel] = field(default_factory=dict) + + def get_supported_archs(self) -> Set[str]: + return self.registered_models.keys() + + def resolve_by_alias(self, alias: str) -> type[nn.Module] | None: + """Resolve a model class by its alias (external module path).""" + if alias in _ALIAS_TO_MODEL: + canonical_name = _ALIAS_TO_MODEL[alias] + return self._try_load_model_cls(canonical_name) + return None + + def register_model( + self, + model_arch: str, + model_cls: type[nn.Module] | str, + ) -> None: + """ + Register an external model to be used in vLLM. + + :code:`model_cls` can be either: + + - A :class:`torch.nn.Module` class directly referencing the model. + - A string in the format :code:`:` which can be used to + lazily import the model. This is useful to avoid initializing CUDA + when importing the model and thus the related error + :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`. + """ + if model_arch in self.registered_models: + logger.warning( + "Model architecture %s is already registered, and will be " + "overwritten by the new model class %s.", + model_arch, + model_cls, + ) + + if isinstance(model_cls, str): + split_str = model_cls.split(":") + if len(split_str) != 2: + msg = "Expected a string in the format `:`" + raise ValueError(msg) + + model = _LazyRegisteredModel(*split_str) + else: + model = _RegisteredModel.from_model_cls(model_cls) + + self.registered_models[model_arch] = model + + def _raise_for_unsupported(self, architectures: list[str]) -> NoReturn: + all_supported_archs = self.get_supported_archs() + + if any(arch in all_supported_archs for arch in architectures): + raise ValueError( + f"Model architectures {architectures} failed " + "to be inspected. Please check the logs for more details." + ) + + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {all_supported_archs}" + ) + + def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None: + if model_arch not in self.registered_models: + return None + + return _try_load_model_cls(model_arch, self.registered_models[model_arch]) + + def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None: + if model_arch not in self.registered_models: + return None + + return _try_inspect_model_cls(model_arch, self.registered_models[model_arch]) + + def _normalize_archs( + self, + architectures: str | list[str], + ) -> list[str]: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + normalized_arch = [] + for arch in architectures: + if arch not in self.registered_models: + registered_models = list(self.registered_models.keys()) + raise Exception( + f"Unsupported model architecture: {arch}. Registered architectures: {registered_models}" + ) + normalized_arch.append(arch) + return normalized_arch + + def inspect_model_cls( + self, + architectures: str | list[str], + ) -> tuple[_ModelInfo, str]: + architectures = self._normalize_archs(architectures) + + for arch in architectures: + model_info = self._try_inspect_model_cls(arch) + if model_info is not None: + return (model_info, arch) + + return self._raise_for_unsupported(architectures) + + def resolve_model_cls( + self, + architectures: str | list[str], + ) -> tuple[type[nn.Module], str]: + architectures = self._normalize_archs(architectures) + + for arch in architectures: + model_cls = self._try_load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) + + return self._raise_for_unsupported(architectures) + + +ModelRegistry = _ModelRegistry( + { + model_arch: _LazyRegisteredModel( + module_name=f"sglang.multimodal_gen.runtime.models.{component_name}.{mod_relname}", + component_name=component_name, + class_name=cls_name, + ) + for model_arch, ( + component_name, + mod_relname, + cls_name, + ) in _SGLANG_DIFFUSION_MODELS.items() + } +) diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/base.py b/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..eb4e3bdda8dc83ccdf649b9e30db46e2817f9af0 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/base.py @@ -0,0 +1,37 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod + +import torch + + +class BaseScheduler(ABC): + timesteps: torch.Tensor + order: int + num_train_timesteps: int + + def __init__(self, *args, **kwargs) -> None: + # Check if subclass has defined all required properties + required_attributes = ["timesteps", "order", "num_train_timesteps"] + + for attr in required_attributes: + if not hasattr(self, attr): + raise AttributeError( + f"Subclasses of BaseScheduler must define '{attr}' property" + ) + + @abstractmethod + def set_shift(self, shift: float) -> None: + pass + + @abstractmethod + def set_timesteps(self, *args, **kwargs) -> None: + pass + + @abstractmethod + def scale_model_input( + self, sample: torch.Tensor, timestep: int | None = None + ) -> torch.Tensor: + pass diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/flow_match_pair.py b/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/flow_match_pair.py new file mode 100644 index 0000000000000000000000000000000000000000..1d7e09d2b9e13a4d4c91a847f6b512ebd991c5df --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/flow_match_pair.py @@ -0,0 +1,529 @@ +# Copied and adapted from: https://github.com/OpenMOSS/MOVA/tree/main/mova/diffusion/schedulers/flow_match.py and flow_match_pair.py +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import math + +import torch + +from sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler + + +class FlowMatchScheduler(BaseScheduler): + def __init__( + self, + num_inference_steps=100, + num_train_timesteps=1000, + shift=3.0, + sigma_max=1.0, + sigma_min=0.003 / 1.002, + inverse_timesteps=False, + extra_one_step=False, + reverse_sigmas=False, + exponential_shift=False, + exponential_shift_mu=None, + shift_terminal=None, + ): + self.order = 1 + self.num_train_timesteps = num_train_timesteps + self.shift = shift + self.sigma_max = sigma_max + self.sigma_min = sigma_min + self.inverse_timesteps = inverse_timesteps + self.extra_one_step = extra_one_step + self.reverse_sigmas = reverse_sigmas + self.exponential_shift = exponential_shift + self.exponential_shift_mu = exponential_shift_mu + self.shift_terminal = shift_terminal + self.train_timesteps = None + self.train_sigmas = None + self.set_timesteps(num_train_timesteps) + self.set_timesteps(num_inference_steps) + BaseScheduler.__init__(self) + + def set_shift(self, shift: float) -> None: + self.shift = shift + + def set_timesteps( + self, + num_inference_steps=100, + denoising_strength=1.0, + training=False, + shift=None, + dynamic_shift_len=None, + ): + if shift is not None: + self.shift = shift + sigma_start = ( + self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength + ) + if self.extra_one_step: + self.sigmas = torch.linspace( + sigma_start, self.sigma_min, num_inference_steps + 1 + )[:-1] + else: + self.sigmas = torch.linspace( + sigma_start, self.sigma_min, num_inference_steps + ) + if self.inverse_timesteps: + self.sigmas = torch.flip(self.sigmas, dims=[0]) + if self.exponential_shift: + mu = ( + self.calculate_shift(dynamic_shift_len) + if dynamic_shift_len is not None + else self.exponential_shift_mu + ) + self.sigmas = math.exp(mu) / (math.exp(mu) + (1 / self.sigmas - 1)) + else: + self.sigmas = ( + self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas) + ) + if self.shift_terminal is not None: + one_minus_z = 1 - self.sigmas + scale_factor = one_minus_z[-1] / (1 - self.shift_terminal) + self.sigmas = 1 - (one_minus_z / scale_factor) + if self.reverse_sigmas: + self.sigmas = 1 - self.sigmas + self.timesteps = self.sigmas * self.num_train_timesteps + # Initialize train_timesteps on first set. + if self.train_timesteps is None: + self.train_timesteps = self.timesteps + self.train_sigmas = self.sigmas + if training: + x = self.timesteps + y = torch.exp( + -2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2 + ) + y_shifted = y - y.min() + bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum()) + self.linear_timesteps_weights = bsmntw_weighing + self.training = True + else: + self.training = False + + def scale_model_input(self, sample: torch.Tensor, timestep: int | None = None): + return sample + + def step(self, model_output, timestep, sample, to_final=False, **kwargs): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + if to_final or timestep_id + 1 >= len(self.timesteps): + sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0 + else: + sigma_ = self.sigmas[timestep_id + 1] + prev_sample = sample + model_output * (sigma_ - sigma) + return prev_sample + + def return_to_timestep(self, timestep, sample, sample_stablized): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + model_output = (sample - sample_stablized) / sigma + return model_output + + def add_noise(self, original_samples, noise, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + sample = (1 - sigma) * original_samples + sigma * noise + return sample + + def training_target(self, sample, noise, timestep): + target = noise - sample + return target + + def training_weight(self, timestep): + timestep_id = torch.argmin( + (self.timesteps - timestep.to(self.timesteps.device)).abs() + ) + weights = self.linear_timesteps_weights[timestep_id] + return weights + + def calculate_shift( + self, + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 8192, + base_shift: float = 0.5, + max_shift: float = 0.9, + ): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +class FlowMatchPairScheduler(FlowMatchScheduler): + """Pairing scheduler built on FlowMatchScheduler. + + Provides a convenient pairing interface for timesteps or sigmas. + + Attributes: + pair_timesteps: Cached timestep pairs of shape [num_timesteps, 2]. + pair_sigmas: Cached sigma pairs of shape [num_timesteps, 2]. + """ + + def __init__( + self, + num_inference_steps=100, + num_train_timesteps=1000, + shift=3.0, + sigma_max=1.0, + sigma_min=0.003 / 1.002, + inverse_timesteps=False, + extra_one_step=False, + reverse_sigmas=False, + exponential_shift=False, + exponential_shift_mu=None, + shift_terminal=None, + ): + self._pair_postprocess_fn = None + self._pair_postprocess_requires_source = False + self.pair_timesteps: torch.Tensor | None = None + self.pair_sigmas: torch.Tensor | None = None + self.timesteps: torch.Tensor | None = None + self.sigmas: torch.Tensor | None = None + super().__init__( + num_inference_steps=num_inference_steps, + num_train_timesteps=num_train_timesteps, + shift=shift, + sigma_max=sigma_max, + sigma_min=sigma_min, + inverse_timesteps=inverse_timesteps, + extra_one_step=extra_one_step, + reverse_sigmas=reverse_sigmas, + exponential_shift=exponential_shift, + exponential_shift_mu=exponential_shift_mu, + shift_terminal=shift_terminal, + ) + + def set_pair_postprocess(self, fn): + """Set a postprocess function to customize pairs after construction. + + Args: + fn: Callable with signature fn(pairs: torch.Tensor) -> torch.Tensor. + The returned tensor must have the same shape as input pairs. + + Raises: + TypeError: If fn is not callable or None. + RuntimeError: If scheduler is not initialized. + """ + if fn is not None and not callable(fn): + raise TypeError("pair_postprocess must be callable or None") + self._pair_postprocess_fn = fn + self._pair_postprocess_requires_source = ( + False if fn is None else bool(getattr(fn, "_requires_source", False)) + ) + if self.timesteps is None or self.sigmas is None: + raise RuntimeError("Scheduler not initialized; call set_timesteps() first") + self._refresh_pair_cache() + + def set_pair_postprocess_by_name(self, name: str | None, **kwargs): + """Configure a postprocess function by name. + + Supported names: + - None/"none"/"off"/"false"/"no": disable + - "quadratic_perp_bulge_swap": x2=x+d, y2=x-d, where d=4*amp*s*(1-s), s=t/T + - "v2a_sequential": assume pairs are (t,t); sample half sequence from column 0 + with stride 2, then let column 0 follow this sequence first, followed by column 1 + - "a2v_sequential": same as above, but column 1 first then column 0 + - "dual_sigma_shift": use only timestep count; rebuild two columns independently using + FlowMatchScheduler sigma transform logic; configurable visual_shift/audio_shift + + Args: + name: Postprocess name or None to disable. + **kwargs: Extra parameters for the named postprocess. For example: + - amp: Float amplitude, default 150.0. + + Raises: + ValueError: If name is unknown. + """ + + if name is None or str(name).lower() in ("none", "off", "false", "no"): + self.set_pair_postprocess(None) + return + if name == "quadratic_perp_bulge_swap": + amp = float(kwargs.get("amp", 150.0)) + + def _quadratic_perp_bulge_swap(pairs: torch.Tensor): + if ( + not isinstance(pairs, torch.Tensor) + or pairs.ndim != 2 + or pairs.shape[1] != 2 + ): + raise ValueError("pairs must be a torch.Tensor of shape [N, 2]") + x = pairs[:, 0] + T = float(self.num_train_timesteps) + s = x / T + d = 4.0 * amp * s * (1.0 - s) + x2 = x + d + y2 = x - d + return torch.stack([x2, y2], dim=1) + + self.set_pair_postprocess(_quadratic_perp_bulge_swap) + return + if name == "v2a_sequential": + + def _v2a(pairs: torch.Tensor): + if ( + not isinstance(pairs, torch.Tensor) + or pairs.ndim != 2 + or pairs.shape[1] != 2 + ): + raise ValueError("pairs must be a torch.Tensor of shape [N, 2]") + N = pairs.shape[0] + base = pairs[:, 0] + seq_half = base[::2] + m = int(seq_half.shape[0]) + col0 = torch.cat([seq_half, seq_half[-1:].repeat(m)], dim=0)[:N] + col1 = torch.cat([seq_half[0:1].repeat(m), seq_half], dim=0)[:N] + return torch.stack( + [ + col0.to(dtype=pairs.dtype, device=pairs.device), + col1.to(dtype=pairs.dtype, device=pairs.device), + ], + dim=1, + ) + + self.set_pair_postprocess(_v2a) + return + if name == "a2v_sequential": + + def _a2v(pairs: torch.Tensor): + if ( + not isinstance(pairs, torch.Tensor) + or pairs.ndim != 2 + or pairs.shape[1] != 2 + ): + raise ValueError("pairs must be a torch.Tensor of shape [N, 2]") + N = pairs.shape[0] + base = pairs[:, 0] + seq_half = base[::2] + m = int(seq_half.shape[0]) + col0 = torch.cat([seq_half[0:1].repeat(m), seq_half], dim=0)[:N] + col1 = torch.cat([seq_half, seq_half[-1:].repeat(m)], dim=0)[:N] + return torch.stack( + [ + col0.to(dtype=pairs.dtype, device=pairs.device), + col1.to(dtype=pairs.dtype, device=pairs.device), + ], + dim=1, + ) + + self.set_pair_postprocess(_a2v) + return + if name == "v2a": + + def _v2a_classic(pairs: torch.Tensor): + if ( + not isinstance(pairs, torch.Tensor) + or pairs.ndim != 2 + or pairs.shape[1] != 2 + ): + raise ValueError("pairs must be a torch.Tensor of shape [N, 2]") + zeros = torch.zeros_like(pairs[:, 0]) + return torch.stack([zeros, pairs[:, 1]], dim=1) + + self.set_pair_postprocess(_v2a_classic) + return + if name == "a2v": + + def _a2v_classic(pairs: torch.Tensor): + if ( + not isinstance(pairs, torch.Tensor) + or pairs.ndim != 2 + or pairs.shape[1] != 2 + ): + raise ValueError("pairs must be a torch.Tensor of shape [N, 2]") + zeros = torch.zeros_like(pairs[:, 1]) + return torch.stack([pairs[:, 0], zeros], dim=1) + + self.set_pair_postprocess(_a2v_classic) + return + if name == "dual_sigma_shift": + visual_shift = float(kwargs.get("visual_shift", self.shift)) + audio_shift = float(kwargs.get("audio_shift", self.shift)) + visual_denoising_strength = float( + kwargs.get("visual_denoising_strength", 1.0) + ) + audio_denoising_strength = float( + kwargs.get("audio_denoising_strength", 1.0) + ) + visual_mu = kwargs.get( + "visual_exponential_shift_mu", self.exponential_shift_mu + ) + audio_mu = kwargs.get( + "audio_exponential_shift_mu", self.exponential_shift_mu + ) + + def _dual_sigma_shift(pairs: torch.Tensor, *, source: str): + if not isinstance(pairs, torch.Tensor): + raise TypeError("pairs must be a torch.Tensor") + if pairs.ndim != 2 or pairs.shape[1] != 2: + raise ValueError("pairs must be a torch.Tensor of shape [N, 2]") + if pairs.shape[0] == 0: + raise ValueError("pairs length must be greater than 0") + if source not in ("timesteps", "sigmas"): + raise ValueError("source must be 'timesteps' or 'sigmas'") + + num_steps = pairs.shape[0] + device = pairs.device + dtype = pairs.dtype + + def _build_column( + shift_value: float, denoising_strength: float, mu_override + ): + if shift_value <= 0: + raise ValueError("shift must be positive") + if denoising_strength <= 0: + raise ValueError("denoising_strength must be positive") + + sigma_start = ( + self.sigma_min + + (self.sigma_max - self.sigma_min) * denoising_strength + ) + if self.extra_one_step: + base = torch.linspace( + sigma_start, + self.sigma_min, + num_steps + 1, + device=device, + dtype=dtype, + )[:-1] + else: + base = torch.linspace( + sigma_start, + self.sigma_min, + num_steps, + device=device, + dtype=dtype, + ) + + if self.inverse_timesteps: + base = torch.flip(base, dims=[0]) + + if self.exponential_shift: + mu_value = mu_override + if mu_value is None: + raise RuntimeError( + "exponential_shift enabled but exponential_shift_mu is missing" + ) + exp_mu = math.exp(float(mu_value)) + base = exp_mu / (exp_mu + (1 / base - 1)) + else: + base = shift_value * base / (1 + (shift_value - 1) * base) + + if self.shift_terminal is not None: + one_minus_z = 1 - base + scale_factor = one_minus_z[-1] / (1 - self.shift_terminal) + base = 1 - (one_minus_z / scale_factor) + + if self.reverse_sigmas: + base = 1 - base + + if source == "timesteps": + return base * self.num_train_timesteps + return base + + col0 = _build_column(visual_shift, visual_denoising_strength, visual_mu) + col1 = _build_column(audio_shift, audio_denoising_strength, audio_mu) + return torch.stack([col0, col1], dim=1) + + _dual_sigma_shift._requires_source = True + self.set_pair_postprocess(_dual_sigma_shift) + return + raise ValueError(f"Unknown pair_postprocess name: {name}") + + def _make_pairs_from_vector(self, vec: torch.Tensor) -> torch.Tensor: + if vec.ndim != 1: + raise ValueError("vec must be 1D") + return torch.stack([vec, vec], dim=1) + + def get_pairs(self, source: str = "timesteps") -> torch.Tensor: + if source == "timesteps": + if self.pair_timesteps is None: + self._refresh_pair_cache() + return self.pair_timesteps + if source == "sigmas": + if self.pair_sigmas is None: + self._refresh_pair_cache() + return self.pair_sigmas + raise ValueError("source must be 'timesteps' or 'sigmas'") + + def timestep_to_sigma(self, timestep: torch.Tensor | float) -> torch.Tensor: + """Return sigma for a scalar timestep via nearest neighbor lookup. + + Args: + timestep: Scalar timestep value. + + Returns: + Sigma corresponding to the nearest timestep. + """ + t_value = float(timestep) + t_cpu = torch.tensor(t_value) + idx = torch.argmin((self.train_timesteps - t_cpu).abs()) + return self.train_sigmas[idx] + + def step_from_to( + self, + model_output: torch.Tensor, + timestep_from: torch.Tensor, + timestep_to: torch.Tensor | None, + sample: torch.Tensor, + ) -> torch.Tensor: + """Advance one step using an explicit (from, to) timestep pair. + + The update rule is: + x_{to} = x_{from} + model_output * (sigma(to) - sigma(from)) + + Args: + model_output: Predicted model output. + timestep_from: Source timestep. + timestep_to: Target timestep or None for terminal. + sample: Current sample at timestep_from. + + Returns: + Updated sample at timestep_to. + """ + sigma_from = self.timestep_to_sigma(timestep_from) + if timestep_to is None: + sigma_to = torch.tensor( + 1.0 if (self.inverse_timesteps or self.reverse_sigmas) else 0.0, + device=sigma_from.device, + dtype=sigma_from.dtype, + ) + else: + sigma_to = self.timestep_to_sigma(timestep_to) + prev_sample = sample + model_output * (sigma_to - sigma_from) + return prev_sample + + def _refresh_pair_cache(self) -> None: + if self.timesteps is None or self.sigmas is None: + raise RuntimeError("Scheduler not initialized; call set_timesteps() first") + + def _apply_postprocess(pairs: torch.Tensor, source: str) -> torch.Tensor: + if self._pair_postprocess_fn is None: + return pairs + if self._pair_postprocess_requires_source: + modified = self._pair_postprocess_fn(pairs, source=source) + else: + modified = self._pair_postprocess_fn(pairs) + if not isinstance(modified, torch.Tensor): + raise TypeError("pair_postprocess must return a torch.Tensor") + if modified.shape != pairs.shape: + raise ValueError("pair_postprocess must return the same shape as input") + return modified + + base_pairs_timesteps = self._make_pairs_from_vector(self.timesteps) + base_pairs_sigmas = self._make_pairs_from_vector(self.sigmas) + + self.pair_timesteps = _apply_postprocess(base_pairs_timesteps, "timesteps") + self.pair_sigmas = _apply_postprocess(base_pairs_sigmas, "sigmas") + + +EntryClass = FlowMatchPairScheduler diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/hunyuan3d_scheduler.py b/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/hunyuan3d_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..259c6c52c04970e8e08547d94c62b33eeed0fc04 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/hunyuan3d_scheduler.py @@ -0,0 +1,371 @@ +# Copied and adapted from: https://github.com/Tencent-Hunyuan/Hunyuan3D-2 +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput + + +@dataclass +class Hunyuan3DFlowMatchSchedulerOutput(BaseOutput): + """Output class for the scheduler's step function.""" + + prev_sample: torch.FloatTensor + + +class Hunyuan3DFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """Euler discrete scheduler for flow matching.""" + + # External module path aliases for compatibility with Hunyuan3D configs + _aliases = [ + "hy3dgen.shapegen.schedulers.FlowMatchEulerDiscreteScheduler", + "hy3dshape.schedulers.FlowMatchEulerDiscreteScheduler", + ] + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting: bool = False, + ): + timesteps = np.linspace( + 1, num_train_timesteps, num_train_timesteps, dtype=np.float32 + ).copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + if not use_dynamic_shifting: + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + self._step_index = None + self._begin_index = None + + self.sigmas = sigmas.to("cpu") + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self) -> Optional[int]: + """The index counter for current timestep.""" + return self._step_index + + @property + def begin_index(self) -> Optional[int]: + """The index for the first timestep.""" + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + """Set the begin index for the scheduler. + + Args: + begin_index: The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_model_input( + self, + sample: torch.FloatTensor, + timestep: Optional[Union[float, torch.FloatTensor]] = None, + ) -> torch.FloatTensor: + """Identity operation for flow matching (no input scaling needed).""" + return sample + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """Forward process in flow-matching (add noise to sample).""" + sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) + + if sample.device.type == "mps" and torch.is_floating_point(timestep): + schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) + timestep = timestep.to(sample.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timestep = timestep.to(sample.device) + + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) for t in timestep + ] + elif self.step_index is not None: + step_indices = [self.step_index] * timestep.shape[0] + else: + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) + + sample = sigma * noise + (1.0 - sigma) * sample + return sample + + def _sigma_to_t(self, sigma: float) -> float: + """Convert sigma to timestep.""" + return sigma * self.config.num_train_timesteps + + def time_shift(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor: + """Apply time shift transformation.""" + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + ): + """Set the discrete timesteps for the diffusion chain.""" + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + "Must pass a value for `mu` when `use_dynamic_shifting` is True" + ) + + if sigmas is None: + self.num_inference_steps = num_inference_steps + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), + self._sigma_to_t(self.sigma_min), + num_inference_steps, + ) + sigmas = timesteps / self.config.num_train_timesteps + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + timesteps = sigmas * self.config.num_train_timesteps + + self.timesteps = timesteps.to(device=device) + self.sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) + + self._step_index = None + self._begin_index = None + + def index_for_timestep( + self, timestep: float, schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """Find the index for a given timestep.""" + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + pos = 1 if len(indices) > 1 else 0 + return indices[pos].item() + + def _init_step_index(self, timestep: Union[float, torch.Tensor]): + """Initialize step index from timestep.""" + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[Hunyuan3DFlowMatchSchedulerOutput, Tuple]: + """Predict the sample from the previous timestep.""" + if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): + raise ValueError( + "Passing integer indices as timesteps is not supported. " + "Pass one of `scheduler.timesteps` as a timestep." + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues + sample = sample.to(torch.float32) + + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + + prev_sample = sample + (sigma_next - sigma) * model_output + prev_sample = prev_sample.to(model_output.dtype) + + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return Hunyuan3DFlowMatchSchedulerOutput(prev_sample=prev_sample) + + def __len__(self) -> int: + return self.config.num_train_timesteps + + +@dataclass +class Hunyuan3DConsistencyFlowMatchSchedulerOutput(BaseOutput): + """Output for consistency flow matching scheduler.""" + + prev_sample: torch.FloatTensor + pred_original_sample: torch.FloatTensor + + +class Hunyuan3DConsistencyFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """Consistency Flow Matching Euler Discrete Scheduler.""" + + # External module path aliases for compatibility with Hunyuan3D configs + _aliases = [ + "hy3dshape.schedulers.Hunyuan3DConsistencyFlowMatchEulerDiscreteScheduler", + ] + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + pcm_timesteps: int = 50, + ): + sigmas = np.linspace(0, 1, num_train_timesteps) + step_ratio = num_train_timesteps // pcm_timesteps + + euler_timesteps = (np.arange(1, pcm_timesteps) * step_ratio).round().astype( + np.int64 + ) - 1 + euler_timesteps = np.asarray([0] + euler_timesteps.tolist()) + + self.euler_timesteps = euler_timesteps + self.sigmas = sigmas[self.euler_timesteps] + self.sigmas = torch.from_numpy(self.sigmas.copy()).to(dtype=torch.float32) + self.timesteps = self.sigmas * num_train_timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") + + @property + def step_index(self) -> Optional[int]: + return self._step_index + + @property + def begin_index(self) -> Optional[int]: + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + self._begin_index = begin_index + + def scale_model_input( + self, + sample: torch.FloatTensor, + timestep: Optional[Union[float, torch.FloatTensor]] = None, + ) -> torch.FloatTensor: + """Identity operation for flow matching (no input scaling needed).""" + return sample + + def _sigma_to_t(self, sigma: float) -> float: + return sigma * self.config.num_train_timesteps + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + ): + """Set timesteps for inference.""" + self.num_inference_steps = ( + num_inference_steps if num_inference_steps is not None else len(sigmas) + ) + inference_indices = np.linspace( + 0, self.config.pcm_timesteps, num=self.num_inference_steps, endpoint=False + ) + inference_indices = np.floor(inference_indices).astype(np.int64) + inference_indices = torch.from_numpy(inference_indices).long() + + self.sigmas_ = self.sigmas[inference_indices] + timesteps = self.sigmas_ * self.config.num_train_timesteps + self.timesteps = timesteps.to(device=device) + self.sigmas_ = torch.cat( + [self.sigmas_, torch.ones(1, device=self.sigmas_.device)] + ) + + self._step_index = None + self._begin_index = None + + def index_for_timestep( + self, timestep: float, schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + indices = (schedule_timesteps == timestep).nonzero() + pos = 1 if len(indices) > 1 else 0 + return indices[pos].item() + + def _init_step_index(self, timestep: Union[float, torch.Tensor]): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[Hunyuan3DConsistencyFlowMatchSchedulerOutput, Tuple]: + """Perform one step of the consistency flow matching scheduler.""" + if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): + raise ValueError("Passing integer indices as timesteps is not supported.") + + if self.step_index is None: + self._init_step_index(timestep) + + sample = sample.to(torch.float32) + + sigma = self.sigmas_[self.step_index] + sigma_next = self.sigmas_[self.step_index + 1] + + prev_sample = sample + (sigma_next - sigma) * model_output + prev_sample = prev_sample.to(model_output.dtype) + + pred_original_sample = sample + (1.0 - sigma) * model_output + pred_original_sample = pred_original_sample.to(model_output.dtype) + + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return Hunyuan3DConsistencyFlowMatchSchedulerOutput( + prev_sample=prev_sample, pred_original_sample=pred_original_sample + ) + + def __len__(self) -> int: + return self.config.num_train_timesteps + + +# Entry class for model registry +EntryClass = [ + Hunyuan3DFlowMatchEulerDiscreteScheduler, + Hunyuan3DConsistencyFlowMatchEulerDiscreteScheduler, +] diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_comfyui_passthrough.py b/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_comfyui_passthrough.py new file mode 100644 index 0000000000000000000000000000000000000000..e87f558b8b106f087cc331bb4c06e863540f5c5c --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_comfyui_passthrough.py @@ -0,0 +1,193 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo +# SPDX-License-Identifier: Apache-2.0 +""" +Pass-through scheduler for ComfyUI integration. + +This scheduler does not modify latents - it simply returns the input sample unchanged. +The actual denoising logic is handled by ComfyUI. +""" + +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput + +from sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class ComfyUIPassThroughSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor`): The input sample unchanged (pass-through). + """ + + prev_sample: torch.FloatTensor + + +class ComfyUIPassThroughScheduler(BaseScheduler, ConfigMixin, SchedulerMixin): + """ + Pass-through scheduler for ComfyUI integration. + + This scheduler does not modify latents. It is used when the denoising logic + is handled externally by ComfyUI. The scheduler simply returns the input + sample unchanged, allowing ComfyUI to manage the denoising process. + + Usage: + - num_inference_steps is always 1 (each step is handled separately) + - timesteps are provided externally by ComfyUI + - step() returns the input sample unchanged + """ + + config_name = "scheduler_config.json" + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps=1000, + *args, + **kwargs, + ): + self.num_train_timesteps = num_train_timesteps + # Initialize timesteps as empty - will be set externally + self.timesteps = torch.tensor([], dtype=torch.long) + self.shift = 0.0 + self._step_index = 0 # Track current step index + self._begin_index: int | None = None # For compatibility with DenoisingStage + + def set_timesteps( + self, + num_inference_steps=1, # Always 1 for ComfyUI + timesteps=None, # Can be provided externally + device=None, + **kwargs, + ): + """ + Set timesteps. For ComfyUI, timesteps are provided externally. + + Args: + num_inference_steps: Ignored (always 1 for ComfyUI) + timesteps: External timesteps provided by ComfyUI + device: Device to place timesteps on + """ + if timesteps is not None: + # Use externally provided timesteps + if isinstance(timesteps, torch.Tensor): + self.timesteps = timesteps + else: + self.timesteps = torch.tensor(timesteps, dtype=torch.long) + if device is not None: + self.timesteps = self.timesteps.to(device) + else: + # Create a single timestep if none provided + if device is None: + device = torch.device("cpu") + self.timesteps = torch.tensor([0], dtype=torch.long, device=device) + + def step( + self, + model_output: torch.FloatTensor, + timestep: torch.FloatTensor | int, + sample: torch.FloatTensor, + return_dict: bool = False, + **kwargs, + ) -> tuple | ComfyUIPassThroughSchedulerOutput: + """ + Pass-through step: returns the input sample unchanged. + + This scheduler does not modify latents. The actual denoising is handled + by ComfyUI, so we simply return the input sample as-is. + + Args: + model_output: Predicted noise (ignored, but kept for API compatibility) + timestep: Current timestep (ignored, but kept for API compatibility) + sample: Input latents (returned unchanged) + return_dict: Whether to return a dict or tuple + + Returns: + The input sample unchanged (prev_sample = sample) + """ + # Increment step index for tracking + self._step_index += 1 + + # Simply return the input sample unchanged + prev_sample = sample + + if not return_dict: + return (prev_sample,) + + return ComfyUIPassThroughSchedulerOutput(prev_sample=prev_sample) + + def scale_model_input( + self, sample: torch.Tensor, timestep: int | None = None + ) -> torch.Tensor: + """ + Scale model input. For pass-through scheduler, returns input unchanged. + + Args: + sample: Input sample + timestep: Timestep (ignored) + + Returns: + Input sample unchanged + """ + return sample + + def set_shift(self, shift: float) -> None: + """ + Set shift parameter (no-op for pass-through scheduler). + + Args: + shift: Shift value (ignored) + """ + self.shift = shift + + def set_begin_index(self, begin_index: int = 0) -> None: + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index: The begin index for the scheduler. + """ + self._begin_index = begin_index + + @property + def begin_index(self) -> int | None: + """ + The index for the first timestep. + """ + return self._begin_index + + @property + def step_index(self) -> int: + """ + The index counter for current timestep. + """ + return self._step_index + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timestep: torch.Tensor, + ) -> torch.Tensor: + """ + Add noise to samples. For pass-through scheduler, returns original samples. + + Args: + original_samples: Original clean samples + noise: Noise to add (ignored) + timestep: Timestep (ignored) + + Returns: + Original samples unchanged + """ + return original_samples + + +EntryClass = ComfyUIPassThroughScheduler diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_match_euler_discrete.py b/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_match_euler_discrete.py new file mode 100644 index 0000000000000000000000000000000000000000..980ff50f91d95f1e4c99876273c15a3d7ca89a69 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_match_euler_discrete.py @@ -0,0 +1,688 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Modified from diffusers==0.29.2 +# +# ============================================================================== +import math +from dataclasses import dataclass +from typing import Any + +import numpy as np +import scipy.stats +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput + +from sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin, BaseScheduler): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + use_dynamic_shifting (`bool`, defaults to False): + Whether to apply timestep shifting on-the-fly based on the image resolution. + base_shift (`float`, defaults to 0.5): + Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent + with desired output. + max_shift (`float`, defaults to 1.15): + Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be + more exaggerated or stylized. + base_image_seq_len (`int`, defaults to 256): + The base image sequence length. + max_image_seq_len (`int`, defaults to 4096): + The maximum image sequence length. + invert_sigmas (`bool`, defaults to False): + Whether to invert the sigmas. + shift_terminal (`float`, defaults to None): + The end value of the shifted timestep schedule. + use_karras_sigmas (`bool`, defaults to False): + Whether to use Karras sigmas for step sizes in the noise schedule during sampling. + use_exponential_sigmas (`bool`, defaults to False): + Whether to use exponential sigmas for step sizes in the noise schedule during sampling. + use_beta_sigmas (`bool`, defaults to False): + Whether to use beta sigmas for step sizes in the noise schedule during sampling. + time_shift_type (`str`, defaults to "exponential"): + The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear". + stochastic_sampling (`bool`, defaults to False): + Whether to use stochastic sampling. + """ + + _compatibles: list[Any] = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting: bool = False, + base_shift: float | None = 0.5, + max_shift: float | None = 1.15, + base_image_seq_len: int | None = 256, + max_image_seq_len: int | None = 4096, + invert_sigmas: bool = False, + shift_terminal: float | None = None, + use_karras_sigmas: bool | None = False, + use_exponential_sigmas: bool | None = False, + use_beta_sigmas: bool | None = False, + time_shift_type: str = "exponential", + stochastic_sampling: bool = False, + ): + if ( + sum( + [ + self.config.use_beta_sigmas, + self.config.use_exponential_sigmas, + self.config.use_karras_sigmas, + ] + ) + > 1 + ): + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) + if time_shift_type not in {"exponential", "linear"}: + raise ValueError( + "`time_shift_type` must either be 'exponential' or 'linear'." + ) + + timesteps = np.linspace( + 1, num_train_timesteps, num_train_timesteps, dtype=np.float32 + )[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + self.num_train_timesteps = num_train_timesteps + + self._step_index: int | None = None + self._begin_index: int | None = None + + self._shift = shift + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + BaseScheduler.__init__(self) + + @property + def shift(self) -> float: + """ + The value used for shifting. + """ + return self._shift + + @property + def step_index(self) -> int | None: + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self) -> int | None: + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0) -> None: + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_shift(self, shift: float) -> None: + self._shift = shift + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: float | torch.FloatTensor, + noise: torch.FloatTensor | None = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + """ + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) + + if sample.device.type == "mps" and torch.is_floating_point(timestep): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) + assert isinstance(timestep, torch.Tensor) + timestep = timestep.to(sample.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + assert isinstance(timestep, torch.Tensor) + timestep = timestep.to(sample.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) for t in timestep + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) + + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma: float) -> float: + return sigma * self.config.num_train_timesteps + + def time_shift( + self, mu: float, sigma: float, t: torch.Tensor | np.ndarray + ) -> torch.Tensor | np.ndarray: + if self.config.time_shift_type == "exponential": + return self._time_shift_exponential(mu, sigma, t) + elif self.config.time_shift_type == "linear": + return self._time_shift_linear(mu, sigma, t) + else: + raise ValueError(f"Unknown time_shift_type: {self.config.time_shift_type}") + + def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: + r""" + Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config + value. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`torch.Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `torch.Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + + def set_timesteps( + self, + num_inference_steps: int | None = None, + device: str | torch.device = None, + sigmas: list[float] | None = None, + mu: float | None = None, + timesteps: list[float] | None = None, + ) -> None: + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`, *optional*): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + sigmas (`List[float]`, *optional*): + Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed + automatically. + mu (`float`, *optional*): + Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep + shifting. + timesteps (`List[float]`, *optional*): + Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed + automatically. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + "`mu` must be passed when `use_dynamic_shifting` is set to be `True`" + ) + + if ( + sigmas is not None + and timesteps is not None + and len(sigmas) != len(timesteps) + ): + raise ValueError("`sigmas` and `timesteps` should have the same length") + + if num_inference_steps is not None: + if (sigmas is not None and len(sigmas) != num_inference_steps) or ( + timesteps is not None and len(timesteps) != num_inference_steps + ): + raise ValueError( + "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided" + ) + else: + if sigmas is not None: + num_inference_steps = len(sigmas) + elif timesteps is not None: + num_inference_steps = len(timesteps) + else: + raise ValueError( + "Either num_inference_steps, sigmas, or timesteps must be provided" + ) + + self.num_inference_steps = num_inference_steps + + # 1. Prepare default sigmas + is_timesteps_provided = timesteps is not None + + timesteps_array: np.ndarray | None = None + if is_timesteps_provided: + assert timesteps is not None + timesteps_array = np.array(timesteps).astype(np.float32) + + sigmas_array: np.ndarray + if sigmas is None: + if timesteps_array is None: + timesteps_array = np.linspace( + self._sigma_to_t(self.sigma_max), + self._sigma_to_t(self.sigma_min), + num_inference_steps, + ) + sigmas_array = timesteps_array / self.config.num_train_timesteps + else: + sigmas_array = np.array(sigmas).astype(np.float32) + num_inference_steps = len(sigmas_array) + + # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of + # "exponential" or "linear" type is applied + if self.config.use_dynamic_shifting: + assert mu is not None, "mu cannot be None when use_dynamic_shifting is True" + sigmas_array = self.time_shift(mu, 1.0, sigmas_array) + else: + sigmas_array = ( + self.shift * sigmas_array / (1 + (self.shift - 1) * sigmas_array) + ) + + # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value + if self.config.shift_terminal: + sigmas_tensor = torch.from_numpy(sigmas_array).to(dtype=torch.float32) + sigmas_tensor = self.stretch_shift_to_terminal(sigmas_tensor) + sigmas_array = sigmas_tensor.numpy() + + # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules + if self.config.use_karras_sigmas: + sigmas_tensor = torch.from_numpy(sigmas_array).to(dtype=torch.float32) + sigmas_tensor = self._convert_to_karras( + in_sigmas=sigmas_tensor, num_inference_steps=num_inference_steps + ) + sigmas_array = sigmas_tensor.numpy() + elif self.config.use_exponential_sigmas: + sigmas_tensor = torch.from_numpy(sigmas_array).to(dtype=torch.float32) + sigmas_tensor = self._convert_to_exponential( + in_sigmas=sigmas_tensor, num_inference_steps=num_inference_steps + ) + sigmas_array = sigmas_tensor.numpy() + elif self.config.use_beta_sigmas: + sigmas_tensor = torch.from_numpy(sigmas_array).to(dtype=torch.float32) + sigmas_tensor = self._convert_to_beta( + in_sigmas=sigmas_tensor, num_inference_steps=num_inference_steps + ) + sigmas_array = sigmas_tensor.numpy() + + # 5. Convert sigmas and timesteps to tensors and move to specified device + sigmas_tensor = torch.from_numpy(sigmas_array).to( + dtype=torch.float32, device=device + ) + if not is_timesteps_provided: + timesteps_tensor = sigmas_tensor * self.config.num_train_timesteps + else: + assert timesteps_array is not None + timesteps_tensor = torch.from_numpy(timesteps_array).to( + dtype=torch.float32, device=device + ) + + # 6. Append the terminal sigma value. + # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the + # `invert_sigmas` flag can be set to `True`. This case is only required in Mochi + if self.config.invert_sigmas: + sigmas_tensor = 1.0 - sigmas_tensor + timesteps_tensor = sigmas_tensor * self.config.num_train_timesteps + sigmas_tensor = torch.cat( + [sigmas_tensor, torch.ones(1, device=sigmas_tensor.device)] + ) + else: + sigmas_tensor = torch.cat( + [sigmas_tensor, torch.zeros(1, device=sigmas_tensor.device)] + ) + + self.timesteps = timesteps_tensor + self.sigmas = sigmas_tensor + self._step_index = None + self._begin_index = None + + def index_for_timestep( + self, + timestep: float | torch.FloatTensor, + schedule_timesteps: torch.Tensor | None = None, + ) -> int: + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep: float | torch.FloatTensor) -> None: + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: int | torch.Tensor, + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: torch.Generator | None = None, + per_token_timesteps: torch.Tensor | None = None, + return_dict: bool = True, + ) -> FlowMatchEulerDiscreteSchedulerOutput | tuple[torch.FloatTensor, ...]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`int` or `torch.Tensor`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + per_token_timesteps (`torch.Tensor`, *optional*): + The timesteps for each token in the sample. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + """ + + if isinstance(timestep, int | torch.IntTensor | torch.LongTensor): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `FlowMatchEulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + if per_token_timesteps is not None: + per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps + + sigmas = self.sigmas[:, None, None] + lower_mask = sigmas < per_token_sigmas[None] - 1e-6 + lower_sigmas = lower_mask * sigmas + lower_sigmas, _ = lower_sigmas.max(dim=0) + + current_sigma = per_token_sigmas[..., None] + next_sigma = lower_sigmas[..., None] + dt = current_sigma - next_sigma + else: + assert self.step_index is not None, "step_index should not be None" + sigma_idx = self.step_index + sigma = self.sigmas[sigma_idx] + sigma_next = self.sigmas[sigma_idx + 1] + + current_sigma = sigma + next_sigma = sigma_next + dt = sigma_next - sigma + + if self.config.stochastic_sampling: + x0 = sample - current_sigma * model_output + noise = torch.randn_like(sample) + prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise + else: + prev_sample = sample + dt * model_output + + # upon completion increase step index by one + assert self._step_index is not None, "_step_index should not be None" + self._step_index += 1 + if per_token_timesteps is None: + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + if isinstance(prev_sample, torch.Tensor | float) and not return_dict: + return (prev_sample,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras( + self, in_sigmas: torch.Tensor, num_inference_steps: int + ) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential( + self, in_sigmas: torch.Tensor, num_inference_steps: int + ) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp( + np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps) + ) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, + in_sigmas: torch.Tensor, + num_inference_steps: int, + alpha: float = 0.6, + beta: float = 0.6, + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + + def _time_shift_exponential( + self, mu: float, sigma: float, t: torch.Tensor | np.ndarray + ) -> torch.Tensor | np.ndarray: + if isinstance(t, np.ndarray): + return np.exp(mu) / (np.exp(mu) + (1 / t - 1) ** sigma) + else: + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def _time_shift_linear( + self, mu: float, sigma: float, t: torch.Tensor | np.ndarray + ) -> torch.Tensor | np.ndarray: + return mu / (mu + (1 / t - 1) ** sigma) + + def add_noise( + self, + clean_latent: torch.Tensor, + noise: torch.Tensor, + timestep: torch.IntTensor, + ) -> torch.Tensor: + """ + Args: + clean_latent: the clean latent with shape [B, C, H, W], + where B is batch_size or batch_size * num_frames + noise: the noise with shape [B, C, H, W] + timestep: the timestep with shape [1] or [bs * num_frames] or [bs, num_frames] + + Returns: + the corrupted latent with shape [B, C, H, W] + """ + # If timestep is [bs, num_frames] + if timestep.ndim == 2: + timestep = timestep.flatten(0, 1) + assert timestep.numel() == clean_latent.shape[0] + elif timestep.ndim == 1: + # If timestep is [1] + if timestep.shape[0] == 1: + timestep = timestep.expand(clean_latent.shape[0]) + else: + assert timestep.numel() == clean_latent.shape[0] + else: + raise ValueError(f"[add_noise] Invalid timestep shape: {timestep.shape}") + # timestep shape should be [B] + self.sigmas = self.sigmas.to(noise.device) + self.timesteps = self.timesteps.to(noise.device) + timestep_id = torch.argmin( + (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1 + ) + sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1) + sample = (1 - sigma) * clean_latent + sigma * noise + return sample.type_as(noise) + + def scale_model_input( + self, sample: torch.Tensor, timestep: int | None = None + ) -> torch.Tensor: + return sample + + def __len__(self) -> int: + return 0 + + +EntryClass = FlowMatchEulerDiscreteScheduler diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_unipc_multistep.py b/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_unipc_multistep.py new file mode 100644 index 0000000000000000000000000000000000000000..4874f69665550820a0cd9c4824d9e29d2ea5e15c --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_unipc_multistep.py @@ -0,0 +1,843 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Convert unipc for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import math +from typing import Any + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) +from diffusers.utils import deprecate + +from sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin, BaseScheduler): + """ + `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` + and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is + usually disabled during the first few steps. + solver_p (`SchedulerMixin`, default `None`): + Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: float | None = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: tuple = (), + solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: str | None = "zero", # "zero", "sigma_min" + **kwargs, + ): + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}" + ) + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps: int | None = None + alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[ + ::-1 + ].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + assert shift is not None + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + self.timesteps = sigmas * num_train_timesteps + self.num_train_timesteps = num_train_timesteps + + self.model_outputs = [None] * solver_order + self.timestep_list: list[Any | None] = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = list(disable_corrector) + self.solver_p = solver_p + self.last_sample = None + self._step_index: int | None = None + self._begin_index: int | None = None + + BaseScheduler.__init__(self) + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + def set_shift(self, shift: float) -> None: + self.config.shift = shift + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: int | None = None, + device: str | torch.device = None, + sigmas: list[float] | None = None, + mu: float | None | None = None, + shift: float | None | None = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + assert num_inference_steps is not None + sigmas = np.linspace( + self.sigma_max, self.sigma_min, num_inference_steps + 1 + ).copy()[ + :-1 + ] # pyright: ignore + + if self.config.use_dynamic_shifting: + assert mu is not None + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + assert isinstance(sigmas, np.ndarray) + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last]]).astype( + np.float32 + ) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas).to(device=device) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64 + ) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = ( + sample.float() + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = ( + torch.clamp(sample, -s, s) / s + ) # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def _sigma_to_alpha_sigma_t(self, sigma) -> tuple[Any, Any]: + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + r""" + Convert the model output to the corresponding type the UniPC algorithm needs. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyword argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int | None = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of UniP at this timestep (corresponds to the *p* in UniPC-p). + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError(" missing `sample` as a required keyword argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError(" missing `order` as a required keyword argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + ) # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s: list[Any] | None = [] + sigmas = self.sigmas.to(device=device) + for i in range(1, order): + si = self.step_index - i # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + assert mi is not None + D1s.append((mi - m0) / rk) # pyright: ignore + + if len(rks) > 0: + rks = torch.stack(rks) + one = torch.ones(1, device=device, dtype=rks.dtype) + rks = torch.cat([rks, one]) + else: + rks = torch.ones(1, device=device, dtype=h.dtype) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.stack(b) + + if D1s is not None and len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = 0.5 * torch.ones(1, dtype=x.dtype, device=device) + else: + assert isinstance(R, torch.Tensor) + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum( + "k,bkc...->bc...", rhos_p, D1s + ) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum( + "k,bkc...->bc...", rhos_p, D1s + ) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int | None = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. + + Returns: + `torch.Tensor`: + The corrected sample tensor at the current timestep. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError(" missing`last_sample` as a required keyword argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError(" missing`this_sample` as a required keyword argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError(" missing`order` as a required keyword argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + # Build rks and D1s fully on device to avoid any host-device sync + # Fast paths for small orders (common cases: 1 or 2) + if order == 1: + rks = torch.ones(1, device=device, dtype=h.dtype) + D1s = None + elif order == 2: + # order == 2 -> only one historical point is used + si = self.step_index - 2 # i = 1 + mi = model_output_list[-2] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h # 0-dim tensor on device + # rks = [rk, 1.0] but keep it on device without list->tensor sync + rks = torch.stack((rk, torch.ones_like(rk))) + assert mi is not None + # D1s shape: (B, K=1, C, ...) to match later einsum over K + D1s = ((mi - m0) / rk).unsqueeze(1) # pyright: ignore + else: + rks_list = [] + D1s_list = [] + for i in range(1, order): + si = self.step_index - (i + 1) + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks_list.append(rk) + assert mi is not None + D1s_list.append((mi - m0) / rk) # pyright: ignore + + # Append 1.0 as a device tensor to rks + rks = torch.stack(rks_list + [torch.ones_like(rks_list[0])]) + D1s = torch.stack(D1s_list, dim=1) if len(D1s_list) > 0 else None + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + # Avoid torch.tensor(list_of_gpu_scalars) which syncs to host + b = torch.stack(b) + + # D1s is already prepared above for order==2; remains None for order==1 + + # for order 1, we use a simplified version + if order == 1: + rhos_c = 0.5 * torch.ones(1, dtype=x.dtype, device=device) + elif order == 2: + # Manually solve the 2x2 linear system to avoid device synchronization from torch.linalg.solve + # R = [[1, 1], [rk, 1]], where rk = rks[0] + rk = rks[0] + det = 1 - rk + # Using Cramer's rule to solve for rhos_c = [x0, x1] + # x0 = (b0 - b1) / det + # x1 = (b1 - rk * b0) / det + rhos_c_0 = (b[0] - b[1]) / det + rhos_c_1 = (b[1] - rk * b[0]) / det + rhos_c = torch.stack([rhos_c_0, rhos_c_1]) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None) -> int: + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + step_index: int = indices[pos].item() + + return step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep) -> None: + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.Tensor, + return_dict: bool = True, + generator=None, + ) -> SchedulerOutput | tuple: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to call 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 + and self.step_index - 1 not in self.disable_corrector + and self.last_sample is not None # pyright: ignore + ) + + sample = sample.to(model_output.device) + model_output_convert = self.convert_model_output(model_output, sample=sample) + + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep # pyright: ignore + + if self.config.lower_order_final: + this_order = min( + self.config.solver_order, len(self.timesteps) - self.step_index + ) # pyright: ignore + else: + this_order = self.config.solver_order + + self.this_order: int = min( + this_order, self.lower_order_nums + 1 + ) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + assert self._step_index is not None + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype + ) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32 + ) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + +EntryClass = FlowUniPCMultistepScheduler diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_helios.py b/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_helios.py new file mode 100644 index 0000000000000000000000000000000000000000..3b3f64e92a80744a16d3a74200dfc6059e0a034c --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_helios.py @@ -0,0 +1,734 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from Helios diffusers scheduler: +# https://github.com/BestWishYsh/Helios +""" +Helios scheduler implementing flow-matching with UniPC/Euler solvers. + +For Phase 1 T2V (stages=1), this simplifies to standard flow-matching +with dynamic shifting and UniPC multistep solver. +""" + +import math +from dataclasses import dataclass + +import numpy as np +import torch + + +@dataclass +class HeliosSchedulerOutput: + prev_sample: torch.FloatTensor + model_outputs: torch.FloatTensor | None = None + last_sample: torch.FloatTensor | None = None + this_order: int | None = None + + +class HeliosSchedulerConfig: + """Mimics diffusers config interface for scheduler parameters.""" + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def get(self, key, default=None): + return getattr(self, key, default) + + +class HeliosScheduler: + """ + Helios multi-stage scheduler supporting Euler, UniPC, and DMD solvers. + + For Phase 1 T2V with stages=1, this is a standard flow-matching scheduler + with optional time shifting and UniPC multistep updates. + """ + + order = 1 + + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + stages: int = 1, + stage_range: list | None = None, + gamma: float = 1 / 3, + thresholding: bool = False, + prediction_type: str = "flow_prediction", + solver_order: int = 2, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: list[int] | None = None, + use_flow_sigmas: bool = True, + scheduler_type: str = "unipc", + use_dynamic_shifting: bool = False, + time_shift_type: str = "linear", + **kwargs, + ): + if stage_range is None: + # Evenly divide [0, 1] into 3 stages for pyramid SR + stage_range = [0, 1 / 3, 2 / 3, 1] + if disable_corrector is None: + disable_corrector = [] + + self.config = HeliosSchedulerConfig( + num_train_timesteps=num_train_timesteps, + shift=shift, + stages=stages, + stage_range=stage_range, + gamma=gamma, + thresholding=thresholding, + prediction_type=prediction_type, + solver_order=solver_order, + predict_x0=predict_x0, + solver_type=solver_type, + lower_order_final=lower_order_final, + disable_corrector=disable_corrector, + use_flow_sigmas=use_flow_sigmas, + scheduler_type=scheduler_type, + use_dynamic_shifting=use_dynamic_shifting, + time_shift_type=time_shift_type, + ) + + self.timestep_ratios = {} + self.timesteps_per_stage = {} + self.sigmas_per_stage = {} + self.start_sigmas = {} + self.end_sigmas = {} + self.ori_start_sigmas = {} + + self.init_sigmas_for_each_stage() + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + self.gamma = gamma + + if solver_type not in ["bh1", "bh2"]: + raise NotImplementedError(f"{solver_type} is not implemented") + + self.predict_x0 = predict_x0 + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = None + self.last_sample = None + self._step_index = None + self._begin_index = None + self.num_inference_steps = None + + def init_sigmas(self): + num_train_timesteps = self.config.num_train_timesteps + shift = self.config.shift + + alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps + 1) + sigmas = 1.0 - alphas + sigmas = np.flip(shift * sigmas / (1 + (shift - 1) * sigmas))[:-1].copy() + sigmas = torch.from_numpy(sigmas) + timesteps = (sigmas * num_train_timesteps).clone() + + self._step_index = None + self._begin_index = None + self.timesteps = timesteps + self.sigmas = sigmas.to("cpu") + + def init_sigmas_for_each_stage(self): + self.init_sigmas() + + stage_distance = [] + stages = self.config.stages + training_steps = self.config.num_train_timesteps + stage_range = self.config.stage_range + + for i_s in range(stages): + start_indice = int(stage_range[i_s] * training_steps) + start_indice = max(start_indice, 0) + end_indice = int(stage_range[i_s + 1] * training_steps) + end_indice = min(end_indice, training_steps) + start_sigma = self.sigmas[start_indice].item() + end_sigma = ( + self.sigmas[end_indice].item() if end_indice < training_steps else 0.0 + ) + self.ori_start_sigmas[i_s] = start_sigma + + if i_s != 0: + ori_sigma = 1 - start_sigma + gamma = self.config.gamma + corrected_sigma = ( + 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) + ) * ori_sigma + start_sigma = 1 - corrected_sigma + + stage_distance.append(start_sigma - end_sigma) + self.start_sigmas[i_s] = start_sigma + self.end_sigmas[i_s] = end_sigma + + tot_distance = sum(stage_distance) + for i_s in range(stages): + if i_s == 0: + start_ratio = 0.0 + else: + start_ratio = sum(stage_distance[:i_s]) / tot_distance + if i_s == stages - 1: + # Use value just below 1.0 to avoid out-of-bounds indexing + end_ratio = 1.0 - 1e-16 + else: + end_ratio = sum(stage_distance[: i_s + 1]) / tot_distance + self.timestep_ratios[i_s] = (start_ratio, end_ratio) + + for i_s in range(stages): + timestep_ratio = self.timestep_ratios[i_s] + # Clamp to max valid timestep (num_train_timesteps - 1) + timestep_max = min( + self.timesteps[int(timestep_ratio[0] * training_steps)], 999 + ) + timestep_min = self.timesteps[ + min(int(timestep_ratio[1] * training_steps), training_steps - 1) + ] + timesteps = np.linspace(timestep_max, timestep_min, training_steps + 1) + self.timesteps_per_stage[i_s] = ( + timesteps[:-1] + if isinstance(timesteps, torch.Tensor) + else torch.from_numpy(timesteps[:-1]) + ) + # Sigma range [0.999, 0]: start just below 1.0 to avoid singularity + stage_sigmas = np.linspace(0.999, 0, training_steps + 1) + self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1]) + + @property + def step_index(self): + return self._step_index + + @property + def begin_index(self): + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + self._begin_index = begin_index + + def time_shift(self, mu, sigma, t): + if self.config.time_shift_type == "exponential": + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + elif self.config.time_shift_type == "linear": + return mu / (mu + (1 / t - 1) ** sigma) + + def set_timesteps( + self, + num_inference_steps: int, + stage_index: int | None = None, + device: str | torch.device = None, + sigmas=None, + mu=None, + is_amplify_first_chunk: bool = False, + ): + if self.config.scheduler_type == "dmd": + if is_amplify_first_chunk: + num_inference_steps = num_inference_steps * 2 + 1 + else: + num_inference_steps = num_inference_steps + 1 + + self.num_inference_steps = num_inference_steps + self.init_sigmas() + + if self.config.stages == 1: + if sigmas is None: + sigmas = np.linspace( + 1, + 1 / self.config.num_train_timesteps, + num_inference_steps + 1, + )[:-1].astype(np.float32) + if self.config.shift != 1.0: + assert not self.config.use_dynamic_shifting + sigmas = self.time_shift(self.config.shift, 1.0, sigmas) + timesteps = (sigmas * self.config.num_train_timesteps).copy() + sigmas = torch.from_numpy(sigmas) + else: + stage_timesteps = self.timesteps_per_stage[stage_index] + timesteps = np.linspace( + stage_timesteps[0].item(), + stage_timesteps[-1].item(), + num_inference_steps, + ) + stage_sigmas = self.sigmas_per_stage[stage_index] + ratios = np.linspace( + stage_sigmas[0].item(), stage_sigmas[-1].item(), num_inference_steps + ) + sigmas = torch.from_numpy(ratios) + + self.timesteps = torch.from_numpy(timesteps).to(device=device) + self.sigmas = torch.cat([sigmas, torch.zeros(1)]).to(device=device) + + self._step_index = None + self.reset_scheduler_history() + + if self.config.scheduler_type == "dmd": + self.timesteps = self.timesteps[:-1] + self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas[-1:]]) + + if self.config.use_dynamic_shifting: + assert self.config.shift == 1.0 + self.sigmas = self.time_shift(mu, 1.0, self.sigmas) + if self.config.stages == 1: + self.timesteps = self.sigmas[:-1] * self.config.num_train_timesteps + else: + self.timesteps = self.timesteps_per_stage[ + stage_index + ].min() + self.sigmas[:-1] * ( + self.timesteps_per_stage[stage_index].max() + - self.timesteps_per_stage[stage_index].min() + ) + + # ---------------------------------- Euler ---------------------------------- + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + indices = (schedule_timesteps == timestep).nonzero() + pos = 1 if len(indices) > 1 else 0 + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step_euler( + self, + model_output: torch.FloatTensor, + timestep=None, + sample: torch.FloatTensor = None, + return_dict: bool = True, + **kwargs, + ) -> HeliosSchedulerOutput | tuple: + if self.step_index is None: + self._step_index = 0 + + sample = sample.to(torch.float32) + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + + prev_sample = sample + (sigma_next - sigma) * model_output + prev_sample = prev_sample.to(model_output.dtype) + + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + return HeliosSchedulerOutput(prev_sample=prev_sample) + + # ---------------------------------- UniPC ---------------------------------- + def _sigma_to_alpha_sigma_t(self, sigma): + if self.config.use_flow_sigmas: + alpha_t = 1 - sigma + sigma_t = torch.clamp(sigma, min=1e-8) + else: + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + return alpha_t, sigma_t + + def convert_model_output(self, model_output, sample=None, sigma=None, **kwargs): + flag = False + if sigma is None: + flag = True + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + if flag: + sigma_t = self.sigmas[self.step_index] + else: + sigma_t = sigma + x0_pred = sample - sigma_t * model_output + elif self.config.prediction_type == "epsilon": + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type {self.config.prediction_type} not supported" + ) + return x0_pred + else: + if self.config.prediction_type == "epsilon": + return model_output + elif self.config.prediction_type == "sample": + return (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + return alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type {self.config.prediction_type} not supported" + ) + + def multistep_uni_p_bh_update( + self, model_output, sample=None, order=None, sigma=None, sigma_next=None + ): + model_output_list = self.model_outputs + m0 = model_output_list[-1] + x = sample + + if sigma_next is None and sigma is None: + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + ) + else: + sigma_t, sigma_s0 = sigma_next, sigma + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_k = h_phi_1 / hh - 1 + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + pred_res = ( + torch.einsum("k,bkc...->bc...", rhos_p, D1s) if D1s is not None else 0 + ) + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + pred_res = ( + torch.einsum("k,bkc...->bc...", rhos_p, D1s) if D1s is not None else 0 + ) + x_t = x_t_ - sigma_t * B_h * pred_res + + return x_t.to(x.dtype) + + def multistep_uni_c_bh_update( + self, + this_model_output, + last_sample=None, + this_sample=None, + order=None, + sigma_before=None, + sigma=None, + ): + model_output_list = self.model_outputs + m0 = model_output_list[-1] + x = last_sample + model_t = this_model_output + + if sigma_before is None and sigma is None: + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + else: + sigma_t, sigma_s0 = sigma, sigma_before + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_k = h_phi_1 / hh - 1 + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + corr_res = ( + torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + if D1s is not None + else 0 + ) + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + corr_res = ( + torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + if D1s is not None + else 0 + ) + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + + return x_t.to(x.dtype) + + def step_unipc( + self, + model_output, + timestep=None, + sample=None, + return_dict: bool = True, + **kwargs, + ) -> HeliosSchedulerOutput | tuple: + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', run 'set_timesteps' first" + ) + + if self.step_index is None: + self._step_index = 0 + + use_corrector = ( + self.step_index > 0 + and self.step_index - 1 not in self.disable_corrector + and self.last_sample is not None + ) + + model_output_convert = self.convert_model_output(model_output, sample=sample) + + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + if self.config.lower_order_final: + this_order = min( + self.config.solver_order, len(self.timesteps) - self.step_index + ) + else: + this_order = self.config.solver_order + self.this_order = min(this_order, self.lower_order_nums + 1) + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + return HeliosSchedulerOutput(prev_sample=prev_sample) + + # ---------------------------------- DMD ---------------------------------- + def add_noise(self, original_samples, noise, timestep, sigmas, timesteps): + sigmas = sigmas.to(noise.device) + timesteps = timesteps.to(noise.device) + timestep_id = torch.argmin( + (timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1 + ) + sigma = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1) + sample = (1 - sigma) * original_samples + sigma * noise + return sample.type_as(noise) + + def convert_flow_pred_to_x0(self, flow_pred, xt, timestep, sigmas, timesteps): + original_dtype = flow_pred.dtype + device = flow_pred.device + flow_pred, xt, sigmas, timesteps = ( + x.double().to(device) for x in (flow_pred, xt, sigmas, timesteps) + ) + timestep_id = torch.argmin( + (timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1 + ) + sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1) + x0_pred = xt - sigma_t * flow_pred + return x0_pred.to(original_dtype) + + def step_dmd( + self, + model_output: torch.FloatTensor, + timestep=None, + sample: torch.FloatTensor = None, + return_dict: bool = True, + cur_sampling_step: int = 0, + dmd_noisy_tensor: torch.FloatTensor | None = None, + dmd_sigmas: torch.FloatTensor | None = None, + dmd_timesteps: torch.FloatTensor | None = None, + all_timesteps: torch.FloatTensor | None = None, + **kwargs, + ) -> HeliosSchedulerOutput | tuple: + pred_image_or_video = self.convert_flow_pred_to_x0( + flow_pred=model_output, + xt=sample, + timestep=torch.full( + (model_output.shape[0],), + timestep, + dtype=torch.long, + device=model_output.device, + ), + sigmas=dmd_sigmas, + timesteps=dmd_timesteps, + ) + if cur_sampling_step < len(all_timesteps) - 1: + prev_sample = self.add_noise( + pred_image_or_video, + dmd_noisy_tensor, + torch.full( + (model_output.shape[0],), + all_timesteps[cur_sampling_step + 1], + dtype=torch.long, + device=model_output.device, + ), + sigmas=dmd_sigmas, + timesteps=dmd_timesteps, + ) + else: + prev_sample = pred_image_or_video + + if not return_dict: + return (prev_sample,) + return HeliosSchedulerOutput(prev_sample=prev_sample) + + # ---------------------------------- Main step ---------------------------------- + def step( + self, + model_output, + timestep=None, + sample=None, + return_dict: bool = True, + **kwargs, + ) -> HeliosSchedulerOutput | tuple: + if self.config.scheduler_type == "euler": + return self.step_euler( + model_output=model_output, + timestep=timestep, + sample=sample, + return_dict=return_dict, + ) + elif self.config.scheduler_type == "unipc": + return self.step_unipc( + model_output=model_output, + timestep=timestep, + sample=sample, + return_dict=return_dict, + ) + elif self.config.scheduler_type == "dmd": + return self.step_dmd( + model_output=model_output, + timestep=timestep, + sample=sample, + return_dict=return_dict, + **kwargs, + ) + else: + raise NotImplementedError( + f"Scheduler type '{self.config.scheduler_type}' not implemented" + ) + + def reset_scheduler_history(self): + self.model_outputs = [None] * self.config.solver_order + self.timestep_list = [None] * self.config.solver_order + self.lower_order_nums = 0 + self.disable_corrector = self.config.disable_corrector + self.solver_p = None + self.last_sample = None + self._step_index = None + self._begin_index = None + + def set_shift(self, shift: float): + """Update the shift parameter (called by SchedulerLoader after loading).""" + self.config.shift = shift + self.shift = shift + + def __len__(self): + return self.config.num_train_timesteps + + +EntryClass = HeliosScheduler diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_self_forcing_flow_match.py b/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_self_forcing_flow_match.py new file mode 100644 index 0000000000000000000000000000000000000000..9a4749a6d9e2192aae41f935bd56878502b2b54b --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_self_forcing_flow_match.py @@ -0,0 +1,142 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput + +from sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class SelfForcingFlowMatchSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class SelfForcingFlowMatchScheduler(BaseScheduler, ConfigMixin, SchedulerMixin): + config_name = "scheduler_config.json" + order = 1 + + @register_to_config + def __init__( + self, + num_inference_steps=100, + num_train_timesteps=1000, + shift=3.0, + sigma_max=1.0, + sigma_min=0.003 / 1.002, + inverse_timesteps=False, + extra_one_step=False, + reverse_sigmas=False, + *args, + **kwargs, + ): + self.num_train_timesteps = num_train_timesteps + self.shift = shift + self.sigma_max = sigma_max + self.sigma_min = sigma_min + self.inverse_timesteps = inverse_timesteps + self.extra_one_step = extra_one_step + self.reverse_sigmas = reverse_sigmas + self.set_timesteps(num_inference_steps) + + def set_timesteps( + self, + num_inference_steps=100, + denoising_strength=1.0, + return_dict=False, + **kwargs, + ): + sigma_start = ( + self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength + ) + if self.extra_one_step: + self.sigmas = torch.linspace( + sigma_start, self.sigma_min, num_inference_steps + 1 + )[:-1] + else: + self.sigmas = torch.linspace( + sigma_start, self.sigma_min, num_inference_steps + ) + if self.inverse_timesteps: + self.sigmas = torch.flip(self.sigmas, dims=[0]) + self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas) + if self.reverse_sigmas: + self.sigmas = 1 - self.sigmas + self.timesteps = self.sigmas * self.num_train_timesteps + + def step( + self, + model_output: torch.FloatTensor, + timestep: torch.FloatTensor, + sample: torch.FloatTensor, + to_final=False, + return_dict=False, + **kwargs, + ): + if timestep.ndim == 2: + timestep = timestep.flatten(0, 1) + elif timestep.ndim == 0: + # handles the case where timestep is a scalar, this occurs when we + # use this scheduler for ODE trajectory + timestep = timestep.unsqueeze(0) + + self.sigmas = self.sigmas.to(model_output.device) + self.timesteps = self.timesteps.to(model_output.device) + timestep = timestep.to(model_output.device) + + timestep_id = torch.argmin( + (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1 + ) + sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1) + if to_final or (timestep_id + 1 >= len(self.timesteps)).any(): + sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0 + else: + sigma_ = self.sigmas[timestep_id + 1].reshape(-1, 1, 1, 1) + prev_sample = sample + model_output * (sigma_ - sigma) + if isinstance(prev_sample, torch.Tensor | float) and not return_dict: + return (prev_sample,) + return SelfForcingFlowMatchSchedulerOutput(prev_sample=prev_sample) + + def add_noise(self, original_samples, noise, timestep): + """ + Diffusion forward corruption process. + Input: + - clean_latent: the clean latent with shape [B*T, C, H, W] + - noise: the noise with shape [B*T, C, H, W] + - timestep: the timestep with shape [B*T] + Output: the corrupted latent with shape [B*T, C, H, W] + """ + if timestep.ndim == 2: + timestep = timestep.flatten(0, 1) + self.sigmas = self.sigmas.to(noise.device) + self.timesteps = self.timesteps.to(noise.device) + timestep_id = torch.argmin( + (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1 + ) + sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1) + sample = (1 - sigma) * original_samples + sigma * noise + return sample.type_as(noise) + + def scale_model_input( + self, sample: torch.Tensor, timestep: int | None = None + ) -> torch.Tensor: + return sample + + def set_shift(self, shift: float) -> None: + self.shift = shift + + +EntryClass = SelfForcingFlowMatchScheduler diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_unipc_multistep.py b/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_unipc_multistep.py new file mode 100644 index 0000000000000000000000000000000000000000..df5e9b834b3fbedc2722bddd507f21c44680fd53 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_unipc_multistep.py @@ -0,0 +1,1207 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: check https://huggingface.co/papers/2302.04867 and https://github.com/wl-zhao/UniPC for more info +# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +# ============================================================================== +# +# Modified from diffusers==0.35.0.dev0 +# +# ============================================================================== + +import math + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) +from diffusers.utils import deprecate, is_scipy_available + +from sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler + +if is_scipy_available(): + import scipy.stats + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin, BaseScheduler): + """ + `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` + and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is + usually disabled during the first few steps. + solver_p (`SchedulerMixin`, default `None`): + Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + use_beta_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta + Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. + use_flow_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use flow sigmas for step sizes in the noise schedule during the sampling process. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: np.ndarray | list[float] | None = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: list[int] = [], + solver_p: SchedulerMixin = None, + use_karras_sigmas: bool | None = False, + use_exponential_sigmas: bool | None = False, + use_beta_sigmas: bool | None = False, + use_flow_sigmas: bool | None = False, + flow_shift: float | None = 1.0, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: str | None = "zero", # "zero", "sigma_min" + rescale_betas_zero_snr: bool = False, + use_dynamic_shifting: bool = False, + time_shift_type: str = "exponential", + ): + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError( + "Make sure to install scipy if you want to use beta sigmas." + ) + if ( + sum( + [ + self.config.use_beta_sigmas, + self.config.use_exponential_sigmas, + self.config.use_karras_sigmas, + ] + ) + > 1 + ): + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace( + beta_start, beta_end, num_train_timesteps, dtype=torch.float32 + ) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError( + f"{beta_schedule} is not implemented for {self.__class__}" + ) + + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + if rescale_betas_zero_snr: + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + self.alphas_cumprod[-1] = 2**-24 + + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}" + ) + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + timesteps = np.linspace( + 0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32 + )[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.num_train_timesteps = num_train_timesteps + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + BaseScheduler.__init__(self) + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + def set_shift(self, shift: float) -> None: + self.config.flow_shift = shift + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_timesteps( + self, + num_inference_steps: int, + device: str | torch.device = None, + mu: float | None = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891 + if mu is not None: + assert ( + self.config.use_dynamic_shifting + and self.config.time_shift_type == "exponential" + ) + self.config.flow_shift = np.exp(mu) + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace( + 0, self.config.num_train_timesteps - 1, num_inference_steps + 1 + ) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps + 1) * step_ratio) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + np.arange(self.config.num_train_timesteps, 0, -step_ratio) + .round() + .copy() + .astype(np.int64) + ) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_karras( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + timesteps = np.array( + [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas] + ).round() + if self.config.final_sigmas_type == "sigma_min": + sigma_last = sigmas[-1] + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + elif self.config.use_exponential_sigmas: + log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_exponential( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + timesteps = np.array( + [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas] + ) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = sigmas[-1] + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + elif self.config.use_beta_sigmas: + log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_beta( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + timesteps = np.array( + [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas] + ) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = sigmas[-1] + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + elif self.config.use_flow_sigmas: + alphas = np.linspace( + 1, 1 / self.config.num_train_timesteps, num_inference_steps + 1 + ) + sigmas = 1.0 - alphas + sigmas = np.flip( + self.config.flow_shift + * sigmas + / (1 + (self.config.flow_shift - 1) * sigmas) + )[:-1].copy() + timesteps = (sigmas * self.config.num_train_timesteps).copy() + if self.config.final_sigmas_type == "sigma_min": + sigma_last = sigmas[-1] + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ( + (1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0] + ) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64 + ) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://huggingface.co/papers/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = ( + sample.float() + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = ( + torch.clamp(sample, -s, s) / s + ) # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = ( + np.cumsum((dists >= 0), axis=0) + .argmax(axis=0) + .clip(max=log_sigmas.shape[0] - 2) + ) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t + def _sigma_to_alpha_sigma_t(self, sigma): + if self.config.use_flow_sigmas: + alpha_t = 1 - sigma + sigma_t = sigma + else: + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras( + self, in_sigmas: torch.Tensor, num_inference_steps + ) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential( + self, in_sigmas: torch.Tensor, num_inference_steps: int + ) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp( + np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps) + ) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, + in_sigmas: torch.Tensor, + num_inference_steps: int, + alpha: float = 0.6, + beta: float = 0.6, + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + r""" + Convert the model output to the corresponding type the UniPC algorithm needs. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyword argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "epsilon": + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + x0_pred = alpha_t * sample - sigma_t * model_output + elif self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " + "`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "epsilon": + return model_output + elif self.config.prediction_type == "sample": + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." + ) + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of UniP at this timestep (corresponds to the *p* in UniPC-p). + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyword argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError("missing `order` as a required keyword argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. + + Returns: + `torch.Tensor`: + The corrected sample tensor at the current timestep. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError("missing `last_sample` as a required keyword argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError("missing `this_sample` as a required keyword argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError("missing `order` as a required keyword argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.Tensor, + return_dict: bool = True, + ) -> SchedulerOutput | tuple: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to call 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 + and self.step_index - 1 not in self.disable_corrector + and self.last_sample is not None + ) + + model_output_convert = self.convert_model_output(model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + if self.config.lower_order_final: + this_order = min( + self.config.solver_order, len(self.timesteps) - self.step_index + ) + else: + this_order = self.config.solver_order + + self.this_order = min( + this_order, self.lower_order_nums + 1 + ) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype + ) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32 + ) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps + + +EntryClass = UniPCMultistepScheduler diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/utils.py b/sglang/python/sglang/multimodal_gen/runtime/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0199bbb3d453f77766040705eb6d5ffc6e4a8c8e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/utils.py @@ -0,0 +1,140 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/utils.py +"""Utils for model executor.""" + +from typing import Any + +import torch + + +def set_weight_attrs( + weight: torch.Tensor, + weight_attrs: dict[str, Any] | None, +): + """Set attributes on a weight tensor. + + This method is used to set attributes on a weight tensor. This method + will not overwrite existing attributes. + + Args: + weight: The weight tensor. + weight_attrs: A dictionary of attributes to set on the weight tensor. + """ + if weight_attrs is None: + return + for key, value in weight_attrs.items(): + assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}" + + # NOTE(woosuk): During weight loading, we often do something like: + # narrowed_tensor = param.data.narrow(0, offset, len) + # narrowed_tensor.copy_(real_weight) + # expecting narrowed_tensor and param.data to share the same storage. + # However, on TPUs, narrowed_tensor will lazily propagate to the base + # tensor, which is param.data, leading to the redundant memory usage. + # This sometimes causes OOM errors during model loading. To avoid this, + # we sync the param tensor after its weight loader is called. + # TODO(woosuk): Remove this hack once we have a better solution. + from sglang.multimodal_gen.runtime.platforms import current_platform + + if current_platform.is_tpu() and key == "weight_loader": + value = _make_synced_weight_loader(value) + setattr(weight, key, value) + + +def _make_synced_weight_loader(original_weight_loader) -> Any: + + def _synced_weight_loader(param, *args, **kwargs): + original_weight_loader(param, *args, **kwargs) + torch._sync(param) + + return _synced_weight_loader + + +def extract_layer_index(layer_name: str) -> int: + """ + Extract the layer index from the module name. + Examples: + - "encoder.layers.0" -> 0 + - "encoder.layers.1.self_attn" -> 1 + - "2.self_attn" -> 2 + - "model.encoder.layers.0.sub.1" -> ValueError + """ + subnames = layer_name.split(".") + int_vals: list[int] = [] + for subname in subnames: + try: + int_vals.append(int(subname)) + except ValueError: + continue + assert len(int_vals) == 1, ( + f"layer name {layer_name} should" " only contain one integer" + ) + return int_vals[0] + + +def modulate( + x: torch.Tensor, + shift: torch.Tensor | None = None, + scale: torch.Tensor | None = None, +) -> torch.Tensor: + """modulate by shift and scale""" + if scale is None and shift is None: + return x + elif shift is None: + return x * (1 + scale.unsqueeze(1)) # type: ignore[union-attr] + elif scale is None: + return x + shift.unsqueeze(1) # type: ignore[union-attr] + else: + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze( + 1 + ) # type: ignore[union-attr] + + +def pred_noise_to_pred_video( + pred_noise: torch.Tensor, + noise_input_latent: torch.Tensor, + timestep: torch.Tensor, + scheduler: Any, +) -> torch.Tensor: + """ + Convert predicted noise to clean latent. + + Args: + pred_noise: the predicted noise with shape [B, C, H, W] + where B is batch_size or batch_size * num_frames + noise_input_latent: the noisy latent with shape [B, C, H, W], + timestep: the timestep with shape [1] or [bs * num_frames] or [bs, num_frames] + scheduler: the scheduler + + Returns: + the predicted video with shape [B, C, H, W] + """ + # If timestep is [bs, num_frames] + if timestep.ndim == 2: + timestep = timestep.flatten(0, 1) + assert timestep.numel() == noise_input_latent.shape[0] + elif timestep.ndim == 1: + # If timestep is [1] + if timestep.shape[0] == 1: + timestep = timestep.expand(noise_input_latent.shape[0]) + else: + assert timestep.numel() == noise_input_latent.shape[0] + else: + raise ValueError( + f"[pred_noise_to_pred_video] Invalid timestep shape: {timestep.shape}" + ) + # timestep shape should be [B] + dtype = pred_noise.dtype + device = pred_noise.device + pred_noise = pred_noise.double().to(device) + noise_input_latent = noise_input_latent.double().to(device) + sigmas = scheduler.sigmas.double().to(device) + timesteps = scheduler.timesteps.double().to(device) + timestep_id = torch.argmin( + (timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1 + ) + sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1) + pred_video = noise_input_latent - sigma_t * pred_noise + return pred_video.to(dtype) diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder.py b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c6a49b95c9acc2fd7ff4df61994e0f87557d3be2 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder.py @@ -0,0 +1,583 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from typing import Dict, Optional, Tuple, Union + +import torch +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, + FusedAttnProcessor2_0, +) +from diffusers.models.autoencoders.vae import ( + Decoder, + DecoderOutput, + DiagonalGaussianDistribution, + Encoder, +) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from torch import nn + +from sglang.multimodal_gen.configs.models.vaes.flux import FluxVAEConfig + + +class AutoencoderKL(nn.Module): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast` + can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + mid_block_add_attention (`bool`, *optional*, default to `True`): + If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the + mid_block will only have resnet blocks + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] + + def __init__( + self, + config: FluxVAEConfig, + ): + super().__init__() + self.config = config + arch_config = config.arch_config + + in_channels = arch_config.in_channels + out_channels = arch_config.out_channels + down_block_types = arch_config.down_block_types + up_block_types = arch_config.up_block_types + block_out_channels = arch_config.block_out_channels + layers_per_block = arch_config.layers_per_block + act_fn = arch_config.act_fn + latent_channels = arch_config.latent_channels + norm_num_groups = arch_config.norm_num_groups + sample_size = arch_config.sample_size + use_quant_conv = arch_config.use_quant_conv + use_post_quant_conv = arch_config.use_post_quant_conv + mid_block_add_attention = arch_config.mid_block_add_attention + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + mid_block_add_attention=mid_block_add_attention, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + mid_block_add_attention=mid_block_add_attention, + ) + + self.quant_conv = ( + nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + if use_quant_conv + else None + ) + self.post_quant_conv = ( + nn.Conv2d(latent_channels, latent_channels, 1) + if use_post_quant_conv + else None + ) + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_size = sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int( + sample_size / (2 ** (len(self.config.block_out_channels) - 1)) + ) + self.tile_overlap_factor = 0.25 + + def enable_tiling(self, use_tiling: bool = True): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.use_tiling = use_tiling + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.enable_tiling(False) + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all( + proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS + for proc in self.attn_processors.values() + ): + processor = AttnAddedKVProcessor() + elif all( + proc.__class__ in CROSS_ATTENTION_PROCESSORS + for proc in self.attn_processors.values() + ): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = x.shape + + if self.use_tiling and ( + width > self.tile_sample_min_size or height > self.tile_sample_min_size + ): + return self._tiled_encode(x) + + enc = self.encoder(x) + if self.quant_conv is not None: + enc = self.quant_conv(enc) + + return enc + + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode( + self, z: torch.Tensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + if self.use_tiling and ( + z.shape[-1] > self.tile_latent_min_size + or z.shape[-2] > self.tile_latent_min_size + ): + return self.tiled_decode(z, return_dict=return_dict) + + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def decode(self, z: torch.FloatTensor) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + return decoded + + def blend_v( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[ + :, :, y, : + ] * (y / blend_extent) + return b + + def blend_h( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[ + :, :, :, x + ] * (x / blend_extent) + return b + + def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[ + :, + :, + i : i + self.tile_sample_min_size, + j : j + self.tile_sample_min_size, + ] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + enc = torch.cat(result_rows, dim=2) + return enc + + def tiled_encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. + """ + deprecation_message = ( + "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the " + "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able " + "to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value." + ) + # deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False) + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[ + :, + :, + i : i + self.tile_sample_min_size, + j : j + self.tile_sample_min_size, + ] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def tiled_decode( + self, z: torch.Tensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], overlap_size): + row = [] + for j in range(0, z.shape[3], overlap_size): + tile = z[ + :, + :, + i : i + self.tile_latent_min_size, + j : j + self.tile_latent_min_size, + ] + if self.config.use_post_quant_conv: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + return dec + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + > [!WARNING] > This API is 🧪 experimental. + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError( + "`fuse_qkv_projections()` is not supported for models having added KV projections." + ) + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + +EntryClass = AutoencoderKL diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_flux2.py b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_flux2.py new file mode 100644 index 0000000000000000000000000000000000000000..7410358fbdee78d5c1a2ace5cf280776af0ba988 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_flux2.py @@ -0,0 +1,524 @@ +import math +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from diffusers.models.autoencoders.vae import ( + Decoder, + DecoderOutput, + DiagonalGaussianDistribution, + Encoder, +) +from diffusers.models.modeling_outputs import AutoencoderKLOutput + +from sglang.multimodal_gen.configs.models.vaes.flux import Flux2VAEConfig +from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE + + +class AutoencoderKLFlux2(ParallelTiledVAE): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] + + def __init__( + self, + config: Flux2VAEConfig, + ): + super().__init__(config=config) + + self.config = config + arch_config = config.arch_config + + in_channels: int = arch_config.in_channels + out_channels: int = arch_config.out_channels + down_block_types: Tuple[str, ...] = arch_config.down_block_types + up_block_types: Tuple[str, ...] = arch_config.up_block_types + block_out_channels: Tuple[int, ...] = arch_config.block_out_channels + layers_per_block: int = arch_config.layers_per_block + act_fn: str = arch_config.act_fn + latent_channels: int = arch_config.latent_channels + norm_num_groups: int = arch_config.norm_num_groups + sample_size: int = arch_config.sample_size + force_upcast: bool = arch_config.force_upcast + use_quant_conv: bool = arch_config.use_quant_conv + use_post_quant_conv: bool = arch_config.use_post_quant_conv + mid_block_add_attention: bool = arch_config.mid_block_add_attention + batch_norm_eps: float = arch_config.batch_norm_eps + batch_norm_momentum: float = arch_config.batch_norm_momentum + patch_size: Tuple[int, int] = arch_config.patch_size + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + mid_block_add_attention=mid_block_add_attention, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + mid_block_add_attention=mid_block_add_attention, + ) + + self.quant_conv = ( + nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + if use_quant_conv + else None + ) + self.post_quant_conv = ( + nn.Conv2d(latent_channels, latent_channels, 1) + if use_post_quant_conv + else None + ) + + self.bn = nn.BatchNorm2d( + math.prod(patch_size) * latent_channels, + eps=batch_norm_eps, + momentum=batch_norm_momentum, + affine=False, + track_running_stats=True, + ) + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int( + sample_size / (2 ** (len(self.config.block_out_channels) - 1)) + ) + self.tile_overlap_factor = 0.25 + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all( + proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS + for proc in self.attn_processors.values() + ): + processor = AttnAddedKVProcessor() + elif all( + proc.__class__ in CROSS_ATTENTION_PROCESSORS + for proc in self.attn_processors.values() + ): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = x.shape + + if self.use_tiling and ( + width > self.tile_sample_min_size or height > self.tile_sample_min_size + ): + return self._tiled_encode(x) + + enc = self.encoder(x) + if self.quant_conv is not None: + enc = self.quant_conv(enc) + + return enc + + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[DiagonalGaussianDistribution]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + + if x.ndim == 5: + assert x.shape[2] == 1 + x = x.squeeze(2) + + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + return posterior + + def _decode( + self, z: torch.Tensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + if self.use_tiling and ( + z.shape[-1] > self.tile_latent_min_size + or z.shape[-2] > self.tile_latent_min_size + ): + return self.tiled_decode(z, return_dict=return_dict) + + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + return decoded + + def blend_v( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[ + :, :, y, : + ] * (y / blend_extent) + return b + + def blend_h( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[ + :, :, :, x + ] * (x / blend_extent) + return b + + def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[ + :, + :, + i : i + self.tile_sample_min_size, + j : j + self.tile_sample_min_size, + ] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + enc = torch.cat(result_rows, dim=2) + return enc + + def tiled_encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. + """ + deprecation_message = ( + "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the " + "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able " + "to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value." + ) + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[ + :, + :, + i : i + self.tile_sample_min_size, + j : j + self.tile_sample_min_size, + ] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def tiled_decode( + self, z: torch.Tensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], overlap_size): + row = [] + for j in range(0, z.shape[3], overlap_size): + tile = z[ + :, + :, + i : i + self.tile_latent_min_size, + j : j + self.tile_latent_min_size, + ] + if self.config.use_post_quant_conv: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + +EntryClass = AutoencoderKLFlux2 diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_qwenimage.py b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_qwenimage.py new file mode 100644 index 0000000000000000000000000000000000000000..42c5426d74343c6c096a51cf4a59a5e9e84680dd --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_qwenimage.py @@ -0,0 +1,1174 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.activations import get_activation +from diffusers.models.autoencoders.vae import ( + DecoderOutput, + DiagonalGaussianDistribution, +) +from diffusers.models.modeling_outputs import AutoencoderKLOutput + +from sglang.multimodal_gen.configs.models.vaes.qwenimage import QwenImageVAEConfig +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) # pylint: disable=invalid-name + +CACHE_T = 2 + + +class QwenImageCausalConv3d(nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + # Set up causal padding + self._padding = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + 2 * self.padding[0], + 0, + ) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + return super().forward(x) + + +class QwenImageRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__( + self, + dim: int, + channel_first: bool = True, + images: bool = True, + bias: bool = False, + ) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return ( + F.normalize(x, dim=(1 if self.channel_first else -1)) + * self.scale + * self.gamma + + self.bias + ) + + +class QwenImageUpsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + + Returns: + torch.Tensor: Upsampled tensor with the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class QwenImageResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + self.time_conv = QwenImageCausalConv3d( + dim, dim * 2, (3, 1, 1), padding=(1, 0, 0) + ) + + elif mode == "downsample2d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + elif mode == "downsample3d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + self.time_conv = QwenImageCausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) + ) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if ( + cache_x.shape[2] < 2 + and feat_cache[idx] is not None + and feat_cache[idx] != "Rep" + ): + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + if ( + cache_x.shape[2] < 2 + and feat_cache[idx] is not None + and feat_cache[idx] == "Rep" + ): + cache_x = torch.cat( + [torch.zeros_like(cache_x).to(cache_x.device), cache_x], + dim=2, + ) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2) + ) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class QwenImageResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = get_activation(non_linearity) + + # layers + self.norm1 = QwenImageRMS_norm(in_dim, images=False) + self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = QwenImageRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = ( + QwenImageCausalConv3d(in_dim, out_dim, 1) + if in_dim != out_dim + else nn.Identity() + ) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +class QwenImageAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = QwenImageRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention(q, k, v) + + x = ( + x.squeeze(1) + .permute(0, 2, 1) + .reshape(batch_size * time, channels, height, width) + ) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + +class QwenImageMidBlock(nn.Module): + """ + Middle block for QwenImageVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + num_layers: int = 1, + ): + super().__init__() + self.dim = dim + + # Create the components + resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(QwenImageAttentionBlock(dim)) + resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # First residual block + x = self.resnets[0](x, feat_cache, feat_idx) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + + x = resnet(x, feat_cache, feat_idx) + + return x + + +class QwenImageEncoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + non_linearity: str = "silu", + input_channels: int = 3, + ): + super().__init__() + # dim = config.arch_config.dim + # z_dim = config.arch_config.z_dim + # dim_mult = config.arch_config.dim_mult + # num_res_blocks = config.arch_config.num_res_blocks + # attn_scales = config.arch_config.attn_scales + # temperal_downsample = config.arch_config.temperal_downsample + # dropout = config.arch_config.dropout + # non_linearity = config.arch_config.non_linearity + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = QwenImageCausalConv3d(input_channels, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + self.down_blocks.append( + QwenImageResidualBlock(in_dim, out_dim, dropout) + ) + if scale in attn_scales: + self.down_blocks.append(QwenImageAttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(QwenImageResample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = QwenImageMidBlock( + out_dim, dropout, non_linearity, num_layers=1 + ) + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class QwenImageUpBlock(nn.Module): + """ + A block that handles upsampling for the QwenImageVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append( + QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity) + ) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList( + [QwenImageResample(out_dim, mode=upsample_mode)] + ) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache, feat_idx) + else: + x = self.upsamplers[0](x) + return x + + +class QwenImageDecoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", + input_channels=3, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = QwenImageMidBlock( + dims[0], dropout, non_linearity, num_layers=1 + ) + + # upsample blocks + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0: + in_dim = in_dim // 2 + + # Determine if we need upsampling + upsample_mode = None + if i != len(dim_mult) - 1: + upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + + # Create and add the upsampling block + up_block = QwenImageUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + # Update scale for next iteration + if upsample_mode is not None: + scale *= 2.0 + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, input_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class AutoencoderKLQwenImage(ParallelTiledVAE): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = False + + # fmt: off + def __init__( + self, + config: QwenImageVAEConfig, + ) -> None: + # fmt: on + super().__init__(config=config) + base_dim = config.arch_config.base_dim + z_dim = config.arch_config.z_dim + dim_mult = config.arch_config.dim_mult + num_res_blocks = config.arch_config.num_res_blocks + attn_scales = config.arch_config.attn_scales + temperal_downsample = config.arch_config.temperal_downsample + dropout = config.arch_config.dropout + # non_linearity = config.arch_config.non_linearity + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + self.input_channels = config.arch_config.input_channels + self.latents_mean = config.arch_config.latents_mean + self.config = config.arch_config + + self.encoder = QwenImageEncoder3d( + base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, + input_channels=self.input_channels + ) + self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1) + + self.decoder = QwenImageDecoder3d( + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, + input_channels=self.input_channels + ) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup + self._cached_conv_counts = { + "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules()) + if self.decoder is not None + else 0, + "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules()) + if self.encoder is not None + else 0, + } + cuda_device = get_local_torch_device() + # FIXME: hardcode + dtype = torch.bfloat16 + latent_channels = config.arch_config.z_dim + + self.shift_factor = ( + torch.tensor( + config.arch_config.latents_mean + ) + .view(1, latent_channels, 1, 1, 1) + .to(cuda_device, dtype) + ) + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def clear_cache(self): + def _count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, QwenImageCausalConv3d): + count += 1 + return count + + self._conv_num = _count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + def _encode(self, x: torch.Tensor): + _, _, num_frame, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + self.clear_cache() + iter_ = 1 + (num_frame - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1): 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + + enc = self.quant_conv(out) + self.clear_cache() + return enc + + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> DiagonalGaussianDistribution: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + return posterior + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + _, _, num_frame, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + self.clear_cache() + x = self.post_quant_conv(z) + for i in range(num_frame): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i: i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i: i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + + out = torch.clamp(out, min=-1.0, max=1.0) + self.clear_cache() + if not return_dict: + return (out,) + + return DecoderOutput(sample=out) + + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + return decoded + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + self.clear_cache() + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, i: i + self.tile_sample_min_height, j: j + self.tile_sample_min_width] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1): 1 + 4 * k, + i: i + self.tile_sample_min_height, + j: j + self.tile_sample_min_width, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + self.clear_cache() + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k: k + 1, i: i + tile_latent_min_height, j: j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) + time.append(decoded) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec + + +EntryClass = AutoencoderKLQwenImage diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/vaes/common.py b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/common.py new file mode 100644 index 0000000000000000000000000000000000000000..095ce49574f53c5be219a91b35b08428b638991f --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/common.py @@ -0,0 +1,648 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from collections.abc import Iterator +from math import prod +from typing import Optional, cast + +import numpy as np +import torch +import torch.distributed as dist +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from diffusers.utils.torch_utils import randn_tensor +from torch import nn + +from sglang.multimodal_gen.configs.models import VAEConfig +from sglang.multimodal_gen.runtime.distributed import ( + get_sp_parallel_rank, + get_sp_world_size, +) + + +class ParallelTiledVAE(ABC, nn.Module): + tile_sample_min_height: int + tile_sample_min_width: int + tile_sample_min_num_frames: int + tile_sample_stride_height: int + tile_sample_stride_width: int + tile_sample_stride_num_frames: int + blend_num_frames: int + use_tiling: bool + use_temporal_tiling: bool + use_parallel_tiling: bool + + def __init__(self, config: VAEConfig, **kwargs) -> None: + super().__init__() + self.config = config + self.tile_sample_min_height = config.tile_sample_min_height + self.tile_sample_min_width = config.tile_sample_min_width + self.tile_sample_min_num_frames = config.tile_sample_min_num_frames + self.tile_sample_stride_height = config.tile_sample_stride_height + self.tile_sample_stride_width = config.tile_sample_stride_width + self.tile_sample_stride_num_frames = config.tile_sample_stride_num_frames + self.blend_num_frames = config.blend_num_frames + self.use_tiling = config.use_tiling + self.use_temporal_tiling = config.use_temporal_tiling + self.use_parallel_tiling = config.use_parallel_tiling + + @property + def device(self): + return next(self.parameters()).device + + @property + def temporal_compression_ratio(self) -> int: + return cast(int, self.config.temporal_compression_ratio) + + @property + def spatial_compression_ratio(self) -> int: + return cast(int, self.config.spatial_compression_ratio) + + @property + def scaling_factor(self) -> float | torch.Tensor: + return cast(float | torch.Tensor, self.config.scaling_factor) + + @abstractmethod + def _encode(self, *args, **kwargs) -> torch.Tensor: + pass + + @abstractmethod + def _decode(self, *args, **kwargs) -> torch.Tensor: + pass + + def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution: + batch_size, num_channels, num_frames, height, width = x.shape + latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 + + if ( + self.use_tiling + and self.use_temporal_tiling + and num_frames > self.tile_sample_min_num_frames + ): + latents = self.tiled_encode(x)[:, :, :latent_num_frames] + elif self.use_tiling and ( + width > self.tile_sample_min_width or height > self.tile_sample_min_height + ): + latents = self.spatial_tiled_encode(x)[:, :, :latent_num_frames] + else: + latents = self._encode(x)[:, :, :latent_num_frames] + return DiagonalGaussianDistribution(latents) + + def decode(self, z: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = ( + self.tile_sample_min_height // self.spatial_compression_ratio + ) + tile_latent_min_width = ( + self.tile_sample_stride_width // self.spatial_compression_ratio + ) + tile_latent_min_num_frames = ( + self.tile_sample_min_num_frames // self.temporal_compression_ratio + ) + num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + if self.use_tiling and self.use_parallel_tiling and get_sp_world_size() > 1: + return self.parallel_tiled_decode(z)[:, :, :num_sample_frames] + if ( + self.use_tiling + and self.use_temporal_tiling + and num_frames > tile_latent_min_num_frames + ): + return self.tiled_decode(z)[:, :, :num_sample_frames] + + if self.use_tiling and ( + width > tile_latent_min_width or height > tile_latent_min_height + ): + return self.spatial_tiled_decode(z)[:, :, :num_sample_frames] + + return self._decode(z)[:, :, :num_sample_frames] + + def blend_v( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * ( + 1 - y / blend_extent + ) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * ( + 1 - x / blend_extent + ) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def blend_t( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * ( + 1 - x / blend_extent + ) + b[:, :, x, :, :] * (x / blend_extent) + return b + + def spatial_tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, _, height, width = x.shape + # latent_height = height // self.spatial_compression_ratio + # latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = ( + self.tile_sample_min_height // self.spatial_compression_ratio + ) + tile_latent_min_width = ( + self.tile_sample_min_width // self.spatial_compression_ratio + ) + tile_latent_stride_height = ( + self.tile_sample_stride_height // self.spatial_compression_ratio + ) + tile_latent_stride_width = ( + self.tile_sample_stride_width // self.spatial_compression_ratio + ) + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + tile = x[ + :, + :, + :, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self._encode(tile) + row.append(tile) + rows.append(row) + + return self._merge_spatial_tiles( + rows, + blend_height, + blend_width, + tile_latent_stride_height, + tile_latent_stride_width, + ) + + def _parallel_data_generator( + self, gathered_results, gathered_dim_metadata + ) -> Iterator[tuple[torch.Tensor, int]]: + global_idx = 0 + for i, per_rank_metadata in enumerate(gathered_dim_metadata): + _start_shape = 0 + for shape in per_rank_metadata: + mul_shape = prod(shape) + yield ( + gathered_results[ + i, _start_shape : _start_shape + mul_shape + ].reshape(shape), + global_idx, + ) + _start_shape += mul_shape + global_idx += 1 + + def parallel_tiled_decode(self, z: torch.FloatTensor) -> torch.FloatTensor: + """ + Parallel version of tiled_decode that distributes both temporal and spatial computation across GPUs + """ + world_size, rank = get_sp_world_size(), get_sp_parallel_rank() + B, C, T, H, W = z.shape + + # Calculate parameters + tile_latent_min_height = ( + self.tile_sample_min_height // self.spatial_compression_ratio + ) + tile_latent_min_width = ( + self.tile_sample_min_width // self.spatial_compression_ratio + ) + tile_latent_min_num_frames = ( + self.tile_sample_min_num_frames // self.temporal_compression_ratio + ) + tile_latent_stride_height = ( + self.tile_sample_stride_height // self.spatial_compression_ratio + ) + tile_latent_stride_width = ( + self.tile_sample_stride_width // self.spatial_compression_ratio + ) + tile_latent_stride_num_frames = ( + self.tile_sample_stride_num_frames // self.temporal_compression_ratio + ) + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Calculate tile dimensions + num_t_tiles = ( + T + tile_latent_stride_num_frames - 1 + ) // tile_latent_stride_num_frames + num_h_tiles = (H + tile_latent_stride_height - 1) // tile_latent_stride_height + num_w_tiles = (W + tile_latent_stride_width - 1) // tile_latent_stride_width + total_spatial_tiles = num_h_tiles * num_w_tiles + total_tiles = num_t_tiles * total_spatial_tiles + + # Calculate tiles per rank and padding + tiles_per_rank = (total_tiles + world_size - 1) // world_size + start_tile_idx = rank * tiles_per_rank + end_tile_idx = min((rank + 1) * tiles_per_rank, total_tiles) + + local_results = [] + local_dim_metadata = [] + # Process assigned tiles + for local_idx, global_idx in enumerate(range(start_tile_idx, end_tile_idx)): + t_idx = global_idx // total_spatial_tiles + spatial_idx = global_idx % total_spatial_tiles + h_idx = spatial_idx // num_w_tiles + w_idx = spatial_idx % num_w_tiles + + # Calculate positions + t_start = t_idx * tile_latent_stride_num_frames + h_start = h_idx * tile_latent_stride_height + w_start = w_idx * tile_latent_stride_width + + # Extract and process tile + tile = z[ + :, + :, + t_start : t_start + tile_latent_min_num_frames + 1, + h_start : h_start + tile_latent_min_height, + w_start : w_start + tile_latent_min_width, + ] + + # Process tile + tile = self._decode(tile) + + if t_start > 0: + tile = tile[:, :, 1:, :, :] + + # Store metadata + shape = tile.shape + # Store decoded data (flattened) + decoded_flat = tile.reshape(-1) + local_results.append(decoded_flat) + local_dim_metadata.append(shape) + + results = torch.cat(local_results, dim=0).contiguous() + del local_results + # first gather size to pad the results + local_size = torch.tensor( + [results.size(0)], device=results.device, dtype=torch.int64 + ) + all_sizes = [ + torch.zeros(1, device=results.device, dtype=torch.int64) + for _ in range(world_size) + ] + dist.all_gather(all_sizes, local_size) + max_size = max(size.item() for size in all_sizes) + padded_results = torch.zeros(max_size, device=results.device) + padded_results[: results.size(0)] = results + del results + + # Gather all results + gathered_dim_metadata = [None] * world_size + gathered_results = ( + torch.zeros_like(padded_results) + .repeat(world_size, *[1] * len(padded_results.shape)) + .contiguous() + ) # use contiguous to make sure it won't copy data in the following operations + # TODO (PY): use sgl_diffusion distributed methods + dist.all_gather_into_tensor(gathered_results, padded_results) + dist.all_gather_object(gathered_dim_metadata, local_dim_metadata) + # Process gathered results + data: list = [ + [[[] for _ in range(num_w_tiles)] for _ in range(num_h_tiles)] + for _ in range(num_t_tiles) + ] + for current_data, global_idx in self._parallel_data_generator( + gathered_results, gathered_dim_metadata + ): + t_idx = global_idx // total_spatial_tiles + spatial_idx = global_idx % total_spatial_tiles + h_idx = spatial_idx // num_w_tiles + w_idx = spatial_idx % num_w_tiles + data[t_idx][h_idx][w_idx] = current_data + # Merge results + result_slices = [] + last_slice_data = None + for i, tem_data in enumerate(data): + slice_data = self._merge_spatial_tiles( + tem_data, + blend_height, + blend_width, + self.tile_sample_stride_height, + self.tile_sample_stride_width, + ) + if i > 0: + slice_data = self.blend_t( + last_slice_data, slice_data, self.blend_num_frames + ) + result_slices.append( + slice_data[:, :, : self.tile_sample_stride_num_frames, :, :] + ) + else: + result_slices.append( + slice_data[:, :, : self.tile_sample_stride_num_frames + 1, :, :] + ) + last_slice_data = slice_data + dec = torch.cat(result_slices, dim=2) + + return dec + + def _merge_spatial_tiles( + self, tiles, blend_height, blend_width, stride_height, stride_width + ) -> torch.Tensor: + """Helper function to merge spatial tiles with blending""" + result_rows = [] + for i, row in enumerate(tiles): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(tiles[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :stride_height, :stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + return torch.cat(result_rows, dim=-2) + + def spatial_tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + + Returns: + `torch.Tensor`: + The decoded images. + """ + + _, _, _, height, width = z.shape + # sample_height = height * self.spatial_compression_ratio + # sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = ( + self.tile_sample_min_height // self.spatial_compression_ratio + ) + tile_latent_min_width = ( + self.tile_sample_min_width // self.spatial_compression_ratio + ) + tile_latent_stride_height = ( + self.tile_sample_stride_height // self.spatial_compression_ratio + ) + tile_latent_stride_width = ( + self.tile_sample_stride_width // self.spatial_compression_ratio + ) + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + tile = z[ + :, + :, + :, + i : i + tile_latent_min_height, + j : j + tile_latent_min_width, + ] + decoded = self._decode(tile) + row.append(decoded) + rows.append(row) + return self._merge_spatial_tiles( + rows, + blend_height, + blend_width, + self.tile_sample_stride_height, + self.tile_sample_stride_width, + ) + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + _, _, num_frames, height, width = x.shape + + # tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = ( + self.tile_sample_stride_num_frames // self.temporal_compression_ratio + ) + + row = [] + for i in range(0, num_frames, self.tile_sample_stride_num_frames): + tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] + if self.use_tiling and ( + height > self.tile_sample_min_height + or width > self.tile_sample_min_width + ): + tile = self.spatial_tiled_encode(tile) + else: + tile = self._encode(tile) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, self.blend_num_frames) + result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :]) + enc = torch.cat(result_row, dim=2) + return enc + + def tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = z.shape + + tile_latent_min_height = ( + self.tile_sample_min_height // self.spatial_compression_ratio + ) + tile_latent_min_width = ( + self.tile_sample_min_width // self.spatial_compression_ratio + ) + tile_latent_min_num_frames = ( + self.tile_sample_min_num_frames // self.temporal_compression_ratio + ) + tile_latent_stride_num_frames = ( + self.tile_sample_stride_num_frames // self.temporal_compression_ratio + ) + + row = [] + for i in range(0, num_frames, tile_latent_stride_num_frames): + tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] + if self.use_tiling and ( + tile.shape[-1] > tile_latent_min_width + or tile.shape[-2] > tile_latent_min_height + ): + decoded = self.spatial_tiled_decode(tile) + else: + decoded = self._decode(tile) + if i > 0: + decoded = decoded[:, :, 1:, :, :] + row.append(decoded) + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, self.blend_num_frames) + result_row.append( + tile[:, :, : self.tile_sample_stride_num_frames, :, :] + ) + else: + result_row.append( + tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :] + ) + + dec = torch.cat(result_row, dim=2) + return dec + + def enable_tiling( + self, + tile_sample_min_height: int | None = None, + tile_sample_min_width: int | None = None, + tile_sample_min_num_frames: int | None = None, + tile_sample_stride_height: int | None = None, + tile_sample_stride_width: int | None = None, + tile_sample_stride_num_frames: int | None = None, + blend_num_frames: int | None = None, + use_tiling: bool | None = None, + use_temporal_tiling: bool | None = None, + use_parallel_tiling: bool | None = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_min_num_frames (`int`, *optional*): + The minimum number of frames required for a sample to be separated into tiles across the frame + dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + tile_sample_stride_num_frames (`int`, *optional*): + The stride between two consecutive frame tiles. This is to ensure that there are no tiling artifacts + produced across the frame dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = ( + tile_sample_min_height or self.tile_sample_min_height + ) + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_min_num_frames = ( + tile_sample_min_num_frames or self.tile_sample_min_num_frames + ) + self.tile_sample_stride_height = ( + tile_sample_stride_height or self.tile_sample_stride_height + ) + self.tile_sample_stride_width = ( + tile_sample_stride_width or self.tile_sample_stride_width + ) + self.tile_sample_stride_num_frames = ( + tile_sample_stride_num_frames or self.tile_sample_stride_num_frames + ) + if blend_num_frames is not None: + self.blend_num_frames = blend_num_frames + else: + self.blend_num_frames = ( + self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + ) + self.use_tiling = use_tiling or self.use_tiling + self.use_temporal_tiling = use_temporal_tiling or self.use_temporal_tiling + self.use_parallel_tiling = use_parallel_tiling or self.use_parallel_tiling + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + +# adapted from https://github.com/huggingface/diffusers/blob/e7ffeae0a191f710881d1fbde00cd6ff025e81f2/src/diffusers/models/autoencoders/vae.py#L691 +class DiagonalGaussianDistribution: + + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + + def sample(self, generator: torch.Generator | None = None) -> torch.Tensor: + # make sure sample is on the same device as the parameters and has same dtype + sample = randn_tensor( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + + def kl( + self, + other: Optional["DiagonalGaussianDistribution"] = None, + dims: tuple[int, ...] = (1, 2, 3), + ) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=dims, + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=dims, + ) + + def nll( + self, sample: torch.Tensor, dims: tuple[int, ...] = (1, 2, 3) + ) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self) -> torch.Tensor: + return self.mean diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/vaes/dac.py b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/dac.py new file mode 100644 index 0000000000000000000000000000000000000000..3d6d821ab8303253624d7beefe9bc438c39346dc --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/dac.py @@ -0,0 +1,627 @@ +# Copied and adapted from: https://github.com/descriptinc/descript-audio-codec + +# SPDX-License-Identifier: MIT + +import math +from bisect import bisect_right +from typing import Union + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +from sglang.multimodal_gen.configs.models.vaes.dac import DacVAEConfig +from sglang.multimodal_gen.runtime.models.vaes.common import ( + DiagonalGaussianDistribution, +) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) + + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = nn.Conv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = nn.Conv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + + def forward(self, z): + """Quantize the input tensor using a fixed codebook and return the corresponding codebook vectors. + + Args: + z (torch.Tensor): Input tensor with shape ``[B, D, T]``. + + Returns: + tuple: A tuple containing: + - z_q (torch.Tensor): Quantized continuous representation with shape ``[B, D, T]``. + - commitment_loss (torch.Tensor): Commitment loss scalar to train encoder to predict + vectors closer to codebook entries. + - codebook_loss (torch.Tensor): Codebook loss scalar to update the codebook. + - indices (torch.Tensor): Codebook indices (quantized discrete representation) with shape ``[B, T]``. + - z_e (torch.Tensor): Projected latents (continuous representation before quantization) with shape ``[B, D, T]``. + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + dim_offsets = [0] + for dim in self.codebook_dim: + dim_offsets.append(dim_offsets[-1] + dim) + self._codebook_dim_offsets = tuple(dim_offsets) + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i]) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantize the input tensor using a fixed set of codebooks and return the corresponding codebook vectors. + + Args: + z (torch.Tensor): Input tensor with shape ``[B, D, T]``. + n_quantizers (int, optional): Number of quantizers to use. If ``None``, + all quantizers are used. When ``n_quantizers`` < ``self.n_codebooks``, + quantizer dropout is applied. Note: if ``self.quantizer_dropout`` > 0 + and in training mode, this argument is ignored and a random number of + quantizers is used. + + Returns: + tuple: A tuple containing: + - z_q (torch.Tensor): Quantized continuous representation with shape ``[B, D, T]``. + - codes (torch.Tensor): Codebook indices for each codebook with shape ``[B, N, T]`` + (quantized discrete representation of input). + - latents (torch.Tensor): Projected latents with shape ``[B, N*D, T]`` + (continuous representation before quantization). + - commitment_loss (torch.Tensor): Commitment loss scalar to train encoder to predict + vectors closer to codebook entries. + - codebook_loss (torch.Tensor): Codebook loss scalar to update the codebook. + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + quantizers = self.quantizers + if self.training: + batch_size = z.shape[0] + device = z.device + n_quantizers = torch.full( + (batch_size,), + self.n_codebooks + 1, + device=device, + dtype=torch.long, + ) + if self.quantizer_dropout > 0: + dropout = torch.randint( + 1, + self.n_codebooks + 1, + (batch_size,), + device=device, + ) + n_dropout = int(batch_size * self.quantizer_dropout) + if n_dropout > 0: + n_quantizers[:n_dropout] = dropout[:n_dropout] + + for i, quantizer in enumerate(quantizers): + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = i < n_quantizers + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + else: + for i, quantizer in enumerate(quantizers): + if i >= n_quantizers: + break + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + z_q = z_q + z_q_i + residual = residual - z_q_i + + commitment_loss += commitment_loss_i.mean() + codebook_loss += codebook_loss_i.mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: torch.Tensor): + """Reconstruct the continuous representation from quantized codes. + + Args: + codes (torch.Tensor): Quantized discrete representation with shape ``[B, N, T]``. + + Returns: + tuple: A tuple containing: + - z_q (torch.Tensor): Quantized continuous representation with shape ``[B, D, T]``. + - z_p (torch.Tensor): Concatenated latent space representation with shape ``[B, N*D, T]``. + - codes (torch.Tensor): Original input codebook indices with shape ``[B, N, T]``. + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Reconstruct the continuous representation from unquantized latents. + + Args: + latents (torch.Tensor): Continuous representation after projection with shape ``[B, N*D, T]``. + + Returns: + tuple: A tuple containing: + - z_q (torch.Tensor): Quantized representation of full-projected space with shape ``[B, D, T]``. + - z_p (torch.Tensor): Quantized representation of latent space with shape ``[B, N*D, T]``. + - codes (torch.Tensor): Codebook indices with shape ``[B, N, T]``. + """ + z_q = 0 + z_p = [] + codes = [] + dims = self._codebook_dim_offsets + n_codebooks = bisect_right(dims, latents.shape[1]) - 1 + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim), + nn.Conv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Snake1d(dim), + nn.Conv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y + + +class EncoderBlock(nn.Module): + def __init__(self, dim: int = 16, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + ResidualUnit(dim // 2, dilation=1), + ResidualUnit(dim // 2, dilation=3), + ResidualUnit(dim // 2, dilation=9), + Snake1d(dim // 2), + nn.Conv1d( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + ) + + def forward(self, x): + return self.block(x) + + +class Encoder(nn.Module): + def __init__( + self, + d_model: int = 64, + strides: list = [2, 4, 8, 8], + d_latent: int = 64, + ): + super().__init__() + # Create first convolution + self.block = [nn.Conv1d(1, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride in strides: + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride)] + + # Create last convolution + self.block += [ + Snake1d(d_model), + nn.Conv1d(d_model, d_latent, kernel_size=3, padding=1), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + def forward(self, x): + return self.block(x) + + +class DecoderBlock(nn.Module): + def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + Snake1d(input_dim), + nn.ConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + output_padding=stride % 2, + ), + ResidualUnit(output_dim, dilation=1), + ResidualUnit(output_dim, dilation=3), + ResidualUnit(output_dim, dilation=9), + ) + + def forward(self, x): + return self.block(x) + + +class Decoder(nn.Module): + def __init__( + self, + input_channel, + channels, + rates, + d_out: int = 1, + ): + super().__init__() + + # Add first conv layer + layers = [nn.Conv1d(input_channel, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, stride in enumerate(rates): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DecoderBlock(input_dim, output_dim, stride)] + + # Add final conv layer + layers += [ + Snake1d(output_dim), + nn.Conv1d(output_dim, d_out, kernel_size=7, padding=3), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class DAC(nn.Module): + def __init__( + self, + config: DacVAEConfig, + ): + super().__init__() + + self.continuous = config.continuous + self.decoder_dim = config.decoder_dim + self.decoder_rates = config.decoder_rates + self.encoder_dim = config.encoder_dim + self.encoder_rates = config.encoder_rates + self.hop_length = math.prod(config.encoder_rates) + self.sample_rate = config.sample_rate + + if config.latent_dim is None: + latent_dim = config.encoder_dim * (2 ** len(config.encoder_rates)) + else: + latent_dim = config.latent_dim + + self.latent_dim = latent_dim + + if config.load_encoder: + self.encoder = Encoder(config.encoder_dim, config.encoder_rates, latent_dim) + + if not config.continuous: + self.n_codebooks = config.n_codebooks + self.codebook_size = config.codebook_size + self.codebook_dim = config.codebook_dim + self.quantizer = ResidualVectorQuantize( + input_dim=latent_dim, + n_codebooks=config.n_codebooks, + codebook_size=config.codebook_size, + codebook_dim=config.codebook_dim, + quantizer_dropout=config.quantizer_dropout, + ) + else: + self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1) + self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1) + + if config.load_decoder: + self.decoder = Decoder( + latent_dim, + config.decoder_dim, + config.decoder_rates, + ) + + self.apply(self.init_weights) + + @staticmethod + def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + @property + def dtype(self): + return next(self.parameters()).dtype + + @property + def device(self): + return next(self.parameters()).device + + def preprocess(self, audio_data, sample_rate): + if sample_rate is None: + sample_rate = self.sample_rate + assert sample_rate == self.sample_rate + + length = audio_data.shape[-1] + right_pad = math.ceil(length / self.hop_length) * self.hop_length - length + audio_data = nn.functional.pad(audio_data, (0, right_pad)) + + return audio_data + + def encode( + self, + audio_data: torch.Tensor, + n_quantizers: int = None, + ): + """Encode audio data into latent representations. + + This method processes audio through the encoder network and optionally applies + vector quantization (in VQ mode) or projects to a Gaussian distribution (in + continuous mode) to produce latent representations. + + Args: + audio_data (torch.Tensor): Audio data to encode, with shape ``[B, 1, T]``. + n_quantizers (int, optional): Number of quantizers to use. If ``None``, + all quantizers are used. Only applicable in VQ mode (``continuous=False``). + + Returns: + tuple: A tuple containing: + - z (torch.Tensor): Encoded representation. In VQ mode, this is the + quantized continuous representation with shape ``[B, D, T]``. In + continuous mode, this is a ``DiagonalGaussianDistribution`` object. + - codes (torch.Tensor or None): Codebook indices with shape ``[B, N, T]`` + in VQ mode, ``None`` in continuous mode. + - latents (torch.Tensor or None): Projected latents with shape ``[B, N*D, T]`` + in VQ mode, ``None`` in continuous mode. + - commitment_loss (torch.Tensor): Commitment loss scalar. + - codebook_loss (torch.Tensor): Codebook loss scalar. + + Note: + In continuous mode, the encoded representation is projected through a + quantization convolution layer and wrapped in a ``DiagonalGaussianDistribution`` + for VAE training. + """ + z = self.encoder(audio_data) # [B x D x T] + if not self.continuous: + z, codes, latents, commitment_loss, codebook_loss = self.quantizer( + z, n_quantizers + ) + else: + z = self.quant_conv(z) # [B x 2D x T] + z = DiagonalGaussianDistribution(z) + codes, latents, commitment_loss, codebook_loss = None, None, 0, 0 + + return z, codes, latents, commitment_loss, codebook_loss + + def decode(self, z: torch.Tensor): + """Decode latent representations back to audio waveforms. + + This method takes latent representations (either quantized from VQ mode or sampled + from the posterior in continuous mode) and reconstructs the corresponding audio + through the decoder network. + + Args: + z (torch.Tensor): Latent representation to decode, with shape ``[B, D, T]``. + In VQ mode (``continuous=False``), this is the quantized continuous + representation. In continuous mode (``continuous=True``), this is sampled + from the posterior distribution. + + Returns: + torch.Tensor: Decoded audio data with shape ``[B, 1, T']``. The output length + T' is determined by the decoder's upsampling rates and may differ from the + input temporal dimension T. + + Note: + In continuous mode (``continuous=True``), the input is first passed through + a post-quantization convolution layer before being fed to the decoder. + """ + if not self.continuous: + audio = self.decoder(z) + else: + z = self.post_quant_conv(z) + audio = self.decoder(z) + + return audio + + def forward( + self, + audio_data: torch.Tensor, + sample_rate: int = None, + n_quantizers: int = None, + ): + """Model forward pass. + + Args: + audio_data (torch.Tensor): Audio to encode, shape [B, 1, T]. + sample_rate (int, optional): Sample rate in Hz. Defaults to + ``self.sample_rate`` when ``None``. + n_quantizers (int, optional): Number of quantizers to use. When ``None``, + all quantizers are used. Only used in VQ mode (``continuous=False``). + + Returns: + dict: A dictionary containing different keys depending on the mode: + + **VQ Mode (``continuous=False``):** + - "audio" (torch.Tensor): Decoded audio, shape [B, 1, length]. + - "z" (torch.Tensor): Quantized continuous representation, shape [B, D, T]. + - "codes" (torch.Tensor): Codebook indices, shape [B, N, T]. + - "latents" (torch.Tensor): Projected latents, shape [B, N*D, T]. + - "vq/commitment_loss" (torch.Tensor): Commitment loss. + - "vq/codebook_loss" (torch.Tensor): Codebook loss. + + **Continuous Mode (``continuous=True``):** + - "audio" (torch.Tensor): Decoded audio, shape [B, 1, length]. + - "z" (torch.Tensor): Latent representation, shape [B, D, T]. + - "kl_loss" (torch.Tensor): KL divergence loss (for VAE training). + """ + length = audio_data.shape[-1] + audio_data = self.preprocess(audio_data, sample_rate) + if not self.continuous: + z, codes, latents, commitment_loss, codebook_loss = self.encode( + audio_data, n_quantizers + ) + + x = self.decode(z) + return { + "audio": x[..., :length], + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + else: + posterior, _, _, _, _ = self.encode(audio_data, n_quantizers) + z = posterior.sample() + x = self.decode(z) + + kl_loss = posterior.kl(dims=(1, 2)) + kl_loss = kl_loss.mean() + + return { + "audio": x[..., :length], + "z": z, + "kl_loss": kl_loss, + } + + +EntryClass = DAC diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/vaes/hunyuan3d_vae.py b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/hunyuan3d_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..c2792ff18018017ddb41dbc711c8400e87e24fba --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/hunyuan3d_vae.py @@ -0,0 +1,1224 @@ +# Copied and adapted from: https://github.com/Tencent-Hunyuan/Hunyuan3D-2 + + +from __future__ import annotations + +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from tqdm import tqdm + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +# Attention backend selection +scaled_dot_product_attention = F.scaled_dot_product_attention + + +class CrossAttentionProcessor: + def __call__(self, attn, q, k, v): + out = scaled_dot_product_attention(q, k, v) + return out + + +class FlashVDMCrossAttentionProcessor: + def __init__(self, topk=None): + self.topk = topk + + def __call__(self, attn, q, k, v): + if k.shape[-2] == 3072: + topk = 1024 + elif k.shape[-2] == 512: + topk = 256 + else: + topk = k.shape[-2] // 3 + + if self.topk is True: + q1 = q[:, :, ::100, :] + sim = q1 @ k.transpose(-1, -2) + sim = torch.mean(sim, -2) + topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1) + topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1]) + v0 = torch.gather(v, dim=-2, index=topk_ind) + k0 = torch.gather(k, dim=-2, index=topk_ind) + out = scaled_dot_product_attention(q, k0, v0) + elif self.topk is False: + out = scaled_dot_product_attention(q, k, v) + else: + idx, counts = self.topk + start = 0 + outs = [] + for grid_coord, count in zip(idx, counts): + end = start + count + q_chunk = q[:, :, start:end, :] + k0, v0 = self.select_topkv(q_chunk, k, v, topk) + out = scaled_dot_product_attention(q_chunk, k0, v0) + outs.append(out) + start += count + out = torch.cat(outs, dim=-2) + self.topk = False + return out + + def select_topkv(self, q_chunk, k, v, topk): + q1 = q_chunk[:, :, ::50, :] + sim = q1 @ k.transpose(-1, -2) + sim = torch.mean(sim, -2) + topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1) + topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1]) + v0 = torch.gather(v, dim=-2, index=topk_ind) + k0 = torch.gather(k, dim=-2, index=topk_ind) + return k0, v0 + + +class FlashVDMTopMCrossAttentionProcessor(FlashVDMCrossAttentionProcessor): + def select_topkv(self, q_chunk, k, v, topk): + q1 = q_chunk[:, :, ::30, :] + sim = q1 @ k.transpose(-1, -2) + # sim = sim.to(torch.float32) + sim = sim.softmax(-1) + sim = torch.mean(sim, 1) + activated_token = torch.where(sim > 1e-6)[2] + index = ( + torch.unique(activated_token, return_counts=True)[0] + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(-1) + ) + index = index.expand(-1, v.shape[1], -1, v.shape[-1]) + v0 = torch.gather(v, dim=-2, index=index) + k0 = torch.gather(k, dim=-2, index=index) + return k0, v0 + + +class FourierEmbedder(nn.Module): + def __init__( + self, + num_freqs: int = 6, + logspace: bool = True, + input_dim: int = 3, + include_input: bool = True, + include_pi: bool = True, + ) -> None: + """The initialization""" + + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32) + else: + frequencies = torch.linspace( + 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32 + ) + + if include_pi: + frequencies *= torch.pi + + self.register_buffer("frequencies", frequencies, persistent=False) + self.include_input = include_input + self.num_freqs = num_freqs + + self.out_dim = self.get_dims(input_dim) + + def get_dims(self, input_dim): + temp = 1 if self.include_input or self.num_freqs == 0 else 0 + out_dim = input_dim * (self.num_freqs * 2 + temp) + + return out_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward process.""" + + if self.num_freqs > 0: + embed = (x[..., None].contiguous() * self.frequencies).view( + *x.shape[:-1], -1 + ) + if self.include_input: + return torch.cat((x, embed.sin(), embed.cos()), dim=-1) + else: + return torch.cat((embed.sin(), embed.cos()), dim=-1) + else: + return x + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob, 3):0.3f}" + + +class MLP(nn.Module): + def __init__( + self, + *, + width: int, + expand_ratio: int = 4, + output_width: int = None, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.width = width + self.c_fc = nn.Linear(width, width * expand_ratio) + self.c_proj = nn.Linear( + width * expand_ratio, output_width if output_width is not None else width + ) + self.gelu = nn.GELU() + self.drop_path = ( + DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + ) + + def forward(self, x): + return self.drop_path(self.c_proj(self.gelu(self.c_fc(x)))) + + +class QKVMultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + heads: int, + n_data: Optional[int] = None, + width=None, + qk_norm=False, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.heads = heads + self.n_data = n_data + self.q_norm = ( + norm_layer(width // heads, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + self.k_norm = ( + norm_layer(width // heads, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + + self.attn_processor = CrossAttentionProcessor() + + def forward(self, q, kv): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + + q = self.q_norm(q) + k = self.k_norm(k) + q, k, v = map( + lambda t: rearrange(t, "b n h d -> b h n d", h=self.heads), (q, k, v) + ) + out = self.attn_processor(self, q, k, v) + out = out.transpose(1, 2).reshape(bs, n_ctx, -1) + return out + + +class MultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool = True, + n_data: Optional[int] = None, + data_width: Optional[int] = None, + norm_layer=nn.LayerNorm, + qk_norm: bool = False, + kv_cache: bool = False, + ): + super().__init__() + self.n_data = n_data + self.width = width + self.heads = heads + self.data_width = width if data_width is None else data_width + self.c_q = nn.Linear(width, width, bias=qkv_bias) + self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias) + self.c_proj = nn.Linear(width, width) + self.attention = QKVMultiheadCrossAttention( + heads=heads, + n_data=n_data, + width=width, + norm_layer=norm_layer, + qk_norm=qk_norm, + ) + self.kv_cache = kv_cache + self.data = None + + def forward(self, x, data): + x = self.c_q(x) + if self.kv_cache: + if self.data is None: + self.data = self.c_kv(data) + logger.info( + "Save kv cache,this should be called only once for one mesh" + ) + data = self.data + else: + data = self.c_kv(data) + x = self.attention(x, data) + x = self.c_proj(x) + return x + + +class ResidualCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + n_data: Optional[int] = None, + width: int, + heads: int, + mlp_expand_ratio: int = 4, + data_width: Optional[int] = None, + qkv_bias: bool = True, + norm_layer=nn.LayerNorm, + qk_norm: bool = False, + ): + super().__init__() + + if data_width is None: + data_width = width + + self.attn = MultiheadCrossAttention( + n_data=n_data, + width=width, + heads=heads, + data_width=data_width, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm, + ) + self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6) + self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) + return x + + +class QKVMultiheadAttention(nn.Module): + def __init__( + self, + *, + heads: int, + n_ctx: int, + width=None, + qk_norm=False, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.heads = heads + self.n_ctx = n_ctx + self.q_norm = ( + norm_layer(width // heads, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + self.k_norm = ( + norm_layer(width // heads, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.heads // 3 + qkv = qkv.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(qkv, attn_ch, dim=-1) + + q = self.q_norm(q) + k = self.k_norm(k) + + q, k, v = map( + lambda t: rearrange(t, "b n h d -> b h n d", h=self.heads), (q, k, v) + ) + out = ( + scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) + ) + return out + + +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + n_ctx: int, + width: int, + heads: int, + qkv_bias: bool, + norm_layer=nn.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.heads = heads + self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias) + self.c_proj = nn.Linear(width, width) + self.attention = QKVMultiheadAttention( + heads=heads, + n_ctx=n_ctx, + width=width, + norm_layer=norm_layer, + qk_norm=qk_norm, + ) + self.drop_path = ( + DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + ) + + def forward(self, x): + x = self.c_qkv(x) + x = self.attention(x) + x = self.drop_path(self.c_proj(x)) + return x + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + *, + n_ctx: int, + width: int, + heads: int, + qkv_bias: bool = True, + norm_layer=nn.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.attn = MultiheadAttention( + n_ctx=n_ctx, + width=width, + heads=heads, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate, + ) + self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.mlp = MLP(width=width, drop_path_rate=drop_path_rate) + self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6) + + def forward(self, x: torch.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, + *, + n_ctx: int, + width: int, + layers: int, + heads: int, + qkv_bias: bool = True, + norm_layer=nn.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + n_ctx=n_ctx, + width=width, + heads=heads, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate, + ) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor): + for block in self.resblocks: + x = block(x) + return x + + +class CrossAttentionDecoder(nn.Module): + + def __init__( + self, + *, + num_latents: int, + out_channels: int, + fourier_embedder: FourierEmbedder, + width: int, + heads: int, + mlp_expand_ratio: int = 4, + downsample_ratio: int = 1, + enable_ln_post: bool = True, + qkv_bias: bool = True, + qk_norm: bool = False, + label_type: str = "binary", + ): + super().__init__() + + self.enable_ln_post = enable_ln_post + self.fourier_embedder = fourier_embedder + self.downsample_ratio = downsample_ratio + self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width) + if self.downsample_ratio != 1: + self.latents_proj = nn.Linear(width * downsample_ratio, width) + if self.enable_ln_post == False: + qk_norm = False + self.cross_attn_decoder = ResidualCrossAttentionBlock( + n_data=num_latents, + width=width, + mlp_expand_ratio=mlp_expand_ratio, + heads=heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + ) + + if self.enable_ln_post: + self.ln_post = nn.LayerNorm(width) + self.output_proj = nn.Linear(width, out_channels) + self.label_type = label_type + + def set_cross_attention_processor(self, processor): + self.cross_attn_decoder.attn.attention.attn_processor = processor + + def forward(self, queries=None, query_embeddings=None, latents=None): + if query_embeddings is None: + fourier_out = self.fourier_embedder(queries) + query_embeddings = self.query_proj(fourier_out.to(latents.dtype)) + + if self.downsample_ratio != 1: + latents = self.latents_proj(latents) + + x = self.cross_attn_decoder(query_embeddings, latents) + + if self.enable_ln_post: + x = self.ln_post(x) + + occ = self.output_proj(x) + return occ + + +def generate_dense_grid_points( + bbox_min: np.ndarray, + bbox_max: np.ndarray, + octree_resolution: int, + indexing: str = "ij", +): + length = bbox_max - bbox_min + num_cells = octree_resolution + + x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) + y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) + z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) + [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) + xyz = np.stack((xs, ys, zs), axis=-1) + grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] + + return xyz, grid_size, length + + +def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float): + """Extract near-surface voxels for hierarchical decoding.""" + device = input_tensor.device + + val = input_tensor + alpha + valid_mask = val > -9000 + + def get_neighbor(t, shift, axis): + if shift == 0: + return t.clone() + pad_dims = [0, 0, 0, 0, 0, 0] + if axis == 0: + pad_idx = 0 if shift > 0 else 1 + pad_dims[pad_idx] = abs(shift) + elif axis == 1: + pad_idx = 2 if shift > 0 else 3 + pad_dims[pad_idx] = abs(shift) + elif axis == 2: + pad_idx = 4 if shift > 0 else 5 + pad_dims[pad_idx] = abs(shift) + + padded = F.pad(t.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode="replicate") + + slice_dims = [slice(None)] * 3 + if axis == 0: + slice_dims[0] = slice(shift, None) if shift > 0 else slice(None, shift) + elif axis == 1: + slice_dims[1] = slice(shift, None) if shift > 0 else slice(None, shift) + elif axis == 2: + slice_dims[2] = slice(shift, None) if shift > 0 else slice(None, shift) + + padded = padded.squeeze(0).squeeze(0) + return padded[slice_dims] + + left = get_neighbor(val, 1, axis=0) + right = get_neighbor(val, -1, axis=0) + back = get_neighbor(val, 1, axis=1) + front = get_neighbor(val, -1, axis=1) + down = get_neighbor(val, 1, axis=2) + up = get_neighbor(val, -1, axis=2) + + def safe_where(neighbor): + return torch.where(neighbor > -9000, neighbor, val) + + left, right = safe_where(left), safe_where(right) + back, front = safe_where(back), safe_where(front) + down, up = safe_where(down), safe_where(up) + + sign = torch.sign(val.to(torch.float32)) + neighbors_sign = torch.stack( + [ + torch.sign(left.to(torch.float32)), + torch.sign(right.to(torch.float32)), + torch.sign(back.to(torch.float32)), + torch.sign(front.to(torch.float32)), + torch.sign(down.to(torch.float32)), + torch.sign(up.to(torch.float32)), + ], + dim=0, + ) + + same_sign = torch.all(neighbors_sign == sign, dim=0) + mask = (~same_sign).to(torch.int32) + return mask * valid_mask.to(torch.int32) + + +class VanillaVolumeDecoder: + """Standard volume decoder using dense grid evaluation.""" + + @torch.no_grad() + def __call__( + self, + latents: torch.FloatTensor, + geo_decoder: Callable, + bounds: Union[Tuple[float], List[float], float] = 1.01, + num_chunks: int = 10000, + octree_resolution: int = None, + enable_pbar: bool = True, + **kwargs, + ): + device = latents.device + dtype = latents.dtype + batch_size = latents.shape[0] + + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + + bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) + xyz_samples, grid_size, length = generate_dense_grid_points( + bbox_min=bbox_min, + bbox_max=bbox_max, + octree_resolution=octree_resolution, + indexing="ij", + ) + xyz_samples = ( + torch.from_numpy(xyz_samples) + .to(device, dtype=dtype) + .contiguous() + .reshape(-1, 3) + ) + + batch_logits = [] + for start in tqdm( + range(0, xyz_samples.shape[0], num_chunks), + desc="Volume Decoding", + disable=not enable_pbar, + ): + chunk_queries = xyz_samples[start : start + num_chunks, :] + chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size) + logits = geo_decoder(queries=chunk_queries, latents=latents) + batch_logits.append(logits) + + grid_logits = torch.cat(batch_logits, dim=1) + grid_logits = grid_logits.view((batch_size, *grid_size)).float() + + return grid_logits + + +class HierarchicalVolumeDecoding: + """Hierarchical volume decoder with multi-resolution refinement.""" + + @torch.no_grad() + def __call__( + self, + latents: torch.FloatTensor, + geo_decoder: Callable, + bounds: Union[Tuple[float], List[float], float] = 1.01, + num_chunks: int = 10000, + mc_level: float = 0.0, + octree_resolution: int = None, + min_resolution: int = 63, + enable_pbar: bool = True, + **kwargs, + ): + device = latents.device + dtype = latents.dtype + + resolutions = [] + if octree_resolution < min_resolution: + resolutions.append(octree_resolution) + while octree_resolution >= min_resolution: + resolutions.append(octree_resolution) + octree_resolution = octree_resolution // 2 + resolutions.reverse() + + # 1. generate query points + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + bbox_min = np.array(bounds[0:3]) + bbox_max = np.array(bounds[3:6]) + bbox_size = bbox_max - bbox_min + + xyz_samples, grid_size, length = generate_dense_grid_points( + bbox_min=bbox_min, + bbox_max=bbox_max, + octree_resolution=resolutions[0], + indexing="ij", + ) + + dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype) + dilate.weight = torch.nn.Parameter( + torch.ones(dilate.weight.shape, dtype=dtype, device=device) + ) + + grid_size = np.array(grid_size) + xyz_samples = ( + torch.from_numpy(xyz_samples) + .to(device, dtype=dtype) + .contiguous() + .reshape(-1, 3) + ) + + # 2. latents to 3d volume + batch_logits = [] + batch_size = latents.shape[0] + for start in tqdm( + range(0, xyz_samples.shape[0], num_chunks), + desc=f"Hierarchical Volume Decoding [r{resolutions[0] + 1}]", + disable=not enable_pbar, + ): + queries = xyz_samples[start : start + num_chunks, :] + batch_queries = repeat(queries, "p c -> b p c", b=batch_size) + logits = geo_decoder(queries=batch_queries, latents=latents) + batch_logits.append(logits) + + grid_logits = torch.cat(batch_logits, dim=1).view( + (batch_size, grid_size[0], grid_size[1], grid_size[2]) + ) + + for octree_depth_now in resolutions[1:]: + grid_size = np.array([octree_depth_now + 1] * 3) + resolution = bbox_size / octree_depth_now + next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device) + next_logits = torch.full( + next_index.shape, -10000.0, dtype=dtype, device=device + ) + curr_points = extract_near_surface_volume_fn( + grid_logits.squeeze(0), mc_level + ) + curr_points += grid_logits.squeeze(0).abs() < 0.95 + + if octree_depth_now == resolutions[-1]: + expand_num = 0 + else: + expand_num = 1 + for i in range(expand_num): + curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0) + cidx_x, cidx_y, cidx_z = torch.where(curr_points > 0) + next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1 + for i in range(2 - expand_num): + next_index = dilate(next_index.unsqueeze(0)).squeeze(0) + nidx = torch.where(next_index > 0) + + next_points = torch.stack(nidx, dim=1) + next_points = next_points * torch.tensor( + resolution, dtype=next_points.dtype, device=device + ) + torch.tensor(bbox_min, dtype=next_points.dtype, device=device) + + # Check if next_points is empty + if next_points.shape[0] == 0: + logger.warning( + f"No valid surface points found at resolution {octree_depth_now}, " + f"skipping this level" + ) + continue + + batch_logits = [] + for start in tqdm( + range(0, next_points.shape[0], num_chunks), + desc=f"Hierarchical Volume Decoding [r{octree_depth_now + 1}]", + disable=not enable_pbar, + ): + queries = next_points[start : start + num_chunks, :] + batch_queries = repeat(queries, "p c -> b p c", b=batch_size) + logits = geo_decoder( + queries=batch_queries.to(latents.dtype), latents=latents + ) + batch_logits.append(logits) + grid_logits = torch.cat(batch_logits, dim=1) + next_logits[nidx] = grid_logits[0, ..., 0] + grid_logits = next_logits.unsqueeze(0) + grid_logits[grid_logits == -10000.0] = float("nan") + + return grid_logits + + +class FlashVDMVolumeDecoding: + """Flash VDM volume decoder with adaptive KV selection.""" + + def __init__(self, topk_mode="mean"): + if topk_mode not in ["mean", "merge"]: + raise ValueError(f"Unsupported topk_mode {topk_mode}") + + if topk_mode == "mean": + self.processor = FlashVDMCrossAttentionProcessor() + else: + self.processor = FlashVDMTopMCrossAttentionProcessor() + + @torch.no_grad() + def __call__( + self, + latents: torch.FloatTensor, + geo_decoder: CrossAttentionDecoder, + bounds: Union[Tuple[float], List[float], float] = 1.01, + num_chunks: int = 10000, + mc_level: float = 0.0, + octree_resolution: int = None, + min_resolution: int = 63, + mini_grid_num: int = 4, + enable_pbar: bool = True, + **kwargs, + ): + processor = self.processor + geo_decoder.set_cross_attention_processor(processor) + + device = latents.device + dtype = latents.dtype + + resolutions = [] + orig_resolution = octree_resolution + if octree_resolution < min_resolution: + resolutions.append(octree_resolution) + while octree_resolution >= min_resolution: + resolutions.append(octree_resolution) + octree_resolution = octree_resolution // 2 + resolutions.reverse() + resolutions[0] = round(resolutions[0] / mini_grid_num) * mini_grid_num - 1 + for i, resolution in enumerate(resolutions[1:]): + resolutions[i + 1] = resolutions[0] * 2 ** (i + 1) + + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + bbox_min = np.array(bounds[0:3]) + bbox_max = np.array(bounds[3:6]) + bbox_size = bbox_max - bbox_min + + xyz_samples, grid_size, length = generate_dense_grid_points( + bbox_min=bbox_min, + bbox_max=bbox_max, + octree_resolution=resolutions[0], + indexing="ij", + ) + + logger.info(f"FlashVDMVolumeDecoding Resolution: {resolutions}") + + dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype) + dilate.weight = torch.nn.Parameter( + torch.ones(dilate.weight.shape, dtype=dtype, device=device) + ) + + grid_size = np.array(grid_size) + + xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype) + batch_size = latents.shape[0] + mini_grid_size = xyz_samples.shape[0] // mini_grid_num + xyz_samples = ( + xyz_samples.view( + mini_grid_num, + mini_grid_size, + mini_grid_num, + mini_grid_size, + mini_grid_num, + mini_grid_size, + 3, + ) + .permute(0, 2, 4, 1, 3, 5, 6) + .reshape(-1, mini_grid_size * mini_grid_size * mini_grid_size, 3) + ) + + batch_logits = [] + num_batchs = max(num_chunks // xyz_samples.shape[1], 1) + for start in tqdm( + range(0, xyz_samples.shape[0], num_batchs), + desc="FlashVDM Volume Decoding", + disable=not enable_pbar, + ): + queries = xyz_samples[start : start + num_batchs, :] + batch = queries.shape[0] + batch_latents = repeat(latents.squeeze(0), "p c -> b p c", b=batch) + processor.topk = True + logits = geo_decoder(queries=queries, latents=batch_latents) + batch_logits.append(logits) + + grid_logits = ( + torch.cat(batch_logits, dim=0) + .reshape( + mini_grid_num, + mini_grid_num, + mini_grid_num, + mini_grid_size, + mini_grid_size, + mini_grid_size, + ) + .permute(0, 3, 1, 4, 2, 5) + .contiguous() + .view((batch_size, grid_size[0], grid_size[1], grid_size[2])) + ) + + for octree_depth_now in resolutions[1:]: + grid_size = np.array([octree_depth_now + 1] * 3) + resolution = bbox_size / octree_depth_now + next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device) + next_logits = torch.full( + next_index.shape, -10000.0, dtype=dtype, device=device + ) + curr_points = extract_near_surface_volume_fn( + grid_logits.squeeze(0), mc_level + ) + curr_points += grid_logits.squeeze(0).abs() < 0.95 + + expand_num = 0 if octree_depth_now == resolutions[-1] else 1 + for _ in range(expand_num): + curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0) + + cidx_x, cidx_y, cidx_z = torch.where(curr_points > 0) + next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1 + for _ in range(2 - expand_num): + next_index = dilate(next_index.unsqueeze(0)).squeeze(0) + nidx = torch.where(next_index > 0) + + next_points = torch.stack(nidx, dim=1) + next_points = next_points * torch.tensor( + resolution, dtype=torch.float32, device=device + ) + torch.tensor(bbox_min, dtype=torch.float32, device=device) + + # Check if next_points is empty (no valid surface points found) + if next_points.shape[0] == 0: + # Skip this resolution level if no points found + # Use the previous grid_logits as fallback + logger.warning( + f"No valid surface points found at resolution {octree_depth_now}, " + f"skipping this level and using previous resolution grid_logits" + ) + continue + + query_grid_num = 6 + min_val = next_points.min(axis=0).values + max_val = next_points.max(axis=0).values + vol_queries_index = ( + (next_points - min_val) / (max_val - min_val) * (query_grid_num - 0.001) + ) + index = torch.floor(vol_queries_index).long() + index = ( + index[..., 0] * (query_grid_num**2) + + index[..., 1] * query_grid_num + + index[..., 2] + ) + index = index.sort() + next_points = next_points[index.indices].unsqueeze(0).contiguous() + unique_values = torch.unique(index.values, return_counts=True) + grid_logits_flat = torch.zeros( + (next_points.shape[1]), dtype=latents.dtype, device=latents.device + ) + input_grid = [[], []] + logits_grid_list = [] + start_num = 0 + sum_num = 0 + for grid_index, count in zip( + unique_values[0].cpu().tolist(), unique_values[1].cpu().tolist() + ): + if sum_num + count < num_chunks or sum_num == 0: + sum_num += count + input_grid[0].append(grid_index) + input_grid[1].append(count) + else: + processor.topk = input_grid + logits_grid = geo_decoder( + queries=next_points[:, start_num : start_num + sum_num], + latents=latents, + ) + start_num = start_num + sum_num + logits_grid_list.append(logits_grid) + input_grid = [[grid_index], [count]] + sum_num = count + if sum_num > 0: + processor.topk = input_grid + logits_grid = geo_decoder( + queries=next_points[:, start_num : start_num + sum_num], + latents=latents, + ) + logits_grid_list.append(logits_grid) + logits_grid = torch.cat(logits_grid_list, dim=1) + grid_logits_flat[index.indices] = logits_grid.squeeze(0).squeeze(-1) + next_logits[nidx] = grid_logits_flat + grid_logits = next_logits.unsqueeze(0) + + grid_logits[grid_logits == -10000.0] = float("nan") + return grid_logits + + +class Latent2MeshOutput: + """Container for mesh output from VAE decoder.""" + + def __init__(self, mesh_v=None, mesh_f=None): + self.mesh_v = mesh_v + self.mesh_f = mesh_f + + +def center_vertices(vertices): + """Translate vertices so bounding box is centered at zero.""" + vert_min = vertices.min(dim=0)[0] + vert_max = vertices.max(dim=0)[0] + vert_center = 0.5 * (vert_min + vert_max) + return vertices - vert_center + + +class SurfaceExtractor: + """Base class for surface extraction algorithms.""" + + def _compute_box_stat( + self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int + ): + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + + bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) + bbox_size = bbox_max - bbox_min + grid_size = [ + int(octree_resolution) + 1, + int(octree_resolution) + 1, + int(octree_resolution) + 1, + ] + return grid_size, bbox_min, bbox_size + + def run(self, *args, **kwargs): + raise NotImplementedError + + def __call__(self, grid_logits, **kwargs): + outputs = [] + for i in range(grid_logits.shape[0]): + try: + vertices, faces = self.run(grid_logits[i], **kwargs) + vertices = vertices.astype(np.float32) + faces = np.ascontiguousarray(faces) + outputs.append(Latent2MeshOutput(mesh_v=vertices, mesh_f=faces)) + except Exception: + import traceback + + traceback.print_exc() + outputs.append(None) + return outputs + + +class MCSurfaceExtractor(SurfaceExtractor): + """Marching Cubes surface extractor.""" + + def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs): + from skimage import measure + + vertices, faces, normals, _ = measure.marching_cubes( + grid_logit.cpu().numpy(), mc_level, method="lewiner" + ) + grid_size, bbox_min, bbox_size = self._compute_box_stat( + bounds, octree_resolution + ) + vertices = vertices / grid_size * bbox_size + bbox_min + return vertices, faces + + +class DMCSurfaceExtractor(SurfaceExtractor): + """Differentiable Marching Cubes surface extractor.""" + + def run(self, grid_logit, *, octree_resolution, **kwargs): + device = grid_logit.device + if not hasattr(self, "dmc"): + try: + from diso import DiffDMC + + self.dmc = DiffDMC(dtype=torch.float32).to(device) + except ImportError: + raise ImportError( + "Please install diso via `pip install diso`, or set mc_algo to 'mc'" + ) + sdf = -grid_logit / octree_resolution + sdf = sdf.to(torch.float32).contiguous() + verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True) + verts = center_vertices(verts) + vertices = verts.detach().cpu().numpy() + faces = faces.detach().cpu().numpy()[:, ::-1] + return vertices, faces + + +SurfaceExtractors = { + "mc": MCSurfaceExtractor, + "dmc": DMCSurfaceExtractor, +} + + +class VectsetVAE(nn.Module): + """Base VAE class for vector set encoding.""" + + def __init__(self, volume_decoder=None, surface_extractor=None): + super().__init__() + if volume_decoder is None: + volume_decoder = VanillaVolumeDecoder() + if surface_extractor is None: + surface_extractor = MCSurfaceExtractor() + self.volume_decoder = volume_decoder + self.surface_extractor = surface_extractor + + def latents2mesh(self, latents: torch.FloatTensor, **kwargs): + """Convert latents to mesh.""" + grid_logits = self.volume_decoder(latents, self.geo_decoder, **kwargs) + outputs = self.surface_extractor(grid_logits, **kwargs) + return outputs + + def enable_flashvdm_decoder( + self, + enabled: bool = True, + adaptive_kv_selection=True, + topk_mode="mean", + mc_algo="dmc", + ): + """Enable or disable FlashVDM decoder for faster inference.""" + if enabled: + if adaptive_kv_selection: + self.volume_decoder = FlashVDMVolumeDecoding(topk_mode) + else: + self.volume_decoder = HierarchicalVolumeDecoding() + if mc_algo not in SurfaceExtractors: + raise ValueError( + f"Unsupported mc_algo {mc_algo}, available: {list(SurfaceExtractors.keys())}" + ) + self.surface_extractor = SurfaceExtractors[mc_algo]() + else: + self.volume_decoder = VanillaVolumeDecoder() + self.surface_extractor = MCSurfaceExtractor() + + +class ShapeVAE(VectsetVAE): + """Shape VAE for 3D mesh generation from latent codes.""" + + _aliases = ["hy3dgen.shapegen.models.ShapeVAE"] + + def __init__( + self, + *, + num_latents: int, + embed_dim: int, + width: int, + heads: int, + num_decoder_layers: int, + num_encoder_layers: int = 8, + pc_size: int = 5120, + pc_sharpedge_size: int = 5120, + point_feats: int = 3, + downsample_ratio: int = 20, + geo_decoder_downsample_ratio: int = 1, + geo_decoder_mlp_expand_ratio: int = 4, + geo_decoder_ln_post: bool = True, + num_freqs: int = 8, + include_pi: bool = True, + qkv_bias: bool = True, + qk_norm: bool = False, + label_type: str = "binary", + drop_path_rate: float = 0.0, + scale_factor: float = 1.0, + use_ln_post: bool = True, + ckpt_path=None, + ): + super().__init__() + self.geo_decoder_ln_post = geo_decoder_ln_post + self.downsample_ratio = downsample_ratio + + self.fourier_embedder = FourierEmbedder( + num_freqs=num_freqs, include_pi=include_pi + ) + + self.post_kl = nn.Linear(embed_dim, width) + + self.transformer = Transformer( + n_ctx=num_latents, + width=width, + layers=num_decoder_layers, + heads=heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate, + ) + + self.geo_decoder = CrossAttentionDecoder( + fourier_embedder=self.fourier_embedder, + out_channels=1, + num_latents=num_latents, + mlp_expand_ratio=geo_decoder_mlp_expand_ratio, + downsample_ratio=geo_decoder_downsample_ratio, + enable_ln_post=self.geo_decoder_ln_post, + width=width // geo_decoder_downsample_ratio, + heads=heads // geo_decoder_downsample_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + label_type=label_type, + ) + + self.scale_factor = scale_factor + self.latent_shape = (num_latents, embed_dim) + + def forward(self, latents): + latents = self.post_kl(latents) + latents = self.transformer(latents) + return latents + + def decode(self, latents): + """Decode latents to features.""" + latents = self.post_kl(latents) + latents = self.transformer(latents) + return latents + + +# Entry class for model registry +EntryClass = ShapeVAE diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/vaes/hunyuanvae.py b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/hunyuanvae.py new file mode 100644 index 0000000000000000000000000000000000000000..972967fa1a40630b890c1e7be30184544254e864 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/hunyuanvae.py @@ -0,0 +1,852 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from diffusers + +# Copyright 2024 The Hunyuan Team, The HuggingFace Team and The sglang-diffusion Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sglang.multimodal_gen.configs.models.vaes import HunyuanVAEConfig +from sglang.multimodal_gen.runtime.layers.activation import get_act_fn +from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE + + +def prepare_causal_attention_mask( + num_frames: int, + height_width: int, + dtype: torch.dtype, + device: torch.device, + batch_size: int | None = None, +) -> torch.Tensor: + indices = torch.arange(1, num_frames + 1, dtype=torch.int32, device=device) + indices_blocks = indices.repeat_interleave(height_width) + x, y = torch.meshgrid(indices_blocks, indices_blocks, indexing="xy") + mask = torch.where(x <= y, 0, -float("inf")).to(dtype=dtype) + + if batch_size is not None: + mask = mask.unsqueeze(0).expand(batch_size, -1, -1) + return mask + + +class HunyuanVAEAttention(nn.Module): + + def __init__( + self, in_channels, heads, dim_head, eps, norm_num_groups, bias + ) -> None: + super().__init__() + self.in_channels = in_channels + self.heads = heads + self.dim_head = dim_head + self.eps = eps + self.norm_num_groups = norm_num_groups + self.bias = bias + + inner_dim = heads * dim_head + + # Define the projection layers + self.to_q = nn.Linear(in_channels, inner_dim, bias=bias) + self.to_k = nn.Linear(in_channels, inner_dim, bias=bias) + self.to_v = nn.Linear(in_channels, inner_dim, bias=bias) + self.to_out = nn.Sequential(nn.Linear(inner_dim, in_channels, bias=bias)) + + # Optional normalization layers + self.group_norm = nn.GroupNorm( + norm_num_groups, in_channels, eps=eps, affine=True + ) + + def forward( + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None + ) -> torch.Tensor: + residual = hidden_states + + batch_size, sequence_length, _ = hidden_states.shape + + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + # Project to query, key, value + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + # Reshape for multi-head attention + head_dim = self.dim_head + + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + # Perform scaled dot-product attention + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + # Reshape back + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, self.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + # Linear projection + hidden_states = self.to_out(hidden_states) + + # Residual connection and rescale + hidden_states = hidden_states + residual + + return hidden_states + + +class HunyuanVideoCausalConv3d(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int, int] = 3, + stride: int | tuple[int, int, int] = 1, + padding: int | tuple[int, int, int] = 0, + dilation: int | tuple[int, int, int] = 1, + bias: bool = True, + pad_mode: str = "replicate", + ) -> None: + super().__init__() + + kernel_size = ( + (kernel_size, kernel_size, kernel_size) + if isinstance(kernel_size, int) + else kernel_size + ) + + self.pad_mode = pad_mode + self.time_causal_padding = ( + kernel_size[0] // 2, + kernel_size[0] // 2, + kernel_size[1] // 2, + kernel_size[1] // 2, + kernel_size[2] - 1, + 0, + ) + + self.conv = nn.Conv3d( + in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad( + hidden_states, self.time_causal_padding, mode=self.pad_mode + ) + return self.conv(hidden_states) + + +class HunyuanVideoUpsampleCausal3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + kernel_size: int = 3, + stride: int = 1, + bias: bool = True, + upsample_factor: tuple[int, ...] = (2, 2, 2), + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + self.upsample_factor = upsample_factor + + self.conv = HunyuanVideoCausalConv3d( + in_channels, out_channels, kernel_size, stride, bias=bias + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_frames = hidden_states.size(2) + + first_frame, other_frames = hidden_states.split((1, num_frames - 1), dim=2) + first_frame = F.interpolate( + first_frame.squeeze(2), + scale_factor=self.upsample_factor[1:], + mode="nearest", + ).unsqueeze(2) + + if num_frames > 1: + # See: https://github.com/pytorch/pytorch/issues/81665 + # Unless you have a version of pytorch where non-contiguous implementation of F.interpolate + # is fixed, this will raise either a runtime error, or fail silently with bad outputs. + # If you are encountering an error here, make sure to try running encoding/decoding with + # `vae.enable_tiling()` first. If that doesn't work, open an issue at: + # https://github.com/huggingface/diffusers/issues + other_frames = other_frames.contiguous() + other_frames = F.interpolate( + other_frames, scale_factor=self.upsample_factor, mode="nearest" + ) + hidden_states = torch.cat((first_frame, other_frames), dim=2) + else: + hidden_states = first_frame + + hidden_states = self.conv(hidden_states) + return hidden_states + + +class HunyuanVideoDownsampleCausal3D(nn.Module): + + def __init__( + self, + channels: int, + out_channels: int | None = None, + padding: int = 1, + kernel_size: int = 3, + bias: bool = True, + stride=2, + ) -> None: + super().__init__() + out_channels = out_channels or channels + + self.conv = HunyuanVideoCausalConv3d( + channels, out_channels, kernel_size, stride, padding, bias=bias + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv(hidden_states) + return hidden_states + + +class HunyuanVideoResnetBlockCausal3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + non_linearity: str = "silu", + ) -> None: + super().__init__() + out_channels = out_channels or in_channels + + self.nonlinearity = get_act_fn(non_linearity) + + self.norm1 = nn.GroupNorm(groups, in_channels, eps=eps, affine=True) + self.conv1 = HunyuanVideoCausalConv3d(in_channels, out_channels, 3, 1, 0) + + self.norm2 = nn.GroupNorm(groups, out_channels, eps=eps, affine=True) + self.dropout = nn.Dropout(dropout) + self.conv2 = HunyuanVideoCausalConv3d(out_channels, out_channels, 3, 1, 0) + + self.conv_shortcut = None + if in_channels != out_channels: + self.conv_shortcut = HunyuanVideoCausalConv3d( + in_channels, out_channels, 1, 1, 0 + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states.contiguous() + residual = hidden_states + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + hidden_states = hidden_states + residual + return hidden_states + + +class HunyuanVideoMidBlock3D(nn.Module): + + def __init__( + self, + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "silu", + resnet_groups: int = 32, + add_attention: bool = True, + attention_head_dim: int = 1, + ) -> None: + super().__init__() + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + self.add_attention = add_attention + + # There is always at least one resnet + resnets = [ + HunyuanVideoResnetBlockCausal3D( + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ] + attentions: list[HunyuanVAEAttention | None] = [] + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + HunyuanVAEAttention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + eps=resnet_eps, + norm_num_groups=resnet_groups, + bias=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + HunyuanVideoResnetBlockCausal3D( + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + self.resnets[0], hidden_states + ) + + for attn, resnet in zip(self.attentions, self.resnets[1:], strict=True): + if attn is not None: + batch_size, num_channels, num_frames, height, width = ( + hidden_states.shape + ) + hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3) + attention_mask = prepare_causal_attention_mask( + num_frames, + height * width, + hidden_states.dtype, + hidden_states.device, + batch_size=batch_size, + ) + hidden_states = attn(hidden_states, attention_mask=attention_mask) + hidden_states = hidden_states.unflatten( + 1, (num_frames, height, width) + ).permute(0, 4, 1, 2, 3) + + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) + + else: + hidden_states = self.resnets[0](hidden_states) + + for attn, resnet in zip(self.attentions, self.resnets[1:], strict=True): + if attn is not None: + batch_size, num_channels, num_frames, height, width = ( + hidden_states.shape + ) + hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3) + attention_mask = prepare_causal_attention_mask( + num_frames, + height * width, + hidden_states.dtype, + hidden_states.device, + batch_size=batch_size, + ) + hidden_states = attn(hidden_states, attention_mask=attention_mask) + hidden_states = hidden_states.unflatten( + 1, (num_frames, height, width) + ).permute(0, 4, 1, 2, 3) + + hidden_states = resnet(hidden_states) + + return hidden_states + + +class HunyuanVideoDownBlock3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "silu", + resnet_groups: int = 32, + add_downsample: bool = True, + downsample_stride: tuple[int, ...] | int = 2, + downsample_padding: int = 1, + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + HunyuanVideoResnetBlockCausal3D( + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + HunyuanVideoDownsampleCausal3D( + out_channels, + out_channels=out_channels, + padding=downsample_padding, + stride=downsample_stride, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + for resnet in self.resnets: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) + else: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class HunyuanVideoUpBlock3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "silu", + resnet_groups: int = 32, + add_upsample: bool = True, + upsample_scale_factor: tuple[int, ...] = (2, 2, 2), + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + HunyuanVideoResnetBlockCausal3D( + in_channels=input_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + HunyuanVideoUpsampleCausal3D( + out_channels, + out_channels=out_channels, + upsample_factor=upsample_scale_factor, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + for resnet in self.resnets: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) + + else: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class HunyuanVideoEncoder3D(nn.Module): + r""" + Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603). + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: tuple[str, ...] = ( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + block_out_channels: tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + temporal_compression_ratio: int = 4, + spatial_compression_ratio: int = 8, + ) -> None: + super().__init__() + + self.conv_in = HunyuanVideoCausalConv3d( + in_channels, block_out_channels[0], kernel_size=3, stride=1 + ) + self.mid_block: HunyuanVideoMidBlock3D | None = None + self.down_blocks = nn.ModuleList([]) + + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + if down_block_type != "HunyuanVideoDownBlock3D": + raise ValueError(f"Unsupported down_block_type: {down_block_type}") + + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio)) + num_time_downsample_layers = int(np.log2(temporal_compression_ratio)) + + if temporal_compression_ratio == 4: + add_spatial_downsample = bool(i < num_spatial_downsample_layers) + add_time_downsample = bool( + i >= (len(block_out_channels) - 1 - num_time_downsample_layers) + and not is_final_block + ) + elif temporal_compression_ratio == 8: + add_spatial_downsample = bool(i < num_spatial_downsample_layers) + add_time_downsample = bool(i < num_time_downsample_layers) + else: + raise ValueError( + f"Unsupported time_compression_ratio: {temporal_compression_ratio}" + ) + + downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1) + downsample_stride_T = (2,) if add_time_downsample else (1,) + downsample_stride = tuple(downsample_stride_T + downsample_stride_HW) + + down_block = HunyuanVideoDownBlock3D( + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=bool(add_spatial_downsample or add_time_downsample), + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + downsample_stride=downsample_stride, + downsample_padding=0, + ) + + self.down_blocks.append(down_block) + + self.mid_block = HunyuanVideoMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + add_attention=mid_block_add_attention, + ) + + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 + ) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = HunyuanVideoCausalConv3d( + block_out_channels[-1], conv_out_channels, kernel_size=3 + ) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for down_block in self.down_blocks: + hidden_states = self._gradient_checkpointing_func( + down_block, hidden_states + ) + + hidden_states = self._gradient_checkpointing_func( + self.mid_block, hidden_states + ) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + assert self.mid_block is not None + hidden_states = self.mid_block(hidden_states) + + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class HunyuanVideoDecoder3D(nn.Module): + r""" + Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603). + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: tuple[str, ...] = ( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + block_out_channels: tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + mid_block_add_attention=True, + time_compression_ratio: int = 4, + spatial_compression_ratio: int = 8, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = HunyuanVideoCausalConv3d( + in_channels, block_out_channels[-1], kernel_size=3, stride=1 + ) + self.up_blocks = nn.ModuleList([]) + + # mid + self.mid_block = HunyuanVideoMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + add_attention=mid_block_add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + if up_block_type != "HunyuanVideoUpBlock3D": + raise ValueError(f"Unsupported up_block_type: {up_block_type}") + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio)) + num_time_upsample_layers = int(np.log2(time_compression_ratio)) + + if time_compression_ratio == 4: + add_spatial_upsample = bool(i < num_spatial_upsample_layers) + add_time_upsample = bool( + i >= len(block_out_channels) - 1 - num_time_upsample_layers + and not is_final_block + ) + else: + raise ValueError( + f"Unsupported time_compression_ratio: {time_compression_ratio}" + ) + + upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1) + upsample_scale_factor_T = (2,) if add_time_upsample else (1,) + upsample_scale_factor = tuple( + upsample_scale_factor_T + upsample_scale_factor_HW + ) + + up_block = HunyuanVideoUpBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=bool(add_spatial_upsample or add_time_upsample), + upsample_scale_factor=upsample_scale_factor, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + ) + + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 + ) + self.conv_act = nn.SiLU() + self.conv_out = HunyuanVideoCausalConv3d( + block_out_channels[0], out_channels, kernel_size=3 + ) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + self.mid_block, hidden_states + ) + + for up_block in self.up_blocks: + hidden_states = self._gradient_checkpointing_func( + up_block, hidden_states + ) + else: + hidden_states = self.mid_block(hidden_states) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states) + + # post-process + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class AutoencoderKLHunyuanVideo(ParallelTiledVAE): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + config: HunyuanVAEConfig, + ) -> None: + nn.Module.__init__(self) + ParallelTiledVAE.__init__(self, config) + + # TODO(will): only pass in config. We do this by manually defining a + # config for hunyuan vae + self.block_out_channels = config.block_out_channels + + if config.load_encoder: + self.encoder = HunyuanVideoEncoder3D( + in_channels=config.in_channels, + out_channels=config.latent_channels, + down_block_types=config.down_block_types, + block_out_channels=config.block_out_channels, + layers_per_block=config.layers_per_block, + norm_num_groups=config.norm_num_groups, + act_fn=config.act_fn, + double_z=True, + mid_block_add_attention=config.mid_block_add_attention, + temporal_compression_ratio=config.temporal_compression_ratio, + spatial_compression_ratio=config.spatial_compression_ratio, + ) + self.quant_conv = nn.Conv3d( + 2 * config.latent_channels, 2 * config.latent_channels, kernel_size=1 + ) + + if config.load_decoder: + self.decoder = HunyuanVideoDecoder3D( + in_channels=config.latent_channels, + out_channels=config.out_channels, + up_block_types=config.up_block_types, + block_out_channels=config.block_out_channels, + layers_per_block=config.layers_per_block, + norm_num_groups=config.norm_num_groups, + act_fn=config.act_fn, + time_compression_ratio=config.temporal_compression_ratio, + spatial_compression_ratio=config.spatial_compression_ratio, + mid_block_add_attention=config.mid_block_add_attention, + ) + self.post_quant_conv = nn.Conv3d( + config.latent_channels, config.latent_channels, kernel_size=1 + ) + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + x = self.encoder(x) + enc = self.quant_conv(x) + return enc + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + return dec + + +EntryClass = AutoencoderKLHunyuanVideo diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/vaes/ltx_2_audio.py b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/ltx_2_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..0341771a0ac9252ebc7db023477abea3772a063b --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/ltx_2_audio.py @@ -0,0 +1,905 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from diffusers.models.autoencoders.vae import ( + DecoderOutput, + DiagonalGaussianDistribution, +) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from torch import nn + +from sglang.multimodal_gen.configs.models.vaes.ltx_audio import LTXAudioVAEConfig +from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE + +LATENT_DOWNSAMPLE_FACTOR = 4 + + +class LTX2AudioCausalConv2d(nn.Module): + """ + A causal 2D convolution that pads asymmetrically along the causal axis. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: int = 1, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + causality_axis: str = "height", + ) -> None: + super().__init__() + + self.causality_axis = causality_axis + kernel_size = ( + (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + ) + dilation = (dilation, dilation) if isinstance(dilation, int) else dilation + + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + if self.causality_axis == "none": + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + elif self.causality_axis in {"width", "width-compatibility"}: + padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) + elif self.causality_axis == "height": + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) + else: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + self.padding = padding + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, self.padding) + return self.conv(x) + + +class LTX2AudioPixelNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + """ + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + super().__init__() + self.dim = dim + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True) + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + +class LTX2AudioAttnBlock(nn.Module): + def __init__( + self, + in_channels: int, + norm_type: str = "group", + ) -> None: + super().__init__() + self.in_channels = in_channels + + if norm_type == "group": + self.norm = nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + elif norm_type == "pixel": + self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = self.norm(x) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + batch, channels, height, width = q.shape + q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous() + k = k.reshape(batch, channels, height * width).contiguous() + attn = torch.bmm(q, k) * (int(channels) ** (-0.5)) + attn = torch.nn.functional.softmax(attn, dim=2) + + v = v.reshape(batch, channels, height * width) + attn = attn.permute(0, 2, 1).contiguous() + h_ = torch.bmm(v, attn).reshape(batch, channels, height, width) + + h_ = self.proj_out(h_) + return x + h_ + + +class LTX2AudioResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + norm_type: str = "group", + causality_axis: str = "height", + ) -> None: + super().__init__() + self.causality_axis = causality_axis + + if ( + self.causality_axis is not None + and self.causality_axis != "none" + and norm_type == "group" + ): + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + if norm_type == "group": + self.norm1 = nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + elif norm_type == "pixel": + self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") + self.non_linearity = nn.SiLU() + if causality_axis is not None: + self.conv1 = LTX2AudioCausalConv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + causality_axis=causality_axis, + ) + else: + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = nn.Linear(temb_channels, out_channels) + if norm_type == "group": + self.norm2 = nn.GroupNorm( + num_groups=32, num_channels=out_channels, eps=1e-6, affine=True + ) + elif norm_type == "pixel": + self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") + self.dropout = nn.Dropout(dropout) + if causality_axis is not None: + self.conv2 = LTX2AudioCausalConv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + causality_axis=causality_axis, + ) + else: + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + if causality_axis is not None: + self.conv_shortcut = LTX2AudioCausalConv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + causality_axis=causality_axis, + ) + else: + self.conv_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + if causality_axis is not None: + self.nin_shortcut = LTX2AudioCausalConv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + causality_axis=causality_axis, + ) + else: + self.nin_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward( + self, x: torch.Tensor, temb: Optional[torch.Tensor] = None + ) -> torch.Tensor: + h = self.norm1(x) + h = self.non_linearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.non_linearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = ( + self.conv_shortcut(x) + if self.use_conv_shortcut + else self.nin_shortcut(x) + ) + + return x + h + + +class LTX2AudioDownsample(nn.Module): + def __init__( + self, + in_channels: int, + with_conv: bool, + causality_axis: Optional[str] = "height", + ) -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.with_conv: + # Padding tuple is in the order: (left, right, top, bottom). + if self.causality_axis == "none": + pad = (0, 1, 0, 1) + elif self.causality_axis == "width": + pad = (2, 0, 0, 1) + elif self.causality_axis == "height": + pad = (0, 1, 2, 0) + elif self.causality_axis == "width-compatibility": + pad = (1, 0, 0, 1) + else: + raise ValueError( + f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`," + f" and `width-compatibility`." + ) + + x = F.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + # with_conv=False implies that causality_axis is "none" + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class LTX2AudioUpsample(nn.Module): + def __init__( + self, + in_channels: int, + with_conv: bool, + causality_axis: Optional[str] = "height", + ) -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + if causality_axis is not None: + self.conv = LTX2AudioCausalConv2d( + in_channels, + in_channels, + kernel_size=3, + stride=1, + causality_axis=causality_axis, + ) + else: + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + if self.causality_axis is None or self.causality_axis == "none": + pass + elif self.causality_axis == "height": + x = x[:, :, 1:, :] + elif self.causality_axis == "width": + x = x[:, :, :, 1:] + elif self.causality_axis == "width-compatibility": + pass + else: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +class LTX2AudioAudioPatchifier: + """ + Patchifier for spectrogram/audio latents. + """ + + def __init__( + self, + patch_size: int, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + is_causal: bool = True, + ): + self.hop_length = hop_length + self.sample_rate = sample_rate + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.is_causal = is_causal + self._patch_size = (1, patch_size, patch_size) + + def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor: + batch, channels, time, freq = audio_latents.shape + return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq) + + def unpatchify( + self, audio_latents: torch.Tensor, channels: int, mel_bins: int + ) -> torch.Tensor: + batch, time, _ = audio_latents.shape + return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3) + + @property + def patch_size(self) -> Tuple[int, int, int]: + return self._patch_size + + +class LTX2AudioEncoder(nn.Module): + def __init__( + self, + base_channels: int = 128, + output_channels: int = 1, + num_res_blocks: int = 2, + attn_resolutions: Optional[Tuple[int, ...]] = None, + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + ch_mult: Tuple[int, ...] = (1, 2, 4), + norm_type: str = "group", + causality_axis: Optional[str] = "width", + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: Optional[int] = 64, + double_z: bool = True, + ): + super().__init__() + + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + + self.base_channels = base_channels + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = output_channels + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.latent_channels = latent_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = causality_axis + + base_block_channels = base_channels + base_resolution = resolution + self.z_shape = (1, latent_channels, base_resolution, base_resolution) + + if self.causality_axis is not None: + self.conv_in = LTX2AudioCausalConv2d( + in_channels, + base_block_channels, + kernel_size=3, + stride=1, + causality_axis=self.causality_axis, + ) + else: + self.conv_in = nn.Conv2d( + in_channels, base_block_channels, kernel_size=3, stride=1, padding=1 + ) + + self.down = nn.ModuleList() + block_in = base_block_channels + curr_res = self.resolution + + for level in range(self.num_resolutions): + stage = nn.Module() + stage.block = nn.ModuleList() + stage.attn = nn.ModuleList() + block_out = self.base_channels * self.channel_multipliers[level] + + for _ in range(self.num_res_blocks): + stage.block.append( + LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + ) + block_in = block_out + if self.attn_resolutions: + if curr_res in self.attn_resolutions: + stage.attn.append( + LTX2AudioAttnBlock(block_in, norm_type=self.norm_type) + ) + + if level != self.num_resolutions - 1: + stage.downsample = LTX2AudioDownsample( + block_in, True, causality_axis=self.causality_axis + ) + curr_res = curr_res // 2 + + self.down.append(stage) + + self.mid = nn.Module() + self.mid.block_1 = LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + + final_block_channels = block_in + z_channels = 2 * latent_channels if double_z else latent_channels + if self.norm_type == "group": + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True + ) + elif self.norm_type == "pixel": + self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {self.norm_type}") + self.non_linearity = nn.SiLU() + + if self.causality_axis is not None: + self.conv_out = LTX2AudioCausalConv2d( + final_block_channels, + z_channels, + kernel_size=3, + stride=1, + causality_axis=self.causality_axis, + ) + else: + self.conv_out = nn.Conv2d( + final_block_channels, z_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # hidden_states expected shape: (batch_size, channels, time, num_mel_bins) + hidden_states = self.conv_in(hidden_states) + + for level in range(self.num_resolutions): + stage = self.down[level] + for block_idx, block in enumerate(stage.block): + hidden_states = block(hidden_states, temb=None) + if stage.attn: + hidden_states = stage.attn[block_idx](hidden_states) + + if level != self.num_resolutions - 1 and hasattr(stage, "downsample"): + hidden_states = stage.downsample(hidden_states) + + hidden_states = self.mid.block_1(hidden_states, temb=None) + hidden_states = self.mid.attn_1(hidden_states) + hidden_states = self.mid.block_2(hidden_states, temb=None) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.non_linearity(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class LTX2AudioDecoder(nn.Module): + """ + Symmetric decoder that reconstructs audio spectrograms from latent features. + + The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal + convolutions. + """ + + def __init__( + self, + base_channels: int = 128, + output_channels: int = 1, + num_res_blocks: int = 2, + attn_resolutions: Optional[Tuple[int, ...]] = None, + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + ch_mult: Tuple[int, ...] = (1, 2, 4), + norm_type: str = "group", + causality_axis: Optional[str] = "width", + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: Optional[int] = 64, + ) -> None: + super().__init__() + + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + self.patchifier = LTX2AudioAudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=sample_rate, + hop_length=mel_hop_length, + is_causal=is_causal, + ) + + self.base_channels = base_channels + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = output_channels + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.latent_channels = latent_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = causality_axis + + base_block_channels = base_channels * self.channel_multipliers[-1] + base_resolution = resolution // (2 ** (self.num_resolutions - 1)) + self.z_shape = (1, latent_channels, base_resolution, base_resolution) + + if self.causality_axis is not None: + self.conv_in = LTX2AudioCausalConv2d( + latent_channels, + base_block_channels, + kernel_size=3, + stride=1, + causality_axis=self.causality_axis, + ) + else: + self.conv_in = nn.Conv2d( + latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1 + ) + self.non_linearity = nn.SiLU() + self.mid = nn.Module() + self.mid.block_1 = LTX2AudioResnetBlock( + in_channels=base_block_channels, + out_channels=base_block_channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = LTX2AudioAttnBlock( + base_block_channels, norm_type=self.norm_type + ) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = LTX2AudioResnetBlock( + in_channels=base_block_channels, + out_channels=base_block_channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + + self.up = nn.ModuleList() + block_in = base_block_channels + curr_res = self.resolution // (2 ** (self.num_resolutions - 1)) + + for level in reversed(range(self.num_resolutions)): + stage = nn.Module() + stage.block = nn.ModuleList() + stage.attn = nn.ModuleList() + block_out = self.base_channels * self.channel_multipliers[level] + + for _ in range(self.num_res_blocks + 1): + stage.block.append( + LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + ) + block_in = block_out + if self.attn_resolutions: + if curr_res in self.attn_resolutions: + stage.attn.append( + LTX2AudioAttnBlock(block_in, norm_type=self.norm_type) + ) + + if level != 0: + stage.upsample = LTX2AudioUpsample( + block_in, True, causality_axis=self.causality_axis + ) + curr_res *= 2 + + self.up.insert(0, stage) + + final_block_channels = block_in + + if self.norm_type == "group": + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True + ) + elif self.norm_type == "pixel": + self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {self.norm_type}") + + if self.causality_axis is not None: + self.conv_out = LTX2AudioCausalConv2d( + final_block_channels, + output_channels, + kernel_size=3, + stride=1, + causality_axis=self.causality_axis, + ) + else: + self.conv_out = nn.Conv2d( + final_block_channels, + output_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward( + self, + sample: torch.Tensor, + ) -> torch.Tensor: + _, _, frames, mel_bins = sample.shape + + target_frames = frames * LATENT_DOWNSAMPLE_FACTOR + + if self.causality_axis is not None: + target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) + + target_channels = self.out_ch + target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins + + hidden_features = self.conv_in(sample) + hidden_features = self.mid.block_1(hidden_features, temb=None) + hidden_features = self.mid.attn_1(hidden_features) + hidden_features = self.mid.block_2(hidden_features, temb=None) + + for level in reversed(range(self.num_resolutions)): + stage = self.up[level] + for block_idx, block in enumerate(stage.block): + hidden_features = block(hidden_features, temb=None) + if stage.attn: + hidden_features = stage.attn[block_idx](hidden_features) + + if level != 0 and hasattr(stage, "upsample"): + hidden_features = stage.upsample(hidden_features) + + if self.give_pre_end: + return hidden_features + + hidden = self.norm_out(hidden_features) + hidden = self.non_linearity(hidden) + decoded_output = self.conv_out(hidden) + decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output + + _, _, current_time, current_freq = decoded_output.shape + target_time = target_frames + target_freq = target_mel_bins + + decoded_output = decoded_output[ + :, + :target_channels, + : min(current_time, target_time), + : min(current_freq, target_freq), + ] + + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + if time_padding_needed > 0 or freq_padding_needed > 0: + padding = ( + 0, + max(freq_padding_needed, 0), + 0, + max(time_padding_needed, 0), + ) + decoded_output = F.pad(decoded_output, padding) + + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output + + +class AutoencoderKLLTX2Audio(ParallelTiledVAE): + r""" + LTX2 audio VAE for encoding and decoding audio latent representations. + """ + + _supports_gradient_checkpointing = False + + def __init__( + self, + config: LTXAudioVAEConfig, + ) -> None: + super().__init__(config=config) + + causality_axis = config.arch_config.causality_axis + attn_resolutions = config.arch_config.attn_resolutions + base_channels = config.arch_config.base_channels + output_channels = config.arch_config.output_channels + ch_mult = config.arch_config.ch_mult + num_res_blocks = config.arch_config.num_res_blocks + in_channels = config.arch_config.in_channels + resolution = config.arch_config.resolution + latent_channels = config.arch_config.latent_channels + norm_type = config.arch_config.norm_type + dropout = config.arch_config.dropout + mid_block_add_attention = config.arch_config.mid_block_add_attention + sample_rate = config.arch_config.sample_rate + mel_hop_length = config.arch_config.mel_hop_length + is_causal = config.arch_config.is_causal + mel_bins = config.arch_config.mel_bins + double_z = config.arch_config.double_z + + supported_causality_axes = {"none", "width", "height", "width-compatibility"} + if causality_axis not in supported_causality_axes: + raise ValueError( + f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}" + ) + + attn_resolution_set = ( + set(attn_resolutions) if attn_resolutions else attn_resolutions + ) + + self.encoder = LTX2AudioEncoder( + base_channels=base_channels, + output_channels=output_channels, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolution_set, + in_channels=in_channels, + resolution=resolution, + latent_channels=latent_channels, + norm_type=norm_type, + causality_axis=causality_axis, + dropout=dropout, + mid_block_add_attention=mid_block_add_attention, + sample_rate=sample_rate, + mel_hop_length=mel_hop_length, + is_causal=is_causal, + mel_bins=mel_bins, + double_z=double_z, + ) + + self.decoder = LTX2AudioDecoder( + base_channels=base_channels, + output_channels=output_channels, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolution_set, + in_channels=in_channels, + resolution=resolution, + latent_channels=latent_channels, + norm_type=norm_type, + causality_axis=causality_axis, + dropout=dropout, + mid_block_add_attention=mid_block_add_attention, + sample_rate=sample_rate, + mel_hop_length=mel_hop_length, + is_causal=is_causal, + mel_bins=mel_bins, + ) + + # Per-channel statistics for normalizing and denormalizing the latent representation. This statistics is computed over + # the entire dataset and stored in model's checkpoint under AudioVAE state_dict + latents_std = torch.zeros((base_channels,)) + latents_mean = torch.ones((base_channels,)) + self.register_buffer("latents_mean", latents_mean, persistent=True) + self.register_buffer("latents_std", latents_std, persistent=True) + + # TODO: confirm whether the mel compression ratio below is correct + self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR + self.use_slicing = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + return self.encoder(x) + + def encode(self, x: torch.Tensor, return_dict: bool = True): + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + return self.decoder(z) + + def decode( + self, z: torch.Tensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + posterior = self.encode(sample).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + if not return_dict: + return (dec.sample,) + return dec + + +EntryClass = AutoencoderKLLTX2Audio diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/vaes/ltx_2_vae.py b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/ltx_2_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..77c4e512bf9f598f24582910bae96209c030e4da --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/ltx_2_vae.py @@ -0,0 +1,1676 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from diffusers.models.activations import get_activation +from diffusers.models.autoencoders.vae import ( + DecoderOutput, + DiagonalGaussianDistribution, +) +from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings +from diffusers.models.modeling_outputs import AutoencoderKLOutput + +from sglang.multimodal_gen.configs.models.vaes.ltx_video import LTXVideoVAEConfig +from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE + + +class PerChannelRMSNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + + For each element along the chosen dimension, this layer normalizes the tensor by the root-mean-square of its values + across that dimension: + + y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps) + """ + + def __init__(self, channel_dim: int = 1, eps: float = 1e-8) -> None: + """ + Args: + dim: Dimension along which to compute the RMS (typically channels). + eps: Small constant added for numerical stability. + """ + super().__init__() + self.channel_dim = channel_dim + self.eps = eps + + def forward( + self, x: torch.Tensor, channel_dim: Optional[int] = None + ) -> torch.Tensor: + """ + Apply RMS normalization along the configured dimension. + """ + channel_dim = channel_dim or self.channel_dim + # Compute mean of squared values along `dim`, keep dimensions for broadcasting. + mean_sq = torch.mean(x**2, dim=self.channel_dim, keepdim=True) + # Normalize by the root-mean-square (RMS). + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + +# Like LTXCausalConv3d, but whether causal inference is performed can be specified at runtime +class LTX2VideoCausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + stride: Union[int, Tuple[int, int, int]] = 1, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups: int = 1, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = ( + kernel_size + if isinstance(kernel_size, tuple) + else (kernel_size, kernel_size, kernel_size) + ) + + dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1) + stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + height_pad = self.kernel_size[1] // 2 + width_pad = self.kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + self.kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + padding=padding, + padding_mode=spatial_padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: + time_kernel_size = self.kernel_size[0] + + if causal: + pad_left = hidden_states[:, :, :1, :, :].repeat( + (1, 1, time_kernel_size - 1, 1, 1) + ) + hidden_states = torch.concatenate([pad_left, hidden_states], dim=2) + else: + pad_left = hidden_states[:, :, :1, :, :].repeat( + (1, 1, (time_kernel_size - 1) // 2, 1, 1) + ) + pad_right = hidden_states[:, :, -1:, :, :].repeat( + (1, 1, (time_kernel_size - 1) // 2, 1, 1) + ) + hidden_states = torch.concatenate( + [pad_left, hidden_states, pad_right], dim=2 + ) + + hidden_states = self.conv(hidden_states) + return hidden_states + + +# Like LTXVideoResnetBlock3d, but uses new causal Conv3d, normal Conv3d for the conv_shortcut, and the spatial padding +# mode is configurable +class LTX2VideoResnetBlock3d(nn.Module): + r""" + A 3D ResNet block used in the LTX 2.0 audiovisual model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + dropout (`float`, defaults to `0.0`): + Dropout rate. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + elementwise_affine (`bool`, defaults to `False`): + Whether to enable elementwise affinity in the normalization layers. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + conv_shortcut (bool, defaults to `False`): + Whether or not to use a convolution shortcut. + """ + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + eps: float = 1e-6, + elementwise_affine: bool = False, + non_linearity: str = "swish", + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = PerChannelRMSNorm() + self.conv1 = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + spatial_padding_mode=spatial_padding_mode, + ) + + self.norm2 = PerChannelRMSNorm() + self.dropout = nn.Dropout(dropout) + self.conv2 = LTX2VideoCausalConv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + spatial_padding_mode=spatial_padding_mode, + ) + + self.norm3 = None + self.conv_shortcut = None + if in_channels != out_channels: + self.norm3 = nn.LayerNorm( + in_channels, eps=eps, elementwise_affine=True, bias=True + ) + # LTX 2.0 uses a normal nn.Conv3d here rather than LTXVideoCausalConv3d + self.conv_shortcut = nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + ) + + self.per_channel_scale1 = None + self.per_channel_scale2 = None + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + + self.scale_shift_table = None + if timestep_conditioning: + self.scale_shift_table = nn.Parameter( + torch.randn(4, in_channels) / in_channels**0.5 + ) + + def forward( + self, + inputs: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + causal: bool = True, + ) -> torch.Tensor: + hidden_states = inputs + + hidden_states = self.norm1(hidden_states) + + if self.scale_shift_table is not None: + temb = ( + temb.unflatten(1, (4, -1)) + + self.scale_shift_table[None, ..., None, None, None] + ) + shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale_1) + shift_1 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states, causal=causal) + + if self.per_channel_scale1 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, + generator=generator, + device=hidden_states.device, + dtype=hidden_states.dtype, + )[None] + hidden_states = ( + hidden_states + + (spatial_noise * self.per_channel_scale1)[None, :, None, ...] + ) + + hidden_states = self.norm2(hidden_states) + + if self.scale_shift_table is not None: + hidden_states = hidden_states * (1 + scale_2) + shift_2 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states, causal=causal) + + if self.per_channel_scale2 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, + generator=generator, + device=hidden_states.device, + dtype=hidden_states.dtype, + )[None] + hidden_states = ( + hidden_states + + (spatial_noise * self.per_channel_scale2)[None, :, None, ...] + ) + + if self.norm3 is not None: + inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1) + + if self.conv_shortcut is not None: + inputs = self.conv_shortcut(inputs) + + hidden_states = hidden_states + inputs + return hidden_states + + +# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d +class LTXVideoDownsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: Union[int, Tuple[int, int, int]] = 1, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.group_size = ( + in_channels * stride[0] * stride[1] * stride[2] + ) // out_channels + + out_channels = out_channels // ( + self.stride[0] * self.stride[1] * self.stride[2] + ) + + self.conv = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: + hidden_states = torch.cat( + [hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2 + ) + + residual = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + residual = residual.unflatten(1, (-1, self.group_size)) + residual = residual.mean(dim=2) + + hidden_states = self.conv(hidden_states, causal=causal) + hidden_states = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + hidden_states = hidden_states + residual + + return hidden_states + + +# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d +class LTXVideoUpsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + stride: Union[int, Tuple[int, int, int]] = 1, + residual: bool = False, + upscale_factor: int = 1, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.residual = residual + self.upscale_factor = upscale_factor + + out_channels = ( + in_channels * stride[0] * stride[1] * stride[2] + ) // upscale_factor + + self.conv = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.residual: + residual = hidden_states.reshape( + batch_size, + -1, + self.stride[0], + self.stride[1], + self.stride[2], + num_frames, + height, + width, + ) + residual = ( + residual.permute(0, 1, 5, 2, 6, 3, 7, 4) + .flatten(6, 7) + .flatten(4, 5) + .flatten(2, 3) + ) + repeats = ( + self.stride[0] * self.stride[1] * self.stride[2] + ) // self.upscale_factor + residual = residual.repeat(1, repeats, 1, 1, 1) + residual = residual[:, :, self.stride[0] - 1 :] + + hidden_states = self.conv(hidden_states, causal=causal) + hidden_states = hidden_states.reshape( + batch_size, + -1, + self.stride[0], + self.stride[1], + self.stride[2], + num_frames, + height, + width, + ) + hidden_states = ( + hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4) + .flatten(6, 7) + .flatten(4, 5) + .flatten(2, 3) + ) + hidden_states = hidden_states[:, :, self.stride[0] - 1 :] + + if self.residual: + hidden_states = hidden_states + residual + + return hidden_states + + +# Like LTX 1.0 LTXVideo095DownBlock3D, but with the updated LTX2VideoResnetBlock3d +class LTX2VideoDownBlock3D(nn.Module): + r""" + Down block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + downsample_type: str = "conv", + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + out_channels = out_channels or in_channels + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + spatial_padding_mode=spatial_padding_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.downsamplers = None + if spatio_temporal_scale: + self.downsamplers = nn.ModuleList() + + if downsample_type == "conv": + self.downsamplers.append( + LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + ) + elif downsample_type == "spatial": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, + out_channels=out_channels, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + ) + elif downsample_type == "temporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + ) + elif downsample_type == "spatiotemporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + causal: bool = True, + ) -> torch.Tensor: + r"""Forward method of the `LTXDownBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + resnet, hidden_states, temb, generator, causal + ) + else: + hidden_states = resnet(hidden_states, temb, generator, causal=causal) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, causal=causal) + + return hidden_states + + +# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d +# Like LTX 1.0 LTXVideoMidBlock3d, but with the updated LTX2VideoResnetBlock3d +class LTX2VideoMidBlock3d(nn.Module): + r""" + A middle block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( + in_channels * 4, 0 + ) + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + causal: bool = True, + ) -> torch.Tensor: + r"""Forward method of the `LTXMidBlock3D` class.""" + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + resnet, hidden_states, temb, generator, causal + ) + else: + hidden_states = resnet(hidden_states, temb, generator, causal=causal) + + return hidden_states + + +# Like LTXVideoUpBlock3d but with no conv_in and the updated LTX2VideoResnetBlock3d +class LTX2VideoUpBlock3d(nn.Module): + r""" + Up block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + upsample_residual: bool = False, + upscale_factor: int = 1, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( + in_channels * 4, 0 + ) + + self.conv_in = None + if in_channels != out_channels: + self.conv_in = LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + + self.upsamplers = None + if spatio_temporal_scale: + self.upsamplers = nn.ModuleList( + [ + LTXVideoUpsampler3d( + out_channels * upscale_factor, + stride=(2, 2, 2), + residual=upsample_residual, + upscale_factor=upscale_factor, + spatial_padding_mode=spatial_padding_mode, + ) + ] + ) + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=out_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + causal: bool = True, + ) -> torch.Tensor: + if self.conv_in is not None: + hidden_states = self.conv_in(hidden_states, temb, generator, causal=causal) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, causal=causal) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + resnet, hidden_states, temb, generator, causal + ) + else: + hidden_states = resnet(hidden_states, temb, generator, causal=causal) + + return hidden_states + + +# Like LTX 1.0 LTXVideoEncoder3d but with different default args - the spatiotemporal downsampling pattern is +# different, as is the layers_per_block (the 2.0 VAE is bigger) +class LTX2VideoEncoder3d(nn.Module): + r""" + The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent + representation. + + Args: + in_channels (`int`, defaults to 3): + Number of input channels. + out_channels (`int`, defaults to 128): + Number of latent channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(256, 512, 1024, 2048)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, True)`: + Whether a block should contain spatio-temporal downscaling layers or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 6, 6, 2, 2)`): + The number of layers per block. + downsample_type (`Tuple[str, ...]`, defaults to `("spatial", "temporal", "spatiotemporal", "spatiotemporal")`): + The spatiotemporal downsampling pattern per block. Per-layer values can be + - `"spatial"` (downsample spatial dims by 2x) + - `"temporal"` (downsample temporal dim by 2x) + - `"spatiotemporal"` (downsample both spatial and temporal dims by 2x) + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 128, + block_out_channels: Tuple[int, ...] = (256, 512, 1024, 2048), + down_block_types: Tuple[str, ...] = ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, True), + layers_per_block: Tuple[int, ...] = (4, 6, 6, 2, 2), + downsample_type: Tuple[str, ...] = ( + "spatial", + "temporal", + "spatiotemporal", + "spatiotemporal", + ), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = True, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.in_channels = in_channels * patch_size**2 + self.is_causal = is_causal + + output_channel = out_channels + + self.conv_in = LTX2VideoCausalConv3d( + in_channels=self.in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + # down blocks + num_block_out_channels = len(block_out_channels) + self.down_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel + output_channel = block_out_channels[i] + + if down_block_types[i] == "LTX2VideoDownBlock3D": + down_block = LTX2VideoDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + downsample_type=downsample_type[i], + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"Unknown down block type: {down_block_types[i]}") + + self.down_blocks.append(down_block) + + # mid block + self.mid_block = LTX2VideoMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[-1], + resnet_eps=resnet_norm_eps, + spatial_padding_mode=spatial_padding_mode, + ) + + # out + self.norm_out = PerChannelRMSNorm() + self.conv_act = nn.SiLU() + self.conv_out = LTX2VideoCausalConv3d( + in_channels=output_channel, + out_channels=out_channels + 1, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.Tensor, causal: Optional[bool] = None + ) -> torch.Tensor: + r"""The forward method of the `LTXVideoEncoder3d` class.""" + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + causal = causal or self.is_causal + + hidden_states = hidden_states.reshape( + batch_size, + num_channels, + post_patch_num_frames, + p_t, + post_patch_height, + p, + post_patch_width, + p, + ) + # Thanks for driving me insane with the weird patching order :( + hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4) + hidden_states = self.conv_in(hidden_states, causal=causal) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for down_block in self.down_blocks: + hidden_states = self._gradient_checkpointing_func( + down_block, hidden_states, None, None, causal + ) + + hidden_states = self._gradient_checkpointing_func( + self.mid_block, hidden_states, None, None, causal + ) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states, causal=causal) + + hidden_states = self.mid_block(hidden_states, causal=causal) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states, causal=causal) + + last_channel = hidden_states[:, -1:] + last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1) + hidden_states = torch.cat([hidden_states, last_channel], dim=1) + + return hidden_states + + +# Like LTX 1.0 LTXVideoDecoder3d, but has only 3 symmetric up blocks which are causal and residual with upsample_factor 2 +class LTX2VideoDecoder3d(nn.Module): + r""" + The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + in_channels (`int`, defaults to 128): + Number of latent channels. + out_channels (`int`, defaults to 3): + Number of output channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal upscaling layers or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `False`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + timestep_conditioning (`bool`, defaults to `False`): + Whether to condition the model on timesteps. + """ + + def __init__( + self, + in_channels: int = 128, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (256, 512, 1024), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True), + layers_per_block: Tuple[int, ...] = (5, 5, 5, 5), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = False, + inject_noise: Tuple[bool, ...] = (False, False, False), + timestep_conditioning: bool = False, + upsample_residual: Tuple[bool, ...] = (True, True, True), + upsample_factor: Tuple[bool, ...] = (2, 2, 2), + spatial_padding_mode: str = "reflect", + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.out_channels = out_channels * patch_size**2 + self.is_causal = is_causal + + block_out_channels = tuple(reversed(block_out_channels)) + spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) + layers_per_block = tuple(reversed(layers_per_block)) + inject_noise = tuple(reversed(inject_noise)) + upsample_residual = tuple(reversed(upsample_residual)) + upsample_factor = tuple(reversed(upsample_factor)) + output_channel = block_out_channels[0] + + self.conv_in = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + self.mid_block = LTX2VideoMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[0], + resnet_eps=resnet_norm_eps, + inject_noise=inject_noise[0], + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + + # up blocks + num_block_out_channels = len(block_out_channels) + self.up_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel // upsample_factor[i] + output_channel = block_out_channels[i] // upsample_factor[i] + + up_block = LTX2VideoUpBlock3d( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i + 1], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + inject_noise=inject_noise[i + 1], + timestep_conditioning=timestep_conditioning, + upsample_residual=upsample_residual[i], + upscale_factor=upsample_factor[i], + spatial_padding_mode=spatial_padding_mode, + ) + + self.up_blocks.append(up_block) + + # out + self.norm_out = PerChannelRMSNorm() + self.conv_act = nn.SiLU() + self.conv_out = LTX2VideoCausalConv3d( + in_channels=output_channel, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + # timestep embedding + self.time_embedder = None + self.scale_shift_table = None + self.timestep_scale_multiplier = None + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter( + torch.tensor(1000.0, dtype=torch.float32) + ) + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( + output_channel * 2, 0 + ) + self.scale_shift_table = nn.Parameter( + torch.randn(2, output_channel) / output_channel**0.5 + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + causal: Optional[bool] = None, + ) -> torch.Tensor: + causal = causal or self.is_causal + + hidden_states = self.conv_in(hidden_states, causal=causal) + + if self.timestep_scale_multiplier is not None: + temb = temb * self.timestep_scale_multiplier + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + self.mid_block, hidden_states, temb, None, causal + ) + + for up_block in self.up_blocks: + hidden_states = self._gradient_checkpointing_func( + up_block, hidden_states, temb, None, causal + ) + else: + hidden_states = self.mid_block(hidden_states, temb, causal=causal) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states, temb, causal=causal) + + hidden_states = self.norm_out(hidden_states) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1)) + temb = temb + self.scale_shift_table[None, ..., None, None, None] + shift, scale = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states, causal=causal) + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.reshape( + batch_size, -1, p_t, p, p, num_frames, height, width + ) + hidden_states = ( + hidden_states.permute(0, 1, 5, 2, 6, 4, 7, 3) + .flatten(6, 7) + .flatten(4, 5) + .flatten(2, 3) + ) + + return hidden_states + + +class AutoencoderKLLTX2Video(ParallelTiledVAE): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = False + + def __init__(self, config: LTXVideoVAEConfig): + super().__init__(config=config) + in_channels = config.arch_config.in_channels + latent_channels = config.arch_config.latent_channels + out_channels = config.arch_config.out_channels + block_out_channels = config.arch_config.block_out_channels + down_block_types = config.arch_config.down_block_types + spatio_temporal_scaling = config.arch_config.spatio_temporal_scaling + layers_per_block = config.arch_config.layers_per_block + downsample_type = config.arch_config.downsample_type + patch_size = config.arch_config.patch_size + patch_size_t = config.arch_config.patch_size_t + resnet_norm_eps = config.arch_config.resnet_norm_eps + encoder_causal = config.arch_config.encoder_causal + encoder_spatial_padding_mode = config.arch_config.encoder_spatial_padding_mode + + decoder_block_out_channels = config.arch_config.decoder_block_out_channels + decoder_spatio_temporal_scaling = ( + config.arch_config.decoder_spatio_temporal_scaling + ) + decoder_layers_per_block = config.arch_config.decoder_layers_per_block + decoder_causal = config.arch_config.decoder_causal + decoder_spatial_padding_mode = config.arch_config.decoder_spatial_padding_mode + + self.encoder = LTX2VideoEncoder3d( + in_channels, + latent_channels, + block_out_channels, + down_block_types, + spatio_temporal_scaling, + layers_per_block, + downsample_type, + patch_size, + patch_size_t, + resnet_norm_eps, + encoder_causal, + encoder_spatial_padding_mode, + ) + + self.decoder = LTX2VideoDecoder3d( + latent_channels, + out_channels, + decoder_block_out_channels, + decoder_spatio_temporal_scaling, + decoder_layers_per_block, + patch_size, + patch_size_t, + resnet_norm_eps, + decoder_causal, + decoder_spatial_padding_mode, + ) + + latents_mean = torch.zeros((latent_channels,), requires_grad=False) + latents_std = torch.ones((latent_channels,), requires_grad=False) + self.register_buffer("latents_mean", latents_mean, persistent=True) + self.register_buffer("latents_std", latents_std, persistent=True) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + self.use_framewise_encoding = False + self.use_framewise_decoding = False + + # This can be configured based on the amount of GPU memory available. + # `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs. + # Setting it to higher values results in higher memory usage. + self.num_sample_frames_batch_size = 16 + self.num_latent_frames_batch_size = 2 + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 512 + self.tile_sample_min_width = 512 + self.tile_sample_min_num_frames = 16 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 448 + self.tile_sample_stride_width = 448 + self.tile_sample_stride_num_frames = 8 + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_min_num_frames: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + tile_sample_stride_num_frames: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = ( + tile_sample_min_height or self.tile_sample_min_height + ) + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_min_num_frames = ( + tile_sample_min_num_frames or self.tile_sample_min_num_frames + ) + self.tile_sample_stride_height = ( + tile_sample_stride_height or self.tile_sample_stride_height + ) + self.tile_sample_stride_width = ( + tile_sample_stride_width or self.tile_sample_stride_width + ) + self.tile_sample_stride_num_frames = ( + tile_sample_stride_num_frames or self.tile_sample_stride_num_frames + ) + + def _encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames: + return self._temporal_tiled_encode(x, causal=causal) + + if self.use_tiling and ( + width > self.tile_sample_min_width or height > self.tile_sample_min_height + ): + return self.tiled_encode(x, causal=causal) + + enc = self.encoder(x, causal=causal) + + return enc + + def encode( + self, x: torch.Tensor, causal: Optional[bool] = None, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [ + self._encode(x_slice, causal=causal) for x_slice in x.split(1) + ] + h = torch.cat(encoded_slices) + else: + h = self._encode(x, causal=causal) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode( + self, + z: torch.Tensor, + temb: Optional[torch.Tensor] = None, + causal: Optional[bool] = None, + return_dict: bool = True, + ) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = ( + self.tile_sample_min_height // self.spatial_compression_ratio + ) + tile_latent_min_width = ( + self.tile_sample_min_width // self.spatial_compression_ratio + ) + tile_latent_min_num_frames = ( + self.tile_sample_min_num_frames // self.temporal_compression_ratio + ) + + if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: + return self._temporal_tiled_decode( + z, temb, causal=causal, return_dict=return_dict + ) + + if self.use_tiling and ( + width > tile_latent_min_width or height > tile_latent_min_height + ): + return self.tiled_decode(z, temb, causal=causal, return_dict=return_dict) + + dec = self.decoder(z, temb, causal=causal) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def decode( + self, + z: torch.Tensor, + temb: Optional[torch.Tensor] = None, + causal: Optional[bool] = None, + return_dict: bool = True, + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + if temb is not None: + decoded_slices = [ + self._decode(z_slice, t_slice, causal=causal).sample + for z_slice, t_slice in (z.split(1), temb.split(1)) + ] + else: + decoded_slices = [ + self._decode(z_slice, causal=causal).sample + for z_slice in z.split(1) + ] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z, temb, causal=causal).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * ( + 1 - y / blend_extent + ) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * ( + 1 - x / blend_extent + ) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def blend_t( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * ( + 1 - x / blend_extent + ) + b[:, :, x, :, :] * (x / blend_extent) + return b + + def tiled_encode( + self, x: torch.Tensor, causal: Optional[bool] = None + ) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = ( + self.tile_sample_min_height // self.spatial_compression_ratio + ) + tile_latent_min_width = ( + self.tile_sample_min_width // self.spatial_compression_ratio + ) + tile_latent_stride_height = ( + self.tile_sample_stride_height // self.spatial_compression_ratio + ) + tile_latent_stride_width = ( + self.tile_sample_stride_width // self.spatial_compression_ratio + ) + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + time = self.encoder( + x[ + :, + :, + :, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ], + causal=causal, + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append( + tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width] + ) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode( + self, + z: torch.Tensor, + temb: Optional[torch.Tensor], + causal: Optional[bool] = None, + return_dict: bool = True, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = ( + self.tile_sample_min_height // self.spatial_compression_ratio + ) + tile_latent_min_width = ( + self.tile_sample_min_width // self.spatial_compression_ratio + ) + tile_latent_stride_height = ( + self.tile_sample_stride_height // self.spatial_compression_ratio + ) + tile_latent_stride_width = ( + self.tile_sample_stride_width // self.spatial_compression_ratio + ) + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + time = self.decoder( + z[ + :, + :, + :, + i : i + tile_latent_min_height, + j : j + tile_latent_min_width, + ], + temb, + causal=causal, + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append( + tile[ + :, + :, + :, + : self.tile_sample_stride_height, + : self.tile_sample_stride_width, + ] + ) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def _temporal_tiled_encode( + self, x: torch.Tensor, causal: Optional[bool] = None + ) -> AutoencoderKLOutput: + batch_size, num_channels, num_frames, height, width = x.shape + latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 + + tile_latent_min_num_frames = ( + self.tile_sample_min_num_frames // self.temporal_compression_ratio + ) + tile_latent_stride_num_frames = ( + self.tile_sample_stride_num_frames // self.temporal_compression_ratio + ) + blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames + + row = [] + for i in range(0, num_frames, self.tile_sample_stride_num_frames): + tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] + if self.use_tiling and ( + height > self.tile_sample_min_height + or width > self.tile_sample_min_width + ): + tile = self.tiled_encode(tile, causal=causal) + else: + tile = self.encoder(tile, causal=causal) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :]) + + enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames] + return enc + + def _temporal_tiled_decode( + self, + z: torch.Tensor, + temb: Optional[torch.Tensor], + causal: Optional[bool] = None, + return_dict: bool = True, + ) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + tile_latent_min_height = ( + self.tile_sample_min_height // self.spatial_compression_ratio + ) + tile_latent_min_width = ( + self.tile_sample_min_width // self.spatial_compression_ratio + ) + tile_latent_min_num_frames = ( + self.tile_sample_min_num_frames // self.temporal_compression_ratio + ) + tile_latent_stride_num_frames = ( + self.tile_sample_stride_num_frames // self.temporal_compression_ratio + ) + blend_num_frames = ( + self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + ) + + row = [] + for i in range(0, num_frames, tile_latent_stride_num_frames): + tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] + if self.use_tiling and ( + tile.shape[-1] > tile_latent_min_width + or tile.shape[-2] > tile_latent_min_height + ): + decoded = self.tiled_decode( + tile, temb, causal=causal, return_dict=True + ).sample + else: + decoded = self.decoder(tile, temb, causal=causal) + if i > 0: + decoded = decoded[:, :, :-1, :, :] + row.append(decoded) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :] + result_row.append(tile) + else: + result_row.append( + tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :] + ) + + dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + temb: Optional[torch.Tensor] = None, + sample_posterior: bool = False, + encoder_causal: Optional[bool] = None, + decoder_causal: Optional[bool] = None, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[torch.Tensor, torch.Tensor]: + x = sample + posterior = self.encode(x, causal=encoder_causal).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, temb, causal=decoder_causal) + if not return_dict: + return (dec.sample,) + return dec + + +EntryClass = AutoencoderKLLTX2Video diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_common_utils.py b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_common_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..25515f83570b88696c827ae923626704509067f0 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_common_utils.py @@ -0,0 +1,457 @@ +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sglang.multimodal_gen.runtime.platforms import current_platform + + +class AvgDown3D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + + _first_chunk = first_chunk.get() if first_chunk is not None else None + if _first_chunk: + x = x[:, :, self.factor_t - 1 :, :, :] + return x + + +class WanCausalConv3d(nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int, int], + stride: int | tuple[int, int, int] = 1, + padding: int | tuple[int, int, int] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + self.padding: tuple[int, int, int] + # Set up causal padding + self._padding: tuple[int, ...] = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + 2 * self.padding[0], + 0, + ) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + x = ( + x.to(self.weight.dtype) if current_platform.is_mps() else x + ) # casting needed for mps since amp isn't supported + return super().forward(x) + + +class WanRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + """ + + def __init__( + self, + dim: int, + channel_first: bool = True, + images: bool = True, + bias: bool = False, + ) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return ( + F.normalize(x, dim=(1 if self.channel_first else -1)) + * self.scale + * self.gamma + + self.bias + ) + + +class WanUpsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +is_first_frame = None +feat_cache = None +feat_idx = None +cache_t = None +first_chunk = None + + +def bind_context( + is_first_frame_var, + feat_cache_var, + feat_idx_var, + cache_t_value, + first_chunk_var, +): + global is_first_frame + global feat_cache + global feat_idx + global cache_t + global first_chunk + is_first_frame = is_first_frame_var + feat_cache = feat_cache_var + feat_idx = feat_idx_var + cache_t = cache_t_value + first_chunk = first_chunk_var + + +def _ensure_bound(): + if ( + is_first_frame is None + or feat_cache is None + or feat_idx is None + or cache_t is None + or first_chunk is None + ): + raise RuntimeError("common_utils.bind_context() must be called before use.") + + +def resample_forward(self, x): + _ensure_bound() + b, c, t, h, w = x.size() + first_frame = is_first_frame.get() + if first_frame: + assert t == 1 + _feat_cache = feat_cache.get() + _feat_idx = feat_idx.get() + if self.mode == "upsample3d": + if _feat_cache is not None: + idx = _feat_idx + if _feat_cache[idx] is None: + _feat_cache[idx] = "Rep" + _feat_idx += 1 + else: + cache_x = x[:, :, -cache_t:, :, :].clone() + if ( + cache_x.shape[2] < 2 + and _feat_cache[idx] is not None + and _feat_cache[idx] != "Rep" + ): + # cache last frame of last two chunk + cache_x = torch.cat( + [ + _feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + if ( + cache_x.shape[2] < 2 + and _feat_cache[idx] is not None + and _feat_cache[idx] == "Rep" + ): + cache_x = torch.cat( + [torch.zeros_like(cache_x).to(cache_x.device), cache_x], + dim=2, + ) + if _feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, _feat_cache[idx]) + _feat_cache[idx] = cache_x + _feat_idx += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + feat_cache.set(_feat_cache) + feat_idx.set(_feat_idx) + elif not first_frame and hasattr(self, "time_conv"): + x = self.time_conv(x) + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + _feat_cache = feat_cache.get() + _feat_idx = feat_idx.get() + if self.mode == "downsample3d": + if _feat_cache is not None: + idx = _feat_idx + if _feat_cache[idx] is None: + _feat_cache[idx] = x.clone() + _feat_idx += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([_feat_cache[idx][:, :, -1:, :, :], x], 2)) + _feat_cache[idx] = cache_x + _feat_idx += 1 + feat_cache.set(_feat_cache) + feat_idx.set(_feat_idx) + elif not first_frame and hasattr(self, "time_conv"): + x = self.time_conv(x) + return x + + +def residual_block_forward(self, x): + _ensure_bound() + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + _feat_cache = feat_cache.get() + _feat_idx = feat_idx.get() + if _feat_cache is not None: + idx = _feat_idx + cache_x = x[:, :, -cache_t:, :, :].clone() + if cache_x.shape[2] < 2 and _feat_cache[idx] is not None: + cache_x = torch.cat( + [ + _feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + + x = self.conv1(x, _feat_cache[idx]) + _feat_cache[idx] = cache_x + _feat_idx += 1 + feat_cache.set(_feat_cache) + feat_idx.set(_feat_idx) + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + _feat_cache = feat_cache.get() + _feat_idx = feat_idx.get() + if _feat_cache is not None: + idx = _feat_idx + cache_x = x[:, :, -cache_t:, :, :].clone() + if cache_x.shape[2] < 2 and _feat_cache[idx] is not None: + cache_x = torch.cat( + [ + _feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + + x = self.conv2(x, _feat_cache[idx]) + _feat_cache[idx] = cache_x + _feat_idx += 1 + feat_cache.set(_feat_cache) + feat_idx.set(_feat_idx) + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +def attention_block_forward(self, x): + identity = x + batch_size, channels, num_frames, height, width = x.size() + x = x.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frames, channels, height, width + ) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * num_frames, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + + x = ( + x.squeeze(1) + .permute(0, 2, 1) + .reshape(batch_size * num_frames, channels, height, width) + ) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, num_frames, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + +def mid_block_forward(self, x): + # First residual block + x = self.resnets[0](x) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:], strict=True): + if attn is not None: + x = attn(x) + + x = resnet(x) + + return x + + +def residual_down_block_forward(self, x): + x_copy = x + for resnet in self.resnets: + x = resnet(x) + if self.downsampler is not None: + x = self.downsampler(x) + + return x + self.avg_shortcut(x_copy) + + +def residual_up_block_forward(self, x): + if self.avg_shortcut is not None: + x_copy = x + + for resnet in self.resnets: + x = resnet(x) + + if self.upsampler is not None: + x = self.upsampler(x) + + if self.avg_shortcut is not None: + x = x + self.avg_shortcut(x_copy) + + return x + + +def up_block_forward(self, x): + for resnet in self.resnets: + x = resnet(x) + + if self.upsamplers is not None: + x = self.upsamplers[0](x) + return x diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_dist_utils.py b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..83fa12ea0a5c2018272f9259cfdadc4875ca2ec0 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_dist_utils.py @@ -0,0 +1,680 @@ +import math + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_sp_group, + get_sp_parallel_rank, + get_sp_world_size, +) +from sglang.multimodal_gen.runtime.layers.activation import get_act_fn +from sglang.multimodal_gen.runtime.models.vaes.parallel.wan_common_utils import ( + AvgDown3D, + DupUp3D, + WanCausalConv3d, + WanRMS_norm, + WanUpsample, + attention_block_forward, + mid_block_forward, + resample_forward, + residual_block_forward, + residual_down_block_forward, + residual_up_block_forward, + up_block_forward, +) +from sglang.multimodal_gen.runtime.platforms import current_platform + + +def tensor_pad(x: torch.Tensor, len_to_pad: int, dim: int = -2): + x = torch.cat( + [ + x, + torch.zeros( + *x.shape[:dim], + len_to_pad, + *x.shape[dim + 1 :], + dtype=x.dtype, + device=x.device, + ), + ], + dim=dim, + ) + return x + + +def tensor_chunk(x: torch.Tensor, dim: int = -2, world_size: int = 1, rank: int = 0): + if x is None: + return None + if world_size <= 1: + return x + len_to_padding = (int(math.ceil(x.shape[dim] / world_size)) * world_size) - x.shape[ + dim + ] + if len_to_padding != 0: + x = tensor_pad(x, len_to_padding, dim=dim) + return torch.chunk(x, world_size, dim=dim)[rank] + + +def split_for_parallel_encode( + x: torch.Tensor, downsample_count: int, world_size: int, rank: int +): + orig_height = x.shape[-2] + expected_height = orig_height // (2**downsample_count) + factor = world_size * (2**downsample_count) + pad_h = (factor - orig_height % factor) % factor + if pad_h: + x = F.pad(x, (0, 0, 0, pad_h, 0, 0)) + expected_local_height = (orig_height + pad_h) // (2**downsample_count) // world_size + x = tensor_chunk(x, dim=-2, world_size=world_size, rank=rank) + return x, expected_height, expected_local_height + + +def ensure_local_height(x: torch.Tensor, expected_local_height: int | None): + if expected_local_height is None: + return x + if x.shape[-2] < expected_local_height: + pad = expected_local_height - x.shape[-2] + return F.pad(x, (0, 0, 0, pad, 0, 0)) + if x.shape[-2] > expected_local_height: + return x[..., :expected_local_height, :].contiguous() + return x + + +def split_for_parallel_decode( + x: torch.Tensor, upsample_count: int, world_size: int, rank: int +): + expected_height = x.shape[-2] * (2**upsample_count) + x = tensor_chunk(x, dim=-2, world_size=world_size, rank=rank) + return x, expected_height + + +def gather_and_trim_height(x: torch.Tensor, expected_height: int | None): + if expected_height is None: + return x + x = get_sp_group().all_gather(x, dim=-2) + if x.shape[-2] != expected_height: + x = x[..., :expected_height, :].contiguous() + return x + + +def _ensure_recv_buf( + recv_buf: torch.Tensor | None, reference: torch.Tensor +) -> torch.Tensor: + if ( + recv_buf is None + or recv_buf.shape != reference.shape + or recv_buf.dtype != reference.dtype + or recv_buf.device != reference.device + ): + return torch.empty_like(reference) + return recv_buf + + +def halo_exchange( + x: torch.Tensor, + height_halo_size: int = 1, + recv_top_buf: torch.Tensor | None = None, + recv_bottom_buf: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if height_halo_size == 0: + return x, recv_top_buf, recv_bottom_buf + + sp_group = get_sp_group() + rank = get_sp_parallel_rank() + world_size = get_sp_world_size() + group = sp_group.device_group + group_ranks = sp_group.ranks + + top_row = x[..., :height_halo_size, :].contiguous() + bottom_row = x[..., -height_halo_size:, :].contiguous() + + recv_top_buf = _ensure_recv_buf(recv_top_buf, top_row) + recv_bottom_buf = _ensure_recv_buf(recv_bottom_buf, bottom_row) + + # use batched P2P operations + p2p_ops = [] + + if rank > 0: + # has previous neighbor, recv previous rank's data to recv_top_buf and send top_row to it. + prev_rank = group_ranks[rank - 1] + p2p_ops.append(dist.P2POp(dist.irecv, recv_top_buf, prev_rank, group)) + p2p_ops.append(dist.P2POp(dist.isend, top_row, prev_rank, group)) + if rank < world_size - 1: + # has next neighbor, send bottom_row to next rank and recv next rank's data to recv_bottom_buf. + next_rank = group_ranks[rank + 1] + p2p_ops.append(dist.P2POp(dist.isend, bottom_row, next_rank, group)) + p2p_ops.append(dist.P2POp(dist.irecv, recv_bottom_buf, next_rank, group)) + + if rank == 0: + recv_top_buf.zero_() + if rank == world_size - 1: + recv_bottom_buf.zero_() + + if p2p_ops: + reqs = dist.batch_isend_irecv(p2p_ops) + for req in reqs: + req.wait() + + return ( + torch.concat([recv_top_buf, x, recv_bottom_buf], dim=-2), + recv_top_buf, + recv_bottom_buf, + ) + + +class WanDistConv2d(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int, int], + stride: int | tuple[int, int, int] = 1, + padding: int | tuple[int, int, int] = 0, + height_padding: tuple[int, int] | None = None, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + self.height_halo_size = (self.kernel_size[-2] - 1) // 2 + if height_padding is None: + height_padding = (self.padding[-2], self.padding[-2]) + self.height_pad_top, self.height_pad_bottom = height_padding + + self.padding: tuple[int, int] + if self.height_halo_size > 0: + self._padding = (self.padding[1], self.padding[1], 0, 0) + else: + self._padding = ( + self.padding[1], + self.padding[1], + self.padding[0], + self.padding[0], + ) + + self.padding = (0, 0) + self._halo_recv_top_buf: torch.Tensor | None = None + self._halo_recv_bottom_buf: torch.Tensor | None = None + self.rank = get_sp_parallel_rank() + self.world_size = get_sp_world_size() + + def forward(self, x): + x = F.pad(x, self._padding) + + x_padded, self._halo_recv_top_buf, self._halo_recv_bottom_buf = halo_exchange( + x, + height_halo_size=self.height_halo_size, + recv_top_buf=self._halo_recv_top_buf, + recv_bottom_buf=self._halo_recv_bottom_buf, + ) + + pad_top = self.height_pad_top + stride = self.stride[-2] + global_start = self.rank * x.shape[-2] + if self.height_halo_size > 0 and stride > 1: + shift = (global_start - self.height_halo_size + pad_top) % stride + if shift: + x_padded = x_padded[..., shift:, :] + global_start += shift + + out = super().forward(x_padded) + + if self.height_halo_size == 0: + return out + + local_height = x.shape[-2] + global_height = local_height * self.world_size + halo = self.height_halo_size + pad_bottom = self.height_pad_bottom + kernel = self.kernel_size[-2] + min_i = math.ceil(((-pad_top) - (global_start - halo)) / stride) + max_i = math.floor( + ((global_height - 1 + pad_bottom) - (kernel - 1) - (global_start - halo)) + / stride + ) + start = max(min_i, 0) + end = min(max_i + 1, out.shape[-2]) + if start != 0 or end != out.shape[-2]: + out = out[..., start:end, :] + + return out + + +class WanDistCausalConv3d(nn.Conv3d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int, int], + stride: int | tuple[int, int, int] = 1, + padding: int | tuple[int, int, int] = 0, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + self.height_pad_top = self.padding[1] + self.height_pad_bottom = self.padding[1] + self.height_halo_size = (self.kernel_size[-2] - 1) // 2 + + self.padding: tuple[int, int, int] + # Set up causal padding, let the halo to control height padding + if self.height_halo_size > 0: + self._padding: tuple[int, ...] = ( + self.padding[2], + self.padding[2], + 0, + 0, + 2 * self.padding[0], + 0, + ) + else: + self._padding: tuple[int, ...] = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + 2 * self.padding[0], + 0, + ) + self.padding = (0, 0, 0) + self._halo_recv_top_buf: torch.Tensor | None = None + self._halo_recv_bottom_buf: torch.Tensor | None = None + self.rank = get_sp_parallel_rank() + self.world_size = get_sp_world_size() + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + + x = F.pad(x, padding) + + x = ( + x.to(self.weight.dtype) if current_platform.is_mps() else x + ) # casting needed for mps since amp isn't supported + + x_padded, self._halo_recv_top_buf, self._halo_recv_bottom_buf = halo_exchange( + x, + height_halo_size=self.height_halo_size, + recv_top_buf=self._halo_recv_top_buf, + recv_bottom_buf=self._halo_recv_bottom_buf, + ) + + pad_top = self.height_pad_top + stride = self.stride[-2] + global_start = self.rank * x.shape[-2] + if self.height_halo_size > 0 and stride > 1: + shift = (global_start - self.height_halo_size + pad_top) % stride + if shift: + x_padded = x_padded[..., shift:, :] + global_start += shift + + out = super().forward(x_padded) + + if self.height_halo_size == 0: + return out + + local_height = x.shape[-2] + global_height = local_height * self.world_size + halo = self.height_halo_size + pad_bottom = self.height_pad_bottom + kernel = self.kernel_size[-2] + min_i = math.ceil(((-pad_top) - (global_start - halo)) / stride) + max_i = math.floor( + ((global_height - 1 + pad_bottom) - (kernel - 1) - (global_start - halo)) + / stride + ) + start = max(min_i, 0) + end = min(max_i + 1, out.shape[-2]) + if start != 0 or end != out.shape[-2]: + out = out[..., start:end, :] + + return out + + +class WanDistZeroPad2d(nn.Module): + """Apply 2D padding once globally across sequence-parallel height splits.""" + + def __init__(self, padding: tuple[int, int, int, int]) -> None: + super().__init__() + self.padding = padding # (left, right, top, bottom) + self.rank = get_sp_parallel_rank() + self.world_size = get_sp_world_size() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + left, right, top, bottom = self.padding + if self.world_size <= 1: + return F.pad(x, (left, right, top, bottom)) + # Only the first/last rank should contribute global top/bottom padding. + top = top if self.rank == 0 else 0 + bottom = bottom if self.rank == self.world_size - 1 else 0 + return F.pad(x, (left, right, top, bottom)) + + +class WanDistResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data used for parallel decoding. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # default to dim //2 + if upsample_out_dim is None: + upsample_out_dim = dim // 2 + + # layers + # We support parallel encode/decode; downsample uses halo exchange as well. + if mode == "upsample2d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + WanDistConv2d(dim, upsample_out_dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + WanDistConv2d(dim, upsample_out_dim, 3, padding=1), + ) + self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential( + WanDistZeroPad2d((0, 1, 0, 0)), + WanDistConv2d(dim, dim, 3, stride=(2, 2), height_padding=(0, 1)), + ) + elif mode == "downsample3d": + self.resample = nn.Sequential( + WanDistZeroPad2d((0, 1, 0, 0)), + WanDistConv2d(dim, dim, 3, stride=(2, 2), height_padding=(0, 1)), + ) + self.time_conv = WanCausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) + ) + + else: + self.resample = nn.Identity() + + def forward(self, x): + return resample_forward(self, x) + + +class WanDistResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = get_act_fn(non_linearity) + + # layers + self.norm1 = WanRMS_norm(in_dim, images=False) + self.conv1 = WanDistCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = WanRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = WanDistCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = ( + WanDistCausalConv3d(in_dim, out_dim, 1) + if in_dim != out_dim + else nn.Identity() + ) + + def forward(self, x): + return residual_block_forward(self, x) + + +class WanDistAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim) -> None: + super().__init__() + self.dim = dim + + # layers + self.norm = WanRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + self.rank = get_sp_parallel_rank() + self.world_size = get_sp_world_size() + self.sp_group = get_sp_group() + + def forward(self, x): + if self.world_size > 1: + x = self.sp_group.all_gather(x, dim=-2) + x = x.contiguous() + x = attention_block_forward(self, x) + if self.world_size > 1: + x = torch.chunk(x, self.world_size, dim=-2)[self.rank] + + return x + + +class WanDistMidBlock(nn.Module): + """ + Middle block for WanVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + num_layers: int = 1, + ): + super().__init__() + self.dim = dim + + # Create the components + resnets = [WanDistResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(WanDistAttentionBlock(dim)) + resnets.append(WanDistResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, x): + return mid_block_forward(self, x) + + +class WanDistResidualDownBlock(nn.Module): + def __init__( + self, + in_dim, + out_dim, + dropout, + num_res_blocks, + temperal_downsample=False, + down_flag=False, + ): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + resnets = [] + for _ in range(num_res_blocks): + resnets.append(WanDistResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + self.resnets = nn.ModuleList(resnets) + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + self.downsampler = WanDistResample(out_dim, mode=mode) + else: + self.downsampler = None + + def forward(self, x): + return residual_down_block_forward(self, x) + + +class WanDistResidualUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + temperal_upsample (bool): Whether to upsample on temporal dimension + up_flag (bool): Whether to upsample or not + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + temperal_upsample: bool = False, + up_flag: bool = False, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2, + ) + else: + self.avg_shortcut = None + + # create residual blocks + resnets = [] + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append( + WanDistResidualBlock(current_dim, out_dim, dropout, non_linearity) + ) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + if up_flag: + upsample_mode = "upsample3d" if temperal_upsample else "upsample2d" + self.upsampler = WanDistResample( + out_dim, mode=upsample_mode, upsample_out_dim=out_dim + ) + else: + self.upsampler = None + + self.gradient_checkpointing = False + + def forward(self, x): + return residual_up_block_forward(self, x) + + +class WanDistUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: str | None = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append( + WanDistResidualBlock(current_dim, out_dim, dropout, non_linearity) + ) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList( + [WanDistResample(out_dim, mode=upsample_mode)] + ) + + self.gradient_checkpointing = False + + def forward(self, x): + return up_block_forward(self, x) diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/vaes/wanvae.py b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/wanvae.py new file mode 100644 index 0000000000000000000000000000000000000000..7279c2ed83b13dcc3c08e1c451939e222dfb661a --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/vaes/wanvae.py @@ -0,0 +1,1024 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextvars +from contextlib import contextmanager + +import torch +import torch.distributed as dist +import torch.nn as nn +from einops import rearrange + +from sglang.multimodal_gen.configs.models.vaes import WanVAEConfig +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_sp_parallel_rank, + get_sp_world_size, +) +from sglang.multimodal_gen.runtime.layers.activation import get_act_fn +from sglang.multimodal_gen.runtime.models.vaes.common import ( + DiagonalGaussianDistribution, + ParallelTiledVAE, +) +from sglang.multimodal_gen.runtime.models.vaes.parallel.wan_common_utils import ( + AvgDown3D, + DupUp3D, + WanCausalConv3d, + WanRMS_norm, + WanUpsample, + attention_block_forward, + bind_context, + mid_block_forward, + resample_forward, + residual_block_forward, + residual_down_block_forward, + residual_up_block_forward, + up_block_forward, +) +from sglang.multimodal_gen.runtime.models.vaes.parallel.wan_dist_utils import ( + WanDistAttentionBlock, + WanDistCausalConv3d, + WanDistMidBlock, + WanDistResample, + WanDistResidualBlock, + WanDistResidualDownBlock, + WanDistResidualUpBlock, + WanDistUpBlock, + ensure_local_height, + gather_and_trim_height, + split_for_parallel_decode, + split_for_parallel_encode, +) + +CACHE_T = 2 + +is_first_frame = contextvars.ContextVar("is_first_frame", default=False) +feat_cache = contextvars.ContextVar("feat_cache", default=None) +feat_idx = contextvars.ContextVar("feat_idx", default=0) +first_chunk = contextvars.ContextVar("first_chunk", default=None) + +bind_context(is_first_frame, feat_cache, feat_idx, CACHE_T, first_chunk) + + +@contextmanager +def forward_context( + first_frame_arg=False, feat_cache_arg=None, feat_idx_arg=None, first_chunk_arg=None +): + is_first_frame_token = is_first_frame.set(first_frame_arg) + feat_cache_token = feat_cache.set(feat_cache_arg) + feat_idx_token = feat_idx.set(feat_idx_arg) + first_chunk_token = first_chunk.set(first_chunk_arg) + try: + yield + finally: + is_first_frame.reset(is_first_frame_token) + feat_cache.reset(feat_cache_token) + feat_idx.reset(feat_idx_token) + first_chunk.reset(first_chunk_token) + + +class WanResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # default to dim //2 + if upsample_out_dim is None: + upsample_out_dim = dim // 2 + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, upsample_out_dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, upsample_out_dim, 3, padding=1), + ) + self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + elif mode == "downsample3d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + self.time_conv = WanCausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) + ) + + else: + self.resample = nn.Identity() + + def forward(self, x): + return resample_forward(self, x) + + +class WanResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = get_act_fn(non_linearity) + + # layers + self.norm1 = WanRMS_norm(in_dim, images=False) + self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = WanRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = ( + WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + ) + + def forward(self, x): + return residual_block_forward(self, x) + + +class WanAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim) -> None: + super().__init__() + self.dim = dim + + # layers + self.norm = WanRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + return attention_block_forward(self, x) + + +class WanMidBlock(nn.Module): + """ + Middle block for WanVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + num_layers: int = 1, + ): + super().__init__() + self.dim = dim + + # Create the components + resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(WanAttentionBlock(dim)) + resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, x): + return mid_block_forward(self, x) + + +class WanResidualDownBlock(nn.Module): + + def __init__( + self, + in_dim, + out_dim, + dropout, + num_res_blocks, + temperal_downsample=False, + down_flag=False, + ): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + resnets = [] + for _ in range(num_res_blocks): + resnets.append(WanResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + self.resnets = nn.ModuleList(resnets) + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + self.downsampler = WanResample(out_dim, mode=mode) + else: + self.downsampler = None + + def forward(self, x): + return residual_down_block_forward(self, x) + + +class WanEncoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + in_channels: int = 3, + dim=128, + z_dim=4, + dim_mult=(1, 2, 4, 4), + num_res_blocks=2, + attn_scales=(), + temperal_downsample=(True, True, False), + dropout=0.0, + non_linearity: str = "silu", + is_residual: bool = False, # wan 2.2 vae use a residual downblock + use_parallel_encode: bool = False, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + dim_mult = list(dim_mult) + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = list(attn_scales) + self.temperal_downsample = list(temperal_downsample) + self.nonlinearity = get_act_fn(non_linearity) + self.use_parallel_encode = use_parallel_encode + self.downsample_count = max(len(dim_mult) - 1, 0) + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + world_size = 1 + if dist.is_initialized(): + world_size = get_sp_world_size() + + if use_parallel_encode and world_size > 1: + CausalConv3d = WanDistCausalConv3d + ResidualDownBlock = WanDistResidualDownBlock + ResidualBlock = WanDistResidualBlock + AttentionBlock = WanDistAttentionBlock + Resample = WanDistResample + MidBlock = WanDistMidBlock + else: + CausalConv3d = WanCausalConv3d + ResidualDownBlock = WanResidualDownBlock + ResidualBlock = WanResidualBlock + AttentionBlock = WanAttentionBlock + Resample = WanResample + MidBlock = WanMidBlock + + # init block + self.conv_in = CausalConv3d(in_channels, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:], strict=True)): + # residual (+attention) blocks + if is_residual: + self.down_blocks.append( + ResidualDownBlock( + in_dim, + out_dim, + dropout, + num_res_blocks, + temperal_downsample=( + temperal_downsample[i] if i != len(dim_mult) - 1 else False + ), + down_flag=i != len(dim_mult) - 1, + ) + ) + else: + for _ in range(num_res_blocks): + self.down_blocks.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = MidBlock(out_dim, dropout, non_linearity, num_layers=1) + + # output blocks + self.norm_out = WanRMS_norm(out_dim, images=False) + self.conv_out = CausalConv3d(out_dim, z_dim, 3, padding=1) + + self.gradient_checkpointing = False + self.world_size = 1 + self.rank = 0 + if dist.is_initialized(): + self.world_size = get_sp_world_size() + self.rank = get_sp_parallel_rank() + + def forward(self, x): + expected_local_height = None + expected_height = None + if self.use_parallel_encode and self.world_size > 1: + x, expected_height, expected_local_height = split_for_parallel_encode( + x, self.downsample_count, self.world_size, self.rank + ) + + _feat_cache = feat_cache.get() + _feat_idx = feat_idx.get() + if _feat_cache is not None: + idx = _feat_idx + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and _feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + _feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv_in(x, _feat_cache[idx]) + _feat_cache[idx] = cache_x + _feat_idx += 1 + feat_cache.set(_feat_cache) + feat_idx.set(_feat_idx) + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + x = layer(x) + + ## middle + if self.use_parallel_encode and self.world_size > 1: + x = ensure_local_height(x, expected_local_height) + x = self.mid_block(x) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + + _feat_cache = feat_cache.get() + _feat_idx = feat_idx.get() + if _feat_cache is not None: + idx = _feat_idx + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and _feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + _feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv_out(x, _feat_cache[idx]) + _feat_cache[idx] = cache_x + _feat_idx += 1 + feat_cache.set(_feat_cache) + feat_idx.set(_feat_idx) + else: + x = self.conv_out(x) + + if self.use_parallel_encode and self.world_size > 1: + x = gather_and_trim_height(x, expected_height) + return x + + +# adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +class WanResidualUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + temperal_upsample (bool): Whether to upsample on temporal dimension + up_flag (bool): Whether to upsample or not + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + temperal_upsample: bool = False, + up_flag: bool = False, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2, + ) + else: + self.avg_shortcut = None + + # create residual blocks + resnets = [] + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append( + WanResidualBlock(current_dim, out_dim, dropout, non_linearity) + ) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + if up_flag: + upsample_mode = "upsample3d" if temperal_upsample else "upsample2d" + self.upsampler = WanResample( + out_dim, mode=upsample_mode, upsample_out_dim=out_dim + ) + else: + self.upsampler = None + + self.gradient_checkpointing = False + + def forward(self, x): + return residual_up_block_forward(self, x) + + +class WanUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: str | None = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append( + WanResidualBlock(current_dim, out_dim, dropout, non_linearity) + ) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)]) + + self.gradient_checkpointing = False + + def forward(self, x): + return up_block_forward(self, x) + + +class WanDecoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=(1, 2, 4, 4), + num_res_blocks=2, + attn_scales=(), + temperal_upsample=(False, True, True), + dropout=0.0, + non_linearity: str = "silu", + out_channels: int = 3, + is_residual: bool = False, + use_parallel_decode: bool = False, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + dim_mult = list(dim_mult) + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = list(attn_scales) + self.temperal_upsample = list(temperal_upsample) + + self.nonlinearity = get_act_fn(non_linearity) + self.use_parallel_decode = use_parallel_decode + self.upsample_count = 0 + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + + world_size = 1 + if dist.is_initialized(): + world_size = get_sp_world_size() + + if use_parallel_decode and world_size > 1: + CausalConv3d = WanDistCausalConv3d + MidBlock = WanDistMidBlock + ResidualUpBlock = WanDistResidualUpBlock + UpBlock = WanDistUpBlock + else: + CausalConv3d = WanCausalConv3d + MidBlock = WanMidBlock + ResidualUpBlock = WanResidualUpBlock + UpBlock = WanUpBlock + + # init block + self.conv_in = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = MidBlock(dims[0], dropout, non_linearity, num_layers=1) + + # upsample blocks + self.upsample_count = 0 + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:], strict=True)): + # residual (+attention) blocks + if i > 0 and not is_residual: + # wan vae 2.1 + in_dim = in_dim // 2 + + # determine if we need upsampling + up_flag = i != len(dim_mult) - 1 + # determine upsampling mode, if not upsampling, set to None + upsample_mode = None + if up_flag and temperal_upsample[i]: + upsample_mode = "upsample3d" + elif up_flag: + upsample_mode = "upsample2d" + + # Create and add the upsampling block + if is_residual: + up_block = ResidualUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + temperal_upsample=temperal_upsample[i] if up_flag else False, + up_flag=up_flag, + non_linearity=non_linearity, + ) + else: + up_block = UpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + if up_flag: + self.upsample_count += 1 + + # output blocks + self.norm_out = WanRMS_norm(out_dim, images=False) + self.conv_out = CausalConv3d(out_dim, out_channels, 3, padding=1) + + self.gradient_checkpointing = False + self.world_size = 1 + self.rank = 0 + if dist.is_initialized(): + self.world_size = get_sp_world_size() + self.rank = get_sp_parallel_rank() + + def forward(self, x): + expected_height = None + if self.use_parallel_decode and self.world_size > 1: + x, expected_height = split_for_parallel_decode( + x, self.upsample_count, self.world_size, self.rank + ) + + ## conv1 + _feat_cache = feat_cache.get() + _feat_idx = feat_idx.get() + if _feat_cache is not None: + idx = _feat_idx + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and _feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + _feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv_in(x, _feat_cache[idx]) + _feat_cache[idx] = cache_x + _feat_idx += 1 + feat_cache.set(_feat_cache) + feat_idx.set(_feat_idx) + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + _feat_cache = feat_cache.get() + _feat_idx = feat_idx.get() + if _feat_cache is not None: + idx = _feat_idx + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and _feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + _feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv_out(x, _feat_cache[idx]) + _feat_cache[idx] = cache_x + _feat_idx += 1 + feat_cache.set(_feat_cache) + feat_idx.set(_feat_idx) + else: + x = self.conv_out(x) + + if self.use_parallel_decode and self.world_size > 1: + x = gather_and_trim_height(x, expected_height) + return x + + +def patchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size, + ) + + return x + + +class AutoencoderKLWan(ParallelTiledVAE): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + Introduced in [Wan 2.1]. + """ + + _supports_gradient_checkpointing = False + + def __init__( + self, + config: WanVAEConfig, + ) -> None: + nn.Module.__init__(self) + ParallelTiledVAE.__init__(self, config) + + self.z_dim = config.z_dim + self.temperal_downsample = list(config.temperal_downsample) + self.temperal_upsample = list(config.temperal_downsample)[::-1] + + if config.decoder_base_dim is None: + decoder_base_dim = config.base_dim + else: + decoder_base_dim = config.decoder_base_dim + + self.latents_mean = list(config.latents_mean) + self.latents_std = list(config.latents_std) + self.shift_factor = config.shift_factor + self.use_parallel_encode = getattr(config, "use_parallel_encode", False) + self.use_parallel_decode = getattr(config, "use_parallel_decode", False) + + if config.load_encoder: + self.encoder = WanEncoder3d( + in_channels=config.in_channels, + dim=config.base_dim, + z_dim=self.z_dim * 2, + dim_mult=config.dim_mult, + num_res_blocks=config.num_res_blocks, + attn_scales=config.attn_scales, + temperal_downsample=self.temperal_downsample, + dropout=config.dropout, + is_residual=config.is_residual, + use_parallel_encode=self.use_parallel_encode, + ) + self.quant_conv = WanCausalConv3d(self.z_dim * 2, self.z_dim * 2, 1) + self.post_quant_conv = WanCausalConv3d(self.z_dim, self.z_dim, 1) + + if config.load_decoder: + self.decoder = WanDecoder3d( + dim=decoder_base_dim, + z_dim=self.z_dim, + dim_mult=config.dim_mult, + num_res_blocks=config.num_res_blocks, + attn_scales=config.attn_scales, + temperal_upsample=self.temperal_upsample, + dropout=config.dropout, + out_channels=config.out_channels, + is_residual=config.is_residual, + use_parallel_decode=self.use_parallel_decode, + ) + + self.use_feature_cache = config.use_feature_cache + + def clear_cache(self) -> None: + + def _count_conv3d(model) -> int: + count = 0 + for m in model.modules(): + if isinstance(m, WanCausalConv3d) or isinstance(m, WanDistCausalConv3d): + count += 1 + return count + + if self.config.load_decoder: + self._conv_num = _count_conv3d(self.decoder) + self._conv_idx = 0 + self._feat_map = [None] * self._conv_num + # cache encode + if self.config.load_encoder: + self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_idx = 0 + self._enc_feat_map = [None] * self._enc_conv_num + + def encode(self, x: torch.Tensor) -> torch.Tensor: + if self.use_feature_cache: + self.clear_cache() + if self.config.patch_size is not None: + x = patchify(x, patch_size=self.config.patch_size) + with forward_context( + feat_cache_arg=self._enc_feat_map, feat_idx_arg=self._enc_conv_idx + ): + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + feat_idx.set(0) + if i == 0: + out = self.encoder(x[:, :, :1, :, :]) + else: + out_ = self.encoder(x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :]) + out = torch.cat([out, out_], 2) + enc = self.quant_conv(out) + mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :] + enc = torch.cat([mu, logvar], dim=1) + enc = DiagonalGaussianDistribution(enc) + self.clear_cache() + else: + for block in self.encoder.down_blocks: + if isinstance(block, WanResample) and block.mode == "downsample3d": + _padding = list(block.time_conv._padding) + _padding[4] = 2 + block.time_conv._padding = tuple(_padding) + enc = ParallelTiledVAE.encode(self, x) + + return enc + + def _encode(self, x: torch.Tensor, first_frame=False) -> torch.Tensor: + with forward_context(first_frame_arg=first_frame): + out = self.encoder(x) + enc = self.quant_conv(out) + mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :] + enc = torch.cat([mu, logvar], dim=1) + return enc + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + first_frame = x[:, :, 0, :, :].unsqueeze(2) + first_frame = self._encode(first_frame, first_frame=True) + + enc = ParallelTiledVAE.tiled_encode(self, x) + enc = enc[:, :, 1:] + enc = torch.cat([first_frame, enc], dim=2) + return enc + + def spatial_tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + first_frame = x[:, :, 0, :, :].unsqueeze(2) + first_frame = self._encode(first_frame, first_frame=True) + + enc = ParallelTiledVAE.spatial_tiled_encode(self, x) + enc = enc[:, :, 1:] + enc = torch.cat([first_frame, enc], dim=2) + return enc + + def decode(self, z: torch.Tensor) -> torch.Tensor: + if self.use_feature_cache: + self.clear_cache() + iter_ = z.shape[2] + x = self.post_quant_conv(z) + with forward_context( + feat_cache_arg=self._feat_map, feat_idx_arg=self._conv_idx + ): + for i in range(iter_): + feat_idx.set(0) + if i == 0: + first_chunk.set(True) + out = self.decoder(x[:, :, i : i + 1, :, :]) + else: + first_chunk.set(False) + out_ = self.decoder(x[:, :, i : i + 1, :, :]) + out = torch.cat([out, out_], 2) + + if self.config.patch_size is not None: + out = unpatchify(out, patch_size=self.config.patch_size) + + out = out.float() + out = torch.clamp(out, min=-1.0, max=1.0) + self.clear_cache() + else: + out = ParallelTiledVAE.decode(self, z) + + return out + + def _decode(self, z: torch.Tensor, first_frame=False) -> torch.Tensor: + x = self.post_quant_conv(z) + with forward_context(first_frame_arg=first_frame): + out = self.decoder(x) + + out = torch.clamp(out, min=-1.0, max=1.0) + + return out + + def tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + self.blend_num_frames *= 2 + dec = ParallelTiledVAE.tiled_decode(self, z) + start_frame_idx = self.temporal_compression_ratio - 1 + dec = dec[:, :, start_frame_idx:] + return dec + + def spatial_tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + dec = ParallelTiledVAE.spatial_tiled_decode(self, z) + start_frame_idx = self.temporal_compression_ratio - 1 + dec = dec[:, :, start_frame_idx:] + return dec + + def parallel_tiled_decode(self, z: torch.FloatTensor) -> torch.FloatTensor: + self.blend_num_frames *= 2 + dec = ParallelTiledVAE.parallel_tiled_decode(self, z) + start_frame_idx = self.temporal_compression_ratio - 1 + dec = dec[:, :, start_frame_idx:] + return dec + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + """ + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + return dec + + +EntryClass = AutoencoderKLWan diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/vision_utils.py b/sglang/python/sglang/multimodal_gen/runtime/models/vision_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2086170d59cb96421d261e85f34869c82bb55d76 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/vision_utils.py @@ -0,0 +1,297 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import os +import tempfile +from collections.abc import Callable +from urllib.parse import unquote, urlparse + +import imageio +import numpy as np +import PIL.Image +import PIL.ImageOps +import requests +import torch +from packaging import version + +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } + + +def pil_to_numpy(images: list[PIL.Image.Image] | PIL.Image.Image) -> np.ndarray: + r""" + Convert a PIL image or a list of PIL images to NumPy arrays. + + Args: + images (`PIL.Image.Image` or `List[PIL.Image.Image]`): + The PIL image or list of images to convert to NumPy format. + + Returns: + `np.ndarray`: + A NumPy array representation of the images. + """ + if not isinstance(images, list): + images = [images] + images = [np.array(image).astype(np.float32) / 255.0 for image in images] + images_arr: np.ndarray = np.stack(images, axis=0) + + return images_arr + + +def numpy_to_pt(images: np.ndarray) -> torch.Tensor: + r""" + Convert a NumPy image to a PyTorch tensor. + + Args: + images (`np.ndarray`): + The NumPy image array to convert to PyTorch format. + + Returns: + `torch.Tensor`: + A PyTorch tensor representation of the images. + """ + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + +def normalize(images: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: + r""" + Normalize an image array to [-1,1]. + + Args: + images (`np.ndarray` or `torch.Tensor`): + The image array to normalize. + + Returns: + `np.ndarray` or `torch.Tensor`: + The normalized image array. + """ + return 2.0 * images - 1.0 + + +# adapted from diffusers.utils import load_image +def load_image( + image: str | PIL.Image.Image, + convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] | None = None, +) -> PIL.Image.Image: + """ + Loads `image` to a PIL Image. + + Args: + image (`str` or `PIL.Image.Image`): + The image to convert to the PIL Image format. + convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*): + A conversion method to apply to the image after loading it. When set to `None` the image will be converted + "RGB". + """ + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + image = PIL.Image.open(requests.get(image, stream=True).raw) + elif os.path.isfile(image): + image = PIL.Image.open(image) + else: + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path." + ) + elif isinstance(image, PIL.Image.Image): + image = image + else: + raise ValueError( + "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image." + ) + + image = PIL.ImageOps.exif_transpose(image) + + if convert_method is not None: + image = convert_method(image) + else: + image = image.convert("RGB") + + return image + + +# adapted from diffusers.utils import load_video +def load_video( + video: str, + convert_method: ( + Callable[[list[PIL.Image.Image]], list[PIL.Image.Image]] | None + ) = None, +) -> list[PIL.Image.Image]: + """ + Loads `video` to a list of PIL Image. + Args: + video (`str`): + A URL or Path to a video to convert to a list of PIL Image format. + convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): + A conversion method to apply to the video after loading it. When set to `None` the images will be converted + to "RGB". + Returns: + `List[PIL.Image.Image]`: + The video as a list of PIL images. + """ + is_url = video.startswith("http://") or video.startswith("https://") + is_file = os.path.isfile(video) + was_tempfile_created = False + + if not (is_url or is_file): + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path." + ) + + if is_url: + response = requests.get(video, stream=True) + if response.status_code != 200: + raise ValueError( + f"Failed to download video. Status code: {response.status_code}" + ) + + parsed_url = urlparse(video) + file_name = os.path.basename(unquote(parsed_url.path)) + + suffix = os.path.splitext(file_name)[1] or ".mp4" + with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp_file: + video_path = temp_file.name + video_data = response.iter_content(chunk_size=8192) + for chunk in video_data: + temp_file.write(chunk) + + video = video_path + + pil_images = [] + if video.endswith(".gif"): + gif = PIL.Image.open(video) + try: + while True: + pil_images.append(gif.copy()) + gif.seek(gif.tell() + 1) + except EOFError: + pass + + else: + try: + imageio.plugins.ffmpeg.get_exe() + except AttributeError: + raise AttributeError( + "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg" + ) from None + + with imageio.get_reader(video) as reader: + # Read all frames + for frame in reader: + pil_images.append(PIL.Image.fromarray(frame)) + + if was_tempfile_created: + os.remove(video_path) + + if convert_method is not None: + pil_images = convert_method(pil_images) + + return pil_images + + +def get_default_height_width( + image: PIL.Image.Image | np.ndarray | torch.Tensor, + vae_scale_factor: int, + height: int | None = None, + width: int | None = None, +) -> tuple[int, int]: + r""" + Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`. + + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it + should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch + tensor, it should have shape `[batch, channels, height, width]`. + height (`Optional[int]`, *optional*, defaults to `None`): + The height of the preprocessed image. If `None`, the height of the `image` input will be used. + width (`Optional[int]`, *optional*, defaults to `None`): + The width of the preprocessed image. If `None`, the width of the `image` input will be used. + + Returns: + `Tuple[int, int]`: + A tuple containing the height and width, both resized to the nearest integer multiple of + `vae_scale_factor`. + """ + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + else: + height = image.shape[1] + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + else: + width = image.shape[2] + + width, height = ( + x - x % vae_scale_factor for x in (width, height) + ) # resize to integer multiple of vae_scale_factor + + return height, width + + +def resize( + image: PIL.Image.Image | np.ndarray | torch.Tensor, + height: int, + width: int, + resize_mode: str = "default", # "default", "fill", "crop" + resample: str = "lanczos", +) -> PIL.Image.Image | np.ndarray | torch.Tensor: + """ + Resize image. + + Args: + image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`): + The image input, can be a PIL image, numpy array or pytorch tensor. + height (`int`): + The height to resize to. + width (`int`): + The width to resize to. + resize_mode (`str`, *optional*, defaults to `default`): + The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit + within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, + will resize the image to fit within the specified width and height, maintaining the aspect ratio, and + then center the image within the dimensions, filling empty with data from image. If `crop`, will resize + the image to fit within the specified width and height, maintaining the aspect ratio, and then center + the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only + supported for PIL image input. + + Returns: + `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: + The resized image. + """ + if resize_mode != "default" and not isinstance(image, PIL.Image.Image): + raise ValueError( + f"Only PIL image input is supported for resize_mode {resize_mode}" + ) + assert isinstance(image, PIL.Image.Image) + if resize_mode == "default": + image = image.resize((width, height), resample=PIL_INTERPOLATION[resample]) + else: + raise ValueError(f"resize_mode {resize_mode} is not supported") + return image diff --git a/sglang/python/sglang/multimodal_gen/runtime/models/vocoder/ltx_2_vocoder.py b/sglang/python/sglang/multimodal_gen/runtime/models/vocoder/ltx_2_vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..82ad20d2a31500a35516f4ed1e933cf0072ccdf7 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/models/vocoder/ltx_2_vocoder.py @@ -0,0 +1,193 @@ +import math +from abc import ABC +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sglang.multimodal_gen.configs.models.vocoder.ltx_vocoder import LTXVocoderConfig + + +class ResBlock(nn.Module): + def __init__( + self, + channels: int, + kernel_size: int = 3, + stride: int = 1, + dilations: Tuple[int, ...] = (1, 3, 5), + leaky_relu_negative_slope: float = 0.1, + padding_mode: str = "same", + ): + super().__init__() + self.dilations = dilations + self.negative_slope = leaky_relu_negative_slope + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding_mode, + ) + for dilation in dilations + ] + ) + + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=stride, + dilation=1, + padding=padding_mode, + ) + for _ in range(len(dilations)) + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for conv1, conv2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, negative_slope=self.negative_slope) + xt = conv1(xt) + xt = F.leaky_relu(xt, negative_slope=self.negative_slope) + xt = conv2(xt) + x = x + xt + return x + + +class LTX2Vocoder(ABC, nn.Module): + r""" + LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms. + """ + + def __init__( + self, + config: LTXVocoderConfig, + ): + super().__init__() + self.config = config + self.sample_rate = ( + getattr(config.arch_config, "sample_rate", None) + or getattr(config.arch_config, "sampling_rate", None) + or getattr(config.arch_config, "audio_sample_rate", None) + ) + + in_channels = config.arch_config.in_channels + hidden_channels = config.arch_config.hidden_channels + out_channels = config.arch_config.out_channels + upsample_kernel_sizes = config.arch_config.upsample_kernel_sizes + upsample_factors = config.arch_config.upsample_factors + resnet_kernel_sizes = config.arch_config.resnet_kernel_sizes + resnet_dilations = config.arch_config.resnet_dilations + leaky_relu_negative_slope = config.arch_config.leaky_relu_negative_slope + + self.num_upsample_layers = len(upsample_kernel_sizes) + self.resnets_per_upsample = len(resnet_kernel_sizes) + self.out_channels = out_channels + self.total_upsample_factor = math.prod(upsample_factors) + self.negative_slope = leaky_relu_negative_slope + + if self.num_upsample_layers != len(upsample_factors): + raise ValueError( + f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length" + f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively." + ) + + if self.resnets_per_upsample != len(resnet_dilations): + raise ValueError( + f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length" + f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively." + ) + + self.conv_in = nn.Conv1d( + in_channels, hidden_channels, kernel_size=7, stride=1, padding=3 + ) + + self.upsamplers = nn.ModuleList() + self.resnets = nn.ModuleList() + input_channels = hidden_channels + for i, (stride, kernel_size) in enumerate( + zip(upsample_factors, upsample_kernel_sizes) + ): + output_channels = input_channels // 2 + self.upsamplers.append( + nn.ConvTranspose1d( + input_channels, # hidden_channels // (2 ** i) + output_channels, # hidden_channels // (2 ** (i + 1)) + kernel_size, + stride=stride, + padding=(kernel_size - stride) // 2, + ) + ) + + for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations): + self.resnets.append( + ResBlock( + output_channels, + kernel_size, + dilations=dilations, + leaky_relu_negative_slope=leaky_relu_negative_slope, + ) + ) + input_channels = output_channels + + self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3) + + def forward( + self, hidden_states: torch.Tensor, time_last: bool = False + ) -> torch.Tensor: + r""" + Forward pass of the vocoder. + + Args: + hidden_states (`torch.Tensor`): + Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last` + is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is + `True`. + time_last (`bool`, *optional*, defaults to `False`): + Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension. + + Returns: + `torch.Tensor`: + Audio waveform tensor of shape (batch_size, out_channels, audio_length) + """ + + # Ensure that the time/frame dimension is last + if not time_last: + hidden_states = hidden_states.transpose(2, 3) + # Combine channels and frequency (mel bins) dimensions + hidden_states = hidden_states.flatten(1, 2) + + hidden_states = self.conv_in(hidden_states) + + for i in range(self.num_upsample_layers): + hidden_states = F.leaky_relu( + hidden_states, negative_slope=self.negative_slope + ) + hidden_states = self.upsamplers[i](hidden_states) + + # Run all resnets in parallel on hidden_states + start = i * self.resnets_per_upsample + end = (i + 1) * self.resnets_per_upsample + resnet_outputs = torch.stack( + [self.resnets[j](hidden_states) for j in range(start, end)], dim=0 + ) + + hidden_states = torch.mean(resnet_outputs, dim=0) + + # NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of + # 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended + hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01) + hidden_states = self.conv_out(hidden_states) + hidden_states = torch.tanh(hidden_states) + + return hidden_states + + +EntryClass = LTX2Vocoder diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/__init__.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..af2eb7d103a81378d30056cf536353bd24621a5e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/comfyui_flux_pipeline.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/comfyui_flux_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..f6910187712c31db159faaa824ab32b46fa6cf35 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/comfyui_flux_pipeline.py @@ -0,0 +1,687 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo +# SPDX-License-Identifier: Apache-2.0 + +import os +import re +from typing import Any, Generator + +import torch + +from sglang.multimodal_gen.configs.models.dits.flux import FluxConfig +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.models.registry import ModelRegistry +from sglang.multimodal_gen.runtime.models.schedulers.scheduling_comfyui_passthrough import ( + ComfyUIPassThroughScheduler, +) +from sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages import ( + ComfyUILatentPreparationStage, + DenoisingStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE + +logger = init_logger(__name__) + + +class ComfyUIFluxPipeline(LoRAPipeline, ComposedPipelineBase): + """ + Simplified pipeline for ComfyUI integration with only denoising stage. + + This pipeline requires pre-processed inputs: + - prompt_embeds: Pre-encoded text embeddings (list of tensors) + - negative_prompt_embeds: Pre-encoded negative prompt embeddings (if using CFG) + - latents: Optional initial noise latents (will be generated if not provided) + + Usage: + generator = DiffGenerator.from_pretrained( + model_path="path/to/model", + pipeline_class_name="ComfyUIFluxPipeline", + device="cuda", + ) + """ + + pipeline_name = "ComfyUIFluxPipeline" + + # Configuration classes for safetensors files without model_index.json + from sglang.multimodal_gen.configs.pipeline_configs.flux import FluxPipelineConfig + from sglang.multimodal_gen.configs.sample.flux import FluxSamplingParams + + pipeline_config_cls = FluxPipelineConfig + sampling_params_cls = FluxSamplingParams + + _required_config_modules = [ + "transformer", + "scheduler", + ] + + def initialize_pipeline(self, server_args: ServerArgs): + """ + Initialize the pipeline with ComfyUI pass-through scheduler. + This scheduler does not modify latents, allowing ComfyUI to handle denoising. + """ + self.modules["scheduler"] = ComfyUIPassThroughScheduler( + num_train_timesteps=1000 + ) + + if hasattr(server_args.pipeline_config, "vae_config"): + vae_config = server_args.pipeline_config.vae_config + if hasattr(vae_config, "post_init") and not hasattr( + vae_config, "_post_init_called" + ): + vae_config.post_init() + logger.info( + "Called vae_config.post_init() to set spatial_compression_ratio. " + f"spatial_compression_ratio={vae_config.arch_config.spatial_compression_ratio}" + ) + + def load_modules( + self, + server_args: ServerArgs, + loaded_modules: dict[str, torch.nn.Module] | None = None, + ) -> dict[str, Any]: + """ + Load modules for ComfyUIFluxPipeline. + + If model_path is a safetensors file, load transformer directly from it + without requiring model_index.json. Otherwise, fall back to default loading. + """ + if os.path.isfile(self.model_path) and self.model_path.endswith(".safetensors"): + logger.info( + "Detected safetensors file, loading transformer directly from: %s", + self.model_path, + ) + return self._load_transformer_from_safetensors(server_args, loaded_modules) + else: + logger.info( + "Model path is a directory, using default loading method: %s", + self.model_path, + ) + return super().load_modules(server_args, loaded_modules) + + def _load_and_convert_weights_from_safetensors( + self, + model_cls: type, + dit_config: FluxConfig, + hf_config: dict, + safetensors_list: list[str], + updated_mapping: dict, + qkv_size: int, + mlp_hidden_dim: int, + has_guidance_embeds: bool, + default_dtype: torch.dtype, + ) -> tuple[torch.nn.Module, dict]: + """ + Load and convert weights from safetensors file, then load them into the model. + """ + from sglang.multimodal_gen.runtime.loader.utils import ( + get_param_names_mapping, + set_default_torch_dtype, + ) + from sglang.multimodal_gen.runtime.loader.weight_utils import ( + safetensors_weights_iterator, + ) + + logger.info( + "Converting ComfyUI Flux weights to SGLang format and loading model..." + ) + + # Create model on target device + device = get_local_torch_device() + with set_default_torch_dtype(default_dtype): + model = model_cls(**{"config": dit_config, "hf_config": hf_config}) + model = model.to(device) + + # Verify model has guidance_embedder if config says it should + has_guidance_embedder = hasattr(model.time_text_embed, "guidance_embedder") + if has_guidance_embeds and not has_guidance_embedder: + logger.warning( + "Config has guidance_embeds=True but model doesn't have guidance_embedder. " + "This may indicate a configuration mismatch." + ) + elif not has_guidance_embeds and has_guidance_embedder: + logger.warning( + "Config has guidance_embeds=False but model has guidance_embedder. " + "This may indicate a configuration mismatch." + ) + + # Note: guidance_in mappings are already included in comfyui_flux_mappings above. + # If model doesn't support guidance embeddings, the weights will be filtered out + # in _convert_comfyui_weights() based on has_guidance_embeds flag. + + param_names_mapping_fn = get_param_names_mapping(updated_mapping) + + weight_iterator = safetensors_weights_iterator(safetensors_list) + converted_weights = self._convert_comfyui_weights( + weight_iterator=weight_iterator, + qkv_size=qkv_size, + mlp_hidden_dim=mlp_hidden_dim, + has_guidance_embeds=has_guidance_embeds, + ) + + model_state_dict = model.state_dict() + missing_keys = set(model_state_dict.keys()) + unexpected_keys = [] + loaded_count = 0 + reverse_param_names_mapping = {} + + # Handle merged parameters (collect all parts before merging) + from collections import defaultdict + + to_merge_params = defaultdict(dict) + + # Process weights incrementally: load immediately after conversion + for source_name, tensor in converted_weights: + target_name, merge_index, num_params_to_merge = param_names_mapping_fn( + source_name + ) + reverse_param_names_mapping[target_name] = ( + source_name, + merge_index, + num_params_to_merge, + ) + + if merge_index is not None: + # Collect parts for merging + to_merge_params[target_name][merge_index] = tensor + if len(to_merge_params[target_name]) == num_params_to_merge: + # All parts collected, merge them + sorted_tensors = [ + to_merge_params[target_name][i] + for i in range(num_params_to_merge) + ] + merged_tensor = torch.cat(sorted_tensors, dim=0) + # Load immediately after merging + if target_name in model_state_dict: + param = model_state_dict[target_name] + loaded_tensor = merged_tensor.to( + device=param.device, dtype=param.dtype + ) + param.data.copy_(loaded_tensor) + missing_keys.discard(target_name) + loaded_count += 1 + del merged_tensor, loaded_tensor + else: + unexpected_keys.append(target_name) + # Clear merged parts + del to_merge_params[target_name] + for t in sorted_tensors: + del t + else: + # Direct mapping, load immediately + if target_name in model_state_dict: + param = model_state_dict[target_name] + # Check shape compatibility + if tensor.shape != param.shape: + logger.warning( + f"Shape mismatch for {target_name}: " + f"loaded {tensor.shape} vs model {param.shape}, skipping. " + f"Source: {source_name}" + ) + unexpected_keys.append(target_name) + del tensor + continue + + # Debug logging for norm_out.linear to verify mapping + if ( + "norm_out.linear" in target_name + or "final_layer.adaLN_modulation" in source_name + ): + logger.info( + f"Loading norm_out.linear: {source_name} -> {target_name}, " + f"shape: {tensor.shape}" + ) + + loaded_tensor = tensor.to(device=param.device, dtype=param.dtype) + param.data.copy_(loaded_tensor) + missing_keys.discard(target_name) + loaded_count += 1 + del tensor, loaded_tensor + else: + # Debug logging for unmapped parameters + if "norm_out.linear" in target_name: + logger.warning( + f"norm_out.linear parameter {target_name} not found in model state_dict. " + f"Source: {source_name}" + ) + unexpected_keys.append(target_name) + + optional_missing_keys = [] + required_missing_keys = [] + for key in missing_keys: + if key.endswith(".bias"): + # Check if corresponding weight exists (if weight exists but bias doesn't, it's optional) + weight_key = key.replace(".bias", ".weight") + if weight_key not in missing_keys: + optional_missing_keys.append(key) + else: + required_missing_keys.append(key) + else: + required_missing_keys.append(key) + + if required_missing_keys: + logger.warning( + f"Required missing keys (first 10): {required_missing_keys[:10]}..." + ) + if optional_missing_keys: + logger.info( + f"Optional missing keys (bias parameters, {len(optional_missing_keys)} total): " + f"These will use default values (zeros)" + ) + if unexpected_keys: + logger.warning(f"Unexpected keys (first 10): {unexpected_keys[:10]}...") + + logger.info(f"Successfully loaded {loaded_count} weight tensors") + + return model, reverse_param_names_mapping + + def _convert_comfyui_weights( + self, + weight_iterator: Generator[tuple[str, torch.Tensor], None, None], + qkv_size: int, + mlp_hidden_dim: int, + has_guidance_embeds: bool, + ) -> Generator[tuple[str, torch.Tensor], None, None]: + """ + Convert ComfyUI Flux weights to SGLang format. + Splits fused qkv weights into to_q/to_k/to_v plus proj_mlp. + Filters out guidance_in weights if model doesn't support guidance embeddings. + Handles scale/shift order difference between ComfyUI and AdaLayerNormContinuous. + """ + for name, tensor in weight_iterator: + if not has_guidance_embeds and name.startswith("guidance_in."): + logger.debug( + f"Skipping {name} (model doesn't support guidance embeddings)" + ) + continue + + # Split fused qkv in double blocks into separate q/k/v projections + match = re.match( + r"double_blocks\.(\d+)\.(img_attn|txt_attn)\.qkv\.(weight|bias)$", name + ) + if match: + block_idx, attn_type, param_type = match.groups() + hidden_size = qkv_size // 3 + + if tensor.shape[0] < 3 * hidden_size: + logger.warning( + f"{name} shape {tensor.shape} smaller than expected qkv size {3 * hidden_size}, skipping" + ) + continue + + if param_type == "bias": + q_tensor = tensor[:hidden_size] + k_tensor = tensor[hidden_size : 2 * hidden_size] + v_tensor = tensor[2 * hidden_size : 3 * hidden_size] + else: + q_tensor = tensor[:hidden_size, :] + k_tensor = tensor[hidden_size : 2 * hidden_size, :] + v_tensor = tensor[2 * hidden_size : 3 * hidden_size, :] + + target_prefix = f"transformer_blocks.{block_idx}.attn" + if attn_type == "img_attn": + yield f"{target_prefix}.to_q.{param_type}", q_tensor + yield f"{target_prefix}.to_k.{param_type}", k_tensor + yield f"{target_prefix}.to_v.{param_type}", v_tensor + else: + # txt_attn corresponds to encoder projections + yield f"{target_prefix}.add_q_proj.{param_type}", q_tensor + yield f"{target_prefix}.add_k_proj.{param_type}", k_tensor + yield f"{target_prefix}.add_v_proj.{param_type}", v_tensor + continue + + match = re.match(r"single_blocks\.(\d+)\.linear1\.(weight|bias)$", name) + if match: + block_idx, param_type = match.groups() + expected_size = qkv_size + mlp_hidden_dim + + if tensor.shape[0] < expected_size: + logger.warning( + f"linear1.{param_type} shape {tensor.shape} doesn't match " + f"expected size {expected_size}, skipping" + ) + continue + + # Split tensor + qkv_tensor = ( + tensor[:qkv_size] if param_type == "bias" else tensor[:qkv_size, :] + ) + mlp_tensor = ( + tensor[qkv_size:] if param_type == "bias" else tensor[qkv_size:, :] + ) + + # Split qkv into q/k/v for single blocks + hidden_size = qkv_size // 3 + if param_type == "bias": + q_tensor = qkv_tensor[:hidden_size] + k_tensor = qkv_tensor[hidden_size : 2 * hidden_size] + v_tensor = qkv_tensor[2 * hidden_size : 3 * hidden_size] + else: + q_tensor = qkv_tensor[:hidden_size, :] + k_tensor = qkv_tensor[hidden_size : 2 * hidden_size, :] + v_tensor = qkv_tensor[2 * hidden_size : 3 * hidden_size, :] + + yield f"single_transformer_blocks.{block_idx}.attn.to_q.{param_type}", q_tensor + yield f"single_transformer_blocks.{block_idx}.attn.to_k.{param_type}", k_tensor + yield f"single_transformer_blocks.{block_idx}.attn.to_v.{param_type}", v_tensor + yield f"single_transformer_blocks.{block_idx}.proj_mlp.{param_type}", mlp_tensor + elif name == "final_layer.adaLN_modulation.1.weight": + # ComfyUI: output order is [shift, scale] + # AdaLayerNormContinuous: expects [scale, shift] + # Need to swap the first half and second half of the weight matrix + # Weight shape: (2 * hidden_size, hidden_size) + # Split into two halves and swap them + half_size = tensor.shape[0] // 2 + shift_weights = tensor[:half_size, :] + scale_weights = tensor[half_size:, :] + # Swap: put scale first, then shift + swapped_tensor = torch.cat([scale_weights, shift_weights], dim=0) + logger.info( + f"Swapped scale/shift order for {name}: " + f"shape {tensor.shape} -> {swapped_tensor.shape}" + ) + yield name, swapped_tensor + elif name == "final_layer.adaLN_modulation.1.bias": + # Same swap for bias: (2 * hidden_size,) + half_size = tensor.shape[0] // 2 + shift_bias = tensor[:half_size] + scale_bias = tensor[half_size:] + swapped_tensor = torch.cat([scale_bias, shift_bias], dim=0) + logger.info( + f"Swapped scale/shift order for {name}: " + f"shape {tensor.shape} -> {swapped_tensor.shape}" + ) + yield name, swapped_tensor + else: + # Other weights pass through (handled by param_names_mapping) + yield name, tensor + + def _load_transformer_from_safetensors( + self, + server_args: ServerArgs, + loaded_modules: dict[str, torch.nn.Module] | None = None, + ) -> dict[str, Any]: + """ + Load transformer directly from safetensors file without model_index.json. + """ + if loaded_modules is not None and "transformer" in loaded_modules: + logger.info("Using provided transformer module") + components = { + "transformer": loaded_modules["transformer"], + "scheduler": self.modules.get("scheduler"), + } + return components + + if hasattr(server_args.pipeline_config, "dit_config"): + dit_config = server_args.pipeline_config.dit_config + if not isinstance(dit_config, FluxConfig): + logger.warning("dit_config is not FluxConfig, creating new FluxConfig") + dit_config = FluxConfig() + server_args.pipeline_config.dit_config = dit_config + else: + logger.info("Creating default FluxConfig") + dit_config = FluxConfig() + server_args.pipeline_config.dit_config = dit_config + + # Set guidance_embeds to True for ComfyUI Flux models + dit_config.arch_config.guidance_embeds = True + logger.info("Set guidance_embeds=True for ComfyUI Flux model") + + if dit_config.arch_config.param_names_mapping is None: + dit_config.arch_config.param_names_mapping = {} + + # ComfyUI Flux uses different parameter names than SGLang Flux + # Key differences: + # - ComfyUI: single_blocks.{i}.linear1 (fused QKV + MLP input) + # - SGLang: single_transformer_blocks.{i}.attn.to_qkv + proj_mlp (separate) + # - ComfyUI: single_blocks.{i}.linear2 + # - SGLang: single_transformer_blocks.{i}.proj_out + # - ComfyUI: double_blocks.{i}.img_attn.qkv / txt_attn.qkv + # - SGLang: transformer_blocks.{i}.attn.to_qkv / attn.to_added_qkv + + # Note: For fused layers like linear1, we need custom weight splitting logic + # which will be handled in the weight conversion function below + comfyui_flux_mappings = { + # Double stream blocks - attention layers + r"double_blocks\.(\d+)\.img_attn\.qkv\.(weight|bias)$": ( + r"transformer_blocks.\1.attn.to_qkv.\2", + None, + None, + ), + r"double_blocks\.(\d+)\.txt_attn\.qkv\.(weight|bias)$": ( + r"transformer_blocks.\1.attn.to_added_qkv.\2", + None, + None, + ), + r"double_blocks\.(\d+)\.img_attn\.proj\.(weight|bias)$": ( + r"transformer_blocks.\1.attn.to_out.0.\2", + None, + None, + ), + r"double_blocks\.(\d+)\.txt_attn\.proj\.(weight|bias)$": ( + r"transformer_blocks.\1.attn.to_add_out.\2", + None, + None, + ), + r"double_blocks\.(\d+)\.img_attn\.norm\.query_norm\.scale$": ( + r"transformer_blocks.\1.attn.norm_q.weight", + None, + None, + ), + r"double_blocks\.(\d+)\.img_attn\.norm\.key_norm\.scale$": ( + r"transformer_blocks.\1.attn.norm_k.weight", + None, + None, + ), + r"double_blocks\.(\d+)\.txt_attn\.norm\.query_norm\.scale$": ( + r"transformer_blocks.\1.attn.norm_added_q.weight", + None, + None, + ), + r"double_blocks\.(\d+)\.txt_attn\.norm\.key_norm\.scale$": ( + r"transformer_blocks.\1.attn.norm_added_k.weight", + None, + None, + ), + # Double stream blocks - MLP layers (map to net structure) + r"double_blocks\.(\d+)\.img_mlp\.0\.(weight|bias)$": ( + r"transformer_blocks.\1.ff.net.0.proj.\2", + None, + None, + ), + r"double_blocks\.(\d+)\.img_mlp\.2\.(weight|bias)$": ( + r"transformer_blocks.\1.ff.net.2.\2", + None, + None, + ), + r"double_blocks\.(\d+)\.txt_mlp\.0\.(weight|bias)$": ( + r"transformer_blocks.\1.ff_context.net.0.proj.\2", + None, + None, + ), + r"double_blocks\.(\d+)\.txt_mlp\.2\.(weight|bias)$": ( + r"transformer_blocks.\1.ff_context.net.2.\2", + None, + None, + ), + # Double stream blocks - modulation layers + r"double_blocks\.(\d+)\.img_mod\.lin\.(weight|bias)$": ( + r"transformer_blocks.\1.norm1.linear.\2", + None, + None, + ), + r"double_blocks\.(\d+)\.txt_mod\.lin\.(weight|bias)$": ( + r"transformer_blocks.\1.norm1_context.linear.\2", + None, + None, + ), + # Single stream blocks - linear2 maps to proj_out + r"single_blocks\.(\d+)\.linear2\.(weight|bias)$": ( + r"single_transformer_blocks.\1.proj_out.\2", + None, + None, + ), + # Single stream blocks - norm layers (scale -> weight) + r"single_blocks\.(\d+)\.norm\.query_norm\.scale$": ( + r"single_transformer_blocks.\1.attn.norm_q.weight", + None, + None, + ), + r"single_blocks\.(\d+)\.norm\.key_norm\.scale$": ( + r"single_transformer_blocks.\1.attn.norm_k.weight", + None, + None, + ), + # Single stream blocks - modulation (maps to norm.linear) + r"single_blocks\.(\d+)\.modulation\.lin\.(weight|bias)$": ( + r"single_transformer_blocks.\1.norm.linear.\2", + None, + None, + ), + # Time and guidance embeddings + r"^time_in\.in_layer\.(weight|bias)$": ( + r"time_text_embed.timestep_embedder.linear_1.\1", + None, + None, + ), + r"^time_in\.out_layer\.(weight|bias)$": ( + r"time_text_embed.timestep_embedder.linear_2.\1", + None, + None, + ), + r"^txt_in\.(weight|bias)$": (r"context_embedder.\1", None, None), + r"^vector_in\.in_layer\.(weight|bias)$": ( + r"time_text_embed.text_embedder.linear_1.\1", + None, + None, + ), + r"^vector_in\.out_layer\.(weight|bias)$": ( + r"time_text_embed.text_embedder.linear_2.\1", + None, + None, + ), + # Final layer mappings + r"^final_layer\.linear\.(weight|bias)$": (r"proj_out.\1", None, None), + r"^final_layer\.norm_final\.(weight|bias)$": (r"norm_out.\1", None, None), + r"^final_layer\.adaLN_modulation\.1\.(weight|bias)$": ( + r"norm_out.linear.\1", + None, + None, + ), + # Image input embedding + r"^img_in\.(weight|bias)$": (r"x_embedder.\1", None, None), + # Guidance embeddings (if model supports guidance) + r"^guidance_in\.in_layer\.(weight|bias)$": ( + r"time_text_embed.guidance_embedder.linear_1.\1", + None, + None, + ), + r"^guidance_in\.out_layer\.(weight|bias)$": ( + r"time_text_embed.guidance_embedder.linear_2.\1", + None, + None, + ), + } + + # Merge ComfyUI mappings with existing mappings (ComfyUI mappings take precedence) + updated_mapping = { + **dit_config.arch_config.param_names_mapping, + **comfyui_flux_mappings, + } + dit_config.arch_config.param_names_mapping = updated_mapping + logger.info( + "Added ComfyUI weight name mappings for Flux model. " + f"Total mappings: {len(updated_mapping)}" + ) + + cls_name = "FluxTransformer2DModel" + model_cls, _ = ModelRegistry.resolve_model_cls(cls_name) + logger.info("Resolved transformer class: %s", cls_name) + + original_mapping = None + if comfyui_flux_mappings: + original_mapping = model_cls.param_names_mapping + model_cls.param_names_mapping = updated_mapping + logger.info( + "Temporarily updated model class param_names_mapping with ComfyUI mappings. " + f"Total mappings: {len(updated_mapping)}" + ) + + safetensors_list = [self.model_path] + logger.info("Loading weights from: %s", safetensors_list) + default_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.dit_precision] + server_args.model_paths["transformer"] = os.path.dirname(self.model_path) or "." + hf_config = {} + + hidden_size = ( + dit_config.arch_config.num_attention_heads + * dit_config.arch_config.attention_head_dim + ) + mlp_ratio = getattr(dit_config.arch_config, "mlp_ratio", 4.0) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + qkv_size = 3 * hidden_size + has_guidance_embeds = True + + # Load and convert weights from safetensors file + model, reverse_param_names_mapping = ( + self._load_and_convert_weights_from_safetensors( + model_cls=model_cls, + dit_config=dit_config, + hf_config=hf_config, + safetensors_list=safetensors_list, + updated_mapping=updated_mapping, + qkv_size=qkv_size, + mlp_hidden_dim=mlp_hidden_dim, + has_guidance_embeds=has_guidance_embeds, + default_dtype=default_dtype, + ) + ) + + model = model.eval() + for param in model.parameters(): + param.requires_grad = False + + model.reverse_param_names_mapping = reverse_param_names_mapping + + if original_mapping is not None: + model_cls.param_names_mapping = original_mapping + + total_params = sum(p.numel() for p in model.parameters()) + logger.info("Loaded transformer with %.2fB parameters", total_params / 1e9) + + components = { + "transformer": model, + "scheduler": self.modules.get("scheduler"), + } + + logger.info("Successfully loaded modules: %s", list(components.keys())) + return components + + def create_pipeline_stages(self, server_args: ServerArgs): + logger.info( + "ComfyUIFluxPipeline.create_pipeline_stages() called - creating latent_preparation_stage and denoising_stage" + ) + + self.add_stages( + [ + ComfyUILatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer"), + ), + DenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + ), + ] + ) + + logger.info( + f"ComfyUIFluxPipeline stages created: {list(self._stage_name_mapping.keys())}" + ) + + +EntryClass = ComfyUIFluxPipeline diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/comfyui_qwen_image_pipeline.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/comfyui_qwen_image_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..138f74b286b2d2c55feafbffd3ec19f2bdf86a8a --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/comfyui_qwen_image_pipeline.py @@ -0,0 +1,346 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from itertools import chain +from typing import Any + +import torch +from torch.distributed import init_device_mesh +from torch.distributed.fsdp import MixedPrecisionPolicy + +from sglang.multimodal_gen.configs.models.dits.qwenimage import QwenImageDitConfig +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.loader.fsdp_load import ( + load_model_from_full_model_state_dict, + shard_model, +) +from sglang.multimodal_gen.runtime.loader.utils import ( + get_param_names_mapping, + set_default_torch_dtype, +) +from sglang.multimodal_gen.runtime.loader.weight_utils import ( + safetensors_weights_iterator, +) +from sglang.multimodal_gen.runtime.models.registry import ModelRegistry +from sglang.multimodal_gen.runtime.models.schedulers.scheduling_comfyui_passthrough import ( + ComfyUIPassThroughScheduler, +) +from sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages import ( + ComfyUILatentPreparationStage, + DenoisingStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE, set_mixed_precision_policy + +logger = init_logger(__name__) + + +class ComfyUIQwenImagePipelineBase(LoRAPipeline, ComposedPipelineBase): + """ + Base pipeline for ComfyUI QwenImage integration with only denoising stage. + + This pipeline requires pre-processed inputs: + - prompt_embeds: Pre-encoded text embeddings (list of tensors) + - latents: Pre-processed image latents in sequence format [B, S, D] + + Usage: + generator = DiffGenerator.from_pretrained( + model_path="path/to/model", + pipeline_class_name="ComfyUIQwenImagePipeline", + device="cuda", + ) + """ + + # Subclasses should override this + zero_cond_t: bool = False + + pipeline_name = "ComfyUIQwenImagePipeline" + + _required_config_modules = [ + "transformer", + "scheduler", + ] + + def initialize_pipeline(self, server_args: ServerArgs): + """ + Initialize the pipeline with ComfyUI pass-through scheduler. + This scheduler does not modify latents, allowing ComfyUI to handle denoising. + """ + self.modules["scheduler"] = ComfyUIPassThroughScheduler( + num_train_timesteps=1000 + ) + + # Ensure VAE config is properly initialized even though we don't load the VAE model + vae_config = server_args.pipeline_config.vae_config + vae_config.post_init() + logger.info( + "Called vae_config.post_init() to set vae_scale_factor. " + f"vae_scale_factor={vae_config.arch_config.vae_scale_factor}" + ) + + def load_modules( + self, + server_args: ServerArgs, + loaded_modules: dict[str, torch.nn.Module] | None = None, + ) -> dict[str, Any]: + """ + Load modules for ComfyUIQwenImagePipeline. + + If model_path is a safetensors file, load transformer directly from it + without requiring model_index.json. Otherwise, fall back to default loading. + """ + if os.path.isfile(self.model_path) and self.model_path.endswith(".safetensors"): + logger.info( + "Detected safetensors file, loading transformer directly from: %s", + self.model_path, + ) + return self._load_transformer_from_safetensors(server_args, loaded_modules) + else: + logger.info( + "Model path is a directory, using default loading method: %s", + self.model_path, + ) + return super().load_modules(server_args, loaded_modules) + + def _load_transformer_from_safetensors( + self, + server_args: ServerArgs, + loaded_modules: dict[str, torch.nn.Module] | None = None, + ) -> dict[str, Any]: + """Load transformer directly from safetensors without model_index.json.""" + + # 1) Fast path: use provided module + if loaded_modules is not None and "transformer" in loaded_modules: + logger.info("Using provided transformer module") + return { + "transformer": loaded_modules["transformer"], + "scheduler": self.modules.get("scheduler"), + } + + # 2) Build config and mappings + dit_config, updated_mapping, model_cls, default_dtype = ( + self._prepare_dit_config_and_mapping(server_args) + ) + safetensors_list = [self.model_path] + logger.info("Loading weights from: %s", safetensors_list) + + # 3) Instantiate model (meta) and optionally shard + model = self._instantiate_model( + model_cls, dit_config, default_dtype, updated_mapping, server_args + ) + + # 4) Load weights + self._load_weights_into_model( + model, safetensors_list, default_dtype, updated_mapping, server_args + ) + + components = { + "transformer": model, + "scheduler": self.modules.get("scheduler"), + } + logger.info("Successfully loaded modules: %s", list(components.keys())) + return components + + def _prepare_dit_config_and_mapping(self, server_args: ServerArgs): + from sglang.multimodal_gen.configs.models.dits.qwenimage import ( + QwenImageArchConfig, + ) + + comfyui_arch_config = QwenImageArchConfig( + patch_size=2, + in_channels=64, + out_channels=16, + num_layers=60, + attention_head_dim=128, + num_attention_heads=24, + joint_attention_dim=3584, + pooled_projection_dim=768, + guidance_embeds=False, + axes_dims_rope=(16, 56, 56), + zero_cond_t=self.zero_cond_t, + ) + dit_config = QwenImageDitConfig(arch_config=comfyui_arch_config) + server_args.pipeline_config.dit_config = dit_config + + if dit_config.arch_config.param_names_mapping is None: + dit_config.arch_config.param_names_mapping = {} + + comfyui_qwen_mappings = {r"^model\.diffusion_model\.(.*)$": r"\1"} + updated_mapping = { + **dit_config.arch_config.param_names_mapping, + **comfyui_qwen_mappings, + } + dit_config.arch_config.param_names_mapping = updated_mapping + logger.info( + "Added ComfyUI weight name mappings to param_names_mapping. " + f"Total mappings: {len(updated_mapping)}" + ) + + cls_name = "QwenImageTransformer2DModel" + model_cls, _ = ModelRegistry.resolve_model_cls(cls_name) + logger.info("Resolved transformer class: %s", cls_name) + + default_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.dit_precision] + server_args.model_paths["transformer"] = os.path.dirname(self.model_path) or "." + assert server_args.hsdp_shard_dim is not None, "hsdp_shard_dim must be set" + logger.info( + "Loading %s from safetensors file, default_dtype: %s", + cls_name, + default_dtype, + ) + return dit_config, updated_mapping, model_cls, default_dtype + + def _instantiate_model( + self, + model_cls, + dit_config, + default_dtype, + updated_mapping, + server_args: ServerArgs, + ): + from sglang.multimodal_gen.runtime.platforms import current_platform + + hf_config = {} + original_mapping = model_cls.param_names_mapping + model_cls.param_names_mapping = updated_mapping + logger.info( + "Temporarily updated model class param_names_mapping with ComfyUI mappings. " + f"Total mappings: {len(updated_mapping)}" + ) + + try: + mp_policy = MixedPrecisionPolicy( + torch.bfloat16, torch.float32, None, cast_forward_inputs=False + ) + set_mixed_precision_policy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + output_dtype=None, + mp_policy=mp_policy, + ) + + with set_default_torch_dtype(default_dtype), torch.device("meta"): + model = model_cls(**{"config": dit_config, "hf_config": hf_config}) + + use_fsdp = server_args.use_fsdp_inference + if current_platform.is_mps(): + use_fsdp = False + logger.info("Disabling FSDP for MPS platform as it's not compatible") + + if use_fsdp: + device_mesh = init_device_mesh( + current_platform.device_type, + mesh_shape=( + server_args.hsdp_replicate_dim, + server_args.hsdp_shard_dim, + ), + mesh_dim_names=("replicate", "shard"), + ) + shard_model( + model, + cpu_offload=server_args.dit_cpu_offload, + reshard_after_forward=True, + mp_policy=mp_policy, + mesh=device_mesh, + fsdp_shard_conditions=model._fsdp_shard_conditions, + pin_cpu_memory=server_args.pin_cpu_memory, + ) + finally: + model_cls.param_names_mapping = original_mapping + + return model + + def _load_weights_into_model( + self, + model, + safetensors_list, + default_dtype, + updated_mapping, + server_args: ServerArgs, + ): + # Create weight iterator for loading + weight_iterator = safetensors_weights_iterator(safetensors_list) + + # Load weights + param_names_mapping_fn = get_param_names_mapping(updated_mapping) + load_model_from_full_model_state_dict( + model, + weight_iterator, + get_local_torch_device(), + default_dtype, + strict=True, + cpu_offload=server_args.dit_cpu_offload, + param_names_mapping=param_names_mapping_fn, + ) + + # Check for meta parameters + for n, p in chain(model.named_parameters(), model.named_buffers()): + if p.is_meta: + raise RuntimeError(f"Unexpected param or buffer {n} on meta device.") + if isinstance(p, torch.nn.Parameter): + p.requires_grad = False + + total_params = sum(p.numel() for p in model.parameters()) + logger.info("Loaded transformer with %.2fB parameters", total_params / 1e9) + + def create_pipeline_stages(self, server_args: ServerArgs): + logger.info( + f"{self.__class__.__name__}.create_pipeline_stages() called - creating latent_preparation_stage and denoising_stage" + ) + + self.add_stages( + [ + ComfyUILatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer"), + ), + DenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + ), + ] + ) + + logger.info( + f"{self.__class__.__name__} stages created: {list(self._stage_name_mapping.keys())}" + ) + + +class ComfyUIQwenImagePipeline(ComfyUIQwenImagePipelineBase): + """ComfyUI QwenImage pipeline for text-to-image generation.""" + + pipeline_name = "ComfyUIQwenImagePipeline" + zero_cond_t = False + + from sglang.multimodal_gen.configs.pipeline_configs.qwen_image import ( + QwenImagePipelineConfig, + ) + from sglang.multimodal_gen.configs.sample.qwenimage import QwenImageSamplingParams + + pipeline_config_cls = QwenImagePipelineConfig + sampling_params_cls = QwenImageSamplingParams + + +class ComfyUIQwenImageEditPipeline(ComfyUIQwenImagePipelineBase): + """ComfyUI QwenImage pipeline for image-to-image editing.""" + + pipeline_name = "ComfyUIQwenImageEditPipeline" + zero_cond_t = True + + from sglang.multimodal_gen.configs.pipeline_configs.qwen_image import ( + QwenImageEditPlusPipelineConfig, + ) + from sglang.multimodal_gen.configs.sample.qwenimage import ( + QwenImageEditPlusSamplingParams, + ) + + pipeline_config_cls = QwenImageEditPlusPipelineConfig + sampling_params_cls = QwenImageEditPlusSamplingParams + + +EntryClass = [ComfyUIQwenImagePipeline, ComfyUIQwenImageEditPipeline] diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/comfyui_zimage_pipeline.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/comfyui_zimage_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..efe5e7430e8909510235740542bc74998c7b4a51 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/comfyui_zimage_pipeline.py @@ -0,0 +1,402 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo +# SPDX-License-Identifier: Apache-2.0 + +import os +import re +from collections.abc import Generator +from itertools import chain +from typing import Any + +import torch +from torch.distributed import init_device_mesh +from torch.distributed.fsdp import MixedPrecisionPolicy + +from sglang.multimodal_gen.configs.models.dits.zimage import ZImageDitConfig +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.loader.fsdp_load import ( + load_model_from_full_model_state_dict, + shard_model, +) +from sglang.multimodal_gen.runtime.loader.utils import ( + get_param_names_mapping, + set_default_torch_dtype, +) +from sglang.multimodal_gen.runtime.loader.weight_utils import ( + safetensors_weights_iterator, +) +from sglang.multimodal_gen.runtime.models.registry import ModelRegistry +from sglang.multimodal_gen.runtime.models.schedulers.scheduling_comfyui_passthrough import ( + ComfyUIPassThroughScheduler, +) +from sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages import ( + ComfyUILatentPreparationStage, + DenoisingStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE, set_mixed_precision_policy + +logger = init_logger(__name__) + + +class ComfyUIZImagePipeline(LoRAPipeline, ComposedPipelineBase): + """ + Simplified pipeline for ComfyUI integration with only denoising stage. + + This pipeline requires pre-processed inputs: + - prompt_embeds: Pre-encoded text embeddings (list of tensors) + - negative_prompt_embeds: Pre-encoded negative prompt embeddings (if using CFG) + - latents: Optional initial noise latents (will be generated if not provided) + + Usage: + generator = DiffGenerator.from_pretrained( + model_path="path/to/model", + pipeline_class_name="ComfyUIZImagePipeline", + device="cuda", + ) + """ + + pipeline_name = "ComfyUIZImagePipeline" + from sglang.multimodal_gen.configs.pipeline_configs.zimage import ( + ZImagePipelineConfig, + ) + from sglang.multimodal_gen.configs.sample.zimage import ZImageSamplingParams + + pipeline_config_cls = ZImagePipelineConfig + sampling_params_cls = ZImageSamplingParams + + _required_config_modules = [ + "transformer", + "scheduler", + ] + + def initialize_pipeline(self, server_args: ServerArgs): + """ + Initialize the pipeline with ComfyUI pass-through scheduler. + This scheduler does not modify latents, allowing ComfyUI to handle denoising. + """ + self.modules["scheduler"] = ComfyUIPassThroughScheduler( + num_train_timesteps=1000 + ) + + # Ensure VAE config is properly initialized even though we don't load the VAE model + # This is necessary because get_freqs_cis uses spatial_compression_ratio + if hasattr(server_args.pipeline_config, "vae_config"): + vae_config = server_args.pipeline_config.vae_config + if hasattr(vae_config, "post_init") and not hasattr( + vae_config, "_post_init_called" + ): + vae_config.post_init() + logger.info( + "Called vae_config.post_init() to set spatial_compression_ratio. " + f"spatial_compression_ratio={vae_config.arch_config.spatial_compression_ratio}" + ) + + def load_modules( + self, + server_args: ServerArgs, + loaded_modules: dict[str, torch.nn.Module] | None = None, + ) -> dict[str, Any]: + """ + Load modules for ComfyUIZImagePipeline. + + If model_path is a safetensors file, load transformer directly from it + without requiring model_index.json. Otherwise, fall back to default loading. + """ + if os.path.isfile(self.model_path) and self.model_path.endswith(".safetensors"): + logger.info( + "Detected safetensors file, loading transformer directly from: %s", + self.model_path, + ) + return self._load_transformer_from_safetensors(server_args, loaded_modules) + else: + logger.info( + "Model path is a directory, using default loading method: %s", + self.model_path, + ) + return super().load_modules(server_args, loaded_modules) + + def _convert_comfyui_qkv_weights( + self, + weight_iterator: Generator[tuple[str, torch.Tensor], None, None], + dim: int, + num_heads: int, + num_kv_heads: int, + ) -> Generator[tuple[str, torch.Tensor], None, None]: + """ + Convert ComfyUI zimage qkv weights to SGLang format. + Splits merged qkv.weight into separate to_q, to_k, to_v weights. + + Args: + weight_iterator: Iterator yielding (name, tensor) pairs from safetensors + dim: Model dimension + num_heads: Number of attention heads + num_kv_heads: Number of key-value heads + + Yields: + (name, tensor) pairs with qkv weights split into to_q, to_k, to_v + """ + head_dim = dim // num_heads + q_size = dim + k_size = head_dim * num_kv_heads + v_size = head_dim * num_kv_heads + + for name, tensor in weight_iterator: + # Match qkv weights in layers, noise_refiner, or context_refiner + # Pattern: (layers|noise_refiner|context_refiner).{i}.attention.qkv.(weight|bias) + match = re.match( + r"(layers|noise_refiner|context_refiner)\.(\d+)\.attention\.qkv\.(weight|bias)$", + name, + ) + if match: + module_name, layer_idx, param_type = match.groups() + base_name = f"{module_name}.{layer_idx}.attention" + + if param_type == "weight": + # Weight shape: (q_size + k_size + v_size, dim) + # Split into q, k, v + q_weight = tensor[:q_size, :] + k_weight = tensor[q_size : q_size + k_size, :] + v_weight = tensor[q_size + k_size :, :] + + logger.debug( + f"Splitting {name} (shape {tensor.shape}) into " + f"to_q ({q_weight.shape}), to_k ({k_weight.shape}), to_v ({v_weight.shape})" + ) + + yield f"{base_name}.to_q.weight", q_weight + yield f"{base_name}.to_k.weight", k_weight + yield f"{base_name}.to_v.weight", v_weight + else: # bias + # Bias shape: (q_size + k_size + v_size,) + # Split into q, k, v + q_bias = tensor[:q_size] + k_bias = tensor[q_size : q_size + k_size] + v_bias = tensor[q_size + k_size :] + + logger.debug( + f"Splitting {name} (shape {tensor.shape}) into " + f"to_q ({q_bias.shape}), to_k ({k_bias.shape}), to_v ({v_bias.shape})" + ) + + yield f"{base_name}.to_q.bias", q_bias + yield f"{base_name}.to_k.bias", k_bias + yield f"{base_name}.to_v.bias", v_bias + else: + # Pass through other weights unchanged + yield name, tensor + + def _load_transformer_from_safetensors( + self, + server_args: ServerArgs, + loaded_modules: dict[str, torch.nn.Module] | None = None, + ) -> dict[str, Any]: + """ + Load transformer directly from safetensors file without model_index.json. + + This method: + 1. Uses hardcoded ZImageDitConfig for zimage model + 2. Loads transformer from the safetensors file + 3. Uses ComfyUIPassThroughScheduler (already created in initialize_pipeline) + """ + # Check if transformer is already provided + if loaded_modules is not None and "transformer" in loaded_modules: + logger.info("Using provided transformer module") + components = { + "transformer": loaded_modules["transformer"], + "scheduler": self.modules.get("scheduler"), + } + return components + + if hasattr(server_args.pipeline_config, "dit_config"): + dit_config = server_args.pipeline_config.dit_config + if not isinstance(dit_config, ZImageDitConfig): + logger.warning( + "dit_config is not ZImageDitConfig, creating new ZImageDitConfig" + ) + dit_config = ZImageDitConfig() + server_args.pipeline_config.dit_config = dit_config + else: + logger.info("Creating default ZImageDitConfig") + dit_config = ZImageDitConfig() + server_args.pipeline_config.dit_config = dit_config + + if dit_config.arch_config.param_names_mapping is None: + dit_config.arch_config.param_names_mapping = {} + + # Add mappings for norm layers: map from ComfyUI format (k_norm/q_norm) to SGLang format (norm_k/norm_q) + # The regex matches the source name from safetensors, and the tuple specifies the target name in the model + # Note: qkv weights are handled separately by _convert_comfyui_qkv_weights function + comfyui_norm_mappings = { + r"(.*)\.attention\.k_norm\.weight$": ( + r"\1.attention.norm_k.weight", + None, + None, + ), + r"(.*)\.attention\.q_norm\.weight$": ( + r"\1.attention.norm_q.weight", + None, + None, + ), + r"(.*)\.attention\.out\.weight$": ( + r"\1.attention.to_out.0.weight", + None, + None, + ), + r"^final_layer\.(.*)$": (r"all_final_layer.2-1.\1", None, None), + r"^x_embedder\.(.*)$": (r"all_x_embedder.2-1.\1", None, None), + } + + # Merge ComfyUI mappings with existing mappings (ComfyUI mappings take precedence) + updated_mapping = { + **dit_config.arch_config.param_names_mapping, + **comfyui_norm_mappings, + } + dit_config.arch_config.param_names_mapping = updated_mapping + logger.info( + "Added ComfyUI weight name mappings (k_norm/q_norm -> norm_k/norm_q) to param_names_mapping. " + f"Total mappings: {len(updated_mapping)}" + ) + + cls_name = "ZImageTransformer2DModel" + model_cls, _ = ModelRegistry.resolve_model_cls(cls_name) + logger.info("Resolved transformer class: %s", cls_name) + safetensors_list = [self.model_path] + logger.info("Loading weights from: %s", safetensors_list) + + default_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.dit_precision] + server_args.model_paths["transformer"] = os.path.dirname(self.model_path) or "." + hf_config = {} + + assert server_args.hsdp_shard_dim is not None, "hsdp_shard_dim must be set" + logger.info( + "Loading %s from safetensors file, default_dtype: %s", + cls_name, + default_dtype, + ) + + original_mapping = model_cls.param_names_mapping + model_cls.param_names_mapping = updated_mapping + logger.info( + "Temporarily updated model class param_names_mapping with ComfyUI mappings. " + f"Total mappings: {len(updated_mapping)}" + ) + + try: + # Create model first (same as maybe_load_fsdp_model) + from sglang.multimodal_gen.runtime.platforms import current_platform + + mp_policy = MixedPrecisionPolicy( + torch.bfloat16, torch.float32, None, cast_forward_inputs=False + ) + + set_mixed_precision_policy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + output_dtype=None, + mp_policy=mp_policy, + ) + + with set_default_torch_dtype(default_dtype), torch.device("meta"): + model = model_cls(**{"config": dit_config, "hf_config": hf_config}) + + # Check if we should use FSDP + use_fsdp = server_args.use_fsdp_inference + if current_platform.is_mps(): + use_fsdp = False + logger.info("Disabling FSDP for MPS platform as it's not compatible") + + if use_fsdp: + world_size = server_args.hsdp_replicate_dim * server_args.hsdp_shard_dim + device_mesh = init_device_mesh( + current_platform.device_type, + mesh_shape=( + server_args.hsdp_replicate_dim, + server_args.hsdp_shard_dim, + ), + mesh_dim_names=("replicate", "shard"), + ) + shard_model( + model, + cpu_offload=server_args.dit_cpu_offload, + reshard_after_forward=True, + mp_policy=mp_policy, + mesh=device_mesh, + fsdp_shard_conditions=model._fsdp_shard_conditions, + pin_cpu_memory=server_args.pin_cpu_memory, + ) + + # Get model dimensions for qkv splitting + arch_config = dit_config.arch_config + dim = arch_config.dim + num_heads = arch_config.num_attention_heads + num_kv_heads = arch_config.n_kv_heads + + # Create weight iterator with qkv conversion + base_weight_iterator = safetensors_weights_iterator(safetensors_list) + converted_weight_iterator = self._convert_comfyui_qkv_weights( + base_weight_iterator, dim, num_heads, num_kv_heads + ) + + # Load weights + param_names_mapping_fn = get_param_names_mapping(updated_mapping) + load_model_from_full_model_state_dict( + model, + converted_weight_iterator, + get_local_torch_device(), + default_dtype, + strict=True, + cpu_offload=server_args.dit_cpu_offload, + param_names_mapping=param_names_mapping_fn, + ) + + # Check for meta parameters + for n, p in chain(model.named_parameters(), model.named_buffers()): + if p.is_meta: + raise RuntimeError( + f"Unexpected param or buffer {n} on meta device." + ) + if isinstance(p, torch.nn.Parameter): + p.requires_grad = False + finally: + model_cls.param_names_mapping = original_mapping + + total_params = sum(p.numel() for p in model.parameters()) + logger.info("Loaded transformer with %.2fB parameters", total_params / 1e9) + + components = { + "transformer": model, + "scheduler": self.modules.get("scheduler"), + } + + logger.info("Successfully loaded modules: %s", list(components.keys())) + return components + + def create_pipeline_stages(self, server_args: ServerArgs): + logger.info( + "ComfyUIZImagePipeline.create_pipeline_stages() called - creating latent_preparation_stage and denoising_stage" + ) + + self.add_stages( + [ + ComfyUILatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer"), + ), + DenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + ), + ] + ) + + logger.info( + f"ComfyUIZImagePipeline stages created: {list(self._stage_name_mapping.keys())}" + ) + + +EntryClass = ComfyUIZImagePipeline diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..f4f6f87cb622165f060ab9ecf29d5cf7fbe3670c --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py @@ -0,0 +1,736 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Diffusers backend pipeline wrapper. + +This module provides a wrapper that allows running any diffusers-supported model +through sglang's infrastructure using vanilla diffusers pipelines. +""" + +import argparse +import inspect +import re +import warnings +from io import BytesIO +from typing import Any + +import numpy as np +import requests +import torch +import torchvision.transforms as T +from diffusers import DiffusionPipeline +from PIL import Image + +from sglang.multimodal_gen.configs.pipeline_configs.base import PipelineConfig +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines_core.executors.pipeline_executor import ( + PipelineExecutor, +) +from sglang.multimodal_gen.runtime.pipelines_core.executors.sync_executor import ( + SyncExecutor, +) +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages import PipelineStage +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class DiffusersExecutionStage(PipelineStage): + """Pipeline stage that wraps diffusers pipeline execution.""" + + def __init__(self, diffusers_pipe: DiffusionPipeline): + super().__init__() + self.diffusers_pipe = diffusers_pipe + + def forward(self, batch: Req, server_args: ServerArgs) -> Req: + """Execute the diffusers pipeline.""" + + kwargs = self._build_pipeline_kwargs(batch, server_args) + + # Filter kwargs to only those supported by the pipeline, warn about ignored args + kwargs, _ = self._filter_pipeline_kwargs(kwargs) + + # Request tensor output for cleaner handling + if "output_type" not in kwargs: + kwargs["output_type"] = "pt" + + with torch.no_grad(), warnings.catch_warnings(record=True): + warnings.simplefilter("always") + try: + output = self.diffusers_pipe(**kwargs) + except TypeError as e: + # Some pipelines don't support output_type="pt" + if "output_type" in str(e): + kwargs.pop("output_type", None) + output = self.diffusers_pipe(**kwargs) + else: + raise + + batch.output = self._extract_output(output) + if batch.output is not None: + batch.output = self._postprocess_output(batch.output) + + return batch + + def _filter_pipeline_kwargs( + self, kwargs: dict, *, strict: bool = False + ) -> tuple[dict, list[str]]: + """Filter kwargs to those accepted by the pipeline's __call__. + + Args: + kwargs: Arguments to filter + strict: If True, raise ValueError on unsupported args; otherwise warn + + Returns: + Tuple of (filtered_kwargs, ignored_keys) + """ + try: + sig = inspect.signature(self.diffusers_pipe.__call__) + except (ValueError, TypeError): + return kwargs, [] + + params = sig.parameters + accepts_var_kwargs = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values() + ) + if accepts_var_kwargs: + return kwargs, [] + + valid = set(params.keys()) - {"self"} + + filtered = {} + ignored = [] + for k, v in kwargs.items(): + if k in valid: + filtered[k] = v + else: + ignored.append(k) + + if ignored: + pipe_name = type(self.diffusers_pipe).__name__ + msg = ( + f"Pipeline '{pipe_name}' does not support: {', '.join(sorted(ignored))}. " + "These arguments will be ignored." + ) + if strict: + raise ValueError(msg) + logger.warning(msg) + + return filtered, ignored + + def _extract_output(self, output: Any) -> torch.Tensor | None: + """Extract tensor output from pipeline result.""" + for attr in ["images", "frames", "video", "sample", "pred_original_sample"]: + if not hasattr(output, attr): + continue + + data = getattr(output, attr) + if data is None: + continue + + result = self._convert_to_tensor(data) + if result is not None: + logger.debug( + "Extracted output from '%s': shape=%s, dtype=%s", + attr, + result.shape, + result.dtype, + ) + return result + + logger.warning("Could not extract output from pipeline result") + return None + + def _convert_to_tensor(self, data: Any) -> torch.Tensor | None: + """Convert various data formats to a tensor.""" + if isinstance(data, torch.Tensor): + return data + + if isinstance(data, np.ndarray): + tensor = torch.from_numpy(data).float() + if tensor.max() > 1.0: + tensor = tensor / 255.0 + # (B, H, W, C) -> (B, C, H, W) or (B, T, H, W, C) -> (B, C, T, H, W) + if tensor.ndim == 4: + tensor = tensor.permute(0, 3, 1, 2) + elif tensor.ndim == 5: + tensor = tensor.permute(0, 4, 1, 2, 3) + return tensor + + if hasattr(data, "mode"): # PIL Image + return T.ToTensor()(data) + + if isinstance(data, list) and len(data) > 0: + return self._convert_list_to_tensor(data) + + return None + + def _convert_list_to_tensor(self, data: list) -> torch.Tensor | None: + """Convert a list of items to a tensor.""" + first = data[0] + + # Nested list (e.g., [[frame1, frame2, ...]] for video batches) + if isinstance(first, list) and len(first) > 0: + data = first + first = data[0] + + if hasattr(first, "mode"): # PIL images + tensors = [T.ToTensor()(img) for img in data] + stacked = torch.stack(tensors) + if len(tensors) > 1: + return stacked.permute(1, 0, 2, 3) # (T, C, H, W) -> (C, T, H, W) + return stacked[0] + + if isinstance(first, torch.Tensor): + stacked = torch.stack(data) + if len(data) > 1: + return stacked.permute(1, 0, 2, 3) + return stacked[0] + + if isinstance(first, np.ndarray): + tensors = [torch.from_numpy(arr).float() for arr in data] + if tensors[0].max() > 1.0: + tensors = [t / 255.0 for t in tensors] + if tensors[0].ndim == 3: + tensors = [t.permute(2, 0, 1) for t in tensors] + stacked = torch.stack(tensors) + if len(data) > 1: + return stacked.permute(1, 0, 2, 3) + return stacked[0] + + return None + + def _postprocess_output(self, output: torch.Tensor) -> torch.Tensor: + """Post-process output tensor to ensure valid values and correct shape.""" + output = output.cpu().float() + + # Handle NaN or Inf values + if torch.isnan(output).any() or torch.isinf(output).any(): + logger.warning("Output contains invalid values, fixing...") + output = torch.nan_to_num(output, nan=0.5, posinf=1.0, neginf=0.0) + + # Normalize to [0, 1] range if needed + min_val, max_val = output.min().item(), output.max().item() + if min_val < -0.5 or max_val > 1.5: + output = (output + 1) / 2 + + output = output.clamp(0, 1) + + # Ensure correct shape for downstream processing + output = self._fix_output_shape(output) + + logger.debug("Final output tensor shape: %s", output.shape) + return output + + def _fix_output_shape(self, output: torch.Tensor) -> torch.Tensor: + """Fix tensor shape for downstream processing. + + Expected: (B, C, H, W) for images or (B, C, T, H, W) for videos. + """ + if output.dim() == 5: + # Video: (B, T, C, H, W) -> (B, C, T, H, W) + return output.permute(0, 2, 1, 3, 4) + + if output.dim() == 4: + if output.shape[0] == 1 or output.shape[1] in [1, 3, 4]: + return output # Already (B, C, H, W) + # (T, C, H, W) -> (1, C, T, H, W) + return output.unsqueeze(0).permute(0, 2, 1, 3, 4) + + if output.dim() == 3: + c, h, w = output.shape + if c > 4 and w <= 4: + output = output.permute(2, 0, 1) + if output.shape[0] == 1: + output = output.repeat(3, 1, 1) + return output.unsqueeze(0) + + if output.dim() == 2: + return output.unsqueeze(0).repeat(3, 1, 1).unsqueeze(0) + + return output + + def _build_pipeline_kwargs(self, batch: Req, server_args: ServerArgs) -> dict: + """Build kwargs dict for diffusers pipeline call.""" + kwargs = {} + + if batch.prompt is not None: + kwargs["prompt"] = batch.prompt + + if batch.negative_prompt: + kwargs["negative_prompt"] = batch.negative_prompt + + if batch.num_inference_steps is not None: + kwargs["num_inference_steps"] = batch.num_inference_steps + + if batch.guidance_scale is not None: + kwargs["guidance_scale"] = batch.guidance_scale + + if batch.true_cfg_scale is not None: + kwargs["true_cfg_scale"] = batch.true_cfg_scale + + if batch.height is not None: + kwargs["height"] = batch.height + + if batch.width is not None: + kwargs["width"] = batch.width + + if batch.num_frames is not None and batch.num_frames > 1: + kwargs["num_frames"] = batch.num_frames + + # Generator for reproducibility + if batch.generator is not None: + kwargs["generator"] = batch.generator + elif batch.seed is not None: + device = self._get_pipeline_device() + kwargs["generator"] = torch.Generator(device=device).manual_seed(batch.seed) + + # Image input for img2img or inpainting + image = self._load_input_image(batch) + if image is not None: + kwargs["image"] = image + + if batch.num_outputs_per_prompt > 1: + kwargs["num_images_per_prompt"] = batch.num_outputs_per_prompt + + # Extra diffusers-specific kwargs + if batch.extra: + diffusers_kwargs = batch.extra.get("diffusers_kwargs", {}) + if diffusers_kwargs: + kwargs.update(diffusers_kwargs) + + return kwargs + + def _get_pipeline_device(self) -> str: + """Get the device the pipeline is running on.""" + for attr in ["unet", "transformer", "vae"]: + component = getattr(self.diffusers_pipe, attr, None) + if component is not None: + try: + return next(component.parameters()).device + except StopIteration: + pass + return current_platform.device_type + + def _load_input_image(self, batch: Req) -> Image.Image | None: + """Load input image from batch.""" + # Check for PIL image in condition_image or pixel_values + if batch.condition_image is not None and isinstance( + batch.condition_image, Image.Image + ): + return batch.condition_image + if batch.pixel_values is not None and isinstance( + batch.pixel_values, Image.Image + ): + return batch.pixel_values + + if not batch.image_path: + return None + + if isinstance(batch.image_path, list): + batch.image_path = batch.image_path[0] + + try: + if batch.image_path.startswith(("http://", "https://")): + response = requests.get(batch.image_path, timeout=30) + response.raise_for_status() + return Image.open(BytesIO(response.content)).convert("RGB") + return Image.open(batch.image_path).convert("RGB") + except Exception as e: + logger.error("Failed to load image from %s: %s", batch.image_path, e) + return None + + +class DiffusersPipeline(ComposedPipelineBase): + """ + Pipeline wrapper that uses vanilla diffusers pipelines. + + This allows running any diffusers-supported model through sglang's infrastructure + without requiring native sglang implementation. + """ + + pipeline_name = "DiffusersPipeline" + is_video_pipeline = False + _required_config_modules: list[str] = [] + + def __init__( + self, + model_path: str, + server_args: ServerArgs, + required_config_modules: list[str] | None = None, + loaded_modules: dict[str, torch.nn.Module] | None = None, + executor: PipelineExecutor | None = None, + ): + self.server_args = server_args + self.model_path = model_path + self._stages: list[PipelineStage] = [] + self._stage_name_mapping: dict[str, PipelineStage] = {} + self.modules: dict[str, Any] = {} + self.memory_usages: dict[str, float] = {} + self.post_init_called = False + self.executor = executor or SyncExecutor(server_args=server_args) + self._cache_dit_enabled = False + + logger.info("Loading diffusers pipeline from %s", model_path) + self.diffusers_pipe = self._load_diffusers_pipeline(model_path, server_args) + self._detect_pipeline_type() + + def _load_diffusers_pipeline(self, model_path: str, server_args: ServerArgs) -> Any: + """Load the diffusers pipeline. + + Optimizations applied: + - device_map: Loads models directly to GPU, warming up CUDA caching allocator + to avoid small tensor allocations during inference. + - Parallel shard loading: When using device_map with accelerate, model shards + are loaded in parallel for faster initialization. + """ + + original_model_path = model_path # Keep original for custom_pipeline + model_path = maybe_download_model(model_path, force_diffusers_model=True) + self.model_path = model_path + + dtype = self._get_dtype(server_args) + logger.info("Loading diffusers pipeline with dtype=%s", dtype) + + # Build common kwargs for from_pretrained + load_kwargs = { + "torch_dtype": dtype, + "trust_remote_code": server_args.trust_remote_code, + "revision": server_args.revision, + } + + # Add quantization config if provided (e.g., BitsAndBytesConfig for 4/8-bit) + config = server_args.pipeline_config + if config is not None: + quant_config = getattr(config, "quantization_config", None) + if quant_config is not None: + load_kwargs["quantization_config"] = quant_config + logger.info( + "Using quantization config: %s", type(quant_config).__name__ + ) + + try: + pipe = DiffusionPipeline.from_pretrained(model_path, **load_kwargs) + except AttributeError as e: + if "has no attribute" in str(e): + # Custom pipeline class not in diffusers - try loading with custom_pipeline + logger.info( + "Pipeline class not found in diffusers, trying custom_pipeline from repo..." + ) + try: + custom_kwargs = { + **load_kwargs, + "custom_pipeline": original_model_path, + } + custom_kwargs["trust_remote_code"] = True + pipe = DiffusionPipeline.from_pretrained( + model_path, **custom_kwargs + ) + except Exception as e2: + match = re.search(r"has no attribute (\w+)", str(e)) + class_name = match.group(1) if match else "unknown" + raise RuntimeError( + f"Pipeline class '{class_name}' not found in diffusers and no custom pipeline.py in repo. " + f"Try: pip install --upgrade diffusers (some pipelines require latest version). " + f"Original error: {e}" + ) from e2 + else: + raise + except Exception as e: + # Only retry with float32 for dtype-related errors + if "dtype" in str(e).lower() or "float" in str(e).lower(): + logger.warning( + "Failed with dtype=%s, falling back to float32: %s", dtype, e + ) + load_kwargs["torch_dtype"] = torch.float32 + pipe = DiffusionPipeline.from_pretrained(model_path, **load_kwargs) + else: + raise + + pipe = pipe.to(get_local_torch_device()) + # Apply VAE memory optimizations from pipeline config + self._apply_vae_optimizations(pipe, server_args) + # Apply attention backend if specified + self._apply_attention_backend(pipe, server_args) + # Apply cache-dit acceleration if configured + pipe = self._apply_cache_dit(pipe, server_args) + # Apply torch.compile if enabled and supported + pipe = self._apply_torch_compile(pipe, server_args) + logger.info("Loaded diffusers pipeline: %s", pipe.__class__.__name__) + return pipe + + def _apply_vae_optimizations(self, pipe: Any, server_args: ServerArgs) -> None: + """Apply VAE memory optimizations (tiling, slicing) from pipeline config.""" + config = server_args.pipeline_config + if config is None: + return + + # VAE slicing: decode latents slice-by-slice for lower peak memory + # https://huggingface.co/docs/diffusers/optimization/memory#vae-slicing + if getattr(config, "vae_slicing", False): + if hasattr(pipe, "enable_vae_slicing"): + pipe.enable_vae_slicing() + logger.info("Enabled VAE slicing for lower memory usage") + + # VAE tiling: decode latents tile-by-tile for large images + # https://huggingface.co/docs/diffusers/optimization/memory#vae-tiling + if getattr(config, "vae_tiling", False): + if hasattr(pipe, "enable_vae_tiling"): + pipe.enable_vae_tiling() + logger.info("Enabled VAE tiling for large image support") + + def _apply_attention_backend(self, pipe: Any, server_args: ServerArgs) -> None: + """Apply attention backend setting from pipeline config or server_args. + + See: https://huggingface.co/docs/diffusers/main/en/optimization/attention_backends + Available backends: flash, _flash_3_hub, sage, xformers, native, etc. + """ + backend = server_args.attention_backend + + if backend is None: + return + + backend = backend.lower() + sglang_backends = {e.name.lower() for e in AttentionBackendEnum} | { + "fa3", + "fa4", + } + if backend in sglang_backends: + logger.debug( + "Skipping diffusers attention backend '%s' because it matches a " + "SGLang backend name. Use diffusers backend names when running " + "the diffusers backend.", + backend, + ) + return + + for component_name in ["transformer", "unet"]: + component = getattr(pipe, component_name, None) + if component is not None and hasattr(component, "set_attention_backend"): + try: + component.set_attention_backend(backend) + logger.info( + "Set attention backend '%s' on %s", backend, component_name + ) + except Exception as e: + logger.warning( + "Failed to set attention backend '%s' on %s: %s", + backend, + component_name, + e, + ) + + def _apply_cache_dit(self, pipe: Any, server_args: ServerArgs) -> Any: + """Enable cache-dit for diffusers pipeline if configured.""" + cache_dit_config = server_args.cache_dit_config + if not cache_dit_config: + return pipe + + try: + import cache_dit + except ImportError as e: + raise RuntimeError( + "cache-dit is required for --cache-dit-config. " + "Install it with `pip install cache-dit`." + ) from e + + if not hasattr(cache_dit, "load_configs"): + raise RuntimeError( + "cache-dit>=1.2.0 is required for --cache-dit-config. " + "Please upgrade cache-dit." + ) + + try: + cache_options = cache_dit.load_configs(cache_dit_config) + except Exception as e: + raise ValueError( + "Failed to load cache-dit config. Provide a YAML/JSON path (or a dict " + "supported by cache-dit>=1.2.0)." + ) from e + + try: + pipe = cache_dit.enable_cache(pipe, **cache_options) + except Exception: + # cache-dit is an external integration and can raise a variety of errors. + logger.exception("Failed to enable cache-dit for diffusers pipeline") + raise + + logger.info("Enabled cache-dit for diffusers pipeline") + self._cache_dit_enabled = True + return pipe + + def _apply_torch_compile(self, pipe: Any, server_args: ServerArgs) -> Any: + """Apply torch.compile to the pipeline if configured and supported.""" + if not server_args.enable_torch_compile: + return pipe + + # check if the pipeline has 'transformer' or 'unet' components which are + # typically the most expensive parts to compile. 'transformer_2' for some + # video pipelines, e.g, Wan 2.2 series, also check for that. + compilable_components = ["transformer", "transformer_2", "unet"] + if not any(hasattr(pipe, comp) for comp in compilable_components): + logger.warning( + "Pipeline does not have 'transformer' or 'unet' components. " + "torch.compile may not provide significant benefits and could increase latency." + ) + return pipe + + if self._cache_dit_enabled: + try: + import cache_dit + + if hasattr(cache_dit, "set_compile_configs"): + cache_dit.set_compile_configs() + except Exception as e: + logger.warning( + f"Failed to set torch_compile configs for cache-dit: {e}" + ) + + for comp in compilable_components: + if hasattr(pipe, comp): + try: + component = getattr(pipe, comp) + # TODO(DefTruth): Add support for 'compile_repeated_blocks' for 'transformer' + # modules which can significantly reduce compilation time for large models + # with repeated blocks. + if isinstance(component, torch.nn.Module) and hasattr( + component, "compile" + ): + # Prefer in-place compilation if supported. According to PyTorch documentation: + # https://docs.pytorch.org/docs/stable/generated/torch.compile.html + component.compile() + else: + compiled_component = torch.compile(component) + setattr(pipe, comp, compiled_component) + logger.info( + f"Applied torch.compile to {comp} component of the pipeline" + ) + except Exception as e: + logger.warning(f"Failed to apply torch.compile to {comp}: {e}") + + return pipe + + def _get_dtype(self, server_args: ServerArgs) -> torch.dtype: + """ + Determine the dtype to use for model loading. + """ + dtype = ( + torch.bfloat16 + if torch.get_device_module().is_bf16_supported() + else torch.float16 + ) + + if hasattr(server_args, "pipeline_config") and server_args.pipeline_config: + dit_precision = server_args.pipeline_config.dit_precision + if dit_precision == "fp16": + dtype = torch.float16 + elif dit_precision == "bf16": + dtype = torch.bfloat16 + elif dit_precision == "fp32": + dtype = torch.float32 + + return dtype + + def _detect_pipeline_type(self): + """Detect if this is an image or video pipeline.""" + pipe_class_name = self.diffusers_pipe.__class__.__name__.lower() + video_indicators = ["video", "animat", "cogvideo", "wan", "hunyuan"] + self.is_video_pipeline = any(ind in pipe_class_name for ind in video_indicators) + logger.debug( + "Detected pipeline type: %s", + "video" if self.is_video_pipeline else "image", + ) + + def load_modules( + self, + server_args: ServerArgs, + loaded_modules: dict[str, torch.nn.Module] | None = None, + ) -> dict[str, Any]: + """Skip sglang's module loading - diffusers handles it.""" + return {"diffusers_pipeline": self.diffusers_pipe} + + def create_pipeline_stages(self, server_args: ServerArgs): + """Create the execution stage wrapping the diffusers pipeline.""" + self.add_stage( + DiffusersExecutionStage(self.diffusers_pipe), "diffusers_execution" + ) + + def initialize_pipeline(self, server_args: ServerArgs): + """Initialize the pipeline.""" + pass + + def post_init(self) -> None: + """Post initialization hook.""" + if self.post_init_called: + return + self.post_init_called = True + self.initialize_pipeline(self.server_args) + self.create_pipeline_stages(self.server_args) + + def add_stage( + self, stage: PipelineStage, stage_name: str | None = None + ) -> "DiffusersPipeline": + """Add a stage to the pipeline.""" + if stage_name is None: + stage_name = self._infer_stage_name(stage) + if stage_name in self._stage_name_mapping: + raise ValueError(f"Duplicate stage name detected: {stage_name}") + + self._stages.append(stage) + self._stage_name_mapping[stage_name] = stage + return self + + @property + def stages(self) -> list[PipelineStage]: + """List of stages in the pipeline.""" + return self._stages + + @torch.no_grad() + def forward(self, batch: Req, server_args: ServerArgs) -> Req: + """Execute the pipeline on the given batch.""" + if not self.post_init_called: + self.post_init() + return self.executor.execute(self.stages, batch, server_args) + + @classmethod + def from_pretrained( + cls, + model_path: str, + device: str | None = None, + torch_dtype: torch.dtype | None = None, + pipeline_config: str | PipelineConfig | None = None, + args: argparse.Namespace | None = None, + required_config_modules: list[str] | None = None, + loaded_modules: dict[str, torch.nn.Module] | None = None, + **kwargs, + ) -> "DiffusersPipeline": + """Load a pipeline from a pretrained model using diffusers backend.""" + kwargs["model_path"] = model_path + server_args = ServerArgs.from_kwargs(**kwargs) + + pipe = cls( + model_path, + server_args, + required_config_modules=required_config_modules, + loaded_modules=loaded_modules, + ) + pipe.post_init() + return pipe + + def get_module(self, module_name: str, default_value: Any = None) -> Any: + """Get a module by name.""" + if module_name == "diffusers_pipeline": + return self.diffusers_pipe + return self.modules.get(module_name, default_value) + + +EntryClass = DiffusersPipeline diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/flux.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/flux.py new file mode 100644 index 0000000000000000000000000000000000000000..371e1c44e12c605b0ab9a6a2b60a868e287481c6 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/flux.py @@ -0,0 +1,90 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +from sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages import ( + InputValidationStage, + TextEncodingStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +# TODO(will): move PRECISION_TO_TYPE to better place + +logger = init_logger(__name__) + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def prepare_mu(batch: Req, server_args: ServerArgs): + height = batch.height + width = batch.width + vae_scale_factor = ( + server_args.pipeline_config.vae_config.arch_config.vae_scale_factor + ) + image_seq_len = (int(height) // vae_scale_factor) * (int(width) // vae_scale_factor) + + mu = calculate_shift( + image_seq_len, + # hard code, since scheduler_config is not in PipelineConfig now + 256, + 4096, + 0.5, + 1.15, + ) + return "mu", mu + + +class FluxPipeline(LoRAPipeline, ComposedPipelineBase): + pipeline_name = "FluxPipeline" + + _required_config_modules = [ + "text_encoder", + "text_encoder_2", + "tokenizer", + "tokenizer_2", + "vae", + "transformer", + "scheduler", + ] + + def create_pipeline_stages(self, server_args: ServerArgs): + self.add_stage(InputValidationStage()) + + self.add_stage( + TextEncodingStage( + text_encoders=[ + self.get_module("text_encoder"), + self.get_module("text_encoder_2"), + ], + tokenizers=[ + self.get_module("tokenizer"), + self.get_module("tokenizer_2"), + ], + ), + "prompt_encoding_stage_primary", + ) + + self.add_standard_timestep_preparation_stage(prepare_extra_kwargs=[prepare_mu]) + self.add_standard_latent_preparation_stage() + self.add_standard_denoising_stage() + self.add_standard_decoding_stage() + + +EntryClass = FluxPipeline diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/flux_2.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/flux_2.py new file mode 100644 index 0000000000000000000000000000000000000000..78beb90e31d039549ed1bff28c87cd6e20e682f6 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/flux_2.py @@ -0,0 +1,62 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo +# SPDX-License-Identifier: Apache-2.0 + +from diffusers.image_processor import VaeImageProcessor + +from sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline, Req +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def compute_empirical_mu(batch: Req, server_args: ServerArgs): + num_steps = batch.num_inference_steps + image_seq_len = batch.raw_latent_shape[1] + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return "mu", float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return "mu", float(mu) + + +class Flux2Pipeline(LoRAPipeline, ComposedPipelineBase): + pipeline_name = "Flux2Pipeline" + + _required_config_modules = [ + "text_encoder", + "tokenizer", + "vae", + "transformer", + "scheduler", + ] + + def create_pipeline_stages(self, server_args: ServerArgs): + vae_image_processor = VaeImageProcessor( + vae_scale_factor=server_args.pipeline_config.vae_config.arch_config.vae_scale_factor + * 2 + ) + + self.add_standard_ti2i_stages( + include_input_validation=True, + vae_image_processor=vae_image_processor, + prompt_encoding="text", + image_vae_stage_kwargs={"vae_image_processor": vae_image_processor}, + prepare_extra_timestep_kwargs=[compute_empirical_mu], + ) + + +EntryClass = Flux2Pipeline diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/flux_2_klein.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/flux_2_klein.py new file mode 100644 index 0000000000000000000000000000000000000000..ff2d6d7d9b3bb7af9d6dd634cd009b185af949a4 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/flux_2_klein.py @@ -0,0 +1,8 @@ +from sglang.multimodal_gen.runtime.pipelines.flux_2 import Flux2Pipeline + + +class Flux2KleinPipeline(Flux2Pipeline): + pipeline_name = "Flux2KleinPipeline" + + +EntryClass = Flux2KleinPipeline diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/glm_image.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/glm_image.py new file mode 100644 index 0000000000000000000000000000000000000000..df3e58e3633819d4197453209281c8258a859a8d --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/glm_image.py @@ -0,0 +1,52 @@ +from sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages import DenoisingStage +from sglang.multimodal_gen.runtime.pipelines_core.stages.model_specific_stages.glm_image import ( + GlmImageBeforeDenoisingStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class GlmImagePipeline(LoRAPipeline, ComposedPipelineBase): + pipeline_name = "GlmImagePipeline" + + _required_config_modules = [ + "text_encoder", + "tokenizer", + "vae", + "vision_language_encoder", + "processor", + "transformer", + "scheduler", + ] + + def create_pipeline_stages(self, server_args: ServerArgs): + self.add_stage( + GlmImageBeforeDenoisingStage( + vae=self.get_module("vae"), + text_encoder=self.get_module("text_encoder"), + tokenizer=self.get_module("tokenizer"), + processor=self.get_module("processor"), + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + vision_language_encoder=self.get_module("vision_language_encoder"), + ), + "glm_image_before_denoising_stage", + ) + + self.add_stage( + DenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + ), + ) + + self.add_standard_decoding_stage() + + +EntryClass = [GlmImagePipeline] diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/helios_pipeline.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/helios_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..eb6a9e398c6d566fbe949abdbdeec932612cba45 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/helios_pipeline.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Helios video diffusion pipeline implementation. + +This module contains an implementation of the Helios video diffusion pipeline +using the modular pipeline architecture. Phase 1: T2V only. +""" + +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines_core.lora_pipeline import LoRAPipeline +from sglang.multimodal_gen.runtime.pipelines_core.stages import ( + InputValidationStage, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.model_specific_stages.helios_denoising import ( + HeliosChunkedDenoisingStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class HeliosPipeline(LoRAPipeline, ComposedPipelineBase): + """ + Helios video diffusion pipeline with LoRA support. + + Implements the Helios T2V pipeline with chunked denoising, + multi-term memory history, and CFG Zero Star guidance. + """ + + pipeline_name = "HeliosPipeline" + + _required_config_modules = [ + "text_encoder", + "tokenizer", + "vae", + "transformer", + "scheduler", + ] + + def initialize_pipeline(self, server_args: ServerArgs): + # Use the scheduler loaded from model's scheduler_config.json as-is. + # It contains critical config: use_dynamic_shifting=true, + # time_shift_type="exponential", etc. + scheduler = self.modules.get("scheduler") + if scheduler is not None and server_args.pipeline_config.flow_shift is not None: + scheduler.set_shift(server_args.pipeline_config.flow_shift) + + # Configure scheduler for Stage 2/3 if enabled + pipeline_config = server_args.pipeline_config + if scheduler is not None and pipeline_config.is_enable_stage2: + scheduler.config.stages = pipeline_config.pyramid_num_stages + scheduler.config.scheduler_type = pipeline_config.scheduler_type + scheduler.config.gamma = pipeline_config.gamma + scheduler.init_sigmas_for_each_stage() + + def create_pipeline_stages(self, server_args: ServerArgs) -> None: + self.add_stage(InputValidationStage()) + self.add_standard_text_encoding_stage() + self.add_standard_latent_preparation_stage() + # Skip standard timestep preparation — the Helios denoising stage + # handles scheduler.set_timesteps internally per-chunk with mu. + self.add_stage( + HeliosChunkedDenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.modules["scheduler"], + ), + "helios_chunked_denoising_stage", + ) + # Standard DecodingStage handles VAE decode of the denoised latents + self.add_standard_decoding_stage() + + +class HeliosPyramidPipeline(HeliosPipeline): + """Helios pyramid SR pipeline (used by Helios-Mid and Helios-Distilled).""" + + pipeline_name = "HeliosPyramidPipeline" + + +EntryClass = [HeliosPipeline, HeliosPyramidPipeline] diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/hunyuan3d_pipeline.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/hunyuan3d_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..c227ef89ac6fcfafc8fddcb995c7d491243a5b40 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/hunyuan3d_pipeline.py @@ -0,0 +1,411 @@ +""" +Hunyuan3D image-to-mesh pipeline implementation. + +Shape pipeline: BeforeDenoising -> Denoising -> Export -> Save +Paint pipeline (optional): Preprocess -> TexGen -> Postprocess +""" + +from __future__ import annotations + +import glob +import importlib +import os +from itertools import chain +from typing import Any + +import torch +import torch.nn as nn + +from sglang.multimodal_gen.configs.pipeline_configs.hunyuan3d import ( + Hunyuan3D2PipelineConfig, +) +from sglang.multimodal_gen.runtime.loader.fsdp_load import ( + load_model_from_full_model_state_dict, + set_default_torch_dtype, +) +from sglang.multimodal_gen.runtime.loader.utils import get_param_names_mapping +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages import ( + Hunyuan3DPaintPostprocessStage, + Hunyuan3DPaintPreprocessStage, + Hunyuan3DPaintTexGenStage, + Hunyuan3DShapeBeforeDenoisingStage, + Hunyuan3DShapeDenoisingStage, + Hunyuan3DShapeExportStage, + Hunyuan3DShapeSaveStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class Hunyuan3D2Pipeline(ComposedPipelineBase): + """Hunyuan3D 2.0 image-to-mesh pipeline. + + Shape pipeline: BeforeDenoising -> Denoising -> Export -> Save + Paint pipeline (optional): Preprocess -> TexGen -> Postprocess + """ + + pipeline_name = "Hunyuan3D2Pipeline" + _required_config_modules = [ + "hy3dshape_model", + "hy3dshape_vae", + "hy3dshape_scheduler", + "hy3dshape_conditioner", + "hy3dshape_image_processor", + ] + + def _load_config(self) -> dict[str, Any]: + return { + "_class_name": self.pipeline_name, + "_diffusers_version": "0.0.0", + "hy3dshape_model": ["diffusers", "Hunyuan3DShapeModel"], + "hy3dshape_vae": ["diffusers", "Hunyuan3DShapeVAE"], + "hy3dshape_scheduler": ["diffusers", "Hunyuan3DShapeScheduler"], + "hy3dshape_conditioner": ["diffusers", "Hunyuan3DShapeConditioner"], + "hy3dshape_image_processor": ["diffusers", "Hunyuan3DShapeImageProcessor"], + } + + # Class resolution + @staticmethod + def _resolve_class(target: str) -> Any: + """Resolve a YAML target string to a Python class.""" + from sglang.multimodal_gen.runtime.models.registry import ModelRegistry + + cls = ModelRegistry.resolve_by_alias(target) + if cls is not None: + return cls + + class_name = target.rsplit(".", 1)[-1] + try: + cls, _ = ModelRegistry.resolve_model_cls(class_name) + return cls + except Exception: + pass + + from sglang.multimodal_gen.runtime.utils.mesh3d_utils import ( + resolve_hunyuan3d_tool, + ) + + for name in (target, class_name): + tool_cls = resolve_hunyuan3d_tool(name) + if tool_cls is not None: + return tool_cls + + module, cls_name = target.rsplit(".", 1) + return getattr(importlib.import_module(module, package=None), cls_name) + + # Path / checkpoint resolution + @staticmethod + def _resolve_shape_dir( + model_path: str, + subfolder: str, + use_safetensors: bool, + variant: str | None, + ) -> tuple[str, str]: + """Locate (or download) the shape subfolder and return (config_path, ckpt_path).""" + local_path = os.path.join(model_path, subfolder) + if not os.path.exists(local_path): + local_path = os.path.expanduser(local_path) + + if not os.path.exists(local_path): + logger.info( + "Local path %s not found, downloading from HuggingFace Hub", + local_path, + ) + from huggingface_hub import snapshot_download + + downloaded = snapshot_download( + repo_id=model_path, + allow_patterns=[f"{subfolder}/*"], + ) + local_path = os.path.join(downloaded, subfolder) + + config_path = os.path.join(local_path, "config.yaml") + if not os.path.exists(config_path): + for alt in ("config.yml", "model_config.yaml"): + alt_path = os.path.join(local_path, alt) + if os.path.exists(alt_path): + config_path = alt_path + break + + if use_safetensors: + ckpt_name = ( + f"model.{variant}.safetensors" if variant else "model.safetensors" + ) + else: + ckpt_name = f"model-{variant}.ckpt" if variant else "model.ckpt" + + ckpt_path = os.path.join(local_path, ckpt_name) + if not os.path.exists(ckpt_path): + pattern = "*.safetensors" if use_safetensors else "*.ckpt" + files = glob.glob(os.path.join(local_path, pattern)) + if files: + ckpt_path = files[0] + + logger.info("Config path: %s", config_path) + logger.info("Checkpoint path: %s", ckpt_path) + return config_path, ckpt_path + + @staticmethod + def _resolve_paint_dir(model_path: str, subfolder: str) -> str: + """Locate (or download) the paint subfolder and return its local path.""" + local_path = os.path.join(model_path, subfolder) + if not os.path.exists(local_path): + local_path = os.path.expanduser(local_path) + + if not os.path.exists(local_path): + logger.info( + "Local path %s not found, downloading from HuggingFace Hub", + local_path, + ) + from huggingface_hub import snapshot_download + + downloaded = snapshot_download( + repo_id=model_path, + allow_patterns=[f"{subfolder}/*"], + ) + local_path = os.path.join(downloaded, subfolder) + + for subdir in ("vae", "unet"): + config_file = os.path.join(local_path, subdir, "config.json") + if not os.path.exists(config_file): + raise FileNotFoundError( + f"Paint model incomplete: {config_file} not found. " + "Download the model or check network connectivity." + ) + + logger.info("Resolved paint model directory: %s", local_path) + return local_path + + @staticmethod + def _load_and_split_checkpoint( + ckpt_path: str, use_safetensors: bool + ) -> dict[str, dict[str, torch.Tensor]]: + """Load a bundled checkpoint and split by the first '.' in each key.""" + if use_safetensors: + import safetensors.torch + + flat = safetensors.torch.load_file(ckpt_path, device="cpu") + ckpt: dict[str, dict[str, torch.Tensor]] = {} + for key, value in flat.items(): + component = key.split(".")[0] + sub_key = key[len(component) + 1 :] + ckpt.setdefault(component, {})[sub_key] = value + return ckpt + else: + return torch.load(ckpt_path, map_location="cpu", weights_only=True) + + # Component loading helpers + @classmethod + def _load_dit_model( + cls, + cfg: dict[str, Any], + weights: dict[str, torch.Tensor], + device: torch.device, + dtype: torch.dtype, + ) -> nn.Module: + """Load the DiT model using meta-device instantiation + standard weight loading.""" + if "target" not in cfg: + raise KeyError("Expected key 'target' in model config.") + target_cls = cls._resolve_class(cfg["target"]) + params = cfg.get("params", {}) + + if hasattr(target_cls, "build_config_from_params"): + dit_config = target_cls.build_config_from_params(params) + init_kwargs: dict[str, Any] = {"config": dit_config, "hf_config": {}} + else: + init_kwargs = params + + with set_default_torch_dtype(dtype), torch.device("meta"): + model = target_cls(**init_kwargs) + + weight_iterator = ((k, v) for k, v in weights.items()) + param_names_mapping_fn = get_param_names_mapping(model.param_names_mapping) + + load_model_from_full_model_state_dict( + model, + weight_iterator, + device, + dtype, + strict=False, + param_names_mapping=param_names_mapping_fn, + ) + + for name, p in chain(model.named_parameters(), model.named_buffers()): + if p.is_meta: + raise RuntimeError(f"Unexpected param or buffer {name} on meta device.") + if isinstance(p, nn.Parameter): + p.requires_grad = False + + return model.eval() + + @classmethod + def _load_simple_component( + cls, + cfg: dict[str, Any], + weights: dict[str, torch.Tensor] | None, + device: torch.device, + dtype: torch.dtype, + ) -> nn.Module: + """Load a component (VAE / conditioner) with direct instantiation + state_dict.""" + if "target" not in cfg: + raise KeyError("Expected key 'target' in component config.") + target_cls = cls._resolve_class(cfg["target"]) + params = cfg.get("params", {}) + + with set_default_torch_dtype(dtype): + component = target_cls(**params) + + if weights is not None: + component.load_state_dict(weights, strict=False) + + component.to(device=device, dtype=dtype) + return component.eval() + + @classmethod + def _instantiate_component(cls, cfg: dict[str, Any]) -> Any: + """Instantiate a lightweight component (scheduler / image_processor) without weights.""" + if "target" not in cfg: + raise KeyError("Expected key 'target' in component config.") + target_cls = cls._resolve_class(cfg["target"]) + params = cfg.get("params", {}) + return target_cls(**params) + + # Module loading override + def load_modules( + self, + server_args: ServerArgs, + loaded_modules: dict[str, torch.nn.Module] | None = None, + ) -> dict[str, Any]: + """Load all Hunyuan3D shape components from a bundled checkpoint.""" + import yaml + + from sglang.multimodal_gen.runtime.distributed import get_local_torch_device + + config = server_args.pipeline_config + if not isinstance(config, Hunyuan3D2PipelineConfig): + raise TypeError(f"Expected Hunyuan3D2PipelineConfig, got {type(config)}") + + model_path = config.shape_model_path or server_args.model_path + + logger.info("Loading Hunyuan3D shape models from %s", model_path) + + config_path, ckpt_path = self._resolve_shape_dir( + model_path, + config.shape_subfolder, + config.shape_use_safetensors, + config.shape_variant, + ) + + with open(config_path, "r") as f: + model_config = yaml.safe_load(f) + + ckpt = self._load_and_split_checkpoint(ckpt_path, config.shape_use_safetensors) + + dtype = torch.float16 + if config.shape_variant and "bf16" in config.shape_variant: + dtype = torch.bfloat16 + device = get_local_torch_device() + + components: dict[str, Any] = {} + + components["hy3dshape_model"] = self._load_dit_model( + model_config["model"], ckpt["model"], device, dtype + ) + + components["hy3dshape_vae"] = self._load_simple_component( + model_config["vae"], ckpt.get("vae"), device, dtype + ) + + components["hy3dshape_conditioner"] = self._load_simple_component( + model_config["conditioner"], ckpt.get("conditioner"), device, dtype + ) + + components["hy3dshape_scheduler"] = self._instantiate_component( + model_config["scheduler"] + ) + components["hy3dshape_image_processor"] = self._instantiate_component( + model_config["image_processor"] + ) + + logger.info("All Hunyuan3D shape components loaded successfully") + + if config.paint_enable: + try: + paint_dir = self._resolve_paint_dir( + server_args.model_path, config.paint_subfolder + ) + components["hy3dpaint_dir"] = paint_dir + except Exception as e: + logger.warning("Failed to resolve paint model path: %s", e) + + return components + + # Pipeline lifecycle + def initialize_pipeline(self, server_args: ServerArgs): + config = server_args.pipeline_config + if not isinstance(config, Hunyuan3D2PipelineConfig): + raise TypeError( + "Hunyuan3D2Pipeline requires Hunyuan3D2PipelineConfig, " + f"got {type(config)}" + ) + + def create_pipeline_stages(self, server_args: ServerArgs): + config = server_args.pipeline_config + assert isinstance(config, Hunyuan3D2PipelineConfig) + + # Shape: 4 stages + self.add_stage( + stage_name="shape_before_denoising", + stage=Hunyuan3DShapeBeforeDenoisingStage( + image_processor=self.get_module("hy3dshape_image_processor"), + conditioner=self.get_module("hy3dshape_conditioner"), + vae=self.get_module("hy3dshape_vae"), + model=self.get_module("hy3dshape_model"), + scheduler=self.get_module("hy3dshape_scheduler"), + config=config, + ), + ) + self.add_stage( + stage_name="shape_denoising", + stage=Hunyuan3DShapeDenoisingStage( + transformer=self.get_module("hy3dshape_model"), + scheduler=self.get_module("hy3dshape_scheduler"), + ), + ) + self.add_stage( + stage_name="shape_export", + stage=Hunyuan3DShapeExportStage( + vae=self.get_module("hy3dshape_vae"), + config=config, + ), + ) + self.add_stage( + stage_name="shape_save", + stage=Hunyuan3DShapeSaveStage(config=config), + ) + + # Paint: 3 stages (optional) + if config.paint_enable: + self.add_stage( + stage_name="paint_preprocess", + stage=Hunyuan3DPaintPreprocessStage(config=config), + ) + self.add_stage( + stage_name="paint_texgen", + stage=Hunyuan3DPaintTexGenStage( + config=config, + paint_dir=self.get_module("hy3dpaint_dir"), + ), + ) + self.add_stage( + stage_name="paint_postprocess", + stage=Hunyuan3DPaintPostprocessStage(config=config), + ) + + +EntryClass = Hunyuan3D2Pipeline diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/hunyuan_pipeline.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/hunyuan_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..301c759f4008363d552f062ba6d2634f83b0ea07 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/hunyuan_pipeline.py @@ -0,0 +1,61 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Hunyuan video diffusion pipeline implementation. + +This module contains an implementation of the Hunyuan video diffusion pipeline +using the modular pipeline architecture. +""" + +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages import ( + InputValidationStage, + TextEncodingStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +# TODO(will): move PRECISION_TO_TYPE to better place + +logger = init_logger(__name__) + + +class HunyuanVideoPipeline(ComposedPipelineBase): + + pipeline_name = "HunyuanVideoPipeline" + + _required_config_modules = [ + "text_encoder", + "text_encoder_2", + "tokenizer", + "tokenizer_2", + "vae", + "transformer", + "scheduler", + ] + + def create_pipeline_stages(self, server_args: ServerArgs): + self.add_stage(InputValidationStage()) + self.add_stage( + TextEncodingStage( + text_encoders=[ + self.get_module("text_encoder"), + self.get_module("text_encoder_2"), + ], + tokenizers=[ + self.get_module("tokenizer"), + self.get_module("tokenizer_2"), + ], + ), + "prompt_encoding_stage_primary", + ) + self.add_standard_timestep_preparation_stage() + self.add_standard_latent_preparation_stage() + self.add_standard_denoising_stage() + self.add_standard_decoding_stage() + + +EntryClass = HunyuanVideoPipeline diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/ltx_2_pipeline.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/ltx_2_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..d8798ecf2fc3ce30a464a52f64792b7b187bc362 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/ltx_2_pipeline.py @@ -0,0 +1,171 @@ +import inspect +import json +import math +import os + +import numpy as np +import torch +from diffusers import FlowMatchEulerDiscreteScheduler + +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages import ( + InputValidationStage, + LTX2AVDecodingStage, + LTX2AVDenoisingStage, + LTX2AVLatentPreparationStage, + LTX2TextConnectorStage, + TextEncodingStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def prepare_mu(batch: Req, server_args: ServerArgs): + height = batch.height + width = batch.width + num_frames = batch.num_frames + + vae_arch = getattr( + getattr(server_args.pipeline_config, "vae_config", None), "arch_config", None + ) + vae_scale_factor = ( + getattr(vae_arch, "spatial_compression_ratio", None) + or getattr(vae_arch, "vae_scale_factor", None) + or getattr(server_args.pipeline_config, "vae_scale_factor", None) + ) + vae_temporal_compression = getattr( + vae_arch, "temporal_compression_ratio", None + ) or getattr(server_args.pipeline_config, "vae_temporal_compression", None) + + # Values from LTX2Pipeline in diffusers + mu = calculate_shift( + 4096, + base_seq_len=1024, + max_seq_len=4096, + base_shift=0.95, + max_shift=2.05, + ) + return "mu", mu + + +def _load_component_config(model_path: str, component_name: str): + """Helper to load component config from model_index.json or config.json""" + try: + # Try loading model_index.json first + index_path = os.path.join(model_path, "model_index.json") + if os.path.exists(index_path): + with open(index_path, "r") as f: + index = json.load(f) + + if component_name in index: + # It's a subfolder + subfolder = index[component_name][1] + config_path = os.path.join(model_path, subfolder, "config.json") + if os.path.exists(config_path): + with open(config_path, "r") as f: + return json.load(f) + + # Fallback to direct config.json in subfolder if standard structure + config_path = os.path.join(model_path, component_name, "config.json") + if os.path.exists(config_path): + with open(config_path, "r") as f: + return json.load(f) + + except Exception as e: + logger.warning(f"Failed to load config for {component_name}: {e}") + + return {} + + +def _filter_kwargs_for_cls(cls, kwargs): + """Filter kwargs to only include those accepted by cls.__init__""" + sig = inspect.signature(cls.__init__) + return {k: v for k, v in kwargs.items() if k in sig.parameters} + + +class LTX2FlowMatchScheduler(FlowMatchEulerDiscreteScheduler): + """Override ``_time_shift_exponential`` to use torch f32 instead of numpy f64.""" + + def _time_shift_exponential(self, mu, sigma, t): + if isinstance(t, np.ndarray): + t_torch = torch.from_numpy(t).to(torch.float32) + result = math.exp(mu) / (math.exp(mu) + (1 / t_torch - 1) ** sigma) + return result.numpy() + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +class LTX2Pipeline(ComposedPipelineBase): + # NOTE: must match `model_index.json`'s `_class_name` for native dispatch. + pipeline_name = "LTX2Pipeline" + + _required_config_modules = [ + "transformer", + "text_encoder", + "tokenizer", + "scheduler", + "vae", + "audio_vae", + "vocoder", + "connectors", + ] + + def initialize_pipeline(self, server_args: ServerArgs): + orig = self.get_module("scheduler") + self.modules["scheduler"] = LTX2FlowMatchScheduler.from_config(orig.config) + + def create_pipeline_stages(self, server_args: ServerArgs): + self.add_stages( + [ + InputValidationStage(), + TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + ), + LTX2TextConnectorStage(connectors=self.get_module("connectors")), + ] + ) + + self.add_standard_timestep_preparation_stage(prepare_extra_kwargs=[prepare_mu]) + + self.add_stages( + [ + LTX2AVLatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer"), + audio_vae=self.get_module("audio_vae"), + ), + LTX2AVDenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + vae=self.get_module("vae"), + audio_vae=self.get_module("audio_vae"), + ), + LTX2AVDecodingStage( + vae=self.get_module("vae"), + audio_vae=self.get_module("audio_vae"), + vocoder=self.get_module("vocoder"), + pipeline=self, + ), + ] + ) + + +EntryClass = LTX2Pipeline diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/mova_pipeline.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/mova_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..57e1c50296323a10e69689263190a1953f9d6d5f --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/mova_pipeline.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +MOVA pipeline integration (native SGLang pipeline). +""" + +from __future__ import annotations + +from sglang.multimodal_gen.configs.pipeline_configs.mova import MOVAPipelineConfig +from sglang.multimodal_gen.configs.sample.mova import MOVASamplingParams +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages import ( + ImageVAEEncodingStage, + InputValidationStage, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.model_specific_stages.mova import ( + MOVADecodingStage, + MOVADenoisingStage, + MOVALatentPreparationStage, + MOVATimestepPreparationStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class MOVAPipeline(ComposedPipelineBase): + """MOVA pipeline with SGLang stage orchestration.""" + + pipeline_name = "MOVA" + is_video_pipeline = True + _required_config_modules = [ + "video_vae", + "audio_vae", + "text_encoder", + "tokenizer", + "scheduler", + "video_dit", + "video_dit_2", + "audio_dit", + "dual_tower_bridge", + ] + pipeline_config_cls = MOVAPipelineConfig + sampling_params_cls = MOVASamplingParams + + def initialize_pipeline(self, server_args: ServerArgs) -> None: + """ + Initialize the pipeline. + + MOVA supports Context Parallel (sequence parallel) through USPAttention, + which uses Ulysses-style all-to-all communication for distributed attention. + """ + if server_args.sp_degree > 1: + logger.info( + "MOVA Context Parallel enabled with sp_degree=%d. " + "Using USPAttention for distributed self-attention.", + server_args.sp_degree, + ) + + def create_pipeline_stages(self, server_args: ServerArgs) -> None: + self.add_stage(InputValidationStage()) + self.add_standard_text_encoding_stage() + if getattr(self.get_module("video_dit"), "require_vae_embedding", True): + self.add_stage(ImageVAEEncodingStage(vae=self.get_module("video_vae"))) + self.add_stage( + MOVALatentPreparationStage( + audio_vae=self.get_module("audio_vae"), + require_vae_embedding=getattr( + self.get_module("video_dit"), "require_vae_embedding", True + ), + ), + "mova_latent_preparation_stage", + ) + self.add_stage( + MOVATimestepPreparationStage( + scheduler=self.get_module("scheduler"), + ), + "mova_timestep_preparation_stage", + ) + self.add_stage( + MOVADenoisingStage( + video_dit=self.get_module("video_dit"), + video_dit_2=self.get_module("video_dit_2"), + audio_dit=self.get_module("audio_dit"), + dual_tower_bridge=self.get_module("dual_tower_bridge"), + scheduler=self.get_module("scheduler"), + ), + "mova_denoising_stage", + ) + self.add_stage( + MOVADecodingStage( + video_vae=self.get_module("video_vae"), + audio_vae=self.get_module("audio_vae"), + ), + "mova_decoding_stage", + ) + + +class MOVAPipelineAlias(MOVAPipeline): + pipeline_name = "MOVAPipeline" + + +EntryClass = [MOVAPipeline, MOVAPipelineAlias] diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/qwen_image.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/qwen_image.py new file mode 100644 index 0000000000000000000000000000000000000000..81695190a8f5bdf9e13028f0a2d38d32c274c280 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/qwen_image.py @@ -0,0 +1,140 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from diffusers.image_processor import VaeImageProcessor + +from sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.model_specific_stages.qwen_image_layered import ( + QwenImageLayeredBeforeDenoisingStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +# TODO(will): move PRECISION_TO_TYPE to better place + +logger = init_logger(__name__) + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def prepare_mu(batch: Req, server_args: ServerArgs): + height = batch.height + width = batch.width + vae_scale_factor = server_args.pipeline_config.vae_config.vae_scale_factor + image_seq_len = (int(height) // vae_scale_factor // 2) * ( + int(width) // vae_scale_factor // 2 + ) + mu = calculate_shift( + image_seq_len, + # hard code, since scheduler_config is not in PipelineConfig now + 256, + 8192, + 0.5, + 0.9, + ) + return "mu", mu + + +class QwenImagePipeline(LoRAPipeline, ComposedPipelineBase): + pipeline_name = "QwenImagePipeline" + + _required_config_modules = [ + "text_encoder", + "tokenizer", + "vae", + "transformer", + "scheduler", + ] + + def create_pipeline_stages(self, server_args: ServerArgs): + self.add_standard_t2i_stages(prepare_extra_timestep_kwargs=[prepare_mu]) + + +class QwenImageEditPipeline(LoRAPipeline, ComposedPipelineBase): + pipeline_name = "QwenImageEditPipeline" + + _required_config_modules = [ + "processor", + "scheduler", + "text_encoder", + "tokenizer", + "transformer", + "vae", + ] + + def create_pipeline_stages(self, server_args: ServerArgs): + vae_image_processor = VaeImageProcessor( + vae_scale_factor=server_args.pipeline_config.vae_config.arch_config.vae_scale_factor + * 2 + ) + + self.add_standard_ti2i_stages( + vae_image_processor=vae_image_processor, + prompt_encoding="image_encoding", + image_processor_key="processor", + prompt_text_encoder_key="text_encoder", + prepare_extra_timestep_kwargs=[prepare_mu], + ) + + +class QwenImageEditPlusPipeline(QwenImageEditPipeline): + pipeline_name = "QwenImageEditPlusPipeline" + + +def prepare_mu_layered(batch: Req, server_args: ServerArgs): + base_seqlen = 256 * 256 / 16 / 16 + mu = (batch.image_latent.shape[1] / base_seqlen) ** 0.5 + return "mu", mu + + +class QwenImageLayeredPipeline(QwenImageEditPipeline): + pipeline_name = "QwenImageLayeredPipeline" + + _required_config_modules = [ + "vae", + "tokenizer", + "processor", + "transformer", + "scheduler", + ] + + def create_pipeline_stages(self, server_args: ServerArgs): + self.add_stage( + QwenImageLayeredBeforeDenoisingStage( + vae=self.get_module("vae"), + tokenizer=self.get_module("tokenizer"), + processor=self.get_module("processor"), + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + model_path=self.model_path, + ) + ) + + self.add_standard_timestep_preparation_stage( + prepare_extra_kwargs=[prepare_mu_layered] + ) + self.add_standard_denoising_stage() + self.add_standard_decoding_stage() + + +EntryClass = [ + QwenImagePipeline, + QwenImageEditPipeline, + QwenImageEditPlusPipeline, + QwenImageLayeredPipeline, +] diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/wan_causal_dmd_pipeline.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/wan_causal_dmd_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..97f1f050b1e8a938e2df9638dfacabcccb485ed8 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/wan_causal_dmd_pipeline.py @@ -0,0 +1,54 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Wan causal DMD pipeline implementation. + +This module wires the causal DMD denoising stage into the modular pipeline. +""" + +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines_core.lora_pipeline import LoRAPipeline + +# isort: off +from sglang.multimodal_gen.runtime.pipelines_core.stages import ( + CausalDMDDenoisingStage, + InputValidationStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +# isort: on + +logger = init_logger(__name__) + + +class WanCausalDMDPipeline(LoRAPipeline, ComposedPipelineBase): + pipeline_name = "WanCausalDMDPipeline" + + _required_config_modules = [ + "text_encoder", + "tokenizer", + "vae", + "transformer", + "scheduler", + ] + + def create_pipeline_stages(self, server_args: ServerArgs) -> None: + self.add_stage(InputValidationStage()) + self.add_standard_text_encoding_stage() + self.add_standard_latent_preparation_stage() + + self.add_stage( + CausalDMDDenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + ), + ) + + self.add_standard_decoding_stage() + + +EntryClass = WanCausalDMDPipeline diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/wan_dmd_pipeline.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/wan_dmd_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..73017f5b6f9bd8352b5c5c61b89eb23e1943e7ec --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/wan_dmd_pipeline.py @@ -0,0 +1,77 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Wan video diffusion pipeline implementation. + +This module contains an implementation of the Wan video diffusion pipeline +using the modular pipeline architecture. +""" + +from sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines_core.lora_pipeline import LoRAPipeline +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +# isort: off +from sglang.multimodal_gen.runtime.pipelines_core.stages import ( + DmdDenoisingStage, + InputValidationStage, +) + +# isort: on + +logger = init_logger(__name__) + + +class WanDMDPipeline(LoRAPipeline, ComposedPipelineBase): + """ + Wan video diffusion pipeline with LoRA support. + """ + + pipeline_name = "WanDMDPipeline" + + _required_config_modules = [ + "text_encoder", + "tokenizer", + "vae", + "transformer", + "scheduler", + ] + + def initialize_pipeline(self, server_args: ServerArgs): + + self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler( + shift=server_args.pipeline_config.flow_shift + ) + + def create_pipeline_stages(self, server_args: ServerArgs) -> None: + self.add_stages( + [ + InputValidationStage(), + ] + ) + + self.add_standard_text_encoding_stage() + + self.add_standard_timestep_preparation_stage() + self.add_standard_latent_preparation_stage() + + self.add_stages( + [ + DmdDenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + ), + ] + ) + + self.add_standard_decoding_stage() + + +EntryClass = WanDMDPipeline diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/wan_i2v_dmd_pipeline.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/wan_i2v_dmd_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..cb954d32790b71bcbc3c69eae3ff897ab041a3b0 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/wan_i2v_dmd_pipeline.py @@ -0,0 +1,54 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Wan video diffusion pipeline implementation. + +This module contains an implementation of the Wan video diffusion pipeline +using the modular pipeline architecture. +""" + +from sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines_core.lora_pipeline import LoRAPipeline +from sglang.multimodal_gen.runtime.pipelines_core.stages import DmdDenoisingStage +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class WanImageToVideoDmdPipeline(LoRAPipeline, ComposedPipelineBase): + pipeline_name = "WanImageToVideoDmdPipeline" + + _required_config_modules = [ + "text_encoder", + "tokenizer", + "vae", + "transformer", + "scheduler", + "image_encoder", + "image_processor", + ] + + def initialize_pipeline(self, server_args: ServerArgs): + self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler( + shift=server_args.pipeline_config.flow_shift + ) + + def create_pipeline_stages(self, server_args: ServerArgs): + self.add_standard_ti2v_stages( + image_vae_encoding_position="after_latent", + denoising_stage_factory=lambda: DmdDenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + transformer_2=self.get_module("transformer_2"), + ), + ) + + +EntryClass = WanImageToVideoDmdPipeline diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/wan_i2v_pipeline.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/wan_i2v_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..984da085a8cd92b391a29af7b67aa7df2b878abc --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/wan_i2v_pipeline.py @@ -0,0 +1,46 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Wan video diffusion pipeline implementation. + +This module contains an implementation of the Wan video diffusion pipeline +using the modular pipeline architecture. +""" + +from sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_unipc_multistep import ( + FlowUniPCMultistepScheduler, +) +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines_core.lora_pipeline import LoRAPipeline +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class WanImageToVideoPipeline(LoRAPipeline, ComposedPipelineBase): + pipeline_name = "WanImageToVideoPipeline" + + _required_config_modules = [ + "text_encoder", + "tokenizer", + "vae", + "transformer", + "scheduler", + "image_encoder", + "image_processor", + ] + + def initialize_pipeline(self, server_args: ServerArgs): + self.modules["scheduler"] = FlowUniPCMultistepScheduler( + shift=server_args.pipeline_config.flow_shift + ) + + def create_pipeline_stages(self, server_args: ServerArgs): + self.add_standard_ti2v_stages() + + +EntryClass = WanImageToVideoPipeline diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/wan_pipeline.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/wan_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..b52045754388b5aa4d89f84cb8266b82da171cab --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/wan_pipeline.py @@ -0,0 +1,49 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Wan video diffusion pipeline implementation. + +This module contains an implementation of the Wan video diffusion pipeline +using the modular pipeline architecture. +""" + +from sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_unipc_multistep import ( + FlowUniPCMultistepScheduler, +) +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines_core.lora_pipeline import LoRAPipeline +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class WanPipeline(LoRAPipeline, ComposedPipelineBase): + """ + Wan video diffusion pipeline with LoRA support. + """ + + pipeline_name = "WanPipeline" + + _required_config_modules = [ + "text_encoder", + "tokenizer", + "vae", + "transformer", + "scheduler", + ] + + def initialize_pipeline(self, server_args: ServerArgs): + # We use UniPCMScheduler from Wan2.1 official repo, not the one in diffusers. + self.modules["scheduler"] = FlowUniPCMultistepScheduler( + shift=server_args.pipeline_config.flow_shift + ) + + def create_pipeline_stages(self, server_args: ServerArgs) -> None: + self.add_standard_t2i_stages() + + +EntryClass = WanPipeline diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines/zimage_pipeline.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines/zimage_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..8f1f714749f3d4ec13eccda0b84de817ae1010ac --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines/zimage_pipeline.py @@ -0,0 +1,61 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo +# SPDX-License-Identifier: Apache-2.0 + + +from sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline, Req +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def prepare_mu(batch: Req, server_args: ServerArgs): + height = batch.height + width = batch.width + vae_scale_factor = server_args.pipeline_config.vae_config.vae_scale_factor + image_seq_len = ((int(height) // vae_scale_factor) // 2) * ( + (int(width) // vae_scale_factor) // 2 + ) + mu = calculate_shift( + image_seq_len, + # hard code, since scheduler_config is not in PipelineConfig now + 256, + 4096, + 0.5, + 1.15, + ) + return "mu", mu + + +class ZImagePipeline(LoRAPipeline, ComposedPipelineBase): + pipeline_name = "ZImagePipeline" + + _required_config_modules = [ + "text_encoder", + "tokenizer", + "vae", + "transformer", + "scheduler", + ] + + def create_pipeline_stages(self, server_args: ServerArgs): + self.add_standard_t2i_stages(prepare_extra_timestep_kwargs=[prepare_mu]) + + +EntryClass = ZImagePipeline diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/__init__.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fbc371b35599059235e188eb530c84c0d1fd1c45 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/__init__.py @@ -0,0 +1,90 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Diffusion pipelines for sglang.multimodal_gen. + +This package contains diffusion pipelines for generating videos and images. +""" + +from typing import cast + +from sglang.multimodal_gen.registry import get_model_info +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines_core.lora_pipeline import LoRAPipeline +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( + maybe_download_model, + verify_model_config_and_directory, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class PipelineWithLoRA(LoRAPipeline, ComposedPipelineBase): + """Type for a pipeline that has both ComposedPipelineBase and LoRAPipeline functionality.""" + + pass + + +def build_pipeline( + server_args: ServerArgs, +) -> PipelineWithLoRA: + """ + Only works with valid hf diffusers configs. (model_index.json) + We want to build a pipeline based on the inference args mode_path: + 1. download the model from the hub if it's not already downloaded + 2. verify the model config and directory + 3. based on the config, determine the pipeline class + """ + model_path = server_args.model_path + + # Check if pipeline class is explicitly specified + if server_args.pipeline_class_name: + from sglang.multimodal_gen.registry import ( + _PIPELINE_REGISTRY, + _discover_and_register_pipelines, + ) + + _discover_and_register_pipelines() + logger.info(f"Requested pipeline_class_name: {server_args.pipeline_class_name}") + logger.info( + f"Available pipelines in registry: {list(_PIPELINE_REGISTRY.keys())}" + ) + pipeline_cls = _PIPELINE_REGISTRY.get(server_args.pipeline_class_name) + if pipeline_cls is None: + raise ValueError( + f"Pipeline class '{server_args.pipeline_class_name}' not found in registry. " + f"Available pipelines: {list(_PIPELINE_REGISTRY.keys())}" + ) + logger.info( + f"✓ Using explicitly specified pipeline: {server_args.pipeline_class_name} (class: {pipeline_cls.__name__})" + ) + else: + logger.info("No pipeline_class_name specified, using model_index.json") + model_info = get_model_info( + model_path, + backend=server_args.backend, + model_id=server_args.model_id, + ) + pipeline_cls = model_info.pipeline_cls + logger.info(f"Using pipeline from model_index.json: {pipeline_cls.__name__}") + + # instantiate the pipelines + pipeline = pipeline_cls(model_path, server_args) + + logger.info("Pipeline instantiated") + + return cast(PipelineWithLoRA, pipeline) + + +__all__ = [ + "build_pipeline", + "ComposedPipelineBase", + "Req", + "LoRAPipeline", +] diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py new file mode 100644 index 0000000000000000000000000000000000000000..bdbcdcc144e3c98f6d6639fed5f6e36886b4489c --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py @@ -0,0 +1,624 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Base class for composed pipelines. + +This module defines the base class for pipelines that are composed of multiple stages. +""" + +import os +import re +from abc import ABC, abstractmethod +from typing import Any, Callable, Literal, cast + +import torch +from tqdm import tqdm + +from sglang.multimodal_gen.runtime.loader.component_loaders.component_loader import ( + PipelineComponentLoader, +) +from sglang.multimodal_gen.runtime.pipelines_core.executors.pipeline_executor import ( + PipelineExecutor, +) +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req +from sglang.multimodal_gen.runtime.pipelines_core.stages import ( + DecodingStage, + DenoisingStage, + ImageEncodingStage, + ImageVAEEncodingStage, + InputValidationStage, + LatentPreparationStage, + PipelineStage, + TextEncodingStage, + TimestepPreparationStage, +) +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( + maybe_download_model, + verify_model_config_and_directory, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class ComposedPipelineBase(ABC): + """ + Base class for pipelines composed of multiple stages. + + This class provides the framework for creating pipelines by composing multiple + stages together. Each stage is responsible for a specific part of the diffusion + process, and the pipeline orchestrates the execution of these stages. + """ + + is_video_pipeline: bool = False # To be overridden by video pipelines + # should contains only the modules to be loaded + _required_config_modules: list[str] = [] + _extra_config_module_map: dict[str, str] = {} + server_args: ServerArgs | None = None + modules: dict[str, Any] = {} + executor: PipelineExecutor | None = None + + # the name of the pipeline it associated with, in diffusers + pipeline_name: str + + def is_lora_effective(self): + return False + + def is_lora_set(self): + return False + + def __init__( + self, + model_path: str, + server_args: ServerArgs, + required_config_modules: list[str] | None = None, + loaded_modules: dict[str, torch.nn.Module] | None = None, + executor: PipelineExecutor | None = None, + ): + """ + Initialize the pipeline. After __init__, the pipeline should be ready to + use. The pipeline should be stateless and not hold any batch state. + """ + self.server_args = server_args + + self.model_path: str = model_path + self._stages: list[PipelineStage] = [] + self._stage_name_mapping: dict[str, PipelineStage] = {} + self.executor = executor or self.build_executor(server_args=server_args) + + if required_config_modules is not None: + self._required_config_modules = required_config_modules + + if self._required_config_modules is None: + raise NotImplementedError("Subclass must set _required_config_modules") + + # [module_name, gpu memory usage] + self.memory_usages: dict[str, float] = {} + # Load modules directly in initialization + logger.info("Loading pipeline modules...") + self.modules = self.load_modules(server_args, loaded_modules) + + self.__post_init__() + + def build_executor(self, server_args: ServerArgs): + # TODO + from sglang.multimodal_gen.runtime.pipelines_core.executors.parallel_executor import ( + ParallelExecutor, + ) + + # return SyncExecutor(server_args=server_args) + return ParallelExecutor(server_args=server_args) + + def __post_init__(self) -> None: + assert self.server_args is not None, "server_args must be set" + self.initialize_pipeline(self.server_args) + + logger.info("Creating pipeline stages...") + self.create_pipeline_stages(self.server_args) + + def get_module(self, module_name: str, default_value: Any = None) -> Any: + return self.modules.get(module_name, default_value) + + def add_module(self, module_name: str, module: Any): + self.modules[module_name] = module + + def _load_config(self) -> dict[str, Any]: + model_path = maybe_download_model(self.model_path, force_diffusers_model=True) + self.model_path = model_path + logger.info("Model path: %s", model_path) + config = verify_model_config_and_directory(model_path) + return cast(dict[str, Any], config) + + @property + def required_config_modules(self) -> list[str]: + """ + List of modules that are required by the pipeline. The names should match + the diffusers directory and model_index.json file. These modules will be + loaded using the PipelineComponentLoader and made available in the + modules dictionary. Access these modules using the get_module method. + + class ConcretePipeline(ComposedPipelineBase): + _required_config_modules = ["vae", "text_encoder", "transformer", "scheduler", "tokenizer"] + + + @property + def required_config_modules(self): + return self._required_config_modules + """ + return self._required_config_modules + + @property + def stages(self) -> list[PipelineStage]: + """ + List of stages in the pipeline. + """ + return self._stages + + @abstractmethod + def create_pipeline_stages(self, server_args: ServerArgs): + """ + Create the inference pipeline stages. + """ + raise NotImplementedError + + def initialize_pipeline(self, server_args: ServerArgs): + """ + Initialize the pipeline. + """ + return + + def _resolve_component_path( + self, server_args: ServerArgs, module_name: str, load_module_name: str + ) -> str: + override_path = server_args.component_paths.get(module_name) + if override_path is not None: + # overridden with args like --vae-path + component_model_path = maybe_download_model(override_path) + else: + component_model_path = os.path.join(self.model_path, load_module_name) + + logger.debug("Resolved component path: %s", component_model_path) + return component_model_path + + def load_modules( + self, + server_args: ServerArgs, + loaded_modules: dict[str, torch.nn.Module] | None = None, + ) -> dict[str, Any]: + """ + Load the modules from the config. + loaded_modules: Optional[Dict[str, torch.nn.Module]] = None, + If provided, loaded_modules will be used instead of loading from config/pretrained weights. + """ + + model_index = self._load_config() + logger.info("Loading pipeline modules from config: %s", model_index) + + # remove keys that are not pipeline modules + model_index.pop("_class_name") + model_index.pop("_diffusers_version") + if ( + "boundary_ratio" in model_index + and model_index["boundary_ratio"] is not None + ): + has_transformer = ( + "transformer" in model_index + or "transformer_2" in model_index + or "transformer" in self.required_config_modules + or "transformer_2" in self.required_config_modules + ) + if has_transformer: + logger.info( + "MoE pipeline detected. Adding transformer_2 to self.required_config_modules..." + ) + if "transformer_2" not in self.required_config_modules: + self.required_config_modules.append("transformer_2") + else: + logger.info( + "Boundary ratio found in model_index.json without transformers; " + "using it for pipeline config only." + ) + logger.info( + "Setting boundary ratio to %s", + model_index["boundary_ratio"], + ) + server_args.pipeline_config.dit_config.boundary_ratio = model_index[ + "boundary_ratio" + ] + + model_index.pop("boundary_ratio", None) + # used by Wan2.2 ti2v + model_index.pop("expand_timesteps", None) + + # some sanity checks + assert ( + len(model_index) > 1 + ), "model_index.json must contain at least one pipeline module" + + model_index = { + required_module: model_index[required_module] + for required_module in self.required_config_modules + } + + for module_name in self.required_config_modules: + if ( + module_name not in model_index + and module_name in self._extra_config_module_map + ): + extra_module_value = self._extra_config_module_map[module_name] + logger.warning( + "model_index.json does not contain a %s module, but found {%s: %s} in _extra_config_module_map, adding to model_index.", + module_name, + module_name, + extra_module_value, + ) + if extra_module_value in model_index: + logger.info( + "Using module %s for %s", extra_module_value, module_name + ) + model_index[module_name] = model_index[extra_module_value] + continue + else: + raise ValueError( + f"Required module key: {module_name} value: {model_index.get(module_name)} was not found in loaded modules {model_index.keys()}" + ) + + # all the component models used by the pipeline + required_modules = self.required_config_modules + logger.info("Loading required components: %s", required_modules) + + loaded_components = {} + for module_name, ( + transformers_or_diffusers, + architecture, + ) in tqdm(iterable=model_index.items(), desc="Loading required modules"): + if transformers_or_diffusers is None: + logger.warning( + "Module %s in model_index.json has null value, removing from required_config_modules", + module_name, + ) + if module_name in self.required_config_modules: + self.required_config_modules.remove(module_name) + continue + if module_name not in required_modules: + logger.info("Skipping module %s", module_name) + continue + if loaded_modules is not None and module_name in loaded_modules: + logger.info("Using module %s already provided", module_name) + loaded_components[module_name] = loaded_modules[module_name] + continue + + # we load the module from the extra config module map if it exists + if module_name in self._extra_config_module_map: + load_module_name = self._extra_config_module_map[module_name] + else: + load_module_name = module_name + + component_model_path = self._resolve_component_path( + server_args, module_name, load_module_name + ) + module, memory_usage = PipelineComponentLoader.load_component( + component_name=load_module_name, + component_model_path=component_model_path, + transformers_or_diffusers=transformers_or_diffusers, + server_args=server_args, + ) + + self.memory_usages[load_module_name] = memory_usage + + if module_name in loaded_components: + logger.warning("Overwriting module %s", module_name) + loaded_components[module_name] = module + + # Check if all required modules were loaded + for module_name in required_modules: + if ( + module_name not in loaded_components + or loaded_components[module_name] is None + ): + raise ValueError( + f"Required module: {module_name} was not found in loaded modules: {list(loaded_components.keys())}" + ) + + logger.debug( + "Memory usage of loaded modules (GiB): %s. Available memory: %s", + self.memory_usages, + round(current_platform.get_available_gpu_memory(), 2), + ) + + return loaded_components + + @staticmethod + def _infer_stage_name(stage: PipelineStage) -> str: + class_name = stage.__class__.__name__ + # snake_case + name = re.sub(r"(? "ComposedPipelineBase": + + assert self.modules is not None, "No modules are registered" + + if stage_name is None: + stage_name = self._infer_stage_name(stage) + if stage_name in self._stage_name_mapping: + raise ValueError(f"Duplicate stage name detected: {stage_name}") + + self._stages.append(stage) + self._stage_name_mapping[stage_name] = stage + return self + + def add_stages( + self, stages: list[PipelineStage | tuple[PipelineStage, str]] + ) -> "ComposedPipelineBase": + + for item in stages: + if isinstance(item, tuple): + stage, name = item + self.add_stage(stage, name) + else: + self.add_stage(item) + return self + + def add_stage_if( + self, + condition: bool | Callable[[], bool], + stage: PipelineStage, + ) -> "ComposedPipelineBase": + should_add = condition() if callable(condition) else condition + if should_add: + self.add_stage(stage) + return self + + def get_stage(self, stage_name: str) -> PipelineStage | None: + """Get a stage by name.""" + return self._stage_name_mapping.get(stage_name) + + def add_standard_text_encoding_stage( + self, + text_encoder_key: str = "text_encoder", + tokenizer_key: str = "tokenizer", + ) -> "ComposedPipelineBase": + return self.add_stage( + TextEncodingStage( + text_encoders=[self.get_module(text_encoder_key)], + tokenizers=[self.get_module(tokenizer_key)], + ), + ) + + def add_standard_timestep_preparation_stage( + self, + scheduler_key: str = "scheduler", + prepare_extra_kwargs: list[Callable] | None = [], + ) -> "ComposedPipelineBase": + return self.add_stage( + TimestepPreparationStage( + scheduler=self.get_module(scheduler_key), + prepare_extra_set_timesteps_kwargs=prepare_extra_kwargs, + ), + ) + + def add_standard_latent_preparation_stage( + self, + scheduler_key: str = "scheduler", + transformer_key: str = "transformer", + ) -> "ComposedPipelineBase": + return self.add_stage( + LatentPreparationStage( + scheduler=self.get_module(scheduler_key), + transformer=self.get_module(transformer_key), + ), + ) + + def add_standard_denoising_stage( + self, + transformer_key: str = "transformer", + transformer_2_key: str | None = "transformer_2", + scheduler_key: str = "scheduler", + vae_key: str | None = "vae", + ) -> "ComposedPipelineBase": + + kwargs = { + "transformer": self.get_module(transformer_key), + "scheduler": self.get_module(scheduler_key), + } + + if transformer_2_key: + transformer_2 = self.get_module(transformer_2_key, None) + if transformer_2 is not None: + kwargs["transformer_2"] = transformer_2 + + if vae_key: + vae = self.get_module(vae_key, None) + if vae is not None: + kwargs["vae"] = vae + kwargs["pipeline"] = self + + return self.add_stage(DenoisingStage(**kwargs)) + + def add_standard_decoding_stage( + self, + vae_key: str = "vae", + ) -> "ComposedPipelineBase": + + return self.add_stage( + DecodingStage(vae=self.get_module(vae_key), pipeline=self), + ) + + def add_standard_t2i_stages( + self, + include_input_validation: bool = True, + prepare_extra_timestep_kwargs: list[Callable] | None = [], + ) -> "ComposedPipelineBase": + + if include_input_validation: + self.add_stage(InputValidationStage()) + + self.add_standard_text_encoding_stage() + + self.add_standard_latent_preparation_stage() + self.add_standard_timestep_preparation_stage( + prepare_extra_kwargs=prepare_extra_timestep_kwargs + ) + self.add_standard_denoising_stage() + self.add_standard_decoding_stage() + + return self + + def add_standard_ti2i_stages( + self, + *, + include_input_validation: bool = True, + vae_image_processor: Any | None = None, + prompt_encoding: Literal["text", "image_encoding"] = "text", + text_encoder_key: str = "text_encoder", + tokenizer_key: str = "tokenizer", + image_processor_key: str = "processor", + prompt_text_encoder_key: str = "text_encoder", + image_vae_key: str = "vae", + image_vae_stage_kwargs: dict[str, Any] | None = None, + prepare_extra_timestep_kwargs: list[Callable] | None = [], + ) -> "ComposedPipelineBase": + if include_input_validation: + self.add_stage( + InputValidationStage(vae_image_processor=vae_image_processor) + ) + + if prompt_encoding == "text": + self.add_standard_text_encoding_stage( + text_encoder_key=text_encoder_key, + tokenizer_key=tokenizer_key, + ) + elif prompt_encoding == "image_encoding": + self.add_stage( + ImageEncodingStage( + image_processor=self.get_module(image_processor_key), + text_encoder=self.get_module(prompt_text_encoder_key), + ), + ) + else: + raise ValueError(f"Unknown prompt_encoding: {prompt_encoding}") + + self.add_stage( + ImageVAEEncodingStage( + vae=self.get_module(image_vae_key), + **(image_vae_stage_kwargs or {}), + ), + ) + + self.add_standard_latent_preparation_stage() + + self.add_standard_timestep_preparation_stage( + prepare_extra_kwargs=prepare_extra_timestep_kwargs + ) + self.add_standard_denoising_stage() + self.add_standard_decoding_stage() + return self + + def add_standard_ti2v_stages( + self, + *, + include_input_validation: bool = True, + vae_image_processor: Any | None = None, + text_encoder_key: str = "text_encoder", + tokenizer_key: str = "tokenizer", + image_encoder_key: str = "image_encoder", + image_processor_key: str = "image_processor", + image_vae_key: str = "vae", + image_vae_stage_kwargs: dict[str, Any] | None = None, + image_vae_encoding_position: Literal[ + "before_timestep", "after_latent" + ] = "before_timestep", + prepare_extra_timestep_kwargs: list[Callable] | None = [], + denoising_stage_factory: Callable[[], PipelineStage] | None = None, + ) -> "ComposedPipelineBase": + if include_input_validation: + self.add_stage( + InputValidationStage(vae_image_processor=vae_image_processor) + ) + + self.add_standard_text_encoding_stage( + text_encoder_key=text_encoder_key, + tokenizer_key=tokenizer_key, + ) + + image_encoder = self.get_module(image_encoder_key, None) + image_processor = self.get_module(image_processor_key, None) + self.add_stage_if( + image_encoder is not None and image_processor is not None, + ImageEncodingStage( + image_encoder=image_encoder, + image_processor=image_processor, + ), + ) + + if image_vae_encoding_position == "before_timestep": + self.add_stage( + ImageVAEEncodingStage( + vae=self.get_module(image_vae_key), + **(image_vae_stage_kwargs or {}), + ) + ) + + self.add_standard_latent_preparation_stage() + self.add_standard_timestep_preparation_stage( + prepare_extra_kwargs=prepare_extra_timestep_kwargs + ) + if image_vae_encoding_position == "after_latent": + self.add_stage( + ImageVAEEncodingStage( + vae=self.get_module(image_vae_key), + **(image_vae_stage_kwargs or {}), + ) + ) + elif image_vae_encoding_position != "before_timestep": + raise ValueError( + f"Unknown image_vae_encoding_position: {image_vae_encoding_position}" + ) + + if denoising_stage_factory is None: + self.add_standard_denoising_stage() + else: + self.add_stage(denoising_stage_factory()) + + self.add_standard_decoding_stage() + return self + + # TODO(will): don't hardcode no_grad + @torch.no_grad() + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> OutputBatch: + """ + Generate a video or image using the pipeline. + + Args: + batch: The batch to generate from. + server_args: The inference arguments. + Returns: + Req: The batch with the generated video or image. + """ + + if self.is_lora_set() and not self.is_lora_effective(): + logger.warning( + "LoRA adapter is set, but not effective. Please make sure the LoRA weights are merged" + ) + + # Execute each stage + if not batch.is_warmup and not batch.suppress_logs: + logger.info( + "Running pipeline stages: %s", + list(self._stage_name_mapping.keys()), + main_process_only=True, + ) + + return self.executor.execute_with_profiling(self.stages, batch, server_args) diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/executors/parallel_executor.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/executors/parallel_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..6a3ffcf9d28991e577166781700f3a313facb3b9 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/executors/parallel_executor.py @@ -0,0 +1,102 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from typing import List + +import torch + +from sglang.multimodal_gen.runtime.distributed import get_sp_group +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_world_rank, +) +from sglang.multimodal_gen.runtime.pipelines_core import Req +from sglang.multimodal_gen.runtime.pipelines_core.executors.pipeline_executor import ( + PipelineExecutor, +) +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch +from sglang.multimodal_gen.runtime.pipelines_core.stages.base import ( + PipelineStage, + StageParallelismType, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.distributed import broadcast_pyobj +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class ParallelExecutor(PipelineExecutor): + """ + The correctness of the execution relies on the parallelism_type declared by stages + + """ + + def collect_from_main(self, batches: list[Req]): + + # TODO: fix this condition + if self.server_args.sp_degree != 1: + sp_group = get_sp_group() + batches = broadcast_pyobj( + batches, + sp_group.rank, + sp_group.cpu_group, + src=sp_group.ranks[0], + ) + + if self.server_args.enable_cfg_parallel: + batches = broadcast_pyobj( + batches, + self.worker.cfg_group.rank, + self.worker.cfg_cpu_group, + src=self.worker.cfg_group.ranks[0], + ) + + def _execute( + self, + stages: List[PipelineStage], + batch: Req, + server_args: ServerArgs, + ) -> OutputBatch: + """ + Execute all pipeline stages respecting their declared parallelism type. + """ + if server_args.enable_cfg_parallel: + rank = get_classifier_free_guidance_rank() + else: + rank = get_world_rank() + cfg_group = get_cfg_group() + + # TODO: decide when to gather on main when CFG_PARALLEL -> MAIN_RANK_ONLY + for stage in stages: + paradigm = stage.parallelism_type + + if paradigm == StageParallelismType.MAIN_RANK_ONLY: + if rank == 0: + # Only main rank executes, others just wait + batch = stage(batch, server_args) + torch.distributed.barrier() + + elif paradigm == StageParallelismType.CFG_PARALLEL: + obj_list = [batch] if rank == 0 else [] + broadcasted_list = broadcast_pyobj( + obj_list, rank=rank, dist_group=cfg_group.cpu_group, src=0 + ) + if rank != 0: + batch = broadcasted_list[0] + batch = stage(batch, server_args) + + torch.distributed.barrier() + + elif paradigm == StageParallelismType.REPLICATED: + batch = stage(batch, server_args) + return batch + + def execute( + self, + stages: List[PipelineStage], + batch: Req, + server_args: ServerArgs, + ) -> OutputBatch: + batch = self._execute(stages, batch, server_args) + return batch diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/executors/pipeline_executor.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/executors/pipeline_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..9775bfe052c107d5c3971788a57a707c61c6a4c3 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/executors/pipeline_executor.py @@ -0,0 +1,105 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Base class for all pipeline executors. +""" + +import contextlib +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List + +from sglang.multimodal_gen.runtime.distributed import get_world_rank +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.runtime.utils.perf_logger import StageProfiler +from sglang.multimodal_gen.runtime.utils.profiler import SGLDiffusionProfiler + +if TYPE_CHECKING: + # Only for type checkers; avoids runtime circular import + from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage + +logger = init_logger(__name__) + + +class Timer(StageProfiler): + """ + A wrapper around StageProfiler to maintain backward compatibility. + It forces simple logging behavior (log start/end) regardless of env vars. + """ + + def __init__(self, name="Stage"): + super().__init__( + stage_name=name, logger=logger, metrics=None, log_stage_start_end=True + ) + + +class PipelineExecutor(ABC): + """ + Abstract base class for all pipeline executors. + + Executors orchestrate the execution of pipeline, with managing the parallel and communications required by stages + + """ + + def __init__(self, server_args): + self.server_args = server_args + + def execute_with_profiling( + self, + stages: List["PipelineStage"], + batch: Req, + server_args: ServerArgs, + ) -> OutputBatch: + + with self.profile_execution(batch, dump_rank=0): + batch = self.execute(stages, batch, server_args) + + return batch + + @abstractmethod + def execute( + self, + stages: List["PipelineStage"], + batch: Req, + server_args: ServerArgs, + ) -> OutputBatch: + """ + Execute the pipeline stages. + + Args: + stages: A list of pipeline stages to execute. + batch: The batch to process. + server_args: The server arguments. + + Returns: + The processed batch. + """ + raise NotImplementedError + + @contextlib.contextmanager + def profile_execution(self, batch: Req, dump_rank: int = 0): + """ + Context manager for profiling execution. + """ + do_profile = batch.profile and not batch.is_warmup + if not do_profile: + # fast forward + yield + return + + request_id = batch.request_id + rank = get_world_rank() + + profiler = SGLDiffusionProfiler( + request_id=request_id, + rank=rank, + full_profile=batch.profile_all_stages, + num_steps=batch.num_profiled_timesteps, + num_inference_steps=batch.num_inference_steps, + ) + try: + yield + finally: + profiler.stop(dump_rank=dump_rank) diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/executors/sync_executor.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/executors/sync_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..5c1d670e65a906a0c2044241e41edcaeb0cc3516 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/executors/sync_executor.py @@ -0,0 +1,52 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Synchronous pipeline executor implementation. +""" + +from typing import List + +from sglang.multimodal_gen.runtime.pipelines_core.executors.pipeline_executor import ( + PipelineExecutor, + SGLDiffusionProfiler, +) +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req +from sglang.multimodal_gen.runtime.pipelines_core.stages import PipelineStage +from sglang.multimodal_gen.runtime.server_args import ServerArgs + + +class SyncExecutor(PipelineExecutor): + """ + A simple synchronous executor that runs stages sequentially. + """ + + def run_profile_all_stages( + self, + stages: List[PipelineStage], + batch: Req, + server_args: ServerArgs, + ) -> OutputBatch: + """ + Execute all pipeline stages sequentially. + """ + for stage in stages: + batch = stage(batch, server_args) + profiler = SGLDiffusionProfiler.get_instance() + if profiler: + profiler.step_stage() + return batch + + def execute( + self, + stages: List[PipelineStage], + batch: Req, + server_args: ServerArgs, + ) -> OutputBatch: + """ + Execute the pipeline stages sequentially. + """ + + batch = self.run_profile_all_stages(stages, batch, server_args) + + return batch diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/lora_format_adapter.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/lora_format_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..889ceea7d2b26c02f6e59b727464f2c9d4f3c212 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/lora_format_adapter.py @@ -0,0 +1,388 @@ +from __future__ import annotations + +import logging +from enum import Enum +from typing import Dict, Iterable, Mapping, Optional + +import torch +from diffusers.loaders import lora_conversion_utils as lcu + +logger = logging.getLogger("LoRAFormatAdapter") + + +class LoRAFormat(str, Enum): + """Supported external LoRA formats before normalization.""" + + STANDARD = "standard" + NON_DIFFUSERS_SD = "non-diffusers-sd" + QWEN_IMAGE_STANDARD = "qwen-image-standard" + XLABS_FLUX = "xlabs-ai" + KOHYA_FLUX = "kohya-flux" + WAN = "wan" + + +def _sample_keys(keys: Iterable[str], k: int = 20) -> list[str]: + out = [] + for i, key in enumerate(keys): + if i >= k: + break + out.append(key) + return out + + +def _has_substring_key(keys: Iterable[str], substr: str) -> bool: + return any(substr in k for k in keys) + + +def _has_prefix_key(keys: Iterable[str], prefix: str) -> bool: + return any(k.startswith(prefix) for k in keys) + + +def _looks_like_xlabs_flux_key(k: str) -> bool: + """XLabs FLUX-style keys under double_blocks/single_blocks with lora down/up.""" + if not (k.endswith(".down.weight") or k.endswith(".up.weight")): + return False + + if not k.startswith( + ( + "double_blocks.", + "single_blocks.", + "diffusion_model.double_blocks", + "diffusion_model.single_blocks", + ) + ): + return False + + return ".processor." in k or ".proj_lora" in k or ".qkv_lora" in k + + +def _looks_like_kohya_flux(state_dict: Mapping[str, torch.Tensor]) -> bool: + """Kohya FLUX LoRA (flux_lora.py) under lora_unet_double/single_blocks_ prefixes.""" + if not state_dict: + return False + keys = state_dict.keys() + return any( + k.startswith("lora_unet_double_blocks_") + or k.startswith("lora_unet_single_blocks_") + for k in keys + ) + + +def _looks_like_non_diffusers_sd(state_dict: Mapping[str, torch.Tensor]) -> bool: + """Classic non-diffusers SD LoRA (Kohya/A1111/sd-scripts).""" + if not state_dict: + return False + keys = state_dict.keys() + return all( + k.startswith(("lora_unet_", "lora_te_", "lora_te1_", "lora_te2_")) for k in keys + ) + + +def _looks_like_wan_lora(state_dict: Mapping[str, torch.Tensor]) -> bool: + """Wan2.2 distill LoRAs (Wan-AI / Wan2.2-Distill-Loras style).""" + if not state_dict: + return False + + for k in state_dict.keys(): + if not k.startswith("diffusion_model.blocks."): + continue + if ".lora_down" not in k and ".lora_up" not in k: + continue + if ".cross_attn." in k or ".self_attn." in k or ".ffn." in k or ".norm3." in k: + return True + + return False + + +def _looks_like_qwen_image(state_dict: Mapping[str, torch.Tensor]) -> bool: + keys = list(state_dict.keys()) + if not keys: + return False + return _has_prefix_key(keys, "transformer.transformer_blocks.") and ( + _has_substring_key(keys, ".lora.down.weight") + or _has_substring_key(keys, ".lora.up.weight") + ) + + +def detect_lora_format_from_state_dict( + state_dict: Mapping[str, torch.Tensor], +) -> LoRAFormat: + """Classify LoRA format by key patterns only.""" + keys = list(state_dict.keys()) + if not keys: + return LoRAFormat.STANDARD + + if _has_substring_key(keys, ".lora_A") or _has_substring_key(keys, ".lora_B"): + return LoRAFormat.STANDARD + + if any(_looks_like_xlabs_flux_key(k) for k in keys): + return LoRAFormat.XLABS_FLUX + if _looks_like_kohya_flux(state_dict): + return LoRAFormat.KOHYA_FLUX + + if _looks_like_wan_lora(state_dict): + return LoRAFormat.WAN + + if _looks_like_qwen_image(state_dict): + return LoRAFormat.STANDARD + + if _looks_like_non_diffusers_sd(state_dict): + return LoRAFormat.NON_DIFFUSERS_SD + + if _has_substring_key(keys, ".lora.down") or _has_substring_key(keys, ".lora_up"): + return LoRAFormat.NON_DIFFUSERS_SD + + return LoRAFormat.STANDARD + + +def _convert_qwen_image_standard( + state_dict: Mapping[str, torch.Tensor], + log: logging.Logger, +) -> Dict[str, torch.Tensor]: + """Qwen-Image: transformer.*.lora.down/up -> transformer_blocks.*.lora_A/B.""" + out: Dict[str, torch.Tensor] = {} + + for name, tensor in state_dict.items(): + new_name = name + + if new_name.startswith("transformer."): + new_name = new_name[len("transformer.") :] + + if new_name.endswith(".lora.down.weight"): + new_name = new_name.replace(".lora.down.weight", ".lora_A.weight") + elif new_name.endswith(".lora.up.weight"): + new_name = new_name.replace(".lora.up.weight", ".lora_B.weight") + + out[new_name] = tensor + + sample = _sample_keys(out.keys(), 20) + return out + + +def _convert_non_diffusers_sd_simple( + state_dict: Mapping[str, torch.Tensor], + log: logging.Logger, +) -> Dict[str, torch.Tensor]: + """Generic down/up -> A/B conversion for non-diffusers SD-like formats.""" + out: Dict[str, torch.Tensor] = {} + + for name, tensor in state_dict.items(): + new_name = name + + if "lora_down.weight" in new_name: + new_name = new_name.replace("lora_down.weight", "lora_A.weight") + elif "lora_up.weight" in new_name: + new_name = new_name.replace("lora_up.weight", "lora_B.weight") + elif new_name.endswith(".lora_down"): + new_name = new_name.replace(".lora_down", ".lora_A") + elif new_name.endswith(".lora_up"): + new_name = new_name.replace(".lora_up", ".lora_B") + + out[new_name] = tensor + + sample = _sample_keys(out.keys(), 20) + log.info( + "[LoRAFormatAdapter] after NON_DIFFUSERS_SD simple conversion, " + "sample keys (<=20): %s", + ", ".join(sample), + ) + return out + + +def _convert_with_diffusers_utils_if_available( + state_dict: Mapping[str, torch.Tensor], + log: logging.Logger, +) -> Optional[Dict[str, torch.Tensor]]: + """Use diffusers.lora_conversion_utils if available.""" + try: + if hasattr(lcu, "maybe_convert_state_dict"): + converted = lcu.maybe_convert_state_dict( # type: ignore[attr-defined] + state_dict + ) + else: + converted = dict(state_dict) + + if not isinstance(converted, dict): + converted = dict(converted) + + sample = _sample_keys(converted.keys(), 20) + log.info( + "[LoRAFormatAdapter] diffusers.lora_conversion_utils converted keys, " + "sample keys (<=20): %s", + ", ".join(sample), + ) + return converted + except Exception as exc: # pragma: no cover + log.warning( + "[LoRAFormatAdapter] diffusers lora_conversion_utils failed, " + "falling back to internal converters. Error: %s", + exc, + ) + return None + + +def _convert_via_diffusers_candidates( + state_dict: Mapping[str, torch.Tensor], + candidate_names: tuple[str, ...], + log: logging.Logger, + unavailable_warning: str, + no_converter_warning: str, + success_info: str, + all_failed_warning: str, +) -> Dict[str, torch.Tensor]: + """Try multiple named converters in lora_conversion_utils, use the first that works.""" + converters = [ + (n, getattr(lcu, n)) for n in candidate_names if callable(getattr(lcu, n, None)) + ] + if not converters: + log.warning(no_converter_warning) + return dict(state_dict) + + last_err: Optional[Exception] = None + + for name, fn in converters: + try: + sd_copy = dict(state_dict) + out = fn(sd_copy) + if isinstance(out, tuple) and isinstance(out[0], dict): + out = out[0] + if not isinstance(out, dict): + raise TypeError(f"Converter {name} returned {type(out)}") + log.info(success_info.format(name=name)) + return out + except Exception as exc: + last_err = exc + + log.warning(all_failed_warning.format(last_err=last_err)) + return dict(state_dict) + + +def _convert_xlabs_ai_via_diffusers( + state_dict: Mapping[str, torch.Tensor], + log: logging.Logger, +) -> Dict[str, torch.Tensor]: + """Convert XLabs FLUX LoRA via diffusers helpers.""" + return _convert_via_diffusers_candidates( + state_dict, + ( + "_convert_xlabs_flux_lora_to_diffusers", + "convert_xlabs_lora_state_dict_to_diffusers", + "convert_xlabs_lora_to_diffusers", + "convert_xlabs_flux_lora_to_diffusers", + ), + log=log, + unavailable_warning=( + "[LoRAFormatAdapter] XLabs FLUX detected but diffusers is unavailable." + ), + no_converter_warning=( + "[LoRAFormatAdapter] No XLabs FLUX converter found in diffusers." + ), + success_info="[LoRAFormatAdapter] Converted XLabs FLUX LoRA using {name}", + all_failed_warning=( + "[LoRAFormatAdapter] All XLabs FLUX converters failed; " + "last error: {last_err}" + ), + ) + + +def _convert_kohya_flux_via_diffusers( + state_dict: Mapping[str, torch.Tensor], + log: logging.Logger, +) -> Dict[str, torch.Tensor]: + """Convert Kohya FLUX LoRA via diffusers helpers.""" + return _convert_via_diffusers_candidates( + state_dict, + ( + "_convert_kohya_flux_lora_to_diffusers", + "convert_kohya_flux_lora_to_diffusers", + ), + log=log, + unavailable_warning=( + "[LoRAFormatAdapter] Kohya FLUX detected but diffusers is unavailable." + ), + no_converter_warning="[LoRAFormatAdapter] No Kohya FLUX converter found.", + success_info="[LoRAFormatAdapter] Converted Kohya FLUX LoRA using {name}", + all_failed_warning=( + "[LoRAFormatAdapter] Kohya FLUX conversion failed; " + "last error: {last_err}" + ), + ) + + +def convert_lora_state_dict_by_format( + state_dict: Mapping[str, torch.Tensor], + fmt: LoRAFormat, + log: logging.Logger, +) -> Dict[str, torch.Tensor]: + """Normalize a raw LoRA state_dict into A/B + .weight naming.""" + if fmt == LoRAFormat.QWEN_IMAGE_STANDARD: + return _convert_qwen_image_standard(state_dict, log) + + if fmt == LoRAFormat.XLABS_FLUX: + converted = _convert_xlabs_ai_via_diffusers(state_dict, log) + return _convert_non_diffusers_sd_simple(converted, log) + + if fmt == LoRAFormat.KOHYA_FLUX: + converted = _convert_kohya_flux_via_diffusers(state_dict, log) + return _convert_non_diffusers_sd_simple(converted, log) + + if fmt == LoRAFormat.WAN: + maybe = _convert_with_diffusers_utils_if_available(state_dict, log) + if maybe is None: + maybe = dict(state_dict) + return _convert_non_diffusers_sd_simple(maybe, log) + + if fmt == LoRAFormat.STANDARD: + maybe = _convert_with_diffusers_utils_if_available(state_dict, log) + if maybe is None: + maybe = dict(state_dict) + + if _looks_like_qwen_image(maybe): + return _convert_qwen_image_standard(maybe, log) + + return maybe + + if fmt == LoRAFormat.NON_DIFFUSERS_SD: + maybe = _convert_with_diffusers_utils_if_available(state_dict, log) + if maybe is None: + maybe = dict(state_dict) + return _convert_non_diffusers_sd_simple(maybe, log) + + log.info( + "[LoRAFormatAdapter] format %s not handled specially, returning as-is", + fmt, + ) + return dict(state_dict) + + +def normalize_lora_state_dict( + state_dict: Mapping[str, torch.Tensor], + logger: Optional[logging.Logger] = None, +) -> Dict[str, torch.Tensor]: + """Normalize any supported LoRA format into a single canonical layout.""" + log = logger or globals()["logger"] + + keys = list(state_dict.keys()) + log.info( + "[LoRAFormatAdapter] normalize_lora_state_dict called, #keys=%d", + len(keys), + ) + if keys: + log.info( + "[LoRAFormatAdapter] before convert, sample keys (<=20): %s", + ", ".join(_sample_keys(keys, 20)), + ) + + fmt = detect_lora_format_from_state_dict(state_dict) + log.info("[LoRAFormatAdapter] detected format: %s", fmt) + + normalized = convert_lora_state_dict_by_format(state_dict, fmt, log) + + norm_keys = list(normalized.keys()) + if norm_keys: + log.info( + "[LoRAFormatAdapter] after convert, sample keys (<=20): %s", + ", ".join(_sample_keys(norm_keys, 20)), + ) + + return normalized diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..109991460640ae963bd59662ba7d452a591568dd --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py @@ -0,0 +1,867 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import os +from collections import defaultdict +from collections.abc import Hashable +from contextlib import contextmanager +from typing import Any + +import torch +import torch.distributed as dist +from safetensors.torch import load_file + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.layers.lora.linear import ( + BaseLayerWithLoRA, + replace_submodule, + wrap_with_lora_layer, +) +from sglang.multimodal_gen.runtime.loader.utils import get_param_names_mapping +from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines_core.lora_format_adapter import ( + normalize_lora_state_dict, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_lora +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +# to avoid deadlocks when forking +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +logger = init_logger(__name__) + + +class LoRAPipeline(ComposedPipelineBase): + """ + Pipeline that supports injecting LoRA adapters into the diffusion transformer. + """ + + # Type annotations for instance attributes (initialized in __init__) + # [lora_nickname][target_LoRA_weight_name_in_SGLang_dit] = weight + # e.g., [jinx][transformer_blocks.0.attn.to_v.lora_A] + lora_adapters: dict[str, dict[str, torch.Tensor]] + loaded_adapter_paths: dict[str, str] # nickname -> lora_path + # Track current adapter per module: {"transformer": "high_lora", "transformer_2": "low_lora"} + cur_adapter_name: dict[str, str] + cur_adapter_path: dict[str, str] + cur_adapter_strength: dict[str, float] # Track current strength per module + cur_adapter_config: dict[str, tuple[list[str], list[float]]] + # [dit_layer_name] = wrapped_lora_layer + lora_layers: dict[str, BaseLayerWithLoRA] + lora_layers_critic: dict[str, BaseLayerWithLoRA] + lora_layers_transformer_2: dict[str, BaseLayerWithLoRA] + server_args: ServerArgs + exclude_lora_layers: list[str] + device: torch.device + lora_target_modules: list[str] | None + lora_path: str | None + lora_nickname: str + lora_rank: int | None + lora_alpha: int | None + lora_initialized: bool + # Track merge status per module: {"transformer": True, "transformer_2": False} + is_lora_merged: dict[str, bool] + # Valid target values for set_lora (class constant, immutable) + VALID_TARGETS: list[str] = ["all", "transformer", "transformer_2", "critic"] + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # Initialize all mutable instance attributes to avoid sharing across instances + self.lora_adapters = defaultdict(dict) + self.loaded_adapter_paths = {} + self.cur_adapter_name = {} + self.cur_adapter_path = {} + self.cur_adapter_strength = {} + # Track full LoRA config: {module_name: (nickname_list, strength_list)} + self.cur_adapter_config = {} + self.lora_layers = {} + self.lora_layers_critic = {} + self.lora_layers_transformer_2 = {} + self.is_lora_merged = {} + self.lora_initialized = False + self.lora_rank = None + self.lora_alpha = None + self.lora_path = None + self.lora_nickname = "default" + + # Initialize from server_args + self.device = get_local_torch_device() + self.exclude_lora_layers = ( + self.server_args.pipeline_config.dit_config.arch_config.exclude_lora_layers + ) + self.lora_target_modules = self.server_args.lora_target_modules + self.lora_path = self.server_args.lora_path + self.lora_nickname = self.server_args.lora_nickname + if self.lora_path is not None: + self.convert_to_lora_layers() + self.set_lora( + self.lora_nickname, self.lora_path, strength=self.server_args.lora_scale # type: ignore + ) # type: ignore + + def is_target_layer(self, module_name: str) -> bool: + if self.lora_target_modules is None: + return True + return any( + target_name in module_name for target_name in self.lora_target_modules + ) + + def _get_target_lora_layers( + self, target: str + ) -> tuple[list[tuple[str, dict[str, BaseLayerWithLoRA]]], str | None]: + """ + Return a list of (module_name, lora_layers_dict) based on the target. + + Args: + target: One of "all", "transformer", "transformer_2", "critic". + + Returns: + A tuple of (result, error_message): + - result: List of tuples (module_name, lora_layers_dict) to operate on. + - error_message: Error description if target is invalid or module doesn't exist, None otherwise. + """ + if target == "all": + result: list[tuple[str, dict[str, BaseLayerWithLoRA]]] = [ + ("transformer", self.lora_layers) + ] + if self.lora_layers_transformer_2: + result.append(("transformer_2", self.lora_layers_transformer_2)) + if self.lora_layers_critic: + result.append(("critic", self.lora_layers_critic)) + return result, None + elif target == "transformer": + return [("transformer", self.lora_layers)], None + elif target == "transformer_2": + if not self.lora_layers_transformer_2: + return [], "transformer_2 does not exist in this pipeline" + return [("transformer_2", self.lora_layers_transformer_2)], None + elif target == "critic": + if not self.lora_layers_critic: + return ( + [], + "critic (fake_score_transformer) does not exist in this pipeline", + ) + return [("critic", self.lora_layers_critic)], None + else: + return [], f"Invalid target: {target}. Valid targets: {self.VALID_TARGETS}" + + @contextmanager + def _temporarily_disable_offload( + self, + target_modules: list[tuple[str, dict[str, BaseLayerWithLoRA]]] | None = None, + target: str | None = None, + use_module_names_only: bool = False, + ): + """ + Context manager to temporarily disable layerwise offload for the given modules. + + Args: + target_modules: List of (module_name, lora_layers_dict) tuples. If None, will be determined from target. + target: Target string ("all", "transformer", etc.). Used if target_modules is None. + use_module_names_only: If True, determine module names directly from target without requiring + LoRA initialization. Used for convert_to_lora_layers scenario. + + Yields: + List of modules that had offload disabled. + """ + from sglang.multimodal_gen.runtime.utils.layerwise_offload import ( + OffloadableDiTMixin, + ) + + module_names = [] + if target_modules is not None: + # Extract module names from target_modules + module_names = [module_name for module_name, _ in target_modules] + elif target is not None: + if use_module_names_only: + if target == "all": + module_names = ["transformer", "transformer_2"] + elif target in ["transformer", "transformer_2", "critic"]: + module_names = [target] + else: + target_modules, _ = self._get_target_lora_layers(target) + if target_modules: + module_names = [module_name for module_name, _ in target_modules] + else: + yield [] + return + + if not module_names: + yield [] + return + + # clear device cache to free up unused memory + if torch.get_device_module().is_available(): + torch.get_device_module().synchronize() + torch.get_device_module().empty_cache() + + offload_disabled_modules = [] + for module_name in module_names: + module = self.modules.get(module_name) + if module is not None and isinstance(module, OffloadableDiTMixin): + if module.layerwise_offload_managers is not None: + module.disable_offload() + offload_disabled_modules.append(module) + + try: + yield offload_disabled_modules + finally: + # Re-enable layerwise offload: sync weights to CPU and restore hooks + for module in offload_disabled_modules: + module.enable_offload() + + def convert_module_lora_layers( + self, + module: torch.nn.Module, + module_name: str, + target_lora_layers: dict[str, BaseLayerWithLoRA], + check_exclude: bool = True, + ) -> int: + """ + Convert layers in a module to LoRA layers. + + Args: + module: The module to convert. + module_name: The name of the module (for replace_submodule). + target_lora_layers: The dictionary to store the converted LoRA layers. + check_exclude: Whether to check the exclude_lora_layers list. + + Returns: + The number of layers converted. + """ + converted_count = 0 + for name, layer in module.named_modules(): + if not self.is_target_layer(name): + continue + + if check_exclude: + excluded = any( + exclude_layer in name for exclude_layer in self.exclude_lora_layers + ) + if excluded: + continue + + lora_layer = wrap_with_lora_layer( + layer, + lora_rank=self.lora_rank, + lora_alpha=self.lora_alpha, + ) + if lora_layer is not None: + target_lora_layers[name] = lora_layer + replace_submodule(self.modules[module_name], name, lora_layer) + converted_count += 1 + + return converted_count + + def convert_to_lora_layers(self) -> None: + """ + Unified method to convert the transformer to a LoRA transformer. + """ + if self.lora_initialized: + return + self.lora_initialized = True + + # Convert transformer + converted_count = self.convert_module_lora_layers( + self.modules["transformer"], + "transformer", + self.lora_layers, + check_exclude=True, + ) + logger.info("Converted %d layers to LoRA layers", converted_count) + + # Convert transformer_2 if exists (e.g., Wan2.2 A14B dual-transformer) + if ( + "transformer_2" in self.modules + and self.modules["transformer_2"] is not None + ): + converted_count_2 = self.convert_module_lora_layers( + self.modules["transformer_2"], + "transformer_2", + self.lora_layers_transformer_2, + check_exclude=True, + ) + logger.info( + "Converted %d layers to LoRA layers in transformer_2", converted_count_2 + ) + + # Convert fake_score_transformer if exists + if "fake_score_transformer" in self.modules: + converted_count_critic = self.convert_module_lora_layers( + self.modules["fake_score_transformer"], + "fake_score_transformer", + self.lora_layers_critic, + check_exclude=False, + ) + logger.info( + "Converted %d layers to LoRA layers in the critic model", + converted_count_critic, + ) + + def _normalize_lora_params( + self, + lora_nickname: str | list[str], + lora_path: str | None | list[str | None], + strength: float | list[float], + target: str | list[str], + ) -> tuple[list[str], list[str | None], list[float], list[str]]: + """ + Normalize LoRA parameters to lists for multi-LoRA support. + + Requirements: + - each nickname must have a corresponding lora_path (no implicit repeat) + - strength / target if scalar broadcast, else length must match nickname + """ + # nickname + if isinstance(lora_nickname, str): + lora_nicknames = [lora_nickname] + else: + lora_nicknames = lora_nickname + + # lora_path: require 1:1 mapping with nickname (no implicit repeat) + if isinstance(lora_path, list): + lora_paths = lora_path + else: + lora_paths = [lora_path] + if len(lora_paths) != len(lora_nicknames): + raise ValueError( + f"Length mismatch: lora_nickname has {len(lora_nicknames)} items, " + f"but lora_path has {len(lora_paths)} items. " + "Provide one path per nickname." + ) + + # strength and target: allow scalar broadcast, else length must match + if isinstance(strength, (int, float)): + strengths = [float(strength)] * len(lora_nicknames) + else: + strengths = [float(s) for s in strength] + if len(strengths) != len(lora_nicknames): + raise ValueError( + f"Length mismatch: lora_nickname has {len(lora_nicknames)} items, " + f"but strength has {len(strengths)} items" + ) + + if isinstance(target, str): + targets = [target] * len(lora_nicknames) + else: + targets = target + if len(targets) != len(lora_nicknames): + raise ValueError( + f"Length mismatch: lora_nickname has {len(lora_nicknames)} items, " + f"but target has {len(targets)} items" + ) + return lora_nicknames, lora_paths, strengths, targets + + def _check_lora_config_matches( + self, + module_name: str, + target_nicknames: list[str], + target_strengths: list[float], + adapter_updated: bool, + ) -> bool: + """ + Check if current LoRA configuration matches the target configuration. + + Args: + module_name: The name of the module to check. + target_nicknames: List of LoRA nicknames to apply. + target_strengths: List of LoRA strengths to apply. + adapter_updated: Whether any adapter was updated/loaded. + + Returns: + True if the configuration matches exactly (including order and strength), False otherwise. + """ + if not self.is_lora_merged.get(module_name, False): + return False + if adapter_updated: + return False # Adapter was updated, need to reapply + + stored_config = self.cur_adapter_config.get(module_name) + if stored_config is None: + return False + + stored_nicknames, stored_strengths = stored_config + # Compare: nickname list and strength list must match exactly (including order) + return ( + stored_nicknames == target_nicknames + and stored_strengths == target_strengths + ) + + def _apply_lora_to_layers( + self, + lora_layers: dict[str, BaseLayerWithLoRA], + lora_nicknames: list[str], + lora_paths: list[str | None], + rank: int, + strengths: list[float], + clear_existing: bool = False, + ) -> int: + """ + Apply LoRA weights to the given lora_layers. Supports multiple LoRA adapters. + + Args: + lora_layers: The dictionary of LoRA layers to apply weights to. + lora_nicknames: The list of nicknames of the LoRA adapters. + lora_paths: The list of paths to the LoRA adapters. Must match length of lora_nicknames. + rank: The distributed rank (for logging). + strengths: The list of LoRA strengths for merge. Must match length of lora_nicknames. + clear_existing: If True, clear existing LoRA weights before adding new ones. + + Returns: + The number of layers that had LoRA weights applied. + """ + if len(lora_paths) != len(lora_nicknames): + raise ValueError( + f"Length mismatch: lora_nicknames has {len(lora_nicknames)} items, " + f"but lora_paths has {len(lora_paths)} items" + ) + if len(strengths) != len(lora_nicknames): + raise ValueError( + f"Length mismatch: lora_nicknames has {len(lora_nicknames)} items, " + f"but strengths has {len(strengths)} items" + ) + + adapted_count = 0 + for name, layer in lora_layers.items(): + # Apply all LoRA adapters in order + for idx, (nickname, path, lora_strength) in enumerate( + zip(lora_nicknames, lora_paths, strengths) + ): + lora_A_name = name + ".lora_A" + lora_B_name = name + ".lora_B" + if ( + lora_A_name in self.lora_adapters[nickname] + and lora_B_name in self.lora_adapters[nickname] + ): + # Some LoRA checkpoints (e.g. Lightning distill) store per-layer alpha as ".alpha". + # If present, we must apply the standard LoRA scaling: scale = alpha / rank. + try: + inferred_rank = int( + self.lora_adapters[nickname][lora_A_name].shape[0] + ) + except Exception: + inferred_rank = None + # Default to None for some checkpoints without ".alpha" + inferred_alpha: int | None = None + alpha_key = name + ".alpha" + if alpha_key in self.lora_adapters[nickname]: + try: + inferred_alpha = int( + self.lora_adapters[nickname][alpha_key].item() + ) + except Exception: + inferred_alpha = None + + if inferred_rank is not None: + layer.lora_rank = inferred_rank + layer.lora_alpha = ( + inferred_alpha + if inferred_alpha is not None + else inferred_rank + ) + + layer.set_lora_weights( + self.lora_adapters[nickname][lora_A_name], + self.lora_adapters[nickname][lora_B_name], + lora_path=path, + strength=lora_strength, + clear_existing=( + clear_existing and idx == 0 + ), # Only clear on first LoRA + ) + adapted_count += 1 + else: + if rank == 0 and idx == 0: # Only warn for first missing LoRA + logger.warning( + "LoRA adapter %s does not contain the weights for layer '%s'. LoRA will not be applied to it.", + path, + name, + ) + # Only disable if no LoRA was applied at all + if idx == len(lora_nicknames) - 1: + has_any_lora = any( + name + ".lora_A" in self.lora_adapters[n] + and name + ".lora_B" in self.lora_adapters[n] + for n in lora_nicknames + ) + if not has_any_lora: + layer.disable_lora = True + return adapted_count + + def is_lora_effective(self, target: str = "all") -> bool: + """ + Check if LoRA is currently effective (merged) for the specified target. + + Args: + target: Which transformer to check. "all" returns True if any is merged. + """ + if target == "all": + return any(self.is_lora_merged.values()) + return self.is_lora_merged.get(target, False) + + def is_lora_set(self, target: str = "all") -> bool: + """ + Check if LoRA has been set for the specified target. + + Args: + target: Which transformer to check. "all" returns True if any is set. + """ + if not self.lora_initialized: + return False + if target == "all": + return bool(self.cur_adapter_name) + return target in self.cur_adapter_name + + def load_lora_adapter(self, lora_path: str, lora_nickname: str, rank: int): + """ + Load the LoRA, and setup the lora_adapters for later weight replacement + """ + assert lora_path is not None + + # Only rank 0 downloads to avoid race conditions where other ranks + # try to load incomplete downloads + if rank == 0: + lora_local_path = maybe_download_lora(lora_path) + else: + lora_local_path = None + + # Synchronize all ranks after download completes + if dist.is_initialized(): + dist.barrier() + + # Non-rank-0 workers now download (will hit cache since rank 0 completed) + if rank != 0: + lora_local_path = maybe_download_lora(lora_path) + + raw_state_dict = load_file(lora_local_path) + lora_state_dict = normalize_lora_state_dict(raw_state_dict, logger=logger) + + if lora_nickname in self.lora_adapters: + self.lora_adapters[lora_nickname].clear() + + config = self.server_args.pipeline_config.dit_config.arch_config + + param_names_mapping_fn = get_param_names_mapping( + config.param_names_mapping + or self.modules["transformer"].param_names_mapping + ) + lora_param_names_mapping_fn = get_param_names_mapping( + config.lora_param_names_mapping + or self.modules["transformer"].lora_param_names_mapping + ) + + to_merge_params: defaultdict[Hashable, dict[Any, Any]] = defaultdict(dict) + for name, weight in lora_state_dict.items(): + name = name.replace("diffusion_model.", "") + name = name.replace(".weight", "") + # misc-format -> HF-format + name, _, _ = lora_param_names_mapping_fn(name) + # HF-format (LoRA) -> SGLang-dit-format + target_name, merge_index, num_params_to_merge = param_names_mapping_fn(name) + # for fuse B(out_dim, r) @ A(r, in_dim) -> (N, out_dim, r) @ (N, r, in_dim) + # see param mapping in HunyuanVideoArchConfig + if merge_index is not None: + to_merge_params[target_name][merge_index] = weight + if len(to_merge_params[target_name]) == num_params_to_merge: + sorted_tensors = [ + to_merge_params[target_name][i] + for i in range(num_params_to_merge) + ] + # Use stack instead of cat because it needs to be compatible with TP. + weight = torch.stack(sorted_tensors, dim=0) + del to_merge_params[target_name] + else: + continue + + if target_name in self.lora_adapters[lora_nickname]: + raise ValueError( + f"Dit target weight name {target_name} already exists in lora_adapters[{lora_nickname}]" + ) + self.lora_adapters[lora_nickname][target_name] = weight.to(self.device) + self.loaded_adapter_paths[lora_nickname] = lora_path + logger.info("Rank %d: loaded LoRA adapter %s", rank, lora_path) + + def set_lora( + self, + lora_nickname: str | list[str], + lora_path: str | None | list[str | None] = None, + target: str | list[str] = "all", + strength: float | list[float] = 1.0, + ): # type: ignore + """ + Load LoRA adapter(s) into the pipeline and apply them to the specified transformer(s). + Supports both single LoRA (backward compatible) and multiple LoRA adapters. + """ + # Normalize inputs to lists for multi-LoRA support + lora_nicknames, lora_paths, strengths, targets = self._normalize_lora_params( + lora_nickname, lora_path, strength, target + ) + + # Validate targets + invalid_targets = [t for t in targets if t not in self.VALID_TARGETS] + if invalid_targets: + raise ValueError( + f"Invalid target(s): {invalid_targets}. Valid targets: {self.VALID_TARGETS}" + ) + + # Disable layerwise offload before convert_to_lora_layers to ensure weights are accessible + # This is critical because convert_to_lora_layers needs to save cpu_weight from actual weights, + # not from offloaded placeholder tensors + if not self.lora_initialized: + with self._temporarily_disable_offload( + target="all", use_module_names_only=True + ): + self.convert_to_lora_layers() + + # Check adapter presence and load missing adapters + adapter_updated = False + rank = dist.get_rank() + + # load required adapters + for nickname, path in zip(lora_nicknames, lora_paths): + if nickname not in self.lora_adapters and path is None: + raise ValueError( + f"Adapter {nickname} not found in the pipeline. Please provide lora_path to load it." + ) + # Check if adapter needs to be loaded + should_load = False + if path is not None: + if nickname not in self.loaded_adapter_paths: + should_load = True + elif self.loaded_adapter_paths[nickname] != path: + should_load = True + if should_load: + adapter_updated = True + self.load_lora_adapter(path, nickname, rank) + + # Group by target to apply separately + target_to_indices = {} + for idx, tgt in enumerate(targets): + if tgt not in target_to_indices: + target_to_indices[tgt] = [] + target_to_indices[tgt].append(idx) + + adapted_count = 0 + for tgt, idx_list in target_to_indices.items(): + target_modules, error = self._get_target_lora_layers(tgt) + if error: + logger.warning("set_lora: %s", error) + if not target_modules: + continue + + # Disable layerwise offload if enabled: load all layers to GPU + # the LoRA weights merging process requires weights being on device + with self._temporarily_disable_offload(target_modules=target_modules): + tgt_nicknames = [lora_nicknames[i] for i in idx_list] + tgt_paths = [lora_paths[i] for i in idx_list] + tgt_strengths = [strengths[i] for i in idx_list] + + merged_name = ( + ",".join(tgt_nicknames) + if len(tgt_nicknames) > 1 + else tgt_nicknames[0] + ) + + # Skip if LoRA configuration matches exactly (including order and strength) + # Since all modules for the same target apply the same config, checking one is sufficient + first_module_name, _ = target_modules[0] + if self._check_lora_config_matches( + first_module_name, tgt_nicknames, tgt_strengths, adapter_updated + ): + logger.info("LoRA configuration matches exactly, skipping") + continue + + # Apply LoRA to modules for this target + for module_name, lora_layers_dict in target_modules: + count = self._apply_lora_to_layers( + lora_layers_dict, + tgt_nicknames, + tgt_paths, + rank, + tgt_strengths, + clear_existing=True, + ) + adapted_count += count + self.cur_adapter_name[module_name] = merged_name + self.cur_adapter_path[module_name] = ",".join( + str(p or self.loaded_adapter_paths.get(n, "")) + for n, p in zip(tgt_nicknames, tgt_paths) + ) + self.is_lora_merged[module_name] = True + self.cur_adapter_strength[module_name] = tgt_strengths[0] + # Store full configuration for multi-LoRA support (preserves order and all strengths) + self.cur_adapter_config[module_name] = ( + tgt_nicknames.copy(), + tgt_strengths.copy(), + ) + + logger.info( + "Rank %d: LoRA adapter(s) %s applied to %d layers (targets: %s, strengths: %s)", + rank, + ", ".join(map(str, lora_paths)) if lora_paths else None, + adapted_count, + ", ".join(targets) if len(set(targets)) > 1 else targets[0], + ( + ", ".join(f"{s:.2f}" for s in strengths) + if len(strengths) > 1 + else f"{strengths[0]:.2f}" + ), + ) + + def merge_lora_weights(self, target: str = "all", strength: float = 1.0) -> None: + """ + Merge LoRA weights into the base model for the specified target. + + This operation is idempotent - calling it when LoRA is already merged is safe. + + Args: + target: Which transformer(s) to merge. One of "all", "transformer", + "transformer_2", "critic". + strength: LoRA strength for merge, default 1.0. + """ + target_modules, error = self._get_target_lora_layers(target) + if error: + logger.warning("merge_lora_weights: %s", error) + if not target_modules: + return + + # Disable layerwise offload if enabled: load all layers to GPU + with self._temporarily_disable_offload(target_modules=target_modules): + for module_name, lora_layers_dict in target_modules: + if self.is_lora_merged.get(module_name, False): + # Check if strength is the same - if so, skip (idempotent) + if self.cur_adapter_strength.get(module_name) == strength: + logger.warning( + "LoRA weights are already merged for %s with same strength", + module_name, + ) + continue + # Different strength requested - allow re-merge (layer handles unmerge internally) + logger.info( + "Re-merging LoRA weights for %s with new strength %s", + module_name, + strength, + ) + for name, layer in lora_layers_dict.items(): + # Only re-enable LoRA for layers that actually have LoRA weights + has_lora_weights = ( + hasattr(layer, "lora_A") and layer.lora_A is not None + ) + if not has_lora_weights: + continue + if hasattr(layer, "disable_lora"): + layer.disable_lora = False + try: + layer.merge_lora_weights(strength=strength) + except Exception as e: + logger.warning("Could not merge layer %s: %s", name, e) + continue + self.is_lora_merged[module_name] = True + self.cur_adapter_strength[module_name] = strength + logger.info( + "LoRA weights merged for %s (strength: %s)", module_name, strength + ) + + def unmerge_lora_weights(self, target: str = "all") -> None: + """ + Unmerge LoRA weights from the base model for the specified target. + This also disables LoRA so it won't be computed on-the-fly. + + This operation is idempotent - calling it when LoRA is not merged is safe. + + Args: + target: Which transformer(s) to unmerge. One of "all", "transformer", + "transformer_2", "critic". + """ + target_modules, error = self._get_target_lora_layers(target) + if error: + logger.warning("unmerge_lora_weights: %s", error) + if not target_modules: + return + + # Disable layerwise offload if enabled: load all layers to GPU + + for module_name, lora_layers_dict in target_modules: + if not self.is_lora_merged.get(module_name, False): + logger.warning( + "LoRA weights are not merged for %s, skipping", module_name + ) + continue + with self._temporarily_disable_offload(target_modules=target_modules): + for name, layer in lora_layers_dict.items(): + # Check layer-level state to avoid raising exception + if hasattr(layer, "merged") and not layer.merged: + logger.warning("Layer %s is not merged, skipping", name) + # Still disable LoRA to prevent on-the-fly computation + if hasattr(layer, "disable_lora"): + layer.disable_lora = True + continue + try: + layer.unmerge_lora_weights() + # Disable LoRA after unmerge to prevent on-the-fly computation + if hasattr(layer, "disable_lora"): + layer.disable_lora = True + except ValueError as e: + logger.warning("Could not unmerge layer %s: %s", name, e) + # Still disable LoRA even if unmerge failed + if hasattr(layer, "disable_lora"): + layer.disable_lora = True + continue + self.is_lora_merged[module_name] = False + self.cur_adapter_strength.pop(module_name, None) + self.cur_adapter_config.pop(module_name, None) + logger.info("LoRA weights unmerged for %s", module_name) + + def get_lora_status(self) -> dict[str, Any]: + """ + Summarize loaded LoRA adapters and current application status per module. + + Returns a plain Python dict with no tensor values to allow safe JSON serialization. + """ + # Loaded adapters: list of {nickname, path} + loaded_adapters = [ + {"nickname": nickname, "path": path} + for nickname, path in self.loaded_adapter_paths.items() + ] + + def _module_status(module_name: str) -> list[dict] | None: + # return list of dict to support multi-lora in the future + if not self.is_lora_merged.get(module_name, False): + return None + else: + return [ + { + "nickname": self.cur_adapter_name.get(module_name, None), + "path": self.cur_adapter_path.get(module_name, None), + "merged": self.is_lora_merged.get(module_name, False), + "strength": self.cur_adapter_strength.get(module_name, None), + } + ] + + # Build active usage per module only for modules that exist in this pipeline + active: dict[str, Any] = {} + if ( + "transformer" in self.modules + and self.modules["transformer"] is not None + and (status := _module_status("transformer")) is not None + ): + active["transformer"] = status + if ( + "transformer_2" in self.modules + and self.modules["transformer_2"] is not None + and (status := _module_status("transformer_2")) is not None + ): + active["transformer_2"] = status + if ( + "fake_score_transformer" in self.modules + and self.modules["fake_score_transformer"] is not None + and (status := _module_status("critic")) is not None + ): + active["critic"] = status + + return { + "loaded_adapters": loaded_adapters, + "active": active, + } diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..9118770a58932b3284cadb0b50d25cab0575b648 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py @@ -0,0 +1,339 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Inspired by SGLang: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_executor/forward_batch_info.py +""" +Data structures for functional pipeline processing. + +This module defines the dataclasses used to pass state between pipeline components +in a functional manner, reducing the need for explicit parameter passing. +""" + +from __future__ import annotations + +import os +import pprint +from copy import deepcopy +from dataclasses import MISSING, asdict, dataclass, field, fields +from typing import Any, Optional + +import PIL.Image +import torch + +from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams +from sglang.multimodal_gen.runtime.server_args import ( + ServerArgs, + _sanitize_for_logging, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.runtime.utils.perf_logger import RequestMetrics +from sglang.multimodal_gen.utils import align_to + +logger = init_logger(__name__) + +SAMPLING_PARAMS_FIELDS = {f.name for f in fields(SamplingParams)} + + +@dataclass(init=False) +class Req: + """ + Complete state passed through the pipeline execution. + + This dataclass contains all information needed during the diffusion pipeline + execution, allowing methods to update specific components without needing + to manage numerous individual parameters. + + [IMPORTANT] Fields that overlap with SamplingParams are automatically delegated to the + sampling_params member via __getattr__ and __setattr__. + """ + + sampling_params: SamplingParams | None = None + + generator: torch.Generator | list[torch.Generator] | None = None + + # Image encoder hidden states + image_embeds: list[torch.Tensor] = field(default_factory=list) + + original_condition_image_size: tuple[int, int] = None + condition_image: torch.Tensor | PIL.Image.Image | None = None + vae_image: torch.Tensor | PIL.Image.Image | None = None + pixel_values: torch.Tensor | PIL.Image.Image | None = None + preprocessed_image: torch.Tensor | None = None + + output_file_ext: str | None = None + # Primary encoder embeddings + prompt_embeds: list[torch.Tensor] | torch.Tensor = field(default_factory=list) + negative_prompt_embeds: list[torch.Tensor] | None = None + prompt_attention_mask: list[torch.Tensor] | None = None + negative_attention_mask: list[torch.Tensor] | None = None + clip_embedding_pos: list[torch.Tensor] | None = None + clip_embedding_neg: list[torch.Tensor] | None = None + + pooled_embeds: list[torch.Tensor] = field(default_factory=list) + neg_pooled_embeds: list[torch.Tensor] = field(default_factory=list) + + # Additional text-related parameters + max_sequence_length: int | None = None + prompt_template: dict[str, Any] | None = None + do_classifier_free_guidance: bool = False + + seeds: list[int] | None = None + + # Tracking if embeddings are already processed + is_prompt_processed: bool = False + + # Audio Embeddings (LTX-2) + audio_prompt_embeds: list[torch.Tensor] | torch.Tensor = field(default_factory=list) + negative_audio_prompt_embeds: list[torch.Tensor] | torch.Tensor = field( + default_factory=list + ) + + # Latent tensors + latents: torch.Tensor | None = None + y: torch.Tensor | None = None + # Flux-2 + latent_ids: torch.Tensor | None = None + + # Audio Latents + audio_latents: torch.Tensor | None = None + audio_noise: torch.Tensor | None = None + raw_audio_latent_shape: tuple[int, ...] | None = None + + # Audio Parameters + generate_audio: bool = True + + raw_latent_shape: torch.Tensor | None = None + noise_pred: torch.Tensor | None = None + # vae-encoded condition image + image_latent: torch.Tensor | list[torch.Tensor] | None = None + condition_image_latent_ids: torch.Tensor | list[torch.Tensor] | None = None + vae_image_sizes: list[tuple[int, int]] | None = None + + # Latent dimensions + height_latents: list[int] | int | None = None + width_latents: list[int] | int | None = None + + # Timesteps + timesteps: torch.Tensor | None = None + paired_timesteps: torch.Tensor | None = None + timestep: torch.Tensor | float | int | None = None + step_index: int | None = None + + eta: float = 0.0 + sigmas: list[float] | None = None + + n_tokens: int | None = None + + # Other parameters that may be needed by specific schedulers + extra_step_kwargs: dict[str, Any] = field(default_factory=dict) + + # Component modules (populated by the pipeline) + modules: dict[str, Any] = field(default_factory=dict) + + trajectory_timesteps: list[torch.Tensor] | None = None + trajectory_latents: torch.Tensor | None = None + trajectory_audio_latents: torch.Tensor | None = None + + # Extra parameters that might be needed by specific pipeline implementations + extra: dict[str, Any] = field(default_factory=dict) + + is_warmup: bool = False + + # STA parameters + STA_param: list | None = None + is_cfg_negative: bool = False + mask_search_final_result_pos: list[list] | None = None + mask_search_final_result_neg: list[list] | None = None + + # VSA parameters + VSA_sparsity: float = 0.0 + + # stage logging + metrics: Optional["RequestMetrics"] = None + + # results + output: torch.Tensor | None = None + audio: torch.Tensor | None = None + audio_sample_rate: int | None = None + + def __init__(self, **kwargs): + # Initialize dataclass fields + for name, field in self.__class__.__dataclass_fields__.items(): + if name in kwargs: + object.__setattr__(self, name, kwargs.pop(name)) + elif field.default is not MISSING: + object.__setattr__(self, name, field.default) + elif field.default_factory is not MISSING: + object.__setattr__(self, name, field.default_factory()) + + for name, value in kwargs.items(): + setattr(self, name, value) + + self.validate() + + def __getattr__(self, name: str) -> Any: + """ + Delegate attribute access to sampling_params if not found in Req. + This is only called when the attribute is not found in the instance. + """ + if name == "sampling_params": + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + sampling_params = object.__getattribute__(self, "sampling_params") + if sampling_params is not None and hasattr(sampling_params, name): + return getattr(sampling_params, name) + + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + def __setattr__(self, name: str, value: Any) -> None: + """ + Smart attribute setting: + 1. If field exists in Req, set it in Req + 2. Else if field exists in sampling_params, set it in sampling_params + 3. Else set it in Req (for dynamic attributes) + """ + if name == "sampling_params": + object.__setattr__(self, name, value) + return + + if name in self.__class__.__dataclass_fields__: + object.__setattr__(self, name, value) + return + + try: + sampling_params = object.__getattribute__(self, "sampling_params") + except AttributeError: + sampling_params = None + + if sampling_params is not None and hasattr(sampling_params, name): + setattr(sampling_params, name, value) + return + + if sampling_params is None and name in SAMPLING_PARAMS_FIELDS: + new_sp = SamplingParams() + object.__setattr__(self, "sampling_params", new_sp) + setattr(new_sp, name, value) + return + + object.__setattr__(self, name, value) + + @property + def batch_size(self): + # Determine batch size + if isinstance(self.prompt, list): + batch_size = len(self.prompt) + elif self.prompt is not None: + batch_size = 1 + else: + batch_size = self.prompt_embeds[0].shape[0] + + # Adjust batch size for number of videos per prompt + batch_size *= self.num_outputs_per_prompt + return batch_size + + def output_file_path(self, num_outputs=1, output_idx=None): + output_file_name = self.output_file_name + if num_outputs > 1 and output_file_name: + base, ext = os.path.splitext(output_file_name) + output_file_name = f"{base}_{output_idx}{ext}" + + if self.output_path is None or not output_file_name: + return None + return os.path.join(self.output_path, output_file_name) + + def set_as_warmup(self, warmup_steps: int = 1): + self.is_warmup = True + self.save_output = False + self.suppress_logs = True + self.extra["cache_dit_num_inference_steps"] = self.num_inference_steps + self.num_inference_steps = warmup_steps + + def copy_as_warmup(self, warmup_steps: int = 1) -> "Req": + req = deepcopy(self) + req.set_as_warmup(warmup_steps) + return req + + def validate(self): + """Initialize dependent fields after dataclass initialization.""" + # Set do_classifier_free_guidance based on guidance scale and negative prompt + if self.guidance_scale > 1.0 and self.negative_prompt is not None: + self.do_classifier_free_guidance = True + if self.negative_prompt_embeds is None: + self.negative_prompt_embeds = [] + if self.guidance_scale_2 is None: + self.guidance_scale_2 = self.guidance_scale + + self.metrics = RequestMetrics(request_id=self.request_id) + + def adjust_size(self, server_args: ServerArgs): + pass + + def __str__(self): + return pprint.pformat(asdict(self), indent=2, width=120) + + def log(self, server_args: ServerArgs): + if self.is_warmup or self.suppress_logs: + return + # TODO: in some cases (e.g., TI2I), height and weight might be undecided at this moment + if self.height: + target_height = align_to(self.height, 16) + else: + target_height = -1 + if self.width: + target_width = align_to(self.width, 16) + else: + target_width = -1 + + # sanitize prompts for info-level logging + sanitized_prompt = _sanitize_for_logging(self.prompt, key_hint="prompt") + sanitized_neg_prompt = _sanitize_for_logging( + self.negative_prompt, key_hint="negative_prompt" + ) + + # Log sampling parameters + debug_str = f"""Sampling params: + width: {target_width} + height: {target_height} + num_frames: {self.num_frames} + fps: {self.fps} + prompt: {sanitized_prompt} + neg_prompt: {sanitized_neg_prompt} + seed: {self.seed} + infer_steps: {self.num_inference_steps} + num_outputs_per_prompt: {self.num_outputs_per_prompt} + guidance_scale: {self.guidance_scale} + embedded_guidance_scale: {server_args.pipeline_config.embedded_cfg_scale} + n_tokens: {self.n_tokens} + flow_shift: {server_args.pipeline_config.flow_shift} + image_path: {self.image_path} + save_output: {self.save_output} + output_file_path: {self.output_file_path()} + """ # type: ignore[attr-defined] + logger.debug(debug_str) + + +@dataclass +class OutputBatch: + """ + Final output (after pipeline completion) + """ + + output: torch.Tensor | None = None + audio: torch.Tensor | None = None + audio_sample_rate: int | None = None + trajectory_timesteps: list[torch.Tensor] | None = None + trajectory_latents: torch.Tensor | None = None + trajectory_decoded: list[torch.Tensor] | None = None + error: str | None = None + output_file_paths: list[str] | None = None + + # logged metrics info, directly from Req.timings + metrics: Optional["RequestMetrics"] = None + + # For ComfyUI integration: noise prediction from denoising stage + noise_pred: torch.Tensor | None = None + peak_memory_mb: float = 0.0 diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/__init__.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6f8e142212cc41dc8967b699afe7cf84ae995225 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/__init__.py @@ -0,0 +1,95 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Pipeline stages for diffusion models. + +This package contains the various stages that can be composed to create +complete diffusion pipelines. +""" + +from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage +from sglang.multimodal_gen.runtime.pipelines_core.stages.causal_denoising import ( + CausalDMDDenoisingStage, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.comfyui_latent_preparation import ( + ComfyUILatentPreparationStage, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.decoding import DecodingStage +from sglang.multimodal_gen.runtime.pipelines_core.stages.decoding_av import ( + LTX2AVDecodingStage, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.denoising import DenoisingStage +from sglang.multimodal_gen.runtime.pipelines_core.stages.denoising_av import ( + LTX2AVDenoisingStage, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.denoising_dmd import ( + DmdDenoisingStage, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.encoding import EncodingStage + +# Hunyuan3D paint stages +from sglang.multimodal_gen.runtime.pipelines_core.stages.hunyuan3d_paint import ( + Hunyuan3DPaintPostprocessStage, + Hunyuan3DPaintPreprocessStage, + Hunyuan3DPaintTexGenStage, +) + +# Hunyuan3D shape stages +from sglang.multimodal_gen.runtime.pipelines_core.stages.hunyuan3d_shape import ( + Hunyuan3DShapeBeforeDenoisingStage, + Hunyuan3DShapeDenoisingStage, + Hunyuan3DShapeExportStage, + Hunyuan3DShapeSaveStage, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.image_encoding import ( + ImageEncodingStage, + ImageVAEEncodingStage, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.input_validation import ( + InputValidationStage, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.latent_preparation import ( + LatentPreparationStage, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.latent_preparation_av import ( + LTX2AVLatentPreparationStage, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.text_connector import ( + LTX2TextConnectorStage, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.text_encoding import ( + TextEncodingStage, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.timestep_preparation import ( + TimestepPreparationStage, +) + +__all__ = [ + "PipelineStage", + "InputValidationStage", + "TimestepPreparationStage", + "LatentPreparationStage", + "ComfyUILatentPreparationStage", + "LTX2AVLatentPreparationStage", + "DenoisingStage", + "DmdDenoisingStage", + "LTX2AVDenoisingStage", + "CausalDMDDenoisingStage", + "EncodingStage", + "DecodingStage", + "LTX2AVDecodingStage", + "ImageEncodingStage", + "ImageVAEEncodingStage", + "TextEncodingStage", + "LTX2TextConnectorStage", + # Hunyuan3D shape stages + "Hunyuan3DShapeBeforeDenoisingStage", + "Hunyuan3DShapeDenoisingStage", + "Hunyuan3DShapeExportStage", + "Hunyuan3DShapeSaveStage", + # Hunyuan3D paint stages + "Hunyuan3DPaintPreprocessStage", + "Hunyuan3DPaintTexGenStage", + "Hunyuan3DPaintPostprocessStage", +] diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/base.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e55243294ccc43f2a129873a5bdd89507040bb65 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/base.py @@ -0,0 +1,239 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Base classes for pipeline stages. + +This module defines the abstract base classes for pipeline stages that can be +composed to create complete diffusion pipelines. +""" + +from abc import ABC, abstractmethod +from enum import Enum, auto + +import torch + +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + VerificationResult, +) +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.server_args import ServerArgs, get_global_server_args +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.runtime.utils.perf_logger import StageProfiler + +logger = init_logger(__name__) + + +class StageParallelismType(Enum): + # execute on all gpus + REPLICATED = auto() + # executed on main rank only + MAIN_RANK_ONLY = auto() + # this stage requires a cfg-parallel + CFG_PARALLEL = auto() + + +class StageVerificationError(Exception): + """Exception raised when stage verification fails.""" + + pass + + +class PipelineStage(ABC): + """ + Abstract base class for all pipeline stages. + + A pipeline stage represents a discrete step in the diffusion process that can be + composed with other stages to create a complete pipeline. Each stage is responsible + for a specific part of the process, such as prompt encoding, latent preparation, etc. + """ + + def __init__(self): + self.server_args = get_global_server_args() + + def log_info(self, msg, *args): + """Logs an informational message with the stage name as a prefix.""" + if self.server_args.comfyui_mode: + return + logger.info(f"[{self.__class__.__name__}] {msg}", *args) + + def log_warning(self, msg, *args): + """Logs a warning message with the stage name as a prefix.""" + logger.warning(f"[{self.__class__.__name__}] {msg}", *args) + + def log_error(self, msg, *args): + """Logs an error message with the stage name as a prefix.""" + logger.error(f"[{self.__class__.__name__}] {msg}", *args) + + def log_debug(self, msg, *args): + """Logs a debug message with the stage name as a prefix.""" + logger.debug(f"[{self.__class__.__name__}] {msg}", *args) + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """ + Verify the input for the stage. + + Example: + from sglang.multimodal_gen.runtime.pipelines.stages.validators import V, VerificationResult + + def verify_input(self, batch, server_args): + result = VerificationResult() + result.add_check("height", batch.height, V.positive_int_divisible(8)) + result.add_check("width", batch.width, V.positive_int_divisible(8)) + result.add_check("image_latent", batch.image_latent, V.is_tensor) + return result + + """ + # Default implementation - no verification + return VerificationResult() + + def maybe_free_model_hooks(self): + pass + + def load_model(self): + """ + Load the model for the stage. + """ + pass + + def offload_model(self): + """ + Offload the model for the stage. + """ + pass + + # execute on all ranks by default + @property + def parallelism_type(self) -> StageParallelismType: + # if get_global_server_args().enable_cfg_parallel: + # return StageParallelismType.MAIN_RANK_ONLY + return StageParallelismType.REPLICATED + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """ + Verify the output for the stage. + + + + Returns: + A VerificationResult containing the verification status. + """ + # Default implementation - no verification + return VerificationResult() + + def _run_verification( + self, + verification_result: VerificationResult, + stage_name: str, + verification_type: str, + ) -> None: + """ + Run verification and raise errors if any checks fail. + + Args: + verification_result: Results from verify_input or verify_output + stage_name: Name of the current stage + verification_type: "input" or "output" + """ + if not verification_result.is_valid(): + failed_fields = verification_result.get_failed_fields() + if failed_fields: + # Get detailed failure information + detailed_summary = verification_result.get_failure_summary() + + failed_fields_str = ", ".join(failed_fields) + error_msg = ( + f"{verification_type.capitalize()} verification failed for {stage_name}: " + f"Failed fields: {failed_fields_str}\n" + f"Details: {detailed_summary}" + ) + raise StageVerificationError(error_msg) + + @property + def device(self) -> torch.device: + """Get the device for this stage.""" + return torch.device( + current_platform.device_type, + ) + + def set_logging(self, enable: bool): + """ + Enable or disable logging for this stage. + + Args: + enable: Whether to enable logging. + """ + self._enable_logging = enable + + def __call__( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Execute the stage's processing on the batch with optional verification and logging. + Should not be overridden by subclasses. + + + + Returns: + The updated batch information after this stage's processing. + """ + stage_name = self.__class__.__name__ + # Check if verification is enabled (simple approach for prototype) + + # Pre-execution input verification + try: + input_result = self.verify_input(batch, server_args) + self._run_verification(input_result, stage_name, "input") + except Exception as e: + logger.error("Input verification failed for %s: %s", stage_name, str(e)) + raise + + # Execute the actual stage logic with unified profiling + with StageProfiler( + stage_name, + logger=logger, + metrics=batch.metrics, + log_stage_start_end=not batch.is_warmup + and not (self.server_args and self.server_args.comfyui_mode), + perf_dump_path_provided=batch.perf_dump_path is not None, + ): + result = self.forward(batch, server_args) + + # Post-execution output verification + try: + output_result = self.verify_output(result, server_args) + self._run_verification(output_result, stage_name, "output") + except Exception as e: + logger.error("Output verification failed for %s: %s", stage_name, str(e)) + raise + + return result + + @abstractmethod + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Forward pass of the stage's processing. + + This method should be implemented by subclasses to provide the forward + processing logic for the stage. + + + + Returns: + The updated batch information after this stage's processing. + """ + raise NotImplementedError + + def backward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + raise NotImplementedError diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/causal_denoising.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/causal_denoising.py new file mode 100644 index 0000000000000000000000000000000000000000..f8640186e1666c1fddeaf41d1a3b287e4da1da01 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/causal_denoising.py @@ -0,0 +1,499 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import torch # type: ignore + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.models.utils import pred_noise_to_pred_video +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.denoising import DenoisingStage +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + StageValidators as V, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + VerificationResult, +) +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class CausalDMDDenoisingStage(DenoisingStage): + """ + Denoising stage for causal diffusion. + """ + + def __init__(self, transformer, scheduler) -> None: + super().__init__(transformer, scheduler) + # KV and cross-attention cache state (initialized on first forward) + self.kv_cache1: list | None = None + self.crossattn_cache: list | None = None + # Model-dependent constants (aligned with causal_inference.py assumptions) + self.num_transformer_blocks = self.transformer.config.arch_config.num_layers + self.num_frames_per_block = ( + self.transformer.config.arch_config.num_frames_per_block + ) + self.sliding_window_num_frames = ( + self.transformer.config.arch_config.sliding_window_num_frames + ) + + try: + self.local_attn_size = getattr( + self.transformer.model, "local_attn_size", -1 + ) # type: ignore + except Exception: + self.local_attn_size = -1 + + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + target_dtype = torch.bfloat16 + autocast_enabled = ( + target_dtype != torch.float32 + ) and not server_args.disable_autocast + + latent_seq_length = batch.latents.shape[-1] * batch.latents.shape[-2] + patch_ratio = ( + self.transformer.config.arch_config.patch_size[-1] + * self.transformer.config.arch_config.patch_size[-2] + ) + self.frame_seq_length = latent_seq_length // patch_ratio + # TODO(will): make this a parameter once we add i2v support + independent_first_frame = self.transformer.independent_first_frame + + # Timesteps for DMD + timesteps = torch.tensor( + server_args.pipeline_config.dmd_denoising_steps, dtype=torch.long + ).cpu() + + if server_args.pipeline_config.warp_denoising_step: + logger.info("Warping timesteps...") + scheduler_timesteps = torch.cat( + (self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)) + ) + timesteps = scheduler_timesteps[1000 - timesteps] + timesteps = timesteps.to(get_local_torch_device()) + logger.info("Using timesteps: %s", timesteps) + + # Image kwargs (kept empty unless caller provides compatible args) + image_kwargs: dict = {} + + pos_cond_kwargs = self.prepare_extra_func_kwargs( + self.transformer.forward, + { + # "encoder_hidden_states_2": batch.clip_embedding_pos, + "encoder_attention_mask": batch.prompt_attention_mask, + }, + ) + + # STA + if self.attn_backend.get_enum() == AttentionBackendEnum.SLIDING_TILE_ATTN: + self.prepare_sta_param(batch, server_args) + + # Latents and prompts + assert batch.latents is not None, "latents must be provided" + latents = batch.latents # [B, C, T, H, W] + b, c, t, h, w = latents.shape + prompt_embeds = batch.prompt_embeds + assert torch.isnan(prompt_embeds[0]).sum() == 0 + + # Initialize or reset caches + if self.kv_cache1 is None: + self._initialize_kv_cache( + batch_size=latents.shape[0], dtype=target_dtype, device=latents.device + ) + self._initialize_crossattn_cache( + batch_size=latents.shape[0], + max_text_len=server_args.pipeline_config.text_encoder_configs[ + 0 + ].arch_config.text_len, + dtype=target_dtype, + device=latents.device, + ) + else: + assert self.crossattn_cache is not None + # reset cross-attention cache + for block_index in range(self.num_transformer_blocks): + self.crossattn_cache[block_index]["is_init"] = False # type: ignore + # reset kv cache pointers + for block_index in range(len(self.kv_cache1)): + self.kv_cache1[block_index]["global_end_index"] = ( + torch.tensor( # type: ignore + [0], dtype=torch.long, device=latents.device + ) + ) + self.kv_cache1[block_index]["local_end_index"] = ( + torch.tensor( # type: ignore + [0], dtype=torch.long, device=latents.device + ) + ) + + # Optional: cache context features from provided image latents prior to generation + current_start_frame = 0 + if getattr(batch, "image_latent", None) is not None: + image_latent = batch.image_latent + assert image_latent is not None + input_frames = image_latent.shape[2] + # timestep zero (or configured context noise) for cache warm-up + t_zero = torch.zeros( + [latents.shape[0]], device=latents.device, dtype=torch.long + ) + if independent_first_frame and input_frames >= 1: + # warm-up with the very first frame independently + image_first_btchw = ( + image_latent[:, :, :1, :, :].to(target_dtype).permute(0, 2, 1, 3, 4) + ) + with torch.autocast( + device_type=current_platform.device_type, + dtype=target_dtype, + enabled=autocast_enabled, + ): + _ = self.transformer( + image_first_btchw, + prompt_embeds, + t_zero, + kv_cache=self.kv_cache1, + crossattn_cache=self.crossattn_cache, + current_start=current_start_frame * self.frame_seq_length, + **image_kwargs, + **pos_cond_kwargs, + ) + current_start_frame += 1 + remaining_frames = input_frames - 1 + else: + remaining_frames = input_frames + + # process remaining input frames in blocks of num_frame_per_block + while remaining_frames > 0: + block = min(self.num_frames_per_block, remaining_frames) + ref_btchw = ( + image_latent[ + :, :, current_start_frame : current_start_frame + block, :, : + ] + .to(target_dtype) + .permute(0, 2, 1, 3, 4) + ) + with torch.autocast( + device_type=current_platform.device_type, + dtype=target_dtype, + enabled=autocast_enabled, + ): + _ = self.transformer( + ref_btchw, + prompt_embeds, + t_zero, + kv_cache=self.kv_cache1, + crossattn_cache=self.crossattn_cache, + current_start=current_start_frame * self.frame_seq_length, + **image_kwargs, + **pos_cond_kwargs, + ) + current_start_frame += block + remaining_frames -= block + + # Base position offset from any cache warm-up + pos_start_base = current_start_frame + + # Determine block sizes + if not independent_first_frame or ( + independent_first_frame and batch.image_latent is not None + ): + if t % self.num_frames_per_block != 0: + raise ValueError( + "num_frames must be divisible by num_frames_per_block for causal DMD denoising" + ) + num_blocks = t // self.num_frames_per_block + block_sizes = [self.num_frames_per_block] * num_blocks + start_index = 0 + else: + if (t - 1) % self.num_frames_per_block != 0: + raise ValueError( + "(num_frames - 1) must be divisible by num_frame_per_block when independent_first_frame=True" + ) + num_blocks = (t - 1) // self.num_frames_per_block + block_sizes = [1] + [self.num_frames_per_block] * num_blocks + start_index = 0 + + # DMD loop in causal blocks + with self.progress_bar(total=len(block_sizes) * len(timesteps)) as progress_bar: + for current_num_frames in block_sizes: + current_latents = latents[ + :, :, start_index : start_index + current_num_frames, :, : + ] + # use BTCHW for DMD conversion routines + noise_latents_btchw = current_latents.permute(0, 2, 1, 3, 4) + video_raw_latent_shape = noise_latents_btchw.shape + + for i, t_cur in enumerate(timesteps): + # Copy for pred conversion + noise_latents = noise_latents_btchw.clone() + latent_model_input = current_latents.to(target_dtype) + + if ( + batch.image_latent is not None + and independent_first_frame + and start_index == 0 + ): + latent_model_input = torch.cat( + [latent_model_input, batch.image_latent.to(target_dtype)], + dim=2, + ) + + # Prepare inputs + t_expand = t_cur.repeat(latent_model_input.shape[0]) + + # Attention metadata if needed + if ( + self.attn_backend.get_enum() + == AttentionBackendEnum.VIDEO_SPARSE_ATTN + ): + self.attn_metadata_builder_cls = ( + self.attn_backend.get_builder_cls() + ) + if self.attn_metadata_builder_cls is not None: + self.attn_metadata_builder = ( + self.attn_metadata_builder_cls() + ) + attn_metadata = self.attn_metadata_builder.build( # type: ignore + current_timestep=i, # type: ignore + raw_latent_shape=( + current_num_frames, + h, + w, + ), # type: ignore + patch_size=server_args.pipeline_config.dit_config.patch_size, # type: ignore + STA_param=batch.STA_param, # type: ignore + VSA_sparsity=server_args.attention_backend_config.VSA_sparsity, # type: ignore + device=get_local_torch_device(), # type: ignore + ) # type: ignore + assert ( + attn_metadata is not None + ), "attn_metadata cannot be None" + else: + attn_metadata = None + else: + attn_metadata = None + + with ( + torch.autocast( + device_type=current_platform.device_type, + dtype=target_dtype, + enabled=autocast_enabled, + ), + set_forward_context( + current_timestep=i, + attn_metadata=attn_metadata, + forward_batch=batch, + ), + ): + # Run transformer; follow DMD stage pattern + t_expanded_noise = t_cur * torch.ones( + (latent_model_input.shape[0], 1), + device=latent_model_input.device, + dtype=torch.long, + ) + pred_noise_btchw = self.transformer( + latent_model_input, + prompt_embeds, + t_expanded_noise, + kv_cache=self.kv_cache1, + crossattn_cache=self.crossattn_cache, + current_start=(pos_start_base + start_index) + * self.frame_seq_length, + start_frame=start_index, + **image_kwargs, + **pos_cond_kwargs, + ).permute(0, 2, 1, 3, 4) + + # Convert pred noise to pred video with FM Euler scheduler utilities + pred_video_btchw = pred_noise_to_pred_video( + pred_noise=pred_noise_btchw.flatten(0, 1), + noise_input_latent=noise_latents.flatten(0, 1), + timestep=t_expand, + scheduler=self.scheduler, + ).unflatten(0, pred_noise_btchw.shape[:2]) + + if i < len(timesteps) - 1: + next_timestep = timesteps[i + 1] * torch.ones( + [1], dtype=torch.long, device=pred_video_btchw.device + ) + noise = torch.randn( + video_raw_latent_shape, + dtype=pred_video_btchw.dtype, + generator=( + batch.generator[0] + if isinstance(batch.generator, list) + else batch.generator + ), + device=self.device, + ) + noise_btchw = noise + noise_latents_btchw = self.scheduler.add_noise( + pred_video_btchw.flatten(0, 1), + noise_btchw.flatten(0, 1), + next_timestep, + ).unflatten(0, pred_video_btchw.shape[:2]) + current_latents = noise_latents_btchw.permute(0, 2, 1, 3, 4) + else: + current_latents = pred_video_btchw.permute(0, 2, 1, 3, 4) + + if progress_bar is not None: + progress_bar.update() + + # Write back and advance + latents[:, :, start_index : start_index + current_num_frames, :, :] = ( + current_latents + ) + + # Re-run with context timestep to update KV cache using clean context + context_noise = getattr(server_args.pipeline_config, "context_noise", 0) + t_context = torch.ones( + [latents.shape[0]], device=latents.device, dtype=torch.long + ) * int(context_noise) + context_bcthw = current_latents.to(target_dtype) + with ( + torch.autocast( + device_type=current_platform.device_type, + dtype=target_dtype, + enabled=autocast_enabled, + ), + set_forward_context( + current_timestep=0, + attn_metadata=attn_metadata, + forward_batch=batch, + ), + ): + t_expanded_context = t_context.unsqueeze(1) + _ = self.transformer( + context_bcthw, + prompt_embeds, + t_expanded_context, + kv_cache=self.kv_cache1, + crossattn_cache=self.crossattn_cache, + current_start=(pos_start_base + start_index) + * self.frame_seq_length, + start_frame=start_index, + **image_kwargs, + **pos_cond_kwargs, + ) + start_index += current_num_frames + + batch.latents = latents + return batch + + def _initialize_kv_cache(self, batch_size, dtype, device) -> None: + """ + Initialize a Per-GPU KV cache aligned with the Wan model assumptions. + """ + kv_cache1 = [] + num_attention_heads = self.transformer.num_attention_heads + attention_head_dim = self.transformer.attention_head_dim + if self.local_attn_size != -1: + kv_cache_size = self.local_attn_size * self.frame_seq_length + else: + kv_cache_size = self.frame_seq_length * self.sliding_window_num_frames + + for _ in range(self.num_transformer_blocks): + kv_cache1.append( + { + "k": torch.zeros( + [ + batch_size, + kv_cache_size, + num_attention_heads, + attention_head_dim, + ], + dtype=dtype, + device=device, + ), + "v": torch.zeros( + [ + batch_size, + kv_cache_size, + num_attention_heads, + attention_head_dim, + ], + dtype=dtype, + device=device, + ), + "global_end_index": torch.tensor( + [0], dtype=torch.long, device=device + ), + "local_end_index": torch.tensor( + [0], dtype=torch.long, device=device + ), + } + ) + + self.kv_cache1 = kv_cache1 + + def _initialize_crossattn_cache( + self, batch_size, max_text_len, dtype, device + ) -> None: + """ + Initialize a Per-GPU cross-attention cache aligned with the Wan model assumptions. + """ + crossattn_cache = [] + num_attention_heads = self.transformer.num_attention_heads + attention_head_dim = self.transformer.attention_head_dim + for _ in range(self.num_transformer_blocks): + crossattn_cache.append( + { + "k": torch.zeros( + [ + batch_size, + max_text_len, + num_attention_heads, + attention_head_dim, + ], + dtype=dtype, + device=device, + ), + "v": torch.zeros( + [ + batch_size, + max_text_len, + num_attention_heads, + attention_head_dim, + ], + dtype=dtype, + device=device, + ), + "is_init": False, + } + ) + self.crossattn_cache = crossattn_cache + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify denoising stage inputs.""" + result = VerificationResult() + result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)]) + result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty) + result.add_check("image_embeds", batch.image_embeds, V.is_list) + result.add_check( + "image_latent", batch.image_latent, V.none_or_tensor_with_dims(5) + ) + result.add_check( + "num_inference_steps", batch.num_inference_steps, V.positive_int + ) + result.add_check("guidance_scale", batch.guidance_scale, V.non_negative_float) + result.add_check("eta", batch.eta, V.non_negative_float) + result.add_check("generator", batch.generator, V.generator_or_list_generators) + result.add_check( + "do_classifier_free_guidance", + batch.do_classifier_free_guidance, + V.bool_value, + ) + result.add_check( + "negative_prompt_embeds", + batch.negative_prompt_embeds, + lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x), + ) + return result diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/comfyui_latent_preparation.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/comfyui_latent_preparation.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4e53a928d711eb1f42c4a916f0b39e30d5c8f5 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/comfyui_latent_preparation.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +ComfyUI latent preparation stage with device mismatch fix. +This stage extends LatentPreparationStage to handle device mismatch issues +that occur when tensors are pickled and unpickled via broadcast_pyobj in +multi-GPU scenarios. +""" + +import dataclasses + +import torch + +from sglang.multimodal_gen.runtime.distributed import ( + get_local_torch_device, + get_sp_group, +) +from sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_world_size +from sglang.multimodal_gen.runtime.pipelines_core.stages.latent_preparation import ( + LatentPreparationStage, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class ComfyUILatentPreparationStage(LatentPreparationStage): + """ + ComfyUI-specific latent preparation stage with device mismatch fix. + + This stage extends LatentPreparationStage to automatically fix device + mismatches for tensor fields on non-source ranks in multi-GPU scenarios. + """ + + @staticmethod + def _fix_tensor_device(value, target_device): + """Recursively fix tensor device, handling single tensors, lists, and tuples.""" + if isinstance(value, torch.Tensor): + if value.device != target_device: + return value.detach().clone().to(target_device) + return value + elif isinstance(value, list): + return [ + ComfyUILatentPreparationStage._fix_tensor_device(v, target_device) + for v in value + ] + elif isinstance(value, tuple): + return tuple( + ComfyUILatentPreparationStage._fix_tensor_device(v, target_device) + for v in value + ) + return value + + @staticmethod + def _has_tensor(value): + """Check if value contains any tensor.""" + if isinstance(value, torch.Tensor): + return True + elif isinstance(value, (list, tuple)): + return any(ComfyUILatentPreparationStage._has_tensor(v) for v in value) + return False + + def forward(self, batch, server_args): + """ + Prepare latents with device mismatch fix for ComfyUI pipelines. + + This method first fixes device mismatches for all tensor fields, + then calls the parent class's forward method, and ensures raw_latent_shape + is set correctly (before packing, for proper unpadding later). + """ + # Fix device mismatch for tensor fields on non-source ranks + if get_sp_world_size() > 1: + sp_group = get_sp_group() + target_device = get_local_torch_device() + + if sp_group.rank != 0: + logger.debug( + f"[ComfyUILatentPreparationStage] Fixing tensor device on rank={sp_group.rank} " + f"target_device={target_device}" + ) + + if dataclasses.is_dataclass(batch): + for field in dataclasses.fields(batch): + value = getattr(batch, field.name, None) + if value is not None and self._has_tensor(value): + fixed_value = self._fix_tensor_device(value, target_device) + setattr(batch, field.name, fixed_value) + else: + for attr_name in dir(batch): + if not attr_name.startswith("_") and not callable( + getattr(batch, attr_name, None) + ): + try: + value = getattr(batch, attr_name, None) + if value is not None and self._has_tensor(value): + fixed_value = self._fix_tensor_device( + value, target_device + ) + setattr(batch, attr_name, fixed_value) + except (AttributeError, TypeError): + continue + + original_latents_shape = None + if batch.latents is not None: + original_latents_shape = batch.latents.shape + + # Call parent class's forward method + result = super().forward(batch, server_args) + + if original_latents_shape is not None: + # Preserve the original shape before any potential packing/conversion + # (e.g., 4D spatial -> 3D sequence) to ensure proper unpadding later. + result.raw_latent_shape = original_latents_shape + + return result diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..980d15210a95ee05b67dfd637fe622970453781d --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py @@ -0,0 +1,240 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Decoding stage for diffusion pipelines. +""" + +import weakref + +import torch + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.loader.component_loaders.vae_loader import VAELoader +from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.base import ( + PipelineStage, + StageParallelismType, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + VerificationResult, +) +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.server_args import ServerArgs, get_global_server_args +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE + +logger = init_logger(__name__) + + +def _ensure_tensor_decode_output(decode_output): + """ + Ensure VAE decode output is a tensor. + + Some VAE implementations return DecoderOutput objects with a .sample attribute, + tuples, or tensors directly. This function normalizes the output to always be a tensor. + + Args: + decode_output: Output from VAE.decode(), can be DecoderOutput, tuple, or torch.Tensor + + Returns: + torch.Tensor: The decoded image tensor + """ + if isinstance(decode_output, tuple): + return decode_output[0] + if hasattr(decode_output, "sample"): + return decode_output.sample + return decode_output + + +class DecodingStage(PipelineStage): + """ + Stage for decoding latent representations into pixel space. + + This stage handles the decoding of latent representations into the final + output format (e.g., pixel values). + """ + + def __init__(self, vae, pipeline=None) -> None: + super().__init__() + self.vae: ParallelTiledVAE = vae + self.pipeline = weakref.ref(pipeline) if pipeline else None + + @property + def parallelism_type(self) -> StageParallelismType: + if get_global_server_args().enable_cfg_parallel: + return StageParallelismType.MAIN_RANK_ONLY + return StageParallelismType.REPLICATED + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify decoding stage inputs.""" + result = VerificationResult() + # Denoised latents for VAE decoding: [batch_size, channels, frames, height_latents, width_latents] + # result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)]) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify decoding stage outputs.""" + result = VerificationResult() + # Decoded video/images: [batch_size, channels, frames, height, width] + # result.add_check("output", batch.output, [V.is_tensor, V.with_dims(5)]) + return result + + def scale_and_shift(self, latents: torch.Tensor, server_args): + scaling_factor, shift_factor = ( + server_args.pipeline_config.get_decode_scale_and_shift( + latents.device, latents.dtype, self.vae + ) + ) + + # 1. scale + if isinstance(scaling_factor, torch.Tensor): + latents = latents / scaling_factor.to(latents.device, latents.dtype) + else: + latents = latents / scaling_factor + + # 2. apply shifting if needed + if shift_factor is not None: + if isinstance(shift_factor, torch.Tensor): + latents += shift_factor.to(latents.device, latents.dtype) + else: + latents += shift_factor + return latents + + @torch.no_grad() + def decode(self, latents: torch.Tensor, server_args: ServerArgs) -> torch.Tensor: + """ + Decode latent representations into pixel space using VAE. + + Args: + latents: Input latent tensor with shape (batch, channels, frames, height_latents, width_latents) + server_args: Configuration containing: + - disable_autocast: Whether to disable automatic mixed precision (default: False) + - pipeline_config.vae_precision: VAE computation precision ("fp32", "fp16", "bf16") + - pipeline_config.vae_tiling: Whether to enable VAE tiling for memory efficiency + + Returns: + Decoded video tensor with shape (batch, channels, frames, height, width), + normalized to [0, 1] range and moved to CPU as float32 + """ + self.vae = self.vae.to(get_local_torch_device()) + latents = latents.to(get_local_torch_device()) + # Setup VAE precision + vae_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision] + vae_autocast_enabled = ( + vae_dtype != torch.float32 + ) and not server_args.disable_autocast + + # scale and shift + latents = self.scale_and_shift(latents, server_args) + # Preprocess latents before decoding (e.g., unpatchify for standard Flux2 VAE) + latents = server_args.pipeline_config.preprocess_decoding( + latents, server_args, vae=self.vae + ) + + # Decode latents + with torch.autocast( + device_type=current_platform.device_type, + dtype=vae_dtype, + enabled=vae_autocast_enabled, + ): + try: + # TODO: make it more specific + if server_args.pipeline_config.vae_tiling: + self.vae.enable_tiling() + except Exception: + pass + if not vae_autocast_enabled: + latents = latents.to(vae_dtype) + decode_output = self.vae.decode(latents) + image = _ensure_tensor_decode_output(decode_output) + + # De-normalize image to [0, 1] range + image = (image / 2 + 0.5).clamp(0, 1) + return image + + def load_model(self): + # load vae if not already loaded (used for memory constrained devices) + pipeline = self.pipeline() if self.pipeline else None + if not self.server_args.model_loaded["vae"]: + loader = VAELoader() + self.vae = loader.load( + self.server_args.model_paths["vae"], self.server_args + ) + if pipeline: + pipeline.add_module("vae", self.vae) + self.server_args.model_loaded["vae"] = True + + def offload_model(self): + # Offload models if needed + self.maybe_free_model_hooks() + + if self.server_args.vae_cpu_offload: + self.vae.to("cpu", non_blocking=True) + + if torch.backends.mps.is_available(): + del self.vae + pipeline = self.pipeline() if self.pipeline else None + if pipeline is not None and "vae" in pipeline.modules: + del pipeline.modules["vae"] + self.server_args.model_loaded["vae"] = False + + @torch.no_grad() + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> OutputBatch: + """ + Decode latent representations into pixel space. + + This method processes the batch through the VAE decoder, converting latent + representations to pixel-space video/images. It also optionally decodes + trajectory latents for visualization purposes. + + """ + # load vae if not already loaded (used for memory constrained devices) + self.load_model() + + frames = self.decode(batch.latents, server_args) + + # decode trajectory latents if needed + if batch.return_trajectory_decoded: + assert ( + batch.trajectory_latents is not None + ), "batch should have trajectory latents" + + # 1. Batch trajectory decoding to improve GPU utilization + # batch.trajectory_latents is [batch_size, timesteps, channels, frames, height, width] + B, T, C, F, H, W = batch.trajectory_latents.shape + flat_latents = batch.trajectory_latents.view(B * T, C, F, H, W) + + logger.info("decoding %s trajectory latents in batch", B * T) + # Use the optimized batch decode + all_decoded = self.decode(flat_latents, server_args) + + # 2. Reshape back + # Keep on GPU to allow faster vectorized post-processing + decoded_tensor = all_decoded.view(B, T, *all_decoded.shape[1:]) + + # Convert to list of tensors (per timestep) as expected by OutputBatch + # Each element in list is [B, channels, frames, H_out, W_out] + trajectory_decoded = [decoded_tensor[:, i] for i in range(T)] + else: + trajectory_decoded = None + + frames = server_args.pipeline_config.post_decoding(frames, server_args) + + # Update batch with decoded image + output_batch = OutputBatch( + output=frames, + trajectory_timesteps=batch.trajectory_timesteps, + trajectory_latents=batch.trajectory_latents, + trajectory_decoded=trajectory_decoded, + metrics=batch.metrics, + ) + + self.offload_model() + + return output_batch diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding_av.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding_av.py new file mode 100644 index 0000000000000000000000000000000000000000..28df749bdf7a5c1b5107357a0eade71d460e1f4d --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding_av.py @@ -0,0 +1,151 @@ +import torch + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.decoding import DecodingStage +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE + +logger = init_logger(__name__) + + +class LTX2AVDecodingStage(DecodingStage): + """ + LTX-2 specific decoding stage that handles both video and audio decoding. + """ + + def __init__(self, vae, audio_vae, vocoder, pipeline=None): + super().__init__(vae, pipeline) + self.audio_vae = audio_vae + self.vocoder = vocoder + # Add video processor for postprocessing + from diffusers.video_processor import VideoProcessor + + self.video_processor = VideoProcessor(vae_scale_factor=32) + + def forward(self, batch: Req, server_args: ServerArgs) -> OutputBatch: + self.load_model() + + self.vae = self.vae.to(get_local_torch_device()) + self.vae.eval() + latents = batch.latents.to(get_local_torch_device()) + + vae_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision] + vae_autocast_enabled = ( + vae_dtype != torch.float32 + ) and not server_args.disable_autocast + + original_dtype = vae_dtype + self.vae.to(torch.bfloat16) + latents = latents.to(torch.bfloat16) + std = self.vae.latents_std.view(1, -1, 1, 1, 1).to(latents) + mean = self.vae.latents_mean.view(1, -1, 1, 1, 1).to(latents) + latents = latents * std + mean + latents = server_args.pipeline_config.preprocess_decoding( + latents, server_args, vae=self.vae + ) + + with torch.autocast( + device_type=current_platform.device_type, + dtype=vae_dtype, + enabled=vae_autocast_enabled, + ): + try: + if server_args.pipeline_config.vae_tiling: + self.vae.enable_tiling() + except Exception: + pass + decode_output = self.vae.decode(latents) + if isinstance(decode_output, tuple): + video = decode_output[0] + elif hasattr(decode_output, "sample"): + video = decode_output.sample + else: + video = decode_output + + self.vae.to(original_dtype) + video = self.video_processor.postprocess_video(video, output_type="np") + + output_batch = OutputBatch( + output=video, + trajectory_timesteps=batch.trajectory_timesteps, + trajectory_latents=batch.trajectory_latents, + trajectory_decoded=None, + metrics=batch.metrics, + ) + + # 2. Decode Audio + try: + audio_latents = batch.audio_latents + except AttributeError: + audio_latents = None + if audio_latents is not None: + # Ensure device/dtype + device = get_local_torch_device() + self.audio_vae = self.audio_vae.to(device) + self.vocoder = self.vocoder.to(device) + self.audio_vae.eval() + self.vocoder.eval() + try: + dtype = self.audio_vae.dtype + except AttributeError: + dtype = None + if dtype is None: + try: + dtype = next(self.audio_vae.parameters()).dtype + except StopIteration: + dtype = torch.float32 + audio_latents = audio_latents.to(device, dtype=dtype) + try: + latents_std = self.audio_vae.latents_std + except AttributeError: + latents_std = None + if isinstance(latents_std, torch.Tensor) and torch.all(latents_std == 0): + logger.warning( + "audio_vae.latents_std is all zeros; audio denorm may be incorrect." + ) + + with torch.no_grad(): + # Decode latents to spectrogram + spectrogram = self.audio_vae.decode(audio_latents, return_dict=False)[0] + if hasattr(self.vocoder, "conv_in") and hasattr( + self.vocoder.conv_in, "in_channels" + ): + expected_in = int(self.vocoder.conv_in.in_channels) + actual_in = int(spectrogram.shape[1]) * int(spectrogram.shape[3]) + if actual_in != expected_in: + raise ValueError( + f"Vocoder expects channels*mel_bins={expected_in}, got {actual_in} from spectrogram shape {tuple(spectrogram.shape)}" + ) + # Decode spectrogram to waveform + waveform = self.vocoder(spectrogram) + output_batch.audio = waveform.cpu().float() + try: + pipeline_audio_cfg = server_args.pipeline_config.audio_vae_config + except AttributeError: + pipeline_audio_cfg = None + try: + pipeline_audio_arch = pipeline_audio_cfg.arch_config # type: ignore[union-attr] + except AttributeError: + pipeline_audio_arch = None + try: + pipeline_audio_sr = pipeline_audio_arch.sample_rate # type: ignore[union-attr] + except AttributeError: + pipeline_audio_sr = None + + try: + vocoder_sr = self.vocoder.sample_rate + except AttributeError: + vocoder_sr = None + try: + audio_vae_sr = self.audio_vae.sample_rate + except AttributeError: + audio_vae_sr = None + output_batch.audio_sample_rate = ( + vocoder_sr or audio_vae_sr or pipeline_audio_sr + ) + + self.offload_model() + return output_batch diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py new file mode 100644 index 0000000000000000000000000000000000000000..f06343174e452fa2de0849a91beda88c80b8609e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py @@ -0,0 +1,1694 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Denoising stage for diffusion pipelines. +""" + +import inspect +import math +import os +import time +import weakref +from collections.abc import Iterable +from functools import lru_cache +from typing import Any + +import torch +import torch.nn as nn +from einops import rearrange +from tqdm.auto import tqdm + +from sglang.multimodal_gen import envs +from sglang.multimodal_gen.configs.pipeline_configs.base import ModelTaskType, STA_Mode +from sglang.multimodal_gen.configs.pipeline_configs.wan import ( + Wan2_2_TI2V_5B_Config, +) +from sglang.multimodal_gen.runtime.cache.cache_dit_integration import ( + CacheDitConfig, + enable_cache_on_dual_transformer, + enable_cache_on_transformer, + get_scm_mask, + refresh_context_on_dual_transformer, + refresh_context_on_transformer, +) +from sglang.multimodal_gen.runtime.distributed import ( + cfg_model_parallel_all_reduce, + get_local_torch_device, + get_sp_group, + get_sp_parallel_rank, + get_sp_world_size, + get_tp_group, + get_world_group, + get_world_size, +) +from sglang.multimodal_gen.runtime.distributed.communication_op import ( + sequence_model_parallel_all_gather, +) +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, +) +from sglang.multimodal_gen.runtime.layers.attention.selector import get_attn_backend +from sglang.multimodal_gen.runtime.layers.attention.STA_configuration import ( + configure_sta, + save_mask_search_results, +) +from sglang.multimodal_gen.runtime.loader.component_loaders.transformer_loader import ( + TransformerLoader, +) +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.base import ( + PipelineStage, + StageParallelismType, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + StageValidators as V, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + VerificationResult, +) +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.runtime.utils.perf_logger import StageProfiler +from sglang.multimodal_gen.runtime.utils.profiler import SGLDiffusionProfiler +from sglang.multimodal_gen.utils import dict_to_3d_list, masks_like + +logger = init_logger(__name__) + + +class DenoisingStage(PipelineStage): + """ + Stage for running the denoising loop in diffusion pipelines. + + This stage handles the iterative denoising process that transforms + the initial noise into the final output. + """ + + def __init__( + self, transformer, scheduler, pipeline=None, transformer_2=None, vae=None + ) -> None: + super().__init__() + self.transformer = transformer + self.transformer_2 = transformer_2 + + hidden_size = self.server_args.pipeline_config.dit_config.hidden_size + num_attention_heads = ( + self.server_args.pipeline_config.dit_config.num_attention_heads + ) + attn_head_size = hidden_size // num_attention_heads + + # torch compile + for transformer in filter(None, [self.transformer, self.transformer_2]): + self._maybe_enable_torch_compile(transformer) + + self.scheduler = scheduler + self.vae = vae + self.pipeline = weakref.ref(pipeline) if pipeline else None + + # TODO(will): hack, should use the actual one in dit + self.attn_backend = get_attn_backend( + head_size=attn_head_size, + dtype=torch.float16, + ) + + # cfg + self.guidance = None + + # misc + self.profiler = None + # cache-dit state (for delayed mounting and idempotent control) + self._cache_dit_enabled = False + self._cached_num_steps = None + self._is_warmed_up = False + + def _maybe_enable_torch_compile(self, module: object) -> None: + """ + Compile a module with torch.compile, and enable inductor overlap tweak if available. + No-op if torch compile is disabled or the object is not a nn.Module. + """ + if not self.server_args.enable_torch_compile or not isinstance( + module, nn.Module + ): + return + try: + import torch._inductor.config as _inductor_cfg + + _inductor_cfg.reorder_for_compute_comm_overlap = True + except ImportError: + pass + mode = os.environ.get("SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs") + logger.info(f"Compiling transformer with mode: {mode}") + # TODO(triple-mu): support customized fullgraph and dynamic in the future + module.compile(mode=mode, fullgraph=False, dynamic=None) + + def _maybe_enable_cache_dit( + self, num_inference_steps: int | tuple[int, int], batch: Req + ) -> None: + """Enable cache-dit on the transformers if configured (idempotent). + + This method should be called after the transformer is fully loaded + and before torch.compile is applied. + + For dual-transformer models (e.g., Wan2.2), this enables cache-dit on both + transformers with (potentially) different configurations. + + """ + if isinstance(num_inference_steps, tuple): + num_high_noise_steps, num_low_noise_steps = num_inference_steps + + # NOTE: When a new request arrives, we need to refresh the cache-dit context. + if self._cache_dit_enabled: + scm_preset = envs.SGLANG_CACHE_DIT_SCM_PRESET + scm_preset = None if scm_preset == "none" else scm_preset + if isinstance(num_inference_steps, tuple): + refresh_context_on_dual_transformer( + self.transformer, + self.transformer_2, + num_high_noise_steps, + num_low_noise_steps, + scm_preset=scm_preset, + ) + else: + refresh_context_on_transformer( + self.transformer, + num_inference_steps, + scm_preset=scm_preset, + ) + return + + # check if cache-dit is enabled in config + if not envs.SGLANG_CACHE_DIT_ENABLED or batch.is_warmup: + return + + world_size = get_world_size() + parallelized = world_size > 1 + + sp_group = None + tp_group = None + if parallelized: + sp_group_candidate = get_sp_group() + tp_group_candidate = get_tp_group() + + sp_world_size = sp_group_candidate.world_size if sp_group_candidate else 1 + tp_world_size = tp_group_candidate.world_size if tp_group_candidate else 1 + + has_sp = sp_world_size > 1 + has_tp = tp_world_size > 1 + + sp_group = sp_group_candidate.device_group if has_sp else None + tp_group = tp_group_candidate.device_group if has_tp else None + + logger.info( + "cache-dit enabled in distributed environment (world_size=%d, has_sp=%s, has_tp=%s)", + world_size, + has_sp, + has_tp, + ) + # === Parse SCM configuration from envs === + # SCM is shared between primary and secondary transformers + scm_preset = envs.SGLANG_CACHE_DIT_SCM_PRESET + scm_compute_bins_str = envs.SGLANG_CACHE_DIT_SCM_COMPUTE_BINS + scm_cache_bins_str = envs.SGLANG_CACHE_DIT_SCM_CACHE_BINS + scm_policy = envs.SGLANG_CACHE_DIT_SCM_POLICY + + # parse custom bins if provided (both must be set together) + scm_compute_bins = None + scm_cache_bins = None + if scm_compute_bins_str and scm_cache_bins_str: + try: + scm_compute_bins = [ + int(x.strip()) for x in scm_compute_bins_str.split(",") + ] + scm_cache_bins = [int(x.strip()) for x in scm_cache_bins_str.split(",")] + except ValueError as e: + logger.warning("Failed to parse SCM bins: %s. SCM disabled.", e) + scm_preset = "none" + elif scm_compute_bins_str or scm_cache_bins_str: + # Only one of the bins was provided - warn user + logger.warning( + "SCM custom bins require both compute_bins and cache_bins. " + "Only one was provided (compute=%s, cache=%s). Falling back to preset '%s'.", + scm_compute_bins_str, + scm_cache_bins_str, + scm_preset, + ) + + # generate SCM mask using cache-dit's steps_mask() + # cache-dit handles step count validation and scaling internally + steps_computation_mask = get_scm_mask( + preset=scm_preset, + num_inference_steps=( + num_inference_steps + if isinstance(num_inference_steps, int) + else num_high_noise_steps + ), + compute_bins=scm_compute_bins, + cache_bins=scm_cache_bins, + ) + + if isinstance(num_inference_steps, tuple): + steps_computation_mask_2 = get_scm_mask( + preset=scm_preset, + num_inference_steps=num_low_noise_steps, + compute_bins=scm_compute_bins, + cache_bins=scm_cache_bins, + ) + + # build config for primary transformer (high-noise expert) + primary_config = CacheDitConfig( + enabled=True, + Fn_compute_blocks=envs.SGLANG_CACHE_DIT_FN, + Bn_compute_blocks=envs.SGLANG_CACHE_DIT_BN, + max_warmup_steps=envs.SGLANG_CACHE_DIT_WARMUP, + residual_diff_threshold=envs.SGLANG_CACHE_DIT_RDT, + max_continuous_cached_steps=envs.SGLANG_CACHE_DIT_MC, + enable_taylorseer=envs.SGLANG_CACHE_DIT_TAYLORSEER, + taylorseer_order=envs.SGLANG_CACHE_DIT_TS_ORDER, + num_inference_steps=( + num_inference_steps + if isinstance(num_inference_steps, int) + else num_high_noise_steps + ), + # SCM fields + steps_computation_mask=steps_computation_mask, + steps_computation_policy=scm_policy, + ) + + if self.transformer_2 is not None: + # dual transformer + # build config for secondary transformer (low-noise expert) + # uses secondary parameters which inherit from primary if not explicitly set + secondary_config = CacheDitConfig( + enabled=True, + Fn_compute_blocks=envs.SGLANG_CACHE_DIT_SECONDARY_FN, + Bn_compute_blocks=envs.SGLANG_CACHE_DIT_SECONDARY_BN, + max_warmup_steps=envs.SGLANG_CACHE_DIT_SECONDARY_WARMUP, + residual_diff_threshold=envs.SGLANG_CACHE_DIT_SECONDARY_RDT, + max_continuous_cached_steps=envs.SGLANG_CACHE_DIT_SECONDARY_MC, + enable_taylorseer=envs.SGLANG_CACHE_DIT_SECONDARY_TAYLORSEER, + taylorseer_order=envs.SGLANG_CACHE_DIT_SECONDARY_TS_ORDER, + num_inference_steps=num_low_noise_steps, + # SCM fields - shared with primary + steps_computation_mask=steps_computation_mask_2, + steps_computation_policy=scm_policy, + ) + + # for dual transformers, must use BlockAdapter to enable cache on both simultaneously. + # Don't call enable_cache separately on each transformer. + self.transformer, self.transformer_2 = enable_cache_on_dual_transformer( + self.transformer, + self.transformer_2, + primary_config, + secondary_config, + model_name="wan2.2", + sp_group=sp_group, + tp_group=tp_group, + ) + logger.info( + "cache-dit enabled on dual transformers (steps=%d, %d)", + num_high_noise_steps, + num_low_noise_steps, + ) + else: + # single transformer + self.transformer = enable_cache_on_transformer( + self.transformer, + primary_config, + model_name="transformer", + sp_group=sp_group, + tp_group=tp_group, + ) + logger.info( + "cache-dit enabled on transformer (steps=%d, Fn=%d, Bn=%d, rdt=%.3f)", + num_inference_steps, + envs.SGLANG_CACHE_DIT_FN, + envs.SGLANG_CACHE_DIT_BN, + envs.SGLANG_CACHE_DIT_RDT, + ) + + self._cache_dit_enabled = True + self._cached_num_steps = num_inference_steps + + @lru_cache(maxsize=8) + def _build_guidance(self, batch_size, target_dtype, device, guidance_val): + """Builds a guidance tensor. This method is cached.""" + return ( + torch.full( + (batch_size,), + guidance_val, + dtype=target_dtype, + device=device, + ) + * 1000.0 + ) + + def get_or_build_guidance(self, bsz: int, dtype, device): + """ + Get the guidance tensor, using a cached version if available. + + This method retrieves a cached guidance tensor using `_build_guidance`. + The caching is based on batch size, dtype, device, and the guidance value, + preventing repeated tensor creation within the denoising loop. + """ + if self.server_args.pipeline_config.should_use_guidance: + # TODO: should the guidance_scale be picked-up from sampling_params? + guidance_val = self.server_args.pipeline_config.embedded_cfg_scale + return self._build_guidance(bsz, dtype, device, guidance_val) + else: + return None + + @property + def parallelism_type(self) -> StageParallelismType: + # return StageParallelismType.CFG_PARALLEL if get_global_server_args().enable_cfg_parallel else StageParallelismType.REPLICATED + return StageParallelismType.REPLICATED + + def _preprocess_latents_for_ti2v( + self, latents, target_dtype, batch, server_args: ServerArgs + ): + # FIXME: should probably move to latent preparation stage, to handle with offload + # Wan2.2 TI2V directly replaces the first frame of the latent with + # the image latent instead of appending along the channel dim + assert batch.image_latent is None, "TI2V task should not have image latents" + assert self.vae is not None, "VAE is not provided for TI2V task" + self.vae = self.vae.to(batch.condition_image.device) + z = self.vae.encode(batch.condition_image).mean.float() + if self.vae.device != "cpu" and server_args.vae_cpu_offload: + self.vae = self.vae.to("cpu") + if hasattr(self.vae, "shift_factor") and self.vae.shift_factor is not None: + if isinstance(self.vae.shift_factor, torch.Tensor): + z -= self.vae.shift_factor.to(z.device, z.dtype) + else: + z -= self.vae.shift_factor + + if isinstance(self.vae.scaling_factor, torch.Tensor): + z = z * self.vae.scaling_factor.to(z.device, z.dtype) + else: + z = z * self.vae.scaling_factor + # z: [B, C, 1, H, W] + latent_model_input = latents.to(target_dtype) + # Keep as [B, C, T, H, W] for proper broadcasting + assert latent_model_input.ndim == 5 + + # Create mask with proper shape [B, C, T, H, W] + latent_for_mask = latent_model_input.squeeze(0) # [C, T, H, W] + _, reserved_frames_masks = masks_like([latent_for_mask], zero=True) + reserved_frames_mask = reserved_frames_masks[0].unsqueeze(0) # [1, C, T, H, W] + + # replace GLOBAL first frame with image - proper broadcasting + # z: [B, C, 1, H, W], reserved_frames_mask: [1, C, T, H, W] + # Both will broadcast correctly + latents = ( + 1.0 - reserved_frames_mask + ) * z + reserved_frames_mask * latent_model_input + assert latents.ndim == 5 + latents = latents.to(get_local_torch_device()) + batch.latents = latents + + F = batch.num_frames + temporal_scale = ( + server_args.pipeline_config.vae_config.arch_config.scale_factor_temporal + ) + spatial_scale = ( + server_args.pipeline_config.vae_config.arch_config.scale_factor_spatial + ) + patch_size = server_args.pipeline_config.dit_config.arch_config.patch_size + seq_len = ( + ((F - 1) // temporal_scale + 1) + * (batch.height // spatial_scale) + * (batch.width // spatial_scale) + // (patch_size[1] * patch_size[2]) + ) + seq_len = int(math.ceil(seq_len / get_sp_world_size())) * get_sp_world_size() + return seq_len, z, reserved_frames_masks + + def _postprocess_latents_for_ti2v(self, z, reserved_frames_masks, batch): + rank_in_sp_group = get_sp_parallel_rank() + sp_world_size = get_sp_world_size() + + if getattr(batch, "did_sp_shard_latents", False): + # Shard z (image latent) along time dimension + # z shape: [1, C, 1, H, W] - only first frame + # Only rank 0 has the first frame after sharding + if z.shape[2] == 1: + # z is single frame, only rank 0 needs it + if rank_in_sp_group == 0: + z_sp = z + else: + # Other ranks don't have the first frame + z_sp = None + else: + # Should not happen for TI2V + z_sp = z + + # Shard reserved_frames_mask along time dimension to match sharded latents + # reserved_frames_mask is a list from masks_like, extract reserved_frames_mask[0] first + # reserved_frames_mask[0] shape: [C, T, H, W] + # All ranks need their portion of reserved_frames_mask for timestep calculation + if reserved_frames_masks is not None: + reserved_frames_mask = reserved_frames_masks[ + 0 + ] # Extract tensor from list + time_dim = reserved_frames_mask.shape[1] # [C, T, H, W] + if time_dim > 0 and time_dim % sp_world_size == 0: + reserved_frames_mask_sp_tensor = rearrange( + reserved_frames_mask, + "c (n t) h w -> c n t h w", + n=sp_world_size, + ).contiguous() + reserved_frames_mask_sp_tensor = reserved_frames_mask_sp_tensor[ + :, rank_in_sp_group, :, :, : + ] + reserved_frames_mask_sp = ( + reserved_frames_mask_sp_tensor # Store as tensor, not list + ) + else: + reserved_frames_mask_sp = reserved_frames_mask + else: + reserved_frames_mask_sp = None + else: + # SP not enabled or latents not sharded + z_sp = z + reserved_frames_mask_sp = ( + reserved_frames_masks[0] if reserved_frames_masks is not None else None + ) # Extract tensor + + return reserved_frames_mask_sp, z_sp + + def _handle_boundary_ratio( + self, + server_args, + batch, + ): + """ + (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert + """ + boundary_ratio = server_args.pipeline_config.dit_config.boundary_ratio + if batch.boundary_ratio is not None: + logger.info( + "Overriding boundary ratio from %s to %s", + boundary_ratio, + batch.boundary_ratio, + ) + boundary_ratio = batch.boundary_ratio + + if boundary_ratio is not None: + boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps + else: + boundary_timestep = None + + return boundary_timestep + + def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs): + """ + Prepare all necessary invariant variables for the denoising loop. + + Returns: + A dictionary containing all the prepared variables for the denoising loop. + """ + assert self.transformer is not None + pipeline = self.pipeline() if self.pipeline else None + + boundary_timestep = self._handle_boundary_ratio(server_args, batch) + # Get timesteps and calculate warmup steps + timesteps = batch.timesteps + num_inference_steps = batch.num_inference_steps + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + if self.transformer_2 is not None: + assert boundary_timestep is not None, "boundary_timestep must be provided" + num_high_noise_steps = (timesteps >= boundary_timestep).sum().item() + num_low_noise_steps = num_inference_steps - num_high_noise_steps + cache_dit_num_inference_steps = (num_high_noise_steps, num_low_noise_steps) + else: + cache_dit_num_inference_steps = num_inference_steps + + if not server_args.model_loaded["transformer"]: + # FIXME: reuse more code + loader = TransformerLoader() + self.transformer = loader.load( + server_args.model_paths["transformer"], server_args, "transformer" + ) + # enable cache-dit before torch.compile (delayed mounting) + self._maybe_enable_cache_dit(cache_dit_num_inference_steps, batch) + self._maybe_enable_torch_compile(self.transformer) + if pipeline: + pipeline.add_module("transformer", self.transformer) + server_args.model_loaded["transformer"] = True + else: + self._maybe_enable_cache_dit(cache_dit_num_inference_steps, batch) + + # Prepare extra step kwargs for scheduler + extra_step_kwargs = self.prepare_extra_func_kwargs( + self.scheduler.step, + {"generator": batch.generator, "eta": batch.eta}, + ) + + # Setup precision and autocast settings + target_dtype = torch.bfloat16 + autocast_enabled = ( + target_dtype != torch.float32 + ) and not server_args.disable_autocast + + # Prepare image latents and embeddings for I2V generation + image_embeds = batch.image_embeds + if len(image_embeds) > 0: + image_embeds = [ + image_embed.to(target_dtype) for image_embed in image_embeds + ] + + # Prepare STA parameters + if self.attn_backend.get_enum() == AttentionBackendEnum.SLIDING_TILE_ATTN: + self.prepare_sta_param(batch, server_args) + + # Get latents and embeddings + latents = batch.latents + prompt_embeds = batch.prompt_embeds + # Removed Tensor truthiness assert to avoid GPU sync + neg_prompt_embeds = None + if batch.do_classifier_free_guidance: + neg_prompt_embeds = batch.negative_prompt_embeds + assert neg_prompt_embeds is not None + # Removed Tensor truthiness assert to avoid GPU sync + + # specifically for Wan2_2_TI2V_5B_Config, not applicable for FastWan2_2_TI2V_5B_Config + should_preprocess_for_wan_ti2v = ( + server_args.pipeline_config.task_type == ModelTaskType.TI2V + and batch.condition_image is not None + and type(server_args.pipeline_config) is Wan2_2_TI2V_5B_Config + ) + + # TI2V specific preparations - before SP sharding + if should_preprocess_for_wan_ti2v: + seq_len, z, reserved_frames_masks = self._preprocess_latents_for_ti2v( + latents, target_dtype, batch, server_args + ) + else: + seq_len, z, reserved_frames_masks = ( + None, + None, + None, + ) + + # Handle sequence parallelism after TI2V processing + self._preprocess_sp_latents(batch, server_args) + latents = batch.latents + + # Shard z and reserved_frames_mask for TI2V if SP is enabled + if should_preprocess_for_wan_ti2v: + reserved_frames_mask_sp, z_sp = self._postprocess_latents_for_ti2v( + z, reserved_frames_masks, batch + ) + else: + reserved_frames_mask_sp, z_sp = ( + reserved_frames_masks[0] if reserved_frames_masks is not None else None + ), z + + guidance = self.get_or_build_guidance( + # TODO: replace with raw_latent_shape? + latents.shape[0], + latents.dtype, + latents.device, + ) + + image_kwargs = self.prepare_extra_func_kwargs( + getattr(self.transformer, "forward", self.transformer), + { + # TODO: make sure on-device + "encoder_hidden_states_image": image_embeds, + "mask_strategy": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24), + }, + ) + + pos_cond_kwargs = self.prepare_extra_func_kwargs( + getattr(self.transformer, "forward", self.transformer), + { + "encoder_hidden_states_2": batch.clip_embedding_pos, + "encoder_attention_mask": batch.prompt_attention_mask, + } + | server_args.pipeline_config.prepare_pos_cond_kwargs( + batch, + self.device, + getattr(self.transformer, "rotary_emb", None), + dtype=target_dtype, + ) + | dict( + encoder_hidden_states=server_args.pipeline_config.get_pos_prompt_embeds( + batch + ) + ), + ) + + if batch.do_classifier_free_guidance: + neg_cond_kwargs = self.prepare_extra_func_kwargs( + getattr(self.transformer, "forward", self.transformer), + { + "encoder_hidden_states_2": batch.clip_embedding_neg, + "encoder_attention_mask": batch.negative_attention_mask, + } + | server_args.pipeline_config.prepare_neg_cond_kwargs( + batch, + self.device, + getattr(self.transformer, "rotary_emb", None), + dtype=target_dtype, + ) + | dict( + encoder_hidden_states=server_args.pipeline_config.get_neg_prompt_embeds( + batch + ) + ), + ) + else: + neg_cond_kwargs = {} + + return { + "extra_step_kwargs": extra_step_kwargs, + "target_dtype": target_dtype, + "autocast_enabled": autocast_enabled, + "timesteps": timesteps, + "num_inference_steps": num_inference_steps, + "num_warmup_steps": num_warmup_steps, + "image_kwargs": image_kwargs, + "pos_cond_kwargs": pos_cond_kwargs, + "neg_cond_kwargs": neg_cond_kwargs, + "latents": latents, + "prompt_embeds": prompt_embeds, + "neg_prompt_embeds": neg_prompt_embeds, + "boundary_timestep": boundary_timestep, + "z": z_sp, # Use SP-sharded version + # ndim == 5 + "reserved_frames_mask": reserved_frames_mask_sp, # Use SP-sharded version + "seq_len": seq_len, + "guidance": guidance, + } + + def _post_denoising_loop( + self, + batch: Req, + latents: torch.Tensor, + trajectory_latents: list, + trajectory_timesteps: list, + server_args: ServerArgs, + is_warmup: bool = False, + ): + # Gather results if using sequence parallelism + if trajectory_latents: + trajectory_tensor = torch.stack(trajectory_latents, dim=1) + trajectory_timesteps_tensor = torch.stack(trajectory_timesteps, dim=0) + else: + trajectory_tensor = None + trajectory_timesteps_tensor = None + + # Gather results if using sequence parallelism + latents, trajectory_tensor = self._postprocess_sp_latents( + batch, latents, trajectory_tensor + ) + + # Gather noise_pred if using sequence parallelism + # noise_pred has the same shape as latents (sharded along sequence dimension) + if ( + get_sp_world_size() > 1 + and getattr(batch, "did_sp_shard_latents", False) + and server_args.comfyui_mode + and hasattr(batch, "noise_pred") + and batch.noise_pred is not None + ): + batch.noise_pred = server_args.pipeline_config.gather_latents_for_sp( + batch.noise_pred + ) + if hasattr(batch, "raw_latent_shape"): + orig_s = batch.raw_latent_shape[1] + if batch.noise_pred.shape[1] > orig_s: + batch.noise_pred = batch.noise_pred[:, :orig_s, :] + + if trajectory_tensor is not None and trajectory_timesteps_tensor is not None: + batch.trajectory_timesteps = trajectory_timesteps_tensor.cpu() + batch.trajectory_latents = trajectory_tensor.cpu() + + # Update batch with final latents + batch.latents = self.server_args.pipeline_config.post_denoising_loop( + latents, batch + ) + + offload_mgr = getattr(self.transformer, "_layerwise_offload_manager", None) + if offload_mgr is not None and getattr(offload_mgr, "enabled", False): + offload_mgr.release_all() + + if self.transformer_2 is not None: + offload_mgr_2 = getattr( + self.transformer_2, "_layerwise_offload_manager", None + ) + if offload_mgr_2 is not None and getattr(offload_mgr_2, "enabled", False): + offload_mgr_2.release_all() + + # Save STA mask search results if needed + if ( + not is_warmup + and self.attn_backend.get_enum() == AttentionBackendEnum.SLIDING_TILE_ATTN + and server_args.attention_backend_config.STA_mode == "STA_SEARCHING" + ): + self.save_sta_search_results(batch) + + # deallocate transformer if on mps + pipeline = self.pipeline() if self.pipeline else None + if torch.backends.mps.is_available() and not is_warmup: + logger.info( + "Memory before deallocating transformer: %s", + torch.mps.current_allocated_memory(), + ) + del self.transformer + if pipeline is not None and "transformer" in pipeline.modules: + del pipeline.modules["transformer"] + server_args.model_loaded["transformer"] = False + logger.info( + "Memory after deallocating transformer: %s", + torch.mps.current_allocated_memory(), + ) + + # reset offload managers with prefetching first layer for next forward + for dit in filter(None, [self.transformer, self.transformer_2]): + if isinstance(dit, OffloadableDiTMixin): + # release all DiT weights to avoid peak VRAM usage, which may increasing the latency for next req + # TODO: should be make this an option? + for manager in dit.layerwise_offload_managers: + manager.release_all() + + def _preprocess_sp_latents(self, batch: Req, server_args: ServerArgs): + """Shard latents for Sequence Parallelism if applicable.""" + if get_sp_world_size() <= 1: + return + + if batch.latents is not None: + ( + batch.latents, + did_shard, + ) = server_args.pipeline_config.shard_latents_for_sp(batch, batch.latents) + batch.did_sp_shard_latents = did_shard + else: + batch.did_sp_shard_latents = False + + # image_latent must be sharded consistently with latents when it is + # concatenated along the sequence dimension in the denoising loop. + if batch.image_latent is not None: + batch.image_latent, _ = server_args.pipeline_config.shard_latents_for_sp( + batch, batch.image_latent + ) + + def _postprocess_sp_latents( + self, + batch: Req, + latents: torch.Tensor, + trajectory_tensor: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Gather latents after Sequence Parallelism if they were sharded.""" + if get_sp_world_size() > 1 and getattr(batch, "did_sp_shard_latents", False): + latents = self.server_args.pipeline_config.gather_latents_for_sp(latents) + if trajectory_tensor is not None: + # trajectory_tensor shapes: + # - video: [b, num_steps, c, t_local, h, w] -> gather on dim=3 + # - image: [b, num_steps, s_local, d] -> gather on dim=2 + trajectory_tensor = trajectory_tensor.to(get_local_torch_device()) + gather_dim = 3 if trajectory_tensor.dim() >= 5 else 2 + trajectory_tensor = sequence_model_parallel_all_gather( + trajectory_tensor, dim=gather_dim + ) + if gather_dim == 2 and hasattr(batch, "raw_latent_shape"): + orig_s = batch.raw_latent_shape[1] + if trajectory_tensor.shape[2] > orig_s: + trajectory_tensor = trajectory_tensor[:, :, :orig_s, :] + return latents, trajectory_tensor + + def step_profile(self): + profiler = SGLDiffusionProfiler.get_instance() + if profiler: + profiler.step_denoising_step() + + def _manage_device_placement( + self, + model_to_use: nn.Module, + model_to_offload: nn.Module | None, + server_args: ServerArgs, + ): + """ + Manages the offload / load behavior of dit + """ + if not server_args.dit_cpu_offload: + return + + # FSDP manages offloading internally + if server_args.use_fsdp_inference: + return + + # Offload the unused model if it's on CUDA + if ( + model_to_offload is not None + and next(model_to_offload.parameters()).device.type == "cuda" + ): + model_to_offload.to("cpu") + + # Load the model to use if it's on CPU + if ( + model_to_use is not None + and next(model_to_use.parameters()).device.type == "cpu" + ): + model_to_use.to(get_local_torch_device()) + + def _select_and_manage_model( + self, + t_int: int, + boundary_timestep: float | None, + server_args: ServerArgs, + batch: Req, + ): + if boundary_timestep is None or t_int >= boundary_timestep: + # High-noise stage + current_model = self.transformer + model_to_offload = self.transformer_2 + current_guidance_scale = batch.guidance_scale + else: + # Low-noise stage + current_model = self.transformer_2 + model_to_offload = self.transformer + current_guidance_scale = batch.guidance_scale_2 + + self._manage_device_placement(current_model, model_to_offload, server_args) + + assert current_model is not None, "The model for the current step is not set." + return current_model, current_guidance_scale + + def expand_timestep_before_forward( + self, + batch: Req, + server_args: ServerArgs, + t_device, + target_dtype, + seq_len: int | None, + reserved_frames_mask, + ): + bsz = batch.raw_latent_shape[0] + should_preprocess_for_wan_ti2v = ( + server_args.pipeline_config.task_type == ModelTaskType.TI2V + and batch.condition_image is not None + and type(server_args.pipeline_config) is Wan2_2_TI2V_5B_Config + ) + + # expand timestep + if should_preprocess_for_wan_ti2v: + # Explicitly cast t_device to the target float type at the beginning. + # This ensures any precision-based rounding (e.g., float32(999.0) -> bfloat16(1000.0)) + # is applied consistently *before* it's used by any rank. + t_device_rounded = t_device.to(target_dtype) + + local_seq_len = seq_len + if get_sp_world_size() > 1 and getattr( + batch, "did_sp_shard_latents", False + ): + local_seq_len = seq_len // get_sp_world_size() + + if get_sp_parallel_rank() == 0 and reserved_frames_mask is not None: + # Rank 0 has the first frame, create a special timestep tensor + # NOTE: The spatial downsampling in the next line is suspicious but kept + # to match original model's potential training configuration. + temp_ts = ( + reserved_frames_mask[0][:, ::2, ::2] * t_device_rounded + ).flatten() + + # Pad to full local sequence length + temp_ts = torch.cat( + [ + temp_ts, + temp_ts.new_ones(local_seq_len - temp_ts.size(0)) + * t_device_rounded, + ] + ) + timestep = temp_ts.unsqueeze(0).repeat(bsz, 1) + else: + # Other ranks get a uniform timestep tensor of the correct shape [B, local_seq_len] + timestep = t_device.repeat(bsz, local_seq_len) + else: + timestep = t_device.repeat(bsz) + return timestep + + def post_forward_for_ti2v_task( + self, batch: Req, server_args: ServerArgs, reserved_frames_mask, latents, z + ): + """ + For Wan2.2 ti2v task, global first frame should be replaced with encoded image after each timestep + """ + should_preprocess_for_wan_ti2v = ( + server_args.pipeline_config.task_type == ModelTaskType.TI2V + and batch.condition_image is not None + and type(server_args.pipeline_config) is Wan2_2_TI2V_5B_Config + ) + if should_preprocess_for_wan_ti2v: + # Apply TI2V mask blending with SP-aware z and reserved_frames_mask. + # This ensures the first frame is always the condition image after each step. + # This is only applied on rank 0, where z is not None. + if z is not None and reserved_frames_mask is not None: + # z: [1, C, 1, H, W] + # latents: [1, C, T_local, H, W] + # reserved_frames_mask: [C, T_local, H, W] + # Unsqueeze mask to [1, C, T_local, H, W] for broadcasting. + # z will broadcast along the time dimension. + latents = ( + 1.0 - reserved_frames_mask.unsqueeze(0) + ) * z + reserved_frames_mask.unsqueeze(0) * latents + + return latents + + @torch.no_grad() + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Run the denoising loop. + """ + # Prepare variables for the denoising loop + + prepared_vars = self._prepare_denoising_loop(batch, server_args) + extra_step_kwargs = prepared_vars["extra_step_kwargs"] + target_dtype = prepared_vars["target_dtype"] + autocast_enabled = prepared_vars["autocast_enabled"] + timesteps = prepared_vars["timesteps"] + num_inference_steps = prepared_vars["num_inference_steps"] + num_warmup_steps = prepared_vars["num_warmup_steps"] + image_kwargs = prepared_vars["image_kwargs"] + pos_cond_kwargs = prepared_vars["pos_cond_kwargs"] + neg_cond_kwargs = prepared_vars["neg_cond_kwargs"] + latents = prepared_vars["latents"] + boundary_timestep = prepared_vars["boundary_timestep"] + z = prepared_vars["z"] + reserved_frames_mask = prepared_vars["reserved_frames_mask"] + seq_len = prepared_vars["seq_len"] + guidance = prepared_vars["guidance"] + + # Initialize lists for ODE trajectory + trajectory_timesteps: list[torch.Tensor] = [] + trajectory_latents: list[torch.Tensor] = [] + + # Run denoising loop + denoising_start_time = time.time() + + # to avoid device-sync caused by timestep comparison + is_warmup = batch.is_warmup + self.scheduler.set_begin_index(0) + timesteps_cpu = timesteps.cpu() + num_timesteps = timesteps_cpu.shape[0] + with torch.autocast( + device_type=current_platform.device_type, + dtype=target_dtype, + enabled=autocast_enabled, + ): + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t_host in enumerate(timesteps_cpu): + with StageProfiler( + f"denoising_step_{i}", + logger=logger, + metrics=batch.metrics, + perf_dump_path_provided=batch.perf_dump_path is not None, + ): + t_int = int(t_host.item()) + t_device = timesteps[i] + current_model, current_guidance_scale = ( + self._select_and_manage_model( + t_int=t_int, + boundary_timestep=boundary_timestep, + server_args=server_args, + batch=batch, + ) + ) + + # Expand latents for I2V + latent_model_input = latents.to(target_dtype) + if batch.image_latent is not None: + assert ( + not server_args.pipeline_config.task_type + == ModelTaskType.TI2V + ), "image latents should not be provided for TI2V task" + latent_model_input = torch.cat( + [latent_model_input, batch.image_latent], dim=1 + ).to(target_dtype) + + timestep = self.expand_timestep_before_forward( + batch, + server_args, + t_device, + target_dtype, + seq_len, + reserved_frames_mask, + ) + + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t_device + ) + + # Predict noise residual + attn_metadata = self._build_attn_metadata( + i, + batch, + server_args, + timestep_value=t_int, + timesteps=timesteps_cpu, + ) + noise_pred = self._predict_noise_with_cfg( + current_model=current_model, + latent_model_input=latent_model_input, + timestep=timestep, + batch=batch, + timestep_index=i, + attn_metadata=attn_metadata, + target_dtype=target_dtype, + current_guidance_scale=current_guidance_scale, + image_kwargs=image_kwargs, + pos_cond_kwargs=pos_cond_kwargs, + neg_cond_kwargs=neg_cond_kwargs, + server_args=server_args, + guidance=guidance, + latents=latents, + ) + + # Save noise_pred to batch for external access (e.g., ComfyUI) + if server_args.comfyui_mode: + batch.noise_pred = noise_pred + + # Compute the previous noisy sample + latents = self.scheduler.step( + model_output=noise_pred, + timestep=t_device, + sample=latents, + **extra_step_kwargs, + return_dict=False, + )[0] + + latents = self.post_forward_for_ti2v_task( + batch, server_args, reserved_frames_mask, latents, z + ) + + # save trajectory latents if needed + if batch.return_trajectory_latents: + trajectory_timesteps.append(t_host) + trajectory_latents.append(latents) + + # Update progress bar + if i == num_timesteps - 1 or ( + (i + 1) > num_warmup_steps + and (i + 1) % self.scheduler.order == 0 + and progress_bar is not None + ): + progress_bar.update() + + if not is_warmup: + self.step_profile() + + denoising_end_time = time.time() + + if num_timesteps > 0 and not is_warmup: + self.log_info( + "average time per step: %.4f seconds", + (denoising_end_time - denoising_start_time) / len(timesteps), + ) + + self._post_denoising_loop( + batch=batch, + latents=latents, + trajectory_latents=trajectory_latents, + trajectory_timesteps=trajectory_timesteps, + server_args=server_args, + is_warmup=is_warmup, + ) + return batch + + # TODO: this will extends the preparation stage, should let subclass/passed-in variables decide which to prepare + def prepare_extra_func_kwargs(self, func, kwargs) -> dict[str, Any]: + """ + Prepare extra kwargs for the scheduler step / denoise step. + + Args: + func: The function to prepare kwargs for. + kwargs: The kwargs to prepare. + """ + import functools + + # Handle cache-dit's partial wrapping logic. + # Cache-dit wraps the forward method with functools.partial where args[0] is the instance. + # We access `_original_forward` if available to inspect the underlying signature. + # See: https://github.com/vipshop/cache-dit + if isinstance(func, functools.partial) and func.args: + func = getattr(func.args[0], "_original_forward", func) + + # Unwrap any decorators (e.g. functools.wraps) + target_func = inspect.unwrap(func) + + # Filter kwargs based on the signature + params = inspect.signature(target_func).parameters + return {k: v for k, v in kwargs.items() if k in params} + + def progress_bar( + self, iterable: Iterable | None = None, total: int | None = None + ) -> tqdm: + """ + Create a progress bar for the denoising process. + """ + local_rank = get_world_group().local_rank + disable = local_rank != 0 + return tqdm(iterable=iterable, total=total, disable=disable) + + def rescale_noise_cfg( + self, noise_cfg, noise_pred_text, guidance_rescale=0.0 + ) -> torch.Tensor: + """ + Rescale noise prediction according to guidance_rescale. + + Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed" + (https://arxiv.org/pdf/2305.08891.pdf), Section 3.4. + + Args: + noise_cfg: The noise prediction with guidance. + noise_pred_text: The text-conditioned noise prediction. + guidance_rescale: The guidance rescale factor. + + Returns: + The rescaled noise prediction. + """ + std_text = noise_pred_text.std( + dim=list(range(1, noise_pred_text.ndim)), keepdim=True + ) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # Rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # Mix with the original results from guidance by factor guidance_rescale + noise_cfg = ( + guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + ) + return noise_cfg + + def _build_attn_metadata( + self, + i: int, + batch: Req, + server_args: ServerArgs, + *, + timestep_value: int | None = None, + timesteps: torch.Tensor | None = None, + ) -> Any | None: + """ + Build attention metadata for custom attention backends. + + Args: + i: The current timestep index. + """ + attn_metadata = None + self.attn_metadata_builder = None + try: + self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls() + except NotImplementedError: + self.attn_metadata_builder_cls = None + if self.attn_metadata_builder_cls: + self.attn_metadata_builder = self.attn_metadata_builder_cls() + if ( + self.attn_backend.get_enum() == AttentionBackendEnum.SLIDING_TILE_ATTN + or self.attn_backend.get_enum() == AttentionBackendEnum.VIDEO_SPARSE_ATTN + ): + attn_metadata = self.attn_metadata_builder.build( + current_timestep=i, + raw_latent_shape=batch.raw_latent_shape[2:5], + patch_size=server_args.pipeline_config.dit_config.patch_size, + STA_param=batch.STA_param, + VSA_sparsity=server_args.attention_backend_config.VSA_sparsity, + device=get_local_torch_device(), + ) + elif ( + self.attn_backend.get_enum() == AttentionBackendEnum.SPARSE_VIDEO_GEN_2_ATTN + ): + if timestep_value is None or timesteps is None: + raise ValueError( + "timestep_value and timesteps must be provided for SVG2 attention metadata" + ) + + svg2_cfg = server_args.attention_backend_config or {} + num_layers = server_args.pipeline_config.dit_config.num_layers + if ( + server_args.pipeline_config.dit_config.prefix.lower() == "hunyuan" + and hasattr(server_args.pipeline_config.dit_config, "num_single_layers") + ): + num_layers += server_args.pipeline_config.dit_config.num_single_layers + first_layers_fp = svg2_cfg.get("svg2_first_layers_fp", 0.03) + if first_layers_fp <= 1.0: + first_layers_fp = math.floor(first_layers_fp * num_layers) + first_layers_fp = max(0, min(int(first_layers_fp), num_layers)) + + first_times_fp = svg2_cfg.get("svg2_first_times_fp", 0.2) + if first_times_fp <= 1.0: + num_fp_steps = math.floor(first_times_fp * len(timesteps)) + if num_fp_steps > 0: + first_times_fp = float(timesteps[num_fp_steps - 1].item() - 1) + else: + first_times_fp = float(timesteps.max().item() + 1) + + current_timestep = int(timestep_value) + + cache = batch.extra.get("svg2_cache") + if cache is None: + from sglang.multimodal_gen.runtime.layers.attention.backends.sparse_video_gen_2_attn import ( + Svg2Cache, + ) + + cache = Svg2Cache() + batch.extra["svg2_cache"] = cache + + patch_size = server_args.pipeline_config.dit_config.patch_size + if isinstance(patch_size, list): + patch_size = tuple(patch_size) + if isinstance(patch_size, int): + patch_size_t = getattr( + server_args.pipeline_config.dit_config, "patch_size_t", None + ) + if patch_size_t is not None: + patch_size = (patch_size_t, patch_size, patch_size) + + context_length = 0 + prompt_length = None + if server_args.pipeline_config.dit_config.prefix.lower() == "hunyuan": + prompt_embeds = server_args.pipeline_config.get_pos_prompt_embeds(batch) + if isinstance(prompt_embeds, list): + text_embeds = prompt_embeds[0] if prompt_embeds else None + else: + text_embeds = prompt_embeds + if isinstance(text_embeds, torch.Tensor) and text_embeds.ndim >= 2: + context_length = int(text_embeds.shape[1]) + if context_length > 0 and batch.prompt_attention_mask: + mask = batch.prompt_attention_mask[0] + if isinstance(mask, torch.Tensor): + if mask.shape[-1] > context_length: + mask = mask[:, -context_length:] + prompt_length = int(mask[0].sum().item()) + if prompt_length is None: + prompt_length = context_length + + attn_metadata = self.attn_metadata_builder.build( + current_timestep=current_timestep, + raw_latent_shape=batch.raw_latent_shape, + patch_size=patch_size, + num_q_centroids=svg2_cfg.get("svg2_num_q_centroids", 300), + num_k_centroids=svg2_cfg.get("svg2_num_k_centroids", 1000), + top_p_kmeans=svg2_cfg.get("svg2_top_p_kmeans", 0.9), + min_kc_ratio=svg2_cfg.get("svg2_min_kc_ratio", 0.1), + kmeans_iter_init=svg2_cfg.get("svg2_kmeans_iter_init", 50), + kmeans_iter_step=svg2_cfg.get("svg2_kmeans_iter_step", 2), + zero_step_kmeans_init=svg2_cfg.get("svg2_zero_step_kmeans_init", False), + first_layers_fp=first_layers_fp, + first_times_fp=first_times_fp, + context_length=context_length, + prompt_length=prompt_length, + cache=cache, + calculate_density=False, # only need density when doing head load balancing + ) + elif self.attn_backend.get_enum() == AttentionBackendEnum.VMOBA_ATTN: + moba_params = server_args.attention_backend_config.moba_config.copy() + moba_params.update( + { + "current_timestep": i, + "raw_latent_shape": batch.raw_latent_shape[2:5], + "patch_size": server_args.pipeline_config.dit_config.patch_size, + "device": get_local_torch_device(), + } + ) + elif self.attn_backend.get_enum() == AttentionBackendEnum.FA: + attn_metadata = self.attn_metadata_builder.build( + raw_latent_shape=batch.raw_latent_shape + ) + else: + # attn_metadata can be None for SDPA attention backend + return None + + return attn_metadata + + def _predict_noise( + self, + current_model, + latent_model_input, + timestep, + target_dtype, + guidance: torch.Tensor, + **kwargs, + ): + return current_model( + hidden_states=latent_model_input, + timestep=timestep, + guidance=guidance, + **kwargs, + ) + + def _predict_noise_with_cfg( + self, + current_model: nn.Module, + latent_model_input: torch.Tensor, + timestep, + batch: Req, + timestep_index: int, + attn_metadata, + target_dtype, + current_guidance_scale, + image_kwargs: dict[str, Any], + pos_cond_kwargs: dict[str, Any], + neg_cond_kwargs: dict[str, Any], + server_args, + guidance, + latents, + ): + """ + Predict the noise residual with classifier-free guidance. + + Args: + current_model: The transformer model to use for the current step. + latent_model_input: The input latents for the model. + timestep: The expanded timestep tensor. + batch: The current batch information. + timestep_index: The current timestep index. + attn_metadata: Attention metadata for custom backends. + target_dtype: The target data type for autocasting. + current_guidance_scale: The guidance scale for the current step. + image_kwargs: Keyword arguments for image conditioning. + pos_cond_kwargs: Keyword arguments for positive prompt conditioning. + neg_cond_kwargs: Keyword arguments for negative prompt conditioning. + + Returns: + The predicted noise. + """ + noise_pred_cond: torch.Tensor | None = None + noise_pred_uncond: torch.Tensor | None = None + cfg_rank = get_classifier_free_guidance_rank() + # positive pass + if not (server_args.enable_cfg_parallel and cfg_rank != 0): + batch.is_cfg_negative = False + with set_forward_context( + current_timestep=timestep_index, + attn_metadata=attn_metadata, + forward_batch=batch, + ): + noise_pred_cond = self._predict_noise( + current_model=current_model, + latent_model_input=latent_model_input, + timestep=timestep, + target_dtype=target_dtype, + guidance=guidance, + **image_kwargs, + **pos_cond_kwargs, + ) + # TODO: can it be moved to after _predict_noise_with_cfg? + noise_pred_cond = server_args.pipeline_config.slice_noise_pred( + noise_pred_cond, latents + ) + if not batch.do_classifier_free_guidance: + # If CFG is disabled, we are done. Return the conditional prediction. + return noise_pred_cond + + # negative pass + if not server_args.enable_cfg_parallel or cfg_rank != 0: + batch.is_cfg_negative = True + with set_forward_context( + current_timestep=timestep_index, + attn_metadata=attn_metadata, + forward_batch=batch, + ): + noise_pred_uncond = self._predict_noise( + current_model=current_model, + latent_model_input=latent_model_input, + timestep=timestep, + target_dtype=target_dtype, + guidance=guidance, + **image_kwargs, + **neg_cond_kwargs, + ) + noise_pred_uncond = server_args.pipeline_config.slice_noise_pred( + noise_pred_uncond, latents + ) + + # Combine predictions + if server_args.enable_cfg_parallel: + # Each rank computes its partial contribution and we sum via all-reduce: + # final = s*cond + (1-s)*uncond + if cfg_rank == 0: + assert noise_pred_cond is not None + partial = current_guidance_scale * noise_pred_cond + else: + assert noise_pred_uncond is not None + partial = (1 - current_guidance_scale) * noise_pred_uncond + + noise_pred = cfg_model_parallel_all_reduce(partial) + + if batch.cfg_normalization and float(batch.cfg_normalization) > 0: + factor = float(batch.cfg_normalization) + pred_f = noise_pred.float() + new_norm = torch.linalg.vector_norm(pred_f) + if cfg_rank == 0: + cond_f = noise_pred_cond.float() + ori_norm = torch.linalg.vector_norm(cond_f) + else: + ori_norm = torch.empty_like(new_norm) + ori_norm = get_cfg_group().broadcast(ori_norm, src=0) + max_norm = ori_norm * factor + + if new_norm > max_norm: + noise_pred = noise_pred * (max_norm / new_norm) + + # Guidance rescale: broadcast std(cond) from rank 0, compute std(cfg) locally + if batch.guidance_rescale > 0.0: + std_cfg = noise_pred.std( + dim=list(range(1, noise_pred.ndim)), keepdim=True + ) + if cfg_rank == 0: + assert noise_pred_cond is not None + std_text = noise_pred_cond.std( + dim=list(range(1, noise_pred_cond.ndim)), keepdim=True + ) + else: + std_text = torch.empty_like(std_cfg) + # Broadcast std_text from local src=0 to all ranks in CFG group + std_text = get_cfg_group().broadcast(std_text, src=0) + noise_pred_rescaled = noise_pred * (std_text / std_cfg) + noise_pred = ( + batch.guidance_rescale * noise_pred_rescaled + + (1 - batch.guidance_rescale) * noise_pred + ) + return noise_pred + else: + # Serial CFG: both cond and uncond are available locally + assert noise_pred_cond is not None and noise_pred_uncond is not None + noise_pred = noise_pred_uncond + current_guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) + + if batch.cfg_normalization and float(batch.cfg_normalization) > 0: + factor = float(batch.cfg_normalization) + cond_f = noise_pred_cond.float() + pred_f = noise_pred.float() + ori_norm = torch.linalg.vector_norm(cond_f) + new_norm = torch.linalg.vector_norm(pred_f) + max_norm = ori_norm * factor + + if new_norm > max_norm: + noise_pred = noise_pred * (max_norm / new_norm) + + if batch.guidance_rescale > 0.0: + noise_pred = self.rescale_noise_cfg( + noise_pred, + noise_pred_cond, + guidance_rescale=batch.guidance_rescale, + ) + return noise_pred + + def prepare_sta_param(self, batch: Req, server_args: ServerArgs): + """ + Prepare Sliding Tile Attention (STA) parameters and settings. + """ + # TODO(kevin): STA mask search, currently only support Wan2.1 with 69x768x1280 + try: + STA_mode = STA_Mode[server_args.attention_backend_config.STA_mode] + except Exception as e: + logger.error(f"Passed STA_mode: {STA_mode} doesn't exist") + raise e + skip_time_steps = server_args.attention_backend_config.skip_time_steps + if batch.timesteps is None: + raise ValueError("Timesteps must be provided") + timesteps_num = batch.timesteps.shape[0] + + logger.info("STA_mode: %s", STA_mode) + if (batch.num_frames, batch.height, batch.width) != ( + 69, + 768, + 1280, + ) and STA_mode != "STA_inference": + raise NotImplementedError( + "STA mask search/tuning is not supported for this resolution" + ) + + if ( + STA_mode == STA_Mode.STA_SEARCHING + or STA_mode == STA_Mode.STA_TUNING + or STA_mode == STA_Mode.STA_TUNING_CFG + ): + size = (batch.width, batch.height) + if size == (1280, 768): + # TODO: make it configurable + sparse_mask_candidates_searching = [ + "3, 1, 10", + "1, 5, 7", + "3, 3, 3", + "1, 6, 5", + "1, 3, 10", + "3, 6, 1", + ] + sparse_mask_candidates_tuning = [ + "3, 1, 10", + "1, 5, 7", + "3, 3, 3", + "1, 6, 5", + "1, 3, 10", + "3, 6, 1", + ] + full_mask = ["3,6,10"] + else: + raise NotImplementedError( + "STA mask search is not supported for this resolution" + ) + layer_num = self.transformer.config.num_layers + # specific for HunyuanVideo + if hasattr(self.transformer.config, "num_single_layers"): + layer_num += self.transformer.config.num_single_layers + head_num = self.transformer.config.num_attention_heads + + if STA_mode == STA_Mode.STA_SEARCHING: + STA_param = configure_sta( + mode=STA_Mode.STA_SEARCHING, + layer_num=layer_num, + head_num=head_num, + time_step_num=timesteps_num, + mask_candidates=sparse_mask_candidates_searching + full_mask, + # last is full mask; Can add more sparse masks while keep last one as full mask + ) + elif STA_mode == STA_Mode.STA_TUNING: + STA_param = configure_sta( + mode=STA_Mode.STA_TUNING, + layer_num=layer_num, + head_num=head_num, + time_step_num=timesteps_num, + mask_search_files_path=f"output/mask_search_result_pos_{size[0]}x{size[1]}/", + mask_candidates=sparse_mask_candidates_tuning, + full_attention_mask=[int(x) for x in full_mask[0].split(",")], + skip_time_steps=skip_time_steps, # Use full attention for first 12 steps + save_dir=f"output/mask_search_strategy_{size[0]}x{size[1]}/", # Custom save directory + timesteps=timesteps_num, + ) + elif STA_mode == STA_Mode.STA_TUNING_CFG: + STA_param = configure_sta( + mode=STA_Mode.STA_TUNING_CFG, + layer_num=layer_num, + head_num=head_num, + time_step_num=timesteps_num, + mask_search_files_path_pos=f"output/mask_search_result_pos_{size[0]}x{size[1]}/", + mask_search_files_path_neg=f"output/mask_search_result_neg_{size[0]}x{size[1]}/", + mask_candidates=sparse_mask_candidates_tuning, + full_attention_mask=[int(x) for x in full_mask[0].split(",")], + skip_time_steps=skip_time_steps, + save_dir=f"output/mask_search_strategy_{size[0]}x{size[1]}/", + timesteps=timesteps_num, + ) + elif STA_mode == STA_Mode.STA_INFERENCE: + import sglang.multimodal_gen.envs as envs + + config_file = envs.SGLANG_DIFFUSION_ATTENTION_CONFIG + if config_file is None: + raise ValueError("SGLANG_DIFFUSION_ATTENTION_CONFIG is not set") + STA_param = configure_sta( + mode=STA_Mode.STA_INFERENCE, + layer_num=layer_num, + head_num=head_num, + time_step_num=timesteps_num, + load_path=config_file, + ) + + batch.STA_param = STA_param + batch.mask_search_final_result_pos = [[] for _ in range(timesteps_num)] + batch.mask_search_final_result_neg = [[] for _ in range(timesteps_num)] + + def save_sta_search_results(self, batch: Req): + """ + Save the STA mask search results. + + Args: + batch: The current batch information. + """ + size = (batch.width, batch.height) + if size == (1280, 768): + # TODO: make it configurable + sparse_mask_candidates_searching = [ + "3, 1, 10", + "1, 5, 7", + "3, 3, 3", + "1, 6, 5", + "1, 3, 10", + "3, 6, 1", + ] + else: + raise NotImplementedError( + "STA mask search is not supported for this resolution" + ) + + if batch.mask_search_final_result_pos is not None and batch.prompt is not None: + save_mask_search_results( + [dict(layer_data) for layer_data in batch.mask_search_final_result_pos], + prompt=str(batch.prompt), + mask_strategies=sparse_mask_candidates_searching, + output_dir=f"output/mask_search_result_pos_{size[0]}x{size[1]}/", + ) + if batch.mask_search_final_result_neg is not None and batch.prompt is not None: + save_mask_search_results( + [dict(layer_data) for layer_data in batch.mask_search_final_result_neg], + prompt=str(batch.prompt), + mask_strategies=sparse_mask_candidates_searching, + output_dir=f"output/mask_search_result_neg_{size[0]}x{size[1]}/", + ) + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify denoising stage inputs.""" + result = VerificationResult() + result.add_check("timesteps", batch.timesteps, [V.is_tensor, V.min_dims(1)]) + # disable temporarily for image-generation models + # result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)]) + result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty) + result.add_check("image_embeds", batch.image_embeds, V.is_list) + # result.add_check( + # "image_latent", batch.image_latent, V.none_or_tensor_with_dims(5) + # ) + result.add_check( + "num_inference_steps", batch.num_inference_steps, V.positive_int + ) + result.add_check("guidance_scale", batch.guidance_scale, V.non_negative_float) + result.add_check("eta", batch.eta, V.non_negative_float) + result.add_check("generator", batch.generator, V.generator_or_list_generators) + result.add_check( + "do_classifier_free_guidance", + batch.do_classifier_free_guidance, + V.bool_value, + ) + result.add_check( + "negative_prompt_embeds", + batch.negative_prompt_embeds, + lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x), + ) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify denoising stage outputs.""" + result = VerificationResult() + # result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)]) + return result diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py new file mode 100644 index 0000000000000000000000000000000000000000..18b2d8dbe2557a4616b5d734a5c087749ac0e513 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py @@ -0,0 +1,796 @@ +import copy +import math +import time +from io import BytesIO + +import av +import numpy as np +import PIL.Image +import torch +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from diffusers.models.modeling_outputs import AutoencoderKLOutput + +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.models.vision_utils import ( + load_image, + normalize, + numpy_to_pt, + pil_to_numpy, +) +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.denoising import DenoisingStage +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + StageValidators as V, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + VerificationResult, +) +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.runtime.utils.perf_logger import StageProfiler +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE + +logger = init_logger(__name__) + + +class LTX2AVDenoisingStage(DenoisingStage): + """ + LTX-2 specific denoising stage that handles joint video and audio generation. + """ + + def __init__(self, transformer, scheduler, vae=None, audio_vae=None, **kwargs): + super().__init__( + transformer=transformer, scheduler=scheduler, vae=vae, **kwargs + ) + self.audio_vae = audio_vae + + @staticmethod + def _get_video_latent_num_frames_for_model( + batch: Req, server_args: ServerArgs, latents: torch.Tensor + ) -> int: + """Return the latent-frame length the DiT model should see. + + - If video latents were time-sharded for SP and are packed as token latents + ([B, S, D]), the model only sees the local shard and must use the local + latent-frame count (stored on the batch during SP sharding). + - Otherwise, fall back to the global latent-frame count inferred from the + requested output frames and the VAE temporal compression ratio. + """ + did_sp_shard = bool(getattr(batch, "did_sp_shard_latents", False)) + is_token_latents = isinstance(latents, torch.Tensor) and latents.ndim == 3 + + if did_sp_shard and is_token_latents: + if not hasattr(batch, "sp_video_latent_num_frames"): + raise ValueError( + "SP-sharded LTX2 token latents require `batch.sp_video_latent_num_frames` " + "to be set by `LTX2PipelineConfig.shard_latents_for_sp()`." + ) + return int(batch.sp_video_latent_num_frames) + + pc = server_args.pipeline_config + return int((batch.num_frames - 1) // int(pc.vae_temporal_compression) + 1) + + @staticmethod + def _truncate_sp_padded_token_latents( + batch: Req, latents: torch.Tensor + ) -> torch.Tensor: + """Remove token padding introduced by SP time-sharding (if applicable).""" + did_sp_shard = bool(getattr(batch, "did_sp_shard_latents", False)) + if not did_sp_shard or not ( + isinstance(latents, torch.Tensor) and latents.ndim == 3 + ): + return latents + + raw_shape = getattr(batch, "raw_latent_shape", None) + if not (isinstance(raw_shape, tuple) and len(raw_shape) == 3): + return latents + + orig_s = int(raw_shape[1]) + cur_s = int(latents.shape[1]) + if cur_s == orig_s: + return latents + if cur_s < orig_s: + raise ValueError( + f"Unexpected gathered token-latents seq_len {cur_s} < original seq_len {orig_s}." + ) + return latents[:, :orig_s, :].contiguous() + + def _maybe_enable_cache_dit(self, num_inference_steps: int, batch: Req) -> None: + """Disable cache-dit for TI2V-style requests (image-conditioned), to avoid stale activations. + + NOTE: base denoising stage calls this hook with (num_inference_steps, batch). + """ + if getattr(self, "_disable_cache_dit_for_request", False): + return + return super()._maybe_enable_cache_dit(num_inference_steps, batch) + + @staticmethod + def _resize_center_crop( + img: PIL.Image.Image, *, width: int, height: int + ) -> PIL.Image.Image: + return img.resize((width, height), resample=PIL.Image.Resampling.BILINEAR) + + @staticmethod + def _apply_video_codec_compression( + img_array: np.ndarray, crf: int = 33 + ) -> np.ndarray: + """Encode as a single H.264 frame and decode back to simulate compression artifacts.""" + if crf == 0: + return img_array + height, width = img_array.shape[0] // 2 * 2, img_array.shape[1] // 2 * 2 + img_array = img_array[:height, :width] + buffer = BytesIO() + container = av.open(buffer, mode="w", format="mp4") + stream = container.add_stream( + "libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"} + ) + stream.height, stream.width = height, width + frame = av.VideoFrame.from_ndarray(img_array, format="rgb24").reformat( + format="yuv420p" + ) + container.mux(stream.encode(frame)) + container.mux(stream.encode()) + container.close() + buffer.seek(0) + container = av.open(buffer) + decoded = next(container.decode(container.streams.video[0])) + container.close() + return decoded.to_ndarray(format="rgb24") + + @staticmethod + def _resize_center_crop_tensor( + img: PIL.Image.Image, + *, + width: int, + height: int, + device: torch.device, + dtype: torch.dtype, + apply_codec_compression: bool = True, + codec_crf: int = 33, + ) -> torch.Tensor: + """Resize, center-crop, and normalize to [1, C, 1, H, W] tensor in [-1, 1].""" + img_array = np.array(img).astype(np.uint8)[..., :3] + if apply_codec_compression: + img_array = LTX2AVDenoisingStage._apply_video_codec_compression( + img_array, crf=codec_crf + ) + tensor = ( + torch.from_numpy(img_array.astype(np.float32)) + .permute(2, 0, 1) + .unsqueeze(0) + .to(device=device) + ) + src_h, src_w = tensor.shape[2], tensor.shape[3] + scale = max(height / src_h, width / src_w) + new_h, new_w = math.ceil(src_h * scale), math.ceil(src_w * scale) + tensor = torch.nn.functional.interpolate( + tensor, size=(new_h, new_w), mode="bilinear", align_corners=False + ) + top, left = (new_h - height) // 2, (new_w - width) // 2 + tensor = tensor[:, :, top : top + height, left : left + width] + return ((tensor / 127.5 - 1.0).to(dtype=dtype)).unsqueeze(2) + + @staticmethod + def _pil_to_normed_tensor(img: PIL.Image.Image) -> torch.Tensor: + # PIL -> numpy [0,1] -> torch [B,C,H,W], then [-1,1] + arr = pil_to_numpy(img) + t = numpy_to_pt(arr) + return normalize(t) + + @staticmethod + def _should_apply_ltx2_ti2v(batch: Req) -> bool: + """True if we have an image-latent token prefix to condition with. + + SP note: when token latents are time-sharded, only the rank that owns the + *global* first latent frame should apply TI2V conditioning (rank with start_frame==0). + """ + if ( + batch.image_latent is None + or int(getattr(batch, "ltx2_num_image_tokens", 0)) <= 0 + ): + return False + did_sp_shard = bool(getattr(batch, "did_sp_shard_latents", False)) + if not did_sp_shard: + return True + return int(getattr(batch, "sp_video_start_frame", 0)) == 0 + + def _prepare_ltx2_image_latent(self, batch: Req, server_args: ServerArgs) -> None: + """Encode `batch.image_path` into packed token latents for LTX-2 TI2V.""" + if ( + batch.image_latent is not None + and int(getattr(batch, "ltx2_num_image_tokens", 0)) > 0 + ): + return + batch.ltx2_num_image_tokens = 0 + batch.image_latent = None + + if batch.image_path is None: + return + if batch.width is None or batch.height is None: + raise ValueError("width/height must be provided for LTX-2 TI2V.") + if self.vae is None: + raise ValueError("VAE must be provided for LTX-2 TI2V.") + + image_path = ( + batch.image_path[0] + if isinstance(batch.image_path, list) + else batch.image_path + ) + + img = load_image(image_path) + batch.condition_image = self._resize_center_crop( + img, width=int(batch.width), height=int(batch.height) + ) + + latents_device = ( + batch.latents.device + if isinstance(batch.latents, torch.Tensor) + else torch.device("cpu") + ) + encode_dtype = batch.latents.dtype + original_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision] + self.vae = self.vae.to(device=latents_device, dtype=encode_dtype) + vae_autocast_enabled = ( + original_dtype != torch.float32 + ) and not server_args.disable_autocast + + video_condition = self._resize_center_crop_tensor( + img, + width=int(batch.width), + height=int(batch.height), + device=latents_device, + dtype=encode_dtype, + ) + + with torch.autocast( + device_type=current_platform.device_type, + dtype=original_dtype, + enabled=vae_autocast_enabled, + ): + try: + if server_args.pipeline_config.vae_tiling: + self.vae.enable_tiling() + except Exception: + pass + if not vae_autocast_enabled: + video_condition = video_condition.to(encode_dtype) + + latent_dist: DiagonalGaussianDistribution = self.vae.encode(video_condition) + if isinstance(latent_dist, AutoencoderKLOutput): + latent_dist = latent_dist.latent_dist + + mode = server_args.pipeline_config.vae_config.encode_sample_mode() + if mode == "argmax": + latent = latent_dist.mode() + elif mode == "sample": + if batch.generator is None: + raise ValueError("Generator must be provided for VAE sampling.") + latent = latent_dist.sample(batch.generator) + else: + raise ValueError(f"Unsupported encode_sample_mode: {mode}") + + # Per-channel normalization: normalized = (x - mean) / std + mean = self.vae.latents_mean.view(1, -1, 1, 1, 1).to(latent) + std = self.vae.latents_std.view(1, -1, 1, 1, 1).to(latent) + latent = (latent - mean) / std + + packed = server_args.pipeline_config.maybe_pack_latents( + latent, latent.shape[0], batch + ) + if not (isinstance(packed, torch.Tensor) and packed.ndim == 3): + raise ValueError("Expected packed image latents [B, S0, D].") + + # Fail-fast token count: must match one latent frame's tokens. + vae_sf = int(server_args.pipeline_config.vae_scale_factor) + patch = int(server_args.pipeline_config.patch_size) + latent_h = int(batch.height) // vae_sf + latent_w = int(batch.width) // vae_sf + expected_tokens = (latent_h // patch) * (latent_w // patch) + if int(packed.shape[1]) != int(expected_tokens): + raise ValueError( + "LTX-2 conditioning token count mismatch: " + f"{int(packed.shape[1])=} {int(expected_tokens)=}." + ) + + batch.image_latent = packed + batch.ltx2_num_image_tokens = int(packed.shape[1]) + + if batch.debug: + logger.info( + "LTX2 TI2V conditioning prepared: %d tokens (shape=%s) for %sx%s", + batch.ltx2_num_image_tokens, + tuple(batch.image_latent.shape), + batch.width, + batch.height, + ) + + self.vae.to(original_dtype) + if server_args.vae_cpu_offload: + self.vae = self.vae.to("cpu") + + @torch.no_grad() + def forward(self, batch: Req, server_args: ServerArgs) -> Req: + """ + Run the denoising loop. + + Args: + batch: The current batch information. + server_args: The inference arguments. + + Returns: + The batch with denoised latents. + """ + # Disable cache-dit for image-conditioned requests (TI2V-style) for correctness/debuggability. + self._disable_cache_dit_for_request = batch.image_path is not None + + # Prepare variables for the denoising loop + + prepared_vars = self._prepare_denoising_loop(batch, server_args) + extra_step_kwargs = prepared_vars["extra_step_kwargs"] + target_dtype = prepared_vars["target_dtype"] + autocast_enabled = prepared_vars["autocast_enabled"] + timesteps = prepared_vars["timesteps"] + num_inference_steps = prepared_vars["num_inference_steps"] + num_warmup_steps = prepared_vars["num_warmup_steps"] + image_kwargs = prepared_vars["image_kwargs"] + pos_cond_kwargs = prepared_vars["pos_cond_kwargs"] + neg_cond_kwargs = prepared_vars["neg_cond_kwargs"] + latents = prepared_vars["latents"] + boundary_timestep = prepared_vars["boundary_timestep"] + z = prepared_vars["z"] + reserved_frames_mask = prepared_vars["reserved_frames_mask"] + seq_len = prepared_vars["seq_len"] + guidance = prepared_vars["guidance"] + + audio_latents = batch.audio_latents + audio_scheduler = copy.deepcopy(self.scheduler) + + # Prepare TI2V conditioning once (encode image -> patchify tokens). + self._prepare_ltx2_image_latent(batch, server_args) + + # For LTX-2 packed token latents, SP sharding happens on the time dimension + # (frames). The model must see local latent frames (RoPE offset is applied + # inside the model using SP rank). + latent_num_frames_for_model = self._get_video_latent_num_frames_for_model( + batch=batch, server_args=server_args, latents=latents + ) + latent_height = batch.height // server_args.pipeline_config.vae_scale_factor + latent_width = batch.width // server_args.pipeline_config.vae_scale_factor + + # Initialize lists for ODE trajectory + trajectory_timesteps: list[torch.Tensor] = [] + trajectory_latents: list[torch.Tensor] = [] + trajectory_audio_latents: list[torch.Tensor] = [] + + # Run denoising loop + denoising_start_time = time.time() + + # to avoid device-sync caused by timestep comparison + is_warmup = batch.is_warmup + self.scheduler.set_begin_index(0) + audio_scheduler.set_begin_index(0) + timesteps_cpu = timesteps.cpu() + num_timesteps = timesteps_cpu.shape[0] + + do_ti2v = self._should_apply_ltx2_ti2v(batch) + num_img_tokens = int(getattr(batch, "ltx2_num_image_tokens", 0)) + denoise_mask = None + clean_latent = None + if do_ti2v: + if not (isinstance(latents, torch.Tensor) and latents.ndim == 3): + raise ValueError("LTX-2 TI2V expects packed token latents [B, S, D].") + latents[:, :num_img_tokens, :] = batch.image_latent[ + :, :num_img_tokens, : + ].to(device=latents.device, dtype=latents.dtype) + denoise_mask = torch.ones( + (latents.shape[0], latents.shape[1], 1), + device=latents.device, + dtype=torch.float32, + ) + denoise_mask[:, :num_img_tokens, :] = 0.0 + clean_latent = latents.detach().clone() + clean_latent[:, :num_img_tokens, :] = batch.image_latent[ + :, :num_img_tokens, : + ].to(device=latents.device, dtype=latents.dtype) + + with torch.autocast( + device_type=current_platform.device_type, + dtype=target_dtype, + enabled=autocast_enabled, + ): + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t_host in enumerate(timesteps_cpu): + with StageProfiler( + f"denoising_step_{i}", + logger=logger, + metrics=batch.metrics, + perf_dump_path_provided=batch.perf_dump_path is not None, + ): + t_int = int(t_host.item()) + t_device = timesteps[i] + current_model, current_guidance_scale = ( + self._select_and_manage_model( + t_int=t_int, + boundary_timestep=boundary_timestep, + server_args=server_args, + batch=batch, + ) + ) + + # Predict noise residual + attn_metadata = self._build_attn_metadata(i, batch, server_args) + + # === LTX-2 sigma-space Euler step (flow matching) === + # Use scheduler-generated sigmas (includes terminal sigma=0). + sigmas = getattr(self.scheduler, "sigmas", None) + if sigmas is None or not isinstance(sigmas, torch.Tensor): + raise ValueError( + "Expected scheduler.sigmas to be a tensor for LTX-2." + ) + sigma = sigmas[i].to(device=latents.device, dtype=torch.float32) + sigma_next = sigmas[i + 1].to( + device=latents.device, dtype=torch.float32 + ) + dt = sigma_next - sigma + + latent_model_input = latents.to(target_dtype) + audio_latent_model_input = audio_latents.to(target_dtype) + + latent_num_frames = latent_num_frames_for_model + + # Audio latent dims + if audio_latent_model_input.ndim == 3: + audio_num_frames_latent = int( + audio_latent_model_input.shape[1] + ) + elif audio_latent_model_input.ndim == 4: + audio_num_frames_latent = int( + audio_latent_model_input.shape[2] + ) + else: + raise ValueError( + f"Unexpected audio latents rank: {audio_latent_model_input.ndim}, shape={tuple(audio_latent_model_input.shape)}" + ) + + # LTX-2 model can generate coords internally. + video_coords = None + audio_coords = None + + timestep = t_device.expand(int(latent_model_input.shape[0])) + if do_ti2v and denoise_mask is not None: + timestep_video = timestep.unsqueeze( + -1 + ) * denoise_mask.squeeze(-1) + else: + timestep_video = timestep + timestep_audio = timestep + + # Conditions + encoder_hidden_states = batch.prompt_embeds[0] + audio_encoder_hidden_states = batch.audio_prompt_embeds[0] + encoder_attention_mask = batch.prompt_attention_mask + + # Follow ltx-pipelines structure: separate pos/neg forward passes, + # then apply CFG on denoised (x0) predictions. + with set_forward_context( + current_timestep=i, attn_metadata=attn_metadata + ): + v_pos, a_v_pos = current_model( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, + timestep=timestep_video, + audio_timestep=timestep_audio, + encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=encoder_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=batch.fps, + audio_num_frames=audio_num_frames_latent, + video_coords=video_coords, + audio_coords=audio_coords, + return_latents=False, + return_dict=False, + ) + + if batch.do_classifier_free_guidance: + neg_encoder_hidden_states = ( + batch.negative_prompt_embeds[0] + ) + neg_audio_encoder_hidden_states = ( + batch.negative_audio_prompt_embeds[0] + ) + neg_encoder_attention_mask = ( + batch.negative_attention_mask + ) + + v_neg, a_v_neg = current_model( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=neg_encoder_hidden_states, + audio_encoder_hidden_states=neg_audio_encoder_hidden_states, + timestep=timestep_video, + audio_timestep=timestep_audio, + encoder_attention_mask=neg_encoder_attention_mask, + audio_encoder_attention_mask=neg_encoder_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=batch.fps, + audio_num_frames=audio_num_frames_latent, + video_coords=video_coords, + audio_coords=audio_coords, + return_latents=False, + return_dict=False, + ) + else: + v_neg = None + a_v_neg = None + + v_pos = v_pos.float() + a_v_pos = a_v_pos.float() + if v_neg is not None: + v_neg = v_neg.float() + if a_v_neg is not None: + a_v_neg = a_v_neg.float() + + # Velocity -> denoised (x0): x0 = x - sigma * v + sigma_val = float(sigma.item()) + denoised_video = (latents.float() - sigma_val * v_pos).to( + latents.dtype + ) + denoised_audio = ( + audio_latents.float() - sigma_val * a_v_pos + ).to(audio_latents.dtype) + + if ( + batch.do_classifier_free_guidance + and v_neg is not None + and a_v_neg is not None + ): + denoised_video_neg = ( + latents.float() - sigma_val * v_neg + ).to(latents.dtype) + denoised_audio_neg = ( + audio_latents.float() - sigma_val * a_v_neg + ).to(audio_latents.dtype) + denoised_video = denoised_video + ( + batch.guidance_scale - 1.0 + ) * (denoised_video - denoised_video_neg) + denoised_audio = denoised_audio + ( + batch.guidance_scale - 1.0 + ) * (denoised_audio - denoised_audio_neg) + + # Apply conditioning mask (keep conditioned tokens clean). + if ( + do_ti2v + and denoise_mask is not None + and clean_latent is not None + ): + denoised_video = ( + denoised_video * denoise_mask + + clean_latent.float() * (1.0 - denoise_mask) + ) + + # Euler step in sigma space: x_next = x + (sigma_next - sigma) * v, + # where v = (x - x0) / sigma. + if sigma_val == 0.0: + v_video = torch.zeros_like(denoised_video) + v_audio = torch.zeros_like(denoised_audio) + else: + v_video = ( + (latents.float() - denoised_video.float()) / sigma_val + ).to(latents.dtype) + v_audio = ( + (audio_latents.float() - denoised_audio.float()) + / sigma_val + ).to(audio_latents.dtype) + + latents = (latents.float() + v_video.float() * dt).to( + dtype=latents.dtype + ) + audio_latents = ( + audio_latents.float() + v_audio.float() * dt + ).to(dtype=audio_latents.dtype) + + if do_ti2v: + latents[:, :num_img_tokens, :] = batch.image_latent[ + :, :num_img_tokens, : + ].to(device=latents.device, dtype=latents.dtype) + + latents = self.post_forward_for_ti2v_task( + batch, server_args, reserved_frames_mask, latents, z + ) + + # save trajectory latents if needed + if batch.return_trajectory_latents: + trajectory_timesteps.append(t_host) + trajectory_latents.append(latents) + if audio_latents is not None: + trajectory_audio_latents.append(audio_latents) + + # Update progress bar + if i == num_timesteps - 1 or ( + (i + 1) > num_warmup_steps + and (i + 1) % self.scheduler.order == 0 + and progress_bar is not None + ): + progress_bar.update() + + if not is_warmup: + self.step_profile() + + denoising_end_time = time.time() + + if num_timesteps > 0 and not is_warmup: + self.log_info( + "average time per step: %.4f seconds", + (denoising_end_time - denoising_start_time) / len(timesteps), + ) + + batch.audio_latents = audio_latents + self._post_denoising_loop( + batch=batch, + latents=latents, + trajectory_latents=trajectory_latents, + trajectory_timesteps=trajectory_timesteps, + trajectory_audio_latents=trajectory_audio_latents, + server_args=server_args, + is_warmup=is_warmup, + ) + + return batch + + def _post_denoising_loop( + self, + batch: Req, + latents: torch.Tensor, + trajectory_latents: list, + trajectory_timesteps: list, + trajectory_audio_latents: list, + server_args: ServerArgs, + is_warmup: bool = False, + ): + # 1. Handle Trajectory (Video) - Copy from base + if trajectory_latents: + trajectory_tensor = torch.stack(trajectory_latents, dim=1) + trajectory_timesteps_tensor = torch.stack(trajectory_timesteps, dim=0) + else: + trajectory_tensor = None + trajectory_timesteps_tensor = None + + latents, trajectory_tensor = self._postprocess_sp_latents( + batch, latents, trajectory_tensor + ) + + # If SP time-sharding padded whole frames worth of tokens, remove padding + # after gather and before unpacking. + latents = self._truncate_sp_padded_token_latents(batch, latents) + + if trajectory_tensor is not None and trajectory_timesteps_tensor is not None: + batch.trajectory_timesteps = trajectory_timesteps_tensor.cpu() + batch.trajectory_latents = trajectory_tensor.cpu() + + # 2. Handle Trajectory (Audio) - LTX-2 specific + if trajectory_audio_latents: + trajectory_audio_tensor = torch.stack(trajectory_audio_latents, dim=1) + # We don't have SP support for audio latents yet (or needed?) + batch.trajectory_audio_latents = trajectory_audio_tensor.cpu() + + # 3. Unpack and Denormalize + # Call pipeline_config._unpad_and_unpack_latents + # latents is video latents. + # batch.audio_latents is audio latents. + + audio_latents = batch.audio_latents + + # NOTE: self.vae and self.audio_vae should be populated via __init__ or manual setting + if self.vae is None or self.audio_vae is None: + logger.warning( + "VAE or Audio VAE not found in DenoisingStage. Skipping unpack and denormalize." + ) + batch.latents = latents + batch.audio_latents = audio_latents + else: + latents, audio_latents = ( + server_args.pipeline_config._unpad_and_unpack_latents( + latents, audio_latents, batch, self.vae, self.audio_vae + ) + ) + + batch.latents = latents + batch.audio_latents = audio_latents + + # 4. Cleanup + offload_mgr = getattr(self.transformer, "_layerwise_offload_manager", None) + if offload_mgr is not None and getattr(offload_mgr, "enabled", False): + offload_mgr.release_all() + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify denoising stage inputs. + + Note: LTX-2 connector stage converts `prompt_embeds`/`negative_prompt_embeds` + from list-of-tensors to a single tensor (video context) and stores audio + context separately. + """ + + result = VerificationResult() + result.add_check("timesteps", batch.timesteps, [V.is_tensor, V.min_dims(1)]) + + # LTX-2 may carry prompt embeddings as either a tensor (preferred) or legacy list. + result.add_check( + "prompt_embeds", + batch.prompt_embeds, + lambda x: V.is_tensor(x) or V.list_not_empty(x), + ) + + # Keep base expectation: image_embeds is always a list (may be empty). + result.add_check("image_embeds", batch.image_embeds, V.is_list) + + result.add_check( + "num_inference_steps", batch.num_inference_steps, V.positive_int + ) + result.add_check("guidance_scale", batch.guidance_scale, V.non_negative_float) + result.add_check("eta", batch.eta, V.non_negative_float) + result.add_check("generator", batch.generator, V.generator_or_list_generators) + result.add_check( + "do_classifier_free_guidance", + batch.do_classifier_free_guidance, + V.bool_value, + ) + + # When CFG is enabled, negative prompt embeddings must exist (tensor or legacy list). + result.add_check( + "negative_prompt_embeds", + batch.negative_prompt_embeds, + lambda x: (not batch.do_classifier_free_guidance) + or V.is_tensor(x) + or V.list_not_empty(x), + ) + return result + + def do_classifier_free_guidance(self, batch: Req) -> bool: + return batch.guidance_scale > 1.0 + + +class LTX2RefinementStage(LTX2AVDenoisingStage): + def __init__( + self, transformer, scheduler, distilled_sigmas, vae=None, audio_vae=None + ): + super().__init__(transformer, scheduler, vae, audio_vae) + self.distilled_sigmas = torch.tensor(distilled_sigmas) + + def forward(self, batch: Req, server_args: ServerArgs) -> Req: + # 1. Add noise to latents + noise_scale = self.distilled_sigmas[0].to(batch.latents.device) + noise = torch.randn_like(batch.latents) + batch.latents = batch.latents + noise * noise_scale + + # 2. Run denoising loop with distilled_sigmas + # Save original sigmas + original_sigmas = self.scheduler.sigmas + original_timesteps = self.scheduler.timesteps + original_num_inference_steps = self.scheduler.num_inference_steps + + # Set distilled sigmas + self.scheduler.sigmas = self.distilled_sigmas.to(self.scheduler.sigmas.device) + # Approximation for timesteps + self.scheduler.timesteps = self.scheduler.sigmas * 1000 + self.scheduler.num_inference_steps = len(self.distilled_sigmas) - 1 + + # Call parent forward + try: + batch = super().forward(batch, server_args) + finally: + # Restore original sigmas + self.scheduler.sigmas = original_sigmas + self.scheduler.timesteps = original_timesteps + self.scheduler.num_inference_steps = original_num_inference_steps + + return batch + + def do_classifier_free_guidance(self, batch: Req) -> bool: + return False # Stage 2 uses simple denoising (no CFG) diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py new file mode 100644 index 0000000000000000000000000000000000000000..504fc429e03bf2308250364154bcada36d3c372a --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py @@ -0,0 +1,294 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import time + +import torch + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) +from sglang.multimodal_gen.runtime.models.utils import pred_noise_to_pred_video +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages import DenoisingStage +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.runtime.utils.perf_logger import StageProfiler +from sglang.multimodal_gen.utils import dict_to_3d_list + +logger = init_logger(__name__) + + +class DmdDenoisingStage(DenoisingStage): + """ + Denoising stage for DMD. + """ + + def __init__(self, transformer, scheduler, transformer_2=None) -> None: + super().__init__( + transformer=transformer, scheduler=scheduler, transformer_2=transformer_2 + ) + self.scheduler = FlowMatchEulerDiscreteScheduler(shift=8.0) + + def _preprocess_sp_latents(self, batch: Req, server_args: ServerArgs): + # 1. to shard latents (B, C, T, H, W) along dim 2 + super()._preprocess_sp_latents(batch, server_args) + + # 2. DMD expects (B, T, C, H, W) for the main latents in the loop + if batch.latents is not None: + batch.latents = batch.latents.permute(0, 2, 1, 3, 4) + + # Note: batch.image_latent is kept as (B, C, T, H, W) here + + def _postprocess_sp_latents( + self, + batch: Req, + latents: torch.Tensor, + trajectory_tensor: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # 1. convert back from DMD's (B, T, C, H, W) to standard (B, C, T, H, W) + # this is because base gather_latents_for_sp expects dim=2 for T + latents = latents.permute(0, 2, 1, 3, 4) + + # 2. use base method to gather + return super()._postprocess_sp_latents(batch, latents, trajectory_tensor) + + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Run the denoising loop. + """ + prepared_vars = self._prepare_denoising_loop(batch, server_args) + + target_dtype = prepared_vars["target_dtype"] + autocast_enabled = prepared_vars["autocast_enabled"] + num_warmup_steps = prepared_vars["num_warmup_steps"] + latents = prepared_vars["latents"] + video_raw_latent_shape = latents.shape + + timesteps = torch.tensor( + server_args.pipeline_config.dmd_denoising_steps, + dtype=torch.long, + device=get_local_torch_device(), + ) + + # prepare image_kwargs + image_embeds = batch.image_embeds + if len(image_embeds) > 0: + image_embeds = [img.to(target_dtype) for img in image_embeds] + + image_kwargs = self.prepare_extra_func_kwargs( + self.transformer.forward, + { + "encoder_hidden_states_image": image_embeds, + "mask_strategy": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24), + }, + ) + + pos_cond_kwargs = prepared_vars["pos_cond_kwargs"] + + denoising_loop_start_time = time.time() + with self.progress_bar(total=len(timesteps)) as progress_bar: + for i, t in enumerate(timesteps): + # Skip if interrupted + if hasattr(self, "interrupt") and self.interrupt: + continue + + with StageProfiler( + f"denoising_step_{i}", + logger=logger, + metrics=batch.metrics, + perf_dump_path_provided=batch.perf_dump_path is not None, + ): + t_int = int(t.item()) + if self.transformer_2 is not None: + current_model, current_guidance_scale = ( + self._select_and_manage_model( + t_int=t_int, + boundary_timestep=self._handle_boundary_ratio( + server_args, batch + ), + server_args=server_args, + batch=batch, + ) + ) + else: + current_model = self.transformer + self._manage_device_placement(current_model, None, server_args) + # Expand latents for I2V + noise_latents = latents.clone() + latent_model_input = latents.to(target_dtype) + + if batch.image_latent is not None: + latent_model_input = torch.cat( + [ + latent_model_input, + batch.image_latent.permute(0, 2, 1, 3, 4), + ], + dim=2, + ).to(target_dtype) + assert not torch.isnan( + latent_model_input + ).any(), "latent_model_input contains nan" + + # Prepare inputs for transformer + t_expand = t.repeat(latent_model_input.shape[0]) + + guidance_expand = self.get_or_build_guidance( + latent_model_input.shape[0], + target_dtype, + get_local_torch_device(), + ) + + # Predict noise residual + with torch.autocast( + device_type=current_platform.device_type, + dtype=target_dtype, + enabled=autocast_enabled, + ): + attn_metadata = self._build_attn_metadata(i, batch, server_args) + + batch.is_cfg_negative = False + with set_forward_context( + current_timestep=i, + attn_metadata=attn_metadata, + forward_batch=batch, + ): + # Run transformer + pred_noise = current_model( + hidden_states=latent_model_input.permute(0, 2, 1, 3, 4), + timestep=t_expand, + guidance=guidance_expand, + **image_kwargs, + **pos_cond_kwargs, + ).permute(0, 2, 1, 3, 4) + + pred_video = pred_noise_to_pred_video( + pred_noise=pred_noise.flatten(0, 1), + noise_input_latent=noise_latents.flatten(0, 1), + timestep=t_expand, + scheduler=self.scheduler, + ).unflatten(0, pred_noise.shape[:2]) + + if i < len(timesteps) - 1: + next_timestep = timesteps[i + 1] * torch.ones( + [1], dtype=torch.long, device=pred_video.device + ) + noise = torch.randn( + video_raw_latent_shape, + dtype=pred_video.dtype, + generator=batch.generator[0], + device=self.device, + ) + latents = self.scheduler.add_noise( + pred_video.flatten(0, 1), + noise.flatten(0, 1), + next_timestep, + ).unflatten(0, pred_video.shape[:2]) + else: + latents = pred_video + + # Update progress bar + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps + and (i + 1) % self.scheduler.order == 0 + and progress_bar is not None + ): + progress_bar.update() + + self.step_profile() + + denoising_loop_end_time = time.time() + if len(timesteps) > 0: + self.log_info( + "average time per step: %.4f seconds", + (denoising_loop_end_time - denoising_loop_start_time) / len(timesteps), + ) + + self._post_denoising_loop( + batch=batch, + latents=latents, + trajectory_latents=[], + trajectory_timesteps=[], + server_args=server_args, + ) + + return batch + + def _select_and_manage_model( + self, + t_int: int, + boundary_timestep: float | None, + server_args: ServerArgs, + batch: Req, + ): + if boundary_timestep is None or t_int >= boundary_timestep: + # High-noise stage + current_model = self.transformer + model_to_offload = self.transformer_2 + current_guidance_scale = batch.guidance_scale + else: + # Low-noise stage + current_model = self.transformer_2 + model_to_offload = self.transformer + current_guidance_scale = batch.guidance_scale_2 + + self._manage_device_placement(current_model, model_to_offload, server_args) + + assert current_model is not None, "The model for the current step is not set." + return current_model, current_guidance_scale + + def _manage_device_placement( + self, + model_to_use: torch.nn.Module, + model_to_offload: torch.nn.Module | None, + server_args: ServerArgs, + ): + """ + Manages the offload / load behavior of dit + """ + if not server_args.dit_cpu_offload: + return + + # Offload the unused model if it's on CUDA + if ( + model_to_offload is not None + and next(model_to_offload.parameters()).device.type == "cuda" + ): + model_to_offload.to("cpu") + + # Load the model to use if it's on CPU + if ( + model_to_use is not None + and next(model_to_use.parameters()).device.type == "cpu" + ): + model_to_use.to(get_local_torch_device()) + + def _handle_boundary_ratio( + self, + server_args, + batch, + ): + """ + (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert + """ + boundary_ratio = server_args.pipeline_config.dit_config.boundary_ratio + if batch.boundary_ratio is not None: + logger.info( + "Overriding boundary ratio from %s to %s", + boundary_ratio, + batch.boundary_ratio, + ) + boundary_ratio = batch.boundary_ratio + + if boundary_ratio is not None: + boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps + else: + boundary_timestep = None + + return boundary_timestep diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/encoding.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff56c7876a4631a71d59977c5e81f0d7d1cdaf2 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/encoding.py @@ -0,0 +1,107 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Encoding stage for diffusion pipelines. +""" + +import torch + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + V, # Import validators +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + VerificationResult, +) +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE + +logger = init_logger(__name__) + + +class EncodingStage(PipelineStage): + """ + Stage for encoding pixel space representations into latent space. + + This stage handles the encoding of pixel-space video/images into latent + representations for further processing in the diffusion pipeline. + """ + + def __init__(self, vae: ParallelTiledVAE) -> None: + super().__init__() + self.vae: ParallelTiledVAE = vae + + @torch.no_grad() + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify encoding stage inputs.""" + result = VerificationResult() + # Input video/images for VAE encoding: [batch_size, channels, frames, height, width] + result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)]) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify encoding stage outputs.""" + result = VerificationResult() + # Encoded latents: [batch_size, channels, frames, height_latents, width_latents] + result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)]) + return result + + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Encode pixel space representations into latent space. + + + + Returns: + The batch with encoded latents. + """ + assert batch.latents is not None and isinstance(batch.latents, torch.Tensor) + + self.vae = self.vae.to(get_local_torch_device()) + + # Setup VAE precision + vae_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision] + vae_autocast_enabled = ( + vae_dtype != torch.float32 + ) and not server_args.disable_autocast + + # Normalize input to [-1, 1] range (reverse of decoding normalization) + latents = (batch.latents * 2.0 - 1.0).clamp(-1, 1) + + # Move to appropriate device and dtype + latents = latents.to(get_local_torch_device()) + + # Encode image to latents + with torch.autocast( + device_type=current_platform.device_type, + dtype=vae_dtype, + enabled=vae_autocast_enabled, + ): + if server_args.pipeline_config.vae_tiling: + self.vae.enable_tiling() + # if server_args.vae_sp: + # self.vae.enable_parallel() + if not vae_autocast_enabled: + latents = latents.to(vae_dtype) + latents = self.vae.encode(latents).mean + + # Update batch with encoded latents + batch.latents = latents + + # Offload models if needed + self.maybe_free_model_hooks() + + if server_args.vae_cpu_offload: + self.vae.to("cpu") + + return batch diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_paint.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_paint.py new file mode 100644 index 0000000000000000000000000000000000000000..15baaedd6af4c5ec5c4b5f8064d462923f272033 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_paint.py @@ -0,0 +1,1066 @@ +""" +Hunyuan3D paint/texture generation stages. + +Three-stage pipeline: Preprocess -> TexGen -> Postprocess. +""" + +from __future__ import annotations + +import os +from typing import Any + +import numpy as np +import torch +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( + retrieve_timesteps, +) +from einops import rearrange + +from sglang.multimodal_gen.configs.pipeline_configs.hunyuan3d import ( + Hunyuan3D2PipelineConfig, +) +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.base import ( + PipelineStage, + StageParallelismType, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + StageValidators as V, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + VerificationResult, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +# Utility functions +def guidance_scale_embedding( + w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 +) -> torch.Tensor: + """Generate guidance scale embeddings.""" + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + +def extract_into_tensor( + a: torch.Tensor, t: torch.Tensor, x_shape: tuple, n_gen: int +) -> torch.Tensor: + """Extract values from tensor and reshape for multi-view generation.""" + out = a.gather(-1, t) + out = out.repeat(n_gen) + out = rearrange(out, "(b n) -> b n", n=n_gen) + b, c, *_ = out.shape + return out.reshape(b, c, *((1,) * (len(x_shape) - 2))) + + +def get_predicted_original_sample( + model_output: torch.Tensor, + timesteps: torch.Tensor, + sample: torch.Tensor, + prediction_type: str, + alphas: torch.Tensor, + sigmas: torch.Tensor, + n_gen: int, +) -> torch.Tensor: + """Get predicted original sample from model output.""" + alphas = extract_into_tensor(alphas, timesteps, sample.shape, n_gen) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape, n_gen) + model_output = rearrange(model_output, "(b n) c h w -> b n c h w", n=n_gen) + + if prediction_type == "epsilon": + pred_x_0 = (sample - sigmas * model_output) / alphas + elif prediction_type == "sample": + pred_x_0 = model_output + elif prediction_type == "v_prediction": + pred_x_0 = alphas * sample - sigmas * model_output + else: + raise ValueError( + f"Prediction type {prediction_type} is not supported; " + "currently, `epsilon`, `sample`, and `v_prediction` are supported." + ) + + return pred_x_0 + + +def get_predicted_noise( + model_output: torch.Tensor, + timesteps: torch.Tensor, + sample: torch.Tensor, + prediction_type: str, + alphas: torch.Tensor, + sigmas: torch.Tensor, + n_gen: int, +) -> torch.Tensor: + """Get predicted noise from model output.""" + alphas = extract_into_tensor(alphas, timesteps, sample.shape, n_gen) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape, n_gen) + model_output = rearrange(model_output, "(b n) c h w -> b n c h w", n=n_gen) + + if prediction_type == "epsilon": + pred_epsilon = model_output + elif prediction_type == "sample": + pred_epsilon = (sample - alphas * model_output) / sigmas + elif prediction_type == "v_prediction": + pred_epsilon = alphas * model_output + sigmas * sample + else: + raise ValueError( + f"Prediction type {prediction_type} is not supported; " + "currently, `epsilon`, `sample`, and `v_prediction` are supported." + ) + + return pred_epsilon + + +def to_rgb_image(maybe_rgba): + """Convert RGBA image to RGB.""" + from PIL import Image + + if maybe_rgba.mode == "RGB": + return maybe_rgba + if maybe_rgba.mode == "RGBA": + rgba = maybe_rgba + img = np.random.randint( + 127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=np.uint8 + ) + img = Image.fromarray(img, "RGB") + img.paste(rgba, mask=rgba.getchannel("A")) + return img + raise ValueError(f"Unsupported image type: {maybe_rgba.mode}") + + +class DDIMSolver: + """DDIM solver for fast sampling.""" + + def __init__( + self, + alpha_cumprods: np.ndarray, + timesteps: int = 1000, + ddim_timesteps: int = 50, + ): + step_ratio = timesteps // ddim_timesteps + self.ddim_timesteps = ( + np.arange(1, ddim_timesteps + 1) * step_ratio + ).round().astype(np.int64) - 1 + self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] + self.ddim_alpha_cumprods_prev = np.asarray( + [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() + ) + self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() + self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) + self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev) + + def to(self, device: torch.device) -> "DDIMSolver": + self.ddim_timesteps = self.ddim_timesteps.to(device) + self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) + self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device) + return self + + def ddim_step( + self, + pred_x0: torch.Tensor, + pred_noise: torch.Tensor, + timestep_index: torch.Tensor, + n_gen: int, + ) -> torch.Tensor: + alpha_cumprod_prev = extract_into_tensor( + self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape, n_gen + ) + dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise + x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt + return x_prev + + +def _recorrect_rgb( + src_image: torch.Tensor, + target_image: torch.Tensor, + alpha_channel: torch.Tensor, + scale: float = 0.95, +) -> torch.Tensor: + """Correct RGB values to match target color distribution.""" + + def flat_and_mask(bgr, a): + mask = torch.where(a > 0.5, True, False) + bgr_flat = bgr.reshape(-1, bgr.shape[-1]) + mask_flat = mask.reshape(-1) + bgr_flat_masked = bgr_flat[mask_flat, :] + return bgr_flat_masked + + src_flat = flat_and_mask(src_image, alpha_channel) + target_flat = flat_and_mask(target_image, alpha_channel) + corrected_bgr = torch.zeros_like(src_image) + + for i in range(3): + src_mean, src_stddev = torch.mean(src_flat[:, i]), torch.std(src_flat[:, i]) + target_mean, target_stddev = torch.mean(target_flat[:, i]), torch.std( + target_flat[:, i] + ) + corrected_bgr[:, :, i] = torch.clamp( + (src_image[:, :, i] - scale * src_mean) * (target_stddev / src_stddev) + + scale * target_mean, + 0, + 1, + ) + + src_mse = torch.mean((src_image - target_image) ** 2) + modify_mse = torch.mean((corrected_bgr - target_image) ** 2) + if src_mse < modify_mse: + corrected_bgr = torch.cat([src_image, alpha_channel], dim=-1) + else: + corrected_bgr = torch.cat([corrected_bgr, alpha_channel], dim=-1) + + return corrected_bgr + + +# Stage 1: Preprocess (UV unwrap + delight + multi-view rendering) +class Hunyuan3DPaintPreprocessStage(PipelineStage): + """Preprocessing: UV unwrap + delight in parallel, then multi-view rendering.""" + + CAMERA_AZIMS = [0, 90, 180, 270, 0, 180] + CAMERA_ELEVS = [0, 0, 0, 0, 90, -90] + VIEW_WEIGHTS = [1, 0.1, 0.5, 0.1, 0.05, 0.05] + + @property + def parallelism_type(self) -> StageParallelismType: + return StageParallelismType.MAIN_RANK_ONLY + + def __init__(self, config: Hunyuan3D2PipelineConfig) -> None: + super().__init__() + self.config = config + self._delight_pipeline = None + self._delight_loaded = False + self._renderer = None + self._renderer_loaded = False + + # --- UV unwrap --- + + def _do_uv_unwrap(self, batch: Req, server_args: ServerArgs) -> Req: + import time + + from sglang.multimodal_gen.runtime.utils.mesh3d_utils import mesh_uv_wrap + + mesh = batch.extra["shape_meshes"] + if isinstance(mesh, list): + mesh = mesh[0] + + try: + start_time = time.time() + mesh = mesh_uv_wrap(mesh) + elapsed = time.time() - start_time + logger.info(f"UV unwrapping completed in {elapsed:.2f}s") + except Exception as e: + logger.warning(f"UV unwrapping failed: {e}") + + batch.extra["paint_mesh"] = mesh + return batch + + # --- Delight --- + + def _load_delight_model(self, server_args: ServerArgs): + if self._delight_loaded: + return + + from diffusers import ( + EulerAncestralDiscreteScheduler, + StableDiffusionInstructPix2PixPipeline, + ) + from huggingface_hub import snapshot_download + + model_path = server_args.model_path + delight_subfolder = getattr( + self.config, "delight_subfolder", "hunyuan3d-delight-v2-0" + ) + + local_path = os.path.join(model_path, delight_subfolder) + if not os.path.exists(local_path): + local_path = os.path.expanduser(local_path) + + if not os.path.exists(local_path): + try: + downloaded = snapshot_download( + repo_id=model_path, + allow_patterns=[f"{delight_subfolder}/*"], + ) + local_path = os.path.join(downloaded, delight_subfolder) + except Exception as e: + logger.warning("Could not download delight model: %s", e) + local_path = None + + if local_path and os.path.exists(local_path): + pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( + local_path, + torch_dtype=torch.float16, + safety_checker=None, + ) + pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( + pipeline.scheduler.config + ) + pipeline.set_progress_bar_config(disable=True) + self._delight_pipeline = pipeline.to(self.device, torch.float16) + logger.info("Delight model loaded successfully") + else: + logger.warning( + "Delight model not available, skipping delight preprocessing" + ) + + self._delight_loaded = True + + @torch.no_grad() + def _run_delight(self, image): + import cv2 + from PIL import Image as PILImage + + image = image.resize((512, 512)) + + if image.mode == "RGBA": + image_array = np.array(image) + alpha_channel = image_array[:, :, 3] + erosion_size = 3 + kernel = np.ones((erosion_size, erosion_size), np.uint8) + alpha_channel = cv2.erode(alpha_channel, kernel, iterations=1) + image_array[alpha_channel == 0, :3] = 255 + image_array[:, :, 3] = alpha_channel + image = PILImage.fromarray(image_array) + + image_tensor = torch.tensor(np.array(image) / 255.0).to(self.device) + alpha = image_tensor[:, :, 3:] + rgb_target = image_tensor[:, :, :3] + else: + image_tensor = torch.tensor(np.array(image) / 255.0).to(self.device) + alpha = torch.ones_like(image_tensor)[:, :, :1] + rgb_target = image_tensor[:, :, :3] + + image = image.convert("RGB") + + image = self._delight_pipeline( + prompt=self.config.delight_prompt, + image=image, + generator=torch.manual_seed(42), + height=512, + width=512, + num_inference_steps=self.config.delight_num_inference_steps, + image_guidance_scale=self.config.delight_cfg_image, + guidance_scale=self.config.delight_guidance_scale, + ).images[0] + + image_tensor = torch.tensor(np.array(image) / 255.0).to(self.device) + rgb_src = image_tensor[:, :, :3] + image = _recorrect_rgb(rgb_src, rgb_target, alpha) + image = image[:, :, :3] * image[:, :, 3:] + torch.ones_like(image[:, :, :3]) * ( + 1.0 - image[:, :, 3:] + ) + image = PILImage.fromarray((image.cpu().numpy() * 255).astype(np.uint8)) + + return image + + def _do_delight(self, batch: Req, server_args: ServerArgs) -> Req: + from PIL import Image + + from sglang.multimodal_gen.runtime.utils.mesh3d_utils import recenter_image + + image = Image.open(batch.image_path) + image = recenter_image(image) + + if not self.config.delight_enable: + logger.info("Delight preprocessing disabled, using original image") + batch.extra["delighted_image"] = image + return batch + + self._load_delight_model(server_args) + if self._delight_pipeline is not None: + try: + image = self._run_delight(image) + logger.info("Image delight completed") + except Exception as e: + logger.warning(f"Image delight failed: {e}") + + batch.extra["delighted_image"] = image + return batch + + # --- Multi-view rendering --- + + def _init_renderer(self): + if self._renderer_loaded: + return + + from sglang.multimodal_gen.runtime.utils.mesh3d_utils import MeshRender + + self._renderer = MeshRender( + default_resolution=self.config.paint_render_size, + texture_size=self.config.paint_texture_size, + ) + self._renderer_loaded = True + logger.info("Mesh renderer initialized") + + def _render_multiview(self, mesh) -> tuple: + self._init_renderer() + self._renderer.load_mesh(mesh) + + normal_maps = self._renderer.render_normal_multiview( + self.CAMERA_ELEVS, self.CAMERA_AZIMS, use_abs_coor=True + ) + position_maps = self._renderer.render_position_multiview( + self.CAMERA_ELEVS, self.CAMERA_AZIMS + ) + + logger.info(f"Rendered {len(normal_maps)} views for texture generation") + return normal_maps, position_maps + + # --- Forward --- + + def forward(self, batch: Req, server_args: ServerArgs) -> Req: + if batch.extra.get("_mesh_failed"): + logger.warning("Mesh generation failed, skipping paint preprocessing") + batch.extra["paint_mesh"] = None + batch.extra["delighted_image"] = None + batch.extra["normal_maps"] = [] + batch.extra["position_maps"] = [] + batch.extra["camera_azims"] = self.CAMERA_AZIMS + batch.extra["camera_elevs"] = self.CAMERA_ELEVS + batch.extra["view_weights"] = self.VIEW_WEIGHTS + batch.extra["renderer"] = None + return batch + + import concurrent.futures + import copy + + # 1. UV unwrap + delight in parallel + batch_for_uv = batch + batch_for_delight = copy.copy(batch) + batch_for_delight.extra = batch.extra.copy() + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + uv_future = executor.submit(self._do_uv_unwrap, batch_for_uv, server_args) + delight_future = executor.submit( + self._do_delight, batch_for_delight, server_args + ) + uv_future.result() + delight_future.result() + + batch.extra["paint_mesh"] = batch_for_uv.extra.get("paint_mesh") + batch.extra["delighted_image"] = batch_for_delight.extra.get("delighted_image") + + # 2. Multi-view rendering + normal_maps, position_maps = self._render_multiview(batch.extra["paint_mesh"]) + batch.extra["normal_maps"] = normal_maps + batch.extra["position_maps"] = position_maps + batch.extra["camera_azims"] = self.CAMERA_AZIMS + batch.extra["camera_elevs"] = self.CAMERA_ELEVS + batch.extra["view_weights"] = self.VIEW_WEIGHTS + batch.extra["renderer"] = self._renderer + + return batch + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + result = VerificationResult() + result.add_check("shape_meshes", batch.extra.get("shape_meshes"), V.not_none) + result.add_check("image_path", batch.image_path, V.not_none) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + result = VerificationResult() + result.add_check("paint_mesh", batch.extra.get("paint_mesh"), V.not_none) + result.add_check( + "delighted_image", batch.extra.get("delighted_image"), V.not_none + ) + result.add_check("normal_maps", batch.extra.get("normal_maps"), V.is_list) + result.add_check("position_maps", batch.extra.get("position_maps"), V.is_list) + result.add_check("renderer", batch.extra.get("renderer"), V.not_none) + return result + + +# Stage 2: TexGen (model loading + input prep + denoising + decode) +class Hunyuan3DPaintTexGenStage(PipelineStage): + def __init__( + self, + config: Hunyuan3D2PipelineConfig, + paint_dir: str | None = None, + transformer: Any = None, + scheduler: Any = None, + vae: Any = None, + vae_scale_factor: int = 8, + image_processor: Any = None, + solver: Any = None, + is_turbo: bool = False, + ) -> None: + super().__init__() + self.config = config + self.paint_dir = paint_dir + self.transformer = transformer + self.scheduler = scheduler + self.vae = vae + self.vae_scale_factor = vae_scale_factor + self.image_processor = image_processor + self.solver = solver + self.is_turbo = is_turbo + self._loaded = transformer is not None + + @property + def parallelism_type(self) -> StageParallelismType: + return StageParallelismType.MAIN_RANK_ONLY + + def _load_paint_models(self, server_args: ServerArgs) -> None: + """Load paint models from pre-resolved local path (no network).""" + if self._loaded: + return + if self.paint_dir is None: + logger.warning("No paint model directory resolved, skipping") + self._loaded = True + return + try: + self._do_load_paint(server_args) + logger.info("Paint pipeline loaded successfully") + except Exception as e: + logger.warning("Failed to load paint pipeline: %s", e) + self.vae = None + self.transformer = None + self.scheduler = None + self._loaded = True + + def _do_load_paint(self, server_args: ServerArgs) -> None: + import json + + from diffusers import AutoencoderKL + from diffusers.image_processor import VaeImageProcessor + + from sglang.multimodal_gen.runtime.models.dits.hunyuan3d import ( + UNet2p5DConditionModel, + ) + + local_path = self.paint_dir + logger.info("Loading paint model from %s", local_path) + vae_dir = os.path.join(local_path, "vae") + with open(os.path.join(vae_dir, "config.json"), "r") as f: + vae_config = json.load(f) + vae_config = {k: v for k, v in vae_config.items() if not k.startswith("_")} + self.vae = AutoencoderKL(**vae_config) + st_path = os.path.join(vae_dir, "diffusion_pytorch_model.safetensors") + bin_path = os.path.join(vae_dir, "diffusion_pytorch_model.bin") + if os.path.exists(st_path): + from safetensors.torch import load_file + + state_dict = load_file(st_path) + elif os.path.exists(bin_path): + state_dict = torch.load(bin_path, map_location="cpu", weights_only=True) + else: + raise FileNotFoundError(f"No VAE weights in {vae_dir}") + self.vae.load_state_dict(state_dict) + self.vae = self.vae.to(device=self.device, dtype=torch.float16).eval() + self.transformer = UNet2p5DConditionModel.from_pretrained( + os.path.join(local_path, "unet"), + torch_dtype=torch.float16, + ).to(self.device) + self.is_turbo = bool(getattr(self.config, "paint_turbo_mode", False)) + sched_path = os.path.join(local_path, "scheduler", "scheduler_config.json") + with open(sched_path, "r") as f: + sched_cfg = json.load(f) + if self.is_turbo: + from diffusers import LCMScheduler + + self.scheduler = LCMScheduler.from_config(sched_cfg) + else: + from diffusers import EulerAncestralDiscreteScheduler + + self.scheduler = EulerAncestralDiscreteScheduler.from_config( + sched_cfg, timestep_spacing="trailing" + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.solver = DDIMSolver( + self.scheduler.alphas_cumprod.cpu().numpy(), + timesteps=self.scheduler.config.num_train_timesteps, + ddim_timesteps=30, + ).to(self.device) + if server_args.enable_torch_compile: + compile_mode = os.environ.get( + "SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs" + ) + logger.info("Compiling paint transformer with mode: %s", compile_mode) + self.transformer.compile(mode=compile_mode, fullgraph=False, dynamic=None) + + def _convert_pil_list_to_tensor( + self, images: list, device: torch.device + ) -> torch.Tensor: + bg_c = [1.0, 1.0, 1.0] + images_tensor = [] + for batch_imgs in images: + view_imgs = [] + for pil_img in batch_imgs: + if pil_img.mode == "L": + pil_img = pil_img.point( + lambda x: 255 if x > 1 else 0, mode="1" + ).convert("RGB") + img = np.asarray(pil_img, dtype=np.float32) / 255.0 + if img.shape[2] > 3: + alpha = img[:, :, 3:] + img = img[:, :, :3] * alpha + bg_c * (1 - alpha) + img = ( + torch.from_numpy(img) + .permute(2, 0, 1) + .unsqueeze(0) + .contiguous() + .to(device=device, dtype=self.vae.dtype) + ) + view_imgs.append(img) + view_imgs = torch.cat(view_imgs, dim=0) + images_tensor.append(view_imgs.unsqueeze(0)) + return torch.cat(images_tensor, dim=0) + + @torch.no_grad() + def _encode_images(self, images: torch.Tensor) -> torch.Tensor: + batch_size = images.shape[0] + images = rearrange(images, "b n c h w -> (b n) c h w") + dtype = next(self.vae.parameters()).dtype + images = (images - 0.5) * 2.0 + posterior = self.vae.encode(images.to(dtype)).latent_dist + latents = posterior.sample() * self.vae.config.scaling_factor + return rearrange(latents, "(b n) c h w -> b n c h w", b=batch_size) + + @staticmethod + def _compute_camera_index(azim: float, elev: float) -> int: + base_idx = int(((azim // 30) + 9) % 12) + if elev == 0: + base, divisor = 12, 1 + elif elev == 20: + base, divisor = 24, 1 + elif elev == -20: + base, divisor = 0, 1 + elif elev == 90: + base, divisor = 40, 3 + elif elev == -90: + base, divisor = 36, 3 + else: + base, divisor = 12, 1 + return base + (base_idx // divisor) + + def _prepare_denoising_inputs( + self, + batch: Req, + server_args: ServerArgs, + ) -> dict[str, Any]: + import random + + from diffusers.utils.torch_utils import randn_tensor + + device = self.device + normal_maps = batch.extra["normal_maps"] + position_maps = batch.extra["position_maps"] + camera_azims = batch.extra["camera_azims"] + camera_elevs = batch.extra["camera_elevs"] + + num_steps = self.config.paint_num_inference_steps + guidance_scale = self.config.paint_guidance_scale + render_size = self.config.paint_resolution + num_in_batch = len(normal_maps) + + seed = 0 + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + generator = torch.Generator(device=device).manual_seed(seed) + + image = batch.extra["delighted_image"] + if not isinstance(image, list): + image = [image] + image = [to_rgb_image(img) for img in image] + + image_vae = [ + torch.tensor(np.array(img, dtype=np.float32) / 255.0) for img in image + ] + image_vae = [ + iv.unsqueeze(0).permute(0, 3, 1, 2).unsqueeze(0) for iv in image_vae + ] + image_vae = torch.cat(image_vae, dim=1).to(device=device, dtype=self.vae.dtype) + ref_latents = self._encode_images(image_vae) + + target_size = render_size + if isinstance(normal_maps, list): + normal_maps = [ + ( + img.resize((target_size, target_size)) + if hasattr(img, "resize") + else img + ) + for img in normal_maps + ] + normal_maps = self._convert_pil_list_to_tensor([normal_maps], device) + if isinstance(position_maps, list): + position_maps = [ + ( + img.resize((target_size, target_size)) + if hasattr(img, "resize") + else img + ) + for img in position_maps + ] + position_maps = self._convert_pil_list_to_tensor([position_maps], device) + + normal_imgs = ( + self._encode_images(normal_maps) if normal_maps is not None else None + ) + position_imgs = ( + self._encode_images(position_maps) if position_maps is not None else None + ) + + camera_info = [ + self._compute_camera_index(azim, elev) + for azim, elev in zip(camera_azims, camera_elevs) + ] + camera_info_gen = torch.tensor([camera_info], device=device, dtype=torch.int64) + camera_info_ref = torch.tensor([[0]], device=device, dtype=torch.int64) + + do_cfg = guidance_scale > 1 and not self.is_turbo + + if self.is_turbo and position_maps is not None: + from sglang.multimodal_gen.runtime.models.dits.hunyuan3d import ( + compute_multi_resolution_discrete_voxel_indice, + compute_multi_resolution_mask, + ) + + position_attn_mask = compute_multi_resolution_mask(position_maps) + position_voxel_indices = compute_multi_resolution_discrete_voxel_indice( + position_maps + ) + else: + position_attn_mask = None + position_voxel_indices = None + + if do_cfg: + negative_ref_latents = torch.zeros_like(ref_latents) + ref_latents = torch.cat([negative_ref_latents, ref_latents]) + ref_scale = torch.as_tensor([0.0, 1.0]).to(ref_latents) + if normal_imgs is not None: + normal_imgs = torch.cat((normal_imgs, normal_imgs)) + if position_imgs is not None: + position_imgs = torch.cat((position_imgs, position_imgs)) + if position_maps is not None: + position_maps = torch.cat((position_maps, position_maps)) + camera_info_gen = torch.cat((camera_info_gen, camera_info_gen)) + camera_info_ref = torch.cat((camera_info_ref, camera_info_ref)) + else: + ref_scale = None + + model_kwargs = { + "ref_latents": ref_latents, + "num_in_batch": num_in_batch, + } + if ref_scale is not None: + model_kwargs["ref_scale"] = ref_scale + if normal_imgs is not None: + model_kwargs["normal_imgs"] = normal_imgs + if position_imgs is not None: + model_kwargs["position_imgs"] = position_imgs + if position_maps is not None: + model_kwargs["position_maps"] = position_maps + model_kwargs["camera_info_gen"] = camera_info_gen + model_kwargs["camera_info_ref"] = camera_info_ref + if position_attn_mask is not None: + model_kwargs["position_attn_mask"] = position_attn_mask + if position_voxel_indices is not None: + model_kwargs["position_voxel_indices"] = position_voxel_indices + + prompt_embeds = self.transformer.learned_text_clip_gen.repeat(1, 1, 1) + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + + if self.is_turbo: + bsz = 3 + index = torch.arange(29, -1, -bsz, device=device).long() + timesteps = self.solver.ddim_timesteps[index] + self.scheduler.set_timesteps(timesteps=timesteps.cpu(), device=device) + timesteps = self.scheduler.timesteps + else: + timesteps, num_steps = retrieve_timesteps( + self.scheduler, num_steps, device, None, None + ) + + num_channels_latents = self.transformer.config.in_channels + latent_shape = ( + num_in_batch, + num_channels_latents, + render_size // self.vae_scale_factor, + render_size // self.vae_scale_factor, + ) + latents = randn_tensor( + latent_shape, generator=generator, device=device, dtype=prompt_embeds.dtype + ) + latents = latents * self.scheduler.init_noise_sigma + + return { + "timesteps": timesteps, + "latents": latents, + "prompt_embeds": prompt_embeds, + "negative_prompt_embeds": negative_prompt_embeds, + "model_kwargs": model_kwargs, + "num_in_batch": num_in_batch, + "num_inference_steps": num_steps, + "guidance_scale": guidance_scale, + "do_cfg": do_cfg, + "generator": generator, + "num_channels_latents": num_channels_latents, + } + + @torch.no_grad() + def _denoise_loop( + self, + timesteps: torch.Tensor, + latents: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + model_kwargs: dict[str, Any], + num_in_batch: int, + guidance_scale: float, + do_cfg: bool, + generator: torch.Generator, + num_channels_latents: int, + ) -> torch.Tensor: + import inspect + + if do_cfg: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + extra_step_kwargs = {} + if "eta" in inspect.signature(self.scheduler.step).parameters: + extra_step_kwargs["eta"] = 0.0 + if "generator" in inspect.signature(self.scheduler.step).parameters: + extra_step_kwargs["generator"] = generator + + for step_idx, t in enumerate(timesteps): + latents = rearrange(latents, "(b n) c h w -> b n c h w", n=num_in_batch) + latent_model_input = torch.cat([latents] * 2) if do_cfg else latents + latent_model_input = rearrange( + latent_model_input, "b n c h w -> (b n) c h w" + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = rearrange( + latent_model_input, "(b n) c h w -> b n c h w", n=num_in_batch + ) + + with set_forward_context( + current_timestep=step_idx, + attn_metadata=None, + ): + noise_pred = self.transformer( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=None, + cross_attention_kwargs=None, + added_cond_kwargs=None, + return_dict=False, + **model_kwargs, + )[0] + + latents = rearrange(latents, "b n c h w -> (b n) c h w") + + if do_cfg: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + latents = self.scheduler.step( + noise_pred, + t, + latents[:, :num_channels_latents, :, :], + **extra_step_kwargs, + return_dict=False, + )[0] + + return latents + + @torch.no_grad() + def _decode_latents(self, latents: torch.Tensor) -> list: + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] + return self.image_processor.postprocess(image, output_type="pil") + + def forward(self, batch: Req, server_args: ServerArgs) -> Req: + if batch.extra.get("_mesh_failed"): + logger.warning("Mesh generation failed, skipping paint texgen") + batch.extra["multiview_textures"] = [] + return batch + + self._load_paint_models(server_args) + + delighted_image = batch.extra["delighted_image"] + normal_maps = batch.extra["normal_maps"] + + if self.transformer is not None: + try: + prepared = self._prepare_denoising_inputs(batch, server_args) + + latents = self._denoise_loop( + timesteps=prepared["timesteps"], + latents=prepared["latents"], + prompt_embeds=prepared["prompt_embeds"], + negative_prompt_embeds=prepared["negative_prompt_embeds"], + model_kwargs=prepared["model_kwargs"], + num_in_batch=prepared["num_in_batch"], + guidance_scale=prepared["guidance_scale"], + do_cfg=prepared["do_cfg"], + generator=prepared["generator"], + num_channels_latents=prepared["num_channels_latents"], + ) + + multiview_textures = self._decode_latents(latents) + logger.info( + "Paint pipeline generated %d textures", len(multiview_textures) + ) + + except Exception as e: + logger.error(f"Paint pipeline execution failed: {e}") + import traceback + + traceback.print_exc() + render_size = self.config.paint_resolution + multiview_textures = [ + delighted_image.resize((render_size, render_size)) + for _ in range(len(normal_maps)) + ] + else: + logger.warning( + "Paint pipeline not available, using reference image for all views" + ) + render_size = self.config.paint_resolution + multiview_textures = [ + delighted_image.resize((render_size, render_size)) + for _ in range(len(normal_maps)) + ] + + batch.extra["multiview_textures"] = multiview_textures + logger.info(f"Generated {len(multiview_textures)} texture views") + return batch + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + if batch.extra.get("_mesh_failed"): + return VerificationResult() + result = VerificationResult() + result.add_check( + "delighted_image", batch.extra.get("delighted_image"), V.not_none + ) + result.add_check("normal_maps", batch.extra.get("normal_maps"), V.is_list) + result.add_check("position_maps", batch.extra.get("position_maps"), V.is_list) + result.add_check("camera_azims", batch.extra.get("camera_azims"), V.is_list) + result.add_check("camera_elevs", batch.extra.get("camera_elevs"), V.is_list) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + result = VerificationResult() + result.add_check( + "multiview_textures", batch.extra.get("multiview_textures"), V.is_list + ) + return result + + +# Stage 3: Postprocess (texture baking + mesh export) +class Hunyuan3DPaintPostprocessStage(PipelineStage): + """Texture baking from multi-view images and final mesh export.""" + + @property + def parallelism_type(self) -> StageParallelismType: + return StageParallelismType.MAIN_RANK_ONLY + + def __init__(self, config: Hunyuan3D2PipelineConfig) -> None: + super().__init__() + self.config = config + + def forward(self, batch: Req, server_args: ServerArgs) -> OutputBatch: + if batch.extra.get("_mesh_failed"): + logger.warning("Mesh generation failed, skipping paint postprocess") + return OutputBatch(output_file_paths=[], metrics=batch.metrics) + + renderer = batch.extra["renderer"] + multiview_textures = batch.extra["multiview_textures"] + camera_elevs = batch.extra["camera_elevs"] + camera_azims = batch.extra["camera_azims"] + view_weights = batch.extra["view_weights"] + + render_size = getattr(self.config, "paint_render_size", 2048) + resized_textures = [] + for tex in multiview_textures: + if hasattr(tex, "resize"): + resized_textures.append(tex.resize((render_size, render_size))) + else: + resized_textures.append(tex) + + try: + texture, mask = renderer.bake_from_multiview( + resized_textures, + camera_elevs, + camera_azims, + view_weights, + method="fast", + ) + + mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype("uint8") + texture = renderer.texture_inpaint(texture, mask_np) + + renderer.set_texture(texture) + textured_mesh = renderer.save_mesh() + logger.info("Texture baking completed") + except Exception as e: + logger.error(f"Texture baking failed: {e}") + textured_mesh = batch.extra["paint_mesh"] + + obj_path = batch.extra["shape_obj_path"] + return_path = batch.extra["shape_return_path"] + + try: + textured_mesh.export(obj_path) + if self.config.paint_save_glb: + glb_path = obj_path[:-4] + ".glb" + textured_mesh.export(glb_path) + return_path = glb_path + self._cleanup_obj_artifacts(obj_path) + except Exception as e: + logger.error(f"Mesh export failed: {e}") + + return OutputBatch(output_file_paths=[return_path], metrics=batch.metrics) + + @staticmethod + def _cleanup_obj_artifacts(obj_path: str) -> None: + """Remove OBJ file and trimesh-generated material artifacts.""" + obj_dir = os.path.dirname(obj_path) or "." + targets = [obj_path] + for f in os.listdir(obj_dir): + if f.endswith(".mtl") or (f.startswith("material") and f.endswith(".png")): + targets.append(os.path.join(obj_dir, f)) + for path in targets: + try: + os.remove(path) + except OSError: + pass + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + if batch.extra.get("_mesh_failed"): + return VerificationResult() + result = VerificationResult() + result.add_check("renderer", batch.extra.get("renderer"), V.not_none) + result.add_check( + "multiview_textures", batch.extra.get("multiview_textures"), V.is_list + ) + result.add_check("camera_elevs", batch.extra.get("camera_elevs"), V.is_list) + result.add_check("camera_azims", batch.extra.get("camera_azims"), V.is_list) + result.add_check("view_weights", batch.extra.get("view_weights"), V.is_list) + return result + + +__all__ = [ + "Hunyuan3DPaintPreprocessStage", + "Hunyuan3DPaintTexGenStage", + "Hunyuan3DPaintPostprocessStage", +] diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_shape.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..6c3be73c1af0042991c6a2ee742f6f4683a45e5b --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_shape.py @@ -0,0 +1,527 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Hunyuan3D shape generation stages. + +Four-stage pipeline: BeforeDenoising -> Denoising -> Export -> Save. +""" + +from __future__ import annotations + +import os +from typing import Any + +import numpy as np +import torch + +from sglang.multimodal_gen.configs.pipeline_configs.hunyuan3d import ( + Hunyuan3D2PipelineConfig, +) +from sglang.multimodal_gen.runtime.loader.component_loaders.transformer_loader import ( + TransformerLoader, +) +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage +from sglang.multimodal_gen.runtime.pipelines_core.stages.denoising import DenoisingStage +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + StageValidators as V, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + VerificationResult, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.runtime.utils.mesh3d_utils import export_to_trimesh + +logger = init_logger(__name__) + + +def retrieve_timesteps( + scheduler, + num_inference_steps=None, + device=None, + timesteps=None, + sigmas=None, + **kwargs, +): + """Retrieve timesteps from scheduler.""" + import inspect + + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of timesteps or sigmas can be passed.") + + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_timesteps: + raise ValueError( + f"Scheduler {scheduler.__class__} doesn't support custom timesteps." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + + elif sigmas is not None: + accepts_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_sigmas: + raise ValueError( + f"Scheduler {scheduler.__class__} doesn't support custom sigmas." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + + return timesteps, num_inference_steps + + +def _prepare_shape_image(image_processor, image, mask=None) -> dict: + """Prepare shape image for conditioning.""" + if isinstance(image, torch.Tensor) and isinstance(mask, torch.Tensor): + return {"image": image, "mask": mask} + + if isinstance(image, str) and not os.path.exists(image): + raise FileNotFoundError(f"Couldn't find image at path {image}") + + if not isinstance(image, list): + image = [image] + + outputs = [image_processor(img) for img in image] + cond_input = {k: [] for k in outputs[0].keys()} + for output in outputs: + for key, value in output.items(): + cond_input[key].append(value) + for key, value in cond_input.items(): + if isinstance(value[0], torch.Tensor): + cond_input[key] = torch.cat(value, dim=0) + return cond_input + + +def _move_to_device(payload, device, dtype): + """Recursively move tensors in payload to specified device and dtype.""" + if isinstance(payload, torch.Tensor): + return payload.to(device=device, dtype=dtype) + if isinstance(payload, dict): + return {k: _move_to_device(v, device, dtype) for k, v in payload.items()} + if isinstance(payload, list): + return [_move_to_device(v, device, dtype) for v in payload] + return payload + + +class Hunyuan3DShapeBeforeDenoisingStage(PipelineStage): + """Monolithic pre-processing stage for Hunyuan3D shape generation. + + Consolidates input validation, image preprocessing, conditioning, and + latent/timestep preparation into a single stage. + """ + + def __init__( + self, + image_processor: Any, + conditioner: Any, + vae: Any, + model: Any, + scheduler: Any, + config: Hunyuan3D2PipelineConfig, + ) -> None: + super().__init__() + self.image_processor = image_processor + self.conditioner = conditioner + self.vae = vae + self.model = model + self.scheduler = scheduler + self.config = config + + def _validate_input(self, batch: Req, server_args: ServerArgs) -> None: + if batch.image_path is None: + raise ValueError("Hunyuan3D requires 'image_path' input.") + if isinstance(batch.image_path, list): + if len(batch.image_path) != 1: + raise ValueError("Hunyuan3D only supports a single image input.") + batch.image_path = batch.image_path[0] + if not isinstance(batch.image_path, str): + raise ValueError( + f"Hunyuan3D expects image_path as str, got {type(batch.image_path)}" + ) + if not os.path.exists(batch.image_path): + raise FileNotFoundError(f"Image path not found: {batch.image_path}") + if batch.num_outputs_per_prompt != 1: + raise ValueError("Hunyuan3D only supports num_outputs_per_prompt=1.") + + def _prepare_latents(self, batch_size, dtype, device, generator): + from diffusers.utils.torch_utils import randn_tensor + + shape = (batch_size, *self.vae.latent_shape) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents * getattr(self.scheduler, "init_noise_sigma", 1.0) + + def forward(self, batch: Req, server_args: ServerArgs) -> Req: + # 1. Input validation + self._validate_input(batch, server_args) + + # 2. Image preprocessing + cond_inputs = _prepare_shape_image(self.image_processor, batch.image_path) + image = cond_inputs.pop("image") + + device = self.device + dtype = next(self.model.parameters()).dtype + image = _move_to_device(image, device, dtype) + cond_inputs = _move_to_device(cond_inputs, device, dtype) + + # 3. Conditioning with CFG + do_cfg = batch.guidance_scale >= 0 and not ( + hasattr(self.model, "guidance_embed") and self.model.guidance_embed is True + ) + + cond = self.conditioner(image=image, **cond_inputs) + if do_cfg: + un_cond = self.conditioner.unconditional_embedding( + image.shape[0], **cond_inputs + ) + + def cat_recursive(a, b): + if isinstance(a, torch.Tensor): + return torch.cat([a, b], dim=0).to(dtype) + out = {} + for key in a.keys(): + out[key] = cat_recursive(a[key], b[key]) + return out + + cond = cat_recursive(cond, un_cond) + + # 4. Latent and timestep preparation + batch_size = image.shape[0] + sigmas = np.linspace(0, 1, batch.num_inference_steps) + timesteps, _ = retrieve_timesteps( + self.scheduler, + batch.num_inference_steps, + device, + sigmas=sigmas, + ) + + generator = batch.generator + if generator is None and batch.seed is not None: + generator = torch.Generator(device=device).manual_seed(batch.seed) + + latents = self._prepare_latents(batch_size, dtype, device, generator) + + guidance = None + if hasattr(self.model, "guidance_embed") and self.model.guidance_embed is True: + guidance = torch.tensor( + [batch.guidance_scale] * batch_size, device=device, dtype=dtype + ) + + # 5. Populate batch + batch.prompt_embeds = [cond] + batch.do_classifier_free_guidance = do_cfg + batch.timesteps = timesteps + batch.latents = latents + batch.extra["shape_guidance"] = guidance + batch.extra["shape_image"] = image + return batch + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + result = VerificationResult() + result.add_check("image_path", batch.image_path, V.not_none) + result.add_check( + "num_inference_steps", batch.num_inference_steps, V.positive_int + ) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + result = VerificationResult() + result.add_check("timesteps", batch.timesteps, [V.is_tensor, V.min_dims(1)]) + result.add_check("latents", batch.latents, V.is_tensor) + result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty) + return result + + +class Hunyuan3DShapeDenoisingStage(DenoisingStage): + """Denoising stage for Hunyuan3D shape generation.""" + + def __init__(self, transformer: Any, scheduler: Any, **kwargs) -> None: + super().__init__(transformer=transformer, scheduler=scheduler, **kwargs) + + def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs): + """Prepare Hunyuan3D-specific variables for the base denoising loop.""" + assert self.transformer is not None + pipeline = self.pipeline() if self.pipeline else None + cache_dit_num_inference_steps = batch.extra.get( + "cache_dit_num_inference_steps", batch.num_inference_steps + ) + if not server_args.model_loaded["transformer"]: + loader = TransformerLoader() + self.transformer = loader.load( + server_args.model_paths["transformer"], server_args, "transformer" + ) + self._maybe_enable_cache_dit(cache_dit_num_inference_steps, batch) + self._maybe_enable_torch_compile(self.transformer) + if pipeline: + pipeline.add_module("transformer", self.transformer) + server_args.model_loaded["transformer"] = True + else: + self._maybe_enable_cache_dit(cache_dit_num_inference_steps, batch) + + timesteps = batch.timesteps + if timesteps is None: + raise ValueError("Timesteps must be provided") + + latents = batch.latents + if latents is None: + raise ValueError("Latents must be provided") + + cond = batch.prompt_embeds[0] if batch.prompt_embeds else None + if cond is None: + raise ValueError("Conditioning (prompt_embeds) must be provided") + + if batch.raw_latent_shape is None: + batch.raw_latent_shape = latents.shape + + guidance = batch.extra.get("shape_guidance") + num_inference_steps = batch.num_inference_steps + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + extra_step_kwargs = self.prepare_extra_func_kwargs( + self.scheduler.step, + {"generator": batch.generator, "eta": batch.eta}, + ) + + target_dtype = next(self.transformer.parameters()).dtype + autocast_enabled = False + + pos_cond_kwargs = {"encoder_hidden_states": cond} + neg_cond_kwargs = {} + + return { + "extra_step_kwargs": extra_step_kwargs, + "target_dtype": target_dtype, + "autocast_enabled": autocast_enabled, + "timesteps": timesteps, + "num_inference_steps": num_inference_steps, + "num_warmup_steps": num_warmup_steps, + "image_kwargs": {}, + "pos_cond_kwargs": pos_cond_kwargs, + "neg_cond_kwargs": neg_cond_kwargs, + "latents": latents, + "prompt_embeds": batch.prompt_embeds, + "neg_prompt_embeds": None, + "boundary_timestep": None, + "z": None, + "reserved_frames_mask": None, + "seq_len": None, + "guidance": guidance, + } + + def _predict_noise( + self, + current_model, + latent_model_input, + timestep, + target_dtype, + guidance: torch.Tensor, + **kwargs, + ): + """Hunyuan3D-specific noise prediction with normalized timestep.""" + cond = kwargs.get("encoder_hidden_states") + timestep_norm = timestep / self.scheduler.config.num_train_timesteps + return current_model(latent_model_input, timestep_norm, cond, guidance=guidance) + + def _predict_noise_with_cfg( + self, + current_model, + latent_model_input: torch.Tensor, + timestep, + batch: Req, + timestep_index: int, + attn_metadata, + target_dtype, + current_guidance_scale, + image_kwargs: dict[str, Any], + pos_cond_kwargs: dict[str, Any], + neg_cond_kwargs: dict[str, Any], + server_args, + guidance, + latents, + ): + """Hunyuan3D-specific CFG: concat latents, single forward, then split.""" + cond = pos_cond_kwargs.get("encoder_hidden_states") + do_cfg = batch.do_classifier_free_guidance + + if do_cfg: + latent_input = torch.cat([latent_model_input] * 2) + else: + latent_input = latent_model_input + + timestep_expanded = timestep.expand(latent_input.shape[0]).to(latents.dtype) + + with set_forward_context( + current_timestep=timestep_index, + attn_metadata=attn_metadata, + forward_batch=batch, + ): + noise_pred = self._predict_noise( + current_model=current_model, + latent_model_input=latent_input, + timestep=timestep_expanded, + target_dtype=target_dtype, + guidance=guidance, + encoder_hidden_states=cond, + ) + + if do_cfg: + noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + current_guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) + + return noise_pred + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + result = VerificationResult() + result.add_check("timesteps", batch.timesteps, [V.is_tensor, V.min_dims(1)]) + result.add_check("latents", batch.latents, V.is_tensor) + result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty) + result.add_check( + "num_inference_steps", batch.num_inference_steps, V.positive_int + ) + result.add_check("guidance_scale", batch.guidance_scale, V.non_negative_float) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + result = VerificationResult() + result.add_check("latents", batch.latents, V.is_tensor) + return result + + +class Hunyuan3DShapeExportStage(PipelineStage): + """VAE decoding and mesh extraction stage.""" + + def __init__(self, vae: Any, config: Hunyuan3D2PipelineConfig) -> None: + super().__init__() + self.vae = vae + self.config = config + + def forward(self, batch: Req, server_args: ServerArgs) -> Req: + if self.config.shape_mc_algo is not None: + try: + from sglang.multimodal_gen.runtime.models.vaes.hunyuan3d_vae import ( + SurfaceExtractors, + ) + + self.vae.surface_extractor = SurfaceExtractors[ + self.config.shape_mc_algo + ]() + except ImportError: + logger.warning( + f"Could not load SurfaceExtractors for mc_algo={self.config.shape_mc_algo}" + ) + + latents = batch.latents + + if self.config.shape_output_type != "latent": + latents = 1.0 / self.vae.scale_factor * latents + latents = self.vae(latents) + + outputs = self.vae.latents2mesh( + latents, + bounds=self.config.shape_box_v, + mc_level=self.config.shape_mc_level, + num_chunks=self.config.shape_num_chunks, + octree_resolution=self.config.shape_octree_resolution, + mc_algo=self.config.shape_mc_algo, + enable_pbar=False, + ) + else: + outputs = latents + + if self.config.shape_output_type == "trimesh": + outputs = export_to_trimesh(outputs) + + batch.extra["shape_meshes"] = outputs + return batch + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + result = VerificationResult() + result.add_check("latents", batch.latents, V.is_tensor) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + result = VerificationResult() + result.add_check("shape_meshes", batch.extra.get("shape_meshes"), V.not_none) + return result + + +class Hunyuan3DShapeSaveStage(PipelineStage): + """Mesh file export and output decision stage.""" + + def __init__(self, config: Hunyuan3D2PipelineConfig) -> None: + super().__init__() + self.config = config + + def _get_output_paths(self, batch: Req) -> tuple[str, str]: + output_path = batch.output_file_path() or os.path.join( + batch.output_path, "output.obj" + ) + if output_path.endswith(".glb"): + obj_path = output_path[:-4] + ".obj" + return obj_path, output_path + if output_path.endswith(".obj"): + return output_path, output_path + return output_path + ".obj", output_path + ".obj" + + def forward(self, batch: Req, server_args: ServerArgs) -> Req | OutputBatch: + mesh_outputs = batch.extra["shape_meshes"] + mesh = mesh_outputs[0] if isinstance(mesh_outputs, list) else mesh_outputs + if isinstance(mesh, list): + mesh = mesh[0] + + if mesh is None: + if batch.is_warmup: + logger.info( + "Skipping mesh export during warmup " + "(surface extraction returned None)" + ) + batch.extra["_mesh_failed"] = True + if self.config.paint_enable: + return batch + return OutputBatch(output_file_paths=[], metrics=batch.metrics) + raise RuntimeError( + "Mesh generation failed: surface extraction returned None. " + "The surface level may be outside the volume data range." + ) + + obj_path, return_path = self._get_output_paths(batch) + output_dir = os.path.dirname(obj_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + mesh.export(obj_path) + + batch.extra["shape_obj_path"] = obj_path + batch.extra["shape_return_path"] = return_path + + if self.config.paint_enable: + return batch + + if return_path.endswith(".glb"): + return_path = obj_path + return OutputBatch(output_file_paths=[return_path], timings=batch.timings) + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + result = VerificationResult() + result.add_check("shape_meshes", batch.extra.get("shape_meshes"), V.not_none) + return result + + +__all__ = [ + "retrieve_timesteps", + "Hunyuan3DShapeBeforeDenoisingStage", + "Hunyuan3DShapeDenoisingStage", + "Hunyuan3DShapeExportStage", + "Hunyuan3DShapeSaveStage", +] diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..988cad561ad9dbb845d2990f479fc20aace00700 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py @@ -0,0 +1,386 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Image encoding stages for I2V diffusion pipelines. + +This module contains implementations of image encoding stages for diffusion pipelines. +""" + +import PIL +import torch +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from diffusers.models.modeling_outputs import AutoencoderKLOutput + +from sglang.multimodal_gen.configs.pipeline_configs.qwen_image import ( + qwen_image_postprocess_text, +) +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE +from sglang.multimodal_gen.runtime.models.vision_utils import ( + normalize, + numpy_to_pt, + pil_to_numpy, +) +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + StageValidators as V, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + VerificationResult, +) +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE + +logger = init_logger(__name__) + + +class ImageEncodingStage(PipelineStage): + """ + Stage for encoding image prompts into embeddings for diffusion models. + + This stage handles the encoding of image prompts into the embedding space + expected by the diffusion model. + """ + + def __init__( + self, + image_processor, + image_encoder=None, + text_encoder=None, + ) -> None: + """ + Initialize the prompt encoding stage. + + Args: + text_encoder: An encoder to encode input_ids and pixel values + """ + super().__init__() + self.image_processor = image_processor + self.image_encoder = image_encoder + self.text_encoder = text_encoder + + def load_model(self): + if self.server_args.image_encoder_cpu_offload: + device = get_local_torch_device() + self.move_to_device(device) + + def offload_model(self): + if self.server_args.image_encoder_cpu_offload: + self.move_to_device("cpu") + + def move_to_device(self, device): + if self.server_args.use_fsdp_inference: + return + fields = [ + "image_processor", + "image_encoder", + ] + for field in fields: + processor = getattr(self, field, None) + if processor and hasattr(processor, "to"): + setattr(self, field, processor.to(device)) + + def encoding_qwen_image_edit(self, outputs, image_inputs): + # encoder hidden state + prompt_embeds = qwen_image_postprocess_text(outputs, image_inputs, 64) + return prompt_embeds + + @torch.no_grad() + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Encode the prompt into image encoder hidden states. + """ + + if batch.condition_image is None: + return batch + cuda_device = get_local_torch_device() + + self.load_model() + image = batch.condition_image + + image_processor_kwargs = ( + server_args.pipeline_config.prepare_image_processor_kwargs(batch) + ) + + image_inputs = self.image_processor( + images=image, return_tensors="pt", **image_processor_kwargs + ).to(cuda_device) + if self.image_encoder: + # if an image encoder is provided + with set_forward_context(current_timestep=0, attn_metadata=None): + outputs = self.image_encoder( + **image_inputs, + **server_args.pipeline_config.image_encoder_extra_args, + ) + image_embeds = server_args.pipeline_config.postprocess_image(outputs) + + batch.image_embeds.append(image_embeds) + elif self.text_encoder: + # if a text encoder is provided, e.g. Qwen-Image-Edit + # 1. neg prompt embeds + if batch.do_classifier_free_guidance: + neg_image_processor_kwargs = ( + server_args.pipeline_config.prepare_image_processor_kwargs( + batch, neg=True + ) + ) + + neg_image_inputs = self.image_processor( + images=image, return_tensors="pt", **neg_image_processor_kwargs + ).to(cuda_device) + + with set_forward_context(current_timestep=0, attn_metadata=None): + outputs = self.text_encoder( + input_ids=image_inputs.input_ids, + attention_mask=image_inputs.attention_mask, + pixel_values=image_inputs.pixel_values, + image_grid_thw=image_inputs.image_grid_thw, + output_hidden_states=True, + ) + if batch.do_classifier_free_guidance: + neg_outputs = self.text_encoder( + input_ids=neg_image_inputs.input_ids, + attention_mask=neg_image_inputs.attention_mask, + pixel_values=neg_image_inputs.pixel_values, + image_grid_thw=neg_image_inputs.image_grid_thw, + output_hidden_states=True, + ) + batch.prompt_embeds.append( + self.encoding_qwen_image_edit(outputs, image_inputs) + ) + + if batch.do_classifier_free_guidance: + batch.negative_prompt_embeds.append( + self.encoding_qwen_image_edit(neg_outputs, neg_image_inputs) + ) + + self.offload_model() + + return batch + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify image encoding stage inputs.""" + result = VerificationResult() + if batch.debug: + logger.debug(f"{batch.condition_image=}") + logger.debug(f"{batch.image_embeds=}") + result.add_check("pil_image", batch.condition_image, V.not_none) + result.add_check("image_embeds", batch.image_embeds, V.is_list) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify image encoding stage outputs.""" + result = VerificationResult() + # result.add_check("image_embeds", batch.image_embeds, V.list_of_tensors_dims(3)) + return result + + +class ImageVAEEncodingStage(PipelineStage): + """ + Stage for encoding pixel representations into latent space. + + This stage handles the encoding of pixel representations into the final + input format (e.g., image_latents). + """ + + def __init__(self, vae: ParallelTiledVAE, **kwargs) -> None: + super().__init__() + self.vae: ParallelTiledVAE = vae + + def load_model(self): + self.vae = self.vae.to(get_local_torch_device()) + + def offload_model(self): + if self.server_args.vae_cpu_offload: + self.vae = self.vae.to("cpu") + + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Encode pixel representations into latent space. + """ + + if batch.condition_image is None: + return batch + + self.load_model() + num_frames = batch.num_frames + + images = ( + batch.vae_image if batch.vae_image is not None else batch.condition_image + ) + if not isinstance(images, list): + images = [images] + + all_image_latents = [] + prepare_condition_image_latent_ids = getattr( + server_args.pipeline_config, "prepare_condition_image_latent_ids", None + ) + condition_latents = [] if callable(prepare_condition_image_latent_ids) else None + for image in images: + image = self.preprocess( + image, + ).to(get_local_torch_device(), dtype=torch.float32) + + # (B, C, H, W) -> (B, C, 1, H, W) + image = image.unsqueeze(2) + + if num_frames == 1: + video_condition = image + else: + video_condition = torch.cat( + [ + image, + image.new_zeros( + image.shape[0], + image.shape[1], + num_frames - 1, + image.shape[3], + image.shape[4], + ), + ], + dim=2, + ) + video_condition = video_condition.to( + device=get_local_torch_device(), dtype=torch.float32 + ) + + # Setup VAE precision + vae_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision] + vae_autocast_enabled = ( + vae_dtype != torch.float32 + ) and not server_args.disable_autocast + + # Encode Image + with torch.autocast( + device_type=current_platform.device_type, + dtype=vae_dtype, + enabled=vae_autocast_enabled, + ): + if server_args.pipeline_config.vae_tiling: + self.vae.enable_tiling() + # if server_args.vae_sp: + # self.vae.enable_parallel() + if not vae_autocast_enabled: + video_condition = video_condition.to(vae_dtype) + latent_dist: DiagonalGaussianDistribution = self.vae.encode( + video_condition + ) + # for auto_encoder from diffusers + if isinstance(latent_dist, AutoencoderKLOutput): + latent_dist = latent_dist.latent_dist + + generator = batch.generator + if generator is None: + raise ValueError("Generator must be provided") + + sample_mode = server_args.pipeline_config.vae_config.encode_sample_mode() + + latent_condition = self.retrieve_latents( + latent_dist, generator, sample_mode=sample_mode + ) + latent_condition = server_args.pipeline_config.postprocess_vae_encode( + latent_condition, self.vae + ) + + scaling_factor, shift_factor = ( + server_args.pipeline_config.get_decode_scale_and_shift( + device=latent_condition.device, + dtype=latent_condition.dtype, + vae=self.vae, + ) + ) + + # apply shift & scale if needed + if isinstance(shift_factor, torch.Tensor): + shift_factor = shift_factor.to(latent_condition.device) + + if isinstance(scaling_factor, torch.Tensor): + scaling_factor = scaling_factor.to(latent_condition.device) + + latent_condition -= shift_factor + latent_condition = latent_condition * scaling_factor + + if condition_latents is not None: + condition_latents.append(latent_condition) + + image_latent = server_args.pipeline_config.postprocess_image_latent( + latent_condition, batch + ) + all_image_latents.append(image_latent) + + batch.image_latent = torch.cat(all_image_latents, dim=1) + if condition_latents is not None: + prepare_condition_image_latent_ids(condition_latents, batch) + + self.offload_model() + return batch + + def retrieve_latents( + self, + encoder_output: DiagonalGaussianDistribution, + generator: torch.Generator | None = None, + sample_mode: str = "sample", + ): + if sample_mode == "sample": + return encoder_output.sample(generator) + elif sample_mode == "argmax": + return encoder_output.mode() + else: + raise AttributeError("Could not access latents of provided encoder_output") + + def preprocess( + self, + image: torch.Tensor | PIL.Image.Image, + ) -> torch.Tensor: + + if isinstance(image, PIL.Image.Image): + image = pil_to_numpy(image) # to np + image = numpy_to_pt(image) # to pt + + do_normalize = True + if image.min() < 0: + do_normalize = False + if do_normalize: + image = normalize(image) + + return image + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify encoding stage inputs.""" + result = VerificationResult() + + assert batch.condition_image is None or ( + isinstance(batch.condition_image, PIL.Image.Image) + or isinstance(batch.condition_image, torch.Tensor) + or isinstance(batch.condition_image, list) + ) + assert batch.height is not None and isinstance(batch.height, int) + assert batch.width is not None and isinstance(batch.width, int) + assert batch.num_frames is not None and isinstance(batch.num_frames, int) + + result.add_check("generator", batch.generator, V.generator_or_list_generators) + result.add_check("height", batch.height, V.positive_int) + result.add_check("width", batch.width, V.positive_int) + result.add_check("num_frames", batch.num_frames, V.positive_int) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify encoding stage outputs.""" + result = VerificationResult() + # result.add_check( + # "image_latent", batch.image_latent, [V.is_tensor, V.with_dims(5)] + # ) + return result diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/input_validation.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/input_validation.py new file mode 100644 index 0000000000000000000000000000000000000000..aa0b686b828ae55fbd93186d24e5b5df338654a0 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/input_validation.py @@ -0,0 +1,372 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Input validation stage for diffusion pipelines. +""" + +import numpy as np +import torch +import torchvision.transforms.functional as TF +from PIL import Image + +from sglang.multimodal_gen.configs.pipeline_configs import WanI2V480PConfig +from sglang.multimodal_gen.configs.pipeline_configs.base import ModelTaskType +from sglang.multimodal_gen.configs.pipeline_configs.mova import MOVAPipelineConfig +from sglang.multimodal_gen.runtime.models.vision_utils import load_image, load_video +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + StageValidators, + VerificationResult, +) +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import best_output_size + +logger = init_logger(__name__) + +# Alias for convenience +V = StageValidators + + +# TODO: since this might change sampling params after logging, should be do this beforehand? + + +class InputValidationStage(PipelineStage): + """ + Stage for validating and preparing inputs for diffusion pipelines. + + This stage validates that all required inputs are present and properly formatted + before proceeding with the diffusion process. + + In this stage, input image and output image may be resized + """ + + def __init__(self, vae_image_processor=None): + super().__init__() + self.vae_image_processor = vae_image_processor + + @staticmethod + def _calculate_dimensions_from_area( + max_area: float, aspect_ratio: float, mod_value: int + ) -> tuple[int, int]: + """ + Calculate output dimensions based on maximum area and aspect ratio. + + Args: + max_area: Maximum area constraint for the output + aspect_ratio: Target aspect ratio (height/width) + mod_value: Value to round dimensions to (typically vae_scale * patch_size) + + Returns: + Tuple of (width, height) rounded to multiples of mod_value + """ + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + return width, height + + def _generate_seeds(self, batch: Req, server_args: ServerArgs): + """Generate seeds for the inference""" + seed = batch.seed + num_videos_per_prompt = batch.num_outputs_per_prompt + + assert seed is not None + seeds = [seed + i for i in range(num_videos_per_prompt)] + batch.seeds = seeds + + # Create generators based on generator_device parameter + # Note: This will overwrite any existing batch.generator + generator_device = batch.generator_device + + if generator_device == "cpu": + device_str = "cpu" + else: + device_str = current_platform.device_type + + batch.generator = [ + torch.Generator(device_str).manual_seed(seed) for seed in seeds + ] + + def preprocess_condition_image( + self, + batch: Req, + server_args: ServerArgs, + condition_image_width, + condition_image_height, + ): + """ + preprocess condition image + NOTE: condition image resizing is only allowed in InputValidationStage + """ + if batch.condition_image is not None and ( + server_args.pipeline_config.task_type == ModelTaskType.I2I + or server_args.pipeline_config.task_type == ModelTaskType.TI2I + ): + # calculate new condition image size + if not isinstance(batch.condition_image, list): + batch.condition_image = [batch.condition_image] + + processed_images = [] + final_image = batch.condition_image[-1] + config = server_args.pipeline_config + config.preprocess_vae_image(batch, self.vae_image_processor) + + for img in batch.condition_image: + size = config.calculate_condition_image_size(img, img.width, img.height) + if size is not None: + width, height = size + img, _ = config.preprocess_condition_image( + img, width, height, self.vae_image_processor + ) + + processed_images.append(img) + + batch.condition_image = processed_images + calculated_size = config.prepare_calculated_size(final_image) + + # adjust output image size + if calculated_size is not None: + calculated_width, calculated_height = calculated_size + width = batch.width or calculated_width + height = batch.height or calculated_height + multiple_of = ( + server_args.pipeline_config.vae_config.get_vae_scale_factor() * 2 + ) + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + batch.width = width + batch.height = height + + elif server_args.pipeline_config.task_type == ModelTaskType.TI2V: + if server_args.pipeline_config.skip_input_image_preprocess: + return + # duplicate with vae_image_processor + # further processing for ti2v task + if isinstance( + batch.condition_image, list + ): # not support multi image input yet. + batch.condition_image = batch.condition_image[0] + + img = batch.condition_image + ih, iw = img.height, img.width + patch_size = server_args.pipeline_config.dit_config.arch_config.patch_size + vae_stride = ( + server_args.pipeline_config.vae_config.arch_config.scale_factor_spatial + ) + dh, dw = patch_size[1] * vae_stride, patch_size[2] * vae_stride + max_area = 704 * 1280 + ow, oh = best_output_size(iw, ih, dw, dh, max_area) + + scale = max(ow / iw, oh / ih) + img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS) + logger.debug("resized condition image to: %sx%s", img.height, img.width) + + # center-crop + x1 = (img.width - ow) // 2 + y1 = (img.height - oh) // 2 + img = img.crop((x1, y1, x1 + ow, y1 + oh)) + assert img.width == ow and img.height == oh + + # to tensor + img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device).unsqueeze(1) + img = img.unsqueeze(0) + batch.height = oh + batch.width = ow + # TODO: should we store in a new field: pixel values? + batch.condition_image = img + + elif isinstance(server_args.pipeline_config, WanI2V480PConfig): + # TODO: could we merge with above? + # resize image only, Wan2.1 I2V + if isinstance(batch.condition_image, list): + batch.condition_image = batch.condition_image[ + 0 + ] # not support multi image input yet. + + max_area = server_args.pipeline_config.max_area + aspect_ratio = condition_image_height / condition_image_width + mod_value = ( + server_args.pipeline_config.vae_config.arch_config.scale_factor_spatial + * server_args.pipeline_config.dit_config.arch_config.patch_size[1] + ) + width, height = self._calculate_dimensions_from_area( + max_area, aspect_ratio, mod_value + ) + + batch.condition_image = batch.condition_image.resize((width, height)) + batch.height = height + batch.width = width + + elif issubclass(type(server_args.pipeline_config), MOVAPipelineConfig): + # resize image only, MOVA + image = batch.condition_image + if isinstance(image, list): + image = image[0] # not support multi image input yet. + + max_area = server_args.pipeline_config.max_area + if hasattr(batch, "height") and hasattr(batch, "width"): + aspect_ratio = batch.height / batch.width + else: + aspect_ratio = ( + batch.sampling_params.height / batch.sampling_params.width + ) + mod_value = ( + server_args.pipeline_config.vae_config.arch_config.scale_factor_spatial + * server_args.pipeline_config.dit_config.arch_config.patch_size[1] + ) + width, height = self._calculate_dimensions_from_area( + max_area, aspect_ratio, mod_value + ) + + config = server_args.pipeline_config + image, (final_w, final_h) = ( + server_args.pipeline_config.preprocess_condition_image( + image, width, height, self.vae_image_processor + ) + ) + batch.condition_image = image + batch.width = final_w + batch.height = final_h + + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Validate and prepare inputs. + """ + + self._generate_seeds(batch, server_args) + + if ( + server_args.pipeline_config.task_type == ModelTaskType.I2M + and batch.num_inference_steps is None + and hasattr(server_args.pipeline_config, "shape_num_inference_steps") + ): + batch.num_inference_steps = ( + server_args.pipeline_config.shape_num_inference_steps + ) + + # Ensure prompt is properly formatted (I2M can be image-only) + if ( + server_args.pipeline_config.task_type != ModelTaskType.I2M + and batch.prompt is None + and batch.prompt_embeds is None + ): + raise ValueError("Either `prompt` or `prompt_embeds` must be provided") + + # Ensure negative prompt is properly formatted if using classifier-free guidance + if ( + batch.do_classifier_free_guidance + and batch.negative_prompt is None + and batch.negative_prompt_embeds is None + ): + raise ValueError( + "For classifier-free guidance, either `negative_prompt` or " + "`negative_prompt_embeds` must be provided" + ) + + # Validate number of inference steps + if batch.num_inference_steps <= 0: + raise ValueError( + f"Number of inference steps must be positive, but got {batch.num_inference_steps}" + ) + + # Validate guidance scale if using classifier-free guidance + if batch.do_classifier_free_guidance and batch.guidance_scale < 0: + raise ValueError( + f"Guidance scale must be positive, but got {batch.guidance_scale}" + ) + + # for i2v, get image from image_path + # @TODO(Wei) hard-coded for wan2.2 5b ti2v for now. Should put this in image_encoding stage + if batch.image_path is not None: + if isinstance(batch.image_path, list): + batch.condition_image = [] + for path in batch.image_path: + if path.endswith(".mp4"): + image = load_video(path)[0] + else: + image = load_image(path) + batch.condition_image.append(image) + + # Use the first image for size reference + condition_image_width = batch.condition_image[0].width + condition_image_height = batch.condition_image[0].height + batch.original_condition_image_size = ( + condition_image_width, + condition_image_height, + ) + else: + if batch.image_path.endswith(".mp4"): + image = load_video(batch.image_path)[0] + else: + image = load_image(batch.image_path) + batch.condition_image = image + condition_image_width, condition_image_height = ( + image.width, + image.height, + ) + batch.original_condition_image_size = image.size + + if server_args.pipeline_config.task_type != ModelTaskType.I2M: + self.preprocess_condition_image( + batch, server_args, condition_image_width, condition_image_height + ) + + # if height or width is not specified at this point, set default to 720p + default_height = 720 + default_width = 1280 + if batch.height is None and batch.width is None: + batch.height = default_height + batch.width = default_width + elif batch.height is None: + batch.height = batch.width * default_height // default_width + elif batch.width is None: + batch.width = batch.height * default_width // default_height + + return batch + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify input validation stage inputs.""" + result = VerificationResult() + result.add_check("seed", batch.seed, [V.not_none, V.non_negative_int]) + result.add_check( + "num_videos_per_prompt", batch.num_outputs_per_prompt, V.positive_int + ) + if server_args.pipeline_config.task_type != ModelTaskType.I2M: + result.add_check( + "prompt_or_embeds", + None, + lambda _: V.string_or_list_strings(batch.prompt) + or V.list_not_empty(batch.prompt_embeds), + ) + + if server_args.pipeline_config.task_type != ModelTaskType.I2M: + result.add_check( + "num_inference_steps", batch.num_inference_steps, V.positive_int + ) + else: + result.add_check( + "num_inference_steps", + batch.num_inference_steps, + lambda x: x is None or V.positive_int(x), + ) + result.add_check( + "guidance_scale", + batch.guidance_scale, + lambda x: not batch.do_classifier_free_guidance or V.non_negative_float(x), + ) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify input validation stage outputs.""" + result = VerificationResult() + result.add_check("height", batch.height, V.positive_int) + result.add_check("width", batch.width, V.positive_int) + result.add_check("seeds", batch.seeds, V.list_not_empty) + result.add_check("generator", batch.generator, V.generator_or_list_generators) + return result diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/latent_preparation.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/latent_preparation.py new file mode 100644 index 0000000000000000000000000000000000000000..dbae650c7fdc7ec4037b43cf5029938e805920a9 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/latent_preparation.py @@ -0,0 +1,152 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Latent preparation stage for diffusion pipelines. +""" + +from diffusers.utils.torch_utils import randn_tensor + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + StageValidators as V, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + VerificationResult, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class LatentPreparationStage(PipelineStage): + """ + Stage for preparing initial latent variables for the diffusion process. + + This stage handles the preparation of the initial latent variables that will be + denoised during the diffusion process. + """ + + def __init__(self, scheduler, transformer) -> None: + super().__init__() + self.scheduler = scheduler + self.transformer = transformer + + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Prepare initial latent variables for the diffusion process. + + + + Returns: + The batch with prepared latent variables. + """ + + # Adjust video length based on VAE version if needed + latent_num_frames = self.adjust_video_length(batch, server_args) + + batch_size = batch.batch_size + + # Get required parameters + dtype = batch.prompt_embeds[0].dtype + device = get_local_torch_device() + generator = batch.generator + latents = batch.latents + num_frames = ( + latent_num_frames if latent_num_frames is not None else batch.num_frames + ) + height = batch.height + width = batch.width + + # TODO(will): remove this once we add input/output validation for stages + if height is None or width is None: + raise ValueError("Height and width must be provided") + + # Validate generator if it's a list + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # Generate or use provided latents + if latents is None: + shape = server_args.pipeline_config.prepare_latent_shape( + batch, batch_size, num_frames + ) + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) + + latent_ids = server_args.pipeline_config.maybe_prepare_latent_ids(latents) + + if latent_ids is not None: + batch.latent_ids = latent_ids.to(device=device) + + latents = server_args.pipeline_config.maybe_pack_latents( + latents, batch_size, batch + ) + else: + latents = latents.to(device) + + # Scale the initial noise if needed + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + # Update batch with prepared latents + batch.latents = latents + batch.raw_latent_shape = latents.shape + return batch + + def adjust_video_length(self, batch: Req, server_args: ServerArgs) -> int: + """ + Adjust video length based on VAE version. + """ + + video_length = batch.num_frames + latent_num_frames = video_length + use_temporal_scaling_frames = ( + server_args.pipeline_config.vae_config.use_temporal_scaling_frames + ) + if use_temporal_scaling_frames: + temporal_scale_factor = ( + server_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio + ) + latent_num_frames = (video_length - 1) // temporal_scale_factor + 1 + return int(latent_num_frames) + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify latent preparation stage inputs.""" + result = VerificationResult() + result.add_check( + "prompt_or_embeds", + None, + lambda _: V.string_or_list_strings(batch.prompt) + or V.list_not_empty(batch.prompt_embeds), + ) + result.add_check("prompt_embeds", batch.prompt_embeds, V.list_of_tensors) + result.add_check( + "num_videos_per_prompt", batch.num_outputs_per_prompt, V.positive_int + ) + result.add_check("generator", batch.generator, V.generator_or_list_generators) + result.add_check("num_frames", batch.num_frames, V.positive_int) + result.add_check("height", batch.height, V.positive_int) + result.add_check("width", batch.width, V.positive_int) + result.add_check("latents", batch.latents, V.none_or_tensor) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify latent preparation stage outputs.""" + result = VerificationResult() + if batch.debug: + logger.debug(f"{batch.raw_latent_shape=}") + # disable temporarily for image-generation models + # result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)]) + result.add_check("raw_latent_shape", batch.raw_latent_shape, V.is_tuple) + return result diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/latent_preparation_av.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/latent_preparation_av.py new file mode 100644 index 0000000000000000000000000000000000000000..4440a30606b3656e17e413942e49abeefea7a3ec --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/latent_preparation_av.py @@ -0,0 +1,104 @@ +import torch +from diffusers.utils.torch_utils import randn_tensor + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.latent_preparation import ( + LatentPreparationStage, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + StageValidators as V, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + VerificationResult, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class LTX2AVLatentPreparationStage(LatentPreparationStage): + """ + LTX-2 specific latent preparation stage that handles both video and audio latents. + """ + + def __init__(self, scheduler, transformer=None, audio_vae=None): + super().__init__(scheduler, transformer) + self.audio_vae = audio_vae + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify latent preparation stage inputs.""" + result = VerificationResult() + result.add_check( + "prompt_or_embeds", + None, + lambda _: V.string_or_list_strings(batch.prompt) + or V.list_not_empty(batch.prompt_embeds) + or V.is_tensor(batch.prompt_embeds), + ) + + if isinstance(batch.prompt_embeds, list): + result.add_check("prompt_embeds", batch.prompt_embeds, V.list_of_tensors) + else: + result.add_check("prompt_embeds", batch.prompt_embeds, V.is_tensor) + + result.add_check( + "num_videos_per_prompt", batch.num_outputs_per_prompt, V.positive_int + ) + result.add_check("generator", batch.generator, V.generator_or_list_generators) + result.add_check("num_frames", batch.num_frames, V.positive_int) + result.add_check("height", batch.height, V.positive_int) + result.add_check("width", batch.width, V.positive_int) + result.add_check("latents", batch.latents, V.none_or_tensor) + return result + + def forward(self, batch: Req, server_args: ServerArgs) -> Req: + # 1. Prepare Video Latents using base class logic + # This sets batch.latents and batch.raw_latent_shape + batch = super().forward(batch, server_args) + + # 2. Prepare Audio Latents (optional) + # Default to True if not specified + try: + generate_audio = batch.generate_audio + except AttributeError: + generate_audio = True + if not generate_audio: + batch.audio_latents = None + batch.raw_audio_latent_shape = None + return batch + + device = get_local_torch_device() + if isinstance(batch.prompt_embeds, list) and batch.prompt_embeds: + dtype = batch.prompt_embeds[0].dtype + elif isinstance(batch.prompt_embeds, torch.Tensor): + dtype = batch.prompt_embeds.dtype + else: + dtype = torch.float16 + generator = batch.generator + + audio_latents = batch.audio_latents + batch_size = batch.batch_size + num_frames = batch.num_frames + + if audio_latents is None: + shape = server_args.pipeline_config.prepare_audio_latent_shape( + batch, batch_size, num_frames + ) + + audio_latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) + else: + audio_latents = audio_latents.to(device) + + audio_latents = server_args.pipeline_config.maybe_pack_audio_latents( + audio_latents, batch_size, batch + ) + + # Store in batch + batch.audio_latents = audio_latents + batch.raw_audio_latent_shape = audio_latents.shape + + return batch diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/glm_image.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/glm_image.py new file mode 100644 index 0000000000000000000000000000000000000000..48476a881bfd4fd6db644f6c2090cec03bddd9cb --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/glm_image.py @@ -0,0 +1,823 @@ +import inspect +import re +import time +from math import sqrt +from typing import List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.utils.torch_utils import randn_tensor + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.models.dits.glm_image import GlmImageKVCache +from sglang.multimodal_gen.runtime.models.vision_utils import load_image +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + base_shift: float = 0.25, + max_shift: float = 0.75, +) -> float: + m = (image_seq_len / base_seq_len) ** 0.5 + mu = m * max_shift + base_shift + return mu + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + """ + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + accepts_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + + if timesteps is not None and sigmas is not None: + if not accepts_timesteps and not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep or sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps( + timesteps=timesteps, sigmas=sigmas, device=device, **kwargs + ) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is not None and sigmas is None: + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is None and sigmas is not None: + if not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, + generator: Optional[torch.Generator] = None, + sample_mode: str = "sample", +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class GlmImageBeforeDenoisingStage(PipelineStage): + r""" + Pipeline for text-to-image generation using GLM-Image. + + This pipeline integrates both the AR (autoregressive) model for token generation and the DiT (diffusion + transformer) model for image decoding. + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder for glyph embeddings. + tokenizer (`PreTrainedTokenizer`): + Tokenizer for the text encoder. + processor (`AutoProcessor`): + Processor for the AR model to handle chat templates and tokenization. + vision_language_encoder ([`GlmImageForConditionalGeneration`]): + The AR model that generates image tokens from text prompts. + transformer ([`GlmImageTransformer2DModel`]): + A text conditioned transformer to denoise the encoded image latents (DiT). + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + def __init__( + self, + tokenizer, + processor, + text_encoder, + vision_language_encoder, + vae, + transformer, + scheduler, + ): + super().__init__() + + self.tokenizer = tokenizer + self.processor = processor + self.text_encoder = text_encoder + self.vision_language_encoder = vision_language_encoder + self.vae = vae + self.transformer = transformer + self.scheduler = scheduler + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) + if getattr(self, "vae", None) + else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") + and self.transformer is not None + and hasattr(self.transformer.config, "sample_size") + else 128 + ) + + def _parse_and_expand_shape_info( + self, prompt: str + ) -> Tuple[str, int, int, int, int]: + """ + Parse the shape info from prompt and expand it for AR model. + + Args: + prompt: The prompt containing H W shape specification + + Returns: + Tuple of (expanded_prompt, token_h, token_w, prev_token_h, prev_token_w) + """ + match = re.search(r"(\d+)\s+(\d+)", prompt) + if match is None: + raise ValueError( + f"Prompt must contain shape info in format 'H W', got: {prompt}" + ) + + token_h, token_w = int(match.group(1)), int(match.group(2)) + ratio = token_h / token_w + prev_token_h = int(sqrt(ratio) * 16) + prev_token_w = int(sqrt(1 / ratio) * 16) + + old_shape = f"{token_h} {token_w}" + new_shape = ( + f"{token_h} {token_w}{prev_token_h} {prev_token_w}" + ) + expanded_prompt = prompt.replace(old_shape, new_shape) + + return expanded_prompt, token_h, token_w, prev_token_h, prev_token_w + + def _build_image_grid_thw( + self, + token_h: int, + token_w: int, + prev_token_h: int, + prev_token_w: int, + existing_grid: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Build image grid tensor for AR model. + + For text-to-image: creates grid for large image + small image For image-to-image: appends new image to existing + grid + """ + if existing_grid is None or existing_grid.numel() == 0: + # Text-to-image: large image + small image + return torch.tensor( + [ + [1, token_h, token_w], + [1, prev_token_h, prev_token_w], + ], + device=device, + ) + else: + # Image-to-image: append to existing + return torch.cat( + [existing_grid, torch.tensor([[1, token_h, token_w]], device=device)], + dim=0, + ) + + def _calculate_ar_generation_params( + self, + token_h: int, + token_w: int, + prev_token_h: int, + prev_token_w: int, + is_text_to_image: bool, + ) -> Tuple[int, int]: + """ + Calculate max_new_tokens and large_image_start_offset for AR generation. + """ + large_image_tokens = token_h * token_w + small_image_tokens = prev_token_h * prev_token_w + + if is_text_to_image: + max_new_tokens = small_image_tokens + large_image_tokens + 1 + large_image_start_offset = small_image_tokens + else: + max_new_tokens = large_image_tokens + 1 + large_image_start_offset = 0 + + return max_new_tokens, large_image_start_offset + + def _extract_large_image_tokens( + self, + outputs: torch.Tensor, + input_length: int, + large_image_start_offset: int, + large_image_tokens: int, + ) -> torch.Tensor: + """ + Extract the large image tokens from AR model output. + """ + generated_tokens = outputs[0][input_length:] + large_image_start = large_image_start_offset + large_image_end = large_image_start + large_image_tokens + return generated_tokens[large_image_start:large_image_end] + + def _upsample_d32_to_d16( + self, token_ids: torch.Tensor, token_h: int, token_w: int + ) -> torch.Tensor: + """ + Upsample token IDs from d32 format to d16 format. + + AR model generates tokens at d32 resolution (each token = 32x32 pixels). DiT expects tokens at d16 resolution + (each token = 16x16 pixels). This function performs 2x nearest-neighbor upsampling. + + Args: + token_ids: Token IDs of shape [N] where N = token_h * token_w + token_h: Height in d32 token units + token_w: Width in d32 token units + + Returns: + Upsampled token IDs of shape [1, N*4] where N*4 = (token_h*2) * (token_w*2) + """ + # Reshape to spatial format: [1, 1, H, W] + token_ids = token_ids.view(1, 1, token_h, token_w) + + # 2x nearest-neighbor upsampling + token_ids = torch.nn.functional.interpolate( + token_ids.float(), scale_factor=2, mode="nearest" + ).to(dtype=torch.long) + + # Flatten back to [1, H*W*4] + token_ids = token_ids.view(1, -1) + + return token_ids + + @staticmethod + def _compute_generation_params( + image_grid_thw, + is_text_to_image: bool, + ): + grid_sizes = [] + grid_hw = [] + + for i in range(image_grid_thw.shape[0]): + t, h, w = image_grid_thw[i].tolist() + grid_sizes.append(int(h * w)) + grid_hw.append((int(h), int(w))) + + if not is_text_to_image: + max_new_tokens = grid_sizes[-1] + 1 + large_image_start_offset = 0 + target_grid_h, target_grid_w = grid_hw[-1] + else: + total_tokens = sum(grid_sizes) + max_new_tokens = total_tokens + 1 + large_image_start_offset = sum(grid_sizes[1:]) + target_grid_h, target_grid_w = grid_hw[0] + return max_new_tokens, large_image_start_offset, target_grid_h, target_grid_w + + @staticmethod + def _upsample_token_ids( + token_ids: torch.Tensor, token_h: int, token_w: int + ) -> torch.Tensor: + token_ids = token_ids.view(1, 1, token_h, token_w) + token_ids = torch.nn.functional.interpolate( + token_ids.float(), scale_factor=2, mode="nearest" + ).to(dtype=torch.long) + token_ids = token_ids.view(1, -1) + return token_ids + + def generate_prior_tokens( + self, + prompt: str, + height: int, + width: int, + image: Optional[List[PIL.Image.Image]] = None, + factor: int = 32, + ) -> Tuple[torch.Tensor, int, int]: + """ + Generate prior tokens using the AR (vision_language_encoder) model. + + Args: + prompt: The text prompt with shape info (e.g., "description36 24") + condition_images: Optional list of condition images for i2i + + Returns: + Tuple of (prior_token_ids, pixel_height, pixel_width) + - prior_token_ids: Upsampled to d16 format, shape [1, token_h*token_w*4] + - pixel_height: Image height in pixels + - pixel_width: Image width in pixels + """ + device = self.vision_language_encoder.device + height = (height // factor) * factor + width = (width // factor) * factor + + is_text_to_image = image is None or len(image) == 0 + # Build messages for processor + content = [] + if image is not None: + for img in image: + content.append({"type": "image", "image": img}) + content.append({"type": "text", "text": prompt}) + messages = [{"role": "user", "content": content}] + + inputs = self.processor.apply_chat_template( + messages, + tokenize=True, + target_h=height, + target_w=width, + return_dict=True, + return_tensors="pt", + ).to(device) + + image_grid_thw = inputs.get("image_grid_thw") + max_new_tokens, large_image_offset, token_h, token_w = ( + self._compute_generation_params( + image_grid_thw=image_grid_thw, is_text_to_image=is_text_to_image + ) + ) + + prior_token_image_ids = None + if image is not None: + prior_token_image_embed = self.vision_language_encoder.get_image_features( + inputs["pixel_values"], image_grid_thw[:-1] + ) + prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0) + prior_token_image_ids = self.vision_language_encoder.get_image_tokens( + prior_token_image_embed, image_grid_thw[:-1] + ) + + # For GLM-Image, greedy decoding is not allowed; it may cause repetitive outputs. + # max_new_tokens must be exactly grid_h * grid_w + 1 (the +1 is for EOS). + outputs = self.vision_language_encoder.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=True, + ) + + prior_token_ids_d32 = self._extract_large_image_tokens( + outputs, + inputs["input_ids"].shape[-1], + large_image_offset, + token_h * token_w, + ) + prior_token_ids = self._upsample_token_ids( + prior_token_ids_d32, token_h, token_w + ) + + return prior_token_ids, prior_token_image_ids + + def get_glyph_texts(self, prompt): + prompt = prompt[0] if isinstance(prompt, list) else prompt + ocr_texts = ( + re.findall(r"'([^']*)'", prompt) + + re.findall(r"“([^“”]*)”", prompt) + + re.findall(r'"([^"]*)"', prompt) + + re.findall(r"「([^「」]*)」", prompt) + ) + return ocr_texts + + def _get_glyph_embeds( + self, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 2048, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + glyph_texts = self.get_glyph_texts(prompt) + input_ids = self.tokenizer( + glyph_texts if len(glyph_texts) > 0 else [""], + max_length=max_sequence_length, + truncation=True, + ).input_ids + input_ids = [ + [self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ + for input_ids_ in input_ids + ] + max_length = max(len(input_ids_) for input_ids_ in input_ids) + attention_mask = torch.tensor( + [ + [1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) + for input_ids_ in input_ids + ], + device=device, + ) + input_ids = torch.tensor( + [ + input_ids_ + + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) + for input_ids_ in input_ids + ], + device=device, + ) + outputs = self.text_encoder(input_ids, attention_mask=attention_mask) + glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0) + + return glyph_embeds.to(device=device, dtype=dtype) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + prompt_embeds: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 2048, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + max_sequence_length (`int`, defaults to `2048`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_glyph_embeds( + prompt, max_sequence_length, device, dtype + ) + + seq_len = prompt_embeds.size(1) + prompt_embeds = prompt_embeds.repeat(1, 1, 1) + prompt_embeds = prompt_embeds.view(1, seq_len, -1) + + negative_prompt_embeds = None + if do_classifier_free_guidance: + negative_prompt = "" + negative_prompt = ( + batch_size * [negative_prompt] + if isinstance(negative_prompt, str) + else negative_prompt + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_glyph_embeds( + negative_prompt, max_sequence_length, device, dtype + ) + + seq_len = negative_prompt_embeds.size(1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, 1, 1) + negative_prompt_embeds = negative_prompt_embeds.view(1, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + ): + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * self.transformer.config.patch_size) + != 0 + or width is not None + and width % (self.transformer.config.patch_size) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs + for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + + guidance_scale = batch.guidance_scale + prompt = batch.prompt + num_inference_steps = batch.num_inference_steps + if batch.image_path is not None: + ar_condition_images = [ + load_image(img_path) for img_path in batch.image_path + ] + else: + ar_condition_images = None + + height = batch.height + width = batch.width + + device = get_local_torch_device() + max_sequence_length = 1024 + generator = torch.Generator(device=device).manual_seed(batch.seed) + attention_kwargs = {} + prompt_embeds = None + do_classifier_free_guidance = True + dtype = torch.bfloat16 + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + batch_size = 1 + + device = get_local_torch_device() + + if ar_condition_images is not None: + height = height or ar_condition_images[0].height + width = width or ar_condition_images[0].width + time_start = time.time() + prior_token_id, prior_token_image_ids = self.generate_prior_tokens( + prompt=prompt, + image=ar_condition_images, + height=height, + width=width, + ) + prior_token_id = prior_token_id.to(device=device) + time_end = time.time() + logger.info(f"generate_prior_tokens time: {time_end - time_start}") + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + # 4. process images + if ar_condition_images is not None: + preprocessed_condition_images = [] + for img in ar_condition_images: + image_height, image_width = ( + img.size[::-1] + if isinstance(img, PIL.Image.Image) + else img.shape[:2] + ) + multiple_of = self.vae_scale_factor * self.transformer.config.patch_size + image_height = (image_height // multiple_of) * multiple_of + image_width = (image_width // multiple_of) * multiple_of + img = self.image_processor.preprocess( + img, height=image_height, width=image_width + ) + preprocessed_condition_images.append(img) + ar_condition_images = preprocessed_condition_images + + # 5. Prepare latents and (optional) condition_images kv cache + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size=1, + num_channels_latents=latent_channels, + height=height, + width=width, + dtype=torch.float32, + device=device, + generator=generator, + ) + + kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers) + + if ar_condition_images is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view( + 1, self.vae.config.latent_channels, 1, 1 + ) + latents_std = torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.latent_channels, 1, 1 + ) + + latents_mean = latents_mean.to(device=device, dtype=prompt_embeds.dtype) + latents_std = latents_std.to(device=device, dtype=prompt_embeds.dtype) + + for condition_image, condition_image_prior_token_id in zip( + ar_condition_images, prior_token_image_ids + ): + condition_image = condition_image.to( + device=device, dtype=prompt_embeds.dtype + ) + + condition_latent = retrieve_latents( + self.vae.encode(condition_image), + generator=generator, + sample_mode="argmax", + ) + condition_latent = (condition_latent - latents_mean) / latents_std + + # Do not remove. + # It would be use to run the reference image through a + # forward pass at timestep 0 and keep the KV cache. + with set_forward_context(current_timestep=1, attn_metadata=None): + _ = self.transformer( + hidden_states=condition_latent, + encoder_hidden_states=torch.zeros_like(prompt_embeds)[ + :1, :0, ... + ], + prior_token_id=condition_image_prior_token_id, + prior_token_drop=torch.full_like( + condition_image_prior_token_id, False, dtype=torch.bool + ), + timestep=torch.zeros((1,), device=device), + target_size=torch.tensor( + [condition_image.shape[-2:]], device=device + ), + crop_coords=torch.zeros((1, 2), device=device), + attention_kwargs=attention_kwargs, + kv_caches=kv_caches, + kv_caches_mode="write", + ) + + # 6. Prepare additional timestep conditions + target_size = (height, width) + target_size = torch.tensor( + [target_size], dtype=prompt_embeds.dtype, device=device + ) + crops_coords_top_left = torch.tensor( + [(0, 0)], dtype=prompt_embeds.dtype, device=device + ) + + # Prepare timesteps + image_seq_len = ( + (height // self.vae_scale_factor) * (width // self.vae_scale_factor) + ) // (self.transformer.config.patch_size**2) + timesteps = np.linspace( + self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps + 1 + )[:-1] + timesteps = timesteps.astype(np.int64).astype(np.float32) + sigmas = timesteps / self.scheduler.config.num_train_timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("base_shift", 0.25), + self.scheduler.config.get("max_shift", 0.75), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu + ) + self._num_timesteps = len(timesteps) + + # 7. Prepare for denoising loop + + batch.prompt_embeds = [prompt_embeds] + batch.negative_prompt_embeds = [negative_prompt_embeds] + batch.latents = latents + batch.timesteps = timesteps + batch.num_inference_steps = num_inference_steps + batch.sigmas = sigmas.tolist() # Convert numpy array to list for validation + batch.generator = generator + batch.raw_latent_shape = latents.shape + + batch.prior_token_id = prior_token_id + batch.prior_token_drop_cond = torch.full_like( + prior_token_id, False, dtype=torch.bool + ) + batch.prior_token_drop_uncond = torch.full_like( + prior_token_id, True, dtype=torch.bool + ) + batch.target_size = target_size + batch.crop_coords = crops_coords_top_left + + batch.kv_caches = kv_caches + + batch.height = height + batch.width = width + + return batch diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/helios_denoising.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/helios_denoising.py new file mode 100644 index 0000000000000000000000000000000000000000..ed7c48c085d1b78defb0dfb084eddbac6eb42886 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/helios_denoising.py @@ -0,0 +1,666 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Helios-specific chunked denoising stage. + +Implements Stage 1 chunked denoising with multi-term memory history +and CFG Zero Star guidance. VAE decoding is handled by the standard +DecodingStage downstream. +""" + +import math + +import numpy as np +import torch +import torch.nn.functional as F + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.base import ( + PipelineStage, + StageParallelismType, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE + +logger = init_logger(__name__) + + +def optimized_scale(positive_flat, negative_flat): + """CFG Zero Star: compute optimal guidance scale.""" + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 + return dot_product / squared_norm + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def sample_block_noise( + batch_size, channel, num_frames, height, width, gamma, patch_size=(1, 2, 2) +): + """Generate spatially-correlated block noise for pyramid SR.""" + _, ph, pw = patch_size + block_size = ph * pw + + cov = ( + torch.eye(block_size) * (1 + gamma) - torch.ones(block_size, block_size) * gamma + ) + dist = torch.distributions.MultivariateNormal( + torch.zeros(block_size, device=cov.device), covariance_matrix=cov + ) + block_number = batch_size * channel * num_frames * (height // ph) * (width // pw) + + noise = dist.sample((block_number,)) + noise = noise.view( + batch_size, channel, num_frames, height // ph, width // pw, ph, pw + ) + noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape( + batch_size, channel, num_frames, height, width + ) + return noise + + +class HeliosChunkedDenoisingStage(PipelineStage): + """ + Helios chunked denoising stage implementing Stage 1 loop. + + Iterates over video chunks, manages history buffers (short/mid/long), + runs transformer per chunk with CFG guidance, scheduler step, + and accumulates denoised latents. VAE decoding is left to DecodingStage. + """ + + def __init__(self, transformer, scheduler): + super().__init__() + self.transformer = transformer + self.scheduler = scheduler + + @property + def parallelism_type(self): + return StageParallelismType.REPLICATED + + def _denoise_one_chunk( + self, + latents, + prompt_embeds, + negative_prompt_embeds, + timesteps, + guidance_scale, + indices_hidden_states, + indices_latents_history_short, + indices_latents_history_mid, + indices_latents_history_long, + latents_history_short, + latents_history_mid, + latents_history_long, + target_dtype, + device, + is_cfg_zero_star=True, + use_zero_init=True, + zero_steps=1, + batch=None, + server_args=None, + ): + """Denoise a single chunk with full timestep loop.""" + batch_size = latents.shape[0] + do_cfg = guidance_scale > 1.0 + + for i, t in enumerate(timesteps): + timestep = t.expand(batch_size) + latent_model_input = latents.to(target_dtype) + + with set_forward_context( + current_timestep=t, + forward_batch=None, + attn_metadata=None, + ): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=( + latents_history_short.to(target_dtype) + if latents_history_short is not None + else None + ), + latents_history_mid=( + latents_history_mid.to(target_dtype) + if latents_history_mid is not None + else None + ), + latents_history_long=( + latents_history_long.to(target_dtype) + if latents_history_long is not None + else None + ), + ) + + if do_cfg: + with set_forward_context( + current_timestep=t, + forward_batch=None, + attn_metadata=None, + ): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=( + latents_history_short.to(target_dtype) + if latents_history_short is not None + else None + ), + latents_history_mid=( + latents_history_mid.to(target_dtype) + if latents_history_mid is not None + else None + ), + latents_history_long=( + latents_history_long.to(target_dtype) + if latents_history_long is not None + else None + ), + ) + + if is_cfg_zero_star: + noise_pred_text = noise_pred + positive_flat = noise_pred_text.reshape(batch_size, -1) + negative_flat = noise_uncond.reshape(batch_size, -1) + + alpha = optimized_scale(positive_flat, negative_flat) + alpha = alpha.view( + batch_size, *([1] * (len(noise_pred_text.shape) - 1)) + ) + alpha = alpha.to(noise_pred_text.dtype) + + if (i <= zero_steps) and use_zero_init: + noise_pred = noise_pred_text * 0.0 + else: + noise_pred = noise_uncond * alpha + guidance_scale * ( + noise_pred_text - noise_uncond * alpha + ) + else: + noise_pred = noise_uncond + guidance_scale * ( + noise_pred - noise_uncond + ) + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + return latents + + def _denoise_one_chunk_stage2( + self, + latents, + prompt_embeds, + negative_prompt_embeds, + guidance_scale, + indices_hidden_states, + indices_latents_history_short, + indices_latents_history_mid, + indices_latents_history_long, + latents_history_short, + latents_history_mid, + latents_history_long, + target_dtype, + device, + pyramid_num_stages, + pyramid_num_inference_steps_list, + is_distilled, + is_amplify_first_chunk, + gamma, + is_cfg_zero_star=True, + use_zero_init=True, + zero_steps=1, + batch=None, + server_args=None, + ): + """Denoise a single chunk using pyramid super-resolution (Stage 2).""" + batch_size, num_channel, num_frames, height, width = latents.shape + patch_size = self.transformer.patch_size + + # Downsample to lowest pyramid level + latents = latents.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frames, num_channel, height, width + ) + for _ in range(pyramid_num_stages - 1): + height //= 2 + width //= 2 + latents = F.interpolate(latents, size=(height, width), mode="bilinear") * 2 + latents = latents.reshape( + batch_size, num_frames, num_channel, height, width + ).permute(0, 2, 1, 3, 4) + + start_point_list = None + if is_distilled: + start_point_list = [latents] + + do_cfg = guidance_scale > 1.0 + + for i_s in range(pyramid_num_stages): + # Compute mu for current resolution + image_seq_len = ( + latents.shape[-1] + * latents.shape[-2] + * latents.shape[-3] + // (patch_size[0] * patch_size[1] * patch_size[2]) + ) + mu = calculate_shift(image_seq_len) + + self.scheduler.set_timesteps( + pyramid_num_inference_steps_list[i_s], + i_s, + device=device, + mu=mu, + is_amplify_first_chunk=is_amplify_first_chunk, + ) + timesteps = self.scheduler.timesteps + + if i_s > 0: + # Upsample 2x nearest-neighbor + height *= 2 + width *= 2 + latents = latents.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frames, + num_channel, + height // 2, + width // 2, + ) + latents = F.interpolate(latents, size=(height, width), mode="nearest") + latents = latents.reshape( + batch_size, num_frames, num_channel, height, width + ).permute(0, 2, 1, 3, 4) + + # Renoise with correlated block noise + ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s] + alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) + beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) + + bs, ch, nf, h, w = latents.shape + noise = sample_block_noise(bs, ch, nf, h, w, gamma, patch_size) + noise = noise.to(device=device, dtype=target_dtype) + latents = alpha * latents + beta * noise + + if is_distilled: + start_point_list.append(latents) + + # Denoising loop for this pyramid stage + for idx, t in enumerate(timesteps): + timestep = t.expand(batch_size) + latent_model_input = latents.to(target_dtype) + + with set_forward_context( + current_timestep=t, + forward_batch=None, + attn_metadata=None, + ): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=( + latents_history_short.to(target_dtype) + if latents_history_short is not None + else None + ), + latents_history_mid=( + latents_history_mid.to(target_dtype) + if latents_history_mid is not None + else None + ), + latents_history_long=( + latents_history_long.to(target_dtype) + if latents_history_long is not None + else None + ), + ) + + if do_cfg: + with set_forward_context( + current_timestep=t, + forward_batch=None, + attn_metadata=None, + ): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=( + latents_history_short.to(target_dtype) + if latents_history_short is not None + else None + ), + latents_history_mid=( + latents_history_mid.to(target_dtype) + if latents_history_mid is not None + else None + ), + latents_history_long=( + latents_history_long.to(target_dtype) + if latents_history_long is not None + else None + ), + ) + + if is_cfg_zero_star: + noise_pred_text = noise_pred + positive_flat = noise_pred_text.reshape(batch_size, -1) + negative_flat = noise_uncond.reshape(batch_size, -1) + + alpha_cfg = optimized_scale(positive_flat, negative_flat) + alpha_cfg = alpha_cfg.view( + batch_size, + *([1] * (len(noise_pred_text.shape) - 1)), + ) + alpha_cfg = alpha_cfg.to(noise_pred_text.dtype) + + if (i_s == 0 and idx <= zero_steps) and use_zero_init: + noise_pred = noise_pred_text * 0.0 + else: + noise_pred = noise_uncond * alpha_cfg + guidance_scale * ( + noise_pred_text - noise_uncond * alpha_cfg + ) + else: + noise_pred = noise_uncond + guidance_scale * ( + noise_pred - noise_uncond + ) + + latents = self.scheduler.step( + noise_pred, + t, + latents, + return_dict=False, + cur_sampling_step=idx, + dmd_noisy_tensor=( + start_point_list[i_s] if start_point_list is not None else None + ), + dmd_sigmas=self.scheduler.sigmas, + dmd_timesteps=self.scheduler.timesteps, + all_timesteps=timesteps, + )[0] + + return latents + + def forward(self, batch: Req, server_args: ServerArgs) -> Req: + """Run the Helios chunked denoising loop.""" + pipeline_config = server_args.pipeline_config + device = ( + batch.latents.device + if hasattr(batch, "latents") and batch.latents is not None + else torch.device("cuda") + ) + target_dtype = PRECISION_TO_TYPE.get( + server_args.pipeline_config.precision, torch.bfloat16 + ) + + # Get config params + num_latent_frames_per_chunk = pipeline_config.num_latent_frames_per_chunk + history_sizes = sorted(list(pipeline_config.history_sizes), reverse=True) + is_cfg_zero_star = pipeline_config.is_cfg_zero_star + zero_steps = pipeline_config.zero_steps + keep_first_frame = pipeline_config.keep_first_frame + guidance_scale = batch.guidance_scale + num_inference_steps = batch.num_inference_steps + + # Stage 2 params + is_enable_stage2 = pipeline_config.is_enable_stage2 + pyramid_num_stages = pipeline_config.pyramid_num_stages + pyramid_num_inference_steps_list = ( + pipeline_config.pyramid_num_inference_steps_list + ) + is_distilled = pipeline_config.is_distilled + is_amplify_first_chunk = pipeline_config.is_amplify_first_chunk + gamma = pipeline_config.gamma + + # Move transformer to GPU if CPU-offloaded + if server_args.dit_cpu_offload and not server_args.use_fsdp_inference: + if next(self.transformer.parameters()).device.type == "cpu": + self.transformer.to(get_local_torch_device()) + + # Get encoder outputs (prompt_embeds is a list of tensors, one per encoder) + prompt_embeds = batch.prompt_embeds + if isinstance(prompt_embeds, list): + prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.to(target_dtype) + negative_prompt_embeds = batch.negative_prompt_embeds + if isinstance(negative_prompt_embeds, list): + negative_prompt_embeds = ( + negative_prompt_embeds[0] if negative_prompt_embeds else None + ) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(target_dtype) + + # Scale factors inherited from the Wan VAE used by Helios + # (AutoencoderKLWan: temporal_compression_ratio=4, spatial_compression_ratio=8) + vae_scale_factor_temporal = 4 + vae_scale_factor_spatial = 8 + + # Compute chunking + height = batch.height + width = batch.width + num_frames = batch.num_frames + num_channels_latents = self.transformer.in_channels + + window_num_frames = ( + num_latent_frames_per_chunk - 1 + ) * vae_scale_factor_temporal + 1 + num_latent_chunk = max( + 1, (num_frames + window_num_frames - 1) // window_num_frames + ) + num_history_latent_frames = sum(history_sizes) + batch_size = 1 # Helios processes one video at a time + + # Prepare history latents + if not keep_first_frame: + history_sizes[-1] = history_sizes[-1] + 1 + history_latents = torch.zeros( + batch_size, + num_channels_latents, + num_history_latent_frames, + height // vae_scale_factor_spatial, + width // vae_scale_factor_spatial, + device=device, + dtype=torch.float32, + ) + + # Build frame indices + if keep_first_frame: + indices = torch.arange( + 0, sum([1, *history_sizes, num_latent_frames_per_chunk]) + ) + ( + indices_prefix, + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_1x, + indices_hidden_states, + ) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0) + indices_latents_history_short = torch.cat( + [indices_prefix, indices_latents_history_1x], dim=0 + ) + else: + indices = torch.arange( + 0, sum([*history_sizes, num_latent_frames_per_chunk]) + ) + ( + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_short, + indices_hidden_states, + ) = indices.split([*history_sizes, num_latent_frames_per_chunk], dim=0) + + indices_hidden_states = indices_hidden_states.unsqueeze(0) + indices_latents_history_short = indices_latents_history_short.unsqueeze(0) + indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0) + indices_latents_history_long = indices_latents_history_long.unsqueeze(0) + + # Set up scheduler + patch_size = self.transformer.patch_size + image_seq_len = ( + num_latent_frames_per_chunk + * (height // vae_scale_factor_spatial) + * (width // vae_scale_factor_spatial) + // (patch_size[0] * patch_size[1] * patch_size[2]) + ) + # Sigma schedule from near-1.0 (pure noise) to 0.0 (clean); 0.999 avoids singularity + sigmas = np.linspace(0.999, 0.0, num_inference_steps + 1)[:-1] + mu = calculate_shift(image_seq_len) + + # Chunk loop + image_latents = None + total_generated_latent_frames = 0 + + self.log_info( + f"Starting chunked denoising: {num_latent_chunk} chunks, " + f"{num_inference_steps} steps each" + ) + + for k in range(num_latent_chunk): + is_first_chunk = k == 0 + + # Extract history + if keep_first_frame: + ( + latents_history_long, + latents_history_mid, + latents_history_1x, + ) = history_latents[:, :, -num_history_latent_frames:].split( + history_sizes, dim=2 + ) + if image_latents is None and is_first_chunk: + latents_prefix = torch.zeros( + ( + batch_size, + num_channels_latents, + 1, + latents_history_1x.shape[-2], + latents_history_1x.shape[-1], + ), + device=device, + dtype=latents_history_1x.dtype, + ) + else: + latents_prefix = image_latents + latents_history_short = torch.cat( + [latents_prefix, latents_history_1x], dim=2 + ) + else: + ( + latents_history_long, + latents_history_mid, + latents_history_short, + ) = history_latents[:, :, -num_history_latent_frames:].split( + history_sizes, dim=2 + ) + + # Generate noise latents for this chunk + latents = torch.randn( + batch_size, + num_channels_latents, + (window_num_frames - 1) // vae_scale_factor_temporal + 1, + height // vae_scale_factor_spatial, + width // vae_scale_factor_spatial, + device=device, + dtype=torch.float32, + ) + + if is_enable_stage2: + # Stage 2: Pyramid SR denoising (handles scheduler internally) + latents = self._denoise_one_chunk_stage2( + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short, + latents_history_mid=latents_history_mid, + latents_history_long=latents_history_long, + target_dtype=target_dtype, + device=device, + pyramid_num_stages=pyramid_num_stages, + pyramid_num_inference_steps_list=pyramid_num_inference_steps_list, + is_distilled=is_distilled, + is_amplify_first_chunk=(is_amplify_first_chunk and is_first_chunk), + gamma=gamma, + is_cfg_zero_star=is_cfg_zero_star, + use_zero_init=True, + zero_steps=zero_steps, + batch=batch, + server_args=server_args, + ) + else: + # Stage 1: Standard flat denoising + self.scheduler.set_timesteps( + num_inference_steps, device=device, sigmas=sigmas, mu=mu + ) + timesteps = self.scheduler.timesteps + + latents = self._denoise_one_chunk( + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + timesteps=timesteps, + guidance_scale=guidance_scale, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short, + latents_history_mid=latents_history_mid, + latents_history_long=latents_history_long, + target_dtype=target_dtype, + device=device, + is_cfg_zero_star=is_cfg_zero_star, + use_zero_init=True, + zero_steps=zero_steps, + batch=batch, + server_args=server_args, + ) + + # Extract first frame as image_latents for subsequent chunks + if keep_first_frame and is_first_chunk and image_latents is None: + image_latents = latents[:, :, 0:1, :, :] + + # Update history + total_generated_latent_frames += latents.shape[2] + history_latents = torch.cat([history_latents, latents], dim=2) + + # Move transformer back to CPU after denoising + if server_args.dit_cpu_offload and not server_args.use_fsdp_inference: + if next(self.transformer.parameters()).device.type != "cpu": + self.transformer.to("cpu") + torch.cuda.empty_cache() + + # Store denoised latents for the standard DecodingStage to decode + batch.latents = history_latents[:, :, -total_generated_latent_frames:] + + return batch diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py new file mode 100644 index 0000000000000000000000000000000000000000..4b77e54447e0769517b810540c76d9adc24eb84c --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py @@ -0,0 +1,918 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +MOVA-specific pipeline stages. + +Sequence Parallelism (SP) Support: +- Video latents are sharded along the sequence dimension (T*H*W) after patchify +- Audio latents are sharded along the sequence dimension (L) after patchify +- USPAttention handles all-to-all communication internally +- Latents are gathered before unpatchify to restore full sequence +""" + +from __future__ import annotations + +import functools +import inspect +import os +from collections.abc import Iterable + +import torch +import torch.nn as nn +from diffusers.utils.torch_utils import randn_tensor +from tqdm.auto import tqdm + +from sglang.multimodal_gen.runtime.distributed import ( + get_local_torch_device, + get_world_group, +) +from sglang.multimodal_gen.runtime.distributed.communication_op import ( + cfg_model_parallel_all_reduce, + sequence_model_parallel_all_gather, +) +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_sp_parallel_rank, + get_sp_world_size, +) +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context + +# Both audio and video DiT use the same sinusoidal_embedding_1d function +# Import from mova_video_dit where it's defined (mova_audio_dit re-exports it) +from sglang.multimodal_gen.runtime.models.dits.mova_video_dit import ( + sinusoidal_embedding_1d, +) + +# Create aliases for backward compatibility +video_sinusoidal_embedding_1d = sinusoidal_embedding_1d +audio_sinusoidal_embedding_1d = sinusoidal_embedding_1d +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch, Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.base import ( + PipelineStage, + StageParallelismType, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.decoding import ( + _ensure_tensor_decode_output, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + StageValidators as V, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + VerificationResult, +) +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.server_args import ServerArgs, get_global_server_args +from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.runtime.utils.perf_logger import StageProfiler +from sglang.multimodal_gen.runtime.utils.profiler import SGLDiffusionProfiler +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE + +logger = init_logger(__name__) + + +class MOVALatentPreparationStage(PipelineStage): + """Prepare video/audio noise latents for MOVA.""" + + def __init__(self, audio_vae, require_vae_embedding: bool = True) -> None: + super().__init__() + self.audio_vae = audio_vae + self.require_vae_embedding = require_vae_embedding + + def forward(self, batch: Req, server_args: ServerArgs) -> Req: + batch_size = batch.batch_size + num_frames = batch.num_frames + if num_frames is None: + raise ValueError("num_frames is required for MOVA") + + audio_num_samples = int(self.audio_vae.sample_rate * num_frames / batch.fps) + + video_shape = server_args.pipeline_config.prepare_latent_shape( + batch, batch_size, num_frames + ) + audio_shape = server_args.pipeline_config.prepare_audio_latent_shape( + batch_size, audio_num_samples, self.audio_vae + ) + + device = get_local_torch_device() + generator = batch.generator + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + dit_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.dit_precision] + batch.latents = randn_tensor( + video_shape, generator=generator, device=device, dtype=dit_dtype + ) + batch.audio_latents = randn_tensor( + audio_shape, generator=generator, device=device, dtype=dit_dtype + ) + + if batch.image_latent is not None: + batch.y = batch.image_latent.to(device=device, dtype=dit_dtype) + elif self.require_vae_embedding: + raise ValueError("MOVA requires reference image latents for denoising") + return batch + + +class MOVATimestepPreparationStage(PipelineStage): + """Prepare paired timesteps for MOVA.""" + + def __init__(self, scheduler) -> None: + super().__init__() + self.scheduler = scheduler + + def forward(self, batch: Req, server_args: ServerArgs) -> Req: + self.scheduler.set_timesteps( + batch.num_inference_steps, + denoising_strength=1.0, + shift=getattr(batch, "sigma_shift", self.scheduler.shift), + ) + self.scheduler.set_pair_postprocess_by_name( + "dual_sigma_shift", + visual_shift=getattr(batch, "visual_shift", 5.0), + audio_shift=getattr(batch, "audio_shift", 5.0), + ) + paired = self.scheduler.get_pairs() + batch.paired_timesteps = paired + batch.timesteps = paired + return batch + + +class MOVADenoisingStage(PipelineStage): + """Run MOVA dual-tower denoising loop.""" + + def __init__(self, video_dit, video_dit_2, audio_dit, dual_tower_bridge, scheduler): + super().__init__() + self.video_dit = video_dit + self.video_dit_2 = video_dit_2 + self.audio_dit = audio_dit + self.dual_tower_bridge = dual_tower_bridge + self.scheduler = scheduler + self._cache_dit_enabled = False + self._cached_num_steps = None + self._torch_compiled = False + + @property + def parallelism_type(self) -> StageParallelismType: + if get_global_server_args().enable_cfg_parallel: + return StageParallelismType.CFG_PARALLEL + return StageParallelismType.REPLICATED + + def _predict( + self, + visual_dit, + visual_latents, + audio_latents, + y, + context, + timestep, + audio_timestep, + video_fps, + timestep_index: int, + attn_metadata, + forward_batch: Req | None = None, + ): + # Set forward context for distributed attention (USPAttention) + with set_forward_context( + current_timestep=timestep_index, + attn_metadata=attn_metadata, + forward_batch=forward_batch, + ): + return self.inference_single_step( + visual_dit=visual_dit, + visual_latents=visual_latents, + audio_latents=audio_latents, + y=y, + context=context, + timestep=timestep, + audio_timestep=audio_timestep, + video_fps=video_fps, + ) + + def _cfg_combine(self, pos, neg, guidance_scale, cfg_rank, enable_cfg_parallel): + if not enable_cfg_parallel: + return neg + guidance_scale * (pos - neg) + if cfg_rank == 0: + partial = guidance_scale * pos + else: + partial = (1 - guidance_scale) * neg + return cfg_model_parallel_all_reduce(partial) + + def _maybe_enable_torch_compile(self, module: nn.Module, server_args: ServerArgs): + """ + Compile a module with torch.compile, and enable inductor overlap tweak if available. + No-op if torch compile is disabled or the object is not a nn.Module. + """ + if not server_args.enable_torch_compile or not isinstance(module, nn.Module): + return + try: + import torch._inductor.config as _inductor_cfg + + _inductor_cfg.reorder_for_compute_comm_overlap = True + except ImportError: + pass + mode = os.environ.get("SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs") + logger.info("Compiling %s with mode: %s", module.__class__.__name__, mode) + # TODO(triple-mu): support customized fullgraph and dynamic in the future + module.compile(mode=mode, fullgraph=False, dynamic=None) + + def _maybe_compile_dits(self, server_args: ServerArgs): + if self._torch_compiled or not server_args.enable_torch_compile: + return + for module in filter(None, [self.video_dit, self.video_dit_2, self.audio_dit]): + self._maybe_enable_torch_compile(module, server_args) + self._torch_compiled = True + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify denoising stage inputs.""" + result = VerificationResult() + result.add_check("y", batch.y, V.is_tensor) + result.add_check("paired_timesteps", batch.paired_timesteps, V.is_tensor) + result.add_check("latents", batch.latents, V.is_tensor) + result.add_check("audio_latents", batch.audio_latents, V.is_tensor) + result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty) + result.add_check( + "negative_prompt_embeds", + batch.negative_prompt_embeds, + lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x), + ) + result.add_check( + "num_inference_steps", batch.num_inference_steps, V.positive_int + ) + result.add_check("guidance_scale", batch.guidance_scale, V.non_negative_float) + result.add_check( + "guidance_rescale", batch.guidance_rescale, V.non_negative_float + ) + result.add_check( + "do_classifier_free_guidance", + batch.do_classifier_free_guidance, + V.bool_value, + ) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify denoising stage outputs.""" + result = VerificationResult() + result.add_check("latents", batch.latents, V.is_tensor) + result.add_check("audio_latents", batch.audio_latents, V.is_tensor) + return result + + def progress_bar( + self, iterable: Iterable | None = None, total: int | None = None + ) -> tqdm: + """ + Create a progress bar for the denoising process. + """ + local_rank = get_world_group().local_rank + disable = local_rank != 0 + return tqdm(iterable=iterable, total=total, disable=disable) + + def step_profile(self): + profiler = SGLDiffusionProfiler.get_instance() + if profiler: + profiler.step_denoising_step() + + def rescale_noise_cfg( + self, noise_cfg, noise_pred_text, guidance_rescale=0.0 + ) -> torch.Tensor: + """ + Rescale noise prediction according to guidance_rescale. + + Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed" + (https://arxiv.org/pdf/2305.08891.pdf), Section 3.4. + """ + std_text = noise_pred_text.std( + dim=list(range(1, noise_pred_text.ndim)), keepdim=True + ) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + noise_cfg = ( + guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + ) + return noise_cfg + + def prepare_extra_func_kwargs(self, func, kwargs) -> dict[str, object]: + if not kwargs: + return {} + + if isinstance(func, functools.partial) and func.args: + func = getattr(func.args[0], "_original_forward", func) + + target_func = inspect.unwrap(func) + params = inspect.signature(target_func).parameters + return {k: v for k, v in kwargs.items() if k in params} + + def _build_attn_metadata( + self, i: int, batch: Req, server_args: ServerArgs + ) -> object | None: + return None + + def _manage_device_placement( + self, + model_to_use: nn.Module | None, + model_to_offload: nn.Module | None, + server_args: ServerArgs, + ): + if not server_args.dit_cpu_offload: + return + + if ( + model_to_offload is not None + and next(model_to_offload.parameters()).device.type == "cuda" + ): + model_to_offload.to("cpu") + + if ( + model_to_use is not None + and next(model_to_use.parameters()).device.type == "cpu" + ): + model_to_use.to(get_local_torch_device()) + + def _select_visual_dit( + self, timestep: float, boundary_ratio: float | None, server_args: ServerArgs + ): + if boundary_ratio is None or self.video_dit_2 is None: + self._manage_device_placement(self.video_dit, None, server_args) + return self.video_dit + + boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps + if timestep >= boundary_timestep: + current_model = self.video_dit + model_to_offload = self.video_dit_2 + else: + current_model = self.video_dit_2 + model_to_offload = self.video_dit + + self._manage_device_placement(current_model, model_to_offload, server_args) + return current_model + + def _ensure_shared_models_on_device(self, server_args: ServerArgs): + """Ensure shared denoising modules are on the active device when cpu offload is enabled.""" + self._manage_device_placement(self.audio_dit, None, server_args) + self._manage_device_placement(self.dual_tower_bridge, None, server_args) + + def _apply_guidance_rescale( + self, + noise_pred, + noise_pred_text, + guidance_rescale, + cfg_rank, + enable_cfg_parallel, + ): + if guidance_rescale <= 0.0: + return noise_pred + if enable_cfg_parallel: + std_cfg = noise_pred.std(dim=list(range(1, noise_pred.ndim)), keepdim=True) + if cfg_rank == 0: + assert noise_pred_text is not None + std_text = noise_pred_text.std( + dim=list(range(1, noise_pred_text.ndim)), keepdim=True + ) + else: + std_text = torch.empty_like(std_cfg) + std_text = get_cfg_group().broadcast(std_text, src=0) + noise_pred_rescaled = noise_pred * (std_text / std_cfg) + return guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * ( + noise_pred + ) + return self.rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale) + + @torch.no_grad() + def forward(self, batch: Req, server_args: ServerArgs) -> Req: + self._maybe_compile_dits(server_args) + self._ensure_shared_models_on_device(server_args) + + paired_timesteps = batch.paired_timesteps + if paired_timesteps is None: + raise ValueError("paired_timesteps must be set for MOVA") + + y = batch.y if batch.y is not None else batch.image_latent + if getattr(self.video_dit, "require_vae_embedding", False) and y is None: + raise ValueError("MOVA requires reference image latents for denoising") + + boundary_ratio = server_args.pipeline_config.boundary_ratio + total_steps = paired_timesteps.shape[0] + cfg_rank = get_classifier_free_guidance_rank() + enable_cfg_parallel = server_args.enable_cfg_parallel + + is_warmup = batch.is_warmup + extra_step_kwargs = self.prepare_extra_func_kwargs( + self.scheduler.step_from_to, + getattr(batch, "extra_step_kwargs", None) or {}, + ) + + metrics = getattr(batch, "metrics", None) + perf_dump_path_provided = getattr(batch, "perf_dump_path", None) is not None + + with self.progress_bar(total=total_steps) as progress_bar: + for idx_step in range(total_steps): + with StageProfiler( + f"denoising_step_{idx_step}", + logger=logger, + metrics=metrics, + perf_dump_path_provided=perf_dump_path_provided, + ): + pair_t = paired_timesteps[idx_step] + if getattr(pair_t, "shape", None) == (2,): + timestep, audio_timestep = pair_t + else: + timestep = pair_t + audio_timestep = pair_t + + cur_visual_dit = self._select_visual_dit( + timestep.item(), boundary_ratio, server_args + ) + + timestep = timestep.unsqueeze(0).to(device=get_local_torch_device()) + audio_timestep = audio_timestep.unsqueeze(0).to( + device=get_local_torch_device() + ) + + attn_metadata = self._build_attn_metadata( + idx_step, batch, server_args + ) + + if not batch.do_classifier_free_guidance: + visual_noise_pred, audio_noise_pred = self._predict( + cur_visual_dit, + batch.latents, + batch.audio_latents, + y, + batch.prompt_embeds[0], + timestep, + audio_timestep, + batch.fps, + idx_step, + attn_metadata, + batch, + ) + else: + if enable_cfg_parallel: + if cfg_rank == 0: + pos = self._predict( + cur_visual_dit, + batch.latents, + batch.audio_latents, + y, + batch.prompt_embeds[0], + timestep, + audio_timestep, + batch.fps, + idx_step, + attn_metadata, + batch, + ) + neg = (None, None) + else: + pos = (None, None) + neg = self._predict( + cur_visual_dit, + batch.latents, + batch.audio_latents, + y, + batch.negative_prompt_embeds[0], + timestep, + audio_timestep, + batch.fps, + idx_step, + attn_metadata, + batch, + ) + else: + pos = self._predict( + cur_visual_dit, + batch.latents, + batch.audio_latents, + y, + batch.prompt_embeds[0], + timestep, + audio_timestep, + batch.fps, + idx_step, + attn_metadata, + batch, + ) + neg = self._predict( + cur_visual_dit, + batch.latents, + batch.audio_latents, + y, + batch.negative_prompt_embeds[0], + timestep, + audio_timestep, + batch.fps, + idx_step, + attn_metadata, + batch, + ) + + visual_noise_pred = self._cfg_combine( + pos[0] if pos[0] is not None else neg[0], + neg[0] if neg[0] is not None else pos[0], + batch.guidance_scale, + cfg_rank, + enable_cfg_parallel, + ) + audio_noise_pred = self._cfg_combine( + pos[1] if pos[1] is not None else neg[1], + neg[1] if neg[1] is not None else pos[1], + batch.guidance_scale, + cfg_rank, + enable_cfg_parallel, + ) + + if batch.guidance_rescale > 0.0: + visual_noise_pred = self._apply_guidance_rescale( + visual_noise_pred, + pos[0] if pos[0] is not None else None, + batch.guidance_rescale, + cfg_rank, + enable_cfg_parallel, + ) + audio_noise_pred = self._apply_guidance_rescale( + audio_noise_pred, + pos[1] if pos[1] is not None else None, + batch.guidance_rescale, + cfg_rank, + enable_cfg_parallel, + ) + + if idx_step + 1 < total_steps: + next_pair_t = paired_timesteps[idx_step + 1] + if getattr(next_pair_t, "shape", None) == (2,): + next_timestep, next_audio_timestep = next_pair_t + else: + next_timestep = next_pair_t + next_audio_timestep = next_pair_t + else: + next_timestep = None + next_audio_timestep = None + + batch.latents = self.scheduler.step_from_to( + visual_noise_pred, + timestep, + next_timestep, + batch.latents, + **extra_step_kwargs, + ) + batch.audio_latents = self.scheduler.step_from_to( + audio_noise_pred, + audio_timestep, + next_audio_timestep, + batch.audio_latents, + **extra_step_kwargs, + ) + + if progress_bar is not None: + progress_bar.update() + if not is_warmup and hasattr(self, "step_profile"): + self.step_profile() + + for dit in filter(None, [self.video_dit, self.video_dit_2, self.audio_dit]): + if isinstance(dit, OffloadableDiTMixin): + dit.prepare_for_next_req() + + return batch + + def _shard_sequence_for_sp( + self, x: torch.Tensor, dim: int = 1 + ) -> tuple[torch.Tensor, int]: + """ + Shard tensor along sequence dimension for Sequence Parallelism. + + Args: + x: Input tensor + dim: Dimension to shard along + + Returns: + (sharded_tensor, pad_len) + """ + sp_size = get_sp_world_size() + if sp_size <= 1: + return x, 0 + + sp_rank = get_sp_parallel_rank() + seq_len = x.shape[dim] + + # Pad if needed + pad_len = (sp_size - (seq_len % sp_size)) % sp_size + if pad_len > 0: + pad_shape = list(x.shape) + pad_shape[dim] = pad_len + pad = torch.zeros(pad_shape, dtype=x.dtype, device=x.device) + x = torch.cat([x, pad], dim=dim) + + # Shard + chunk_size = x.shape[dim] // sp_size + start = sp_rank * chunk_size + end = start + chunk_size + idx = [slice(None)] * x.dim() + idx[dim] = slice(start, end) + return x[tuple(idx)], pad_len + + def _gather_sequence_from_sp( + self, x: torch.Tensor, pad_len: int, dim: int = 1 + ) -> torch.Tensor: + """ + Gather tensor along sequence dimension after Sequence Parallelism. + + Args: + x: Sharded tensor + pad_len: Padding length that was added during sharding + dim: Dimension to gather along + + Returns: + Gathered tensor with padding removed + """ + sp_size = get_sp_world_size() + if sp_size <= 1: + return x + + gathered = sequence_model_parallel_all_gather(x, dim=dim) + if pad_len > 0: + idx = [slice(None)] * gathered.dim() + idx[dim] = slice(0, gathered.shape[dim] - pad_len) + gathered = gathered[tuple(idx)] + return gathered + + def inference_single_step( + self, + visual_dit, + visual_latents: torch.Tensor, + audio_latents: torch.Tensor, + y, + context: torch.Tensor, + timestep: torch.Tensor, + audio_timestep: torch.Tensor, + video_fps: float, + ): + """ + Single inference step for MOVA dual-tower denoising. + + Supports Sequence Parallelism (SP): + - After patchify, sequences are sharded across SP ranks + - USPAttention handles distributed attention communication + - Before unpatchify, sequences are gathered back + """ + model_dtype = visual_dit.time_embedding.fc_in.weight.dtype + device = visual_latents.device + + visual_context = context.to(device=device, dtype=model_dtype) + audio_context = context.to(device=device, dtype=model_dtype) + with torch.autocast( + device_type=current_platform.device_type, dtype=torch.float32 + ): + visual_t = visual_dit.time_embedding( + video_sinusoidal_embedding_1d(visual_dit.freq_dim, timestep) + ) + visual_t_mod, _ = visual_dit.time_projection(visual_t) + visual_t_mod = visual_t_mod.unflatten(1, (6, visual_dit.dim)) + + audio_t = self.audio_dit.time_embedding( + audio_sinusoidal_embedding_1d(self.audio_dit.freq_dim, audio_timestep) + ) + audio_t_mod, _ = self.audio_dit.time_projection(audio_t) + audio_t_mod = audio_t_mod.unflatten(1, (6, self.audio_dit.dim)) + + visual_t = visual_t.to(model_dtype) + visual_t_mod = visual_t_mod.to(model_dtype) + audio_t = audio_t.to(model_dtype) + audio_t_mod = audio_t_mod.to(model_dtype) + + visual_context_emb = visual_dit.text_embedding(visual_context) + audio_context_emb = self.audio_dit.text_embedding(audio_context) + + visual_x = visual_latents.to(model_dtype) + audio_x = audio_latents.to(model_dtype) + + if getattr(visual_dit, "require_vae_embedding", False): + visual_x = torch.cat([visual_x, y], dim=1) + + # Patchify visual latents + visual_x, (t, h, w) = visual_dit.patchify(visual_x) + grid_size = (t, h, w) + full_visual_seq_len = t * h * w + + # Build visual freqs for full sequence + visual_dit._init_freqs() + visual_freqs = tuple(freq.to(visual_x.device) for freq in visual_dit.freqs) + visual_freqs = ( + torch.cat( + [ + visual_freqs[0][:t].view(t, 1, 1, -1).expand(t, h, w, -1), + visual_freqs[1][:h].view(1, h, 1, -1).expand(t, h, w, -1), + visual_freqs[2][:w].view(1, 1, w, -1).expand(t, h, w, -1), + ], + dim=-1, + ) + .reshape(full_visual_seq_len, 1, -1) + .to(visual_x.device) + ) + + # Patchify audio latents + audio_x, (f,) = self.audio_dit.patchify(audio_x, None) + full_audio_seq_len = f + + # Build audio freqs for full sequence + self.audio_dit._init_freqs() + audio_freqs = ( + torch.cat( + [ + self.audio_dit.freqs[0][:f].view(f, -1).expand(f, -1), + self.audio_dit.freqs[1][:f].view(f, -1).expand(f, -1), + self.audio_dit.freqs[2][:f].view(f, -1).expand(f, -1), + ], + dim=-1, + ) + .reshape(full_audio_seq_len, 1, -1) + .to(audio_x.device) + ) + + # Shard sequences for SP + visual_x, visual_pad_len = self._shard_sequence_for_sp(visual_x, dim=1) + audio_x, audio_pad_len = self._shard_sequence_for_sp(audio_x, dim=1) + + # Shard freqs to match local sequence length + visual_freqs, _ = self._shard_sequence_for_sp(visual_freqs, dim=0) + audio_freqs, _ = self._shard_sequence_for_sp(audio_freqs, dim=0) + + # Forward through dual-tower DiT + visual_x, audio_x = self.forward_dual_tower_dit( + visual_dit=visual_dit, + visual_x=visual_x, + audio_x=audio_x, + visual_context=visual_context_emb, + audio_context=audio_context_emb, + visual_t_mod=visual_t_mod, + audio_t_mod=audio_t_mod, + visual_freqs=visual_freqs, + audio_freqs=audio_freqs, + grid_size=grid_size, + video_fps=video_fps, + full_visual_seq_len=full_visual_seq_len, + full_audio_seq_len=full_audio_seq_len, + ) + + # Gather sequences back from SP before head/unpatchify + visual_x = self._gather_sequence_from_sp(visual_x, visual_pad_len, dim=1) + audio_x = self._gather_sequence_from_sp(audio_x, audio_pad_len, dim=1) + + visual_output = visual_dit.head(visual_x, visual_t) + visual_output = visual_dit.unpatchify(visual_output, grid_size) + + audio_output = self.audio_dit.head(audio_x, audio_t) + audio_output = self.audio_dit.unpatchify(audio_output, (f,)) + + return visual_output.float(), audio_output.float() + + def forward_dual_tower_dit( + self, + visual_dit, + visual_x: torch.Tensor, + audio_x: torch.Tensor, + visual_context: torch.Tensor, + audio_context: torch.Tensor, + visual_t_mod: torch.Tensor, + audio_t_mod: torch.Tensor, + visual_freqs: torch.Tensor, + audio_freqs: torch.Tensor, + grid_size: tuple[int, int, int], + video_fps: float, + full_visual_seq_len: int, + full_audio_seq_len: int, + condition_scale: float | None = 1.0, + a2v_condition_scale: float | None = None, + v2a_condition_scale: float | None = None, + ): + """ + Forward pass through dual-tower DiT with cross-modal interaction. + + Sequence Parallelism (SP) Support: + - visual_x and audio_x are already sharded along sequence dimension + - visual_freqs and audio_freqs match the local sequence length + - USPAttention in self-attention handles distributed communication + - LocalAttention in cross-attention operates on local sequence vs replicated context + - Cross-modal attention (dual_tower_bridge) uses LocalAttention (no SP communication) + + Args: + full_visual_seq_len: Full visual sequence length before SP sharding + full_audio_seq_len: Full audio sequence length before SP sharding + """ + min_layers = min(len(visual_dit.blocks), len(self.audio_dit.blocks)) + visual_layers = len(visual_dit.blocks) + sp_size = get_sp_world_size() + + # Build RoPE frequencies for cross-attention if needed (only used when SP == 1) + # When SP > 1, we rebuild freqs inside the loop after gathering full sequences + visual_rope_cos_sin, audio_rope_cos_sin = ( + self.dual_tower_bridge.build_aligned_freqs( + video_fps=video_fps, + grid_size=grid_size, + audio_steps=full_audio_seq_len, + device=visual_x.device, + dtype=visual_x.dtype, + ) + ) + if visual_rope_cos_sin is not None: + visual_rope_cos_sin = [ + self._shard_sequence_for_sp(rope_cos_sin, dim=1)[0] + for rope_cos_sin in visual_rope_cos_sin + ] + if audio_rope_cos_sin is not None: + audio_rope_cos_sin = [ + self._shard_sequence_for_sp(rope_cos_sin, dim=1)[0] + for rope_cos_sin in audio_rope_cos_sin + ] + + for layer_idx in range(min_layers): + visual_block = visual_dit.blocks[layer_idx] + audio_block = self.audio_dit.blocks[layer_idx] + + # Cross-modal interaction via dual tower bridge + # Bridge operations (PerFrameAttentionPooling, RoPE) expect full sequences + # When SP is enabled, we need to gather before bridge and shard after + if self.dual_tower_bridge.should_interact(layer_idx, "a2v"): + visual_x, audio_x = self.dual_tower_bridge( + layer_idx, + visual_x, + audio_x, + x_freqs=visual_rope_cos_sin, + y_freqs=audio_rope_cos_sin, + a2v_condition_scale=a2v_condition_scale, + v2a_condition_scale=v2a_condition_scale, + condition_scale=condition_scale, + video_grid_size=grid_size, + ) + + # Self-attention and FFN in DiT blocks + visual_x = visual_block( + visual_x, visual_context, visual_t_mod, visual_freqs + ) + audio_x = audio_block(audio_x, audio_context, audio_t_mod, audio_freqs) + + # Process remaining visual layers (if visual has more layers than audio) + for layer_idx in range(min_layers, visual_layers): + visual_block = visual_dit.blocks[layer_idx] + visual_x = visual_block( + visual_x, visual_context, visual_t_mod, visual_freqs + ) + + return visual_x, audio_x + + +class MOVADecodingStage(PipelineStage): + """Decode video and audio outputs for MOVA.""" + + def __init__(self, video_vae, audio_vae) -> None: + super().__init__() + self.video_vae = video_vae + self.audio_vae = audio_vae + + @property + def parallelism_type(self) -> StageParallelismType: + if get_global_server_args().enable_cfg_parallel: + return StageParallelismType.MAIN_RANK_ONLY + return StageParallelismType.REPLICATED + + @torch.no_grad() + def forward(self, batch: Req, server_args: ServerArgs) -> OutputBatch: + self.video_vae = self.video_vae.to(get_local_torch_device()) + self.audio_vae = self.audio_vae.to(get_local_torch_device()) + + video_latents = server_args.pipeline_config.denormalize_video_latents( + batch.latents, self.video_vae + ) + + vae_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision] + vae_autocast_enabled = ( + vae_dtype != torch.float32 + ) and not server_args.disable_autocast + + with torch.autocast( + device_type=current_platform.device_type, + dtype=vae_dtype, + enabled=vae_autocast_enabled, + ): + if server_args.pipeline_config.vae_tiling: + self.video_vae.enable_tiling() + if not vae_autocast_enabled: + video_latents = video_latents.to(vae_dtype) + decode_output = self.video_vae.decode(video_latents) + video = _ensure_tensor_decode_output(decode_output) + + video = (video / 2 + 0.5).clamp(0, 1) + + with torch.autocast( + device_type=current_platform.device_type, dtype=torch.float32 + ): + audio = self.audio_vae.decode(batch.audio_latents) + output_batch = OutputBatch( + output=video, + audio=audio, + audio_sample_rate=getattr(self.audio_vae, "sample_rate", None), + metrics=batch.metrics, + ) + return output_batch diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/qwen_image_layered.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/qwen_image_layered.py new file mode 100644 index 0000000000000000000000000000000000000000..4d64759c52fe6dacd7eb4c2934d18ce5ed1a5f8e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/qwen_image_layered.py @@ -0,0 +1,529 @@ +import inspect +import math +from typing import List, Optional, Union + +import numpy as np +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.utils.torch_utils import randn_tensor + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.models.vision_utils import load_image +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus.calculate_dimensions +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, + generator: Optional[torch.Generator] = None, + sample_mode: str = "sample", +): + if sample_mode == "sample": + return encoder_output.sample(generator) + elif sample_mode == "argmax": + return encoder_output.mode() + else: + return encoder_output + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class QwenImageLayeredBeforeDenoisingStage(PipelineStage): + def __init__( + self, vae, tokenizer, processor, transformer, scheduler, model_path + ) -> None: + super().__init__() + self.vae = vae.to(torch.bfloat16) + from transformers import Qwen2_5_VLForConditionalGeneration + + self.text_encoder = ( + Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_path, subfolder="text_encoder" + ) + .to(get_local_torch_device()) + .to(torch.bfloat16) + ) + self.tokenizer = tokenizer + self.processor = processor + self.transformer = transformer + self.scheduler = scheduler + + self.vae_scale_factor = ( + 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + ) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2 + ) + self.vl_processor = processor + self.tokenizer_max_length = 1024 + self.latent_channels = self.vae.z_dim if getattr(self, "vae", None) else 16 + + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 34 + self.image_caption_prompt_cn = """<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n# 图像标注器\n你是一个专业的图像标注器。请基于输入图像,撰写图注:\n1. +使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。\n2. 通过加入以下内容,丰富图注细节:\n - 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等\n - +对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等\n - 环境细节:例如天气、光照、颜色、纹理、气氛等\n - 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在图注中强调\n3. +保持真实性与准确性:\n - 不要使用笼统的描述\n - +描述图像中所有可见的信息,但不要加入没有在图像中出现的内容\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>assistant\n""" + self.image_caption_prompt_en = """<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n# Image Annotator\nYou are a professional +image annotator. Please write an image caption based on the input image:\n1. Write the caption using natural, +descriptive language without structured formats or rich text.\n2. Enrich caption details by including: \n - Object +attributes, such as quantity, color, shape, size, material, state, position, actions, and so on\n - Vision Relations +between objects, such as spatial relations, functional relations, possessive relations, attachment relations, action +relations, comparative relations, causal relations, and so on\n - Environmental details, such as weather, lighting, +colors, textures, atmosphere, and so on\n - Identify the text clearly visible in the image, without translation or +explanation, and highlight it in the caption with quotation marks\n3. Maintain authenticity and accuracy:\n - Avoid +generalizations\n - Describe all visible information in the image, while do not add information not explicitly shown in +the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>assistant\n""" + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def get_image_caption(self, prompt_image, use_en_prompt=True, device=None): + if use_en_prompt: + prompt = self.image_caption_prompt_en + else: + prompt = self.image_caption_prompt_cn + model_inputs = self.vl_processor( + text=prompt, + images=prompt_image, + padding=True, + return_tensors="pt", + ).to(device) + with set_forward_context(current_timestep=0, attn_metadata=None): + generated_ids = self.text_encoder.generate( + **model_inputs, max_new_tokens=512 + ) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] + for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids) + ] + output_text = self.vl_processor.batch_decode( + generated_ids_trimmed, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + )[0] + return output_text.strip() + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, + padding=True, + return_tensors="pt", + ).to(device) + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden( + hidden_states, txt_tokens.attention_mask + ) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [ + torch.ones(e.size(0), dtype=torch.long, device=e.device) + for e in split_hidden_states + ] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [ + torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) + for u in split_hidden_states + ] + ) + encoder_attention_mask = torch.stack( + [ + torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) + for u in attn_mask_list + ] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width, layers): + latents = latents.view( + batch_size, layers, num_channels_latents, height // 2, 2, width // 2, 2 + ) + latents = latents.permute(0, 1, 3, 5, 2, 4, 6) + latents = latents.reshape( + batch_size, layers * (height // 2) * (width // 2), num_channels_latents * 4 + ) + + return latents + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, device + ) + + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + self.vae = self.vae.to(get_local_torch_device()) + if isinstance(generator, list): + image_latents = [ + retrieve_latents( + self.vae.encode(image[i : i + 1]), + generator=generator[i], + sample_mode="argmax", + ) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents( + self.vae.encode(image), generator=generator, sample_mode="argmax" + ) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + self.vae.to("cpu") + return image_latents + + def prepare_latents( + self, + image, + batch_size, + num_channels_latents, + height, + width, + layers, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = ( + batch_size, + layers + 1, + num_channels_latents, + height, + width, + ) ### the generated first image is combined image + + image_latents = None + if image is not None: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if ( + batch_size > image_latents.shape[0] + and batch_size % image_latents.shape[0] == 0 + ): + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat( + [image_latents] * additional_image_per_prompt, dim=0 + ) + elif ( + batch_size > image_latents.shape[0] + and batch_size % image_latents.shape[0] != 0 + ): + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = image_latents.permute( + 0, 2, 1, 3, 4 + ) # (b, c, f, h, w) -> (b, f, c, h, w) + image_latents = self._pack_latents( + image_latents, + batch_size, + num_channels_latents, + image_latent_height, + image_latent_width, + 1, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) + latents = self._pack_latents( + latents, batch_size, num_channels_latents, height, width, layers + 1 + ) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents + + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + use_en_prompt = True + device = get_local_torch_device() + layers = batch.num_frames + num_inference_steps = batch.num_inference_steps + generator = batch.generator + + assert batch.image_path is not None + image = load_image(batch.image_path[0]) + image = image.convert("RGBA") + image_size = image.size + resolution = 640 # TODO: support user-specified resolution + calculated_width, calculated_height = calculate_dimensions( + resolution * resolution, image_size[0] / image_size[1] + ) + + height = calculated_height + width = calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, calculated_height, calculated_width) + prompt_image = image + image = self.image_processor.preprocess( + image, calculated_height, calculated_width + ) + image = image.unsqueeze(2) + image = image.to(dtype=torch.bfloat16) + + prompt = self.get_image_caption( + prompt_image, use_en_prompt=use_en_prompt, device=device + ) + + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + device=device, + ) + + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=batch.negative_prompt, + device=device, + ) + + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents = self.prepare_latents( + image, + 1, + num_channels_latents, + height, + width, + layers, + prompt_embeds.dtype, + device, + generator, + ) + img_shapes = [ + [ + *[ + ( + 1, + height // self.vae_scale_factor // 2, + width // self.vae_scale_factor // 2, + ) + for _ in range(layers + 1) + ], + ( + 1, + calculated_height // self.vae_scale_factor // 2, + calculated_width // self.vae_scale_factor // 2, + ), + ] + ] + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 0, num_inference_steps + 1)[:-1] + image_seq_len = latents.shape[1] + base_seqlen = 256 * 256 / 16 / 16 + mu = (image_latents.shape[1] / base_seqlen) ** 0.5 + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + + txt_seq_lens = ( + prompt_embeds_mask.sum(dim=1).tolist() + if prompt_embeds_mask is not None + else None + ) + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() + if negative_prompt_embeds_mask is not None + else None + ) + is_rgb = torch.tensor([0]).to(device=device, dtype=torch.long) + + batch.prompt_embeds = [prompt_embeds] + batch.prompt_embeds_mask = [prompt_embeds_mask] + batch.negative_prompt_embeds = [negative_prompt_embeds] + batch.negative_prompt_embeds_mask = [negative_prompt_embeds_mask] + batch.latents = latents + batch.image_latent = image_latents + batch.num_inference_steps = num_inference_steps + batch.sigmas = sigmas.tolist() # Convert numpy array to list for validation + batch.generator = torch.manual_seed(0) + batch.original_condition_image_size = image_size + batch.raw_latent_shape = latents.shape + batch.txt_seq_lens = txt_seq_lens + batch.img_shapes = img_shapes + + return batch diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/text_connector.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/text_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..93baeba6e5673199fad4ca88ec75cd6b6680c90f --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/text_connector.py @@ -0,0 +1,92 @@ +import torch + +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage +from sglang.multimodal_gen.runtime.server_args import ServerArgs + + +class LTX2TextConnectorStage(PipelineStage): + """ + Stage for applying LTX-2 Text Connectors to split/transform text embeddings + into video and audio contexts. + """ + + def __init__(self, connectors): + super().__init__() + self.connectors = connectors + + def forward(self, batch: Req, server_args: ServerArgs) -> Req: + # Input: batch.prompt_embeds (from Gemma, [B, S, D]) + # Output: batch.prompt_embeds (Video Context), batch.audio_prompt_embeds (Audio Context) + + prompt_embeds = batch.prompt_embeds + prompt_attention_mask = batch.prompt_attention_mask + neg_prompt_embeds = batch.negative_prompt_embeds + neg_prompt_attention_mask = batch.negative_attention_mask + + if isinstance(prompt_embeds, list): + prompt_embeds = prompt_embeds[0] if len(prompt_embeds) > 0 else None + + if isinstance(prompt_attention_mask, list): + prompt_attention_mask = ( + prompt_attention_mask[0] if len(prompt_attention_mask) > 0 else None + ) + + if isinstance(neg_prompt_embeds, list): + neg_prompt_embeds = ( + neg_prompt_embeds[0] if len(neg_prompt_embeds) > 0 else None + ) + + if isinstance(neg_prompt_attention_mask, list): + neg_prompt_attention_mask = ( + neg_prompt_attention_mask[0] + if len(neg_prompt_attention_mask) > 0 + else None + ) + + # Handle CFG: Concatenate negative and positive inputs + if batch.do_classifier_free_guidance: + + # Concatenate: [Negative, Positive] + prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat( + [neg_prompt_attention_mask, prompt_attention_mask], dim=0 + ) + + # Prepare additive mask for connectors (as per Diffusers implementation) + dtype = prompt_embeds.dtype + + additive_attention_mask = (1 - prompt_attention_mask.to(dtype)) * -1000000.0 + + # Call connectors + # Expects: prompt_embeds, attention_mask, additive_mask=True + with set_forward_context(current_timestep=None, attn_metadata=None): + connector_prompt_embeds, connector_audio_prompt_embeds, connector_mask = ( + self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + ) + + # Split results if CFG was enabled + if batch.do_classifier_free_guidance: + neg_embeds, pos_embeds = connector_prompt_embeds.chunk(2, dim=0) + neg_audio_embeds, pos_audio_embeds = connector_audio_prompt_embeds.chunk( + 2, dim=0 + ) + neg_mask, pos_mask = connector_mask.chunk(2, dim=0) + + batch.prompt_embeds = [pos_embeds] + batch.audio_prompt_embeds = [pos_audio_embeds] + batch.prompt_attention_mask = pos_mask + + batch.negative_prompt_embeds = [neg_embeds] + batch.negative_audio_prompt_embeds = [neg_audio_embeds] + batch.negative_attention_mask = neg_mask + else: + # Update positive fields + batch.prompt_embeds = [connector_prompt_embeds] + batch.audio_prompt_embeds = [connector_audio_prompt_embeds] + batch.prompt_attention_mask = connector_mask + + return batch diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/text_encoding.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/text_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..7b51cff5f61d820656a6ccad12cd3af4b533f0e9 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/text_encoding.py @@ -0,0 +1,332 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Prompt encoding stages for diffusion pipelines. + +This module contains implementations of prompt encoding stages for diffusion pipelines. +""" + +import torch + +from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput +from sglang.multimodal_gen.configs.pipeline_configs import FluxPipelineConfig +from sglang.multimodal_gen.configs.pipeline_configs.flux import Flux2PipelineConfig +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + StageValidators as V, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + VerificationResult, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class TextEncodingStage(PipelineStage): + """ + Stage for encoding text prompts into embeddings for diffusion models. + + This stage handles the encoding of text prompts into the embedding space + expected by the diffusion model. + """ + + def __init__(self, text_encoders, tokenizers) -> None: + """ + Initialize the prompt encoding stage. + + """ + super().__init__() + self.tokenizers = tokenizers + self.text_encoders = text_encoders + + @torch.no_grad() + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Encode the prompt into text encoder hidden states. + """ + assert len(self.tokenizers) == len(self.text_encoders) + assert len(self.text_encoders) == len( + server_args.pipeline_config.text_encoder_configs + ) + + # Encode positive prompt with all available encoders + assert batch.prompt is not None + prompt_text: str | list[str] = batch.prompt + + all_indices: list[int] = list(range(len(self.text_encoders))) + + prompt_embeds_list, prompt_masks_list, pooler_embeds_list = self.encode_text( + prompt_text, + server_args, + encoder_index=all_indices, + return_attention_mask=True, + ) + + for pe in prompt_embeds_list: + batch.prompt_embeds.append(pe) + + for pe in pooler_embeds_list: + batch.pooled_embeds.append(pe) + + if batch.prompt_attention_mask is None: + batch.prompt_attention_mask = [] + for am in prompt_masks_list: + batch.prompt_attention_mask.append(am) + + # Encode negative prompt if CFG is enabled + if batch.do_classifier_free_guidance: + assert isinstance(batch.negative_prompt, str) + neg_embeds_list, neg_masks_list, neg_pooler_embeds_list = self.encode_text( + batch.negative_prompt, + server_args, + encoder_index=all_indices, + return_attention_mask=True, + ) + + assert batch.negative_prompt_embeds is not None + + for ne in neg_embeds_list: + batch.negative_prompt_embeds.append(ne) + + for pe in neg_pooler_embeds_list: + batch.neg_pooled_embeds.append(pe) + if batch.negative_attention_mask is None: + batch.negative_attention_mask = [] + for nm in neg_masks_list: + batch.negative_attention_mask.append(nm) + + return batch + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify text encoding stage inputs.""" + result = VerificationResult() + result.add_check("prompt", batch.prompt, V.string_or_list_strings) + result.add_check( + "negative_prompt", + batch.negative_prompt, + lambda x: not batch.do_classifier_free_guidance or V.string_not_none(x), + ) + result.add_check( + "do_classifier_free_guidance", + batch.do_classifier_free_guidance, + V.bool_value, + ) + result.add_check("prompt_embeds", batch.prompt_embeds, V.is_list) + result.add_check( + "negative_prompt_embeds", batch.negative_prompt_embeds, V.none_or_list + ) + return result + + def prepare_tokenizer_kwargs(self, tokenizer_kwargs, **kwargs): + tok_kwargs = tokenizer_kwargs | kwargs + + return tok_kwargs + + @torch.no_grad() + def encode_text( + self, + text: str | list[str], + server_args: ServerArgs, + encoder_index: int | list[int] | None = None, + return_attention_mask: bool = False, + return_type: str = "list", # one of: "list", "dict", "stack" + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + max_length: int | None = None, + truncation: bool | None = None, + padding: bool | str | None = None, + return_overflowing_tokens=None, + return_length=None, + ): + """ + Encode plain text using selected text encoder(s) and return embeddings. + + Args: + text: A single string or a list of strings to encode. + server_args: The inference arguments providing pipeline config, + including tokenizer and encoder settings, preprocess and postprocess + functions. + encoder_index: Encoder selector by index. Accepts an int or list of ints. + return_attention_mask: If True, also return attention masks for each + selected encoder. + return_type: "list" (default) returns a list aligned with selection; + "dict" returns a dict keyed by encoder index as a string; "stack" stacks along a + new first dimension (requires matching shapes). + device: Optional device override for inputs; defaults to local torch device. + dtype: Optional dtype to cast returned embeddings to. + max_length: Optional per-call tokenizer override. + truncation: Optional per-call tokenizer override. + padding: Optional per-call tokenizer override. + + Returns: + Depending on return_type and return_attention_mask: + - list: List[Tensor] or (List[Tensor], List[Tensor]) + - dict: Dict[str, Tensor] or (Dict[str, Tensor], Dict[str, Tensor]) + - stack: Tensor of shape [num_encoders, ...] or a tuple with stacked + attention masks + """ + + assert len(self.tokenizers) == len(self.text_encoders) + assert len(self.text_encoders) == len( + server_args.pipeline_config.text_encoder_configs + ) + + # Resolve selection into indices + encoder_cfgs = server_args.pipeline_config.text_encoder_configs + if encoder_index is None: + indices: list[int] = [0] + elif isinstance(encoder_index, int): + indices = [encoder_index] + else: + indices = list(encoder_index) + # validate range + num_encoders = len(self.text_encoders) + for idx in indices: + if idx < 0 or idx >= num_encoders: + raise IndexError( + f"encoder index {idx} out of range [0, {num_encoders - 1}]" + ) + + # Validate indices are within range + num_encoders = len(self.text_encoders) + + # Normalize input to list[str] + assert isinstance(text, str | list) + if isinstance(text, str): + texts: list[str] = [text] + else: + texts = text + + embeds_list: list[torch.Tensor] = [] + pooled_embeds_list: list[torch.Tensor] = [] + + attn_masks_list: list[torch.Tensor] = [] + + preprocess_funcs = server_args.pipeline_config.preprocess_text_funcs + postprocess_funcs = server_args.pipeline_config.postprocess_text_funcs + text_encoder_extra_args = server_args.pipeline_config.text_encoder_extra_args + encoder_cfgs = server_args.pipeline_config.text_encoder_configs + + if return_type not in ("list", "dict", "stack"): + raise ValueError( + f"Invalid return_type '{return_type}'. Expected one of: 'list', 'dict', 'stack'" + ) + + target_device = device if device is not None else get_local_torch_device() + + for i in indices: + tokenizer = self.tokenizers[i] + text_encoder = self.text_encoders[i] + encoder_config = encoder_cfgs[i] + preprocess_func = preprocess_funcs[i] + postprocess_func = postprocess_funcs[i] + text_encoder_extra_arg = ( + text_encoder_extra_args[i] + if i < len(text_encoder_extra_args) and text_encoder_extra_args[i] + else {} + ) + + processed_text_list: list[str] = [] + for prompt_str in texts: + preprocessed = preprocess_func(prompt_str) + processed_text_list.append(preprocessed) + + # Prepare tokenizer args + tok_kwargs = self.prepare_tokenizer_kwargs( + encoder_config.tokenizer_kwargs, + **text_encoder_extra_arg, + ) + + text_inputs: dict = server_args.pipeline_config.tokenize_prompt( + processed_text_list, tokenizer, tok_kwargs + ).to(target_device) + + input_ids = text_inputs["input_ids"] + is_flux_v1 = isinstance( + server_args.pipeline_config, FluxPipelineConfig + ) and not isinstance(server_args.pipeline_config, Flux2PipelineConfig) + is_flux_t5 = is_flux_v1 and i == 1 + + if is_flux_t5: + attention_mask = torch.ones(input_ids.shape[:2], device=target_device) + else: + attention_mask = text_inputs["attention_mask"] + with set_forward_context(current_timestep=0, attn_metadata=None): + outputs: BaseEncoderOutput = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + prompt_embeds = postprocess_func(outputs, text_inputs) + if dtype is not None: + prompt_embeds = prompt_embeds.to(dtype=dtype) + + embeds_list.append(prompt_embeds) + if is_flux_v1: + pooled_embeds_list.append(outputs.pooler_output) + if return_attention_mask: + attn_masks_list.append(attention_mask) + + # Shape results according to return_type + if return_type == "list": + if return_attention_mask: + return embeds_list, attn_masks_list, pooled_embeds_list + return embeds_list, pooled_embeds_list + + if return_type == "dict": + key_strs = [str(i) for i in indices] + embeds_dict = {k: v for k, v in zip(key_strs, embeds_list, strict=False)} + if return_attention_mask: + attn_dict = { + k: v for k, v in zip(key_strs, attn_masks_list, strict=False) + } + return embeds_dict, attn_dict + return embeds_dict + + # return_type == "stack" + # Validate shapes are compatible + base_shape = list(embeds_list[0].shape) + for t in embeds_list[1:]: + if list(t.shape) != base_shape: + raise ValueError( + f"Cannot stack embeddings with differing shapes: {[list(t.shape) for t in embeds_list]}" + ) + stacked_embeds = torch.stack(embeds_list, dim=0) + if return_attention_mask: + base_mask_shape = list(attn_masks_list[0].shape) + for m in attn_masks_list[1:]: + if list(m.shape) != base_mask_shape: + raise ValueError( + f"Cannot stack attention masks with differing shapes: {[list(m.shape) for m in attn_masks_list]}" + ) + stacked_masks = torch.stack(attn_masks_list, dim=0) + return stacked_embeds, stacked_masks + return stacked_embeds + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify text encoding stage outputs.""" + result = VerificationResult() + result.add_check( + "prompt_embeds", batch.prompt_embeds, V.list_of_tensors_min_dims(2) + ) + result.add_check( + "negative_prompt_embeds", + batch.negative_prompt_embeds, + lambda x: not batch.do_classifier_free_guidance + or V.list_of_tensors_with_min_dims(x, 2), + ) + if batch.debug: + logger.debug(f"{batch.prompt_embeds=}") + logger.debug(f"{batch.negative_prompt_embeds=}") + return result diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/timestep_preparation.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/timestep_preparation.py new file mode 100644 index 0000000000000000000000000000000000000000..44cf8c1196fdfb05cff9a98d187e88e1f33f2001 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/timestep_preparation.py @@ -0,0 +1,166 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Timestep preparation stages for diffusion pipelines. + +This module contains implementations of timestep preparation stages for diffusion pipelines. +""" + +import inspect +from typing import Any, Callable, Tuple + +import torch + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.pipelines_core.stages.base import ( + PipelineStage, + StageParallelismType, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + StageValidators as V, +) +from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( + VerificationResult, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class TimestepPreparationStage(PipelineStage): + """ + Stage for preparing timesteps for the diffusion process. + + This stage handles the preparation of the timestep sequence that will be used + during the diffusion process. + """ + + def __init__( + self, + scheduler, + prepare_extra_set_timesteps_kwargs: list[ + Callable[[Req, ServerArgs], Tuple[str, Any]] + ] = [], + ) -> None: + super().__init__() + self.scheduler = scheduler + self.prepare_extra_set_timesteps_kwargs = ( + prepare_extra_set_timesteps_kwargs or [] + ) + + @property + def parallelism_type(self) -> StageParallelismType: + return StageParallelismType.REPLICATED + + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Prepare timesteps for the diffusion process. + + + + Returns: + The batch with prepared timesteps. + """ + scheduler = self.scheduler + device = get_local_torch_device() + num_inference_steps = batch.num_inference_steps + timesteps = batch.timesteps + sigmas = batch.sigmas + n_tokens = batch.n_tokens + + sigmas = server_args.pipeline_config.prepare_sigmas(sigmas, num_inference_steps) + batch.sigmas = sigmas + + # Prepare extra kwargs for set_timesteps + extra_set_timesteps_kwargs = {} + if ( + n_tokens is not None + and "n_tokens" in inspect.signature(scheduler.set_timesteps).parameters + ): + extra_set_timesteps_kwargs["n_tokens"] = n_tokens + + for callee in self.prepare_extra_set_timesteps_kwargs: + key, value = callee(batch, server_args) + assert isinstance(key, str) + extra_set_timesteps_kwargs[key] = value + if key == "mu": + batch.extra["mu"] = value + + # Handle custom timesteps or sigmas + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + + if timesteps is not None: + accepts_timesteps = ( + "timesteps" in inspect.signature(scheduler.set_timesteps).parameters + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps( + timesteps=timesteps, device=device, **extra_set_timesteps_kwargs + ) + timesteps = scheduler.timesteps + elif sigmas is not None: + accept_sigmas = ( + "sigmas" in inspect.signature(scheduler.set_timesteps).parameters + ) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps( + sigmas=sigmas, device=device, **extra_set_timesteps_kwargs + ) + timesteps = scheduler.timesteps + else: + scheduler.set_timesteps( + num_inference_steps, device=device, **extra_set_timesteps_kwargs + ) + timesteps = scheduler.timesteps + + # Update batch with prepared timesteps + batch.timesteps = timesteps + if not batch.is_warmup: + self.log_debug("timesteps: %s", timesteps) + return batch + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify timestep preparation stage inputs.""" + result = VerificationResult() + result.add_check( + "num_inference_steps", batch.num_inference_steps, V.positive_int + ) + result.add_check("timesteps", batch.timesteps, V.none_or_tensor) + result.add_check("sigmas", batch.sigmas, V.none_or_list) + result.add_check("n_tokens", batch.n_tokens, V.none_or_positive_int) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify timestep preparation stage outputs.""" + if ( + batch.is_warmup + and isinstance(batch.timesteps, torch.Tensor) + and torch.isnan(batch.timesteps).any() + ): + # when num-inference-steps == 1, the last sigma being 1, the 1 / last_sigma could be nan + # this a workaround for warmup req only + batch.timesteps = torch.ones( + (1,), dtype=torch.float32, device=get_local_torch_device() + ) + + result = VerificationResult() + result.add_check("timesteps", batch.timesteps, [V.is_tensor, V.with_dims(1)]) + return result diff --git a/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/validators.py b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/validators.py new file mode 100644 index 0000000000000000000000000000000000000000..1ca9e992d7bf0bcf1fd0b0d67a0c6becfc5d8df9 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/pipelines_core/stages/validators.py @@ -0,0 +1,522 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Common validators for pipeline stage verification. + +This module provides reusable validation functions that can be used across +all pipeline stages for input/output verification. +""" + +from collections.abc import Callable +from typing import Any + +import torch + + +class StageValidators: + """Common validators for pipeline stages.""" + + @staticmethod + def not_none(value: Any) -> bool: + """Check if value is not None.""" + return value is not None + + @staticmethod + def positive_int(value: Any) -> bool: + """Check if value is a positive integer.""" + return isinstance(value, int) and value > 0 + + @staticmethod + def non_negative_int(value: Any) -> bool: + """Check if value is a non-negative float.""" + return isinstance(value, int | float) and value >= 0 + + @staticmethod + def positive_float(value: Any) -> bool: + """Check if value is a positive float.""" + return isinstance(value, int | float) and value > 0 + + @staticmethod + def non_negative_float(value: Any) -> bool: + """Check if value is a non-negative float.""" + return isinstance(value, int | float) and value >= 0 + + @staticmethod + def divisible_by(value: Any, divisor: int) -> bool: + """Check if value is divisible by divisor.""" + return value is not None and isinstance(value, int) and value % divisor == 0 + + @staticmethod + def is_tensor(value: Any) -> bool: + """Check if value is a torch tensor and doesn't contain NaN values.""" + if not isinstance(value, torch.Tensor): + return False + return not torch.isnan(value).any().item() + + @staticmethod + def tensor_with_dims(value: Any, dims: int) -> bool: + """Check if value is a tensor with specific dimensions and no NaN values.""" + if not isinstance(value, torch.Tensor): + return False + if value.dim() != dims: + return False + return not torch.isnan(value).any().item() + + @staticmethod + def tensor_min_dims(value: Any, min_dims: int) -> bool: + """Check if value is a tensor with at least min_dims dimensions and no NaN values.""" + if not isinstance(value, torch.Tensor): + return False + if value.dim() < min_dims: + return False + return not torch.isnan(value).any().item() + + @staticmethod + def tensor_shape_matches(value: Any, expected_shape: tuple) -> bool: + """Check if tensor shape matches expected shape (None for any size) and no NaN values.""" + if not isinstance(value, torch.Tensor): + return False + if len(value.shape) != len(expected_shape): + return False + for actual, expected in zip(value.shape, expected_shape, strict=True): + if expected is not None and actual != expected: + return False + return not torch.isnan(value).any().item() + + @staticmethod + def list_not_empty(value: Any) -> bool: + """Check if value is a non-empty list.""" + return isinstance(value, list) and len(value) > 0 + + @staticmethod + def list_length(value: Any, length: int) -> bool: + """Check if list has specific length.""" + return isinstance(value, list) and len(value) == length + + @staticmethod + def list_min_length(value: Any, min_length: int) -> bool: + """Check if list has at least min_length items.""" + return isinstance(value, list) and len(value) >= min_length + + @staticmethod + def string_not_empty(value: Any) -> bool: + """Check if value is a non-empty string.""" + return isinstance(value, str) and len(value.strip()) > 0 + + @staticmethod + def string_not_none(value: Any) -> bool: + """Check if value is a non-empty string.""" + return isinstance(value, str) and len(value) > 0 + + @staticmethod + def string_or_list_strings(value: Any) -> bool: + """Check if value is a string or list of strings.""" + if isinstance(value, str): + return True + if isinstance(value, list): + return all(isinstance(item, str) for item in value) + return False + + @staticmethod + def bool_value(value: Any) -> bool: + """Check if value is a boolean.""" + return isinstance(value, bool) + + @staticmethod + def generator_or_list_generators(value: Any) -> bool: + """Check if value is a Generator or list of Generators.""" + if isinstance(value, torch.Generator): + return True + if isinstance(value, list): + return all(isinstance(item, torch.Generator) for item in value) + return False + + @staticmethod + def is_list(value: Any) -> bool: + """Check if value is a list (can be empty).""" + return isinstance(value, list) + + @staticmethod + def is_tuple(value: Any) -> bool: + """Check if value is a tuple.""" + return isinstance(value, tuple) + + @staticmethod + def none_or_tensor(value: Any) -> bool: + """Check if value is None or a tensor without NaN values.""" + if value is None: + return True + if not isinstance(value, torch.Tensor): + return False + return not torch.isnan(value).any().item() + + @staticmethod + def list_of_tensors_with_dims(value: Any, dims: int) -> bool: + """Check if value is a non-empty list where all items are tensors with specific dimensions and no NaN values.""" + if not isinstance(value, list) or len(value) == 0: + return False + for item in value: + if not isinstance(item, torch.Tensor): + return False + if item.dim() != dims: + return False + if torch.isnan(item).any().item(): + return False + return True + + @staticmethod + def list_of_tensors(value: Any) -> bool: + """Check if value is a non-empty list where all items are tensors without NaN values.""" + if not isinstance(value, list) or len(value) == 0: + return False + for item in value: + if not isinstance(item, torch.Tensor): + return False + if torch.isnan(item).any().item(): + return False + return True + + @staticmethod + def list_of_tensors_with_min_dims(value: Any, min_dims: int) -> bool: + """Check if value is a non-empty list where all items are tensors with at least min_dims dimensions and no NaN values.""" + if not isinstance(value, list) or len(value) == 0: + return False + for item in value: + if not isinstance(item, torch.Tensor): + return False + if item.dim() < min_dims: + return False + if torch.isnan(item).any().item(): + return False + return True + + @staticmethod + def none_or_tensor_with_dims(dims: int) -> Callable[[Any], bool]: + """Return a validator that checks if value is None or a tensor with specific dimensions and no NaN values.""" + + def validator(value: Any) -> bool: + if value is None: + return True + if not isinstance(value, torch.Tensor): + return False + if value.dim() != dims: + return False + return not torch.isnan(value).any().item() + + return validator + + @staticmethod + def none_or_list(value: Any) -> bool: + """Check if value is None or a list.""" + return value is None or isinstance(value, list) + + @staticmethod + def none_or_positive_int(value: Any) -> bool: + """Check if value is None or a positive integer.""" + return value is None or (isinstance(value, int) and value > 0) + + # Helper methods that return functions for common patterns + @staticmethod + def with_dims(dims: int) -> Callable[[Any], bool]: + """Return a validator that checks if tensor has specific dimensions and no NaN values.""" + + def validator(value: Any) -> bool: + return StageValidators.tensor_with_dims(value, dims) + + return validator + + @staticmethod + def min_dims(min_dims: int) -> Callable[[Any], bool]: + """Return a validator that checks if tensor has at least min_dims dimensions and no NaN values.""" + + def validator(value: Any) -> bool: + return StageValidators.tensor_min_dims(value, min_dims) + + return validator + + @staticmethod + def divisible(divisor: int) -> Callable[[Any], bool]: + """Return a validator that checks if value is divisible by divisor.""" + + def validator(value: Any) -> bool: + return StageValidators.divisible_by(value, divisor) + + return validator + + @staticmethod + def positive_int_divisible(divisor: int) -> Callable[[Any], bool]: + """Return a validator that checks if value is a positive integer divisible by divisor.""" + + def validator(value: Any) -> bool: + return ( + isinstance(value, int) + and value > 0 + and StageValidators.divisible_by(value, divisor) + ) + + return validator + + @staticmethod + def list_of_tensors_dims(dims: int) -> Callable[[Any], bool]: + """Return a validator that checks if value is a list of tensors with specific dimensions and no NaN values.""" + + def validator(value: Any) -> bool: + return StageValidators.list_of_tensors_with_dims(value, dims) + + return validator + + @staticmethod + def list_of_tensors_min_dims(min_dims: int) -> Callable[[Any], bool]: + """Return a validator that checks if value is a list of tensors with at least min_dims dimensions and no NaN values.""" + + def validator(value: Any) -> bool: + return StageValidators.list_of_tensors_with_min_dims(value, min_dims) + + return validator + + +class ValidationFailure: + """Details about a specific validation failure.""" + + def __init__( + self, + validator_name: str, + actual_value: Any, + expected: str | None = None, + error_msg: str | None = None, + ): + self.validator_name = validator_name + self.actual_value = actual_value + self.expected = expected + self.error_msg = error_msg + + def __str__(self) -> str: + parts = [f"Validator '{self.validator_name}' failed"] + + if self.error_msg: + parts.append(f"Error: {self.error_msg}") + + # Add actual value info (but limit very long representations) + actual_str = self._format_value(self.actual_value) + parts.append(f"Actual: {actual_str}") + + if self.expected: + parts.append(f"Expected: {self.expected}") + + return ". ".join(parts) + + def _format_value(self, value: Any) -> str: + """Format a value for display in error messages.""" + if value is None: + return "None" + elif isinstance(value, torch.Tensor): + return f"tensor(shape={list(value.shape)}, dtype={value.dtype})" + elif isinstance(value, list): + if len(value) == 0: + return "[]" + elif len(value) <= 3: + item_strs = [self._format_value(item) for item in value] + return f"[{', '.join(item_strs)}]" + else: + return f"list(length={len(value)}, first_item={self._format_value(value[0])})" + elif isinstance(value, str): + if len(value) > 50: + return f"'{value[:47]}...'" + else: + return f"'{value}'" + else: + return f"{type(value).__name__}({value})" + + +class VerificationResult: + """Wrapper class for stage verification results.""" + + def __init__(self) -> None: + self._checks: dict[str, bool] = {} + self._failures: dict[str, list[ValidationFailure]] = {} + + def add_check( + self, + field_name: str, + value: Any, + validators: Callable[[Any], bool] | list[Callable[[Any], bool]], + ) -> "VerificationResult": + """ + Add a validation check for a field. + + Args: + field_name: Name of the field being checked + value: The actual value to validate + validators: Single validation function or list of validation functions. + Each function will be called with the value as its first argument. + + Returns: + Self for method chaining + + Examples: + # Single validator + result.add_check("tensor", my_tensor, V.is_tensor) + + # Multiple validators (all must pass) + result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)]) + + # Using partial functions for parameters + result.add_check("height", batch.height, [V.not_none, V.divisible(8)]) + """ + if not isinstance(validators, list): + validators = [validators] + + failures = [] + all_passed = True + + # Apply all validators and collect detailed failure info + for validator in validators: + try: + passed = validator(value) + if not passed: + all_passed = False + failure = self._create_validation_failure(validator, value) + failures.append(failure) + except Exception as e: + # If any validator raises an exception, consider the check failed + all_passed = False + validator_name = getattr(validator, "__name__", str(validator)) + failure = ValidationFailure( + validator_name=validator_name, + actual_value=value, + error_msg=f"Exception during validation: {str(e)}", + ) + failures.append(failure) + + self._checks[field_name] = all_passed + if not all_passed: + self._failures[field_name] = failures + + return self + + def _create_validation_failure( + self, validator: Callable, value: Any + ) -> ValidationFailure: + """Create a ValidationFailure with detailed information.""" + validator_name = getattr(validator, "__name__", str(validator)) + + # Try to extract meaningful expected value info based on validator type + expected = None + error_msg = None + + # Handle common validator patterns + if hasattr(validator, "__closure__") and validator.__closure__: + # This is likely a closure (like our helper functions) + if "dims" in validator_name or "with_dims" in str(validator): + if isinstance(value, torch.Tensor): + expected = f"tensor with {validator.__closure__[0].cell_contents} dimensions" + else: + expected = "tensor with specific dimensions" + elif "divisible" in str(validator): + expected = ( + f"integer divisible by {validator.__closure__[0].cell_contents}" + ) + + # Handle specific validator types and check for NaN values + if validator_name == "is_tensor": + expected = "torch.Tensor without NaN values" + if isinstance(value, torch.Tensor) and torch.isnan(value).any().item(): + error_msg = ( + f"tensor contains {torch.isnan(value).sum().item()} NaN values" + ) + elif validator_name == "positive_int": + expected = "positive integer" + elif validator_name == "not_none": + expected = "non-None value" + elif validator_name == "list_not_empty": + expected = "non-empty list" + elif validator_name == "bool_value": + expected = "boolean value" + elif ( + "tensor_with_dims" in validator_name or "tensor_min_dims" in validator_name + ): + if isinstance(value, torch.Tensor): + if torch.isnan(value).any().item(): + error_msg = f"tensor has {value.dim()} dimensions but contains {torch.isnan(value).sum().item()} NaN values" + else: + error_msg = f"tensor has {value.dim()} dimensions" + elif validator_name == "is_list": + expected = "list" + elif validator_name == "none_or_tensor": + expected = "None or tensor without NaN values" + if isinstance(value, torch.Tensor) and torch.isnan(value).any().item(): + error_msg = ( + f"tensor contains {torch.isnan(value).sum().item()} NaN values" + ) + elif validator_name == "list_of_tensors": + expected = "non-empty list of tensors without NaN values" + if isinstance(value, list) and len(value) > 0: + nan_count = 0 + for item in value: + if ( + isinstance(item, torch.Tensor) + and torch.isnan(item).any().item() + ): + nan_count += torch.isnan(item).sum().item() + if nan_count > 0: + error_msg = ( + f"list contains tensors with total {nan_count} NaN values" + ) + elif "list_of_tensors_with_dims" in validator_name: + expected = ( + "non-empty list of tensors with specific dimensions and no NaN values" + ) + if isinstance(value, list) and len(value) > 0: + nan_count = 0 + for item in value: + if ( + isinstance(item, torch.Tensor) + and torch.isnan(item).any().item() + ): + nan_count += torch.isnan(item).sum().item() + if nan_count > 0: + error_msg = ( + f"list contains tensors with total {nan_count} NaN values" + ) + + return ValidationFailure( + validator_name=validator_name, + actual_value=value, + expected=expected, + error_msg=error_msg, + ) + + def is_valid(self) -> bool: + """Check if all validations passed.""" + return all(self._checks.values()) + + def get_failed_fields(self) -> list[str]: + """Get list of fields that failed validation.""" + return [field for field, passed in self._checks.items() if not passed] + + def get_detailed_failures(self) -> dict[str, list[ValidationFailure]]: + """Get detailed failure information for each failed field.""" + return self._failures.copy() + + def get_failure_summary(self) -> str: + """Get a comprehensive summary of all validation failures.""" + if self.is_valid(): + return "All validations passed" + + summary_parts = [] + for field_name, failures in self._failures.items(): + field_summary = f"\n Field '{field_name}':" + for i, failure in enumerate(failures, 1): + field_summary += f"\n {i}. {failure}" + summary_parts.append(field_summary) + + return "Validation failures:" + "".join(summary_parts) + + def to_dict(self) -> dict: + """Convert to dictionary for backward compatibility.""" + return self._checks.copy() + + +# Alias for convenience +V = StageValidators diff --git a/sglang/python/sglang/multimodal_gen/runtime/platforms/__init__.py b/sglang/python/sglang/multimodal_gen/runtime/platforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..55aa8ad5d2d477b05202b95294b5af65d0298aae --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/platforms/__init__.py @@ -0,0 +1,214 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/platforms/__init__.py + +import traceback +from typing import TYPE_CHECKING + +# imported by other files, do not remove +from sglang.multimodal_gen.runtime.platforms.interface import ( # noqa: F401 + AttentionBackendEnum, + Platform, + PlatformEnum, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import resolve_obj_by_qualname + +logger = init_logger(__name__) + + +def cuda_platform_plugin() -> str | None: + is_cuda = False + + try: + from sglang.multimodal_gen.utils import import_pynvml + + pynvml = import_pynvml() # type: ignore[no-untyped-call] + pynvml.nvmlInit() + try: + # NOTE: Edge case: sgl_diffusion cpu build on a GPU machine. + # Third-party pynvml can be imported in cpu build, + # we need to check if sgl_diffusion is built with cpu too. + # Otherwise, sgl_diffusion will always activate cuda plugin + # on a GPU machine, even if in a cpu build. + is_cuda = pynvml.nvmlDeviceGetCount() > 0 + finally: + pynvml.nvmlShutdown() + except Exception as e: + if "nvml" not in e.__class__.__name__.lower(): + # If the error is not related to NVML, re-raise it. + raise e + + # CUDA is supported on Jetson, but NVML may not be. + import os + + def cuda_is_jetson() -> bool: + return os.path.isfile("/etc/nv_tegra_release") or os.path.exists( + "/sys/class/tegra-firmware" + ) + + if cuda_is_jetson(): + is_cuda = True + if is_cuda: + logger.info("CUDA is available") + + return ( + "sglang.multimodal_gen.runtime.platforms.cuda.CudaPlatform" if is_cuda else None + ) + + +def mps_platform_plugin() -> str | None: + """Detect if MPS (Metal Performance Shaders) is available on macOS.""" + is_mps = False + + try: + import torch + + if torch.backends.mps.is_available(): + is_mps = True + logger.info("MPS (Metal Performance Shaders) is available") + except Exception as e: + logger.info("MPS detection failed: %s", e) + + return "sglang.multimodal_gen.runtime.platforms.mps.MpsPlatform" if is_mps else None + + +def cpu_platform_plugin() -> str | None: + """Detect if CPU platform should be used.""" + # CPU is always available as a fallback + return "sglang.multimodal_gen.runtime.platforms.cpu.CpuPlatform" + + +def rocm_platform_plugin() -> str | None: + is_rocm = False + + try: + import amdsmi + + amdsmi.amdsmi_init() + try: + if len(amdsmi.amdsmi_get_processor_handles()) > 0: + is_rocm = True + logger.info("ROCm platform is available") + finally: + amdsmi.amdsmi_shut_down() + except Exception as e: + logger.info("ROCm platform is unavailable: %s", e) + + return ( + "sglang.multimodal_gen.runtime.platforms.rocm.RocmPlatform" if is_rocm else None + ) + + +def npu_platform_plugin() -> str | None: + is_npu = False + + try: + import torch + + if torch.npu.is_available(): + is_npu = True + logger.info("NPU is available") + except Exception as e: + logger.info("NPU detection failed: %s", e) + return ( + "sglang.multimodal_gen.runtime.platforms.npu.NPUPlatformBase" + if is_npu + else None + ) + + +def musa_platform_plugin() -> str | None: + is_musa = False + + try: + import pymtml + + pymtml.mtmlLibraryInit() + try: + is_musa = pymtml.mtmlLibraryCountDevice() > 0 + finally: + pymtml.mtmlLibraryShutDown() + except Exception as e: + logger.info("MUSA platform is unavailable: %s", e) + + return ( + "sglang.multimodal_gen.runtime.platforms.musa.MusaPlatform" if is_musa else None + ) + + +builtin_platform_plugins = { + "cuda": cuda_platform_plugin, + "rocm": rocm_platform_plugin, + "mps": mps_platform_plugin, + "cpu": cpu_platform_plugin, + "npu": npu_platform_plugin, + "musa": musa_platform_plugin, +} + + +def resolve_current_platform_cls_qualname() -> str: + # TODO(will): if we need to support other platforms, we should consider if + # vLLM's plugin architecture is suitable for our needs. + + # Try MPS first on macOS + platform_cls_qualname = mps_platform_plugin() + if platform_cls_qualname is not None: + return platform_cls_qualname + + # Fall back to ROCm + platform_cls_qualname = rocm_platform_plugin() + if platform_cls_qualname is not None: + return platform_cls_qualname + + # Fall back to CUDA + platform_cls_qualname = cuda_platform_plugin() + if platform_cls_qualname is not None: + return platform_cls_qualname + + # Fall back to NPU + platform_cls_qualname = npu_platform_plugin() + if platform_cls_qualname is not None: + return platform_cls_qualname + + # Fall back to MUSA + platform_cls_qualname = musa_platform_plugin() + if platform_cls_qualname is not None: + return platform_cls_qualname + + # Fall back to CPU as last resort + platform_cls_qualname = cpu_platform_plugin() + if platform_cls_qualname is not None: + return platform_cls_qualname + + raise RuntimeError("No platform plugin found. Please check your " "installation.") + + +_current_platform: Platform | None = None +_init_trace: str = "" + +current_platform: Platform + + +def __getattr__(name: str): + if name == "current_platform": + # lazy init current_platform. + # 1. out-of-tree platform plugins need `from sglang.multimodal_gen.runtime.platforms import + # Platform` so that they can inherit `Platform` class. Therefore, + # we cannot resolve `current_platform` during the import of + # `sglang.multimodal_gen.runtime.platforms`. + global _current_platform + if _current_platform is None: + platform_cls_qualname = resolve_current_platform_cls_qualname() + _current_platform = resolve_obj_by_qualname(platform_cls_qualname)() + global _init_trace + _init_trace = "".join(traceback.format_stack()) + return _current_platform + elif name in globals(): + return globals()[name] + else: + raise AttributeError(f"No attribute named '{name}' exists in {__name__}.") + + +__all__ = ["Platform", "PlatformEnum", "current_platform", "_init_trace"] diff --git a/sglang/python/sglang/multimodal_gen/runtime/platforms/cpu.py b/sglang/python/sglang/multimodal_gen/runtime/platforms/cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..c937c15a9c43dc75f3e5f96eee123badc7c78a8b --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/platforms/cpu.py @@ -0,0 +1,88 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/platforms/cpu.py + +import platform +from functools import lru_cache +from typing import Any + +import psutil +import torch + +from sglang.multimodal_gen.runtime.platforms.interface import ( + CpuArchEnum, + Platform, + PlatformEnum, +) + + +class CpuPlatform(Platform): + _enum = PlatformEnum.CPU + device_name = "CPU" + device_type = "cpu" + dispatch_key = "CPU" + + @classmethod + def get_cpu_architecture(cls) -> CpuArchEnum: + """Get the CPU architecture.""" + machine = platform.machine().lower() + if machine in ("x86_64", "amd64", "i386", "i686"): + return CpuArchEnum.X86 + elif machine in ("arm64", "aarch64"): + return CpuArchEnum.ARM + else: + return CpuArchEnum.UNSPECIFIED + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + return platform.processor() + + @classmethod + def get_device_uuid(cls, device_id: int = 0) -> str: + return platform.machine() + + @classmethod + @lru_cache(maxsize=1) + def get_device_total_memory(cls, device_id: int = 0) -> int: + + return psutil.virtual_memory().total + + @classmethod + def is_async_output_supported(cls, enforce_eager: bool | None) -> bool: + return True + + @classmethod + def get_current_memory_usage( + cls, device: torch.types.Device | None = None + ) -> float: + # For CPU, we can't easily get memory usage without additional libraries + return 0.0 + + @classmethod + def get_available_gpu_memory( + cls, + device_id: int = 0, + distributed: bool = False, + empty_cache: bool = True, + cpu_group: Any = None, + ) -> float: + + total_free_memory = psutil.virtual_memory().available + # For simplicity, we assume 1 NUMA node for now in this platform abstraction + # as get_cpu_ids_by_node is not available in multimodal_gen.runtime.utils + n_numa_node = 1 + free_memory = total_free_memory / n_numa_node + + if distributed: + import torch.distributed as dist + + tensor = torch.tensor(free_memory, dtype=torch.float32) + dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=cpu_group) + free_memory = float(tensor.item()) + + return free_memory / (1 << 30) + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "sglang.multimodal_gen.runtime.distributed.device_communicators.cpu_communicator.CpuCommunicator" diff --git a/sglang/python/sglang/multimodal_gen/runtime/platforms/cuda.py b/sglang/python/sglang/multimodal_gen/runtime/platforms/cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..84c75100b2e6b27aef97c858158c6d91fa52e267 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/platforms/cuda.py @@ -0,0 +1,535 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/platforms/cuda.py +"""Code inside this file can safely assume cuda platform, e.g. importing +pynvml. However, it should not initialize cuda context. +""" + +import os +from collections.abc import Callable +from functools import lru_cache, wraps +from typing import Any, TypeVar + +import psutil +import torch +from typing_extensions import ParamSpec + +from sglang.multimodal_gen import envs +from sglang.multimodal_gen.runtime.platforms.interface import ( + AttentionBackendEnum, + DeviceCapability, + Platform, + PlatformEnum, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import import_pynvml + +logger = init_logger(__name__) + +_P = ParamSpec("_P") +_R = TypeVar("_R") + +pynvml = import_pynvml() # type: ignore[no-untyped-call] + +# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models +# see https://github.com/huggingface/diffusers/issues/9704 for details +torch.backends.cuda.enable_cudnn_sdp(False) + + +def device_id_to_physical_device_id(device_id: int) -> int: + if "CUDA_VISIBLE_DEVICES" in os.environ: + device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + if device_ids == [""]: + msg = ( + "CUDA_VISIBLE_DEVICES is set to empty string, which means" + " GPU support is disabled. If you are using ray, please unset" + " the environment variable `CUDA_VISIBLE_DEVICES` inside the" + " worker/actor. " + "Check https://github.com/vllm-project/vllm/issues/8402 for" + " more information." + ) + raise RuntimeError(msg) + physical_device_id = device_ids[device_id] + return int(physical_device_id) + else: + return device_id + + +def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: + @wraps(fn) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + pynvml.nvmlInit() + try: + return fn(*args, **kwargs) + finally: + pynvml.nvmlShutdown() + + return wrapper + + +class CudaPlatformBase(Platform): + _enum = PlatformEnum.CUDA + device_name: str = "cuda" + device_type: str = "cuda" + dispatch_key: str = "CUDA" + device_control_env_var: str = "CUDA_VISIBLE_DEVICES" + + @classmethod + def get_local_torch_device(cls) -> torch.device: + return torch.device(f"cuda:{envs.LOCAL_RANK}") + + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: + raise NotImplementedError + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + raise NotImplementedError + + @classmethod + @lru_cache(maxsize=1) + def get_device_total_memory(cls, device_id: int = 0) -> int: + raise NotImplementedError + + @classmethod + def is_async_output_supported(cls, enforce_eager: bool | None) -> bool: + if enforce_eager: + logger.warning( + "To see benefits of async output processing, enable CUDA " + "graph. Since, enforce-eager is enabled, async output " + "processor cannot be used" + ) + return False + return True + + @classmethod + def is_full_nvlink(cls, device_ids: list[int]) -> bool: + raise NotImplementedError + + @classmethod + def log_warnings(cls) -> None: + pass + + @classmethod + def get_current_memory_usage( + cls, device: torch.types.Device | None = None + ) -> float: + torch.cuda.reset_peak_memory_stats(device) + return float(torch.cuda.max_memory_allocated(device)) + + @classmethod + def get_available_gpu_memory( + cls, + device_id: int = 0, + distributed: bool = False, + empty_cache: bool = True, + cpu_group: Any = None, + ) -> float: + if empty_cache: + torch.cuda.empty_cache() + + if torch.distributed.is_initialized(): + device_id = torch.distributed.get_rank() + + device_props = torch.cuda.get_device_properties(device_id) + if device_props.is_integrated: + free_gpu_memory = psutil.virtual_memory().available + else: + free_gpu_memory, _ = torch.cuda.mem_get_info(device_id) + + if distributed: + import torch.distributed as dist + + tensor = torch.tensor(free_gpu_memory, dtype=torch.float32, device="cuda") + dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=cpu_group) + free_gpu_memory = float(tensor.item()) + + return free_gpu_memory / (1 << 30) + + @classmethod + def get_attn_backend_cls_str( + cls, + selected_backend: AttentionBackendEnum | None, + head_size: int, + dtype: torch.dtype, + ) -> str: + target_backend: AttentionBackendEnum | None = None + # TODO(will): maybe come up with a more general interface for local attention + # if distributed is False, we always try to use Flash attn + if selected_backend == AttentionBackendEnum.SLIDING_TILE_ATTN: + try: + from st_attn import sliding_tile_attention # noqa: F401 + + from sglang.multimodal_gen.runtime.layers.attention.backends.sliding_tile_attn import ( # noqa: F401 + SlidingTileAttentionBackend, + ) + + logger.info("Using Sliding Tile Attention backend") + + return "sglang.multimodal_gen.runtime.layers.attention.backends.sliding_tile_attn.SlidingTileAttentionBackend" + except ImportError as e: + logger.error( + "Failed to import Sliding Tile Attention backend: %s", str(e) + ) + raise ImportError( + "Sliding Tile Attention backend is not installed. " + ) from e + elif selected_backend == AttentionBackendEnum.SAGE_ATTN: + try: + from sageattention import sageattn # noqa: F401 + + from sglang.multimodal_gen.runtime.layers.attention.backends.sage_attn import ( # noqa: F401 + SageAttentionBackend, + ) + + logger.info("Using Sage Attention backend") + + return "sglang.multimodal_gen.runtime.layers.attention.backends.sage_attn.SageAttentionBackend" + except ImportError as e: + logger.info(e) + logger.info( + "Sage Attention backend is not installed (To install it, run `pip install sageattention==2.2.0 --no-build-isolation`). Falling back to Flash Attention." + ) + target_backend = AttentionBackendEnum.FA + elif selected_backend == AttentionBackendEnum.SAGE_ATTN_3: + try: + from sglang.multimodal_gen.runtime.layers.attention.backends.sage_attn3 import ( # noqa: F401 + SageAttention3Backend, + ) + + logger.info("Using Sage Attention 3 backend") + return "sglang.multimodal_gen.runtime.layers.attention.backends.sage_attn3.SageAttention3Backend" + except ImportError as e: + logger.info(e) + logger.info( + "Sage Attention 3 backend is not installed (To install it, see https://github.com/thu-ml/SageAttention/tree/main/sageattention3_blackwell#installation). Falling back to Torch SDPA." + ) + target_backend = AttentionBackendEnum.TORCH_SDPA + elif selected_backend == AttentionBackendEnum.VIDEO_SPARSE_ATTN: + try: + from vsa import block_sparse_attn # noqa: F401 + + from sglang.multimodal_gen.runtime.layers.attention.backends.video_sparse_attn import ( # noqa: F401 + VideoSparseAttentionBackend, + ) + + logger.info("Using Video Sparse Attention backend") + + return "sglang.multimodal_gen.runtime.layers.attention.backends.video_sparse_attn.VideoSparseAttentionBackend" + except ImportError as e: + logger.error( + "Failed to import Video Sparse Attention backend: %s", str(e) + ) + raise ImportError( + "Video Sparse Attention backend is not installed." + ) from e + elif selected_backend == AttentionBackendEnum.SPARSE_VIDEO_GEN_2_ATTN: + try: + from svg.kernels.triton.permute import ( # noqa: F401 + apply_inverse_permutation_triton, + permute_tensor_by_labels_triton, + ) + from svg.kmeans_utils import ( # noqa: F401 + batch_kmeans_Euclid, + density_calculation, + dynamic_block_sparse_fwd_flashinfer, + identify_dynamic_map, + ) + + from sglang.multimodal_gen.runtime.layers.attention.backends.sparse_video_gen_2_attn import ( # noqa: F401 + SparseVideoGen2AttentionBackend, + ) + + logger.info("Using Sparse Video Gen 2 (SAP) Attention backend") + return "sglang.multimodal_gen.runtime.layers.attention.backends.sparse_video_gen_2_attn.SparseVideoGen2AttentionBackend" + except ImportError as e: + logger.error( + "Failed to import Sparse Video Gen 2 (SAP) Attention backend: %s", + str(e), + ) + raise ImportError( + "Sparse Video Gen 2 (SAP) Attention backend is not installed. " + "Please install it by following the instructions at " + "https://github.com/svg-project/Sparse-VideoGen" + ) from e + elif selected_backend == AttentionBackendEnum.VMOBA_ATTN: + try: + from kernel.attn.vmoba_attn.vmoba import moba_attn_varlen # noqa: F401 + + from sglang.multimodal_gen.runtime.layers.attention.backends.vmoba import ( # noqa: F401 + VMOBAAttentionBackend, + ) + + logger.info("Using Video MOBA Attention backend") + + return "sglang.multimodal_gen.runtime.layers.attention.backends.vmoba.VMOBAAttentionBackend" + except ImportError as e: + logger.error( + "Failed to import Video MoBA Attention backend: %s", str(e) + ) + raise ImportError( + "Video MoBA Attention backend is not installed. " + ) from e + elif selected_backend == AttentionBackendEnum.AITER: + logger.info("Using AITer backend") + return "sglang.multimodal_gen.runtime.layers.attention.backends.aiter.AITerBackend" + elif selected_backend == AttentionBackendEnum.TORCH_SDPA: + logger.info("Using Torch SDPA backend") + return "sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend" + elif selected_backend == AttentionBackendEnum.SLA_ATTN: + logger.info("Using Sparse Linear Attention backend") + return "sglang.multimodal_gen.runtime.layers.attention.backends.sparse_linear_attn.SparseLinearAttentionBackend" + elif selected_backend == AttentionBackendEnum.SAGE_SLA_ATTN: + logger.info("Using Sage Sparse Linear Attention backend") + return "sglang.multimodal_gen.runtime.layers.attention.backends.sparse_linear_attn.SageSparseLinearAttentionBackend" + elif selected_backend == AttentionBackendEnum.FA2: + from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn_2 import ( # noqa: F401 + FlashAttention2Backend, + ) + + logger.info("Using FlashAttention2 backend") + return "sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn_2.FlashAttention2Backend" + elif selected_backend in [ + AttentionBackendEnum.FA, + ]: + if cls.is_sm120(): + logger.info( + "FlashAttention is not supported on SM12.x in this build; falling back to Torch SDPA." + ) + target_backend = AttentionBackendEnum.TORCH_SDPA + elif cls.is_blackwell(): + from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import ( + set_fa_ver, + ) + + set_fa_ver(4) + target_backend = AttentionBackendEnum.FA + else: + target_backend = AttentionBackendEnum.FA + elif selected_backend: + raise ValueError(f"Invalid attention backend for {cls.device_name}") + else: + if cls.is_sm120(): + # On SM12.x, the sgl-kernel FlashAttention wheels may not include + # support yet. Default to Torch SDPA for correctness. + logger.info("Defaulting to Torch SDPA backend on SM12.x") + target_backend = AttentionBackendEnum.TORCH_SDPA + elif cls.is_blackwell(): + from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import ( + set_fa_ver, + ) + + set_fa_ver(4) + target_backend = AttentionBackendEnum.FA + else: + target_backend = AttentionBackendEnum.FA + + # Ensure we have a target backend selected before validation/fallback. + if target_backend is None: + target_backend = AttentionBackendEnum.FA + + if target_backend == AttentionBackendEnum.FA and cls.is_blackwell(): + from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import ( + set_fa_ver, + ) + + set_fa_ver(4) + + if not cls.has_device_capability(80): + logger.info("Cannot use FlashAttention backend for Volta and Turing GPUs.") + target_backend = AttentionBackendEnum.TORCH_SDPA + elif dtype not in (torch.float16, torch.bfloat16): + logger.info( + "Cannot use FlashAttention backend for dtype other than " + "torch.float16 or torch.bfloat16." + ) + target_backend = AttentionBackendEnum.TORCH_SDPA + # FlashAttn is valid for the model, checking if the package is + # installed. + if target_backend == AttentionBackendEnum.FA: + try: + from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import ( # noqa: F401 + FlashAttentionBackend, + ) + + supported_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in supported_sizes: + logger.info( + "Cannot use FlashAttention backend for head size %d.", + head_size, + ) + target_backend = AttentionBackendEnum.TORCH_SDPA + except ImportError: + logger.info( + "Cannot use FlashAttention backend because the " + "flash_attn package is not found. " + "Make sure that flash_attn was built and installed " + "(on by default)." + ) + target_backend = AttentionBackendEnum.TORCH_SDPA + + if target_backend == AttentionBackendEnum.TORCH_SDPA: + logger.info("Using Torch SDPA backend") + + return "sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend" + + logger.info("Using FlashAttention (FA3 for hopper, FA4 for blackwell) backend") + + return "sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn.FlashAttentionBackend" + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "sglang.multimodal_gen.runtime.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + + +# NVML utils +# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, +# all the related functions work on real physical device ids. +# the major benefit of using NVML is that it will not initialize CUDA +class NvmlCudaPlatform(CudaPlatformBase): + @classmethod + @lru_cache(maxsize=8) + @with_nvml_context + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: + try: + physical_device_id = device_id_to_physical_device_id(device_id) + handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) + major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle) + return DeviceCapability(major=major, minor=minor) + except RuntimeError: + return None + + @classmethod + @lru_cache(maxsize=8) + @with_nvml_context + def has_device_capability( + cls, + capability: tuple[int, int] | int, + device_id: int = 0, + ) -> bool: + try: + return bool(super().has_device_capability(capability, device_id)) + except RuntimeError: + return False + + @classmethod + @lru_cache(maxsize=8) + @with_nvml_context + def get_device_name(cls, device_id: int = 0) -> str: + physical_device_id = device_id_to_physical_device_id(device_id) + return cls._get_physical_device_name(physical_device_id) + + @classmethod + @lru_cache(maxsize=8) + @with_nvml_context + def get_device_uuid(cls, device_id: int = 0) -> str: + physical_device_id = device_id_to_physical_device_id(device_id) + handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) + return str(pynvml.nvmlDeviceGetUUID(handle)) + + @classmethod + @lru_cache(maxsize=8) + @with_nvml_context + def get_device_total_memory(cls, device_id: int = 0) -> int: + physical_device_id = device_id_to_physical_device_id(device_id) + handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) + return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total) + + @classmethod + @with_nvml_context + def is_full_nvlink(cls, physical_device_ids: list[int]) -> bool: + """ + query if the set of gpus are fully connected by nvlink (1 hop) + """ + handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + p2p_status = pynvml.nvmlDeviceGetP2PStatus( + handle, + peer_handle, + pynvml.NVML_P2P_CAPS_INDEX_NVLINK, + ) + if p2p_status != pynvml.NVML_P2P_STATUS_OK: + return False + except pynvml.NVMLError: + logger.exception( + "NVLink detection failed. This is normal if" + " your machine has no NVLink equipped." + ) + return False + return True + + @classmethod + def _get_physical_device_name(cls, device_id: int = 0) -> str: + handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) + return str(pynvml.nvmlDeviceGetName(handle)) + + @classmethod + @with_nvml_context + def log_warnings(cls) -> None: + device_ids: int = pynvml.nvmlDeviceGetCount() + if device_ids > 1: + device_names = [cls._get_physical_device_name(i) for i in range(device_ids)] + if ( + len(set(device_names)) > 1 + and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID" + ): + logger.warning( + "Detected different devices in the system: %s. Please" + " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to " + "avoid unexpected behavior.", + ", ".join(device_names), + ) + + +class NonNvmlCudaPlatform(CudaPlatformBase): + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + return str(torch.cuda.get_device_name(device_id)) + + @classmethod + @lru_cache(maxsize=1) + def get_device_total_memory(cls, device_id: int = 0) -> int: + device_props = torch.cuda.get_device_properties(device_id) + return int(device_props.total_memory) + + @classmethod + def is_full_nvlink(cls, physical_device_ids: list[int]) -> bool: + logger.exception( + "NVLink detection not possible, as context support was" + " not found. Assuming no NVLink available." + ) + return False + + +# Autodetect either NVML-enabled or non-NVML platform +# based on whether NVML is available. +nvml_available = False +try: + try: + pynvml.nvmlInit() + nvml_available = True + except Exception: + # On Jetson, NVML is not supported. + nvml_available = False +finally: + if nvml_available: + pynvml.nvmlShutdown() + +CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform + +try: + from sphinx.ext.autodoc.mock import _MockModule + + if not isinstance(pynvml, _MockModule): + CudaPlatform.log_warnings() +except ModuleNotFoundError: + CudaPlatform.log_warnings() diff --git a/sglang/python/sglang/multimodal_gen/runtime/platforms/interface.py b/sglang/python/sglang/multimodal_gen/runtime/platforms/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..8c03fb32487fc3ef7b955dad87252b0d8732dfc7 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/platforms/interface.py @@ -0,0 +1,389 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/platforms/interface.py +from __future__ import annotations + +import enum +import random +from functools import lru_cache +from typing import TYPE_CHECKING, Any, NamedTuple + +import numpy as np +import torch + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import resolve_obj_by_qualname + +if TYPE_CHECKING: + from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionImpl, + ) + +logger = init_logger(__name__) + + +class AttentionBackendEnum(enum.Enum): + FA2 = enum.auto() + FA = enum.auto() + SLIDING_TILE_ATTN = enum.auto() + TORCH_SDPA = enum.auto() + SAGE_ATTN = enum.auto() + SAGE_ATTN_3 = enum.auto() + VIDEO_SPARSE_ATTN = enum.auto() + SPARSE_VIDEO_GEN_2_ATTN = enum.auto() + VMOBA_ATTN = enum.auto() + AITER = enum.auto() + SLA_ATTN = enum.auto() + SAGE_SLA_ATTN = enum.auto() + NO_ATTENTION = enum.auto() + + def __str__(self): + return self.name.lower() + + @property + def is_sparse(self) -> bool: + return self in { + AttentionBackendEnum.SLIDING_TILE_ATTN, + AttentionBackendEnum.VIDEO_SPARSE_ATTN, + AttentionBackendEnum.SPARSE_VIDEO_GEN_2_ATTN, + AttentionBackendEnum.VMOBA_ATTN, + AttentionBackendEnum.SLA_ATTN, + AttentionBackendEnum.SAGE_SLA_ATTN, + } + + +class PlatformEnum(enum.Enum): + CUDA = enum.auto() + ROCM = enum.auto() + TPU = enum.auto() + CPU = enum.auto() + MPS = enum.auto() + NPU = enum.auto() + MUSA = enum.auto() + OOT = enum.auto() + UNSPECIFIED = enum.auto() + + +class CpuArchEnum(enum.Enum): + X86 = enum.auto() + ARM = enum.auto() + UNSPECIFIED = enum.auto() + + +class DeviceCapability(NamedTuple): + major: int + minor: int + + def as_version_str(self) -> str: + return f"{self.major}.{self.minor}" + + def to_int(self) -> int: + """ + Express device capability as an integer ````. + + It is assumed that the minor version is always a single digit. + """ + assert 0 <= self.minor < 10 + return self.major * 10 + self.minor + + +class Platform: + _enum: PlatformEnum + device_name: str + device_type: str + device: torch.device | None = None # Dummy attribute for compatibility + + # available dispatch keys: + # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa + # use "CPU" as a fallback for platforms not registered in PyTorch + dispatch_key: str = "CPU" + + # The torch.compile backend for compiling simple and + # standalone functions. The default value is "inductor" to keep + # the same behavior as PyTorch. + # NOTE: for the forward part of the model, vLLM has another separate + # compilation strategy. + simple_compile_backend: str = "inductor" + + supported_quantization: list[str] = [] + + @lru_cache(maxsize=1) + def is_cuda(self) -> bool: + return self.is_cuda_static() + + @lru_cache(maxsize=1) + def is_npu(self) -> bool: + return self._enum == PlatformEnum.NPU + + @lru_cache(maxsize=1) + def is_rocm(self) -> bool: + return self.is_rocm_static() + + @lru_cache(maxsize=1) + def is_tpu(self) -> bool: + return self._enum == PlatformEnum.TPU + + @lru_cache(maxsize=1) + def is_cpu(self) -> bool: + return self._enum == PlatformEnum.CPU + + @classmethod + @lru_cache(maxsize=1) + def is_blackwell(cls): + if not cls.is_cuda_static(): + return False + return torch.cuda.get_device_capability()[0] == 10 + + @classmethod + @lru_cache(maxsize=1) + def is_hopper(cls): + if not cls.is_cuda_static(): + return False + return torch.cuda.get_device_capability() == (9, 0) + + @classmethod + @lru_cache(maxsize=1) + def is_sm120(cls): + if not cls.is_cuda_static(): + return False + return torch.cuda.get_device_capability()[0] == 12 + + @classmethod + def is_cuda_static(cls) -> bool: + return getattr(cls, "_enum", None) == PlatformEnum.CUDA + + @classmethod + def is_rocm_static(cls) -> bool: + return getattr(cls, "_enum", None) == PlatformEnum.ROCM + + @lru_cache(maxsize=1) + def is_hpu(self) -> bool: + return hasattr(torch, "hpu") and torch.hpu.is_available() + + @lru_cache(maxsize=1) + def is_xpu(self) -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + @lru_cache(maxsize=1) + def is_npu(self) -> bool: + return hasattr(torch, "npu") and torch.npu.is_available() + + def is_out_of_tree(self) -> bool: + return self._enum == PlatformEnum.OOT + + @lru_cache(maxsize=1) + def is_cuda_alike(self) -> bool: + """Stateless version of :func:`torch.cuda.is_available`.""" + return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM, PlatformEnum.MUSA) + + @lru_cache(maxsize=1) + def is_mps(self) -> bool: + return self._enum == PlatformEnum.MPS + + @lru_cache(maxsize=1) + def is_musa(self): + try: + return hasattr(torch, "musa") and torch.musa.is_available() + except ModuleNotFoundError: + return False + + @lru_cache(maxsize=1) + def is_hip(self) -> bool: + return self.is_rocm() + + @classmethod + @lru_cache(maxsize=1) + def is_amp_supported(cls) -> bool: + return True + + @classmethod + def get_local_torch_device(cls) -> torch.device: + raise NotImplementedError + + @classmethod + def get_attn_backend_cls_str( + cls, + selected_backend: AttentionBackendEnum | None, + head_size: int, + dtype: torch.dtype, + ) -> str: + """Get the attention backend class of a device.""" + return "" + + @classmethod + def get_device_capability( + cls, + device_id: int = 0, + ) -> DeviceCapability | None: + """Stateless version of :func:`torch.cuda.get_device_capability`.""" + return None + + @classmethod + def has_device_capability( + cls, + capability: tuple[int, int] | int, + device_id: int = 0, + ) -> bool: + """ + Test whether this platform is compatible with a device capability. + + The ``capability`` argument can either be: + + - A tuple ``(major, minor)``. + - An integer ````. (See :meth:`DeviceCapability.to_int`) + """ + current_capability = cls.get_device_capability(device_id=device_id) + if current_capability is None: + return False + + if isinstance(capability, tuple): + return current_capability >= capability + + return current_capability.to_int() >= capability + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + """Get the name of a device.""" + raise NotImplementedError + + @classmethod + def get_device_uuid(cls, device_id: int = 0) -> str: + """Get the uuid of a device, e.g. the PCI bus ID.""" + raise NotImplementedError + + @classmethod + @lru_cache(maxsize=1) + def get_device_total_memory(cls, device_id: int = 0) -> int: + """Get the total memory of a device in bytes.""" + raise NotImplementedError + + @lru_cache(maxsize=1) + def get_device(self, local_rank: int) -> torch.device: + if self.is_cuda() or self.is_rocm(): + return torch.device("cuda", local_rank) + elif self.is_npu(): + return torch.device("npu", local_rank) + elif self.is_musa(): + return torch.device("musa", local_rank) + elif self.is_mps(): + return torch.device("mps") + else: + return torch.device("cpu") + + @lru_cache(maxsize=1) + def get_torch_distributed_backend_str(self) -> str: + if self.is_cuda_alike(): + return "nccl" + elif self.is_npu(): + return "hccl" + elif self.is_musa(): + return "mccl" + elif self.is_mps(): + return "gloo" + else: + raise NotImplementedError( + "No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available" + ) + + @classmethod + def is_async_output_supported(cls, enforce_eager: bool | None) -> bool: + """ + Check if the current platform supports async output. + """ + raise NotImplementedError + + @classmethod + def inference_mode(cls): + """A device-specific wrapper of `torch.inference_mode`. + + This wrapper is recommended because some hardware backends such as TPU + do not support `torch.inference_mode`. In such a case, they will fall + back to `torch.no_grad` by overriding this method. + """ + return torch.inference_mode(mode=True) + + @classmethod + def seed_everything(cls, seed: int | None = None) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + if seed is not None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.get_device_module().manual_seed_all(seed) + + @classmethod + def verify_model_arch(cls, model_arch: str) -> None: + """ + Verify whether the current platform supports the specified model + architecture. + + - This will raise an Error or Warning based on the model support on + the current platform. + - By default all models are considered supported. + """ + pass + + @classmethod + def verify_quantization(cls, quant: str) -> None: + """ + Verify whether the quantization is supported by the current platform. + """ + if cls.supported_quantization and quant not in cls.supported_quantization: + raise ValueError( + f"{quant} quantization is currently not supported in " + f"{cls.device_name}." + ) + + @classmethod + def get_current_memory_usage( + cls, device: torch.types.Device | None = None + ) -> float: + """ + Return the memory usage in bytes. + """ + raise NotImplementedError + + @classmethod + def get_available_gpu_memory( + cls, + device_id: int = 0, + distributed: bool = False, + empty_cache: bool = True, + cpu_group: Any = None, + ) -> float: + """ + Return the available memory in GiB. + """ + raise NotImplementedError + + @classmethod + def get_device_communicator_cls(cls) -> str: + """ + Get device specific communicator class for distributed communication. + """ + return "sglang.multimodal_gen.runtime.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa + + @classmethod + def get_cpu_architecture(cls) -> CpuArchEnum: + """Get the CPU architecture of the current platform.""" + return CpuArchEnum.UNSPECIFIED + + @classmethod + def enable_dit_layerwise_offload_for_wan_by_default(cls) -> bool: + """Whether to enable DIT layerwise offload by default on the current platform.""" + return True + + def get_attn_backend(self, *args, **kwargs) -> AttentionImpl: + attention_cls_str = self.get_attn_backend_cls_str(*args, **kwargs) + return resolve_obj_by_qualname(attention_cls_str) + + +class UnspecifiedPlatform(Platform): + _enum = PlatformEnum.UNSPECIFIED + device_type = "" diff --git a/sglang/python/sglang/multimodal_gen/runtime/platforms/mps.py b/sglang/python/sglang/multimodal_gen/runtime/platforms/mps.py new file mode 100644 index 0000000000000000000000000000000000000000..bb9116a4d62c97c7a0aa6539745c78d28b895925 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/platforms/mps.py @@ -0,0 +1,127 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo +from functools import lru_cache +from typing import Any + +import psutil +import torch + +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + Platform, + PlatformEnum, +) +from sglang.multimodal_gen.runtime.platforms.interface import DeviceCapability +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +# SPDX-License-Identifier: Apache-2.0 + + +logger = init_logger(__name__) + + +class MpsPlatform(Platform): + _enum = PlatformEnum.MPS + device_name: str = "mps" + device_type: str = "mps" + dispatch_key: str = "MPS" + device_control_env_var: str = "MPS_VISIBLE_DEVICES" + + @classmethod + @lru_cache(maxsize=1) + def is_amp_supported(cls) -> bool: + return False + + @classmethod + def get_local_torch_device(cls) -> torch.device: + return torch.device("mps") + + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: + raise NotImplementedError + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + raise NotImplementedError + + @classmethod + def get_device_uuid(cls, device_id: int = 0) -> str: + raise NotImplementedError + + @classmethod + @lru_cache(maxsize=1) + def get_device_total_memory(cls, device_id: int = 0) -> int: + + return psutil.virtual_memory().total + + @classmethod + def is_async_output_supported(cls, enforce_eager: bool | None) -> bool: + if enforce_eager: + logger.warning( + "To see benefits of async output processing, enable MPS " + "graph. Since, enforce-eager is enabled, async output " + "processor cannot be used" + ) + return False + return True + + @classmethod + def get_current_memory_usage( + cls, device: torch.types.Device | None = None + ) -> float: + return 0.0 + + @classmethod + def get_available_gpu_memory( + cls, + device_id: int = 0, + distributed: bool = False, + empty_cache: bool = True, + cpu_group: Any = None, + ) -> float: + + if empty_cache: + torch.mps.empty_cache() + + # For MPS, available memory is essentially the system available memory + free_memory = psutil.virtual_memory().available + + if distributed: + import torch.distributed as dist + + tensor = torch.tensor(free_memory, dtype=torch.float32) + dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=cpu_group) + free_memory = float(tensor.item()) + + return free_memory / (1 << 30) + + @classmethod + def get_attn_backend_cls_str( + cls, + selected_backend: AttentionBackendEnum | None, + head_size: int, + dtype: torch.dtype, + ) -> str: + # MPS supports SDPA (Scaled Dot-Product Attention) which is the most compatible + logger.info("Using Torch SDPA backend for MPS.") + return ( + "sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend" + ) + + @classmethod + def get_device_communicator_cls(cls) -> str: + # Use base communicator for MPS + return "sglang.multimodal_gen.runtime.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" + + @classmethod + def seed_everything(cls, seed: int | None = None) -> None: + """Set the seed for MPS device.""" + if seed is not None: + import random + + import numpy as np + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + # MPS doesn't have manual_seed_all like CUDA + # The manual_seed above should be sufficient diff --git a/sglang/python/sglang/multimodal_gen/runtime/platforms/musa.py b/sglang/python/sglang/multimodal_gen/runtime/platforms/musa.py new file mode 100644 index 0000000000000000000000000000000000000000..7d443be6b5429e544a49448f3396564b82f99a74 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/platforms/musa.py @@ -0,0 +1,320 @@ +""" +This file is a platform abstraction for MThreads (MUSA) GPUs, +adjusted to match the structure and interface of `cuda.py`. +""" + +import os +from collections.abc import Callable +from functools import lru_cache, wraps +from typing import Any, TypeVar + +import psutil +import pymtml + +# isort: off +import torch +import torchada # noqa: F401 + +# isort: on +from typing_extensions import ParamSpec + +from sglang.multimodal_gen import envs +from sglang.multimodal_gen.runtime.platforms.interface import ( + AttentionBackendEnum, + DeviceCapability, + Platform, + PlatformEnum, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +def device_id_to_physical_device_id(device_id: int) -> int: + if "MUSA_VISIBLE_DEVICES" in os.environ: + device_ids = os.environ["MUSA_VISIBLE_DEVICES"].split(",") + if device_ids == [""]: + msg = ( + "MUSA_VISIBLE_DEVICES is set to empty string, which means" + " GPU support is disabled. If you are using ray, please unset" + " the environment variable `MUSA_VISIBLE_DEVICES` inside the" + " worker/actor. " + "Check https://github.com/vllm-project/vllm/issues/8402 for" + " more information." + ) + raise RuntimeError(msg) + physical_device_id = device_ids[device_id] + return int(physical_device_id) + else: + return device_id + + +def with_mtml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: + @wraps(fn) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + pymtml.nvmlInit() + try: + return fn(*args, **kwargs) + finally: + pymtml.nvmlShutdown() + + return wrapper + + +class MusaPlatformBase(Platform): + _enum = PlatformEnum.MUSA + device_name: str = "musa" + device_type: str = "musa" + dispatch_key: str = "MUSA" + device_control_env_var: str = "MUSA_VISIBLE_DEVICES" + + @classmethod + def get_local_torch_device(cls) -> torch.device: + return torch.device(f"musa:{envs.LOCAL_RANK}") + + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: + raise NotImplementedError + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + raise NotImplementedError + + @classmethod + @lru_cache(maxsize=1) + def get_device_total_memory(cls, device_id: int = 0) -> int: + raise NotImplementedError + + @classmethod + def is_async_output_supported(cls, enforce_eager: bool | None) -> bool: + if enforce_eager: + logger.warning( + "To see benefits of async output processing, enable MUSA " + "graph. Since, enforce-eager is enabled, async output " + "processor cannot be used" + ) + return False + return True + + @classmethod + def is_full_mtlink(cls, device_ids: list[int]) -> bool: + raise NotImplementedError + + @classmethod + def log_warnings(cls) -> None: + pass + + @classmethod + def get_current_memory_usage( + cls, device: torch.types.Device | None = None + ) -> float: + torch.cuda.reset_peak_memory_stats(device) + return float(torch.cuda.max_memory_allocated(device)) + + @classmethod + def get_available_gpu_memory( + cls, + device_id: int = 0, + distributed: bool = False, + empty_cache: bool = True, + cpu_group: Any = None, + ) -> float: + if empty_cache: + torch.cuda.empty_cache() + + if torch.distributed.is_initialized(): + device_id = torch.distributed.get_rank() + + device_props = torch.cuda.get_device_properties(device_id) + if device_props.is_integrated: + free_gpu_memory = psutil.virtual_memory().available + else: + free_gpu_memory, _ = torch.cuda.mem_get_info(device_id) + + if distributed: + import torch.distributed as dist + + tensor = torch.tensor(free_gpu_memory, dtype=torch.float32, device="musa") + dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=cpu_group) + free_gpu_memory = float(tensor.item()) + + return free_gpu_memory / (1 << 30) + + @classmethod + def get_attn_backend_cls_str( + cls, + selected_backend: AttentionBackendEnum | None, + head_size: int, + dtype: torch.dtype, + ) -> str: + logger.info("Using Torch SDPA backend.") + return ( + "sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend" + ) + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "sglang.multimodal_gen.runtime.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + + +# MTML utils +# Note that MTML is not affected by `MUSA_VISIBLE_DEVICES`, +# all the related functions work on real physical device ids. +# the major benefit of using MTML is that it will not initialize MUSA +class MtmlMusaPlatform(MusaPlatformBase): + @classmethod + @lru_cache(maxsize=8) + @with_mtml_context + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: + try: + physical_device_id = device_id_to_physical_device_id(device_id) + handle = pymtml.nvmlDeviceGetHandleByIndex(physical_device_id) + major, minor = pymtml.nvmlDeviceGetCudaComputeCapability(handle) + return DeviceCapability(major=major, minor=minor) + except RuntimeError: + return None + + @classmethod + @lru_cache(maxsize=8) + @with_mtml_context + def has_device_capability( + cls, + capability: tuple[int, int] | int, + device_id: int = 0, + ) -> bool: + try: + return bool(super().has_device_capability(capability, device_id)) + except RuntimeError: + return False + + @classmethod + @lru_cache(maxsize=8) + @with_mtml_context + def get_device_name(cls, device_id: int = 0) -> str: + physical_device_id = device_id_to_physical_device_id(device_id) + return cls._get_physical_device_name(physical_device_id) + + @classmethod + @lru_cache(maxsize=8) + @with_mtml_context + def get_device_uuid(cls, device_id: int = 0) -> str: + physical_device_id = device_id_to_physical_device_id(device_id) + handle = pymtml.nvmlDeviceGetHandleByIndex(physical_device_id) + return str(pymtml.nvmlDeviceGetUUID(handle)) + + @classmethod + @lru_cache(maxsize=8) + @with_mtml_context + def get_device_total_memory(cls, device_id: int = 0) -> int: + physical_device_id = device_id_to_physical_device_id(device_id) + handle = pymtml.nvmlDeviceGetHandleByIndex(physical_device_id) + return int(pymtml.nvmlDeviceGetMemoryInfo(handle).total) + + @classmethod + @with_mtml_context + def is_full_mtlink(cls, physical_device_ids: list[int]) -> bool: + """ + query if the set of gpus are fully connected by mtlink (1 hop) + """ + handles = [pymtml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + p2p_status = pymtml.nvmlDeviceGetP2PStatus( + handle, + peer_handle, + pymtml.NVML_P2P_CAPS_INDEX_NVLINK, + ) + if p2p_status != pymtml.NVML_P2P_STATUS_OK: + return False + except pymtml.NVMLError: + logger.exception( + "MTLink detection failed. This is normal if" + " your machine has no MTLink equipped." + ) + return False + return True + + @classmethod + def _get_physical_device_name(cls, device_id: int = 0) -> str: + handle = pymtml.nvmlDeviceGetHandleByIndex(device_id) + return str(pymtml.nvmlDeviceGetName(handle)) + + @classmethod + @with_mtml_context + def log_warnings(cls) -> None: + device_ids: int = pymtml.nvmlDeviceGetCount() + if device_ids > 1: + device_names = [cls._get_physical_device_name(i) for i in range(device_ids)] + if ( + len(set(device_names)) > 1 + and os.environ.get("MUSA_DEVICE_ORDER") != "PCI_BUS_ID" + ): + logger.warning( + "Detected different devices in the system: %s. Please" + " make sure to set `MUSA_DEVICE_ORDER=PCI_BUS_ID` to " + "avoid unexpected behavior.", + ", ".join(device_names), + ) + + +class NonMtmlMusaPlatform(MusaPlatformBase): + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + return str(torch.cuda.get_device_name(device_id)) + + @classmethod + @lru_cache(maxsize=1) + def get_device_total_memory(cls, device_id: int = 0) -> int: + device_props = torch.cuda.get_device_properties(device_id) + return int(device_props.total_memory) + + @classmethod + def is_full_mtlink(cls, physical_device_ids: list[int]) -> bool: + logger.error( + "MTLink detection not possible, as context support was" + " not found. Assuming no MTLink available." + ) + return False + + +# Autodetect either MTML-enabled or non-MTML platform +# based on whether MTML is available. +mtml_available = False + +if "MUSA_DISABLE_MTML" not in os.environ: + try: + try: + pymtml.nvmlInit() + mtml_available = True + except Exception: + mtml_available = False + finally: + if mtml_available: + pymtml.nvmlShutdown() + +MusaPlatform = MtmlMusaPlatform if mtml_available else NonMtmlMusaPlatform + +try: + from sphinx.ext.autodoc.mock import _MockModule + + if not isinstance(pymtml, _MockModule): + MusaPlatform.log_warnings() +except ModuleNotFoundError: + MusaPlatform.log_warnings() + +if __name__ == "__main__": + print(MusaPlatform.__name__) + print(MusaPlatform.get_device_name()) + print(MusaPlatform.get_device_capability()) + print(MusaPlatform.get_device_total_memory()) + print(MusaPlatform.is_full_mtlink([0, 1, 2, 3, 4, 5, 6, 7])) diff --git a/sglang/python/sglang/multimodal_gen/runtime/platforms/npu.py b/sglang/python/sglang/multimodal_gen/runtime/platforms/npu.py new file mode 100644 index 0000000000000000000000000000000000000000..c73733409b808407f5b10b038c35236ef4b4e261 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/platforms/npu.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm-ascend: https://github.com/vllm-project/vllm-ascend/blob/main/vllm_ascend/platform.py + +import os +from typing import Any + +import torch + +from sglang.multimodal_gen import envs +from sglang.multimodal_gen.runtime.platforms.interface import ( + AttentionBackendEnum, + DeviceCapability, + Platform, + PlatformEnum, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def device_id_to_physical_device_id(device_id: int) -> int: + if "ASCEND_RT_VISIBLE_DEVICES" in os.environ: + device_ids = os.environ["ASCEND_RT_VISIBLE_DEVICES"].split(",") + if device_ids == [""]: + msg = ( + "ASCEND_RT_VISIBLE_DEVICES is set to empty string, which means" + " NPU support is disabled" + ) + raise RuntimeError(msg) + physical_device_id = device_ids[device_id] + return int(physical_device_id) + else: + return device_id + + +class NPUPlatformBase(Platform): + _enum = PlatformEnum.NPU + device_name: str = "npu" + device_type: str = "npu" + dispatch_key: str = "NPU" + device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES" + + @classmethod + def get_local_torch_device(cls) -> torch.device: + return torch.device(f"npu:{envs.LOCAL_RANK}") + + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + return None + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + return str(torch.npu.get_device_name(device_id)) + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + device_props = torch.npu.get_device_properties(device_id) + return int(device_props.total_memory) + + @classmethod + def is_async_output_supported(cls, enforce_eager: bool | None) -> bool: + if enforce_eager: + logger.warning( + "To see benefits of async output processing, enable NPU " + "graph. Since, enforce-eager is enabled, async output " + "processor cannot be used" + ) + return False + return True + + @classmethod + def is_full_nvlink(cls, physical_device_ids: list[int]) -> bool: + logger.exception( + "NVLink detection not possible, as context support was" + " not found. Assuming no NVLink available." + ) + return False + + @classmethod + def get_available_gpu_memory( + cls, + device_id: int = 0, + distributed: bool = False, + empty_cache: bool = True, + cpu_group: Any = None, + ) -> float: + if empty_cache: + torch.npu.empty_cache() + + free_gpu_memory, _ = torch.npu.mem_get_info(device_id) + + if distributed: + import torch.distributed as dist + + tensor = torch.tensor(free_gpu_memory, dtype=torch.float32, device="npu") + dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=cpu_group) + free_gpu_memory = float(tensor.item()) + + return free_gpu_memory / (1 << 30) + + @classmethod + def log_warnings(cls) -> None: + pass + + @classmethod + def get_current_memory_usage( + cls, device: torch.types.Device | None = None + ) -> float: + torch.npu.reset_peak_memory_stats(device) + return float(torch.npu.max_memory_allocated(device)) + + @classmethod + def get_attn_backend_cls_str( + cls, + selected_backend: AttentionBackendEnum | None, + head_size: int, + dtype: torch.dtype, + ) -> str: + logger.info("Using Torch SDPA backend.") + return ( + "sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend" + ) + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "sglang.multimodal_gen.runtime.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + + @classmethod + def enable_dit_layerwise_offload_for_wan_by_default(cls) -> bool: + """The performance of the layerwise_offload feature depends on the device's memory size and the memory size occupied by the model. Use --dit-layerwise-offload True if it suitable for your case.""" + return False diff --git a/sglang/python/sglang/multimodal_gen/runtime/platforms/rocm.py b/sglang/python/sglang/multimodal_gen/runtime/platforms/rocm.py new file mode 100644 index 0000000000000000000000000000000000000000..d36c76dc0e684707e124d7438bda1064e4671c2b --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/platforms/rocm.py @@ -0,0 +1,178 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from rocm/vllm: https://github.com/ROCm/vllm/blob/v0.7.3%2Brocm/vllm/platforms/rocm.py +""" +This file is a platform abstraction for ROCm GPUs, +adjusted to match the structure and interface of `cuda.py`. +""" + +from functools import lru_cache +from typing import Any + +import torch + +import sglang.multimodal_gen.envs as envs +from sglang.multimodal_gen.runtime.platforms.interface import ( + AttentionBackendEnum, + DeviceCapability, + Platform, + PlatformEnum, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +# ROCm uses the same torch.cuda interface +class RocmPlatform(Platform): + _enum = PlatformEnum.ROCM + device_name: str = "rocm" + device_type: str = "cuda" # torch uses 'cuda' backend string + dispatch_key: str = "CUDA" + device_control_env_var: str = "CUDA_VISIBLE_DEVICES" + + @classmethod + def get_local_torch_device(cls) -> torch.device: + return torch.device(f"cuda:{envs.LOCAL_RANK}") + + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + return str(torch.cuda.get_device_name(device_id)) + + @classmethod + @lru_cache(maxsize=1) + def get_device_total_memory(cls, device_id: int = 0) -> int: + return torch.cuda.get_device_properties(device_id).total_memory + + @classmethod + def is_async_output_supported(cls, enforce_eager: bool | None) -> bool: + if enforce_eager: + logger.warning( + "To see benefits of async output processing, enable CUDA graph. " + "Since enforce-eager is enabled, async output processor cannot be used" + ) + return False + return True + + @classmethod + def log_warnings(cls) -> None: + pass # ROCm-specific warnings can be added here + + @classmethod + def get_current_memory_usage(cls, device: torch.device | None = None) -> float: + torch.cuda.reset_peak_memory_stats(device) + return float(torch.cuda.max_memory_allocated(device)) + + @classmethod + def get_available_gpu_memory( + cls, + device_id: int = 0, + distributed: bool = False, + empty_cache: bool = True, + cpu_group: Any = None, + ) -> float: + if empty_cache: + torch.cuda.empty_cache() + + free_gpu_memory, _ = torch.cuda.mem_get_info(device_id) + + if distributed: + import torch.distributed as dist + + tensor = torch.tensor(free_gpu_memory, dtype=torch.float32, device="cuda") + dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=cpu_group) + free_gpu_memory = float(tensor.item()) + + return free_gpu_memory / (1 << 30) + + @classmethod + def get_attn_backend_cls_str( + cls, + selected_backend: AttentionBackendEnum | None, + head_size: int, + dtype: torch.dtype, + ) -> str: + if selected_backend == AttentionBackendEnum.TORCH_SDPA: + logger.info("Using Torch SDPA backend.") + return "sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend" + + elif selected_backend in (AttentionBackendEnum.FA, None): + pass + + elif selected_backend == AttentionBackendEnum.AITER: + if dtype not in (torch.float16, torch.bfloat16): + logger.warning( + "AITer backend works best with fp16/bf16 inputs but got dtype=%s. " + "Proceeding with AITer anyway.", + dtype, + ) + logger.info("Using AITer backend on ROCm.") + return "sglang.multimodal_gen.runtime.layers.attention.backends.aiter.AITerBackend" + + elif selected_backend in ( + AttentionBackendEnum.SLIDING_TILE_ATTN, + AttentionBackendEnum.SAGE_ATTN, + ): + raise ValueError( + f"{selected_backend.name} is not supported on {cls.device_name}." + ) + elif selected_backend: + raise ValueError( + f"Invalid attention backend for {cls.device_name}: {selected_backend}" + ) + + target_backend = AttentionBackendEnum.FA + if dtype not in (torch.float16, torch.bfloat16): + logger.info( + "Cannot use FlashAttention backend for dtype other than " + "torch.float16 or torch.bfloat16." + ) + target_backend = AttentionBackendEnum.TORCH_SDPA + + if target_backend == AttentionBackendEnum.FA: + try: + import flash_attn # noqa: F401 + + from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import ( # noqa: F401 + FlashAttentionBackend, + ) + + supported_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in supported_sizes: + logger.info( + "Cannot use FlashAttention-2 backend for head size %d.", + head_size, + ) + target_backend = AttentionBackendEnum.TORCH_SDPA + except ImportError: + logger.info( + "Cannot use FlashAttention backend because the " + "flash_attn package is not found. " + "Make sure that flash_attn was built and installed " + "(on by default)." + ) + target_backend = AttentionBackendEnum.TORCH_SDPA + + if target_backend == AttentionBackendEnum.TORCH_SDPA: + logger.info("Using Torch SDPA backend.") + + return "sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend" + + logger.info("Using Flash Attention backend.") + + return "sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn.FlashAttentionBackend" + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "sglang.multimodal_gen.runtime.distributed.device_communicators.cuda_communicator.CudaCommunicator" # works for ROCm too + + @classmethod + def enable_dit_layerwise_offload_for_wan_by_default(cls) -> bool: + """ROCm performs better without DIT layerwise offload on Wan.""" + return False diff --git a/sglang/python/sglang/multimodal_gen/runtime/postprocess/__init__.py b/sglang/python/sglang/multimodal_gen/runtime/postprocess/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f70951a2d7126abf572763b820a1195185794e51 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/postprocess/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Frame interpolation support for SGLang diffusion pipelines.""" + +from sglang.multimodal_gen.runtime.postprocess.rife_interpolator import ( + FrameInterpolator, + interpolate_video_frames, +) + +__all__ = ["FrameInterpolator", "interpolate_video_frames"] diff --git a/sglang/python/sglang/multimodal_gen/runtime/postprocess/rife_interpolator.py b/sglang/python/sglang/multimodal_gen/runtime/postprocess/rife_interpolator.py new file mode 100644 index 0000000000000000000000000000000000000000..1d722e06543ea649edd62db4c36a2587cd8c5e2d --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/postprocess/rife_interpolator.py @@ -0,0 +1,483 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +RIFE 4.22.lite frame interpolation for SGLang diffusion pipelines. + +RIFE model code is vendored and adapted from: + - https://github.com/hzwer/ECCV2022-RIFE (MIT License) + - https://github.com/hzwer/Practical-RIFE (MIT License) + Copyright (c) 2021 Zhewei Huang + +The FrameInterpolator wrapper and integration code are original work. +""" + +import os +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +# Default HuggingFace repo for RIFE 4.22.lite weights +_DEFAULT_RIFE_HF_REPO = "elfgum/RIFE-4.22.lite" + +# Module-level cache: model_path -> Model instance +_MODEL_CACHE: dict[str, "Model"] = {} + + +# --------------------------------------------------------------------------- +# Vendored RIFE 4.22.lite model code +# (IFBlock, IFNet_HDv3 backbone, Model wrapper) +# --------------------------------------------------------------------------- + + +def warp(tenInput: torch.Tensor, tenFlow: torch.Tensor) -> torch.Tensor: + """Warp tenInput by tenFlow using grid_sample.""" + # Build base grid for the current size + tenHorizontal = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device) + .view(1, 1, 1, tenFlow.shape[3]) + .expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + ) + tenVertical = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device) + .view(1, 1, tenFlow.shape[2], 1) + .expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + ) + tenGrid = torch.cat([tenHorizontal, tenVertical], dim=1) + + tenFlow = torch.cat( + [ + tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0), + ], + dim=1, + ) + + grid = (tenGrid + tenFlow).permute(0, 2, 3, 1) + return F.grid_sample( + input=tenInput, + grid=grid, + mode="bilinear", + padding_mode="border", + align_corners=True, + ) + + +def _conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + """Conv2d + LeakyReLU helper (matches RIFE 4.22 conv()).""" + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ), + nn.LeakyReLU(0.2, True), + ) + + +class ResConv(nn.Module): + """Residual convolution block with learnable beta scaling (RIFE 4.22).""" + + def __init__(self, c: int, dilation: int = 1): + super().__init__() + self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.relu(self.conv(x) * self.beta + x) + + +class IFBlock(nn.Module): + """Single-scale optical flow + mask + feature block (RIFE 4.22).""" + + def __init__(self, in_planes: int, c: int = 64): + super().__init__() + self.conv0 = nn.Sequential( + _conv(in_planes, c // 2, 3, 2, 1), + _conv(c // 2, c, 3, 2, 1), + ) + self.convblock = nn.Sequential( + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ) + self.lastconv = nn.Sequential( + nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1), + nn.PixelShuffle(2), + ) + + def forward( + self, + x: torch.Tensor, + flow: Optional[torch.Tensor] = None, + scale: float = 1.0, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + x = F.interpolate( + x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False + ) + if flow is not None: + flow = ( + F.interpolate( + flow, + scale_factor=1.0 / scale, + mode="bilinear", + align_corners=False, + ) + * 1.0 + / scale + ) + x = torch.cat((x, flow), 1) + feat = self.conv0(x) + feat = self.convblock(feat) + tmp = self.lastconv(feat) + tmp = F.interpolate( + tmp, scale_factor=scale, mode="bilinear", align_corners=False + ) + flow = tmp[:, :4] * scale + mask = tmp[:, 4:5] + feat = tmp[:, 5:] + return flow, mask, feat + + +class Head(nn.Module): + """Feature encoder producing 4-channel features at full resolution (RIFE 4.22).""" + + def __init__(self): + super().__init__() + self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1) + self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1) + self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1) + self.cnn3 = nn.ConvTranspose2d(16, 4, 4, 2, 1) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x0 = self.cnn0(x) + x = self.relu(x0) + x1 = self.cnn1(x) + x = self.relu(x1) + x2 = self.cnn2(x) + x = self.relu(x2) + x3 = self.cnn3(x) + return x3 + + +class IFNet(nn.Module): + """4-scale IFNet optical flow network (RIFE 4.22 backbone).""" + + def __init__(self): + super().__init__() + self.block0 = IFBlock(7 + 8, c=192) + self.block1 = IFBlock(8 + 4 + 8 + 8, c=128) + self.block2 = IFBlock(8 + 4 + 8 + 8, c=64) + self.block3 = IFBlock(8 + 4 + 8 + 8, c=32) + self.encode = Head() + + def forward( + self, + x: torch.Tensor, + timestep: float = 0.5, + scale_list: Optional[list] = None, + ) -> tuple[list, torch.Tensor, list]: + if scale_list is None: + scale_list = [8, 4, 2, 1] + + channel = x.shape[1] // 2 + img0 = x[:, :channel] + img1 = x[:, channel:] + + if not torch.is_tensor(timestep): + timestep = (x[:, :1].clone() * 0 + 1) * timestep + else: + timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) + + f0 = self.encode(img0[:, :3]) + f1 = self.encode(img1[:, :3]) + + flow_list = [] + merged = [] + mask_list = [] + warped_img0 = img0 + warped_img1 = img1 + flow = None + mask = None + + block = [self.block0, self.block1, self.block2, self.block3] + for i in range(4): + if flow is None: + flow, mask, feat = block[i]( + torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1), + None, + scale=scale_list[i], + ) + else: + wf0 = warp(f0, flow[:, :2]) + wf1 = warp(f1, flow[:, 2:4]) + fd, m0, feat = block[i]( + torch.cat( + ( + warped_img0[:, :3], + warped_img1[:, :3], + wf0, + wf1, + timestep, + mask, + feat, + ), + 1, + ), + flow, + scale=scale_list[i], + ) + mask = m0 + flow = flow + fd + + mask_list.append(mask) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + merged.append((warped_img0, warped_img1)) + + mask = torch.sigmoid(mask) + merged[3] = warped_img0 * mask + warped_img1 * (1 - mask) + + return flow_list, mask_list[3], merged + + +class Model: + """Wraps IFNet, provides load_model() and inference() API.""" + + def __init__(self): + self.flownet = IFNet() + self.device_type: str = "cpu" + + def eval(self) -> "Model": + self.flownet.eval() + return self + + def device(self) -> torch.device: + return next(self.flownet.parameters()).device + + def load_model(self, path: str, strip_module_prefix: bool = True) -> None: + """Load weights from {path}/flownet.pkl. + + Args: + path: Directory containing ``flownet.pkl``. + strip_module_prefix: If True, strip the ``module.`` prefix that + ``DataParallel`` / ``DistributedDataParallel`` adds to keys. + """ + flownet_path = os.path.join(path, "flownet.pkl") + if not os.path.isfile(flownet_path): + raise FileNotFoundError( + f"RIFE weight file not found: {flownet_path}\n" + "Expected layout: /flownet.pkl" + ) + + def convert(param): + if strip_module_prefix: + return { + k.replace("module.", ""): v + for k, v in param.items() + if "module." in k + } + else: + return {k: v for k, v in param.items() if "module." not in k} + + state = torch.load(flownet_path, map_location="cpu", weights_only=False) + self.flownet.load_state_dict(convert(state), strict=False) + logger.info("Loaded RIFE weights from %s", flownet_path) + + def inference( + self, + img0: torch.Tensor, + img1: torch.Tensor, + scale: float = 1.0, + timestep: float = 0.5, + ) -> torch.Tensor: + """Interpolate a single intermediate frame between img0 and img1.""" + n, c, h, w = img0.shape + + # Pad to multiples of 32 so that RIFE's downsample/upsample round-trips + # preserve spatial dimensions exactly. + ph = ((h - 1) // 32 + 1) * 32 + pw = ((w - 1) // 32 + 1) * 32 + pad = (0, pw - w, 0, ph - h) + img0 = F.pad(img0, pad) + img1 = F.pad(img1, pad) + + imgs = torch.cat((img0, img1), 1) + scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] + with torch.no_grad(): + flow_list, mask, merged = self.flownet( + imgs, + timestep=timestep, + scale_list=scale_list, + ) + + # Crop back to original resolution + return merged[3][:, :, :h, :w] + + +# --------------------------------------------------------------------------- +# FrameInterpolator public class +# --------------------------------------------------------------------------- + + +class FrameInterpolator: + """ + Lazy-loaded RIFE 4.22.lite frame interpolator. + + Weights are loaded on first call to `.interpolate()` and cached globally + per model_path to avoid reloading across requests. + """ + + def __init__(self, model_path: Optional[str] = None): + self._model_path = model_path + self._resolved_path: Optional[str] = None + + def _ensure_model_loaded(self) -> Model: + """Load RIFE model weights. + + Accepts a local directory **or** a HuggingFace repo ID. When *None* + (the default) the weights are downloaded (and cached) automatically + from ``elfgum/RIFE-4.22.lite`` via ``maybe_download_model()``. + """ + from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( + maybe_download_model, + ) + + model_path = self._model_path or _DEFAULT_RIFE_HF_REPO + + # Resolve: local path pass-through, HF repo ID → download & cache + model_path = maybe_download_model(model_path) + + self._resolved_path = model_path + + if model_path in _MODEL_CACHE: + return _MODEL_CACHE[model_path] + + device = current_platform.get_local_torch_device() + model = Model() + model.load_model(model_path, strip_module_prefix=True) + model.eval() + model.flownet = model.flownet.to(device) + _MODEL_CACHE[model_path] = model + logger.info("RIFE model loaded on device: %s", device) + return model + + @staticmethod + def _frame_to_tensor(frame: np.ndarray, device: torch.device) -> torch.Tensor: + """Convert uint8 HWC numpy frame to float32 CHW tensor on device.""" + t = torch.from_numpy(frame).permute(2, 0, 1).unsqueeze(0).float() / 255.0 + return t.to(device) + + @staticmethod + def _tensor_to_frame(t: torch.Tensor) -> np.ndarray: + """Convert float32 CHW tensor (batch=1) to uint8 HWC numpy frame.""" + arr = t.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0).cpu().numpy() + return (arr * 255.0).astype(np.uint8) + + def _make_inference( + self, model: Model, I0: torch.Tensor, I1: torch.Tensor, n: int, scale: float + ) -> list[torch.Tensor]: + """ + Recursively generate n-1 intermediate frames between I0 and I1. + + Returns a list of intermediate frame tensors (not including I0 or I1). + """ + if n == 1: + return [model.inference(I0, I1, scale=scale)] + mid = model.inference(I0, I1, scale=scale) + return ( + self._make_inference(model, I0, mid, n // 2, scale) + + [mid] + + self._make_inference(model, mid, I1, n // 2, scale) + ) + + def interpolate( + self, + frames: list[np.ndarray], + exp: int = 1, + scale: float = 1.0, + ) -> tuple[list[np.ndarray], int]: + """ + Interpolate frames using RIFE. + + Args: + frames: List of uint8 numpy arrays with shape [H, W, 3]. + exp: Exponent for interpolation factor. 1 → 2×, 2 → 4×. + scale: RIFE inference scale. Use 0.5 for high-resolution inputs. + + Returns: + (interpolated_frames, multiplier) where multiplier = 2**exp. + """ + if len(frames) < 2: + logger.warning( + "Frame interpolation requires at least 2 frames; returning input unchanged." + ) + return frames, 1 + + model = self._ensure_model_loaded() + device = model.device() + + n_intermediate = 2**exp // 2 # intermediates per adjacent pair + + result: list[np.ndarray] = [] + for i in range(len(frames) - 1): + I0 = self._frame_to_tensor(frames[i], device) + I1 = self._frame_to_tensor(frames[i + 1], device) + + intermediate_tensors = self._make_inference( + model, I0, I1, n_intermediate, scale + ) + + result.append(frames[i]) + for t in intermediate_tensors: + result.append(self._tensor_to_frame(t)) + + result.append(frames[-1]) + multiplier = 2**exp + return result, multiplier + + +# --------------------------------------------------------------------------- +# Module-level convenience function +# --------------------------------------------------------------------------- + + +def interpolate_video_frames( + frames: list[np.ndarray], + exp: int = 1, + scale: float = 1.0, + model_path: Optional[str] = None, +) -> tuple[list[np.ndarray], int]: + """ + Convenience wrapper around FrameInterpolator. + + Args: + frames: List of uint8 HWC numpy frames. + exp: Interpolation exponent (1=2×, 2=4×). + scale: RIFE inference scale (default 1.0; use 0.5 for high-res). + model_path: Local directory or HuggingFace repo ID containing + ``flownet.pkl``. *None* → default ``elfgum/RIFE-4.22.lite``. + + Returns: + (interpolated_frames, multiplier) + """ + interpolator = FrameInterpolator(model_path=model_path) + return interpolator.interpolate(frames, exp=exp, scale=scale) diff --git a/sglang/python/sglang/multimodal_gen/runtime/scheduler_client.py b/sglang/python/sglang/multimodal_gen/runtime/scheduler_client.py new file mode 100644 index 0000000000000000000000000000000000000000..caec33da4365917cc0cc059ce26d13cc00bd98e9 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/scheduler_client.py @@ -0,0 +1,206 @@ +import pickle +from typing import Any + +import zmq +import zmq.asyncio + +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +async def run_zeromq_broker(server_args: ServerArgs): + """ + This function runs as a background task in the FastAPI process. + It listens for TCP requests from offline clients (e.g., DiffGenerator). + """ + ctx = zmq.asyncio.Context() + # This is the REP socket that listens for requests from DiffGenerator + socket = ctx.socket(zmq.REP) + broker_endpoint = f"tcp://*:{server_args.broker_port}" + socket.bind(broker_endpoint) + logger.info(f"ZMQ Broker is listening for offline jobs on {broker_endpoint}") + + while True: + try: + # 1. Receive a request from an offline client + payload = await socket.recv() + request_batch = pickle.loads(payload) + logger.info("Broker received an offline job from a client.") + + # 2. Forward the request to the main Scheduler via the shared client + response_batch = await async_scheduler_client.forward(request_batch) + + # 3. Send the Scheduler's reply back to the offline client + await socket.send(pickle.dumps(response_batch)) + + except Exception as e: + logger.error(f"Error in ZMQ Broker: {e}", exc_info=True) + # A reply must be sent to prevent the client from hanging + try: + await socket.send(pickle.dumps({"status": "error", "message": str(e)})) + except Exception: + pass + + +class SchedulerClient: + """ + A synchronous, singleton client for communicating with the Scheduler service. + Designed for use in DiffGenerator, where synchronous usage is preferred + """ + + def __init__(self): + self.context = None + self.scheduler_socket = None + self.server_args = None + + def initialize(self, server_args: ServerArgs): + if self.context is not None and not self.context.closed: + logger.warning("SchedulerClient is already initialized. Re-initializing.") + self.close() + + self.server_args = server_args + self.context = zmq.Context() + self.scheduler_socket = self.context.socket(zmq.REQ) + + # Set socket options for the main communication socket + self.scheduler_socket.setsockopt(zmq.LINGER, 0) + + # 100 minute timeout for generation + self.scheduler_socket.setsockopt(zmq.RCVTIMEO, 6000000) + + scheduler_endpoint = self.server_args.scheduler_endpoint + self.scheduler_socket.connect(scheduler_endpoint) + logger.debug( + f"SchedulerClient connected to backend scheduler at {scheduler_endpoint}" + ) + + def forward(self, batch: Any) -> Any: + """Sends a batch or request to the scheduler and waits for the response.""" + try: + self.scheduler_socket.send_pyobj(batch) + output_batch = self.scheduler_socket.recv_pyobj() + return output_batch + except zmq.error.Again: + logger.error("Timeout waiting for response from scheduler.") + raise TimeoutError("Scheduler did not respond in time.") + + def ping(self) -> bool: + """ + Checks if the scheduler server is alive using a temporary socket. + """ + if self.context is None or self.context.closed: + logger.error("Cannot ping: client is not initialized.") + return False + + ping_socket = self.context.socket(zmq.REQ) + ping_socket.setsockopt(zmq.LINGER, 0) + ping_socket.setsockopt(zmq.RCVTIMEO, 2000) # 2-second timeout for pings + + endpoint = self.server_args.scheduler_endpoint + + try: + ping_socket.connect(endpoint) + ping_socket.send_pyobj({"method": "ping"}) + ping_socket.recv_pyobj() + return True + except zmq.error.Again: + return False + finally: + ping_socket.close() + + def close(self): + """Closes the socket and terminates the context.""" + if self.scheduler_socket: + self.scheduler_socket.close() + self.scheduler_socket = None + if self.context: + self.context.term() + self.context = None + + +class AsyncSchedulerClient: + """ + An asynchronous, singleton client for communicating with the Scheduler service. + Designed for use in asynchronous environments like FastAPI entrypoints. + + To support high concurrency, it creates a new REQ socket for each request + rather than sharing a single one (which would cause ZMQ state errors). + """ + + def __init__(self): + self.context = None + self.server_args = None + + def initialize(self, server_args: ServerArgs): + if self.context is not None and not self.context.closed: + logger.warning( + "AsyncSchedulerClient is already initialized. Re-initializing." + ) + self.close() + + self.server_args = server_args + self.context = zmq.asyncio.Context() + logger.debug("AsyncSchedulerClient initialized with zmq.asyncio.Context") + + async def forward(self, batch: Any) -> Any: + """Sends a batch or request to the scheduler and waits for the response.""" + if self.context is None: + raise RuntimeError( + "AsyncSchedulerClient is not initialized. Call initialize() first." + ) + + # Create a temporary REQ socket for this request to allow concurrency + socket = self.context.socket(zmq.REQ) + socket.setsockopt(zmq.LINGER, 0) + # 100 minute timeout + socket.setsockopt(zmq.RCVTIMEO, 6000000) + + endpoint = self.server_args.scheduler_endpoint + socket.connect(endpoint) + + try: + await socket.send(pickle.dumps(batch)) + payload = await socket.recv() + return pickle.loads(payload) + except zmq.error.Again: + logger.error("Timeout waiting for response from scheduler.") + raise TimeoutError("Scheduler did not respond in time.") + finally: + socket.close() + + async def ping(self) -> bool: + """ + Checks if the scheduler server is alive using a temporary socket. + """ + if self.context is None or self.context.closed: + logger.error("Cannot ping: client is not initialized.") + return False + + ping_socket = self.context.socket(zmq.REQ) + ping_socket.setsockopt(zmq.LINGER, 0) + ping_socket.setsockopt(zmq.RCVTIMEO, 2000) + + endpoint = self.server_args.scheduler_endpoint + + try: + ping_socket.connect(endpoint) + await ping_socket.send(pickle.dumps({"method": "ping"})) + await ping_socket.recv() + return True + except zmq.error.Again: + return False + finally: + ping_socket.close() + + def close(self): + """Closes the socket and terminates the context.""" + if self.context: + self.context.term() + self.context = None + + +# Singleton instances for easy access +async_scheduler_client = AsyncSchedulerClient() +sync_scheduler_client = SchedulerClient() diff --git a/sglang/python/sglang/multimodal_gen/runtime/server_args.py b/sglang/python/sglang/multimodal_gen/runtime/server_args.py new file mode 100644 index 0000000000000000000000000000000000000000..246aa3ec3549e40ecbaeb1010e078bf8488498cb --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/server_args.py @@ -0,0 +1,1318 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Inspired by SGLang: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py +"""The arguments of sglang-diffusion Inference.""" + +import argparse +import dataclasses +import inspect +import json +import math +import os +import random +import sys +import tempfile +from dataclasses import field +from enum import Enum +from typing import Any, Optional + +import addict +import yaml + +from sglang.multimodal_gen import envs +from sglang.multimodal_gen.configs.models.encoders import T5Config +from sglang.multimodal_gen.configs.pipeline_configs.base import PipelineConfig +from sglang.multimodal_gen.configs.quantization import NunchakuSVDQuantArgs +from sglang.multimodal_gen.runtime.layers.quantization.configs.nunchaku_config import ( + NunchakuConfig, +) +from sglang.multimodal_gen.runtime.loader.utils import BYTES_PER_GB +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) +from sglang.multimodal_gen.runtime.utils.common import ( + is_port_available, + is_valid_ipv6_address, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import ( + configure_logger, + init_logger, +) +from sglang.multimodal_gen.utils import ( + FlexibleArgumentParser, + StoreBoolean, + expand_path_fields, +) + +logger = init_logger(__name__) + + +def _is_torch_tensor(obj: Any) -> tuple[bool, Any]: + """Return (is_tensor, torch_module_or_None) without importing torch at module import time.""" + try: + import torch # type: ignore + + return isinstance(obj, torch.Tensor), torch + except Exception: + return False, None + + +def _sanitize_for_logging(obj: Any, key_hint: str | None = None) -> Any: + """Recursively convert objects to JSON-serializable forms for concise logging. + + Rules: + - Drop any field/dict key named 'param_names_mapping'. + - Render Enums using their value. + - Render torch.Tensor as a compact summary; if key name is 'scaling_factor', include stats. + - Dataclasses are expanded to dicts and sanitized recursively. + - Callables/functions are rendered as their qualified name. + - Redact sensitive fields like 'prompt' and 'negative_prompt' (only show length). + - Fallback to str(...) for unknown types. + """ + # Handle simple types quickly + if obj is None or isinstance(obj, (str, int, float, bool)): + # redact sensitive prompt fields + if key_hint in ("prompt", "negative_prompt"): + if isinstance(obj, str): + return f"" + return obj + + # Enum -> value for readability + if isinstance(obj, Enum): + return obj.value + + # torch.Tensor handling (lazy import) + is_tensor, torch_mod = _is_torch_tensor(obj) + if is_tensor: + try: + ten = obj.detach().cpu() + if key_hint == "scaling_factor": + # Provide a compact, single-line summary for scaling_factor + stats = { + "shape": list(ten.shape), + "dtype": str(ten.dtype), + } + # Stats might fail for some dtypes; guard individually + try: + stats["min"] = float(ten.min().item()) + except Exception: + pass + try: + stats["max"] = float(ten.max().item()) + except Exception: + pass + try: + stats["mean"] = float(ten.float().mean().item()) + except Exception: + pass + return {"tensor": "scaling_factor", **stats} + # Generic tensor summary + return {"tensor": True, "shape": list(ten.shape), "dtype": str(ten.dtype)} + except Exception: + return "" + + # Dataclasses -> dict + if dataclasses.is_dataclass(obj): + result: dict[str, Any] = {} + for f in dataclasses.fields(obj): + if not f.repr: + continue + name = f.name + if "names_mapping" in name: # drop noisy mappings + continue + try: + value = getattr(obj, name) + except Exception: + continue + result[name] = _sanitize_for_logging(value, key_hint=name) + return result + + # Dicts -> sanitize keys/values; drop 'param_names_mapping' + if isinstance(obj, dict): + result_dict: dict[str, Any] = {} + for k, v in obj.items(): + try: + key_str = str(k) + except Exception: + key_str = "" + if key_str == "param_names_mapping": + continue + result_dict[key_str] = _sanitize_for_logging(v, key_hint=key_str) + return result_dict + + # Sequences/Sets -> list + if isinstance(obj, (list, tuple, set)): + return [_sanitize_for_logging(x, key_hint=key_hint) for x in obj] + + # Functions / Callables -> qualified name + try: + if inspect.isroutine(obj) or inspect.isclass(obj): + module = getattr(obj, "__module__", "") + qn = getattr(obj, "__qualname__", getattr(obj, "__name__", "")) + return f"{module}.{qn}" if module else qn + except Exception: + pass + + # Fallback: string representation + try: + return str(obj) + except Exception: + return "" + + +class Backend(str, Enum): + """ + Enumeration for different model backends. + - AUTO: Automatically select backend (prefer sglang native, fallback to diffusers) + - SGLANG: Use sglang's native optimized implementation + - DIFFUSERS: Use vanilla diffusers pipeline (supports all diffusers models) + """ + + AUTO = "auto" + SGLANG = "sglang" + DIFFUSERS = "diffusers" + + @classmethod + def from_string(cls, value: str) -> "Backend": + """Convert string to Backend enum.""" + try: + return cls(value.lower()) + except ValueError: + raise ValueError( + f"Invalid backend: {value}. Must be one of: {', '.join([m.value for m in cls])}" + ) from None + + @classmethod + def choices(cls) -> list[str]: + """Get all available choices as strings for argparse.""" + return [backend.value for backend in cls] + + +@dataclasses.dataclass +class ServerArgs: + # Model and path configuration (for convenience) + model_path: str + + # explicit model ID override (e.g. "Qwen-Image") + model_id: str | None = None + + # Model backend (sglang native or diffusers) + backend: Backend = Backend.AUTO + + # Attention + attention_backend: str = None + attention_backend_config: addict.Dict | None = None + cache_dit_config: str | dict[str, Any] | None = ( + None # cache-dit config for diffusers + ) + + # Distributed executor backend + nccl_port: Optional[int] = None + + # HuggingFace specific parameters + trust_remote_code: bool = False + revision: str | None = None + + # Parallelism + num_gpus: int = 1 + tp_size: Optional[int] = None + sp_degree: Optional[int] = None + # sequence parallelism + ulysses_degree: Optional[int] = None + ring_degree: Optional[int] = None + # data parallelism + # number of data parallelism groups + dp_size: int = 1 + # number of gpu in a dp group + dp_degree: int = 1 + # cfg parallel + enable_cfg_parallel: bool = False + + hsdp_replicate_dim: int = 1 + hsdp_shard_dim: Optional[int] = None + dist_timeout: int | None = 3600 # 1 hour + + pipeline_config: PipelineConfig = field(default_factory=PipelineConfig, repr=False) + + # Pipeline override + pipeline_class_name: str | None = ( + None # Override pipeline class from model_index.json + ) + + # LoRA parameters + # (Wenxuan) prefer to keep it here instead of in pipeline config to not make it complicated. + lora_path: str | None = None + lora_nickname: str = "default" # for swapping adapters in the pipeline + lora_scale: float = 1.0 # LoRA scale for merging (e.g., 0.125 for Hyper-SD) + + # Component path overrides (key = model_index.json component name, value = path) + component_paths: dict[str, str] = field(default_factory=dict) + + # path to pre-quantized transformer weights (single .safetensors or directory). + transformer_weights_path: str | None = None + # can restrict layers to adapt, e.g. ["q_proj"] + # Will adapt only q, k, v, o by default. + lora_target_modules: list[str] | None = None + + # CPU offload parameters + dit_cpu_offload: bool | None = None + dit_layerwise_offload: bool | None = None + dit_offload_prefetch_size: float = 0.0 + text_encoder_cpu_offload: bool | None = None + image_encoder_cpu_offload: bool | None = None + vae_cpu_offload: bool | None = None + use_fsdp_inference: bool = False + pin_cpu_memory: bool = True + + # ComfyUI integration + comfyui_mode: bool = False + + # Compilation + enable_torch_compile: bool = False + + # warmup + warmup: bool = False + warmup_resolutions: list[str] = None + warmup_steps: int = 1 + + disable_autocast: bool | None = None + + # Quantization / Nunchaku SVDQuant configuration + nunchaku_config: NunchakuSVDQuantArgs | NunchakuConfig | None = field( + default_factory=NunchakuSVDQuantArgs, repr=False + ) + + # Master port for distributed inference + # TODO: do not hard code + master_port: int | None = None + + # http server endpoint config + host: str | None = "127.0.0.1" + port: int | None = 30000 + + # TODO: webui and their endpoint, check if webui_port is available. + webui: bool = False + webui_port: int | None = 12312 + + scheduler_port: int = 5555 + + output_path: str | None = "outputs/" + input_save_path: str | None = "inputs/uploads" + + # Prompt text file for batch processing + prompt_file_path: str | None = None + + # model paths for correct deallocation + model_paths: dict[str, str] = field(default_factory=dict) + model_loaded: dict[str, bool] = field( + default_factory=lambda: { + "transformer": True, + "vae": True, + "video_vae": True, + "audio_vae": True, + "video_dit": True, + "audio_dit": True, + "dual_tower_bridge": True, + } + ) + + # # DMD parameters + # dmd_denoising_steps: List[int] | None = field(default=None) + + # MoE parameters used by Wan2.2 + boundary_ratio: float | None = None + + # Logging + log_level: str = "info" + + @property + def broker_port(self) -> int: + return self.port + 1 + + @property + def is_local_mode(self) -> bool: + """ + If no server is running when a generation task begins, 'local_mode' will be enabled: a dedicated server will be launched + """ + return self.host is None or self.port is None + + def _adjust_path(self): + expand_path_fields(self) + self._adjust_save_paths() + + def _adjust_parameters(self): + """set defaults and normalize values.""" + self._adjust_offload() + self._adjust_path() + self._adjust_quant_config() + self._adjust_warmup() + self._adjust_network_ports() + # adjust parallelism before attention backend + self._adjust_parallelism() + self._adjust_attention_backend() + self._adjust_platform_specific() + self._adjust_autocast() + self.adjust_pipeline_config() + + def _validate_parameters(self): + """check consistency and raise errors for invalid configs""" + self._validate_pipeline() + self._validate_offload() + self._validate_parallelism() + self._validate_cfg_parallel() + + def _adjust_save_paths(self): + """Normalize empty-string save paths to None (disabled).""" + if self.output_path is not None and self.output_path.strip() == "": + self.output_path = None + if self.input_save_path is not None and self.input_save_path.strip() == "": + self.input_save_path = None + + def _adjust_quant_config(self): + """validate and adjust""" + + # nunchaku + ncfg = self.nunchaku_config + if ncfg is None or isinstance(ncfg, NunchakuConfig): + return + ncfg.validate() + + # propagate the path to server_args + if ncfg.transformer_weights_path: + self.transformer_weights_path = ncfg.transformer_weights_path + + if not ncfg.enable_svdquant or not ncfg.transformer_weights_path: + self.nunchaku_config = None + else: + self.nunchaku_config = NunchakuConfig( + precision=self.nunchaku_config.quantization_precision, + rank=self.nunchaku_config.quantization_rank, + act_unsigned=self.nunchaku_config.quantization_act_unsigned, + transformer_weights_path=self.nunchaku_config.transformer_weights_path, + ) + + def adjust_pipeline_config(self): + # enable parallel folding when SP is enabled + if self.tp_size != 1 or self.sp_degree <= 1: + return + + enabled = False + for text_encoder_config in self.pipeline_config.text_encoder_configs: + if isinstance(text_encoder_config, T5Config): + text_encoder_config.parallel_folding = True + enabled = True + text_encoder_config.parallel_folding_mode = "sp" + + if enabled: + logger.info( + "Enabled T5 text encoder parallel folding (mode=sp) for %s (tp_size=%s, sp_degree=%s).", + self.__class__.__name__, + self.tp_size, + self.sp_degree, + ) + + def _adjust_offload(self): + # TODO: to be handled by each platform + if current_platform.get_device_total_memory() / BYTES_PER_GB < 30: + logger.info("Enabling all offloading for GPU with low device memory") + if self.dit_cpu_offload is None: + self.dit_cpu_offload = True + if self.text_encoder_cpu_offload is None: + self.text_encoder_cpu_offload = True + if self.image_encoder_cpu_offload is None: + self.image_encoder_cpu_offload = True + if self.vae_cpu_offload is None: + self.vae_cpu_offload = True + elif self.pipeline_config.task_type.is_image_gen(): + logger.info( + "Disabling some offloading (except dit, text_encoder) for image generation model" + ) + if self.dit_cpu_offload is None: + self.dit_cpu_offload = True + if self.text_encoder_cpu_offload is None: + self.text_encoder_cpu_offload = True + if self.image_encoder_cpu_offload is None: + self.image_encoder_cpu_offload = False + if self.vae_cpu_offload is None: + self.vae_cpu_offload = False + else: + if self.dit_cpu_offload is None: + self.dit_cpu_offload = True + if self.text_encoder_cpu_offload is None: + self.text_encoder_cpu_offload = True + if self.image_encoder_cpu_offload is None: + self.image_encoder_cpu_offload = True + if self.vae_cpu_offload is None: + self.vae_cpu_offload = True + + def _adjust_attention_backend(self): + if self.attention_backend in ["fa3", "fa4"]: + self.attention_backend = "fa" + + # attention_backend_config + if self.attention_backend_config is None: + self.attention_backend_config = addict.Dict() + elif isinstance(self.attention_backend_config, str): + self.attention_backend_config = addict.Dict( + self._parse_attention_backend_config(self.attention_backend_config) + ) + + if self.ring_degree > 1: + if self.attention_backend is not None and self.attention_backend not in ( + "fa", + "sage_attn", + ): + raise ValueError( + "Ring Attention is only supported for flash attention or sage attention backend for now" + ) + if self.attention_backend is None: + self.attention_backend = "fa" + logger.info( + "Ring Attention is currently only supported for flash attention or sage attention; " + "attention_backend has been automatically set to flash attention" + ) + + if self.attention_backend is None and self.backend != Backend.DIFFUSERS: + self._set_default_attention_backend() + + def _adjust_warmup(self): + if self.warmup_resolutions is not None: + self.warmup = True + + if self.warmup: + logger.info( + "Warmup enabled, the launch time is expected to be longer than usual" + ) + + def _adjust_network_ports(self): + self.port = self.settle_port(self.port) + initial_scheduler_port = self.scheduler_port + ( + random.randint(0, 100) if self.scheduler_port == 5555 else 0 + ) + self.scheduler_port = self.settle_port(initial_scheduler_port) + initial_master_port = ( + self.master_port + if self.master_port is not None + else (30005 + random.randint(0, 100)) + ) + self.master_port = self.settle_port(initial_master_port, 37) + + def _adjust_parallelism(self): + if self.tp_size is None: + self.tp_size = 1 + + if self.hsdp_shard_dim is None: + self.hsdp_shard_dim = self.num_gpus + + # adjust sp_degree: allocate all remaining GPUs after TP and DP + if self.sp_degree is None: + num_gpus_per_group = self.dp_size * self.tp_size + if self.enable_cfg_parallel: + num_gpus_per_group *= 2 + if self.num_gpus % num_gpus_per_group == 0: + self.sp_degree = self.num_gpus // num_gpus_per_group + else: + # Will be validated later + self.sp_degree = 1 + + if ( + self.ulysses_degree is None + and self.ring_degree is None + and self.sp_degree != 1 + ): + self.ulysses_degree = self.sp_degree + logger.info( + f"Automatically set ulysses_degree=sp_degree={self.ulysses_degree} for best performance" + ) + + if self.ulysses_degree is None: + self.ulysses_degree = 1 + logger.debug( + f"Ulysses degree not set, using default value {self.ulysses_degree}" + ) + + if self.ring_degree is None: + self.ring_degree = 1 + logger.debug(f"Ring degree not set, using default value {self.ring_degree}") + + def _adjust_platform_specific(self): + if current_platform.is_mps(): + self.use_fsdp_inference = False + self.dit_layerwise_offload = False + + # automatically enable dit_layerwise_offload for Wan/MOVA models if appropriate + if not envs.SGLANG_CACHE_DIT_ENABLED: + pipeline_name_lower = self.pipeline_config.__class__.__name__.lower() + if ( + ("wan" in pipeline_name_lower or "mova" in pipeline_name_lower) + and self.dit_layerwise_offload is None + and current_platform.enable_dit_layerwise_offload_for_wan_by_default() + ): + logger.info( + f"Automatically enable dit_layerwise_offload for {self.pipeline_config.__class__.__name__} " + "for low memory and performance balance" + ) + self.dit_layerwise_offload = True + + def _adjust_autocast(self): + if self.disable_autocast is None: + self.disable_autocast = not self.pipeline_config.enable_autocast + + def _parse_attention_backend_config(self, config_str: str) -> dict[str, Any]: + """parse attention backend config from string.""" + if not config_str: + return {} + + # 1. treat as file path + if os.path.exists(config_str): + if config_str.endswith((".yaml", ".yml")): + with open(config_str, "r") as f: + return yaml.safe_load(f) + elif config_str.endswith(".json"): + with open(config_str, "r") as f: + return json.load(f) + + # 2. treat as JSON string + try: + return json.loads(config_str) + except json.JSONDecodeError: + pass + + # 3. treat as k=v pairs (simple implementation). e.g., "sparsity=0.5,enable_x=true" + try: + config = {} + pairs = config_str.split(",") + for pair in pairs: + k, v = pair.split("=", 1) + k = k.strip() + v = v.strip() + if v.lower() == "true": + v = True + elif v.lower() == "false": + v = False + elif v.replace(".", "", 1).isdigit(): + v = float(v) if "." in v else int(v) + config[k] = v + return config + except Exception: + raise ValueError(f"Could not parse attention backend config: {config_str}") + + def __post_init__(self): + # configure logger before use + configure_logger(server_args=self) + + # 1. adjust parameters + self._adjust_parameters() + + # 2. Validate parameters + self._validate_parameters() + + # log clean server_args + try: + safe_args = _sanitize_for_logging(self, key_hint="server_args") + logger.info("server_args: %s", json.dumps(safe_args, ensure_ascii=False)) + except Exception: + # Fallback to default repr if sanitization fails + logger.info(f"server_args: {self}") + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + # Model and path configuration + parser.add_argument( + "--model-path", + type=str, + help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.", + ) + parser.add_argument( + "--model-id", + type=str, + default=ServerArgs.model_id, + help=( + "Override the model ID used for config resolution. " + "Useful when --model-path is a local directory whose name does not match " + "any registered HF repo name. Should be the repo name portion of the HF ID " + "(e.g. 'Qwen-Image' for 'Qwen/Qwen-Image')." + ), + ) + # attention + parser.add_argument( + "--attention-backend", + type=str, + default=None, + help=( + "The attention backend to use. For SGLang-native pipelines, use " + "values like fa, torch_sdpa, sage_attn, etc. For diffusers pipelines, " + "use diffusers attention backend names such as flash, _flash_3_hub, " + "sage, or xformers." + ), + ) + parser.add_argument( + "--attention-backend-config", + type=str, + default=None, + help="Configuration for the attention backend. Can be a JSON string, a path to a JSON/YAML file, or key=value pairs.", + ) + parser.add_argument( + "--cache-dit-config", + type=str, + default=ServerArgs.cache_dit_config, + help="Path to a Cache-DiT YAML/JSON config. Enables cache-dit for diffusers backend.", + ) + + # HuggingFace specific parameters + parser.add_argument( + "--trust-remote-code", + action=StoreBoolean, + default=ServerArgs.trust_remote_code, + help="Trust remote code when loading HuggingFace models", + ) + parser.add_argument( + "--revision", + type=str, + default=ServerArgs.revision, + help="The specific model version to use (can be a branch name, tag name, or commit id)", + ) + + # Parallelism + parser.add_argument( + "--num-gpus", + type=int, + default=ServerArgs.num_gpus, + help="The number of GPUs to use.", + ) + parser.add_argument( + "--tp-size", + type=int, + default=None, + help="The tensor parallelism size. Defaults to 1 if not specified.", + ) + parser.add_argument( + "--sp-degree", + type=int, + default=None, + help="The sequence parallelism size. If not specified, will use all remaining GPUs after accounting for TP and DP.", + ) + parser.add_argument( + "--ulysses-degree", + type=int, + default=ServerArgs.ulysses_degree, + help="Ulysses sequence parallel degree. Used in attention layer.", + ) + parser.add_argument( + "--ring-degree", + type=int, + default=ServerArgs.ring_degree, + help="Ring sequence parallel degree. Used in attention layer.", + ) + parser.add_argument( + "--enable-cfg-parallel", + action="store_true", + default=ServerArgs.enable_cfg_parallel, + help="Enable cfg parallel.", + ) + parser.add_argument( + "--data-parallel-size", + "--dp-size", + "--dp", + type=int, + default=ServerArgs.dp_size, + help="The data parallelism size.", + ) + + parser.add_argument( + "--hsdp-replicate-dim", + type=int, + default=ServerArgs.hsdp_replicate_dim, + help="The data parallelism size.", + ) + parser.add_argument( + "--hsdp-shard-dim", + type=int, + default=None, + help="The data parallelism shards. Defaults to num_gpus if not specified.", + ) + parser.add_argument( + "--dist-timeout", + type=int, + default=ServerArgs.dist_timeout, + help="Timeout for torch.distributed operations in seconds. " + "Increase this value if you encounter 'Connection closed by peer' errors after the service is idle. ", + ) + + # Prompt text file for batch processing + parser.add_argument( + "--prompt-file-path", + type=str, + default=ServerArgs.prompt_file_path, + help="Path to a text file containing prompts (one per line) for batch processing", + ) + + parser.add_argument( + "--mask-strategy-file-path", + type=str, + help="Path to mask strategy JSON file for STA", + ) + parser.add_argument( + "--enable-torch-compile", + action=StoreBoolean, + default=ServerArgs.enable_torch_compile, + help="Use torch.compile to speed up DiT inference." + + "However, will likely cause precision drifts. See (https://github.com/pytorch/pytorch/issues/145213)", + ) + + # warmup + parser.add_argument( + "--warmup", + action=StoreBoolean, + default=ServerArgs.warmup, + help="Perform some warmup after server starts (if `--warmup-resolutions` is specified) or before processing the first request (if `--warmup-resolutions` is not specified)." + "Recommended to enable when benchmarking to ensure fair comparison and best performance." + "When enabled with `--warmup-resolutions` unspecified, look for the line ending with `(with warmup excluded)` for actual processing time.", + ) + parser.add_argument( + "--warmup-resolutions", + type=str, + nargs="+", + default=ServerArgs.warmup_resolutions, + help="Specify resolutions for server to warmup. e.g., `--warmup-resolutions 256x256, 720x720`", + ) + parser.add_argument( + "--warmup-steps", + type=int, + default=ServerArgs.warmup_steps, + help="The number of warmup steps to perform for each resolution.", + ) + + parser.add_argument( + "--dit-cpu-offload", + action=StoreBoolean, + help="Use CPU offload for DiT inference. Enable if run out of memory with FSDP.", + ) + parser.add_argument( + "--dit-layerwise-offload", + action=StoreBoolean, + default=ServerArgs.dit_layerwise_offload, + help="Enable layerwise CPU offload with async H2D prefetch overlap for supported DiT models (e.g., Wan, MOVA). " + "Cannot be used together with cache-dit (SGLANG_CACHE_DIT_ENABLED), dit_cpu_offload, or use_fsdp_inference.", + ) + parser.add_argument( + "--dit-offload-prefetch-size", + type=float, + default=ServerArgs.dit_offload_prefetch_size, + help="The size of prefetch for dit-layerwise-offload. If the value is between 0.0 and 1.0, it is treated as a ratio of the total number of layers. If the value is >= 1, it is treated as the absolute number of layers. 0.0 means prefetch 1 layer (lowest memory). Values above 0.5 might have peak memory close to no offload but worse performance.", + ) + parser.add_argument( + "--use-fsdp-inference", + action=StoreBoolean, + help="Use FSDP for inference by sharding the model weights. Latency is very low due to prefetch--enable if run out of memory.", + ) + parser.add_argument( + "--text-encoder-cpu-offload", + action=StoreBoolean, + help="Use CPU offload for text encoder. Enable if run out of memory.", + ) + parser.add_argument( + "--image-encoder-cpu-offload", + action=StoreBoolean, + help="Use CPU offload for image encoder. Enable if run out of memory.", + ) + parser.add_argument( + "--vae-cpu-offload", + action=StoreBoolean, + help="Use CPU offload for VAE. Enable if run out of memory.", + ) + parser.add_argument( + "--pin-cpu-memory", + action=StoreBoolean, + help='Pin memory for CPU offload. Only added as a temp workaround if it throws "CUDA error: invalid argument". ' + "Should be enabled in almost all cases", + ) + parser.add_argument( + "--disable-autocast", + action=StoreBoolean, + help="Disable autocast for denoising loop and vae decoding in pipeline sampling", + ) + + # Nunchaku SVDQuant quantization parameters + NunchakuSVDQuantArgs.add_cli_args(parser) + + # Master port for distributed inference + parser.add_argument( + "--master-port", + type=int, + default=ServerArgs.master_port, + help="Master port for distributed inference. If not set, a random free port will be used.", + ) + parser.add_argument( + "--scheduler-port", + type=int, + default=ServerArgs.scheduler_port, + help="Port for the scheduler server.", + ) + parser.add_argument( + "--host", + type=str, + default=ServerArgs.host, + help="Host for the HTTP API server.", + ) + parser.add_argument( + "--port", + type=int, + default=ServerArgs.port, + help="Port for the HTTP API server.", + ) + parser.add_argument( + "--webui", + action=StoreBoolean, + default=ServerArgs.webui, + help="Whether to use webui for better display", + ) + + parser.add_argument( + "--webui-port", + type=int, + default=ServerArgs.webui_port, + help="Whether to use webui for better display", + ) + parser.add_argument( + "--output-path", + type=str, + default=ServerArgs.output_path, + help='Directory path to save generated images/videos. Set to "" to disable persistent saving.', + ) + parser.add_argument( + "--input-save-path", + type=str, + default=ServerArgs.input_save_path, + help='Directory path to save uploaded input images/videos. Set to "" to disable persistent saving.', + ) + + # LoRA + parser.add_argument( + "--lora-path", + type=str, + default=ServerArgs.lora_path, + help="The path to the LoRA adapter weights (can be local file path or HF hub id) to launch with", + ) + parser.add_argument( + "--lora-nickname", + type=str, + default=ServerArgs.lora_nickname, + help="The nickname for the LoRA adapter to launch with", + ) + parser.add_argument( + "--lora-scale", + type=float, + default=ServerArgs.lora_scale, + help="LoRA scale for merging (e.g., 0.125 for Hyper-SD). Same as lora_scale in Diffusers", + ) + # Add pipeline configuration arguments + PipelineConfig.add_cli_args(parser) + + # Logging + parser.add_argument( + "--log-level", + type=str, + default=ServerArgs.log_level, + help="The logging level of all loggers.", + ) + parser.add_argument( + "--backend", + type=str, + choices=Backend.choices(), + default=ServerArgs.backend.value, + help="The model backend to use. 'auto' prefers sglang native and falls back to diffusers. " + "'sglang' uses native optimized implementation. 'diffusers' uses vanilla diffusers pipeline.", + ) + return parser + + def url(self): + host = self.host + if not host or host == "0.0.0.0": + host = "127.0.0.1" + elif host == "::": + host = "::1" + if is_valid_ipv6_address(host): + return f"http://[{host}]:{self.port}" + else: + return f"http://{host}:{self.port}" + + @property + def scheduler_endpoint(self): + """ + Internal endpoint for scheduler. + Prefers the configured host but normalizes localhost -> 127.0.0.1 to avoid ZMQ issues. + """ + scheduler_host = self.host + if scheduler_host is None or scheduler_host == "localhost": + scheduler_host = "127.0.0.1" + return f"tcp://{scheduler_host}:{self.scheduler_port}" + + def settle_port( + self, port: int, port_inc: int = 42, max_attempts: int = 100 + ) -> int: + """ + Find an available port with retry logic. + """ + attempts = 0 + original_port = port + + while attempts < max_attempts: + if is_port_available(port): + if attempts > 0: + logger.info( + f"Port {original_port} was unavailable, using port {port} instead" + ) + return port + + attempts += 1 + if port < 60000: + port += port_inc + else: + # Wrap around with randomization to avoid collision + port = 5000 + random.randint(0, 1000) + + raise RuntimeError( + f"Failed to find available port after {max_attempts} attempts " + f"(started from port {original_port})" + ) + + @staticmethod + def _extract_component_paths( + unknown_args: list[str], + ) -> tuple[dict[str, str], list[str]]: + """ + Extract dynamic ``---path`` args from unrecognised CLI args. + """ + component_paths: dict[str, str] = {} + remaining: list[str] = [] + i = 0 + while i < len(unknown_args): + arg = unknown_args[i] + key_part = arg.split("=", 1)[0] if "=" in arg else arg + if key_part.startswith("--") and key_part.endswith("-path"): + component = key_part[2:-5].replace("-", "_") + if "=" in arg: + component_paths[component] = arg.split("=", 1)[1] + elif i + 1 < len(unknown_args) and not unknown_args[i + 1].startswith( + "-" + ): + i += 1 + component_paths[component] = unknown_args[i] + else: + remaining.append(arg) + i += 1 + continue + else: + remaining.append(arg) + i += 1 + + # canonicalize and validate + for component, path in component_paths.items(): + path = os.path.expanduser(path) + component_paths[component] = path + return component_paths, remaining + + @classmethod + def from_cli_args( + cls, args: argparse.Namespace, unknown_args: list[str] | None = None + ) -> "ServerArgs": + if unknown_args is None: + unknown_args = [] + + # extract dynamic ---path from unknown args + dynamic_paths, remaining = cls._extract_component_paths(unknown_args) + if remaining: + raise SystemExit(f"error: unrecognized arguments: {' '.join(remaining)}") + + provided_args = cls.get_provided_args(args, unknown_args) + + # Handle config file + config_file = provided_args.get("config") + if config_file: + config_args = cls.load_config_file(config_file) + provided_args = {**config_args, **provided_args} + + if dynamic_paths: + existing = dict(provided_args.get("component_paths") or {}) + existing.update(dynamic_paths) + provided_args["component_paths"] = existing + + return cls.from_dict(provided_args) + + @classmethod + def from_dict(cls, kwargs: dict[str, Any]) -> "ServerArgs": + """Create a ServerArgs object from a dictionary.""" + attrs = [attr.name for attr in dataclasses.fields(cls)] + server_args_kwargs: dict[str, Any] = {} + + component_paths = dict(kwargs.get("component_paths") or {}) + if component_paths: + server_args_kwargs["component_paths"] = component_paths + + for attr in attrs: + if attr == "pipeline_config": + pipeline_config = PipelineConfig.from_kwargs(kwargs) + logger.debug(f"Using PipelineConfig: {type(pipeline_config)}") + server_args_kwargs["pipeline_config"] = pipeline_config + elif attr == "nunchaku_config": + nunchaku_config = NunchakuSVDQuantArgs.from_dict(kwargs) + server_args_kwargs["nunchaku_config"] = nunchaku_config + elif attr in kwargs: + server_args_kwargs[attr] = kwargs[attr] + + return cls(**server_args_kwargs) + + @staticmethod + def load_config_file(config_file: str) -> dict[str, Any]: + """Load a config file.""" + if config_file.endswith(".json"): + with open(config_file, "r") as f: + return json.load(f) + elif config_file.endswith((".yaml", ".yml")): + try: + import yaml + except ImportError: + raise ImportError( + "Please install PyYAML to use YAML config files. " + "`pip install pyyaml`" + ) + with open(config_file, "r") as f: + return yaml.safe_load(f) + else: + raise ValueError(f"Unsupported config file format: {config_file}") + + @classmethod + def from_kwargs(cls, **kwargs: Any) -> "ServerArgs": + # Convert backend string to enum if necessary + if "backend" in kwargs and isinstance(kwargs["backend"], str): + kwargs["backend"] = Backend.from_string(kwargs["backend"]) + + kwargs["pipeline_config"] = PipelineConfig.from_kwargs(kwargs) + return cls(**kwargs) + + @staticmethod + def get_provided_args( + args: argparse.Namespace, unknown_args: list[str] + ) -> dict[str, Any]: + """Get the arguments provided by the user.""" + provided_args = {} + # We need to check against the raw command-line arguments to see what was + # explicitly provided by the user, vs. what's a default value from argparse. + raw_argv = sys.argv + unknown_args + + # Create a set of argument names that were present on the command line. + # This handles both styles: '--arg=value' and '--arg value'. + provided_arg_names = set() + for arg in raw_argv: + if arg.startswith("--"): + # For '--arg=value', this gets 'arg'; for '--arg', this also gets 'arg'. + arg_name = arg.split("=", 1)[0].replace("-", "_").lstrip("_") + provided_arg_names.add(arg_name) + + # Populate provided_args if the argument from the namespace was on the command line. + for k, v in vars(args).items(): + if k in provided_arg_names: + provided_args[k] = v + + return provided_args + + def _validate_pipeline(self): + if self.pipeline_config is None: + raise ValueError("pipeline_config is not set in ServerArgs") + + self.pipeline_config.check_pipeline_config() + + def _validate_offload(self): + # validate dit_offload_prefetch_size + if self.dit_offload_prefetch_size > 1 and ( + isinstance(self.dit_offload_prefetch_size, float) + and not self.dit_offload_prefetch_size.is_integer() + ): + self.dit_offload_prefetch_size = int( + math.floor(self.dit_offload_prefetch_size) + ) + logger.info( + f"Invalid --dit-offload-prefetch-size value passed, truncated to: {self.dit_offload_prefetch_size}" + ) + + if 0.5 <= self.dit_offload_prefetch_size < 1.0: + logger.info( + "We do not recommend --dit-offload-prefetch-size to be between 0.5 and 1.0" + ) + + # validate dit_layerwise_offload conflicts + if self.dit_layerwise_offload: + if self.dit_offload_prefetch_size < 0.0: + raise ValueError("dit_offload_prefetch_size must be non-negative") + + if self.use_fsdp_inference: + logger.warning( + "dit_layerwise_offload is enabled, automatically disabling use_fsdp_inference." + ) + self.use_fsdp_inference = False + + if self.dit_cpu_offload is None: + logger.warning( + "dit_layerwise_offload is enabled, automatically disabling dit_cpu_offload." + ) + self.dit_cpu_offload = False + + if envs.SGLANG_CACHE_DIT_ENABLED: + raise ValueError( + "dit_layerwise_offload cannot be enabled together with cache-dit. " + "cache-dit may reuse skipped blocks whose weights have been released by layerwise offload, " + "causing shape mismatch errors. " + "Please disable either --dit-layerwise-offload or SGLANG_CACHE_DIT_ENABLED." + ) + + def _validate_parallelism(self): + if self.sp_degree > self.num_gpus or self.num_gpus % self.sp_degree != 0: + raise ValueError( + f"num_gpus ({self.num_gpus}) must be >= and divisible by sp_degree ({self.sp_degree})" + ) + + if ( + self.hsdp_replicate_dim > self.num_gpus + or self.num_gpus % self.hsdp_replicate_dim != 0 + ): + raise ValueError( + f"num_gpus ({self.num_gpus}) must be >= and divisible by hsdp_replicate_dim ({self.hsdp_replicate_dim})" + ) + + if ( + self.hsdp_shard_dim > self.num_gpus + or self.num_gpus % self.hsdp_shard_dim != 0 + ): + raise ValueError( + f"num_gpus ({self.num_gpus}) must be >= and divisible by hsdp_shard_dim ({self.hsdp_shard_dim})" + ) + + if self.num_gpus % self.dp_size != 0: + raise ValueError( + f"num_gpus ({self.num_gpus}) must be divisible by dp_size ({self.dp_size})" + ) + + if self.dp_size < 1: + raise ValueError("--dp-size must be a natural number") + + if self.dp_size > 1: + raise ValueError("DP is not yet supported") + + num_gpus_per_group = self.dp_size * self.tp_size + if self.enable_cfg_parallel: + num_gpus_per_group *= 2 + + if self.num_gpus % num_gpus_per_group != 0: + raise ValueError( + f"num_gpus ({self.num_gpus}) must be divisible by (dp_size * tp_size{' * 2' if self.enable_cfg_parallel else ''}) = {num_gpus_per_group}" + ) + + if self.sp_degree != self.ring_degree * self.ulysses_degree: + raise ValueError( + f"sp_degree ({self.sp_degree}) must equal ring_degree * ulysses_degree " + f"({self.ring_degree} * {self.ulysses_degree} = {self.ring_degree * self.ulysses_degree})" + ) + + if os.getenv("SGLANG_CACHE_DIT_ENABLED", "").lower() == "true": + has_sp = self.sp_degree > 1 + has_tp = self.tp_size > 1 + if has_sp and has_tp: + logger.warning( + "cache-dit is enabled with hybrid parallelism (SP + TP). " + "Proceeding anyway (SGLang integration may support this mode)." + ) + + def _validate_cfg_parallel(self): + if self.enable_cfg_parallel and self.num_gpus == 1: + raise ValueError( + "CFG Parallelism is enabled via `--enable-cfg-parallel`, but num_gpus == 1" + ) + + def _set_default_attention_backend(self) -> None: + """Configure ROCm defaults when users do not specify an attention backend.""" + if current_platform.is_rocm(): + default_backend = AttentionBackendEnum.AITER.name.lower() + self.attention_backend = default_backend + logger.info( + "Attention backend not specified. Using '%s' by default on ROCm " + "to match SGLang SRT defaults.", + default_backend, + ) + + +@dataclasses.dataclass +class PortArgs: + # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq) + scheduler_input_ipc_name: str + + # The port for nccl initialization (torch.dist) + nccl_port: int + + # The ipc filename for rpc call between Engine and Scheduler + rpc_ipc_name: str + + # The ipc filename for Scheduler to send metrics + metrics_ipc_name: str + + # Master port for distributed inference + master_port: int | None = None + + @staticmethod + def from_server_args( + server_args: ServerArgs, dp_rank: Optional[int] = None + ) -> "PortArgs": + if server_args.nccl_port is None: + nccl_port = server_args.scheduler_port + random.randint(100, 1000) + while True: + if is_port_available(nccl_port): + break + if nccl_port < 60000: + nccl_port += 42 + else: + nccl_port -= 43 + else: + nccl_port = server_args.nccl_port + + # Normal case, use IPC within a single node + return PortArgs( + scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", + nccl_port=nccl_port, + rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", + metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", + master_port=server_args.master_port, + ) + + +_global_server_args = None + + +def prepare_server_args(argv: list[str]) -> ServerArgs: + """ + Prepare the inference arguments from the command line arguments. + """ + parser = FlexibleArgumentParser() + ServerArgs.add_cli_args(parser) + raw_args, unknown_args = parser.parse_known_args(argv) + server_args = ServerArgs.from_cli_args(raw_args, unknown_args) + return server_args + + +def set_global_server_args(server_args: ServerArgs): + """ + Set the global sgl_diffusion config for each process + """ + global _global_server_args + _global_server_args = server_args + + +def get_global_server_args() -> ServerArgs: + if _global_server_args is None: + # in ci, usually when we test custom ops/modules directly, + # we don't set the sgl_diffusion config. In that case, we set a default + # config. + # TODO(will): may need to handle this for CI. + raise ValueError("Global sgl_diffusion args is not set.") + return _global_server_args diff --git a/sglang/python/sglang/multimodal_gen/runtime/utils/common.py b/sglang/python/sglang/multimodal_gen/runtime/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e3dc098b9bdd4af091452f4bd55b69a57d3396 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/utils/common.py @@ -0,0 +1,317 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import ipaddress +import logging +import os +import platform +import signal +import socket +import sys +import threading +from functools import lru_cache + +import psutil +import torch +import zmq + +# use the native logger to avoid circular import +logger = logging.getLogger(__name__) + + +def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None): + """Kill the process and all its child processes.""" + # Remove sigchld handler to avoid spammy logs. + if threading.current_thread() is threading.main_thread(): + signal.signal(signal.SIGCHLD, signal.SIG_DFL) + + if parent_pid is None: + parent_pid = os.getpid() + include_parent = False + + try: + itself = psutil.Process(parent_pid) + except psutil.NoSuchProcess: + return + + children = itself.children(recursive=True) + for child in children: + if child.pid == skip_pid: + continue + try: + child.kill() + except psutil.NoSuchProcess: + pass + + if include_parent: + try: + if parent_pid == os.getpid(): + itself.kill() + sys.exit(0) + + itself.kill() + + # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes), + # so we send an additional signal to kill them. + itself.send_signal(signal.SIGQUIT) + except psutil.NoSuchProcess: + pass + + +def add_prefix(name: str, prefix: str) -> str: + """Add a weight path prefix to a module name. + + Args: + name: base module name. + prefix: weight prefix str to added to the front of `name` concatenated with `.`. + + Returns: + The string `prefix.name` if prefix is non-empty, otherwise just `name`. + """ + return name if not prefix else f"{prefix}.{name}" + + +def is_valid_ipv6_address(address: str) -> bool: + try: + ipaddress.IPv6Address(address) + return True + except ValueError: + return False + + +def configure_ipv6(dist_init_addr): + addr = dist_init_addr + end = addr.find("]") + if end == -1: + raise ValueError("invalid IPv6 address format: missing ']'") + + host = addr[: end + 1] + + # this only validates the address without brackets: we still need the below checks. + # if it's invalid, immediately raise an error so we know it's not formatting issues. + if not is_valid_ipv6_address(host[1:end]): + raise ValueError(f"invalid IPv6 address: {host}") + + port_str = None + if len(addr) > end + 1: + if addr[end + 1] == ":": + port_str = addr[end + 2 :] + else: + raise ValueError("received IPv6 address format: expected ':' after ']'") + + if not port_str: + raise ValueError( + "a port must be specified in IPv6 address (format: [ipv6]:port)" + ) + + try: + port = int(port_str) + except ValueError: + raise ValueError(f"invalid port in IPv6 address: '{port_str}'") + return port, host + + +def is_port_available(port): + """Return whether a port is available.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", port)) + s.listen(1) + return True + except socket.error: + return False + except OverflowError: + return False + + +def get_zmq_socket( + context: zmq.Context, + socket_type: zmq.SocketType, + endpoint: str, + bind: bool, + max_bind_retries: int = 10, +) -> tuple[zmq.Socket, str]: + """ + Create and configure a ZMQ socket. + + Args: + context: ZMQ context + socket_type: Type of ZMQ socket + endpoint: Endpoint string (e.g., "tcp://localhost:5555") + bind: Whether to bind (True) or connect (False) + max_bind_retries: Maximum number of retries if bind fails due to address already in use + + Returns: + A tuple of (socket, actual_endpoint). The actual_endpoint may differ from the + requested endpoint if bind retry was needed. + """ + mem = psutil.virtual_memory() + total_mem = mem.total / 1024**3 + available_mem = mem.available / 1024**3 + if total_mem > 32 and available_mem > 16: + buf_size = int(0.5 * 1024**3) + else: + buf_size = -1 + + socket = context.socket(socket_type) + if endpoint.find("[") != -1: + socket.setsockopt(zmq.IPV6, 1) + + def set_send_opt(): + socket.setsockopt(zmq.SNDHWM, 0) + socket.setsockopt(zmq.SNDBUF, buf_size) + + def set_recv_opt(): + socket.setsockopt(zmq.RCVHWM, 0) + socket.setsockopt(zmq.RCVBUF, buf_size) + + if socket_type == zmq.PUSH: + set_send_opt() + elif socket_type == zmq.PULL: + set_recv_opt() + elif socket_type in [zmq.DEALER, zmq.REQ, zmq.REP, zmq.ROUTER]: + set_send_opt() + set_recv_opt() + else: + raise ValueError(f"Unsupported socket type: {socket_type}") + + if bind: + # Parse port from endpoint for retry logic + import re + + port_match = re.search(r":(\d+)$", endpoint) + + if port_match and max_bind_retries > 1: + original_port = int(port_match.group(1)) + last_exception = None + + for attempt in range(max_bind_retries): + try: + current_endpoint = endpoint + if attempt > 0: + # Try next port (increment by 42 to match settle_port logic) + current_port = original_port + attempt * 42 + current_endpoint = re.sub( + r":(\d+)$", f":{current_port}", endpoint + ) + logger.info( + f"ZMQ bind failed for port {original_port + (attempt - 1) * 42}, " + f"retrying with port {current_port} (attempt {attempt + 1}/{max_bind_retries})" + ) + + socket.bind(current_endpoint) + + if attempt > 0: + logger.warning( + f"Successfully bound ZMQ socket to {current_endpoint} after {attempt + 1} attempts. " + f"Original port {original_port} was unavailable." + ) + + return socket, current_endpoint + + except zmq.ZMQError as e: + last_exception = e + if e.errno == zmq.EADDRINUSE and attempt < max_bind_retries - 1: + # Address already in use, try next port + continue + elif attempt == max_bind_retries - 1: + # Last attempt failed + logger.error( + f"Failed to bind ZMQ socket after {max_bind_retries} attempts. " + f"Original endpoint: {endpoint}, Last tried port: {original_port + attempt * 42}" + ) + raise + else: + # Different error, raise immediately + raise + + # Should not reach here, but just in case + if last_exception: + raise last_exception + else: + # No retry logic needed (either no port in endpoint or max_bind_retries == 1) + socket.bind(endpoint) + return socket, endpoint + else: + socket.connect(endpoint) + return socket, endpoint + + return socket, endpoint + + +# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip + + +@lru_cache(maxsize=1) +def is_host_cpu_x86() -> bool: + machine = platform.machine().lower() + return ( + machine in ("x86_64", "amd64", "i386", "i686") + and hasattr(torch, "cpu") + and torch.cpu.is_available() + ) + + +# cuda + + +def set_cuda_arch(): + capability = torch.cuda.get_device_capability() + arch = f"{capability[0]}.{capability[1]}" + os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}" + + +# musa + + +def set_musa_arch(): + capability = torch.cuda.get_device_capability() + arch = f"{capability[0]}{capability[1]}" + os.environ["TORCH_MUSA_ARCH_LIST"] = f"{arch}" + + +# env var managements + +_warned_bool_env_var_keys = set() + + +def get_bool_env_var(name: str, default: str = "false") -> bool: + value = os.getenv(name, default) + value = str(value).strip().lower() + + truthy_values = {"1", "true", "yes", "y", "t", "on"} + falsy_values = {"0", "false", "no", "n", "f", "off", ""} + + if (value not in truthy_values) and (value not in falsy_values): + if value not in _warned_bool_env_var_keys: + logger.warning( + f"get_bool_env_var({name}) see non-understandable value={value} and treat as false" + ) + _warned_bool_env_var_keys.add(value) + + return value in truthy_values + + +try: + import sgl_kernel # noqa: F401 + + is_intel_amx_backend_available = hasattr( + torch.ops.sgl_kernel, "convert_weight_packed" + ) +except: + is_intel_amx_backend_available = False + +try: + # move torch._C._cpu._is_amx_tile_supported() from cpu_has_amx_support + # to support torch compile + is_amx_tile_supported = torch._C._cpu._is_amx_tile_supported() +except: + is_amx_tile_supported = False + + +def cpu_has_amx_support(): + return is_amx_tile_supported and is_intel_amx_backend_available + + +def use_intel_amx_backend(layer): + return getattr(layer, "use_intel_amx_backend", False) diff --git a/sglang/python/sglang/multimodal_gen/runtime/utils/distributed.py b/sglang/python/sglang/multimodal_gen/runtime/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..80afa07218d8d745c7530b354e40090816e221f2 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/utils/distributed.py @@ -0,0 +1,234 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import pickle +from typing import Any, List, Optional + +import numpy as np +import torch +import torch.distributed as dist + +from sglang.multimodal_gen.runtime.platforms import current_platform + + +def broadcast_pyobj( + data: List[Any], + rank: int, + dist_group: Optional[torch.distributed.ProcessGroup] = None, + src: int = 0, + force_cpu_device: bool = True, +): + """Broadcast inputs from src rank to all other ranks with torch.dist backend. + The `rank` here refer to the source rank on global process group (regardless + of dist_group argument). + """ + + device = torch.device( + current_platform.device_type if not force_cpu_device else "cpu" + ) + + if rank == src: + if data is None or len(data) == 0: + tensor_size = torch.tensor([0], dtype=torch.long, device=device) + dist.broadcast(tensor_size, src=src, group=dist_group) + else: + serialized_data = pickle.dumps(data) + size = len(serialized_data) + + tensor_data = torch.ByteTensor( + np.frombuffer(serialized_data, dtype=np.uint8).copy() + ).to(device) + tensor_size = torch.tensor([size], dtype=torch.long, device=device) + + dist.broadcast(tensor_size, src=src, group=dist_group) + dist.broadcast(tensor_data, src=src, group=dist_group) + return data + else: + tensor_size = torch.tensor([0], dtype=torch.long, device=device) + dist.broadcast(tensor_size, src=src, group=dist_group) + size = tensor_size.item() + + if size == 0: + return [] + + tensor_data = torch.empty(size, dtype=torch.uint8, device=device) + dist.broadcast(tensor_data, src=src, group=dist_group) + + serialized_data = bytes(tensor_data.cpu().numpy()) + data = pickle.loads(serialized_data) + return data + + +def generate_masked_orthogonal_rank_groups( + world_size: int, parallel_size: list[int], mask: list[bool] +) -> list[list[int]]: + """Generate orthogonal parallel groups based on the parallel size and mask. + + Arguments: + world_size (int): world size + + parallel_size (List[int]): + The parallel size of each orthogonal parallel type. For example, if + tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, + and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. + + mask (List[bool]): + The mask controls which parallel methods the generated groups represent. If mask[i] is + True, it means the generated group contains the i-th parallelism method. For example, + if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then + the generated group is the `tp-dp` group, if the mask = [False, True, False], then the + generated group is the `pp` group. + + Algorithm: + For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and + + If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each. + For example, if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the + dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].) + The tp_rank and pp_rank will be combined to form the `dp_group_index`. + dp_group_index = tp_rank + pp_rank * tp_size (2) + + So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in + range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the + equation (1). + + This function solve this math problem. + + For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4], + and the mask = [False, True, False]. Then, + dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2 + dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2 + ... + dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2 + + dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4] + dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5] + ... + dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23] + """ + + def prefix_product(a: List[int], init=1) -> List[int]: + r = [init] + for v in a: + init = init * v + r.append(init) + return r + + def inner_product(a: List[int], b: List[int]) -> int: + return sum([x * y for x, y in zip(a, b)]) + + def decompose(index, shape, stride=None): + """ + This function solve the math problem below: + There is an equation: + index = sum(idx[i] * stride[i]) + And given the value of index, stride. + Return the idx. + This function will used to get the pp/dp/pp_rank + from group_index and rank_in_group. + """ + if stride is None: + stride = prefix_product(shape) + idx = [(index // d) % s for s, d in zip(shape, stride)] + # stride is a prefix_product result. And the value of stride[-1] + # is not used. + assert ( + sum([x * y for x, y in zip(idx, stride[:-1])]) == index + ), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx) + return idx + + masked_shape = [s for s, m in zip(parallel_size, mask) if m] + unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] + + global_stride = prefix_product(parallel_size) + masked_stride = [d for d, m in zip(global_stride, mask) if m] + unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] + + group_size = prefix_product(masked_shape)[-1] + num_of_group = world_size // group_size + + ranks = [] + for group_index in range(num_of_group): + # get indices from unmaksed for group_index. + decomposed_group_idx = decompose(group_index, unmasked_shape) + rank = [] + for rank_in_group in range(group_size): + # get indices from masked for rank_in_group. + decomposed_rank_idx = decompose(rank_in_group, masked_shape) + rank.append( + inner_product(decomposed_rank_idx, masked_stride) + + inner_product(decomposed_group_idx, unmasked_stride) + ) + ranks.append(rank) + return ranks + + +class RankGenerator(object): + def __init__( + self, + tp: int, + sp: int, + pp: int, + cfg: int, + dp: int, + order: str, + rank_offset: int = 0, + ) -> None: + self.tp = tp + self.sp = sp + self.pp = pp + self.cfg = cfg + self.dp = dp + self.rank_offset = rank_offset + self.world_size = tp * sp * pp * cfg * dp + + self.name_to_size = { + "tp": self.tp, + "sp": self.sp, + "pp": self.pp, + "cfg": self.cfg, + "dp": self.dp, + } + order = order.lower() + + for name in self.name_to_size.keys(): + if name not in order and self.name_to_size[name] != 1: + raise RuntimeError( + f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order})." + ) + elif name not in order: + order = order + "-" + name + + self.order = order + self.ordered_size = [] + + for token in order.split("-"): + self.ordered_size.append(self.name_to_size[token]) + + def get_mask(self, order: str, token: str): + ordered_token = order.split("-") + token = token.split("-") + mask = [False] * len(ordered_token) + for t in token: + mask[ordered_token.index(t)] = True + return mask + + def get_ranks(self, token): + """Get rank group by input token. + + Arguments: + token (str): + Specify the ranks type that want to get. If we want + to obtain multiple parallel types, we can use a hyphen + '-' to separate them. For example, if we want to obtain + the TP_DP group, the token should be 'tp-dp'. + + """ + mask = self.get_mask(self.order, token) + ranks = generate_masked_orthogonal_rank_groups( + self.world_size, self.ordered_size, mask + ) + if self.rank_offset > 0: + for rank_group in ranks: + for i in range(len(rank_group)): + rank_group[i] += self.rank_offset + return ranks diff --git a/sglang/python/sglang/multimodal_gen/runtime/utils/hf_diffusers_utils.py b/sglang/python/sglang/multimodal_gen/runtime/utils/hf_diffusers_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..63859a407faf47b6ad310e50c3b328ac99fd4264 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/utils/hf_diffusers_utils.py @@ -0,0 +1,953 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from SGLang: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/hf_transformers_utils.py + +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for Huggingface Transformers.""" + +import contextlib +import glob +import json +import os +import shutil +import time +from functools import reduce +from pathlib import Path +from typing import Any, Dict, List, Optional, Union, cast + +from diffusers.loaders.lora_base import ( + _best_guess_weight_name, # watch out for potetential removal from diffusers +) +from huggingface_hub.errors import ( + LocalEntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, +) +from requests.exceptions import ConnectionError as RequestsConnectionError +from requests.exceptions import RequestException +from safetensors import safe_open +from transformers import AutoConfig, PretrainedConfig + +from sglang.multimodal_gen.runtime.layers.quantization import ( + QuantizationConfig, + get_quantization_config, +) +from sglang.multimodal_gen.runtime.loader.utils import _clean_hf_config_inplace +from sglang.multimodal_gen.runtime.loader.weight_utils import get_lock +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.srt.environ import envs +from sglang.utils import is_in_ci + +logger = init_logger(__name__) + + +def _check_index_files_for_missing_shards( + model_path: str, +) -> tuple[bool, list[str], list[str]]: + """ + Check all subdirectories for missing shards based on index files. + + This catches cases where a model download was interrupted, leaving + some safetensors shards missing while the index file exists. + + Args: + model_path: Path to the model directory + + Returns: + Tuple of (all_valid, missing_files, checked_subdirs) + """ + missing_files = [] + checked_subdirs = [] + + # Add common subdirectories for diffusers models + try: + subdirs = os.listdir(model_path) + except OSError as e: + logger.warning("Failed to list model directory %s: %s", model_path, e) + return True, [], [] # Assume valid if we can't check + + # Check the root directory and all subdirectories that might contain model weights + dirs_to_check = [model_path] + + for subdir in subdirs: + subdir_path = os.path.join(model_path, subdir) + if os.path.isdir(subdir_path): + dirs_to_check.append(subdir_path) + + for dir_path in dirs_to_check: + # Find all safetensors index files + index_files = glob.glob(os.path.join(dir_path, "*.safetensors.index.json")) + + for index_file in index_files: + checked_subdirs.append(os.path.basename(dir_path)) + try: + with open(index_file) as f: + index_data = json.load(f) + + weight_map = index_data.get("weight_map", {}) + if not weight_map: + continue + + # Get unique files referenced in weight_map + required_files = set(weight_map.values()) + + for file_name in required_files: + file_path = os.path.join(dir_path, file_name) + if not os.path.exists(file_path): + relative_path = os.path.relpath(file_path, model_path) + missing_files.append(relative_path) + + except Exception as e: + logger.warning("Failed to read index file %s: %s", index_file, e) + continue + + return len(missing_files) == 0, missing_files, checked_subdirs + + +def _cleanup_model_cache(model_path: str, reason: str) -> bool: + """ + Remove the model cache directory to force a clean re-download. + + Args: + model_path: Path to the model directory (snapshot path) + reason: Reason for cleanup (for logging) + + Returns: + True if cleanup was performed, False otherwise + """ + # Navigate up to the model root directory: snapshots/hash -> snapshots -> model_root + # HF cache structure: models--org--name/snapshots/hash/ + try: + snapshot_dir = os.path.abspath(model_path) + snapshots_dir = os.path.dirname(snapshot_dir) + repo_folder = os.path.dirname(snapshots_dir) + + # Verify this looks like an HF cache structure + if os.path.basename(snapshots_dir) != "snapshots": + logger.warning( + "Model path %s doesn't appear to be in HF cache structure, skipping cleanup", + model_path, + ) + return False + + logger.warning( + "Removing model cache at %s. Reason: %s", + repo_folder, + reason, + ) + shutil.rmtree(repo_folder) + logger.info("Successfully removed corrupted cache directory") + return True + except Exception as e: + logger.error( + "Failed to remove corrupted cache directory %s: %s. " + "Manual cleanup may be required.", + model_path, + e, + ) + return False + + +def _ci_validate_diffusers_model(model_path: str) -> tuple[bool, bool]: + """ + CI-specific validation for diffusers models. + + Checks all subdirectories (transformer, transformer_2, vae, etc.) for + missing shards based on their index files. If issues are found in CI, + cleans up the cache to force re-download. + + Args: + model_path: Path to the model directory + + Returns: + Tuple of (is_valid, cleanup_performed) + - is_valid: True if the model is valid + - cleanup_performed: True if cleanup was performed (only relevant when is_valid=False) + """ + if not is_in_ci(): + return True, False + is_valid, missing_files, checked_subdirs = _check_index_files_for_missing_shards( + model_path + ) + + if not is_valid: + logger.error( + "CI validation failed for %s. Missing %d file(s): %s. " + "Checked subdirectories: %s", + model_path, + len(missing_files), + missing_files[:5] if len(missing_files) > 5 else missing_files, + checked_subdirs, + ) + cleanup_performed = _cleanup_model_cache( + model_path, + f"Missing {len(missing_files)} shard file(s): {missing_files[:3]}", + ) + return False, cleanup_performed + + if checked_subdirs: + logger.info( + "CI validation passed for %s. Checked subdirectories: %s", + model_path, + checked_subdirs, + ) + + return True, False + + +def _verify_diffusers_model_complete(path: str) -> bool: + """Check if a diffusers model directory has all required component subdirectories.""" + config_path = os.path.join(path, "model_index.json") + if not os.path.exists(config_path): + return False + + try: + with open(config_path) as config_file: + model_index = json.load(config_file) + except Exception as exc: + logger.warning("Failed to read model_index.json at %s: %s", config_path, exc) + return False + + component_keys = [ + key + for key, value in model_index.items() + if isinstance(value, (list, tuple)) + and len(value) == 2 + and all(isinstance(item, str) for item in value) + ] + if component_keys: + return all(os.path.exists(os.path.join(path, key)) for key in component_keys) + + return os.path.exists(os.path.join(path, "transformer")) and os.path.exists( + os.path.join(path, "vae") + ) + + +_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = { + # ChatGLMConfig.model_type: ChatGLMConfig, + # DbrxConfig.model_type: DbrxConfig, + # ExaoneConfig.model_type: ExaoneConfig, + # Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig, +} + +for name, cls in _CONFIG_REGISTRY.items(): + with contextlib.suppress(ValueError): + AutoConfig.register(name, cls) + + +def download_from_hf(model_path: str): + if os.path.exists(model_path): + return model_path + + return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"]) + + +def get_hf_config( + component_model_path: str, + trust_remote_code: bool, + revision: str | None = None, + model_override_args: dict | None = None, + **kwargs, +) -> PretrainedConfig: + if check_gguf_file(component_model_path): + raise NotImplementedError("GGUF models are not supported.") + + config = AutoConfig.from_pretrained( + component_model_path, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + if config.model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[config.model_type] + config = config_class.from_pretrained(component_model_path, revision=revision) + # NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`. + config._name_or_path = component_model_path + if model_override_args: + config.update(model_override_args) + + return config + + +def get_config( + model: str, + trust_remote_code: bool, + revision: Optional[str] = None, + model_override_args: Optional[dict] = None, + **kwargs, +): + return AutoConfig.from_pretrained( + model, trust_remote_code=trust_remote_code, revision=revision, **kwargs + ) + + +def load_dict(file_path): + if not os.path.exists(file_path): + return {} + try: + # Load the config directly from the file + with open(file_path) as f: + config_dict: dict[str, Any] = json.load(f) + if "_diffusers_version" in config_dict: + config_dict.pop("_diffusers_version") + # TODO(will): apply any overrides from inference args + return config_dict + except Exception as e: + raise RuntimeError( + f"Failed to load diffusers config from {file_path}: {e}" + ) from e + + +def get_diffusers_component_config( + component_path: str, +) -> dict[str, Any]: + """Gets a configuration of a submodule for the given diffusers model.""" + # Download from HuggingFace Hub if path doesn't exist locally + if not os.path.exists(component_path): + component_path = maybe_download_model(component_path) + + config_names = ["generation_config.json"] + # By default, we load config.json, but scheduler_config.json for scheduler + if "scheduler" in component_path: + config_names.append("scheduler_config.json") + else: + config_names.append("config.json") + + config_file_paths = [ + os.path.join(component_path, config_name) for config_name in config_names + ] + + combined_config = reduce( + lambda acc, path: acc | load_dict(path), config_file_paths, {} + ) + + _clean_hf_config_inplace(combined_config) + + logger.debug("HF model config: %s", combined_config) + + return combined_config + + +def replace_prefix(key: str, prefix_mapping: dict[str, str]) -> str: + for prefix, new_prefix in prefix_mapping.items(): + if key.startswith(prefix): + key = key.replace(prefix, new_prefix, 1) + return key + + +def get_quant_config( + model_config, + packed_modules_mapping: Dict[str, List[str]] = {}, + remap_prefix: Dict[str, str] | None = None, +) -> QuantizationConfig: + if "quantization_config" not in model_config: + return None + quant_cls = get_quantization_config( + model_config["quantization_config"]["quant_method"] + ) + + # GGUF doesn't have config file + if model_config["quantization_config"]["quant_method"] == "gguf": + return quant_cls.from_config({}) + + # Read the quantization config from the HF model config, if available. + hf_quant_config = model_config["quantization_config"] + # some vision model may keep quantization_config in their text_config + hf_text_config = getattr(model_config, "text_config", None) + if hf_quant_config is None and hf_text_config is not None: + hf_quant_config = getattr(hf_text_config, "quantization_config", None) + if hf_quant_config is None: + # compressed-tensors uses a compressions_config + hf_quant_config = getattr(model_config, "compression_config", None) + if hf_quant_config is not None: + hf_quant_config["packed_modules_mapping"] = packed_modules_mapping + return quant_cls.from_config(hf_quant_config) + # In case of bitsandbytes/QLoRA, get quant config from the adapter model. + else: + model_name_or_path = model_config["model_path"] + is_local = os.path.isdir(model_name_or_path) + hf_folder = model_name_or_path + + possible_config_filenames = quant_cls.get_config_filenames() + + # If the quantization config is not found, use the default config. + if not possible_config_filenames: + return quant_cls() + + config_files = glob.glob(os.path.join(hf_folder, "*.json")) + + quant_config_files = [ + f for f in config_files if any(f.endswith(x) for x in possible_config_filenames) + ] + if len(quant_config_files) == 0: + raise ValueError( + f"Cannot find the config file for {model_config['quantization_config']['quant_method']}" + ) + if len(quant_config_files) > 1: + raise ValueError( + f"Found multiple config files for {model_config['quantization_config']['quant_method']}: " + f"{quant_config_files}" + ) + + quant_config_file = quant_config_files[0] + with open(quant_config_file) as f: + config = json.load(f) + if remap_prefix is not None: + exclude_modules = [ + replace_prefix(key, remap_prefix) + for key in config["quantization"]["exclude_modules"] + ] + config["quantization"]["exclude_modules"] = exclude_modules + config["packed_modules_mapping"] = packed_modules_mapping + return quant_cls.from_config(config) + + +# Models don't use the same configuration key for determining the maximum +# context length. Store them here so we can sanely check them. +# NOTE: The ordering here is important. Some models have two of these and we +# have a preference for which value gets used. +CONTEXT_LENGTH_KEYS = [ + "max_sequence_length", + "seq_length", + "max_seq_len", + "model_max_length", + "max_position_embeddings", +] + + +def attach_additional_stop_token_ids(tokenizer): + # Special handling for stop token <|eom_id|> generated by llama 3 tool use. + if "<|eom_id|>" in tokenizer.get_added_vocab(): + tokenizer.additional_stop_token_ids = { + tokenizer.get_added_vocab()["<|eom_id|>"] + } + else: + tokenizer.additional_stop_token_ids = None + + +def check_gguf_file(model: str | os.PathLike) -> bool: + """Check if the file is a GGUF model.""" + model = Path(model) + if not model.is_file(): + return False + elif model.suffix == ".gguf": + return True + + with open(model, "rb") as f: + header = f.read(4) + return header == b"GGUF" + + +def maybe_download_lora( + model_name_or_path: str, local_dir: str | None = None, download: bool = True +) -> str: + """ + Check if the model path is a Hugging Face Hub model ID and download it if needed. + Args: + model_name_or_path: Local path or Hugging Face Hub model ID + local_dir: Local directory to save the model + download: Whether to download the model from Hugging Face Hub + + Returns: + Local path to the model + """ + allow_patterns = ["*.json", "*.safetensors", "*.bin"] + + local_path = maybe_download_model( + model_name_or_path, + local_dir, + download, + is_lora=True, + allow_patterns=allow_patterns, + ) + # return directly if local_path is a file + if os.path.isfile(local_path): + return local_path + + weight_name = _best_guess_weight_name(local_path, file_extension=".safetensors") + # AMD workaround: PR 15813 changed from model_name_or_path to local_path, + # which can return None. Fall back to original behavior on ROCm. + if weight_name is None and current_platform.is_rocm(): + weight_name = _best_guess_weight_name( + model_name_or_path, file_extension=".safetensors" + ) + return os.path.join(local_path, weight_name) + + +def verify_model_config_and_directory(model_path: str) -> dict[str, Any]: + """ + Verify that the model directory contains a valid diffusers configuration. + + Args: + model_path: Path to the model directory + + Returns: + The loaded model configuration as a dictionary + """ + + # Check for model_index.json which is required for diffusers models + config_path = os.path.join(model_path, "model_index.json") + if not os.path.exists(config_path): + raise ValueError( + f"Model directory {model_path} does not contain model_index.json. " + "Only HuggingFace diffusers format is supported." + ) + + # Load the config + with open(config_path) as f: + config = json.load(f) + + # Verify diffusers version exists + if "_diffusers_version" not in config: + raise ValueError("model_index.json does not contain _diffusers_version") + + logger.info("Diffusers version: %s", config["_diffusers_version"]) + + component_keys = [ + key + for key, value in config.items() + if isinstance(value, (list, tuple)) + and len(value) == 2 + and all(isinstance(item, str) for item in value) + ] + if component_keys: + missing_components = [ + component_key + for component_key in component_keys + if not os.path.exists(os.path.join(model_path, component_key)) + ] + if missing_components: + missing_str = ", ".join(missing_components) + raise ValueError( + f"Model directory {model_path} is missing required component " + f"directories: {missing_str}." + ) + else: + transformer_dir = os.path.join(model_path, "transformer") + vae_dir = os.path.join(model_path, "vae") + if not os.path.exists(transformer_dir): + raise ValueError( + f"Model directory {model_path} does not contain a transformer/ directory." + ) + if not os.path.exists(vae_dir): + raise ValueError( + f"Model directory {model_path} does not contain a vae/ directory." + ) + return cast(dict[str, Any], config) + + +def maybe_download_model_index(model_name_or_path: str) -> dict[str, Any]: + """ + Download and extract just the model_index.json for a Hugging Face model. + + Args: + model_name_or_path: Path or HF Hub model ID + + Returns: + The parsed model_index.json as a dictionary + """ + import tempfile + + from huggingface_hub.errors import EntryNotFoundError + + # If it's a local path, verify it directly + if os.path.exists(model_name_or_path): + try: + return verify_model_config_and_directory(model_name_or_path) + except ValueError: + # Not a pipeline, maybe a single model. + config_path = os.path.join(model_name_or_path, "config.json") + if os.path.exists(config_path): + with open(config_path) as f: + config = json.load(f) + return config + raise + + # For remote models, download just the model_index.json + try: + with tempfile.TemporaryDirectory() as tmp_dir: + # Download just the model_index.json file + model_index_path = hf_hub_download( + repo_id=model_name_or_path, + filename="model_index.json", + local_dir=tmp_dir, + ) + + # Load the model_index.json + with open(model_index_path) as f: + config: dict[str, Any] = json.load(f) + + # Verify it has the required fields + if "_class_name" not in config: + raise ValueError( + f"model_index.json for {model_name_or_path} does not contain _class_name field" + ) + + if "_diffusers_version" not in config: + raise ValueError( + f"model_index.json for {model_name_or_path} does not contain _diffusers_version field" + ) + + # Add the pipeline name for downstream use + config["pipeline_name"] = config["_class_name"] + + logger.debug( + "Downloaded model_index.json for %s, pipeline: %s", + model_name_or_path, + config["_class_name"], + ) + return config + except EntryNotFoundError: + logger.warning( + "model_index.json not found for %s. Assuming it is a single model and downloading it.", + model_name_or_path, + ) + local_path = maybe_download_model(model_name_or_path) + config_path = os.path.join(local_path, "config.json") + if not os.path.exists(config_path): + raise ValueError( + f"Failed to find config.json for {model_name_or_path} after failing to find model_index.json" + f"You might be looking for models ending with '-Diffusers'" + ) + with open(config_path) as f: + config = json.load(f) + return config + except Exception as e: + raise ValueError( + f"Failed to download or parse model_index.json for {model_name_or_path}: {e}" + ) from e + + +def maybe_download_model( + model_name_or_path: str, + local_dir: str | None = None, + download: bool = True, + is_lora: bool = False, + allow_patterns: list[str] | None = None, + force_diffusers_model: bool = False, +) -> str: + """ + Check if the model path is a Hugging Face Hub model ID and download it if needed. + + Args: + model_name_or_path: Local path or Hugging Face Hub model ID + local_dir: Local directory to save the model + download: Whether to download the model from Hugging Face Hub + is_lora: If True, skip model completeness verification (LoRA models don't have transformer/vae directories) + force_diffusers_model: If True, apply diffusers model check. Otherwise it should be a component model + Returns: + Local path to the model + """ + + # 1. Local path check: if path exists locally, verify it's complete (skip for LoRA) + if os.path.exists(model_name_or_path): + if not force_diffusers_model: + return model_name_or_path + if is_lora or _verify_diffusers_model_complete(model_name_or_path): + if not is_lora: + is_valid, cleanup_performed = _ci_validate_diffusers_model( + model_name_or_path + ) + if not is_valid: + if cleanup_performed: + logger.warning( + "CI validation failed for local model at %s, " + "cache has been cleaned up, will re-download", + model_name_or_path, + ) + # Fall through to download + else: + raise ValueError( + f"CI validation failed for local model at {model_name_or_path}. " + "Some safetensors shards are missing. " + "Please manually delete the model directory and retry." + ) + else: + logger.info("Model already exists locally and is complete") + return model_name_or_path + else: + logger.info("Model already exists locally and is complete") + return model_name_or_path + else: + logger.warning( + "Local model at %s appears incomplete (missing required components), " + "will attempt re-download", + model_name_or_path, + ) + + # 2. Cache-first strategy (Fast Path) + # Try to read from HF cache without network access + try: + logger.info( + "Checking for cached model in HF Hub cache for %s...", model_name_or_path + ) + local_path = snapshot_download( + repo_id=model_name_or_path, + ignore_patterns=["*.onnx", "*.msgpack"], + local_dir=local_dir, + local_files_only=True, + max_workers=8, + ) + if not force_diffusers_model: + return str(local_path) + if is_lora or _verify_diffusers_model_complete(local_path): + if not is_lora: + is_valid, cleanup_performed = _ci_validate_diffusers_model(local_path) + if not is_valid: + logger.warning( + "CI validation failed for cached model at %s, " + "%s, will re-download", + local_path, + ( + "cache has been cleaned up" + if cleanup_performed + else "cleanup was not performed" + ), + ) + # Fall through to download + else: + logger.info("Found complete model in cache at %s", local_path) + return str(local_path) + else: + logger.info("Found complete model in cache at %s", local_path) + return str(local_path) + else: + if not download: + raise ValueError( + f"Model {model_name_or_path} found in cache but is incomplete and download=False." + ) + logger.info( + "Model found in cache but incomplete, will download from HF Hub" + ) + except LocalEntryNotFoundError: + if not download: + raise ValueError( + f"Model {model_name_or_path} not found in local cache and download=False." + ) + logger.info("Model not found in cache, will download from HF Hub") + except Exception as e: + logger.warning( + "Unexpected error while checking cache for %s: %s, will attempt download", + model_name_or_path, + e, + ) + if not download: + raise ValueError( + f"Error checking cache for {model_name_or_path} and download=False: {e}" + ) from e + + # 3. Download strategy (with retry mechanism) + MAX_RETRIES = 5 + for attempt in range(MAX_RETRIES): + try: + logger.info( + "Downloading model snapshot from HF Hub for %s (attempt %d/%d)...", + model_name_or_path, + attempt + 1, + MAX_RETRIES, + ) + with get_lock(model_name_or_path).acquire(poll_interval=2): + local_path = snapshot_download( + repo_id=model_name_or_path, + ignore_patterns=["*.onnx", "*.msgpack"], + allow_patterns=allow_patterns, + local_dir=local_dir, + max_workers=8, + ) + + if not force_diffusers_model: + return str(local_path) + # Verify downloaded model is complete (skip for LoRA) + elif not is_lora and not _verify_diffusers_model_complete(local_path): + logger.warning( + "Downloaded model at %s is incomplete, retrying with force_download=True", + local_path, + ) + with get_lock(model_name_or_path).acquire(poll_interval=2): + local_path = snapshot_download( + repo_id=model_name_or_path, + ignore_patterns=["*.onnx", "*.msgpack"], + local_dir=local_dir, + max_workers=8, + force_download=True, + ) + if not _verify_diffusers_model_complete(local_path): + raise ValueError( + f"Downloaded model at {local_path} is still incomplete after forced re-download. " + "The model repository may be missing required components (model_index.json, transformer/, or vae/)." + ) + + # CI validation: check all subdirectories for missing shards after download + if not is_lora: + is_valid, cleanup_performed = _ci_validate_diffusers_model(local_path) + if not is_valid: + # In CI, if validation fails after download, we have a serious issue + # If cleanup was performed, the next retry should get a fresh download + raise ValueError( + f"CI validation failed for downloaded model at {local_path}. " + f"Some safetensors shards are missing. Cleanup performed: {cleanup_performed}." + ) + + logger.info("Downloaded model to %s", local_path) + return str(local_path) + + except (RepositoryNotFoundError, RevisionNotFoundError) as e: + raise ValueError( + f"Model or revision not found at {model_name_or_path}. " + f"Please check the model ID or ensure you have access to the repository. Error: {e}" + ) from e + except (RequestException, RequestsConnectionError) as e: + if attempt == MAX_RETRIES - 1: + raise ValueError( + f"Could not find model at {model_name_or_path} and failed to download from HF Hub " + f"after {MAX_RETRIES} attempts due to network error: {e}" + ) from e + wait_time = 2**attempt + logger.warning( + "Download failed (attempt %d/%d) due to network error: %s. " + "Retrying in %d seconds...", + attempt + 1, + MAX_RETRIES, + e, + wait_time, + ) + time.sleep(wait_time) + except Exception as e: + raise ValueError( + f"Could not find model at {model_name_or_path} and failed to download from HF Hub: {e}" + ) from e + + +# Unified download functions with Hugging Face-compatible names +def hf_hub_download( + repo_id: str, + filename: str, + local_dir: Optional[Union[str, Path]] = None, + **kwargs, +) -> str: + """Unified hf_hub_download that supports both Hugging Face Hub and ModelScope.""" + if envs.SGLANG_USE_MODELSCOPE.get(): + from modelscope import model_file_download + + return model_file_download( + model_id=repo_id, + file_path=filename, + cache_dir=local_dir, + **kwargs, + ) + else: + from huggingface_hub import hf_hub_download as _hf_hub_download + + return _hf_hub_download( + repo_id=repo_id, + filename=filename, + local_dir=local_dir, + **kwargs, + ) + + +def snapshot_download( + repo_id: str, + local_dir: Optional[Union[str, Path]] = None, + ignore_patterns: Optional[Union[list[str], str]] = None, + allow_patterns: Optional[Union[list[str], str]] = None, + local_files_only: bool = False, + max_workers: int = 8, + **kwargs, +) -> str: + """Unified snapshot_download that supports both Hugging Face Hub and ModelScope.""" + if envs.SGLANG_USE_MODELSCOPE.get(): + from modelscope import snapshot_download as _ms_snapshot_download + + ms_kwargs = { + "model_id": repo_id, + "local_dir": local_dir, + "ignore_patterns": ignore_patterns, + "allow_patterns": allow_patterns, + "local_files_only": local_files_only, + "max_workers": max_workers, + } + ms_kwargs.update(kwargs) + return _ms_snapshot_download(**ms_kwargs) + else: + from huggingface_hub import snapshot_download as _hf_snapshot_download + + hf_kwargs = { + "repo_id": repo_id, + "local_dir": local_dir, + "ignore_patterns": ignore_patterns, + "allow_patterns": allow_patterns, + "local_files_only": local_files_only, + "max_workers": max_workers, + "etag_timeout": 60, + } + hf_kwargs.update(kwargs) + return _hf_snapshot_download(**hf_kwargs) + + +def get_metadata_from_safetensors_file(file_path: str): + try: + with safe_open(file_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + return metadata + except Exception as e: + logger.warning(e) + + +def get_quant_config_from_safetensors_metadata( + file_path: str, +) -> Optional[QuantizationConfig]: + """Extract quantization config from a safetensors file's metadata header. + Returns None if no recognizable quantization metadata is found. + """ + metadata = get_metadata_from_safetensors_file(file_path) + if not metadata: + return None + + quant_config_str = metadata.get("_quantization_metadata") + if not quant_config_str: + return None + try: + quant_config_dict = json.loads(quant_config_str) + except Exception as _e: + return None + + # handle diffusers fp8 safetensors metadata format + if ( + "quant_method" not in quant_config_dict + and "format_version" in quant_config_dict + and "layers" in quant_config_dict + ): + layers = quant_config_dict.get("layers", {}) + if any( + isinstance(v, dict) and "float8" in v.get("format", "") + for v in layers.values() + ): + quant_config_dict["quant_method"] = "fp8" + quant_config_dict["activation_scheme"] = "dynamic" + + quant_method = quant_config_dict.get("quant_method") + if not quant_method: + return None + + try: + quant_cls = get_quantization_config(quant_method) + config = quant_cls.from_config(quant_config_dict) + logger.debug(f"Get quantization config from safetensors file: {file_path}") + return config + except Exception as _e: + return None diff --git a/sglang/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py b/sglang/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..b6b782a01c2937cc1858b7483ccce6055145c03e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py @@ -0,0 +1,499 @@ +import re +from itertools import chain +from typing import Any, Dict, List, Set, Tuple + +import torch + +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +# Adapted from skywork AI Infra diffusion optimize +class LayerwiseOffloadManager: + """A lightweight layerwise CPU offload manager. + + This utility offloads per-layer parameters/buffers from GPU to CPU, and + supports async H2D prefetch using a dedicated CUDA stream. + + Typical usage: + - Construct the manager with the target model and the list-like module + attribute that represents transformer blocks (e.g. ``blocks``). + - Call :meth:`initialize` once to offload weights and prefetch layer 0. + - During forward, call :meth:`prefetch_layer` for the next layer and + :meth:`release_layer` for the finished layer. + """ + + def __init__( + self, + model: torch.nn.Module, + *, + layers_attr_str: str, + num_layers: int, + enabled: bool, + pin_cpu_memory: bool = True, + prefetch_size: int = 1, + ) -> None: + self.model = model + self.layers_attr_str = layers_attr_str + self.num_layers = num_layers + self.pin_cpu_memory = pin_cpu_memory + self.prefetch_size = min(max(1, prefetch_size), self.num_layers) + self.enabled = bool(enabled and torch.get_device_module().is_available()) + if not self.enabled: + return + self.device = torch.device( + current_platform.device_type, torch.get_device_module().current_device() + ) + self.copy_stream = torch.get_device_module().Stream() + + self._layer_name_re = re.compile( + rf"(^|\.){re.escape(layers_attr_str)}\.(\d+)(\.|$)" + ) + + # layer_idx -> {dtype: consolidated_pinned_cpu_tensor} + # stores the consolidated weight from a same layer, of same dtype + self._consolidated_cpu_weights: Dict[int, Dict[torch.dtype, torch.Tensor]] = {} + # layer_idx -> {name: {dtype, offset, numel, shape}} + # stores the offset and numel of each weight from a same layer, of same dtype + self._weight_metadata: Dict[int, Dict[str, Dict[str, Any]]] = {} + # layer indices that are already in gpu + self._gpu_layers: Set[int] = set() + # layer_idx -> torch.get_device_module().Event for fine-grained sync, to make sure the weight is resident in pre-hook + self._prefetch_events: Dict[int, torch.get_device_module().Event] = {} + + self._named_parameters: Dict[str, torch.nn.Parameter] = {} + self._named_buffers: Dict[str, torch.Tensor] = {} + # Store forward hooks for removal + self._forward_hooks: List[Any] = [] + + self._initialize() + + def _match_layer_idx(self, name: str) -> int | None: + m = self._layer_name_re.search(name) + if not m: + return None + try: + return int(m.group(2)) + except Exception: + return None + + @torch.compiler.disable + def _initialize(self) -> None: + if not self.enabled: + return + + self._named_parameters = dict(self.model.named_parameters()) + self._named_buffers = dict(self.model.named_buffers()) + + # 1. collect and group tensors by layer and dtype + layer_groups: Dict[int, Dict[torch.dtype, List[Tuple[str, torch.Tensor]]]] = {} + all_tensors = chain(self._named_parameters.items(), self._named_buffers.items()) + for name, tensor in all_tensors: + layer_idx = self._match_layer_idx(name) + if layer_idx is None or layer_idx >= self.num_layers: + continue + layer_groups.setdefault(layer_idx, {}).setdefault(tensor.dtype, []).append( + (name, tensor) + ) + + # 2. concat and offload (in pinned memory) + for layer_idx, dtype_to_params in layer_groups.items(): + self._consolidated_cpu_weights[layer_idx] = {} + self._weight_metadata[layer_idx] = {} + + for dtype, weights in dtype_to_params.items(): + total_numel = sum(t.numel() for _, t in weights) + + # create concatenated CPU buffer (in pinned memory) + cpu_buffer = torch.empty( + total_numel, dtype=dtype, pin_memory=self.pin_cpu_memory + ) + + # offload weights to the buffer + current_offset = 0 + for name, weight in weights: + numel = weight.numel() + cpu_buffer[current_offset : current_offset + numel].copy_( + weight.flatten() + ) + self._weight_metadata[layer_idx][name] = { + "dtype": dtype, + "offset": current_offset, + "numel": numel, + "shape": weight.shape, + } + + weight.data = torch.empty((1,), device=self.device, dtype=dtype) + + current_offset += numel + + self._consolidated_cpu_weights[layer_idx][dtype] = cpu_buffer + + # prefetch the first layer for warm-up + self.prepare_for_next_req(non_blocking=False) + + self.register_forward_hooks() + logger.info( + f"LayerwiseOffloadManager initialized with num prefetched layer: {self.prefetch_size}, total num layers: {self.num_layers}" + ) + + def prepare_for_next_req(self, non_blocking=True): + """ + Prepare for the next round of denoising loop with prefetching the necessary layers + """ + for i in range(self.prefetch_size): + self.prefetch_layer(i, non_blocking=non_blocking) + if not non_blocking and self.copy_stream is not None: + torch.get_device_module().current_stream().wait_stream(self.copy_stream) + + def get_target_with_name(self, name: str) -> torch.Tensor: + """get the target model weight/buffer to be replaced""" + if name in self._named_parameters: + target = self._named_parameters[name] + else: + target = self._named_buffers[name] + return target + + @torch.compiler.disable + def prefetch_layer(self, layer_idx: int, non_blocking: bool = True) -> None: + """ + idempotent + """ + if not self.enabled or self.device is None or self.copy_stream is None: + return + if layer_idx < 0 or layer_idx >= self.num_layers: + return + if layer_idx in self._gpu_layers: + return + if layer_idx not in self._consolidated_cpu_weights: + return + self.copy_stream.wait_stream(torch.get_device_module().current_stream()) + + # create gpu buffer and load from CPU buffer + gpu_buffers: Dict[torch.dtype, torch.Tensor] = {} + with torch.get_device_module().stream(self.copy_stream): + for dtype, cpu_buffer in self._consolidated_cpu_weights[layer_idx].items(): + gpu_buffer = torch.empty( + cpu_buffer.shape, dtype=dtype, device=self.device + ) + gpu_buffer.copy_(cpu_buffer, non_blocking=non_blocking) + gpu_buffers[dtype] = gpu_buffer + + # record the prefetch event of this layer + event = torch.get_device_module().Event() + event.record(self.copy_stream) + self._prefetch_events[layer_idx] = event + + # restore model's weights by their metadata using gpu buffer + for name, meta in self._weight_metadata[layer_idx].items(): + dtype = meta["dtype"] + gpu_buffer = gpu_buffers[dtype] + + # map the parameter's data to the correct slice of the GPU buffer + target = self.get_target_with_name(name) + target.data = gpu_buffer[ + meta["offset"] : meta["offset"] + meta["numel"] + ].view(meta["shape"]) + + self._gpu_layers.add(layer_idx) + + @torch.compiler.disable + def release_layer(self, layer_idx: int) -> None: + """ + lightweight release layer weights + Basically set the reference count to the gpu weight tensor to zero. The weights on cpu is untouched + """ + if not self.enabled or self.device is None: + return + + # clear prefetch event, since it's useless and needs to be reset + self._prefetch_events.pop(layer_idx, None) + + if layer_idx <= 0: + return + + if layer_idx not in self._gpu_layers: + return + + for name, meta in self._weight_metadata.get(layer_idx, {}).items(): + target = self.get_target_with_name(name) + target.data = torch.empty((1,), device=self.device, dtype=meta["dtype"]) + + self._gpu_layers.discard(layer_idx) + + @torch.compiler.disable + def release_all(self) -> None: + if not self.enabled or self.device is None: + return + if self.copy_stream is not None: + torch.get_device_module().current_stream().wait_stream(self.copy_stream) + + for layer_idx in list(self._gpu_layers): + self.release_layer(layer_idx) + + @torch.compiler.disable + def load_all_layers(self) -> None: + """Load all layers from CPU to GPU.""" + if not self.enabled or self.device is None: + return + if self.copy_stream is not None: + torch.get_device_module().current_stream().wait_stream(self.copy_stream) + + for layer_idx in range(self.num_layers): + if layer_idx not in self._gpu_layers: + self.prefetch_layer(layer_idx, non_blocking=False) + + @torch.compiler.disable + def sync_layer_to_cpu(self, layer_idx: int) -> None: + """Sync a layer's weights from GPU back to CPU.""" + if not self.enabled or layer_idx not in self._gpu_layers: + return + if layer_idx not in self._consolidated_cpu_weights: + return + + if self.copy_stream is not None: + torch.get_device_module().current_stream().wait_stream(self.copy_stream) + + # Collect current GPU weights and write back to CPU buffer + for name, meta in self._weight_metadata.get(layer_idx, {}).items(): + target = self.get_target_with_name(name) + gpu_weight = target.data.flatten().cpu() + + dtype = meta["dtype"] + cpu_buffer = self._consolidated_cpu_weights[layer_idx][dtype] + offset = meta["offset"] + numel = meta["numel"] + cpu_buffer[offset : offset + numel].copy_(gpu_weight) + + @torch.compiler.disable + def sync_all_layers_to_cpu(self) -> None: + """Sync all loaded layers' weights from GPU back to CPU.""" + if not self.enabled or self.device is None: + return + if self.copy_stream is not None: + torch.get_device_module().current_stream().wait_stream(self.copy_stream) + + for layer_idx in list(self._gpu_layers): + self.sync_layer_to_cpu(layer_idx) + + @torch.compiler.disable + def update_cpu_weights( + self, weight_dict: Dict[str, torch.Tensor] + ) -> Set[str] | None: + """Update consolidated CPU buffers with new weights. + + When layerwise offload (--dit-layerwise-offload) is enabled, the + offload manager replaces GPU parameters with small torch.empty((1,)) + placeholders while real weights live in consolidated pinned CPU + buffers. + + The refit process writes new weights directly into the CPU buffers, + bypassing the placeholders. For any layer that happens to be resident + on the GPU at update time, the live GPU tensor is also updated. + + Args: + weight_dict: Mapping of parameter name to new weight tensor. + + Returns: + Set of parameter names that were successfully updated. + + Raises: + ValueError: If a weight's shape does not match the recorded + metadata (i.e., the real shape, not the placeholder shape). + """ + if not self.enabled: + return None + + updated_names: Set[str] = set() + for name, loaded_weight in weight_dict.items(): + layer_idx = self._match_layer_idx(name) + if layer_idx is None: + continue + meta_layer = self._weight_metadata.get(layer_idx) + if meta_layer is None or name not in meta_layer: + continue + + meta = meta_layer[name] + if tuple(meta["shape"]) != tuple(loaded_weight.shape): + raise ValueError( + f"Shape mismatch for {name}: " + f"expected={tuple(meta['shape'])}, " + f"loaded={tuple(loaded_weight.shape)}" + ) + + dtype = meta["dtype"] + offset = meta["offset"] + numel = meta["numel"] + cpu_buffer = self._consolidated_cpu_weights[layer_idx][dtype] + cpu_buffer[offset : offset + numel].copy_( + loaded_weight.to(dtype=dtype).flatten() + ) + + # If this layer is currently on GPU, update the live parameter. + if layer_idx in self._gpu_layers: + target = self.get_target_with_name(name) + target.data.copy_(loaded_weight.to(dtype=target.dtype)) + + updated_names.add(name) + + return updated_names + + def iter_cpu_weights(self): + """Yield (name, tensor) pairs from consolidated CPU buffers. + + This reconstructs the original weight tensors (with correct shapes) + from the flat CPU buffers using stored metadata. Unlike + model.named_parameters(), which returns (1,) placeholders + when offload is enabled, this method returns the real weights and + can be used for checksum computation. + """ + for layer_idx in sorted(self._weight_metadata): + for name, meta in self._weight_metadata[layer_idx].items(): + dtype = meta["dtype"] + offset = meta["offset"] + numel = meta["numel"] + shape = meta["shape"] + cpu_buffer = self._consolidated_cpu_weights[layer_idx][dtype] + yield name, cpu_buffer[offset : offset + numel].reshape(shape) + + def register_forward_hooks(self) -> None: + if not self.enabled: + return + + layers = getattr(self.model, self.layers_attr_str) + + def make_pre_hook(i): + def hook(module, input): + # wait only for the current layer if it's being prefetched + if i == 0: + self.prepare_for_next_req(non_blocking=False) + if i in self._prefetch_events: + torch.get_device_module().current_stream().wait_event( + self._prefetch_events[i] + ) + + # trigger batch prefetch (i + prefetch_size ~ i + 2 * prefetch_size) if needed + if i % self.prefetch_size == 0: + for j in range(i + self.prefetch_size, i + 2 * self.prefetch_size): + layer_to_prefetch = j % self.num_layers + self.prefetch_layer(layer_to_prefetch, non_blocking=True) + + return hook + + def make_post_hook(i): + def hook(module, input, output): + # previous, we wait here, until the copy stream for next layer is finished, + # now with any prefetch_size, only wait for the copy stream, when the copy stream is for the next layer + self.release_layer(i) + + return hook + + # register prefetch & release hooks for each layer + self._forward_hooks.clear() + for i, layer in enumerate(layers): + pre_hook_handle = layer.register_forward_pre_hook(make_pre_hook(i)) + post_hook_handle = layer.register_forward_hook(make_post_hook(i)) + self._forward_hooks.extend([pre_hook_handle, post_hook_handle]) + + def remove_forward_hooks(self) -> None: + """Remove all registered forward hooks.""" + for hook_handle in self._forward_hooks: + hook_handle.remove() + self._forward_hooks.clear() + + +class OffloadableDiTMixin: + """ + A mixin that registers forward hooks for a DiT to enable layerwise offload + """ + + # the list of names of a DiT's layers/blocks + layer_names: List[str] + layerwise_offload_managers: list[LayerwiseOffloadManager] = [] + + def configure_layerwise_offload(self, server_args: ServerArgs): + self.layerwise_offload_managers = [] + for layer_name in self.layer_names: + # a manager per layer-list + module_list = getattr(self, layer_name, None) + if module_list is None or not isinstance(module_list, torch.nn.ModuleList): + continue + + num_layers = len(module_list) + if server_args.dit_offload_prefetch_size < 1.0: + prefetch_size = 1 + int( + round(server_args.dit_offload_prefetch_size * (num_layers - 1)) + ) + else: + prefetch_size = int(server_args.dit_offload_prefetch_size) + + manager = LayerwiseOffloadManager( + model=self, + layers_attr_str=layer_name, + num_layers=num_layers, + enabled=True, + pin_cpu_memory=server_args.pin_cpu_memory, + prefetch_size=prefetch_size, + ) + self.layerwise_offload_managers.append(manager) + + logger.info( + f"Enabled layerwise offload for {self.__class__.__name__} on modules: {self.layer_names}" + ) + + def prepare_for_next_req(self): + if self.layerwise_offload_managers is None: + return + for manager in self.layerwise_offload_managers: + manager.prepare_for_next_req(non_blocking=True) + + def disable_offload(self) -> None: + """Disable layerwise offload: load all layers to GPU and remove hooks.""" + if self.layerwise_offload_managers is None: + return + for manager in self.layerwise_offload_managers: + if manager.enabled: + manager.remove_forward_hooks() + manager.load_all_layers() + + def enable_offload(self) -> None: + """Re-enable layerwise offload: sync weights to CPU, release layers, and restore hooks.""" + if self.layerwise_offload_managers is None: + return + for manager in self.layerwise_offload_managers: + if manager.enabled: + manager.sync_all_layers_to_cpu() + manager.release_all() + manager.register_forward_hooks() + + +def iter_materialized_weights(module: torch.nn.Module): + """Yield (name, tensor) pairs with materialized weights, even under offload. + + When layerwise offload is active, module.named_parameters() returns + (1,) placeholders for offloaded layers. This function reads the + actual data from the offload manager's CPU buffers and chains it with + the non-offloaded parameters. + """ + offload_managers: list = [] + if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers: + offload_managers = [m for m in module.layerwise_offload_managers if m.enabled] + + if not offload_managers: + yield from module.named_parameters() + return + + # Collect offloaded names and their real tensors from CPU buffers. + offloaded_names: set[str] = set() + for manager in offload_managers: + for name, tensor in manager.iter_cpu_weights(): + offloaded_names.add(name) + yield name, tensor + + # Yield non-offloaded parameters (e.g. final norms, embeddings). + for name, param in module.named_parameters(): + if name not in offloaded_names: + yield name, param diff --git a/sglang/python/sglang/multimodal_gen/runtime/utils/logging_utils.py b/sglang/python/sglang/multimodal_gen/runtime/utils/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2b94bfe8b4667267a6a6ab09ec9ad89844190482 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/utils/logging_utils.py @@ -0,0 +1,496 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/logger.py +"""Logging configuration for sglang.multimodal_gen.""" + +import argparse +import contextlib +import datetime +import logging +import os +import sys +import time +from contextlib import contextmanager +from functools import lru_cache, partial +from logging import Logger +from types import MethodType +from typing import Any, cast + +import sglang.multimodal_gen.envs as envs + +SGLANG_DIFFUSION_LOGGING_LEVEL = envs.SGLANG_DIFFUSION_LOGGING_LEVEL +SGLANG_DIFFUSION_LOGGING_PREFIX = envs.SGLANG_DIFFUSION_LOGGING_PREFIX + +# color +CYAN = "\033[1;36m" +RED = "\033[91m" +GREEN = "\033[92m" +YELLOW = "\033[93m" +RESET = "\033[0;0m" + +_FORMAT = ( + f"{SGLANG_DIFFUSION_LOGGING_PREFIX}%(levelname)s %(asctime)s " + "[%(filename)s: %(lineno)d] %(message)s" +) + +# _FORMAT = "[%(asctime)s] %(message)s" +_DATE_FORMAT = "%m-%d %H:%M:%S" + +DEFAULT_LOGGING_CONFIG = { + "formatters": { + "sgl_diffusion": { + "class": "sglang.multimodal_gen.runtime.utils.logging_utils.ColoredFormatter", + "datefmt": _DATE_FORMAT, + "format": _FORMAT, + }, + }, + "handlers": { + "sgl_diffusion": { + "class": "logging.StreamHandler", + "formatter": "sgl_diffusion", + "level": SGLANG_DIFFUSION_LOGGING_LEVEL, + "stream": "ext://sys.stdout", + }, + }, + "loggers": { + "sgl_diffusion": { + "handlers": ["sgl_diffusion"], + "level": "WARNING", + "propagate": False, + }, + }, + "root": { + "handlers": ["sgl_diffusion"], + "level": "DEBUG", + }, + "version": 1, + "disable_existing_loggers": False, +} + + +class ColoredFormatter(logging.Formatter): + """A logging formatter that adds color to log levels.""" + + LEVEL_COLORS = { + logging.ERROR: RED, + logging.WARNING: YELLOW, + } + + def format(self, record: logging.LogRecord) -> str: + """Adds color to the log""" + + formatted_message = super().format(record) + + color = self.LEVEL_COLORS.get(record.levelno) + if color: + formatted_message = f"{color}{formatted_message}{RESET}" + + return formatted_message + + +class SortedHelpFormatter(argparse.HelpFormatter): + """SortedHelpFormatter that sorts arguments by their option strings.""" + + def add_arguments(self, actions): + actions = sorted(actions, key=lambda x: x.option_strings) + super().add_arguments(actions) + + +@lru_cache +def _print_info_once(logger: Logger, msg: str) -> None: + # Set the stacklevel to 2 to print the original caller's line info + logger.info(msg, stacklevel=2) + + +@lru_cache +def _print_warning_once(logger: Logger, msg: str) -> None: + # Set the stacklevel to 2 to print the original caller's line info + logger.warning(msg, stacklevel=2) + + +def get_is_main_process(): + try: + rank = int(os.environ["RANK"]) + except (KeyError, ValueError): + rank = 0 + return rank == 0 + + +def get_is_local_main_process(): + try: + rank = int(os.environ["LOCAL_RANK"]) + except (KeyError, ValueError): + rank = 0 + return rank == 0 + + +def _log_process_aware( + server_log_level: int, + level: int, + logger_self: Logger, + msg: object, + *args: Any, + main_process_only: bool, + local_main_process_only: bool, + **kwargs: Any, +) -> None: + """Helper function to log a message if the process rank matches the criteria.""" + is_main_process = get_is_main_process() + is_local_main_process = get_is_local_main_process() + should_log = ( + not main_process_only + and not local_main_process_only + or (main_process_only and is_main_process) + or (local_main_process_only and is_local_main_process) + or server_log_level <= logging.DEBUG + ) + + if should_log: + # stacklevel=3 to show the original caller's location, + # as this function is called by the patched methods. + if "stacklevel" in kwargs: + logger_self.log(level, msg, *args, **kwargs) + else: + logger_self.log(level, msg, *args, stacklevel=3, **kwargs) + + +class _SGLDiffusionLogger(Logger): + """ + Note: + This class is just to provide type information. + We actually patch the methods directly on the :class:`logging.Logger` + instance to avoid conflicting with other libraries such as + `intel_extension_for_pytorch.utils._logger`. + """ + + def info_once(self, msg: str) -> None: + """ + As :meth:`info`, but subsequent calls with the same message + are silently dropped. + """ + _print_info_once(self, msg) + + def warning_once(self, msg: str) -> None: + """ + As :meth:`warning`, but subsequent calls with the same message + are silently dropped. + """ + _print_warning_once(self, msg) + + def info( # type: ignore[override] + self, + msg: object, + *args: Any, + main_process_only: bool = True, + local_main_process_only: bool = True, + **kwargs: Any, + ) -> None: ... + + def debug( # type: ignore[override] + self, + msg: object, + *args: Any, + main_process_only: bool = True, + local_main_process_only: bool = True, + **kwargs: Any, + ) -> None: ... + + def warning( # type: ignore[override] + self, + msg: object, + *args: Any, + main_process_only: bool = False, + local_main_process_only: bool = True, + **kwargs: Any, + ) -> None: ... + + def error( # type: ignore[override] + self, + msg: object, + *args: Any, + main_process_only: bool = False, + local_main_process_only: bool = True, + **kwargs: Any, + ) -> None: ... + + +def init_logger(name: str) -> _SGLDiffusionLogger: + """The main purpose of this function is to ensure that loggers are + retrieved in such a way that we can be sure the root sgl_diffusion logger has + already been configured.""" + + logger = logging.getLogger(name) + + server_log_level = logger.getEffectiveLevel() + + # Patch instance methods + setattr(logger, "info_once", MethodType(_print_info_once, logger)) + setattr(logger, "warning_once", MethodType(_print_warning_once, logger)) + + def _create_patched_method( + level: int, + main_process_only_default: bool, + local_main_process_only_default: bool, + ): + def _method( + self: Logger, + msg: object, + *args: Any, + main_process_only: bool = main_process_only_default, + local_main_process_only: bool = local_main_process_only_default, + **kwargs: Any, + ) -> None: + _log_process_aware( + server_log_level, + level, + self, + msg, + *args, + main_process_only=main_process_only, + local_main_process_only=local_main_process_only, + **kwargs, + ) + + return _method + + setattr( + logger, + "info", + MethodType(_create_patched_method(logging.INFO, True, True), logger), + ) + setattr( + logger, + "debug", + MethodType(_create_patched_method(logging.DEBUG, True, True), logger), + ) + setattr( + logger, + "warning", + MethodType(_create_patched_method(logging.WARNING, False, True), logger), + ) + setattr( + logger, + "error", + MethodType(_create_patched_method(logging.ERROR, False, False), logger), + ) + + return cast(_SGLDiffusionLogger, logger) + + +logger = init_logger(__name__) + + +def _trace_calls(log_path, root_dir, frame, event, arg=None): + if event in ["call", "return"]: + # Extract the filename, line number, function name, and the code object + filename = frame.f_code.co_filename + lineno = frame.f_lineno + func_name = frame.f_code.co_name + if not filename.startswith(root_dir): + # only log the functions in the sgl_diffusion root_dir + return + # Log every function call or return + try: + last_frame = frame.f_back + if last_frame is not None: + last_filename = last_frame.f_code.co_filename + last_lineno = last_frame.f_lineno + last_func_name = last_frame.f_code.co_name + else: + # initial frame + last_filename = "" + last_lineno = 0 + last_func_name = "" + with open(log_path, "a") as f: + ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") + if event == "call": + f.write( + f"{ts} Call to" + f" {func_name} in {filename}:{lineno}" + f" from {last_func_name} in {last_filename}:" + f"{last_lineno}\n" + ) + else: + f.write( + f"{ts} Return from" + f" {func_name} in {filename}:{lineno}" + f" to {last_func_name} in {last_filename}:" + f"{last_lineno}\n" + ) + except NameError: + # modules are deleted during shutdown + pass + return partial(_trace_calls, log_path, root_dir) + + +def enable_trace_function_call(log_file_path: str, root_dir: str | None = None): + """ + Enable tracing of every function call in code under `root_dir`. + This is useful for debugging hangs or crashes. + `log_file_path` is the path to the log file. + `root_dir` is the root directory of the code to trace. If None, it is the + sgl_diffusion root directory. + + Note that this call is thread-level, any threads calling this function + will have the trace enabled. Other threads will not be affected. + """ + logger.warning( + "SGLANG_DIFFUSION_TRACE_FUNCTION is enabled. It will record every" + " function executed by Python. This will slow down the code. It " + "is suggested to be used for debugging hang or crashes only." + ) + logger.info("Trace frame log is saved to %s", log_file_path) + if root_dir is None: + # by default, this is the sgl_diffusion root directory + root_dir = os.path.dirname(os.path.dirname(__file__)) + sys.settrace(partial(_trace_calls, log_file_path, root_dir)) + + +def set_uvicorn_logging_configs(): + from uvicorn.config import LOGGING_CONFIG + + LOGGING_CONFIG["formatters"]["default"][ + "fmt" + ] = "[%(asctime)s] %(levelprefix)s %(message)s" + LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S" + LOGGING_CONFIG["formatters"]["access"][ + "fmt" + ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s' + LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S" + + +def configure_logger(server_args, prefix: str = ""): + log_format = f"[%(asctime)s{prefix}] %(message)s" + datefmt = "%m-%d %H:%M:%S" + + formatter = ColoredFormatter(log_format, datefmt=datefmt) + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(formatter) + + root = logging.getLogger() + root.handlers.clear() + root.addHandler(handler) + root.setLevel(getattr(logging, server_args.log_level.upper())) + + set_uvicorn_logging_configs() + + +@lru_cache(maxsize=1) +def get_log_level() -> int: + root = logging.getLogger() + return root.level + + +def suppress_loggers(loggers_to_suppress: list[str], level: int = logging.WARNING): + original_levels = {} + + for logger_name in loggers_to_suppress: + logger = logging.getLogger(logger_name) + original_levels[logger_name] = logger.level + logger.setLevel(level) + + return original_levels + + +def globally_suppress_loggers(): + # globally suppress some obsessive loggers + target_names = [ + "imageio", + "imageio_ffmpeg", + "PIL", + "PIL_Image", + "python_multipart.multipart", + "filelock", + "urllib3", + ] + + for name in target_names: + logging.getLogger(name).setLevel(logging.ERROR) + + +# source: https://github.com/vllm-project/vllm/blob/a11f4a81e027efd9ef783b943489c222950ac989/vllm/utils/system_utils.py#L60 +@contextlib.contextmanager +def suppress_stdout(): + """ + Suppress stdout from C libraries at the file descriptor level. + + Only suppresses stdout, not stderr, to preserve error messages. + Example: + with suppress_stdout(): + # C library calls that would normally print to stdout + torch.distributed.new_group(ranks, backend="gloo") + """ + # Don't suppress if logging level is DEBUG + + stdout_fd = sys.stdout.fileno() + stdout_dup = os.dup(stdout_fd) + devnull_fd = os.open(os.devnull, os.O_WRONLY) + + try: + sys.stdout.flush() + os.dup2(devnull_fd, stdout_fd) + yield + finally: + sys.stdout.flush() + os.dup2(stdout_dup, stdout_fd) + os.close(stdout_dup) + os.close(devnull_fd) + + +class GenerationTimer: + def __init__(self): + self.start_time = 0.0 + self.end_time = 0.0 + self.duration = 0.0 + + +@contextmanager +def log_generation_timer( + logger: logging.Logger, + prompt: str, + request_idx: int | None = None, + total_requests: int | None = None, +): + if request_idx is not None and total_requests is not None: + logger.info( + "Processing prompt %d/%d: %s", + request_idx, + total_requests, + prompt[:100], + ) + + timer = GenerationTimer() + timer.start_time = time.perf_counter() + try: + yield timer + timer.end_time = time.perf_counter() + timer.duration = timer.end_time - timer.start_time + logger.info( + f"Pixel data generated successfully in {GREEN}%.2f{RESET} seconds", + timer.duration, + ) + except Exception as e: + if request_idx is not None: + logger.error( + "Failed to generate output for prompt %d: %s", + request_idx, + e, + exc_info=True, + ) + else: + logger.error( + f"Failed to generate output for prompt: {e}", + exc_info=True, + ) + raise + + +def log_batch_completion( + logger: logging.Logger, num_outputs: int, total_time: float +) -> None: + logger.info( + f"Completed batch processing. Generated %d outputs in {GREEN}%.2f{RESET} seconds", + num_outputs, + total_time, + ) diff --git a/sglang/python/sglang/multimodal_gen/runtime/utils/mesh3d_utils.py b/sglang/python/sglang/multimodal_gen/runtime/utils/mesh3d_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8d17eaf9d2c40e5c6133d6392ff301055e3a2cf8 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/utils/mesh3d_utils.py @@ -0,0 +1,1114 @@ +"""Adapted from Hunyuan3D-2: https://github.com/Tencent/Hunyuan3D-2""" + +from __future__ import annotations + +import math +from typing import Any, List, Optional, Tuple, Union + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +import trimesh +from einops import rearrange, repeat +from PIL import Image + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +# Import C++ mesh processor extension +from sglang.multimodal_gen.csrc.render.mesh_processor import meshVerticeInpaint + + +def transform_pos( + mtx: Union[np.ndarray, torch.Tensor], + pos: torch.Tensor, + keepdim: bool = False, +) -> torch.Tensor: + """Transform positions by a matrix.""" + t_mtx = torch.from_numpy(mtx).to(pos.device) if isinstance(mtx, np.ndarray) else mtx + + if pos.shape[-1] == 3: + posw = torch.cat([pos, torch.ones([pos.shape[0], 1]).to(pos.device)], axis=1) + else: + posw = pos + + if keepdim: + return torch.matmul(posw, t_mtx.t())[...] + else: + return torch.matmul(posw, t_mtx.t())[None, ...] + + +def get_mv_matrix( + elev: float, + azim: float, + camera_distance: float, + center: Optional[np.ndarray] = None, +) -> np.ndarray: + """Compute model-view matrix from camera parameters.""" + elev = -elev + azim += 90 + + elev_rad = math.radians(elev) + azim_rad = math.radians(azim) + + camera_position = np.array( + [ + camera_distance * math.cos(elev_rad) * math.cos(azim_rad), + camera_distance * math.cos(elev_rad) * math.sin(azim_rad), + camera_distance * math.sin(elev_rad), + ] + ) + + if center is None: + center = np.array([0, 0, 0]) + else: + center = np.array(center) + + lookat = center - camera_position + lookat = lookat / np.linalg.norm(lookat) + + up = np.array([0, 0, 1.0]) + right = np.cross(lookat, up) + right = right / np.linalg.norm(right) + up = np.cross(right, lookat) + up = up / np.linalg.norm(up) + + c2w = np.concatenate( + [np.stack([right, up, -lookat], axis=-1), camera_position[:, None]], axis=-1 + ) + + w2c = np.zeros((4, 4)) + w2c[:3, :3] = np.transpose(c2w[:3, :3], (1, 0)) + w2c[:3, 3:] = -np.matmul(np.transpose(c2w[:3, :3], (1, 0)), c2w[:3, 3:]) + w2c[3, 3] = 1.0 + + return w2c.astype(np.float32) + + +def get_orthographic_projection_matrix( + left: float = -1, + right: float = 1, + bottom: float = -1, + top: float = 1, + near: float = 0, + far: float = 2, +) -> np.ndarray: + """Compute orthographic projection matrix.""" + ortho_matrix = np.eye(4, dtype=np.float32) + ortho_matrix[0, 0] = 2 / (right - left) + ortho_matrix[1, 1] = 2 / (top - bottom) + ortho_matrix[2, 2] = -2 / (far - near) + ortho_matrix[0, 3] = -(right + left) / (right - left) + ortho_matrix[1, 3] = -(top + bottom) / (top - bottom) + ortho_matrix[2, 3] = -(far + near) / (far - near) + return ortho_matrix + + +def get_perspective_projection_matrix( + fovy: float, + aspect_wh: float, + near: float, + far: float, +) -> np.ndarray: + """Compute perspective projection matrix.""" + fovy_rad = math.radians(fovy) + return np.array( + [ + [1.0 / (math.tan(fovy_rad / 2.0) * aspect_wh), 0, 0, 0], + [0, 1.0 / math.tan(fovy_rad / 2.0), 0, 0], + [0, 0, -(far + near) / (far - near), -2.0 * far * near / (far - near)], + [0, 0, -1, 0], + ] + ).astype(np.float32) + + +def export_to_trimesh(mesh_output: Any) -> Any: + """Convert mesh output to trimesh format.""" + if isinstance(mesh_output, list): + outputs = [] + for mesh in mesh_output: + if mesh is None: + outputs.append(None) + else: + # Reverse face winding + mesh.mesh_f = mesh.mesh_f[:, ::-1] + mesh_obj = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f) + outputs.append(mesh_obj) + return outputs + else: + mesh_output.mesh_f = mesh_output.mesh_f[:, ::-1] + return trimesh.Trimesh(mesh_output.mesh_v, mesh_output.mesh_f) + + +def mesh_uv_wrap(mesh: Any) -> Any: + """Apply UV unwrapping to mesh. In-place like native Hunyuan3D-2 for same layout.""" + try: + import xatlas + except ImportError: + logger.warning("xatlas not available, skipping UV unwrap") + return mesh + + if isinstance(mesh, trimesh.Scene): + mesh = mesh.dump(concatenate=True) + + if len(mesh.faces) > 500000000: + raise ValueError( + "The mesh has more than 500,000,000 faces, which is not supported." + ) + + vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces) + + mesh.vertices = mesh.vertices[vmapping] + mesh.faces = indices + if not hasattr(mesh.visual, "uv"): + mesh.visual = trimesh.visual.TextureVisuals( + uv=uvs, material=trimesh.visual.material.SimpleMaterial() + ) + else: + mesh.visual.uv = uvs + + return mesh + + +def stride_from_shape(shape: Tuple[int, ...]) -> List[int]: + """Compute stride from shape for scatter operations.""" + stride = [1] + for x in reversed(shape[1:]): + stride.append(stride[-1] * x) + return list(reversed(stride)) + + +def scatter_add_nd_with_count( + input: torch.Tensor, + count: torch.Tensor, + indices: torch.Tensor, + values: torch.Tensor, + weights: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Scatter add with counting for texture baking.""" + D = indices.shape[-1] + C = input.shape[-1] + size = input.shape[:-1] + stride = stride_from_shape(size) + + assert len(size) == D + + input = input.view(-1, C) + count = count.view(-1, 1) + + flatten_indices = ( + indices * torch.tensor(stride, dtype=torch.long, device=indices.device) + ).sum(-1) + + if weights is None: + weights = torch.ones_like(values[..., :1]) + + input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values) + count.scatter_add_(0, flatten_indices.unsqueeze(1), weights) + + return input.view(*size, C), count.view(*size, 1) + + +def linear_grid_put_2d( + H: int, + W: int, + coords: torch.Tensor, + values: torch.Tensor, + return_count: bool = False, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Put values into a 2D grid using linear interpolation.""" + C = values.shape[-1] + + indices = coords * torch.tensor( + [H - 1, W - 1], dtype=torch.float32, device=coords.device + ) + indices_00 = indices.floor().long() + indices_00[:, 0].clamp_(0, H - 2) + indices_00[:, 1].clamp_(0, W - 2) + indices_01 = indices_00 + torch.tensor( + [0, 1], dtype=torch.long, device=indices.device + ) + indices_10 = indices_00 + torch.tensor( + [1, 0], dtype=torch.long, device=indices.device + ) + indices_11 = indices_00 + torch.tensor( + [1, 1], dtype=torch.long, device=indices.device + ) + + h = indices[..., 0] - indices_00[..., 0].float() + w = indices[..., 1] - indices_00[..., 1].float() + w_00 = (1 - h) * (1 - w) + w_01 = (1 - h) * w + w_10 = h * (1 - w) + w_11 = h * w + + result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) + count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) + weights = torch.ones_like(values[..., :1]) + + result, count = scatter_add_nd_with_count( + result, + count, + indices_00, + values * w_00.unsqueeze(1), + weights * w_00.unsqueeze(1), + ) + result, count = scatter_add_nd_with_count( + result, + count, + indices_01, + values * w_01.unsqueeze(1), + weights * w_01.unsqueeze(1), + ) + result, count = scatter_add_nd_with_count( + result, + count, + indices_10, + values * w_10.unsqueeze(1), + weights * w_10.unsqueeze(1), + ) + result, count = scatter_add_nd_with_count( + result, + count, + indices_11, + values * w_11.unsqueeze(1), + weights * w_11.unsqueeze(1), + ) + + if return_count: + return result, count + + mask = count.squeeze(-1) > 0 + result[mask] = result[mask] / count[mask].repeat(1, C) + + return result + + +class MeshRender: + """Mesh renderer using CUDA rasterization for texture generation.""" + + def __init__( + self, + camera_distance: float = 1.45, + camera_type: str = "orth", + default_resolution: int = 1024, + texture_size: int = 1024, + bake_mode: str = "linear", + device: str = "cuda", + ): + """Initialize the mesh renderer.""" + self.device = device + + self.set_default_render_resolution(default_resolution) + self.set_default_texture_resolution(texture_size) + + self.camera_distance = camera_distance + self.camera_type = camera_type + self.bake_angle_thres = 75 + self.bake_unreliable_kernel_size = int( + (2 / 512) * max(self.default_resolution[0], self.default_resolution[1]) + ) + self.bake_mode = bake_mode + + # Set up camera projection matrix + if camera_type == "orth": + self.ortho_scale = 1.2 + self.camera_proj_mat = get_orthographic_projection_matrix( + left=-self.ortho_scale * 0.5, + right=self.ortho_scale * 0.5, + bottom=-self.ortho_scale * 0.5, + top=self.ortho_scale * 0.5, + near=0.1, + far=100, + ) + elif camera_type == "perspective": + self.camera_proj_mat = get_perspective_projection_matrix( + 49.13, + self.default_resolution[1] / self.default_resolution[0], + 0.01, + 100.0, + ) + else: + raise ValueError(f"Unknown camera type: {camera_type}") + + # Mesh data + self.vtx_pos = None + self.pos_idx = None + self.vtx_uv = None + self.uv_idx = None + self.tex = None + self.mesh_copy = None + self.scale_factor = 1.0 + + def set_default_render_resolution( + self, default_resolution: Union[int, Tuple[int, int]] + ): + """Set default rendering resolution.""" + if isinstance(default_resolution, int): + default_resolution = (default_resolution, default_resolution) + self.default_resolution = default_resolution + + def set_default_texture_resolution(self, texture_size: Union[int, Tuple[int, int]]): + """Set default texture resolution.""" + if isinstance(texture_size, int): + texture_size = (texture_size, texture_size) + self.texture_size = texture_size + + def _rasterize( + self, + pos_clip: torch.Tensor, + tri: torch.Tensor, + resolution: Tuple[int, int], + ) -> torch.Tensor: + """Rasterize using CUDA rasterizer.""" + from sglang.multimodal_gen.csrc.render.hunyuan3d_rasterizer import rasterize + + if pos_clip.dim() == 2: + pos_clip = pos_clip.unsqueeze(0) + + findices, barycentric = rasterize(pos_clip, tri, resolution) + rast_out = torch.cat((barycentric, findices.unsqueeze(-1).float()), dim=-1) + rast_out = rast_out.unsqueeze(0) + return rast_out + + def _interpolate( + self, + attr: torch.Tensor, + rast_out: torch.Tensor, + tri: torch.Tensor, + ) -> torch.Tensor: + """Interpolate vertex attributes.""" + from sglang.multimodal_gen.csrc.render.hunyuan3d_rasterizer import interpolate + + barycentric = rast_out[0, ..., :-1] + findices = rast_out[0, ..., -1].int() + + if attr.dim() == 2: + attr = attr.unsqueeze(0) + + result = interpolate(attr, findices, barycentric, tri) + return result + + def load_mesh( + self, + mesh: Union[trimesh.Trimesh, trimesh.Scene], + scale_factor: float = 1.15, + auto_center: bool = True, + ): + """Load a mesh for rendering.""" + if isinstance(mesh, trimesh.Scene): + mesh = mesh.dump(concatenate=True) + + self.mesh_copy = mesh.copy() + + vtx_pos = mesh.vertices.astype(np.float32) + pos_idx = mesh.faces.astype(np.int32) + + # Get UV coordinates if available + if hasattr(mesh.visual, "uv") and mesh.visual.uv is not None: + vtx_uv = mesh.visual.uv.astype(np.float32) + uv_idx = pos_idx.copy() + else: + vtx_uv = None + uv_idx = None + + self.vtx_pos = torch.from_numpy(vtx_pos).to(self.device).float() + self.pos_idx = torch.from_numpy(pos_idx).to(self.device).to(torch.int32) + + if vtx_uv is not None and uv_idx is not None: + self.vtx_uv = torch.from_numpy(vtx_uv).to(self.device).float() + self.uv_idx = torch.from_numpy(uv_idx).to(self.device).to(torch.int32) + else: + self.vtx_uv = None + self.uv_idx = None + + # Coordinate transformation (Y-up to Z-up) + self.vtx_pos[:, [0, 1]] = -self.vtx_pos[:, [0, 1]] + self.vtx_pos[:, [1, 2]] = self.vtx_pos[:, [2, 1]] + if self.vtx_uv is not None: + self.vtx_uv[:, 1] = 1.0 - self.vtx_uv[:, 1] + + if auto_center: + max_bb = (self.vtx_pos - 0).max(0)[0] + min_bb = (self.vtx_pos - 0).min(0)[0] + center = (max_bb + min_bb) / 2 + scale = torch.norm(self.vtx_pos - center, dim=1).max() * 2.0 + self.vtx_pos = (self.vtx_pos - center) * (scale_factor / float(scale)) + self.scale_factor = scale_factor + + def save_mesh(self) -> trimesh.Trimesh: + """Save mesh with current texture, reusing the original mesh object.""" + texture_data = self.get_texture() + texture_img = Image.fromarray((texture_data * 255).astype(np.uint8)) + + material = trimesh.visual.material.SimpleMaterial( + image=texture_img, diffuse=(255, 255, 255) + ) + self.mesh_copy.visual = trimesh.visual.TextureVisuals( + uv=self.mesh_copy.visual.uv, image=texture_img, material=material + ) + return self.mesh_copy + + def get_mesh(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Get mesh data with inverse coordinate transformation.""" + vtx_pos = self.vtx_pos.cpu().numpy().copy() + pos_idx = self.pos_idx.cpu().numpy() + + # Inverse coordinate transformation + vtx_pos[:, [1, 2]] = vtx_pos[:, [2, 1]] + vtx_pos[:, [0, 1]] = -vtx_pos[:, [0, 1]] + + if self.vtx_uv is not None: + vtx_uv = self.vtx_uv.cpu().numpy().copy() + vtx_uv[:, 1] = 1.0 - vtx_uv[:, 1] + uv_idx = self.uv_idx.cpu().numpy() + else: + vtx_uv = None + uv_idx = None + + return vtx_pos, pos_idx, vtx_uv, uv_idx + + def set_texture(self, tex: Union[np.ndarray, torch.Tensor, Image.Image]): + """Set texture for the mesh.""" + if isinstance(tex, np.ndarray): + if tex.max() <= 1.0: + tex = (tex * 255).astype(np.uint8) + tex = Image.fromarray(tex.astype(np.uint8)) + elif isinstance(tex, torch.Tensor): + tex_np = tex.cpu().numpy() + if tex_np.max() <= 1.0: + tex_np = (tex_np * 255).astype(np.uint8) + tex = Image.fromarray(tex_np.astype(np.uint8)) + + tex = tex.resize(self.texture_size).convert("RGB") + tex = np.array(tex) / 255.0 + self.tex = torch.from_numpy(tex).to(self.device).float() + + def get_texture(self) -> np.ndarray: + """Get current texture as numpy array.""" + if self.tex is None: + return np.ones((*self.texture_size, 3), dtype=np.float32) + return self.tex.cpu().numpy() + + def _get_pos_from_mvp( + self, + elev: float, + azim: float, + camera_distance: Optional[float] = None, + center: Optional[np.ndarray] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Get camera-space and clip-space positions.""" + proj = self.camera_proj_mat + r_mv = get_mv_matrix( + elev=elev, + azim=azim, + camera_distance=( + self.camera_distance if camera_distance is None else camera_distance + ), + center=center, + ) + + pos_camera = transform_pos(r_mv, self.vtx_pos, keepdim=True) + pos_clip = transform_pos(proj, pos_camera) + + return pos_camera, pos_clip + + def render_normal( + self, + elev: float, + azim: float, + camera_distance: Optional[float] = None, + center: Optional[np.ndarray] = None, + resolution: Optional[Tuple[int, int]] = None, + bg_color: List[float] = [1, 1, 1], + use_abs_coor: bool = False, + normalize_rgb: bool = True, + return_type: str = "th", + ) -> Union[torch.Tensor, np.ndarray, Image.Image]: + """Render normal map from a viewpoint.""" + pos_camera, pos_clip = self._get_pos_from_mvp( + elev, azim, camera_distance, center + ) + + if resolution is None: + resolution = self.default_resolution + if isinstance(resolution, (int, float)): + resolution = (int(resolution), int(resolution)) + + rast_out = self._rasterize(pos_clip, self.pos_idx, resolution) + + # Compute face normals + if use_abs_coor: + mesh_triangles = self.vtx_pos[self.pos_idx[:, :3].long(), :] + else: + pos_camera_3d = pos_camera[:, :3] / pos_camera[:, 3:4] + mesh_triangles = pos_camera_3d[self.pos_idx[:, :3].long(), :] + + face_normals = F.normalize( + torch.cross( + mesh_triangles[:, 1, :] - mesh_triangles[:, 0, :], + mesh_triangles[:, 2, :] - mesh_triangles[:, 0, :], + dim=-1, + ), + dim=-1, + ) + + # Compute vertex normals + vertex_normals = trimesh.geometry.mean_vertex_normals( + vertex_count=self.vtx_pos.shape[0], + faces=self.pos_idx.cpu().numpy(), + face_normals=face_normals.cpu().numpy(), + ) + vertex_normals = ( + torch.from_numpy(vertex_normals).float().to(self.device).contiguous() + ) + + # Interpolate normals + normal = self._interpolate(vertex_normals[None, ...], rast_out, self.pos_idx) + + # Apply visibility mask + visible_mask = torch.clamp(rast_out[..., -1:], 0, 1) + bg_tensor = torch.tensor(bg_color, dtype=torch.float32, device=self.device) + normal = normal * visible_mask + bg_tensor * (1 - visible_mask) + + if normalize_rgb: + normal = (normal + 1) * 0.5 + + image = normal[0, ...] + + if return_type == "np": + image = image.cpu().numpy() + elif return_type == "pl": + image = image.cpu().numpy() * 255 + image = Image.fromarray(image.astype(np.uint8)) + + return image + + def render_position( + self, + elev: float, + azim: float, + camera_distance: Optional[float] = None, + center: Optional[np.ndarray] = None, + resolution: Optional[Tuple[int, int]] = None, + bg_color: List[float] = [1, 1, 1], + return_type: str = "th", + ) -> Union[torch.Tensor, np.ndarray, Image.Image]: + """Render position map from a viewpoint.""" + pos_camera, pos_clip = self._get_pos_from_mvp( + elev, azim, camera_distance, center + ) + + if resolution is None: + resolution = self.default_resolution + if isinstance(resolution, (int, float)): + resolution = (int(resolution), int(resolution)) + + rast_out = self._rasterize(pos_clip, self.pos_idx, resolution) + + # Position colors (normalized vertex positions) + tex_position = 0.5 - self.vtx_pos[:, :3] / self.scale_factor + tex_position = tex_position.contiguous() + + # Interpolate positions + position = self._interpolate(tex_position[None, ...], rast_out, self.pos_idx) + + # Apply visibility mask + visible_mask = torch.clamp(rast_out[..., -1:], 0, 1) + bg_tensor = torch.tensor(bg_color, dtype=torch.float32, device=self.device) + position = position * visible_mask + bg_tensor * (1 - visible_mask) + + image = position[0, ...] + + if return_type == "np": + image = image.cpu().numpy() + elif return_type == "pl": + image = image.cpu().numpy() * 255 + image = Image.fromarray(image.astype(np.uint8)) + + return image + + def render_normal_multiview( + self, + camera_elevs: List[float], + camera_azims: List[float], + use_abs_coor: bool = True, + ) -> List[Image.Image]: + """Render normal maps from multiple viewpoints.""" + normal_maps = [] + for elev, azim in zip(camera_elevs, camera_azims): + normal_map = self.render_normal( + elev, azim, use_abs_coor=use_abs_coor, return_type="pl" + ) + normal_maps.append(normal_map) + return normal_maps + + def render_position_multiview( + self, + camera_elevs: List[float], + camera_azims: List[float], + ) -> List[Image.Image]: + """Render position maps from multiple viewpoints.""" + position_maps = [] + for elev, azim in zip(camera_elevs, camera_azims): + position_map = self.render_position(elev, azim, return_type="pl") + position_maps.append(position_map) + return position_maps + + def _render_sketch_from_depth(self, depth_image: torch.Tensor) -> torch.Tensor: + """Render sketch from depth using edge detection.""" + depth_image_np = depth_image.cpu().numpy() + depth_image_np = (depth_image_np * 255).astype(np.uint8) + depth_edges = cv2.Canny(depth_image_np, 30, 80) + sketch_image = ( + torch.from_numpy(depth_edges).to(depth_image.device).float() / 255.0 + ) + sketch_image = sketch_image.unsqueeze(-1) + return sketch_image + + def back_project( + self, + image: Union[Image.Image, np.ndarray, torch.Tensor], + elev: float, + azim: float, + camera_distance: Optional[float] = None, + center: Optional[np.ndarray] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Back-project an image onto mesh UV space.""" + if isinstance(image, Image.Image): + image = torch.tensor(np.array(image) / 255.0) + elif isinstance(image, np.ndarray): + image = torch.tensor(image) + if image.dim() == 2: + image = image.unsqueeze(-1) + image = image.float().to(self.device) + resolution = image.shape[:2] + channel = image.shape[-1] + + pos_camera, pos_clip = self._get_pos_from_mvp( + elev, azim, camera_distance, center + ) + + rast_out = self._rasterize(pos_clip, self.pos_idx, resolution) + visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0, ...] + + # Compute vertex normals for angle-based weighting + pos_camera_3d = pos_camera[:, :3] / pos_camera[:, 3:4] + v0 = pos_camera_3d[self.pos_idx[:, 0].long(), :] + v1 = pos_camera_3d[self.pos_idx[:, 1].long(), :] + v2 = pos_camera_3d[self.pos_idx[:, 2].long(), :] + face_normals = F.normalize(torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1) + + vertex_normals = trimesh.geometry.mean_vertex_normals( + vertex_count=self.vtx_pos.shape[0], + faces=self.pos_idx.cpu().numpy(), + face_normals=face_normals.cpu().numpy(), + ) + vertex_normals = ( + torch.from_numpy(vertex_normals).float().to(self.device).contiguous() + ) + + # Interpolate normals and UVs + normal = self._interpolate(vertex_normals[None, ...], rast_out, self.pos_idx) + normal = normal[0, ...] + + if self.vtx_uv is not None: + uv = self._interpolate(self.vtx_uv[None, ...], rast_out, self.uv_idx) + else: + # No UV coordinates + texture = torch.zeros( + self.texture_size[1], self.texture_size[0], channel, device=self.device + ) + cos_map = torch.zeros( + self.texture_size[1], self.texture_size[0], 1, device=self.device + ) + boundary_map = torch.zeros_like(cos_map) + return texture, cos_map, boundary_map + + # Compute depth for sketch + tex_depth = pos_camera_3d[:, 2].reshape(1, -1, 1).contiguous() + depth = self._interpolate(tex_depth, rast_out, self.pos_idx)[0, ...] + depth_masked = depth[visible_mask > 0] + if depth_masked.numel() > 0: + depth_max, depth_min = depth_masked.max(), depth_masked.min() + depth_normalized = (depth - depth_min) / (depth_max - depth_min + 1e-8) + else: + depth_normalized = depth + depth_image = depth_normalized * visible_mask + + sketch_image = self._render_sketch_from_depth(depth_image) + + # Cosine weighting + lookat = torch.tensor([[0, 0, -1]], device=self.device) + cos_image = torch.nn.functional.cosine_similarity(lookat, normal.view(-1, 3)) + cos_image = cos_image.view(normal.shape[0], normal.shape[1], 1) + + cos_thres = np.cos(self.bake_angle_thres / 180 * np.pi) + cos_image[cos_image < cos_thres] = 0 + + # Shrink visible mask + kernel_size = self.bake_unreliable_kernel_size * 2 + 1 + kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float32).to( + sketch_image.device + ) + + visible_mask_proc = visible_mask.permute(2, 0, 1).unsqueeze(0).float() + visible_mask_proc = F.conv2d( + 1.0 - visible_mask_proc, kernel, padding=kernel_size // 2 + ) + visible_mask_proc = 1.0 - (visible_mask_proc > 0).float() + visible_mask_proc = visible_mask_proc.squeeze(0).permute(1, 2, 0) + + sketch_proc = sketch_image.permute(2, 0, 1).unsqueeze(0) + sketch_proc = F.conv2d(sketch_proc, kernel, padding=kernel_size // 2) + sketch_proc = (sketch_proc > 0).float() + sketch_proc = sketch_proc.squeeze(0).permute(1, 2, 0) + visible_mask_proc = visible_mask_proc * (sketch_proc < 0.5) + + cos_image[visible_mask_proc == 0] = 0 + + # Linear baking + proj_mask = (visible_mask_proc != 0).view(-1) + uv_flat = uv.squeeze(0).contiguous().view(-1, 2)[proj_mask] + image_flat = image.squeeze(0).contiguous().view(-1, channel)[proj_mask] + cos_flat = cos_image.contiguous().view(-1, 1)[proj_mask] + sketch_flat = sketch_image.contiguous().view(-1, 1)[proj_mask] + + texture = linear_grid_put_2d( + self.texture_size[1], self.texture_size[0], uv_flat[..., [1, 0]], image_flat + ) + cos_map = linear_grid_put_2d( + self.texture_size[1], self.texture_size[0], uv_flat[..., [1, 0]], cos_flat + ) + boundary_map = linear_grid_put_2d( + self.texture_size[1], + self.texture_size[0], + uv_flat[..., [1, 0]], + sketch_flat, + ) + + return texture, cos_map, boundary_map + + def bake_from_multiview( + self, + views: List[Image.Image], + camera_elevs: List[float], + camera_azims: List[float], + view_weights: List[float], + method: str = "fast", + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Bake texture from multiple views.""" + project_textures, project_weighted_cos_maps = [], [] + bake_exp = 4 + + for view, camera_elev, camera_azim, weight in zip( + views, camera_elevs, camera_azims, view_weights + ): + project_texture, project_cos_map, _ = self.back_project( + view, camera_elev, camera_azim + ) + project_cos_map = weight * (project_cos_map**bake_exp) + project_textures.append(project_texture) + project_weighted_cos_maps.append(project_cos_map) + + if method == "fast": + texture, ori_trust_map = self.fast_bake_texture( + project_textures, project_weighted_cos_maps + ) + else: + raise ValueError(f"Unknown bake method: {method}") + + return texture, ori_trust_map > 1e-8 + + @torch.no_grad() + def fast_bake_texture( + self, + textures: List[torch.Tensor], + cos_maps: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Fast texture baking by weighted averaging.""" + channel = textures[0].shape[-1] + texture_merge = torch.zeros(self.texture_size + (channel,)).to(self.device) + trust_map_merge = torch.zeros(self.texture_size + (1,)).to(self.device) + + for texture, cos_map in zip(textures, cos_maps): + view_sum = (cos_map > 0).sum() + painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum() + if view_sum > 0 and painted_sum / view_sum > 0.99: + continue + texture_merge += texture * cos_map + trust_map_merge += cos_map + + texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8) + texture_merge = texture_merge.clamp(0.0, 1.0) + + return texture_merge, trust_map_merge > 1e-8 + + def texture_inpaint( + self, + texture: torch.Tensor, + mask: Union[torch.Tensor, np.ndarray], + ) -> torch.Tensor: + """Inpaint missing regions in UV texture using mesh-aware method.""" + if isinstance(texture, torch.Tensor): + texture_np = texture.cpu().numpy() + else: + texture_np = texture + + if isinstance(mask, torch.Tensor): + mask_np = mask.cpu().numpy() + else: + mask_np = mask + + # Ensure proper format + if texture_np.max() <= 1.0: + texture_np = texture_np.astype(np.float32) + else: + texture_np = (texture_np / 255.0).astype(np.float32) + + if mask_np.ndim == 3: + mask_np = mask_np.squeeze(-1) + if mask_np.dtype == np.uint8: + mask_uint8 = mask_np + else: + mask_uint8 = ((mask_np > 0) * 255).astype(np.uint8) + + # Get mesh data for mesh-aware inpainting + vtx_pos, pos_idx, vtx_uv, uv_idx = self.get_mesh() + + if vtx_uv is not None and uv_idx is not None: + texture_np, mask_uint8 = meshVerticeInpaint( + texture_np, mask_uint8, vtx_pos, vtx_uv, pos_idx, uv_idx + ) + + # Final OpenCV inpainting for remaining holes + texture_uint8 = (texture_np * 255).astype(np.uint8) + inpaint_mask = 255 - mask_uint8 + texture_inpainted = cv2.inpaint(texture_uint8, inpaint_mask, 3, cv2.INPAINT_NS) + + return torch.from_numpy(texture_inpainted / 255.0).float().to(self.device) + + # Alias for compatibility + uv_inpaint = texture_inpaint + + +def array_to_tensor(np_array): + """Convert numpy array to normalized tensor.""" + image_pt = torch.tensor(np_array).float() + image_pt = image_pt / 255 * 2 - 1 + image_pt = rearrange(image_pt, "h w c -> c h w") + image_pts = repeat(image_pt, "c h w -> b c h w", b=1) + return image_pts + + +def recenter_image(image, border_ratio=0.2): + """Recenter a PIL image, cropping to non-transparent content with a border.""" + from PIL import Image as PILImage + + if image.mode == "RGB": + return image + elif image.mode == "L": + return image.convert("RGB") + if image.mode != "RGBA": + image = image.convert("RGBA") + + alpha_channel = np.array(image)[:, :, 3] + non_zero_indices = np.argwhere(alpha_channel > 0) + if non_zero_indices.size == 0: + raise ValueError("Image is fully transparent") + + min_row, min_col = non_zero_indices.min(axis=0) + max_row, max_col = non_zero_indices.max(axis=0) + + cropped_image = image.crop((min_col, min_row, max_col + 1, max_row + 1)) + + width, height = cropped_image.size + border_width = int(width * border_ratio) + border_height = int(height * border_ratio) + + new_width = width + 2 * border_width + new_height = height + 2 * border_height + square_size = max(new_width, new_height) + + new_image = PILImage.new("RGBA", (square_size, square_size), (255, 255, 255, 0)) + + paste_x = (square_size - new_width) // 2 + border_width + paste_y = (square_size - new_height) // 2 + border_height + new_image.paste(cropped_image, (paste_x, paste_y)) + return new_image + + +class ImageProcessorV2: + """Image processor for Hunyuan3D single-view input.""" + + # External module path aliases for compatibility with Hunyuan3D configs + _aliases = [ + "hy3dshape.preprocessors.ImageProcessorV2", + "hy3dgen.shapegen.preprocessors.ImageProcessorV2", + ] + + def __init__(self, size=512, border_ratio=None): + self.size = size + self.border_ratio = border_ratio + + @staticmethod + def recenter(image, border_ratio: float = 0.2): + """recenter an image to leave some empty space at the image border.""" + + if image.shape[-1] == 4: + mask = image[..., 3] + else: + mask = np.ones_like(image[..., 0:1]) * 255 + image = np.concatenate([image, mask], axis=-1) + mask = mask[..., 0] + + height, width, channels = image.shape + + size = max(height, width) + result = np.zeros((size, size, channels), dtype=np.uint8) + + coords = np.nonzero(mask) + x_min, x_max = coords[0].min(), coords[0].max() + y_min, y_max = coords[1].min(), coords[1].max() + crop_h = x_max - x_min + crop_w = y_max - y_min + if crop_h == 0 or crop_w == 0: + raise ValueError("input image is empty") + desired_size = int(size * (1 - border_ratio)) + scale = desired_size / max(crop_h, crop_w) + scaled_h = int(crop_h * scale) + scaled_w = int(crop_w * scale) + x2_min = (size - scaled_h) // 2 + x2_max = x2_min + scaled_h + + y2_min = (size - scaled_w) // 2 + y2_max = y2_min + scaled_w + + result[x2_min:x2_max, y2_min:y2_max] = cv2.resize( + image[x_min:x_max, y_min:y_max], + (scaled_w, scaled_h), + interpolation=cv2.INTER_AREA, + ) + + bg = np.ones((result.shape[0], result.shape[1], 3), dtype=np.uint8) * 255 + + mask = result[..., 3:].astype(np.float32) / 255 + result = result[..., :3] * mask + bg * (1 - mask) + + mask = mask * 255 + result = result.clip(0, 255).astype(np.uint8) + mask = mask.clip(0, 255).astype(np.uint8) + return result, mask + + def load_image(self, image, border_ratio=0.15, to_tensor=True): + if isinstance(image, str): + image = cv2.imread(image, cv2.IMREAD_UNCHANGED) + image, mask = self.recenter(image, border_ratio=border_ratio) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + elif isinstance(image, Image.Image): + image = image.convert("RGBA") + image = np.asarray(image) + image, mask = self.recenter(image, border_ratio=border_ratio) + + image = cv2.resize(image, (self.size, self.size), interpolation=cv2.INTER_CUBIC) + mask = cv2.resize(mask, (self.size, self.size), interpolation=cv2.INTER_NEAREST) + mask = mask[..., np.newaxis] + + if to_tensor: + image = array_to_tensor(image) + mask = array_to_tensor(mask) + return image, mask + + def __call__(self, image, border_ratio=0.15, to_tensor=True, **kwargs): + if self.border_ratio is not None: + border_ratio = self.border_ratio + image, mask = self.load_image( + image, border_ratio=border_ratio, to_tensor=to_tensor + ) + outputs = {"image": image, "mask": mask} + return outputs + + +class MVImageProcessorV2(ImageProcessorV2): + """Multi-view image processor for Hunyuan3D.""" + + # External module path aliases for compatibility with Hunyuan3D configs + _aliases = [ + "hy3dshape.preprocessors.MVImageProcessorV2", + ] + + return_view_idx = True + + def __init__(self, size=512, border_ratio=None): + super().__init__(size, border_ratio) + self.view2idx = {"front": 0, "left": 1, "back": 2, "right": 3} + + def __call__(self, image_dict, border_ratio=0.15, to_tensor=True, **kwargs): + if self.border_ratio is not None: + border_ratio = self.border_ratio + + images = [] + masks = [] + view_idxs = [] + for view_tag, image in image_dict.items(): + view_idxs.append(self.view2idx[view_tag]) + image, mask = self.load_image( + image, border_ratio=border_ratio, to_tensor=to_tensor + ) + images.append(image) + masks.append(mask) + + zipped_lists = zip(view_idxs, images, masks) + sorted_zipped_lists = sorted(zipped_lists) + view_idxs, images, masks = zip(*sorted_zipped_lists) + + image = torch.cat(images, 0).unsqueeze(0) + mask = torch.cat(masks, 0).unsqueeze(0) + outputs = {"image": image, "mask": mask, "view_idxs": view_idxs} + return outputs + + +# All tool classes available in this module for resolution +TOOL_CLASSES = ( + ImageProcessorV2, + MVImageProcessorV2, +) + + +def resolve_hunyuan3d_tool(target: str): + """Resolve a Hunyuan3D tool class by target string.""" + # First, try to match against _aliases + for cls in TOOL_CLASSES: + aliases = getattr(cls, "_aliases", []) + if target in aliases: + return cls + + # Then, try to match against class names + for cls in TOOL_CLASSES: + if cls.__name__ == target: + return cls + + return None + + +__all__ = [ + "transform_pos", + "get_mv_matrix", + "get_orthographic_projection_matrix", + "get_perspective_projection_matrix", + "export_to_trimesh", + "mesh_uv_wrap", + "meshVerticeInpaint", + "stride_from_shape", + "scatter_add_nd_with_count", + "linear_grid_put_2d", + "MeshRender", + "recenter_image", + "array_to_tensor", + "ImageProcessorV2", + "MVImageProcessorV2", + "TOOL_CLASSES", + "resolve_hunyuan3d_tool", +] diff --git a/sglang/python/sglang/multimodal_gen/runtime/utils/perf_logger.py b/sglang/python/sglang/multimodal_gen/runtime/utils/perf_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..d44a50d37f19aec170a173d1cacb54cb24a281e7 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/utils/perf_logger.py @@ -0,0 +1,366 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo +import dataclasses +import json +import logging +import os +import subprocess +import sys +import time +from datetime import datetime +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from dateutil.tz import UTC + +import sglang +import sglang.multimodal_gen.envs as envs +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.utils.logging_utils import ( + CYAN, + RESET, + _SGLDiffusionLogger, + get_is_main_process, + init_logger, +) + +logger = init_logger(__name__) + + +@dataclasses.dataclass +class MemorySnapshot: + allocated_mb: float # current allocated memory + reserved_mb: float # current reserved memory (actual VRAM) + peak_allocated_mb: float # peak allocated since last reset + peak_reserved_mb: float # peak reserved since last reset + + def to_dict(self) -> Dict[str, Any]: + return { + "allocated_mb": round(self.allocated_mb, 2), + "reserved_mb": round(self.reserved_mb, 2), + "peak_allocated_mb": round(self.peak_allocated_mb, 2), + "peak_reserved_mb": round(self.peak_reserved_mb, 2), + } + + +@dataclasses.dataclass +class RequestMetrics: + """Performance metrics for a single request, including timings and memory snapshots.""" + + def __init__(self, request_id: str): + self.request_id = request_id + self.stages: Dict[str, float] = {} + self.steps: list[float] = [] + self.total_duration_ms: float = 0.0 + # memory tracking: {checkpoint_name: MemorySnapshot} + self.memory_snapshots: Dict[str, MemorySnapshot] = {} + + @property + def total_duration_s(self) -> float: + return self.total_duration_ms / 1000.0 + + def record_stage(self, stage_name: str, duration_s: float): + """Records the duration of a pipeline stage""" + self.stages[stage_name] = duration_s * 1000 # Store as milliseconds + + def record_steps(self, index: int, duration_s: float): + """Records the duration of a denoising step""" + assert index == len(self.steps) + self.steps.append(duration_s * 1000) + + def record_memory_snapshot(self, checkpoint_name: str, snapshot: MemorySnapshot): + self.memory_snapshots[checkpoint_name] = snapshot + + def to_dict(self) -> Dict[str, Any]: + """Serializes the metrics data to a dictionary.""" + return { + "request_id": self.request_id, + "stages": self.stages, + "steps": self.steps, + "total_duration_ms": self.total_duration_ms, + "memory_snapshots": { + name: snapshot.to_dict() + for name, snapshot in self.memory_snapshots.items() + }, + } + + +def get_diffusion_perf_log_dir() -> str: + """ + Determines the directory for performance logs. + """ + log_dir = os.environ.get("SGLANG_PERF_LOG_DIR") + if log_dir: + return os.path.abspath(log_dir) + if log_dir is None: + sglang_path = Path(sglang.__file__).resolve() + target_path = (sglang_path.parent / "../../.cache/logs").resolve() + return str(target_path) + return "" + + +@lru_cache(maxsize=1) +def get_git_commit_hash() -> str: + try: + commit_hash = os.environ.get("SGLANG_GIT_COMMIT") + if not commit_hash: + commit_hash = ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL + ) + .strip() + .decode("utf-8") + ) + _CACHED_COMMIT_HASH = commit_hash + return commit_hash + except (subprocess.CalledProcessError, FileNotFoundError): + _CACHED_COMMIT_HASH = "N/A" + return "N/A" + + +def capture_memory_snapshot() -> MemorySnapshot: + if not torch.get_device_module().is_available(): + return MemorySnapshot( + allocated_mb=0.0, + reserved_mb=0.0, + peak_allocated_mb=0.0, + peak_reserved_mb=0.0, + ) + + allocated = torch.get_device_module().memory_allocated() + reserved = torch.get_device_module().memory_reserved() + peak_allocated = torch.get_device_module().max_memory_allocated() + peak_reserved = torch.get_device_module().max_memory_reserved() + + return MemorySnapshot( + allocated_mb=allocated / (1024**2), + reserved_mb=reserved / (1024**2), + peak_allocated_mb=peak_allocated / (1024**2), + peak_reserved_mb=peak_reserved / (1024**2), + ) + + +@dataclasses.dataclass +class RequestPerfRecord: + request_id: str + + timestamp: str + commit_hash: str + tag: str + + stages: list[dict] + steps: list[float] + total_duration_ms: float + memory_snapshots: dict[str, dict] = dataclasses.field(default_factory=dict) + + def __init__( + self, + request_id, + commit_hash, + tag, + stages, + steps, + total_duration_ms, + memory_snapshots=None, + timestamp=None, + ): + self.request_id = request_id + if timestamp is not None: + self.timestamp = timestamp + else: + self.timestamp = datetime.now(UTC).isoformat() + + self.commit_hash = commit_hash + self.tag = tag + self.stages = stages + self.steps = steps + self.total_duration_ms = total_duration_ms + self.memory_snapshots = memory_snapshots or {} + + +class StageProfiler: + """ + A unified context manager, records performance metrics (usually of a single Stage or a step) into a provided RequestMetrics object (usually from a Req). + """ + + def __init__( + self, + stage_name: str, + logger: _SGLDiffusionLogger, + metrics: Optional["RequestMetrics"], + log_stage_start_end: bool = False, + perf_dump_path_provided: bool = False, + capture_memory: bool = False, + ): + self.stage_name = stage_name + self.metrics = metrics + self.logger = logger + self.start_time = 0.0 + self.log_timing = perf_dump_path_provided or envs.SGLANG_DIFFUSION_STAGE_LOGGING + self.log_stage_start_end = log_stage_start_end + self.capture_memory = capture_memory + + def __enter__(self): + if self.log_stage_start_end: + msg = f"[{self.stage_name}] started..." + if self.logger.isEnabledFor(logging.DEBUG): + msg += f" ({round(current_platform.get_available_gpu_memory(), 2)} GB left)" + self.logger.info(msg) + + if (self.log_timing and self.metrics) or self.log_stage_start_end: + if ( + os.environ.get("SGLANG_DIFFUSION_SYNC_STAGE_PROFILING", "0") == "1" + and self.stage_name.startswith("denoising_step_") + and torch.get_device_module().is_available() + ): + torch.get_device_module().synchronize() + self.start_time = time.perf_counter() + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not ((self.log_timing and self.metrics) or self.log_stage_start_end): + return False + + if ( + os.environ.get("SGLANG_DIFFUSION_SYNC_STAGE_PROFILING", "0") == "1" + and self.stage_name.startswith("denoising_step_") + and torch.get_device_module().is_available() + ): + torch.get_device_module().synchronize() + execution_time_s = time.perf_counter() - self.start_time + + if exc_type: + self.logger.error( + "[%s] Error during execution after %.4f ms: %s", + self.stage_name, + execution_time_s * 1000, + exc_val, + exc_info=True, + ) + return False + + if self.log_stage_start_end: + self.logger.info( + f"[{self.stage_name}] finished in {execution_time_s:.4f} seconds", + ) + + if self.log_timing and self.metrics: + if "denoising_step_" in self.stage_name: + index = int(self.stage_name[len("denoising_step_") :]) + self.metrics.record_steps(index, execution_time_s) + else: + self.metrics.record_stage(self.stage_name, execution_time_s) + + # capture memory snapshot after stage if requested + if self.capture_memory and torch.get_device_module().is_available(): + snapshot = capture_memory_snapshot() + self.metrics.record_memory_snapshot( + f"after_{self.stage_name}", snapshot + ) + + return False + + +class PerformanceLogger: + """ + A global utility class for logging performance metrics for all request, categorized by request-id. + + Serves both as a runtime logger (stream to file) and a dump utility. + + Notice that RequestMetrics stores the performance metrics of a single request + """ + + @classmethod + def dump_benchmark_report( + cls, + file_path: str, + metrics: "RequestMetrics", + meta: Optional[Dict[str, Any]] = None, + tag: str = "benchmark_dump", + ): + """ + Static method to dump a standardized benchmark report to a file. + Eliminates duplicate logic in CLI/Client code. + """ + formatted_steps = [ + {"name": name, "duration_ms": duration_ms} + for name, duration_ms in metrics.stages.items() + ] + + denoise_steps_ms = [ + {"step": idx, "duration_ms": duration_ms} + for idx, duration_ms in enumerate(metrics.steps) + ] + + memory_checkpoints = { + name: snapshot.to_dict() + for name, snapshot in metrics.memory_snapshots.items() + } + + report = { + "timestamp": datetime.now(UTC).isoformat(), + "request_id": metrics.request_id, + "commit_hash": get_git_commit_hash(), + "tag": tag, + "total_duration_ms": metrics.total_duration_ms, + "steps": formatted_steps, + "denoise_steps_ms": denoise_steps_ms, + "memory_checkpoints": memory_checkpoints, + "meta": meta or {}, + } + + try: + abs_path = os.path.abspath(file_path) + os.makedirs(os.path.dirname(abs_path), exist_ok=True) + with open(abs_path, "w", encoding="utf-8") as f: + json.dump(report, f, indent=2) + logger.info(f"Metrics dumped to: {CYAN}{abs_path}{RESET}") + except IOError as e: + logger.error(f"Failed to dump metrics to {abs_path}: {e}") + + @classmethod + def log_request_summary( + cls, + metrics: "RequestMetrics", + tag: str = "total_inference_time", + ): + """logs the stage metrics and total duration for a completed request + to the performance_log file. + + Note that this accords to the time spent internally in server, postprocess is not included + """ + formatted_stages = [ + {"name": name, "execution_time_ms": duration_ms} + for name, duration_ms in metrics.stages.items() + ] + + memory_checkpoints = { + name: snapshot.to_dict() + for name, snapshot in metrics.memory_snapshots.items() + } + + record = RequestPerfRecord( + metrics.request_id, + commit_hash=get_git_commit_hash(), + tag="pipeline_stage_metrics", + stages=formatted_stages, + steps=metrics.steps, + total_duration_ms=metrics.total_duration_ms, + memory_snapshots=memory_checkpoints, + ) + + try: + if get_is_main_process(): + log_dir = get_diffusion_perf_log_dir() + if not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + + log_file = os.path.join(log_dir, "performance.log") + + with open(log_file, "a", encoding="utf-8") as f: + f.write(json.dumps(dataclasses.asdict(record)) + "\n") + + except (OSError, PermissionError) as e: + print(f"WARNING: Failed to log performance record: {e}", file=sys.stderr) diff --git a/sglang/python/sglang/multimodal_gen/runtime/utils/profiler.py b/sglang/python/sglang/multimodal_gen/runtime/utils/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..19a0ebfd84390366da06d79adb2398e3b27539bd --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/runtime/utils/profiler.py @@ -0,0 +1,183 @@ +import gzip +import os + +import torch + +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.utils.logging_utils import CYAN, RESET, init_logger + +if current_platform.is_npu(): + import torch_npu + + patches = [ + ["profiler.profile", torch_npu.profiler.profile], + ["profiler.schedule", torch_npu.profiler.schedule], + ] + torch_npu._apply_patches(patches) + +logger = init_logger(__name__) + + +class SGLDiffusionProfiler: + """ + A wrapper around torch.profiler to simplify usage in pipelines. + Supports both full profiling and scheduled profiling. + + + 1. if profile_all_stages is on: profile all stages, including all denoising steps + 2. otherwise, if num_profiled_timesteps is specified: profile {num_profiled_timesteps} denoising steps. profile all steps if num_profiled_timesteps==-1 + """ + + _instance = None + + def __init__( + self, + request_id: str | None = None, + rank: int = 0, + full_profile: bool = False, + num_steps: int | None = None, + num_inference_steps: int | None = None, + log_dir: str | None = None, + ): + self.request_id = request_id or "profile_trace" + self.rank = rank + self.full_profile = full_profile + + self.log_dir = ( + log_dir + if log_dir is not None + else os.getenv("SGLANG_TORCH_PROFILER_DIR", "./logs") + ) + + try: + os.makedirs(self.log_dir, exist_ok=True) + except OSError: + pass + + activities = [torch.profiler.ProfilerActivity.CPU] + if torch.cuda.is_available() or ( + hasattr(torch, "musa") and torch.musa.is_available() + ): + activities.append(torch.profiler.ProfilerActivity.CUDA) + if current_platform.is_npu(): + activities.append(torch_npu.profiler.ProfilerActivity.NPU) + + common_torch_profiler_args = dict( + activities=activities, + record_shapes=True, + with_stack=True, + on_trace_ready=( + None + if not current_platform.is_npu() + else torch_npu.profiler.tensorboard_trace_handler(self.log_dir) + ), + ) + if self.full_profile: + # profile all stages + self.profiler = torch.profiler.profile(**common_torch_profiler_args) + self.profile_mode_id = "full stages" + else: + # profile denoising stage only + warmup = 1 + num_actual_steps = num_inference_steps if num_steps == -1 else num_steps + self.num_active_steps = num_actual_steps + warmup + self.profiler = torch.profiler.profile( + **common_torch_profiler_args, + schedule=torch.profiler.schedule( + skip_first=0, + wait=0, + warmup=warmup, + active=self.num_active_steps, + repeat=1, + ), + ) + self.profile_mode_id = f"{num_actual_steps} steps" + + logger.info(f"Profiling request: {request_id} for {self.profile_mode_id}...") + + self.has_stopped = False + + SGLDiffusionProfiler._instance = self + self.start() + + def start(self): + logger.info("Starting Profiler...") + self.profiler.start() + + def _step(self): + self.profiler.step() + + def step_stage(self): + if self.full_profile: + self._step() + + def step_denoising_step(self): + if not self.full_profile: + if self.num_active_steps >= 0: + self._step() + self.num_active_steps -= 1 + else: + # early exit when enough steps are captured, to reduce the trace file size + self.stop(dump_rank=0) + + @classmethod + def get_instance(cls) -> "SGLDiffusionProfiler": + return cls._instance + + def stop(self, export_trace: bool = True, dump_rank: int | None = None): + if self.has_stopped: + return + self.has_stopped = True + logger.info("Stopping Profiler...") + if torch.cuda.is_available() or ( + hasattr(torch, "musa") and torch.musa.is_available() + ): + torch.cuda.synchronize() + if current_platform.is_npu(): + torch.npu.synchronize() + export_trace = False # set to false because our internal torch_npu.profiler will generate trace file + self.profiler.stop() + + if export_trace: + if dump_rank is not None and dump_rank != self.rank: + pass + else: + self._export_trace() + + SGLDiffusionProfiler._instance = None + + def _export_trace(self): + + try: + os.makedirs(self.log_dir, exist_ok=True) + sanitized_profile_mode_id = self.profile_mode_id.replace(" ", "_") + trace_path = os.path.abspath( + os.path.join( + self.log_dir, + f"{self.request_id}-{sanitized_profile_mode_id}-global-rank{self.rank}.trace.json.gz", + ) + ) + self.profiler.export_chrome_trace(trace_path) + + if self._check_trace_integrity(trace_path): + logger.info(f"Saved profiler traces to: {CYAN}{trace_path}{RESET}") + else: + logger.warning(f"Trace file may be corrupted: {trace_path}") + except Exception as e: + logger.error(f"Failed to save trace: {e}") + + def _check_trace_integrity(self, trace_path: str) -> bool: + try: + if not os.path.exists(trace_path) or os.path.getsize(trace_path) == 0: + return False + + with gzip.open(trace_path, "rb") as f: + content = f.read() + if content.count(b"\x1f\x8b") > 1: + logger.warning("Multiple gzip headers detected") + return False + + return True + except Exception as e: + logger.warning(f"Trace file integrity check failed: {e}") + return False diff --git a/sglang/python/sglang/multimodal_gen/test/__init__.py b/sglang/python/sglang/multimodal_gen/test/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..af2eb7d103a81378d30056cf536353bd24621a5e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/sglang/python/sglang/multimodal_gen/test/cli/test_generate_common.py b/sglang/python/sglang/multimodal_gen/test/cli/test_generate_common.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcbf8ab1f4aa421bd7dac6cac55239613cef302 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/cli/test_generate_common.py @@ -0,0 +1,117 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +""" +Common generate cli test, one test for image and video each +""" + +import dataclasses +import os +import shlex +import subprocess +import sys +import unittest +from typing import Optional + +from PIL import Image + +from sglang.multimodal_gen.configs.sample.sampling_params import DataType +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.test.test_utils import check_image_size + +logger = init_logger(__name__) + + +@dataclasses.dataclass +class TestResult: + name: str + key: str + succeed: bool + + +def run_command(command) -> Optional[float]: + """Runs a command and returns the execution time and status.""" + print(f"Running command: {shlex.join(command)}") + + with subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + encoding="utf-8", + ) as process: + for line in process.stdout: + sys.stdout.write(line) + process.wait() + if process.returncode == 0: + return True + print(f"Command failed with exit code {process.returncode}") + return False + + +class CLIBase(unittest.TestCase): + model_path: str = None + extra_args = [] + data_type: DataType = None + # tested on h100 + + width: int = 720 + height: int = 720 + output_path: str = "test_outputs" + + def get_base_command(self): + return [ + "sglang", + "generate", + "--prompt", + "A curious raccoon", + "--save-output", + "--log-level=debug", + f"--width={self.width}", + f"--height={self.height}", + f"--output-path={self.output_path}", + ] + + def _run_command(self, name: str, model_path: str, args=[]): + command = ( + self.get_base_command() + + [f"--model-path={model_path}"] + + shlex.split(args or "") + + ["--output-file-name", f"{name}"] + + self.extra_args + ) + succeed = run_command(command) + status = "Success" if succeed else "Failed" + + return name, status + + def _run_test(self, name: str, args, model_path: str, test_key: str): + name, status = self._run_command(name, args=args, model_path=model_path) + self.verify(status, name) + + def verify(self, status, name): + print("-" * 80) + print("\n" * 3) + + # test task status + self.assertEqual(status, "Success", f"{name} command failed") + + # test output file + path = os.path.join( + self.output_path, f"{name}.{self.data_type.get_default_extension()}" + ) + self.assertTrue(os.path.exists(path), f"Output file not exist for {path}") + if self.data_type == DataType.IMAGE: + with Image.open(path) as image: + check_image_size(self, image, self.width, self.height) + + def model_name(self): + return self.model_path.split("/")[-1] + + def test_single_gpu(self): + """single gpu""" + self._run_test( + name=f"{self.model_name()}_single_gpu", + args=None, + model_path=self.model_path, + test_key="test_single_gpu", + ) diff --git a/sglang/python/sglang/multimodal_gen/test/cli/test_generate_t2i_perf.py b/sglang/python/sglang/multimodal_gen/test/cli/test_generate_t2i_perf.py new file mode 100644 index 0000000000000000000000000000000000000000..8c7bb7f45465ecbd4517c6272a76c088b8803929 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/cli/test_generate_t2i_perf.py @@ -0,0 +1,23 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import unittest + +from sglang.multimodal_gen.configs.sample.sampling_params import DataType +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.test.cli.test_generate_common import CLIBase +from sglang.multimodal_gen.test.test_utils import DEFAULT_FLUX_1_DEV_MODEL_NAME_FOR_TEST + +logger = init_logger(__name__) + + +class TestFlux_T2V(CLIBase): + model_path = DEFAULT_FLUX_1_DEV_MODEL_NAME_FOR_TEST + extra_args = [] + data_type: DataType = DataType.IMAGE + + +del CLIBase + + +if __name__ == "__main__": + unittest.main() diff --git a/sglang/python/sglang/multimodal_gen/test/run_suite.py b/sglang/python/sglang/multimodal_gen/test/run_suite.py new file mode 100644 index 0000000000000000000000000000000000000000..2244e3f02df606f51e0d1f21ddf8915b7ba1a99f --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/run_suite.py @@ -0,0 +1,345 @@ +""" +Test runner for multimodal_gen that manages test suites and parallel execution. + +Usage: + python3 run_suite.py --suite --partition-id --total-partitions + +Example: + python3 run_suite.py --suite 1-gpu --partition-id 0 --total-partitions 4 +""" + +import argparse +import os +import random +import subprocess +import sys +from pathlib import Path + +import tabulate + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +_UPDATE_WEIGHTS_FROM_DISK_TEST_FILE = "test_update_weights_from_disk.py" +_UPDATE_WEIGHTS_MODEL_PAIR_ENV = "SGLANG_MMGEN_UPDATE_WEIGHTS_PAIR" +_UPDATE_WEIGHTS_MODEL_PAIR_IDS = ( + "FLUX.2-klein-base-4B", + "Qwen-Image", +) + +SUITES = { + # no GPU required; safe to run on any CPU-only runner + "unit": [ + "../unit/test_sampling_params_validate.py", + "../unit/test_storage.py", + "../unit/test_lora_format_adapter.py", + "../unit/test_server_args_unit.py", + # add new unit tests here + ], + "1-gpu": [ + "test_server_a.py", + "test_server_b.py", + # cli test + "../cli/test_generate_t2i_perf.py", + "test_update_weights_from_disk.py", + # add new 1-gpu test files here + ], + "2-gpu": [ + "test_server_2_gpu_a.py", + "test_server_2_gpu_b.py", + # add new 2-gpu test files here + ], +} + +suites_ascend = { + "1-npu": [ + "ascend/test_server_1_npu.py", + # add new 1-npu test files here + ], + "2-npu": [ + "ascend/test_server_2_npu.py", + # add new 2-npu test files here + ], +} + +SUITES.update(suites_ascend) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run multimodal_gen test suite") + parser.add_argument( + "--suite", + type=str, + required=True, + choices=list(SUITES.keys()), + help="The test suite to run (e.g., 1-gpu, 2-gpu)", + ) + parser.add_argument( + "--partition-id", + type=int, + default=0, + help="Index of the current partition (for parallel execution)", + ) + parser.add_argument( + "--total-partitions", + type=int, + default=1, + help="Total number of partitions", + ) + parser.add_argument( + "--base-dir", + type=str, + default="server", + help="Base directory for tests relative to this script's parent", + ) + parser.add_argument( + "-k", + "--filter", + type=str, + default=None, + help="Pytest filter expression (passed to pytest -k)", + ) + parser.add_argument( + "--continue-on-error", + action="store_true", + default=False, + help="Continue running remaining tests even if one fails (for CI consistency; pytest already continues by default)", + ) + return parser.parse_args() + + +def collect_test_items(files, filter_expr=None): + """Collect test item node IDs from the given files using pytest --collect-only.""" + cmd = [sys.executable, "-m", "pytest", "--collect-only", "-q"] + if filter_expr: + cmd.extend(["-k", filter_expr]) + cmd.extend(files) + + print(f"Collecting tests with command: {' '.join(cmd)}") + result = subprocess.run(cmd, capture_output=True, text=True) + + # Check for collection errors + # pytest exit codes: + # 0: success + # 1: tests collected but some had errors during collection + # 2: test execution interrupted + # 3: internal error + # 4: command line usage error + # 5: no tests collected (may be expected with filters) + if result.returncode not in (0, 5): + error_msg = ( + f"pytest --collect-only failed with exit code {result.returncode}\n" + f"Command: {' '.join(cmd)}\n" + ) + if result.stderr: + error_msg += f"stderr:\n{result.stderr}\n" + if result.stdout: + error_msg += f"stdout:\n{result.stdout}\n" + logger.error(error_msg) + raise RuntimeError(error_msg) + + if result.returncode == 5: + print( + "No tests were collected (exit code 5). This may be expected with filters." + ) + + # Parse the output to extract test node IDs + # pytest -q outputs lines like: test_file.py::TestClass::test_method[param] + test_items = [] + for line in result.stdout.strip().split("\n"): + line = line.strip() + # Skip empty lines and summary lines + if line and "::" in line and not line.startswith(("=", "-", " ")): + # Handle lines that might have extra info after the test ID + test_id = line.split()[0] if " " in line else line + if "::" in test_id: + test_items.append(test_id) + + print(f"Collected {len(test_items)} test items") + return test_items + + +def run_pytest(files, filter_expr=None): + if not files: + print("No files to run.") + return 0 + + base_cmd = [sys.executable, "-m", "pytest", "-s", "-v"] + + # Add pytest -k filter if provided + if filter_expr: + base_cmd.extend(["-k", filter_expr]) + + max_retries = 6 + # retry if the perf assertion failed, for {max_retries} times + for i in range(max_retries + 1): + cmd = list(base_cmd) + if i > 0: + cmd.append("--last-failed") + # Always include files to constrain test discovery scope + # This prevents pytest from scanning the entire rootdir and + # discovering unrelated tests that may have missing dependencies + cmd.extend(files) + + if i > 0: + print( + f"Performance assertion failed. Retrying ({i}/{max_retries}) with --last-failed..." + ) + + print(f"Running command: {' '.join(cmd)}") + + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + bufsize=0, + ) + + output_bytes = bytearray() + while True: + chunk = process.stdout.read(4096) + if not chunk: + break + sys.stdout.buffer.write(chunk) + sys.stdout.buffer.flush() + output_bytes.extend(chunk) + + process.wait() + returncode = process.returncode + + if returncode == 0: + return 0 + + # Exit code 5 means no tests were collected/selected - treat as success + # when using filters, since some partitions may have all tests filtered out + if returncode == 5: + print( + "No tests collected (exit code 5). This is expected when filters " + "deselect all tests in a partition. Treating as success." + ) + return 0 + + # check if the failure is due to an assertion in test_server_utils.py + full_output = output_bytes.decode("utf-8", errors="replace") + is_perf_assertion = ( + "multimodal_gen/test/server/test_server_utils.py" in full_output + and "AssertionError" in full_output + ) + + is_flaky_ci_assertion = ( + "SafetensorError" in full_output or "FileNotFoundError" in full_output + ) + + is_oom_error = ( + "out of memory" in full_output.lower() + or "oom killer" in full_output.lower() + ) + + if not (is_perf_assertion or is_flaky_ci_assertion or is_oom_error): + return returncode + + print(f"Max retry exceeded") + return returncode + + +def _is_in_ci() -> bool: + return os.environ.get("SGLANG_IS_IN_CI", "").lower() in ("1", "true", "yes", "on") + + +def _maybe_pin_update_weights_model_pair(suite_files_rel: list[str]) -> None: + if not _is_in_ci(): + return + if _UPDATE_WEIGHTS_FROM_DISK_TEST_FILE not in suite_files_rel: + return + if os.environ.get(_UPDATE_WEIGHTS_MODEL_PAIR_ENV): + print( + f"Using preset {_UPDATE_WEIGHTS_MODEL_PAIR_ENV}=" + f"{os.environ[_UPDATE_WEIGHTS_MODEL_PAIR_ENV]}" + ) + return + + selected_pair = random.choice(_UPDATE_WEIGHTS_MODEL_PAIR_IDS) + os.environ[_UPDATE_WEIGHTS_MODEL_PAIR_ENV] = selected_pair + print(f"Selected {_UPDATE_WEIGHTS_MODEL_PAIR_ENV}={selected_pair} for this CI run") + + +def main(): + args = parse_args() + + # 1. resolve base path + current_file_path = Path(__file__).resolve() + test_root_dir = current_file_path.parent + target_dir = test_root_dir / args.base_dir + + if not target_dir.exists(): + print(f"Error: Target directory {target_dir} does not exist.") + sys.exit(1) + + # 2. get files from suite + suite_files_rel = SUITES[args.suite] + _maybe_pin_update_weights_model_pair(suite_files_rel) + + suite_files_abs = [] + for f_rel in suite_files_rel: + f_abs = target_dir / f_rel + if not f_abs.exists(): + print(f"Warning: Test file {f_rel} not found in {target_dir}. Skipping.") + continue + suite_files_abs.append(str(f_abs)) + + if not suite_files_abs: + print(f"No valid test files found for suite '{args.suite}'.") + sys.exit(0) + + # 3. collect all test items and partition by items (not files) + all_test_items = collect_test_items(suite_files_abs, filter_expr=args.filter) + + if not all_test_items: + print(f"No test items found for suite '{args.suite}'.") + sys.exit(0) + + # Partition by test items + my_items = [ + item + for i, item in enumerate(all_test_items) + if i % args.total_partitions == args.partition_id + ] + + # Print test info at beginning (similar to test/run_suite.py pretty_print_tests) + partition_info = f"{args.partition_id + 1}/{args.total_partitions} (0-based id={args.partition_id})" + headers = ["Suite", "Partition"] + rows = [[args.suite, partition_info]] + msg = tabulate.tabulate(rows, headers=headers, tablefmt="psql") + "\n" + msg += f"✅ Enabled {len(my_items)} test(s):\n" + for item in my_items: + msg += f" - {item}\n" + print(msg, flush=True) + print( + f"Suite: {args.suite} | Partition: {args.partition_id}/{args.total_partitions}" + ) + print(f"Selected {len(suite_files_abs)} files:") + for f in suite_files_abs: + print(f" - {os.path.basename(f)}") + + if not my_items: + print("No items assigned to this partition. Exiting success.") + sys.exit(0) + + print(f"Running {len(my_items)} items in this shard: {', '.join(my_items)}") + + # 4. execute with the specific test items + exit_code = run_pytest(my_items) + + # Print tests again at the end for visibility + msg = "\n" + tabulate.tabulate(rows, headers=headers, tablefmt="psql") + "\n" + msg += f"✅ Executed {len(my_items)} test(s):\n" + for item in my_items: + msg += f" - {item}\n" + print(msg, flush=True) + + sys.exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/sglang/python/sglang/multimodal_gen/test/scripts/gen_diffusion_ci_outputs.py b/sglang/python/sglang/multimodal_gen/test/scripts/gen_diffusion_ci_outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..bcc82fab9ce6482824c960725f73ecff83214dfd --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/scripts/gen_diffusion_ci_outputs.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +""" +Generate diffusion CI outputs for consistency testing. + +This script reuses the CI test code by calling run_suite.py with SGLANG_GEN_GT=1, +ensuring that GT generation uses exactly the same code path as CI tests. + +Usage: + python gen_diffusion_ci_outputs.py --suite 1-gpu --partition-id 0 --total-partitions 2 --out-dir ./output + python gen_diffusion_ci_outputs.py --suite 1-gpu --case-ids qwen_image_t2i flux_image_t2i --out-dir ./output +""" + +import argparse +import os +import sys +from pathlib import Path + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.test.run_suite import SUITES, collect_test_items, run_pytest + +logger = init_logger(__name__) + + +def main(): + """Main entry point.""" + parser = argparse.ArgumentParser(description="Generate diffusion CI outputs") + parser.add_argument( + "--suite", + type=str, + choices=["1-gpu", "2-gpu"], + required=True, + help="Test suite to run (1-gpu or 2-gpu)", + ) + parser.add_argument( + "--partition-id", + type=int, + required=False, + help="Partition ID for matrix partitioning (0-based)", + ) + parser.add_argument( + "--total-partitions", + type=int, + required=False, + help="Total number of partitions", + ) + parser.add_argument( + "--out-dir", + type=str, + required=True, + help="Output directory for generated files", + ) + parser.add_argument( + "--continue-on-error", + action="store_true", + help="Continue processing other cases if one fails", + ) + parser.add_argument( + "--case-ids", + type=str, + nargs="*", + required=False, + help="Specific case IDs to run (space-separated). If provided, only these cases will be run.", + ) + + args = parser.parse_args() + + # Validate partition arguments + if args.partition_id is not None and args.total_partitions is not None: + if args.partition_id < 0 or args.partition_id >= args.total_partitions: + parser.error(f"partition-id must be in range [0, {args.total_partitions})") + elif args.partition_id is not None or args.total_partitions is not None: + parser.error( + "Both --partition-id and --total-partitions must be provided together" + ) + + # Create output directory + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + # Set environment variables for GT generation mode + os.environ["SGLANG_GEN_GT"] = "1" + os.environ["SGLANG_GT_OUTPUT_DIR"] = str(out_dir.absolute()) + os.environ["SGLANG_SKIP_CONSISTENCY"] = ( + "1" # Skip consistency checks in GT gen mode + ) + + logger.info(f"GT generation mode enabled") + logger.info(f"Output directory: {out_dir}") + + # Resolve test files path (same as run_suite.py) + current_file_path = Path(__file__).resolve() + test_root_dir = current_file_path.parent.parent # scripts -> test + target_dir = test_root_dir / "server" + + # Get files from suite (same as run_suite.py) + suite_files_rel = SUITES[args.suite] + suite_files_abs = [] + for f_rel in suite_files_rel: + f_abs = target_dir / f_rel + if not f_abs.exists(): + logger.warning(f"Test file {f_rel} not found in {target_dir}. Skipping.") + continue + suite_files_abs.append(str(f_abs)) + + if not suite_files_abs: + logger.error(f"No valid test files found for suite '{args.suite}'.") + sys.exit(1) + + # Build pytest filter for case_ids if provided + filter_expr = None + if args.case_ids: + # pytest parametrized test format: test_diffusion_generation[case_id] + filters = [f"test_diffusion_generation[{case_id}]" for case_id in args.case_ids] + filter_expr = " or ".join(filters) + logger.info(f"Filtering by case IDs: {args.case_ids}") + + # Collect all test items (same as run_suite.py) + all_test_items = collect_test_items(suite_files_abs, filter_expr=filter_expr) + + if not all_test_items: + logger.warning(f"No test items found for suite '{args.suite}'.") + sys.exit(0) + + # Partition by test items (same as run_suite.py) + partition_id = args.partition_id if args.partition_id is not None else 0 + total_partitions = args.total_partitions if args.total_partitions is not None else 1 + + my_items = [ + item + for i, item in enumerate(all_test_items) + if i % total_partitions == partition_id + ] + + logger.info( + f"Partition {partition_id}/{total_partitions}: " + f"running {len(my_items)} of {len(all_test_items)} test items" + ) + + if not my_items: + logger.warning("No items assigned to this partition. Exiting success.") + sys.exit(0) + + # Run pytest with the specific test items (same as run_suite.py) + exit_code = run_pytest(my_items) + + if exit_code != 0: + if args.continue_on_error: + logger.warning(f"pytest exited with code {exit_code}") + else: + sys.exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/sglang/python/sglang/multimodal_gen/test/scripts/gen_perf_baselines.py b/sglang/python/sglang/multimodal_gen/test/scripts/gen_perf_baselines.py new file mode 100644 index 0000000000000000000000000000000000000000..982e51374e3f0f1f00463aa6be69a12db3278d20 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/scripts/gen_perf_baselines.py @@ -0,0 +1,210 @@ +import argparse +import inspect +import json +import os +import re +import sys +from pathlib import Path + +from openai import OpenAI + +from sglang.multimodal_gen.test.server.test_server_utils import ( + ServerManager, + get_generate_fn, +) +from sglang.multimodal_gen.test.server.testcase_configs import ( + BASELINE_CONFIG, + DiffusionTestCase, +) +from sglang.multimodal_gen.test.test_utils import ( + get_dynamic_server_port, + wait_for_req_perf_record, +) + + +def _all_cases() -> list[DiffusionTestCase]: + import sglang.multimodal_gen.test.server.testcase_configs as cfg + + cases: list[DiffusionTestCase] = [] + for _, v in inspect.getmembers(cfg): + if isinstance(v, list) and v and isinstance(v[0], DiffusionTestCase): + cases.extend(v) + + seen: set[str] = set() + out: list[DiffusionTestCase] = [] + for c in cases: + if c.id not in seen: + seen.add(c.id) + out.append(c) + return out + + +def _baseline_path() -> Path: + import sglang.multimodal_gen.test.server.testcase_configs as cfg + + return Path(cfg.__file__).with_name("perf_baselines.json") + + +def _openai_client(port: int) -> OpenAI: + return OpenAI(api_key="sglang-anything", base_url=f"http://localhost:{port}/v1") + + +def _build_server_extra_args(case: DiffusionTestCase) -> str: + server_args = case.server_args + a = os.environ.get("SGLANG_TEST_SERVE_ARGS", "") + a += f" --num-gpus {server_args.num_gpus}" + if server_args.tp_size is not None: + a += f" --tp-size {server_args.tp_size}" + if server_args.ulysses_degree is not None: + a += f" --ulysses-degree {server_args.ulysses_degree}" + if server_args.dit_layerwise_offload: + a += " --dit-layerwise-offload true" + if server_args.dit_offload_prefetch_size: + a += f" --dit-offload-prefetch-size {server_args.dit_offload_prefetch_size}" + if server_args.ring_degree is not None: + a += f" --ring-degree {server_args.ring_degree}" + if server_args.lora_path: + a += f" --lora-path {server_args.lora_path}" + + # default warmup + a += " --warmup" + + for extra_arg in server_args.extras: + a += f" {extra_arg}" + return a + + +def _build_env_vars(case: DiffusionTestCase) -> dict[str, str]: + if case.server_args.enable_cache_dit: + return {"SGLANG_CACHE_DIT_ENABLED": "true"} + return {} + + +def _torch_cleanup() -> None: + try: + import gc + + gc.collect() + except Exception: + pass + try: + import torch + + if torch.get_device_module().is_available(): + torch.get_device_module().synchronize() + torch.get_device_module().empty_cache() + except Exception: + pass + + +def _run_case(case: DiffusionTestCase) -> dict: + default_port = get_dynamic_server_port() + port = int(os.environ.get("SGLANG_TEST_SERVER_PORT", default_port)) + mgr = ServerManager( + model=case.server_args.model_path, + port=port, + wait_deadline=float(os.environ.get("SGLANG_TEST_WAIT_SECS", "1200")), + extra_args=_build_server_extra_args(case), + env_vars=_build_env_vars(case), + ) + ctx = mgr.start() + try: + sp = case.sampling_params + output_size = os.environ.get("SGLANG_TEST_OUTPUT_SIZE", sp.output_size) + client = _openai_client(ctx.port) + gen = get_generate_fn( + model_path=case.server_args.model_path, + modality=case.server_args.modality, + sampling_params=sp, + ) + rid, _ = gen(case.id, client) + rec = wait_for_req_perf_record( + rid, + ctx.perf_log_path, + timeout=float(os.environ.get("SGLANG_PERF_TIMEOUT", "300")), + ) + if rec is None: + raise RuntimeError(f"missing perf record: {case.id}") + from sglang.multimodal_gen.test.server.testcase_configs import ( + PerformanceSummary, + ) + + perf = PerformanceSummary.from_req_perf_record( + rec, BASELINE_CONFIG.step_fractions + ) + if case.server_args.modality == "video" and sp.num_frames and sp.num_frames > 0: + if "per_frame_generation" not in perf.stage_metrics: + perf.stage_metrics["per_frame_generation"] = perf.e2e_ms / sp.num_frames + + return { + "stages_ms": {k: round(v, 2) for k, v in perf.stage_metrics.items()}, + "denoise_step_ms": { + str(k): round(v, 2) for k, v in perf.all_denoise_steps.items() + }, + "expected_e2e_ms": round(perf.e2e_ms, 2), + "expected_avg_denoise_ms": round(perf.avg_denoise_ms, 2), + "expected_median_denoise_ms": round(perf.median_denoise_ms, 2), + } + finally: + ctx.cleanup() + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--baseline", default="") + ap.add_argument("--out", default="") + ap.add_argument("--match", default="") + ap.add_argument("--case", action="append", default=[]) + ap.add_argument("--all-from-baseline", action="store_true") + ap.add_argument("--timeout", type=float, default=300.0) + args = ap.parse_args() + + os.environ.setdefault("SGLANG_GEN_BASELINE", "1") + os.environ["SGLANG_PERF_TIMEOUT"] = str(args.timeout) + + baseline_path = Path(args.baseline) if args.baseline else _baseline_path() + out_path = Path(args.out) if args.out else baseline_path + data = json.loads(baseline_path.read_text(encoding="utf-8")) + scenarios = data.setdefault("scenarios", {}) + + ids = set(args.case) if args.case else None + pat = re.compile(args.match) if args.match else None + if args.all_from_baseline: + ids = set(scenarios.keys()) + pat = None + + all_cases = _all_cases() + cases = [] + for c in all_cases: + if ids and c.id not in ids: + continue + if pat and not pat.search(c.id): + continue + cases.append(c) + + if args.all_from_baseline and ids: + case_ids = {c.id for c in all_cases} + missing = sorted([i for i in ids if i not in case_ids]) + if missing: + sys.stderr.write(f"missing cases in testcase_configs.py: {len(missing)}\n") + + if not cases: + return 0 + + for c in cases: + prev = scenarios.get(c.id, {}) + note = prev.get("notes") + baseline = _run_case(c) + if note is not None: + baseline["notes"] = note + scenarios[c.id] = baseline + sys.stdout.write(f"{c.id}\n") + sys.stdout.flush() + _torch_cleanup() + + out_path.write_text(json.dumps(data, indent=4) + "\n", encoding="utf-8") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/sglang/python/sglang/multimodal_gen/test/server/ascend/perf_baselines_npu.json b/sglang/python/sglang/multimodal_gen/test/server/ascend/perf_baselines_npu.json new file mode 100644 index 0000000000000000000000000000000000000000..c901a46f29322adb97d90b81b230f713aa74d717 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/server/ascend/perf_baselines_npu.json @@ -0,0 +1,206 @@ +{ + "metadata": { + "model": "Diffusion Server", + "hardware": "CI A2 64GB pool", + "description": "Reference numbers captured from the CI diffusion server baseline run" + }, + "scenarios": { + "flux_image_t2i_npu": { + "stages_ms": { + "InputValidationStage": 0.07, + "TextEncodingStage": 154.51, + "TimestepPreparationStage": 53.52, + "LatentPreparationStage": 0.39, + "DenoisingStage": 19423.39, + "DecodingStage": 40.14 + }, + "denoise_step_ms": { + "0": 123.16, + "1": 91.7, + "2": 265.62, + "3": 402.68, + "4": 402.86, + "5": 402.78, + "6": 402.99, + "7": 402.77, + "8": 402.59, + "9": 402.93, + "10": 402.05, + "11": 402.99, + "12": 402.29, + "13": 403.07, + "14": 402.62, + "15": 402.99, + "16": 402.68, + "17": 403.0, + "18": 402.74, + "19": 402.85, + "20": 402.83, + "21": 403.03, + "22": 402.56, + "23": 402.84, + "24": 402.79, + "25": 402.95, + "26": 402.65, + "27": 403.01, + "28": 402.66, + "29": 402.92, + "30": 402.75, + "31": 403.0, + "32": 402.9, + "33": 402.48, + "34": 402.85, + "35": 402.03, + "36": 402.93, + "37": 402.3, + "38": 403.12, + "39": 402.83, + "40": 402.84, + "41": 402.75, + "42": 402.97, + "43": 402.62, + "44": 402.91, + "45": 402.81, + "46": 402.97, + "47": 402.57, + "48": 403.0, + "49": 402.75 + }, + "expected_e2e_ms": 23819.1, + "expected_avg_denoise_ms": 388.22, + "expected_median_denoise_ms": 402.82 + }, + "flux_2_image_t2i_2npu": { + "stages_ms": { + "InputValidationStage": 0.06, + "TextEncodingStage": 5628.31, + "ImageVAEEncodingStage": 0.01, + "LatentPreparationStage": 0.75, + "TimestepPreparationStage": 30.68, + "DenoisingStage": 55002.26, + "DecodingStage": 43.73 + }, + "denoise_step_ms": { + "0": 110.35, + "1": 301.82, + "2": 1139.81, + "3": 1114.17, + "4": 1099.34, + "5": 1099.12, + "6": 1100.16, + "7": 1099.67, + "8": 1099.09, + "9": 1089.81, + "10": 1109.73, + "11": 1099.97, + "12": 1100.26, + "13": 1099.67, + "14": 1099.79, + "15": 1099.6, + "16": 1100.16, + "17": 1099.87, + "18": 1100.02, + "19": 1099.34, + "20": 1099.6, + "21": 1099.45, + "22": 1100.2, + "23": 1099.29, + "24": 1098.86, + "25": 1090.38, + "26": 1109.19, + "27": 1099.67, + "28": 1100.06, + "29": 1099.22, + "30": 1100.08, + "31": 1098.86, + "32": 1099.73, + "33": 1099.11, + "34": 1100.13, + "35": 1103.97, + "36": 1095.26, + "37": 1099.38, + "38": 1099.34, + "39": 1099.17, + "40": 1100.08, + "41": 1089.89, + "42": 1106.69, + "43": 1102.57, + "44": 1100.17, + "45": 1099.21, + "46": 1100.42, + "47": 1099.38, + "48": 1099.59, + "49": 1099.47 + }, + "expected_e2e_ms": 64195.08, + "expected_avg_denoise_ms": 1065.0, + "expected_median_denoise_ms": 1099.63 + }, + "wan2_1_t2v_1.3b_1_npu": { + "stages_ms": { + "InputValidationStage": 0.07, + "TextEncodingStage": 876.11, + "LatentPreparationStage": 0.25, + "TimestepPreparationStage": 2.9, + "DenoisingStage": 26188.0, + "DecodingStage": 320.03, + "per_frame_generation": null + }, + "denoise_step_ms": { + "0": 103.56, + "1": 329.59, + "2": 545.23, + "3": 537.0, + "4": 536.27, + "5": 536.29, + "6": 536.33, + "7": 536.0, + "8": 536.17, + "9": 536.28, + "10": 535.53, + "11": 536.04, + "12": 536.42, + "13": 536.09, + "14": 536.32, + "15": 536.25, + "16": 536.36, + "17": 536.21, + "18": 536.29, + "19": 536.15, + "20": 536.28, + "21": 536.5, + "22": 536.46, + "23": 536.06, + "24": 536.45, + "25": 536.24, + "26": 536.14, + "27": 536.13, + "28": 536.22, + "29": 536.15, + "30": 535.94, + "31": 536.1, + "32": 536.13, + "33": 536.2, + "34": 536.24, + "35": 536.34, + "36": 536.54, + "37": 536.42, + "38": 536.41, + "39": 536.42, + "40": 536.13, + "41": 536.32, + "42": 536.23, + "43": 536.16, + "44": 536.05, + "45": 536.18, + "46": 536.08, + "47": 536.34, + "48": 536.26, + "49": 535.41 + }, + "expected_e2e_ms": 38738.17, + "expected_avg_denoise_ms": 523.62, + "expected_median_denoise_ms": 536.23 + } + } +} diff --git a/sglang/python/sglang/multimodal_gen/test/server/ascend/test_server_1_npu.py b/sglang/python/sglang/multimodal_gen/test/server/ascend/test_server_1_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..3be09a8992dc82f53f1d58feffc3bf578c8c1991 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/server/ascend/test_server_1_npu.py @@ -0,0 +1,29 @@ +""" +Config-driven diffusion performance test with pytest parametrization. + + +If the actual run is significantly better than the baseline, the improved cases with their updated baseline will be printed +""" + +from __future__ import annotations + +import pytest + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.test.server.ascend.testcase_configs_npu import ONE_NPU_CASES +from sglang.multimodal_gen.test.server.test_server_common import ( # noqa: F401 + DiffusionServerBase, + diffusion_server, +) +from sglang.multimodal_gen.test.server.testcase_configs import DiffusionTestCase + +logger = init_logger(__name__) + + +class TestDiffusionServerOneNpu(DiffusionServerBase): + """Performance tests for 1-NPU diffusion cases.""" + + @pytest.fixture(params=ONE_NPU_CASES, ids=lambda c: c.id) + def case(self, request) -> DiffusionTestCase: + """Provide a DiffusionTestCase for each 1-NPU test.""" + return request.param diff --git a/sglang/python/sglang/multimodal_gen/test/server/ascend/test_server_2_npu.py b/sglang/python/sglang/multimodal_gen/test/server/ascend/test_server_2_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..91bf37badae14c1a3249fde5aabdef904f91e963 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/server/ascend/test_server_2_npu.py @@ -0,0 +1,29 @@ +""" +Config-driven diffusion performance test with pytest parametrization. + + +If the actual run is significantly better than the baseline, the improved cases with their updated baseline will be printed +""" + +from __future__ import annotations + +import pytest + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.test.server.ascend.testcase_configs_npu import TWO_NPU_CASES +from sglang.multimodal_gen.test.server.test_server_common import ( # noqa: F401 + DiffusionServerBase, + diffusion_server, +) +from sglang.multimodal_gen.test.server.testcase_configs import DiffusionTestCase + +logger = init_logger(__name__) + + +class TestDiffusionServerTwoNpu(DiffusionServerBase): + """Performance tests for 2-NPU diffusion cases.""" + + @pytest.fixture(params=TWO_NPU_CASES, ids=lambda c: c.id) + def case(self, request) -> DiffusionTestCase: + """Provide a DiffusionTestCase for each 2-NPU test.""" + return request.param diff --git a/sglang/python/sglang/multimodal_gen/test/server/ascend/testcase_configs_npu.py b/sglang/python/sglang/multimodal_gen/test/server/ascend/testcase_configs_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..c9e160f10bea30929e3003fde9eec61e595be8af --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/server/ascend/testcase_configs_npu.py @@ -0,0 +1,45 @@ +from sglang.multimodal_gen.test.server.testcase_configs import ( + T2V_PROMPT, + DiffusionSamplingParams, + DiffusionServerArgs, + DiffusionTestCase, + T2I_sampling_params, +) + +ONE_NPU_CASES: list[DiffusionTestCase] = [ + # === Text to Image (T2I) === + DiffusionTestCase( + "flux_image_t2i_npu", + DiffusionServerArgs( + model_path="/root/.cache/modelscope/hub/models/black-forest-labs/FLUX.1-dev", + modality="image", + ), + T2I_sampling_params, + ), + # === Text to Video (T2V) === + DiffusionTestCase( + "wan2_1_t2v_1.3b_1_npu", + DiffusionServerArgs( + model_path="/root/.cache/modelscope/hub/models/Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + modality="video", + custom_validator="video", + ), + DiffusionSamplingParams( + prompt=T2V_PROMPT, + ), + ), +] + +TWO_NPU_CASES: list[DiffusionTestCase] = [ + # === Text to Image (T2I) === + DiffusionTestCase( + "flux_2_image_t2i_2npu", + DiffusionServerArgs( + model_path="/root/.cache/modelscope/hub/models/black-forest-labs/FLUX.2-dev", + modality="image", + num_gpus=2, + tp_size=2, + ), + T2I_sampling_params, + ), +] diff --git a/sglang/python/sglang/multimodal_gen/test/server/conftest.py b/sglang/python/sglang/multimodal_gen/test/server/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..9b5ddde583d75557cc2f5870f2e6374d94a77db3 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/server/conftest.py @@ -0,0 +1,160 @@ +import os + +import pytest + +print("[CONFTEST] Loading conftest.py at import time") + + +def pytest_configure(config): + """ + Create the perf results StashKey once and store it in config. + This hook runs once per test session, before module double-import issues. + """ + if not hasattr(config, "_diffusion_perf_key"): + config._diffusion_perf_key = pytest.StashKey[list]() + print(f"[CONFTEST] Created perf_results_key: {config._diffusion_perf_key}") + + +def add_perf_results(config, results: list): + """Add performance results to the shared stash.""" + # Get the shared key from config (created once in pytest_configure) + key = config._diffusion_perf_key + existing = config.stash.get(key, []) + existing.extend(results) + config.stash[key] = existing + print(f"[CONFTEST] Added {len(results)} results, total now: {len(existing)}") + + +@pytest.fixture(scope="session") +def perf_config(request): + """Provide access to pytest config for storing perf results.""" + return request.config + + +def _write_github_step_summary(content: str): + """Write content to GitHub Step Summary if available.""" + summary_file = os.environ.get("GITHUB_STEP_SUMMARY") + if summary_file: + with open(summary_file, "a") as f: + f.write(content) + + +def _write_results_json(results: list, output_path: str = "diffusion-results.json"): + """Write performance results to JSON file for CI artifact collection.""" + import json + + try: + with open(output_path, "w") as f: + json.dump(results, f, indent=2) + print(f"[CONFTEST] Wrote results to {output_path}") + except Exception as e: + print(f"[CONFTEST] Failed to write results JSON: {e}") + + +def _generate_diffusion_markdown_report(results: list) -> str: + """Generate a markdown report for diffusion performance results.""" + if not results: + return "" + + gpu_config = os.environ.get("GPU_CONFIG", "") + header = "## Diffusion Performance Summary" + if gpu_config: + header += f" [{gpu_config}]" + header += "\n\n" + + # Main performance table + markdown = header + markdown += "| Test Suite | Test Name | Modality | E2E (ms) | Avg Denoise (ms) | Median Denoise (ms) |\n" + markdown += "| ---------- | --------- | -------- | -------- | ---------------- | ------------------- |\n" + + for entry in sorted(results, key=lambda x: (x["class_name"], x["test_name"])): + modality = entry.get("modality", "image") + markdown += ( + f"| {entry['class_name']} | {entry['test_name']} | {modality} | " + f"{entry['e2e_ms']:.2f} | {entry['avg_denoise_ms']:.2f} | " + f"{entry['median_denoise_ms']:.2f} |\n" + ) + + # Video-specific metrics table (if any video tests) + video_results = [r for r in results if r.get("modality") == "video"] + if video_results: + markdown += "\n### Video Generation Metrics\n\n" + markdown += "| Test Name | FPS | Total Frames | Avg Frame Time (ms) |\n" + markdown += "| --------- | --- | ------------ | ------------------- |\n" + for entry in video_results: + fps = entry.get("frames_per_second", "N/A") + frames = entry.get("total_frames", "N/A") + avg_frame = entry.get("avg_frame_time_ms", "N/A") + if isinstance(fps, float): + fps = f"{fps:.2f}" + if isinstance(avg_frame, float): + avg_frame = f"{avg_frame:.2f}" + markdown += f"| {entry['test_name']} | {fps} | {frames} | {avg_frame} |\n" + + return markdown + + +def pytest_sessionfinish(session): + """ + This hook is called by pytest at the end of the entire test session. + It prints a consolidated summary of all performance results. + """ + # Get results from stash using the shared key from config + key = session.config._diffusion_perf_key + results = session.config.stash.get(key, []) + print(f"\n[DEBUG] pytest_sessionfinish called, has {len(results)} entries") + if not results: + print("[DEBUG] No results collected, skipping summary output") + return + + # Print to stdout (existing behavior) + print("\n\n" + "=" * 35 + " Performance Summary " + "=" * 35) + print( + f"{'Test Suite':<30} | {'Test Name':<20} | {'E2E (ms)':>12} | {'Avg Denoise (ms)':>18} | {'Median Denoise (ms)':>20}" + ) + print( + "-" * 30 + + "-+-" + + "-" * 20 + + "-+-" + + "-" * 12 + + "-+-" + + "-" * 18 + + "-+-" + + "-" * 20 + ) + + for entry in sorted(results, key=lambda x: x["class_name"]): + print( + f"{entry['class_name']:<30} | {entry['test_name']:<20} | {entry['e2e_ms']:>12.2f} | " + f"{entry['avg_denoise_ms']:>18.2f} | {entry['median_denoise_ms']:>20.2f}" + ) + + print("=" * 91) + + print("\n\n" + "=" * 36 + " Detailed Reports " + "=" * 37) + for entry in sorted(results, key=lambda x: x["class_name"]): + print(f"\n--- Details for {entry['class_name']} / {entry['test_name']} ---") + stage_report = ", ".join( + f"{name}:{duration:.2f}ms" + for name, duration in entry.get("stage_metrics", {}).items() + ) + if stage_report: + print(f" Stages: {stage_report}") + + sampled_steps = entry.get("sampled_steps") or {} + if sampled_steps: + step_report = ", ".join( + f"{idx}:{duration:.2f}ms" + for idx, duration in sorted(sampled_steps.items()) + ) + print(f" Sampled Steps: {step_report}") + print("=" * 91) + + # Write to GitHub Step Summary (new behavior for CI monitoring) + markdown_report = _generate_diffusion_markdown_report(results) + if markdown_report: + _write_github_step_summary(markdown_report) + + # Write results to JSON file for CI artifact collection + _write_results_json(results) diff --git a/sglang/python/sglang/multimodal_gen/test/server/perf_baselines.json b/sglang/python/sglang/multimodal_gen/test/server/perf_baselines.json new file mode 100644 index 0000000000000000000000000000000000000000..335d121bad03191ac7e4277273b42ac09acb3abc --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/server/perf_baselines.json @@ -0,0 +1,2141 @@ +{ + "metadata": { + "model": "Diffusion Server", + "hardware": "CI H100 80GB pool", + "description": "Reference numbers captured from the CI diffusion server baseline run" + }, + "tolerances": { + "long_term": { + "e2e": 0.1, + "denoise_stage": 0.05, + "non_denoise_stage": 0.4, + "denoise_step": 0.2, + "denoise_agg": 0.1 + }, + "pr_test": { + "e2e": 0.15, + "denoise_stage": 0.1, + "non_denoise_stage": 0.6, + "denoise_step": 0.25, + "denoise_agg": 0.15 + } + }, + "improvement_reporting": { + "threshold": 0.2 + }, + "sampling": { + "step_fractions": [ + 0.0, + 0.2, + 0.4, + 0.6, + 0.8, + 1.0 + ] + }, + "scenarios": { + "qwen_image_t2i": { + "notes": "Single-image generation using the default prompt", + "stages_ms": { + "DecodingStage": 51.86, + "TextEncodingStage": 611.83, + "InputValidationStage": 0.05, + "DenoisingStage": 14289.46, + "LatentPreparationStage": 0.2, + "TimestepPreparationStage": 3.34 + }, + "denoise_step_ms": { + "0": 240.5, + "1": 279.1, + "2": 283.29, + "3": 296.63, + "4": 287.72, + "5": 283.39, + "6": 283.98, + "7": 291.82, + "8": 283.1, + "9": 284.43, + "10": 288.95, + "11": 285.6, + "12": 285.99, + "13": 285.47, + "14": 289.66, + "15": 285.74, + "16": 284.15, + "17": 290.27, + "18": 288.04, + "19": 284.57, + "20": 286.69, + "21": 288.95, + "22": 287.09, + "23": 285.6, + "24": 289.31, + "25": 285.48, + "26": 285.53, + "27": 288.13, + "28": 287.65, + "29": 285.97, + "30": 288.9, + "31": 287.97, + "32": 286.48, + "33": 285.38, + "34": 286.62, + "35": 288.22, + "36": 285.6, + "37": 286.61, + "38": 287.06, + "39": 286.2, + "40": 284.6, + "41": 285.69, + "42": 288.46, + "43": 285.53, + "44": 285.34, + "45": 285.74, + "46": 287.25, + "47": 285.0, + "48": 286.82, + "49": 287.19 + }, + "expected_e2e_ms": 14959.11, + "expected_avg_denoise_ms": 285.67, + "expected_median_denoise_ms": 286.1 + }, + "qwen_image_t2i_2_gpus": { + "stages_ms": { + "InputValidationStage": 0.04, + "TextEncodingStage": 693.2, + "TimestepPreparationStage": 2.84, + "LatentPreparationStage": 9.13, + "DenoisingStage": 24529.77, + "DecodingStage": 612.79 + }, + "denoise_step_ms": { + "0": 405.94, + "1": 420.06, + "2": 414.79, + "3": 392.4, + "4": 408.14, + "5": 605.0, + "6": 469.39, + "7": 574.04, + "8": 539.61, + "9": 452.93, + "10": 279.36, + "11": 271.8, + "12": 438.26, + "13": 552.65, + "14": 576.1, + "15": 679.84, + "16": 543.0, + "17": 512.81, + "18": 522.27, + "19": 545.06, + "20": 545.85, + "21": 523.83, + "22": 519.36, + "23": 513.78, + "24": 532.54, + "25": 524.94, + "26": 542.59, + "27": 570.91, + "28": 568.73, + "29": 564.52, + "30": 564.57, + "31": 544.94, + "32": 496.81, + "33": 488.98, + "34": 457.18, + "35": 441.42, + "36": 437.44, + "37": 477.6, + "38": 429.17, + "39": 465.55, + "40": 448.25, + "41": 511.83, + "42": 450.6, + "43": 375.78, + "44": 504.4, + "45": 524.44, + "46": 535.22, + "47": 514.52, + "48": 431.58, + "49": 410.68 + }, + "expected_e2e_ms": 25850.45, + "expected_avg_denoise_ms": 490.43, + "expected_median_denoise_ms": 512.32 + }, + "flux_image_t2i": { + "stages_ms": { + "DecodingStage": 32.72, + "TextEncodingStage": 51.96, + "InputValidationStage": 0.03, + "DenoisingStage": 7545.16, + "LatentPreparationStage": 0.2, + "TimestepPreparationStage": 2.43 + }, + "denoise_step_ms": { + "0": 50.06, + "1": 58.88, + "2": 151.24, + "3": 150.97, + "4": 151.23, + "5": 151.63, + "6": 159.11, + "7": 158.31, + "8": 153.42, + "9": 151.42, + "10": 151.91, + "11": 151.05, + "12": 151.52, + "13": 157.2, + "14": 152.76, + "15": 153.85, + "16": 153.02, + "17": 151.09, + "18": 151.49, + "19": 155.13, + "20": 155.2, + "21": 152.82, + "22": 152.2, + "23": 150.99, + "24": 152.74, + "25": 153.45, + "26": 153.63, + "27": 154.92, + "28": 152.72, + "29": 151.84, + "30": 151.84, + "31": 152.44, + "32": 153.03, + "33": 154.07, + "34": 152.36, + "35": 153.48, + "36": 152.05, + "37": 152.45, + "38": 152.42, + "39": 154.91, + "40": 152.68, + "41": 153.43, + "42": 151.62, + "43": 153.52, + "44": 153.13, + "45": 152.85, + "46": 152.33, + "47": 151.61, + "48": 152.4, + "49": 152.33 + }, + "expected_e2e_ms": 7798.99, + "expected_avg_denoise_ms": 150.77, + "expected_median_denoise_ms": 152.45 + }, + "flux_2_image_t2i": { + "stages_ms": { + "LatentPreparationStage": 0.52, + "TimestepPreparationStage": 2.91, + "TextEncodingStage": 518.54, + "ImageVAEEncodingStage": 0.0, + "InputValidationStage": 0.05, + "DenoisingStage": 24901.97, + "DecodingStage": 8.98 + }, + "denoise_step_ms": { + "0": 69.14, + "1": 132.57, + "2": 508.67, + "3": 493.52, + "4": 504.31, + "5": 492.99, + "6": 501.91, + "7": 495.18, + "8": 500.87, + "9": 497.36, + "10": 498.74, + "11": 497.46, + "12": 499.08, + "13": 494.65, + "14": 500.35, + "15": 496.89, + "16": 500.23, + "17": 497.01, + "18": 501.68, + "19": 493.8, + "20": 501.1, + "21": 494.81, + "22": 501.04, + "23": 499.27, + "24": 500.04, + "25": 497.14, + "26": 499.05, + "27": 494.91, + "28": 496.89, + "29": 498.53, + "30": 497.94, + "31": 497.09, + "32": 497.7, + "33": 497.58, + "34": 496.43, + "35": 497.7, + "36": 497.37, + "37": 497.17, + "38": 499.27, + "39": 495.52, + "40": 501.67, + "41": 495.11, + "42": 500.69, + "43": 501.61, + "44": 501.91, + "45": 495.58, + "46": 499.37, + "47": 496.8, + "48": 497.49, + "49": 495.69 + }, + "expected_e2e_ms": 25832.82, + "expected_avg_denoise_ms": 489.43, + "expected_median_denoise_ms": 497.53 + }, + "flux_2_klein_image_t2i": { + "stages_ms": { + "DecodingStage": 9.27, + "TextEncodingStage": 92.17, + "InputValidationStage": 0.05, + "ImageVAEEncodingStage": 0.0, + "DenoisingStage": 252.01, + "LatentPreparationStage": 0.42, + "TimestepPreparationStage": 1.5 + }, + "denoise_step_ms": { + "0": 19.91, + "1": 19.32, + "2": 51.99, + "3": 61.78 + }, + "expected_e2e_ms": 430.73, + "expected_avg_denoise_ms": 38.25, + "expected_median_denoise_ms": 35.95 + }, + "layerwise_offload": { + "stages_ms": { + "InputValidationStage": 0.06, + "TextEncodingStage": 513.58, + "LatentPreparationStage": 0.46, + "TimestepPreparationStage": 2.38, + "DenoisingStage": 52187.62, + "DecodingStage": 190.31 + }, + "denoise_step_ms": { + "0": 1033.45, + "1": 137.03, + "2": 1046.96, + "3": 1039.28, + "4": 1039.05, + "5": 1043.91, + "6": 1041.75, + "7": 1037.6, + "8": 1043.54, + "9": 1048.63, + "10": 1039.8, + "11": 1042.25, + "12": 1041.54, + "13": 1045.89, + "14": 1038.99, + "15": 1041.82, + "16": 1038.32, + "17": 1045.53, + "18": 1046.54, + "19": 1041.22, + "20": 1044.55, + "21": 1041.31, + "22": 1051.28, + "23": 1043.12, + "24": 1044.65, + "25": 1042.25, + "26": 1046.47, + "27": 1052.9, + "28": 1039.04, + "29": 1042.39, + "30": 1045.33, + "31": 1038.05, + "32": 1037.76, + "33": 1037.93, + "34": 1052.85, + "35": 1045.59, + "36": 1054.32, + "37": 1044.59, + "38": 1043.57, + "39": 1041.93, + "40": 1043.59, + "41": 1046.17, + "42": 1046.92, + "43": 1047.04, + "44": 1046.8, + "45": 1041.86, + "46": 1041.05, + "47": 1044.04, + "48": 1039.77, + "49": 1047.12 + }, + "expected_e2e_ms": 53290.15, + "expected_avg_denoise_ms": 1025.35, + "expected_median_denoise_ms": 1043.33 + }, + "flux_2_ti2i": { + "stages_ms": { + "InputValidationStage": 99.82, + "TextEncodingStage": 519.88, + "ImageVAEEncodingStage": 254.56, + "LatentPreparationStage": 12.4, + "TimestepPreparationStage": 2.71, + "DenoisingStage": 54705.41, + "DecodingStage": 469.47 + }, + "denoise_step_ms": { + "0": 1067.03, + "1": 271.58, + "2": 1073.07, + "3": 1071.93, + "4": 1100.0, + "5": 1102.28, + "6": 1088.3, + "7": 1089.09, + "8": 1086.95, + "9": 1089.33, + "10": 1089.28, + "11": 1096.51, + "12": 1098.88, + "13": 1080.84, + "14": 1098.44, + "15": 1100.88, + "16": 1086.83, + "17": 1090.58, + "18": 1096.35, + "19": 1086.25, + "20": 1082.71, + "21": 1097.6, + "22": 1098.72, + "23": 1100.9, + "24": 1099.02, + "25": 1101.52, + "26": 1098.75, + "27": 1101.41, + "28": 1091.75, + "29": 1087.2, + "30": 1101.33, + "31": 1098.14, + "32": 1100.14, + "33": 1098.91, + "34": 1100.05, + "35": 1099.12, + "36": 1100.22, + "37": 1103.29, + "38": 1092.79, + "39": 1086.59, + "40": 1094.81, + "41": 1105.6, + "42": 1100.54, + "43": 1099.95, + "44": 1096.5, + "45": 1086.69, + "46": 1095.85, + "47": 1092.85, + "48": 1086.17, + "49": 1099.67 + }, + "expected_e2e_ms": 56308.23, + "expected_avg_denoise_ms": 1077.26, + "expected_median_denoise_ms": 1096.5 + }, + "flux_2_ti2i_multi_image_cache_dit": { + "stages_ms": { + "ImageVAEEncodingStage": 282.83, + "DenoisingStage": 26936.93, + "DecodingStage": 129.33, + "TextEncodingStage": 737.01, + "LatentPreparationStage": 0.84, + "TimestepPreparationStage": 20.57, + "InputValidationStage": 84.29 + }, + "denoise_step_ms": { + "0": 1846.55, + "1": 232.6, + "2": 1621.05, + "3": 1606.1, + "4": 1441.78, + "5": 59.2, + "6": 60.32, + "7": 235.36, + "8": 1443.21, + "9": 59.71, + "10": 59.79, + "11": 236.12, + "12": 1439.11, + "13": 59.97, + "14": 60.46, + "15": 239.29, + "16": 1441.32, + "17": 60.76, + "18": 60.68, + "19": 240.16, + "20": 1442.46, + "21": 60.14, + "22": 61.1, + "23": 239.26, + "24": 1443.28, + "25": 59.41, + "26": 60.74, + "27": 238.02, + "28": 1444.38, + "29": 59.09, + "30": 59.12, + "31": 241.69, + "32": 1441.99, + "33": 60.44, + "34": 61.93, + "35": 241.0, + "36": 1443.56, + "37": 60.55, + "38": 61.07, + "39": 238.11, + "40": 1443.33, + "41": 59.08, + "42": 60.43, + "43": 239.23, + "44": 1444.53, + "45": 61.77, + "46": 61.65, + "47": 239.84, + "48": 1443.89, + "49": 59.92 + }, + "expected_e2e_ms": 28591.93, + "expected_avg_denoise_ms": 533.42, + "expected_median_denoise_ms": 235.74 + }, + "flux_image_t2i_2_gpus": { + "stages_ms": { + "InputValidationStage": 0.03, + "TextEncodingStage": 74.47, + "TimestepPreparationStage": 2.23, + "LatentPreparationStage": 6.17, + "DenoisingStage": 8400.49, + "DecodingStage": 381.56 + }, + "denoise_step_ms": { + "0": 73.27, + "1": 166.6, + "2": 167.31, + "3": 168.7, + "4": 168.83, + "5": 171.05, + "6": 174.64, + "7": 170.92, + "8": 169.69, + "9": 169.21, + "10": 167.71, + "11": 177.62, + "12": 166.44, + "13": 174.61, + "14": 170.43, + "15": 169.47, + "16": 167.24, + "17": 169.15, + "18": 169.51, + "19": 172.3, + "20": 172.19, + "21": 172.36, + "22": 168.39, + "23": 168.47, + "24": 170.55, + "25": 170.96, + "26": 168.43, + "27": 169.01, + "28": 169.62, + "29": 170.95, + "30": 171.83, + "31": 171.92, + "32": 170.1, + "33": 170.46, + "34": 169.91, + "35": 168.91, + "36": 170.27, + "37": 170.23, + "38": 169.62, + "39": 169.66, + "40": 169.57, + "41": 169.42, + "42": 168.59, + "43": 171.12, + "44": 169.6, + "45": 169.93, + "46": 171.23, + "47": 171.03, + "48": 170.14, + "49": 169.4 + }, + "expected_e2e_ms": 9006.3, + "expected_avg_denoise_ms": 167.89, + "expected_median_denoise_ms": 169.67 + }, + "zimage_image_t2i": { + "stages_ms": { + "InputValidationStage": 0.03, + "TextEncodingStage": 403.47, + "TimestepPreparationStage": 1.41, + "LatentPreparationStage": 0.11, + "DenoisingStage": 756.21, + "DecodingStage": 29.41 + }, + "denoise_step_ms": { + "0": 22.29, + "1": 75.04, + "2": 93.82, + "3": 93.34, + "4": 93.38, + "5": 93.58, + "6": 94.01, + "7": 93.97, + "8": 94.32 + }, + "expected_e2e_ms": 1292.92, + "expected_avg_denoise_ms": 83.75, + "expected_median_denoise_ms": 93.58 + }, + "zimage_image_t2i_fp8": { + "stages_ms": { + "InputValidationStage": 0.04, + "TextEncodingStage": 428.59, + "LatentPreparationStage": 0.14, + "TimestepPreparationStage": 47.26, + "DenoisingStage": 778.56, + "DecodingStage": 10.39 + }, + "denoise_step_ms": { + "0": 40.9, + "1": 61.08, + "2": 95.65, + "3": 95.83, + "4": 95.65, + "5": 96.09, + "6": 96.23, + "7": 96.04, + "8": 96.29 + }, + "expected_e2e_ms": 1370.28, + "expected_avg_denoise_ms": 85.97, + "expected_median_denoise_ms": 95.83 + }, + "zimage_image_t2i_multi_lora": { + "stages_ms": { + "InputValidationStage": 0.04, + "TextEncodingStage": 413.69, + "TimestepPreparationStage": 1.3, + "LatentPreparationStage": 0.11, + "DenoisingStage": 813.7, + "DecodingStage": 34.51 + }, + "denoise_step_ms": { + "0": 30.35, + "1": 74.53, + "2": 99.34, + "3": 100.92, + "4": 99.46, + "5": 100.57, + "6": 99.72, + "7": 100.86, + "8": 103.87 + }, + "expected_e2e_ms": 1464.31, + "expected_avg_denoise_ms": 89.96, + "expected_median_denoise_ms": 99.72 + }, + "zimage_image_t2i_2_gpus": { + "stages_ms": { + "InputValidationStage": 0.08, + "TextEncodingStage": 420.74, + "TimestepPreparationStage": 1.5, + "LatentPreparationStage": 0.12, + "DenoisingStage": 1304.07, + "DecodingStage": 37.83 + }, + "denoise_step_ms": { + "0": 49.76, + "1": 155.22, + "2": 155.98, + "3": 156.16, + "4": 157.04, + "5": 156.54, + "6": 156.29, + "7": 157.36, + "8": 156.05 + }, + "expected_e2e_ms": 1464.87, + "expected_avg_denoise_ms": 144.49, + "expected_median_denoise_ms": 156.16 + }, + "qwen_image_edit_ti2i": { + "stages_ms": { + "LatentPreparationStage": 0.16, + "TimestepPreparationStage": 2.62, + "ImageEncodingStage": 1174.26, + "ImageVAEEncodingStage": 132.67, + "InputValidationStage": 38.1, + "DenoisingStage": 38135.64, + "DecodingStage": 139.72 + }, + "denoise_step_ms": { + "0": 618.31, + "1": 769.07, + "2": 766.91, + "3": 762.77, + "4": 764.26, + "5": 765.27, + "6": 767.35, + "7": 764.18, + "8": 766.16, + "9": 766.89, + "10": 766.1, + "11": 764.96, + "12": 763.52, + "13": 765.22, + "14": 765.44, + "15": 763.9, + "16": 763.19, + "17": 764.83, + "18": 765.36, + "19": 765.19, + "20": 765.96, + "21": 765.74, + "22": 765.87, + "23": 764.85, + "24": 765.44, + "25": 765.95, + "26": 766.21, + "27": 767.91, + "28": 765.45, + "29": 764.81, + "30": 766.26, + "31": 765.37, + "32": 766.71, + "33": 765.67, + "34": 766.64, + "35": 765.98, + "36": 766.04, + "37": 764.19, + "38": 765.15, + "39": 766.33, + "40": 767.68, + "41": 765.36, + "42": 766.61, + "43": 766.06, + "44": 765.26, + "45": 765.29, + "46": 764.64, + "47": 766.07, + "48": 762.89, + "49": 763.01 + }, + "expected_e2e_ms": 39706.9, + "expected_avg_denoise_ms": 762.57, + "expected_median_denoise_ms": 765.44 + }, + "qwen_image_t2i_cache_dit_enabled": { + "stages_ms": { + "InputValidationStage": 0.05, + "TextEncodingStage": 675.95, + "TimestepPreparationStage": 3.21, + "LatentPreparationStage": 0.2, + "DenoisingStage": 5248.83, + "DecodingStage": 52.24 + }, + "denoise_step_ms": { + "0": 227.68, + "1": 277.41, + "2": 276.7, + "3": 291.52, + "4": 52.8, + "5": 6.58, + "6": 231.58, + "7": 52.55, + "8": 7.69, + "9": 230.59, + "10": 52.58, + "11": 7.14, + "12": 6.95, + "13": 234.71, + "14": 53.28, + "15": 7.09, + "16": 6.63, + "17": 233.93, + "18": 52.71, + "19": 6.64, + "20": 6.5, + "21": 231.37, + "22": 52.28, + "23": 6.61, + "24": 6.48, + "25": 232.86, + "26": 54.92, + "27": 7.51, + "28": 7.19, + "29": 233.51, + "30": 52.97, + "31": 6.72, + "32": 7.02, + "33": 233.14, + "34": 52.47, + "35": 6.66, + "36": 6.52, + "37": 233.84, + "38": 51.49, + "39": 6.87, + "40": 6.74, + "41": 233.75, + "42": 52.65, + "43": 6.62, + "44": 6.55, + "45": 233.45, + "46": 52.33, + "47": 6.55, + "48": 232.58, + "49": 52.84 + }, + "expected_e2e_ms": 5982.78, + "expected_avg_denoise_ms": 104.84, + "expected_median_denoise_ms": 102.01 + }, + "wan2_1_t2v_1.3b_teacache_enabled": { + "stages_ms": { + "DenoisingStage": 4598.36, + "InputValidationStage": 0.07, + "DecodingStage": 552.92, + "LatentPreparationStage": 0.26, + "per_frame_generation": null, + "TextEncodingStage": 1114.01, + "TimestepPreparationStage": 2.1 + }, + "denoise_step_ms": { + "0": 94.24, + "1": 172.68, + "2": 169.48, + "3": 169.08, + "4": 168.38, + "5": 167.27, + "6": 62.95, + "7": 119.56, + "8": 53.34, + "9": 121.85, + "10": 47.64, + "11": 125.75, + "12": 3.24, + "13": 48.21, + "14": 125.17, + "15": 3.71, + "16": 48.15, + "17": 124.61, + "18": 3.3, + "19": 47.25, + "20": 129.33, + "21": 3.11, + "22": 48.03, + "23": 127.46, + "24": 3.37, + "25": 45.6, + "26": 127.17, + "27": 3.35, + "28": 49.83, + "29": 125.42, + "30": 3.19, + "31": 42.76, + "32": 131.19, + "33": 2.93, + "34": 130.04, + "35": 44.77, + "36": 131.45, + "37": 44.06, + "38": 131.02, + "39": 43.48, + "40": 130.42, + "41": 45.24, + "42": 129.46, + "43": 44.6, + "44": 130.33, + "45": 173.84, + "46": 175.58, + "47": 168.16, + "48": 173.85, + "49": 177.56 + }, + "expected_e2e_ms": 6497.84, + "expected_avg_denoise_ms": 91.85, + "expected_median_denoise_ms": 120.7 + }, + "wan2_1_t2v_1.3b": { + "stages_ms": { + "InputValidationStage": 0.07, + "TextEncodingStage": 2237.78, + "TimestepPreparationStage": 2.1, + "LatentPreparationStage": 0.84, + "DenoisingStage": 13041.23, + "DecodingStage": 1274.63, + "per_frame_generation": null + }, + "denoise_step_ms": { + "0": 224.71, + "1": 248.13, + "2": 246.48, + "3": 247.87, + "4": 249.38, + "5": 246.76, + "6": 250.42, + "7": 250.81, + "8": 250.98, + "9": 249.9, + "10": 246.72, + "11": 249.79, + "12": 250.46, + "13": 249.19, + "14": 247.55, + "15": 250.12, + "16": 247.57, + "17": 247.21, + "18": 247.32, + "19": 247.42, + "20": 248.21, + "21": 247.19, + "22": 247.72, + "23": 247.45, + "24": 247.9, + "25": 247.87, + "26": 247.18, + "27": 247.65, + "28": 246.91, + "29": 248.26, + "30": 247.82, + "31": 247.73, + "32": 247.38, + "33": 247.84, + "34": 247.46, + "35": 247.52, + "36": 247.94, + "37": 248.76, + "38": 248.01, + "39": 247.45, + "40": 247.84, + "41": 248.33, + "42": 247.41, + "43": 248.16, + "44": 248.18, + "45": 248.44, + "46": 248.65, + "47": 247.73, + "48": 247.48, + "49": 247.54 + }, + "expected_e2e_ms": 18382.19, + "expected_avg_denoise_ms": 260.76, + "expected_median_denoise_ms": 247.84 + }, + "wan2_1_t2v_1.3b_text_encoder_cpu_offload": { + "stages_ms": { + "InputValidationStage": 0.09, + "TextEncodingStage": 2480.54, + "TimestepPreparationStage": 3.73, + "LatentPreparationStage": 1.34, + "DenoisingStage": 12514.88, + "DecodingStage": 1147.6, + "per_frame_generation": null + }, + "denoise_step_ms": { + "0": 487.21, + "1": 243.47, + "2": 244.28, + "3": 244.06, + "4": 244.77, + "5": 245.86, + "6": 245.38, + "7": 246.74, + "8": 246.28, + "9": 245.58, + "10": 245.6, + "11": 245.21, + "12": 245.08, + "13": 245.03, + "14": 245.53, + "15": 245.36, + "16": 246.17, + "17": 245.32, + "18": 244.37, + "19": 246.83, + "20": 245.87, + "21": 244.93, + "22": 245.11, + "23": 245.23, + "24": 245.76, + "25": 245.44, + "26": 246.47, + "27": 244.56, + "28": 244.76, + "29": 244.79, + "30": 244.76, + "31": 244.8, + "32": 245.11, + "33": 245.27, + "34": 245.37, + "35": 245.3, + "36": 244.84, + "37": 245.26, + "38": 245.38, + "39": 245.31, + "40": 244.7, + "41": 245.84, + "42": 245.66, + "43": 246.68, + "44": 245.38, + "45": 245.98, + "46": 246.02, + "47": 245.96, + "48": 245.31, + "49": 244.99 + }, + "expected_e2e_ms": 16161.11, + "expected_avg_denoise_ms": 250.18, + "expected_median_denoise_ms": 245.32 + }, + "wan2_1_t2v_1.3b_cfg_parallel": { + "stages_ms": { + "InputValidationStage": 0.08, + "TextEncodingStage": 2700.44, + "TimestepPreparationStage": 2.82, + "LatentPreparationStage": 2.0, + "DenoisingStage": 11640.75, + "DecodingStage": 890.63, + "per_frame_generation": null + }, + "denoise_step_ms": { + "0": 266.91, + "1": 211.32, + "2": 206.59, + "3": 208.12, + "4": 210.68, + "5": 210.28, + "6": 213.92, + "7": 211.25, + "8": 212.89, + "9": 205.35, + "10": 205.92, + "11": 208.99, + "12": 207.1, + "13": 208.1, + "14": 206.52, + "15": 205.5, + "16": 205.24, + "17": 204.93, + "18": 207.05, + "19": 203.78, + "20": 205.23, + "21": 203.87, + "22": 204.28, + "23": 203.8, + "24": 206.02, + "25": 207.2, + "26": 209.53, + "27": 207.46, + "28": 206.77, + "29": 208.14, + "30": 208.05, + "31": 208.78, + "32": 209.23, + "33": 209.72, + "34": 208.26, + "35": 208.55, + "36": 205.24, + "37": 204.96, + "38": 203.77, + "39": 210.2, + "40": 202.57, + "41": 204.77, + "42": 204.96, + "43": 203.8, + "44": 203.9, + "45": 204.49, + "46": 207.75, + "47": 209.09, + "48": 207.51, + "49": 207.38 + }, + "expected_e2e_ms": 15245.6, + "expected_avg_denoise_ms": 224.37, + "expected_median_denoise_ms": 207.15 + }, + "turbo_wan2_1_t2v_1.3b": { + "stages_ms": { + "InputValidationStage": 0.06, + "TextEncodingStage": 2508.95, + "TimestepPreparationStage": 73.51, + "LatentPreparationStage": 1.34, + "DmdDenoisingStage": 1285.25, + "DecodingStage": 805.04, + "per_frame_generation": null + }, + "denoise_step_ms": { + "0": 897.62, + "1": 126.04, + "2": 126.52, + "3": 128.26 + }, + "expected_e2e_ms": 4686.66, + "expected_avg_denoise_ms": 319.61, + "expected_median_denoise_ms": 127.39 + }, + "wan2_2_ti2v_5b": { + "stages_ms": { + "InputValidationStage": 96.27, + "TextEncodingStage": 2238.81, + "TimestepPreparationStage": 2.39, + "LatentPreparationStage": 27.62, + "DenoisingStage": 134069.79, + "DecodingStage": 13559.79, + "per_frame_generation": null + }, + "denoise_step_ms": { + "0": 3181.0, + "1": 2561.67, + "2": 2578.49, + "3": 2582.1, + "4": 2572.24, + "5": 2577.72, + "6": 2581.35, + "7": 2578.79, + "8": 2584.98, + "9": 2588.49, + "10": 2594.37, + "11": 2591.19, + "12": 2591.32, + "13": 2595.35, + "14": 2594.35, + "15": 2595.62, + "16": 2596.35, + "17": 2596.11, + "18": 2597.24, + "19": 2603.13, + "20": 2599.9, + "21": 2601.48, + "22": 2603.58, + "23": 2601.13, + "24": 2600.47, + "25": 2604.13, + "26": 2606.04, + "27": 2605.3, + "28": 2602.02, + "29": 2601.83, + "30": 2603.57, + "31": 2606.63, + "32": 2606.1, + "33": 2602.24, + "34": 2603.29, + "35": 2602.34, + "36": 2602.16, + "37": 2608.14, + "38": 2603.48, + "39": 2601.7, + "40": 2603.96, + "41": 2604.58, + "42": 2606.67, + "43": 2603.52, + "44": 2599.88, + "45": 2598.66, + "46": 2600.74, + "47": 2602.31, + "48": 2608.4, + "49": 2606.02 + }, + "expected_e2e_ms": 150004.2, + "expected_avg_denoise_ms": 2608.84, + "expected_median_denoise_ms": 2601.59 + }, + "qwen_image_edit_2509_ti2i": { + "stages_ms": { + "InputValidationStage": 213.24, + "ImageEncodingStage": 1089.12, + "ImageVAEEncodingStage": 304.56, + "TimestepPreparationStage": 2.94, + "LatentPreparationStage": 0.2, + "DenoisingStage": 50724.5, + "DecodingStage": 601.02 + }, + "denoise_step_ms": { + "0": 1057.09, + "1": 1267.06, + "2": 1268.33, + "3": 1268.94, + "4": 1270.36, + "5": 1270.44, + "6": 1268.61, + "7": 1270.21, + "8": 1274.98, + "9": 1271.57, + "10": 1273.15, + "11": 1271.56, + "12": 1272.69, + "13": 1271.62, + "14": 1274.04, + "15": 1276.81, + "16": 1272.2, + "17": 1269.33, + "18": 1275.96, + "19": 1274.43, + "20": 1272.57, + "21": 1275.28, + "22": 1273.63, + "23": 1275.06, + "24": 1277.39, + "25": 1277.27, + "26": 1274.74, + "27": 1273.38, + "28": 1276.77, + "29": 1275.59, + "30": 1275.51, + "31": 1274.9, + "32": 1274.8, + "33": 1279.03, + "34": 1272.9, + "35": 1274.67, + "36": 1272.61, + "37": 1272.82, + "38": 1276.41, + "39": 1273.55 + }, + "expected_e2e_ms": 52938.04, + "expected_avg_denoise_ms": 1267.96, + "expected_median_denoise_ms": 1273.46 + }, + "qwen_image_layered_i2i": { + "stages_ms": { + "QwenImageLayeredBeforeDenoisingStage": 2897.28, + "DecodingStage": 312.93, + "DenoisingStage": 39417.66, + "TimestepPreparationStage": 2.29 + }, + "denoise_step_ms": { + "0": 657.28, + "1": 799.2, + "2": 790.35, + "3": 785.79, + "4": 792.9, + "5": 795.78, + "6": 791.28, + "7": 790.87, + "8": 786.47, + "9": 791.03, + "10": 788.77, + "11": 790.57, + "12": 788.7, + "13": 786.01, + "14": 791.43, + "15": 789.88, + "16": 791.18, + "17": 792.78, + "18": 792.06, + "19": 790.47, + "20": 792.48, + "21": 789.13, + "22": 792.12, + "23": 789.36, + "24": 790.2, + "25": 790.87, + "26": 792.37, + "27": 794.92, + "28": 792.9, + "29": 791.43, + "30": 793.01, + "31": 793.71, + "32": 794.15, + "33": 787.93, + "34": 792.12, + "35": 794.01, + "36": 789.05, + "37": 790.51, + "38": 793.29, + "39": 791.94, + "40": 788.94, + "41": 788.85, + "42": 789.76, + "43": 788.89, + "44": 791.62, + "45": 788.04, + "46": 790.03, + "47": 786.82, + "48": 789.75, + "49": 789.0 + }, + "expected_e2e_ms": 42660.88, + "expected_avg_denoise_ms": 788.2, + "expected_median_denoise_ms": 790.72 + }, + "fastwan2_2_ti2v_5b": { + "stages_ms": { + "InputValidationStage": 300.0, + "TextEncodingStage": 843.86, + "TimestepPreparationStage": 58.66, + "LatentPreparationStage": 28.55, + "DmdDenoisingStage": 499.34, + "DecodingStage": 1924.01, + "per_frame_generation": null + }, + "denoise_step_ms": { + "0": 164.76, + "1": 165.6, + "2": 165.84 + }, + "expected_e2e_ms": 7722.91, + "expected_avg_denoise_ms": 165.42, + "expected_median_denoise_ms": 165.66 + }, + "fast_hunyuan_video": { + "stages_ms": { + "InputValidationStage": 0.34, + "TextEncodingStage": 550.63, + "TimestepPreparationStage": 44.28, + "LatentPreparationStage": 0.29, + "DenoisingStage": 9154.39, + "DecodingStage": 5995.09, + "per_frame_generation": null + }, + "denoise_step_ms": { + "0": 485.99, + "1": 1399.84, + "2": 1399.91, + "3": 1397.79, + "4": 1400.61, + "5": 1402.53 + }, + "expected_e2e_ms": 16672.15, + "expected_avg_denoise_ms": 1608.46, + "expected_median_denoise_ms": 1488.48 + }, + "wan2_2_i2v_a14b_2gpu": { + "stages_ms": { + "InputValidationStage": 18.45, + "TextEncodingStage": 3337.77, + "TimestepPreparationStage": 2.9, + "LatentPreparationStage": 1.25, + "ImageVAEEncodingStage": 1655.89, + "DenoisingStage": 106972.82, + "DecodingStage": 1355.52, + "per_frame_generation": null + }, + "denoise_step_ms": { + "0": 1525.6, + "1": 1582.6, + "2": 1597.84, + "3": 1601.34, + "4": 1600.86, + "5": 1598.32, + "6": 1600.93, + "7": 1599.88, + "8": 1600.0, + "9": 1600.55, + "10": 1599.27, + "11": 1600.59, + "12": 1600.17, + "13": 1599.72, + "14": 1599.76, + "15": 24098.85, + "16": 1601.29, + "17": 1598.89, + "18": 1600.12, + "19": 1600.52, + "20": 1599.59, + "21": 1600.37, + "22": 1600.35, + "23": 1599.7, + "24": 1599.92, + "25": 1599.75, + "26": 1600.2, + "27": 1600.06, + "28": 1600.41, + "29": 1599.35, + "30": 1600.69, + "31": 1600.15, + "32": 1599.33, + "33": 1599.86, + "34": 1600.52, + "35": 1599.84, + "36": 1600.38, + "37": 1599.23, + "38": 1600.27, + "39": 1599.78 + }, + "expected_e2e_ms": 123182.9887, + "expected_avg_denoise_ms": 2831.0, + "expected_median_denoise_ms": 1600.09 + }, + "turbo_wan2_2_i2v_a14b_2gpu": { + "stages_ms": { + "InputValidationStage": 25.01, + "TextEncodingStage": 5198.6, + "TimestepPreparationStage": 56.26, + "LatentPreparationStage": 1.4, + "ImageVAEEncodingStage": 1001.89, + "DmdDenoisingStage": 4487.79, + "DecodingStage": 821.01, + "per_frame_generation": null + }, + "denoise_step_ms": { + "0": 3042.56, + "1": 485.88, + "2": 721.2, + "3": 475.58 + }, + "expected_e2e_ms": 11605.97, + "expected_avg_denoise_ms": 1120.4, + "expected_median_denoise_ms": 481.74 + }, + "wan2_1_i2v_14b_480P_2gpu": { + "stages_ms": { + "InputValidationStage": 38.23, + "TextEncodingStage": 3550.36, + "ImageEncodingStage": 3462.55, + "TimestepPreparationStage": 2.6, + "LatentPreparationStage": 9.73, + "ImageVAEEncodingStage": 2290.98, + "DenoisingStage": 415021.17, + "DecodingStage": 3016.1, + "per_frame_generation": null + }, + "denoise_step_ms": { + "0": 10200.25, + "1": 8222.39, + "2": 8279.38, + "3": 8301.48, + "4": 8338.87, + "5": 8352.39, + "6": 8354.64, + "7": 8353.64, + "8": 8315.58, + "9": 8308.48, + "10": 8299.65, + "11": 8292.7, + "12": 8292.73, + "13": 8285.21, + "14": 8276.06, + "15": 8270.41, + "16": 8273.04, + "17": 8266.04, + "18": 8267.7, + "19": 8264.06, + "20": 8259.32, + "21": 8257.26, + "22": 8253.02, + "23": 8251.77, + "24": 8260.97, + "25": 8251.39, + "26": 8237.43, + "27": 8241.33, + "28": 8235.96, + "29": 8240.6, + "30": 8232.48, + "31": 8237.85, + "32": 8244.3, + "33": 8236.79, + "34": 8239.83, + "35": 8239.89, + "36": 8239.12, + "37": 8246.74, + "38": 8235.67, + "39": 8242.77, + "40": 8241.17, + "41": 8240.24, + "42": 8237.01, + "43": 8231.26, + "44": 8232.85, + "45": 8226.56, + "46": 8236.98, + "47": 8226.73, + "48": 8220.49, + "49": 8217.04 + }, + "expected_e2e_ms": 426697.37, + "expected_avg_denoise_ms": 8300.19, + "expected_median_denoise_ms": 8267.01 + }, + "wan2_1_i2v_14b_720P_2gpu": { + "stages_ms": { + "InputValidationStage": 53.67, + "TextEncodingStage": 2838, + "ImageEncodingStage": 3123.99, + "TimestepPreparationStage": 3.39, + "LatentPreparationStage": 8.41, + "ImageVAEEncodingStage": 2261.05, + "DenoisingStage": 417418.12, + "DecodingStage": 2968.35 + }, + "denoise_step_ms": { + "0": 11848.08, + "1": 8220.3, + "2": 8274.3, + "3": 8298.9, + "4": 8303.34, + "5": 8322.44, + "6": 8314.37, + "7": 8318.54, + "8": 8304.94, + "9": 8303.04, + "10": 8305.22, + "11": 8296.22, + "12": 8289.2, + "13": 8294.19, + "14": 8294.87, + "15": 8285.96, + "16": 8284.98, + "17": 8281.61, + "18": 8277.35, + "19": 8287.46, + "20": 8280.3, + "21": 8279.18, + "22": 8279.37, + "23": 8280.16, + "24": 8282.67, + "25": 8272.14, + "26": 8279.37, + "27": 8271.66, + "28": 8274.6, + "29": 8272.88, + "30": 8273.76, + "31": 8266.17, + "32": 8267.77, + "33": 8266.88, + "34": 8263.14, + "35": 8265.97, + "36": 8267.76, + "37": 8268.03, + "38": 8262.24, + "39": 8261.4, + "40": 8263.65, + "41": 8272.46, + "42": 8254.9, + "43": 8261.03, + "44": 8252.92, + "45": 8262.49, + "46": 8253.67, + "47": 8254.92, + "48": 8257.08, + "49": 8236.56 + }, + "expected_e2e_ms": 427536.9, + "expected_avg_denoise_ms": 8348.21, + "expected_median_denoise_ms": 8274.45 + }, + "wan2_2_t2v_a14b_2gpu": { + "stages_ms": { + "InputValidationStage": 0.07, + "TextEncodingStage": 2575.3, + "TimestepPreparationStage": 1.99, + "LatentPreparationStage": 1.26, + "DenoisingStage": 156678.8406, + "DecodingStage": 2702.7, + "per_frame_generation": null + }, + "denoise_step_ms": { + "0": 17908.3, + "1": 2379.69, + "2": 2393.59, + "3": 2400.91, + "4": 2398.76, + "5": 2403.1, + "6": 2403.26, + "7": 2399.48, + "8": 2401.33, + "9": 2398.4, + "10": 2401.14, + "11": 2409.1, + "12": 2401.16, + "13": 2408.74, + "14": 2404.97, + "15": 2400.51, + "16": 2402.84, + "17": 2401.87, + "18": 2399.67, + "19": 2400.71, + "20": 2399.23, + "21": 2400.13, + "22": 2400.64, + "23": 2399.15, + "24": 2399.58, + "25": 2400.26, + "26": 35247.02, + "27": 2390.25, + "28": 2398.42, + "29": 2399.8, + "30": 2400.08, + "31": 2400.58, + "32": 2403.68, + "33": 2399.37, + "34": 2401.53, + "35": 2399.69, + "36": 2399.9, + "37": 2400.75, + "38": 2398.97, + "39": 2399.12 + }, + "expected_e2e_ms": 149864.99, + "expected_avg_denoise_ms": 3608.89, + "expected_median_denoise_ms": 2400.38 + }, + "wan2_1_t2v_14b_2gpu": { + "stages_ms": { + "InputValidationStage": 0.05, + "TextEncodingStage": 2310.34, + "TimestepPreparationStage": 2.42, + "LatentPreparationStage": 27.7, + "DenoisingStage": 803631.52, + "DecodingStage": 8898.74, + "per_frame_generation": null + }, + "denoise_step_ms": { + "0": 17347.88, + "1": 15956.93, + "2": 16027.54, + "3": 16054.15, + "4": 16081.46, + "5": 16062.7, + "6": 16058.56, + "7": 16057.58, + "8": 16061.04, + "9": 16120.97, + "10": 16036.84, + "11": 16019.6, + "12": 16042.29, + "13": 16039.87, + "14": 16063.0, + "15": 16036.16, + "16": 16079.82, + "17": 16019.7, + "18": 16061.5, + "19": 16039.95, + "20": 16009.42, + "21": 16051.01, + "22": 16039.31, + "23": 16048.22, + "24": 16071.41, + "25": 16078.75, + "26": 16061.78, + "27": 16018.39, + "28": 16041.44, + "29": 16039.64, + "30": 16041.89, + "31": 16039.6, + "32": 16038.97, + "33": 15999.48, + "34": 16019.93, + "35": 16040.27, + "36": 16020.3, + "37": 16039.38, + "38": 15999.4, + "39": 16022.15, + "40": 16042.32, + "41": 16016.62, + "42": 15998.92, + "43": 16041.48, + "44": 15999.63, + "45": 16003.21, + "46": 15995.91, + "47": 16023.52, + "48": 16016.64, + "49": 16019.6 + }, + "expected_e2e_ms": 814884.71, + "expected_avg_denoise_ms": 16062.92, + "expected_median_denoise_ms": 16039.62 + }, + "wan2_2_t2v_a14b_lora_2gpu": { + "stages_ms": { + "InputValidationStage": 0.09, + "TextEncodingStage": 2552.97, + "TimestepPreparationStage": 1.99, + "LatentPreparationStage": 1.29, + "DenoisingStage": 154340.69, + "DecodingStage": 2730.86, + "per_frame_generation": null + }, + "denoise_step_ms": { + "0": 26510.7, + "1": 2381.25, + "2": 2396.9, + "3": 2400.96, + "4": 2402.47, + "5": 2399.6, + "6": 2400.5, + "7": 2401.13, + "8": 2399.32, + "9": 2400.0, + "10": 2401.35, + "11": 2400.04, + "12": 2408.27, + "13": 2407.08, + "14": 2405.92, + "15": 2403.99, + "16": 2402.12, + "17": 2402.52, + "18": 2398.08, + "19": 2399.9, + "20": 2400.14, + "21": 2398.64, + "22": 2401.32, + "23": 2400.75, + "24": 2399.27, + "25": 2400.21, + "26": 36387.55, + "27": 2399.77, + "28": 2398.09, + "29": 2404.64, + "30": 2400.68, + "31": 2404.3, + "32": 2392.44, + "33": 2390.56, + "34": 2396.05, + "35": 2394.86, + "36": 2396.07, + "37": 2398.49, + "38": 2394.77, + "39": 2394.19 + }, + "expected_e2e_ms": 159643.06, + "expected_avg_denoise_ms": 3851.87, + "expected_median_denoise_ms": 2400.09 + }, + "wan2_1_t2v_1_3b_lora_1gpu": { + "stages_ms": { + "InputValidationStage": 0.06, + "TextEncodingStage": 2467.44, + "TimestepPreparationStage": 2.96, + "LatentPreparationStage": 1.87, + "DenoisingStage": 14859.47, + "DecodingStage": 1199.31, + "per_frame_generation": null + }, + "denoise_step_ms": { + "0": 233.29, + "1": 265.02, + "2": 257.83, + "3": 260.27, + "4": 261.43, + "5": 258.58, + "6": 256.64, + "7": 256.91, + "8": 258.41, + "9": 257.84, + "10": 257.08, + "11": 257.0, + "12": 258.44, + "13": 257.1, + "14": 256.95, + "15": 257.2, + "16": 256.84, + "17": 257.64, + "18": 257.22, + "19": 257.42, + "20": 256.91, + "21": 256.99, + "22": 257.17, + "23": 257.63, + "24": 258.89, + "25": 257.46, + "26": 257.3, + "27": 257.42, + "28": 257.19, + "29": 257.65, + "30": 257.39, + "31": 256.93, + "32": 258.23, + "33": 257.62, + "34": 281.86, + "35": 295.86, + "36": 296.73, + "37": 287.21, + "38": 300.87, + "39": 303.47, + "40": 294.09, + "41": 270.52, + "42": 256.53, + "43": 256.58, + "44": 256.29, + "45": 255.81, + "46": 256.34, + "47": 256.08, + "48": 255.92, + "49": 255.87 + }, + "expected_e2e_ms": 18547.46, + "expected_avg_denoise_ms": 297.09, + "expected_median_denoise_ms": 257.42 + }, + "wan2_1_i2v_14b_lora_2gpu": { + "stages_ms": { + "InputValidationStage": 23.97, + "TextEncodingStage": 2485.39, + "ImageEncodingStage": 2372.07, + "TimestepPreparationStage": 2.6, + "LatentPreparationStage": 0.18, + "ImageVAEEncodingStage": 2500.13, + "DenoisingStage": 193514.04, + "DecodingStage": 3341.78, + "per_frame_generation": null + }, + "denoise_step_ms": { + "0": 6680.62, + "1": 3765.8, + "2": 3774.63, + "3": 3772.93, + "4": 3781.13, + "5": 3778.22, + "6": 3776.41, + "7": 3772.02, + "8": 3776.15, + "9": 3768.82, + "10": 3775.31, + "11": 3771.32, + "12": 3774.33, + "13": 3772.5, + "14": 3778.41, + "15": 3775.31, + "16": 3771.38, + "17": 3774.87, + "18": 3780.01, + "19": 3772.85, + "20": 3773.65, + "21": 3774.47, + "22": 3774.39, + "23": 3773.08, + "24": 3776.71, + "25": 3780.01, + "26": 3774.83, + "27": 3773.27, + "28": 3773.76, + "29": 3772.75, + "30": 3773.01, + "31": 3773.34, + "32": 3773.13, + "33": 3774.12, + "34": 3772.19, + "35": 3774.7, + "36": 3773.98, + "37": 3772.47, + "38": 3771.72, + "39": 3774.07, + "40": 3773.71, + "41": 3773.6, + "42": 3772.12, + "43": 3773.75, + "44": 3782.43, + "45": 3779.66, + "46": 3779.86, + "47": 3774.58, + "48": 3770.54, + "49": 3776.76 + }, + "expected_e2e_ms": 204257.12, + "expected_avg_denoise_ms": 3855.55, + "expected_median_denoise_ms": 3774.03 + }, + "flux_2_image_t2i_2_gpus": { + "stages_ms": { + "InputValidationStage": 0.05, + "TextEncodingStage": 518.88, + "ImageVAEEncodingStage": 0.0, + "LatentPreparationStage": 0.45, + "TimestepPreparationStage": 3.41, + "DenoisingStage": 26377.63, + "DecodingStage": 321.94 + }, + "denoise_step_ms": { + "0": 129.07, + "1": 437.16, + "2": 437.7, + "3": 437.67, + "4": 437.84, + "5": 438.03, + "6": 438.09, + "7": 437.65, + "8": 437.95, + "9": 438.31, + "10": 437.99, + "11": 438.54, + "12": 438.47, + "13": 438.2, + "14": 438.56, + "15": 438.69, + "16": 438.69, + "17": 438.98, + "18": 437.96, + "19": 438.9, + "20": 438.87, + "21": 438.04, + "22": 437.88, + "23": 439.09, + "24": 438.61, + "25": 437.68, + "26": 439.2, + "27": 439.63, + "28": 438.65, + "29": 439.32, + "30": 439.01, + "31": 438.84, + "32": 438.72, + "33": 439.09, + "34": 438.3, + "35": 439.48, + "36": 438.2, + "37": 439.67, + "38": 440.65, + "39": 439.96, + "40": 439.0, + "41": 439.2, + "42": 439.37, + "43": 439.98, + "44": 438.6, + "45": 439.58, + "46": 440.23, + "47": 440.1, + "48": 440.21, + "49": 439.22 + }, + "expected_e2e_ms": 27624.8, + "expected_avg_denoise_ms": 518.23, + "expected_median_denoise_ms": 528.06 + }, + "qwen_image_edit_2511_ti2i": { + "stages_ms": { + "InputValidationStage": 55.15, + "ImageEncodingStage": 770.33, + "ImageVAEEncodingStage": 88.06, + "TimestepPreparationStage": 2.12, + "LatentPreparationStage": 0.14, + "DenoisingStage": 23869.32, + "DecodingStage": 108.23 + }, + "denoise_step_ms": { + "0": 478.35, + "1": 608.56, + "2": 588.51, + "3": 607.26, + "4": 599.37, + "5": 595.19, + "6": 603.22, + "7": 594.48, + "8": 605.06, + "9": 597.63, + "10": 601.03, + "11": 597.18, + "12": 598.82, + "13": 600.05, + "14": 598.57, + "15": 601.4, + "16": 595.17, + "17": 599.21, + "18": 600.86, + "19": 600.93, + "20": 600.35, + "21": 600.63, + "22": 597.58, + "23": 600.73, + "24": 599.36, + "25": 600.48, + "26": 600.33, + "27": 599.34, + "28": 599.61, + "29": 599.71, + "30": 596.03, + "31": 599.85, + "32": 599.36, + "33": 601.58, + "34": 597.91, + "35": 600.79, + "36": 599.29, + "37": 601.64, + "38": 598.24, + "39": 599.87 + }, + "expected_e2e_ms": 24895.28, + "expected_avg_denoise_ms": 596.59, + "expected_median_denoise_ms": 599.66 + }, + "fsdp-inference": { + "stages_ms": { + "InputValidationStage": 0.04, + "TextEncodingStage": 411.12, + "TimestepPreparationStage": 1.44, + "LatentPreparationStage": 0.1, + "DenoisingStage": 1569.61, + "DecodingStage": 41.43 + }, + "denoise_step_ms": { + "0": 165.33, + "1": 158.34, + "2": 167.65, + "3": 179.11, + "4": 183.98, + "5": 175.08, + "6": 178.34, + "7": 178.53, + "8": 178.08 + }, + "expected_e2e_ms": 2103.05, + "expected_avg_denoise_ms": 173.83, + "expected_median_denoise_ms": 178.08 + }, + "hunyuan3d_shape_gen": { + "stages_ms": { + "Hunyuan3DShapeBeforeDenoisingStage": 31.42, + "Hunyuan3DShapeDenoisingStage": 3259.83, + "Hunyuan3DShapeExportStage": 8735.55, + "Hunyuan3DShapeSaveStage": 981.64, + "Hunyuan3DPaintPreprocessStage": 226071.67, + "Hunyuan3DPaintTexGenStage": 11083.05, + "Hunyuan3DPaintPostprocessStage": 7469.29 + }, + "denoise_step_ms": { + "0": 32.26, + "1": 63.34, + "2": 65.44, + "3": 65.44, + "4": 65.6, + "5": 65.81, + "6": 65.82, + "7": 65.48, + "8": 65.9, + "9": 65.77, + "10": 65.54, + "11": 65.68, + "12": 65.85, + "13": 65.77, + "14": 65.7, + "15": 65.78, + "16": 66.0, + "17": 66.15, + "18": 65.91, + "19": 66.5, + "20": 65.76, + "21": 66.08, + "22": 66.06, + "23": 66.23, + "24": 65.79, + "25": 65.58, + "26": 65.88, + "27": 65.67, + "28": 65.87, + "29": 66.09, + "30": 65.81, + "31": 65.91, + "32": 66.18, + "33": 65.93, + "34": 66.26, + "35": 66.26, + "36": 66.27, + "37": 65.57, + "38": 66.02, + "39": 66.19, + "40": 65.23, + "41": 66.11, + "42": 66.18, + "43": 65.86, + "44": 65.86, + "45": 65.92, + "46": 65.65, + "47": 65.78, + "48": 66.01, + "49": 66.08 + }, + "expected_e2e_ms": 257696.97, + "expected_avg_denoise_ms": 65.16, + "expected_median_denoise_ms": 65.86 + }, + "wan2_1_t2v_1.3b_frame_interp_2x": { + "stages_ms": { + "TextEncodingStage": 1104.4, + "TimestepPreparationStage": 2.19, + "LatentPreparationStage": 0.15, + "DenoisingStage": 8502.22, + "DecodingStage": 498.36, + "InputValidationStage": 0.07 + }, + "denoise_step_ms": { + "0": 91.83, + "1": 174.57, + "2": 170.48, + "3": 169.33, + "4": 169.24, + "5": 177.43, + "6": 173.73, + "7": 171.67, + "8": 170.98, + "9": 168.61, + "10": 169.96, + "11": 174.75, + "12": 172.33, + "13": 170.62, + "14": 169.84, + "15": 168.86, + "16": 171.32, + "17": 174.7, + "18": 172.31, + "19": 171.71, + "20": 170.98, + "21": 169.83, + "22": 170.54, + "23": 173.08, + "24": 172.11, + "25": 171.49, + "26": 171.0, + "27": 170.9, + "28": 171.78, + "29": 173.44, + "30": 171.14, + "31": 170.72, + "32": 170.64, + "33": 170.58, + "34": 172.51, + "35": 171.74, + "36": 171.57, + "37": 170.73, + "38": 171.49, + "39": 170.98, + "40": 172.63, + "41": 171.88, + "42": 171.71, + "43": 170.94, + "44": 170.31, + "45": 171.25, + "46": 171.43, + "47": 171.55, + "48": 172.08, + "49": 169.92 + }, + "expected_e2e_ms": 10464.97, + "expected_avg_denoise_ms": 169.92, + "expected_median_denoise_ms": 171.37 + }, + "flux_2_klein_ti2i_2_gpus": { + "stages_ms": { + "InputValidationStage": 40.19, + "TextEncodingStage": 88.84, + "ImageVAEEncodingStage": 80.81, + "LatentPreparationStage": 1.05, + "TimestepPreparationStage": 28.64, + "DenoisingStage": 354.04, + "DecodingStage": 11.11 + }, + "denoise_step_ms": { + "0": 33.54, + "1": 61.3, + "2": 86.9, + "3": 87.55 + }, + "expected_e2e_ms": 716.81, + "expected_avg_denoise_ms": 67.32, + "expected_median_denoise_ms": 74.1 + } + } +} diff --git a/sglang/python/sglang/multimodal_gen/test/server/test_server_2_gpu_a.py b/sglang/python/sglang/multimodal_gen/test/server/test_server_2_gpu_a.py new file mode 100644 index 0000000000000000000000000000000000000000..3668f63e6334663a6c536863f15e09b1c4f13302 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/server/test_server_2_gpu_a.py @@ -0,0 +1,25 @@ +""" +2 GPU tests +""" + +from __future__ import annotations + +import pytest + +from sglang.multimodal_gen.test.server.test_server_common import ( # noqa: F401 + DiffusionServerBase, + diffusion_server, +) +from sglang.multimodal_gen.test.server.testcase_configs import ( + TWO_GPU_CASES_A, + DiffusionTestCase, +) + + +class TestDiffusionServerTwoGpu(DiffusionServerBase): + """Performance tests for 2-GPU diffusion cases.""" + + @pytest.fixture(params=TWO_GPU_CASES_A, ids=lambda c: c.id) + def case(self, request) -> DiffusionTestCase: + """Provide a DiffusionTestCase for each 2-GPU test.""" + return request.param diff --git a/sglang/python/sglang/multimodal_gen/test/server/test_server_2_gpu_b.py b/sglang/python/sglang/multimodal_gen/test/server/test_server_2_gpu_b.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9b5cdc7640ab9186a43cc4a586adce2352a403 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/server/test_server_2_gpu_b.py @@ -0,0 +1,25 @@ +""" +2 GPU tests +""" + +from __future__ import annotations + +import pytest + +from sglang.multimodal_gen.test.server.test_server_common import ( # noqa: F401 + DiffusionServerBase, + diffusion_server, +) +from sglang.multimodal_gen.test.server.testcase_configs import ( + TWO_GPU_CASES_B, + DiffusionTestCase, +) + + +class TestDiffusionServerTwoGpu(DiffusionServerBase): + """Performance tests for 2-GPU diffusion cases.""" + + @pytest.fixture(params=TWO_GPU_CASES_B, ids=lambda c: c.id) + def case(self, request) -> DiffusionTestCase: + """Provide a DiffusionTestCase for each 2-GPU test.""" + return request.param diff --git a/sglang/python/sglang/multimodal_gen/test/server/test_server_a.py b/sglang/python/sglang/multimodal_gen/test/server/test_server_a.py new file mode 100644 index 0000000000000000000000000000000000000000..fdf072ec89e14f4215b852d14309758b4e80e4b9 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/server/test_server_a.py @@ -0,0 +1,31 @@ +""" +Config-driven diffusion performance test with pytest parametrization. + + +If the actual run is significantly better than the baseline, the improved cases with their updated baseline will be printed +""" + +from __future__ import annotations + +import pytest + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.test.server.test_server_common import ( # noqa: F401 + DiffusionServerBase, + diffusion_server, +) +from sglang.multimodal_gen.test.server.testcase_configs import ( + ONE_GPU_CASES_A, + DiffusionTestCase, +) + +logger = init_logger(__name__) + + +class TestDiffusionServerOneGpu(DiffusionServerBase): + """Performance tests for 1-GPU diffusion cases.""" + + @pytest.fixture(params=ONE_GPU_CASES_A, ids=lambda c: c.id) + def case(self, request) -> DiffusionTestCase: + """Provide a DiffusionTestCase for each 1-GPU test.""" + return request.param diff --git a/sglang/python/sglang/multimodal_gen/test/server/test_server_b.py b/sglang/python/sglang/multimodal_gen/test/server/test_server_b.py new file mode 100644 index 0000000000000000000000000000000000000000..1a0432db6f3bca0df47d521292ef5f43d1ca14fa --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/server/test_server_b.py @@ -0,0 +1,31 @@ +""" +Config-driven diffusion performance test with pytest parametrization. + + +If the actual run is significantly better than the baseline, the improved cases with their updated baseline will be printed +""" + +from __future__ import annotations + +import pytest + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.test.server.test_server_common import ( # noqa: F401 + DiffusionServerBase, + diffusion_server, +) +from sglang.multimodal_gen.test.server.testcase_configs import ( + ONE_GPU_CASES_B, + DiffusionTestCase, +) + +logger = init_logger(__name__) + + +class TestDiffusionServerOneGpu(DiffusionServerBase): + """Performance tests for 1-GPU diffusion cases.""" + + @pytest.fixture(params=ONE_GPU_CASES_B, ids=lambda c: c.id) + def case(self, request) -> DiffusionTestCase: + """Provide a DiffusionTestCase for each 1-GPU test.""" + return request.param diff --git a/sglang/python/sglang/multimodal_gen/test/server/test_server_common.py b/sglang/python/sglang/multimodal_gen/test/server/test_server_common.py new file mode 100644 index 0000000000000000000000000000000000000000..690c5ba6263676c9134b5beece61643c8f1552bf --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/server/test_server_common.py @@ -0,0 +1,897 @@ +""" +Config-driven diffusion generation test with pytest parametrization. + + +If the actual run is significantly better than the baseline, the improved cases with their updated baseline will be printed +""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any, Callable + +import openai +import pytest +import requests +from openai import OpenAI + +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.runtime.utils.perf_logger import RequestPerfRecord +from sglang.multimodal_gen.test.server import conftest +from sglang.multimodal_gen.test.server.test_server_utils import ( + VALIDATOR_REGISTRY, + PerformanceValidator, + ServerContext, + ServerManager, + get_generate_fn, +) +from sglang.multimodal_gen.test.server.testcase_configs import ( + BASELINE_CONFIG, + DiffusionTestCase, + PerformanceSummary, + ScenarioConfig, +) +from sglang.multimodal_gen.test.test_utils import ( + _consistency_gt_filenames, + extract_key_frames_from_video, + get_dynamic_server_port, + wait_for_req_perf_record, +) + +logger = init_logger(__name__) + + +@pytest.fixture +def diffusion_server(case: DiffusionTestCase) -> ServerContext: + """Start a diffusion server for a single case and tear it down afterwards.""" + server_args = case.server_args + + # Skip ring attention tests on AMD/ROCm - Ring Attention requires Flash Attention + # which is not available on AMD. Use Ulysses parallelism instead. + if ( + current_platform.is_hip() + and server_args.ring_degree is not None + and server_args.ring_degree > 1 + ): + pytest.skip( + f"Skipping {case.id}: Ring Attention (ring_degree={server_args.ring_degree}) " + "requires Flash Attention which is not available on AMD/ROCm" + ) + + default_port = get_dynamic_server_port() + port = int(os.environ.get("SGLANG_TEST_SERVER_PORT", default_port)) + sampling_params = case.sampling_params + extra_args = os.environ.get("SGLANG_TEST_SERVE_ARGS", "") + + # In GT generation mode, force --backend diffusers + if os.environ.get("SGLANG_GEN_GT", "0") == "1": + if "--backend" not in extra_args: + extra_args = "--backend diffusers " + extra_args.strip() + + extra_args += f" --num-gpus {server_args.num_gpus}" + + if server_args.tp_size is not None: + extra_args += f" --tp-size {server_args.tp_size}" + + if server_args.ulysses_degree is not None: + extra_args += f" --ulysses-degree {server_args.ulysses_degree}" + + if server_args.dit_layerwise_offload: + extra_args += f" --dit-layerwise-offload true" + + if server_args.dit_offload_prefetch_size: + extra_args += ( + f" --dit-offload-prefetch-size {server_args.dit_offload_prefetch_size}" + ) + + if server_args.text_encoder_cpu_offload: + extra_args += f" --text-encoder-cpu-offload" + + if server_args.ring_degree is not None: + extra_args += f" --ring-degree {server_args.ring_degree}" + + # LoRA support + if server_args.lora_path: + extra_args += f" --lora-path {server_args.lora_path}" + + # default warmup + extra_args += f" --warmup" + + for arg in server_args.extras: + extra_args += f" {arg}" + + # Build custom environment variables + env_vars = {} + if server_args.enable_cache_dit: + env_vars["SGLANG_CACHE_DIT_ENABLED"] = "true" + + # start server + manager = ServerManager( + model=server_args.model_path, + port=port, + wait_deadline=float(os.environ.get("SGLANG_TEST_WAIT_SECS", "1200")), + extra_args=extra_args, + env_vars=env_vars, + ) + ctx = manager.start() + + try: + # Reconstruct output size for OpenAI API + # Allow override via environment variable (useful for AMD where large resolutions can cause GPU hang) + output_size = os.environ.get( + "SGLANG_TEST_OUTPUT_SIZE", sampling_params.output_size + ) + except Exception as exc: + logger.error("Warm-up failed for %s: %s", case.id, exc) + ctx.cleanup() + raise + + try: + yield ctx + finally: + ctx.cleanup() + + +class DiffusionServerBase: + """Performance tests for all diffusion models/scenarios. + + This single test class runs against all cases defined in ONE_GPU_CASES. + Each case gets its own server instance via the parametrized fixture. + """ + + _perf_results: list[dict[str, Any]] = [] + _improved_baselines: list[dict[str, Any]] = [] + _pytest_config = None # Store pytest config for stash access + + @classmethod + def setup_class(cls): + cls._perf_results = [] + cls._improved_baselines = [] + + @classmethod + def teardown_class(cls): + print( + f"\n[DEBUG teardown_class] Called for {cls.__name__}, _perf_results has {len(cls._perf_results)} entries" + ) + if cls._pytest_config: + # Add results to pytest stash (shared across all import contexts) + for result in cls._perf_results: + result["class_name"] = cls.__name__ + conftest.add_perf_results(cls._pytest_config, cls._perf_results) + print( + f"[DEBUG teardown_class] Added {len(cls._perf_results)} results to stash" + ) + else: + print( + "[DEBUG teardown_class] No pytest_config available, skipping stash update" + ) + + if cls._improved_baselines: + import json + + output = """ +--- POTENTIAL BASELINE IMPROVEMENTS DETECTED --- +The following test cases performed significantly better than their baselines. +Consider updating perf_baselines.json with the snippets below: +""" + for item in cls._improved_baselines: + output += ( + f'\n"{item["id"]}": {json.dumps(item["baseline"], indent=4)},\n' + ) + print(output) + + @pytest.fixture(autouse=True) + def _capture_pytest_config(self, request): + """Capture pytest config for use in teardown_class.""" + self.__class__._pytest_config = request.config + + def _client(self, ctx: ServerContext) -> OpenAI: + """Get OpenAI client for the server.""" + return OpenAI( + api_key="sglang-anything", + base_url=f"http://localhost:{ctx.port}/v1", + ) + + def run_and_collect( + self, + ctx: ServerContext, + case_id: str, + generate_fn: Callable[[str, openai.Client], tuple[str, bytes]], + ) -> tuple[RequestPerfRecord, bytes]: + """Run generation and collect performance records. + + Returns: + Tuple of (performance_record, content_bytes) + """ + log_path = ctx.perf_log_path + log_wait_timeout = 30 + + client = self._client(ctx) + rid, content = generate_fn(case_id, client) + + req_perf_record = wait_for_req_perf_record( + rid, + log_path, + timeout=log_wait_timeout, + ) + + return (req_perf_record, content) + + def _validate_and_record( + self, + case: DiffusionTestCase, + perf_record: RequestPerfRecord, + ) -> None: + """Validate metrics and record results.""" + is_baseline_generation_mode = os.environ.get("SGLANG_GEN_BASELINE", "0") == "1" + + scenario = BASELINE_CONFIG.scenarios.get(case.id) + missing_scenario = False + if scenario is None: + # Create dummy scenario to allow metric collection + scenario = type( + "DummyScenario", + (), + { + "expected_e2e_ms": 0, + "expected_avg_denoise_ms": 0, + "expected_median_denoise_ms": 0, + "stages_ms": {}, + "denoise_step_ms": {}, + }, + )() + if not is_baseline_generation_mode: + missing_scenario = True + + validator_name = case.server_args.custom_validator or "default" + validator_class = VALIDATOR_REGISTRY.get(validator_name, PerformanceValidator) + + validator = validator_class( + scenario=scenario, + tolerances=BASELINE_CONFIG.tolerances, + step_fractions=BASELINE_CONFIG.step_fractions, + ) + + summary = validator.collect_metrics(perf_record) + + if case.run_perf_check: + if is_baseline_generation_mode or missing_scenario: + self._dump_baseline_for_testcase(case, summary, missing_scenario) + if missing_scenario: + pytest.fail( + f"Testcase '{case.id}' not found in perf_baselines.json" + ) + return + + self._check_for_improvement(case, summary, scenario) + + # only run performance validation if run_perf_check is True + try: + validator.validate(perf_record, case.sampling_params.num_frames) + except AssertionError as e: + logger.error(f"Performance validation failed for {case.id}:\n{e}") + self._dump_baseline_for_testcase(case, summary, missing_scenario) + raise + + result = { + "test_name": case.id, + "modality": case.server_args.modality, + "e2e_ms": summary.e2e_ms, + "avg_denoise_ms": summary.avg_denoise_ms, + "median_denoise_ms": summary.median_denoise_ms, + "stage_metrics": summary.stage_metrics, + "sampled_steps": summary.sampled_steps, + } + + # video-specific metrics + if summary.frames_per_second: + result.update( + { + "frames_per_second": summary.frames_per_second, + "total_frames": summary.total_frames, + "avg_frame_time_ms": summary.avg_frame_time_ms, + } + ) + + self.__class__._perf_results.append(result) + print( + f"[DEBUG _validate_and_record] Appended result for {case.id}, class {self.__class__.__name__} now has {len(self.__class__._perf_results)} results" + ) + + def _check_for_improvement( + self, + case: DiffusionTestCase, + summary: PerformanceSummary, + scenario: "ScenarioConfig", + ) -> None: + """Check for potential significant performance improvements and record them.""" + is_improved = False + threshold = BASELINE_CONFIG.improvement_threshold + + def is_sig_faster(actual, expected): + if expected == 0 or expected is None: + return False + return actual < expected * (1 - threshold) + + def safe_get_metric(metric_dict, key): + val = metric_dict.get(key) + return val if val is not None else float("inf") + + # Check for any significant improvement + if ( + is_sig_faster(summary.e2e_ms, scenario.expected_e2e_ms) + or is_sig_faster(summary.avg_denoise_ms, scenario.expected_avg_denoise_ms) + or is_sig_faster( + summary.median_denoise_ms, scenario.expected_median_denoise_ms + ) + ): + is_improved = True + # Combine metrics, always taking the better (lower) value + new_stages = { + stage: min( + safe_get_metric(summary.stage_metrics, stage), + safe_get_metric(scenario.stages_ms, stage), + ) + for stage in set(summary.stage_metrics) | set(scenario.stages_ms) + } + new_denoise_steps = { + step: min( + safe_get_metric(summary.all_denoise_steps, step), + safe_get_metric(scenario.denoise_step_ms, step), + ) + for step in set(summary.all_denoise_steps.keys()) + | set(scenario.denoise_step_ms) + } + + # Check for stage-level improvements + if not is_improved: + for stage, new_val in new_stages.items(): + if is_sig_faster(new_val, scenario.stages_ms.get(stage, float("inf"))): + is_improved = True + break + if not is_improved: + for step, new_val in new_denoise_steps.items(): + if is_sig_faster( + new_val, scenario.denoise_step_ms.get(step, float("inf")) + ): + is_improved = True + break + + if is_improved: + new_baseline = { + "stages_ms": {k: round(v, 2) for k, v in new_stages.items()}, + "denoise_step_ms": { + str(k): round(v, 2) for k, v in new_denoise_steps.items() + }, + "expected_e2e_ms": round( + min(summary.e2e_ms, scenario.expected_e2e_ms), 2 + ), + "expected_avg_denoise_ms": round( + min(summary.avg_denoise_ms, scenario.expected_avg_denoise_ms), 2 + ), + "expected_median_denoise_ms": round( + min(summary.median_denoise_ms, scenario.expected_median_denoise_ms), + 2, + ), + } + self._improved_baselines.append({"id": case.id, "baseline": new_baseline}) + + def _dump_baseline_for_testcase( + self, + case: DiffusionTestCase, + summary: "PerformanceSummary", + missing_scenario: bool = False, + ) -> None: + """Dump performance metrics as a JSON scenario for baselines.""" + import json + + denoise_steps_formatted = { + str(k): round(v, 2) for k, v in summary.all_denoise_steps.items() + } + stages_formatted = {k: round(v, 2) for k, v in summary.stage_metrics.items()} + + baseline = { + "stages_ms": stages_formatted, + "denoise_step_ms": denoise_steps_formatted, + "expected_e2e_ms": round(summary.e2e_ms, 2), + "expected_avg_denoise_ms": round(summary.avg_denoise_ms, 2), + "expected_median_denoise_ms": round(summary.median_denoise_ms, 2), + } + + # Video-specific metrics + if case.server_args.modality == "video": + if "per_frame_generation" not in baseline["stages_ms"]: + baseline["stages_ms"]["per_frame_generation"] = ( + round(summary.avg_frame_time_ms, 2) + if summary.avg_frame_time_ms + else None + ) + action = "add" if missing_scenario else "update" + output = f""" +{action} this baseline in the "scenarios" section of perf_baselines.json: + +"{case.id}": {json.dumps(baseline, indent=4)} + +""" + logger.error(output) + + def _save_gt_output( + self, + case: DiffusionTestCase, + content: bytes, + ) -> None: + """Save generated content as ground truth files. + + Args: + case: Test case configuration + content: Generated content bytes (image or video) + """ + gt_output_dir = os.environ.get("SGLANG_GT_OUTPUT_DIR") + if not gt_output_dir: + logger.error("SGLANG_GT_OUTPUT_DIR not set, cannot save GT output") + return + + out_dir = Path(gt_output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + num_gpus = case.server_args.num_gpus + is_video = case.server_args.modality == "video" + + if is_video: + # Extract key frames from video + frames = extract_key_frames_from_video( + content, num_frames=case.sampling_params.num_frames + ) + + if len(frames) != 3: + logger.warning( + f"{case.id}: expected 3 frames, got {len(frames)}, skipping frame save" + ) + return + + # Save frames (reuse naming from _consistency_gt_filenames) + filenames = _consistency_gt_filenames(case.id, num_gpus, is_video=True) + from PIL import Image + + for frame, fn in zip(frames, filenames): + frame_path = out_dir / fn + Image.fromarray(frame).save(frame_path) + logger.info(f"Saved GT frame: {frame_path}") + else: + # Save image + from sglang.multimodal_gen.test.test_utils import detect_image_format + + detected_format = detect_image_format(content) + filenames = _consistency_gt_filenames( + case.id, num_gpus, is_video=False, output_format=detected_format + ) + output_path = out_dir / filenames[0] + output_path.write_bytes(content) + logger.info(f"Saved GT image: {output_path} (format: {detected_format})") + + def _test_lora_api_functionality( + self, + ctx: ServerContext, + case: DiffusionTestCase, + generate_fn: Callable[[str, openai.Client], tuple[str, bytes]], + ) -> None: + """ + Test LoRA API functionality with end-to-end validation: merge, unmerge, and set_lora. + This test verifies that each API call succeeds AND that generation works after each operation. + """ + base_url = f"http://localhost:{ctx.port}/v1" + client = OpenAI(base_url=base_url, api_key="dummy") + + # Test 1: unmerge_lora_weights - API should succeed and generation should work + logger.info("[LoRA E2E] Testing unmerge_lora_weights for %s", case.id) + resp = requests.post(f"{base_url}/unmerge_lora_weights") + assert resp.status_code == 200, f"unmerge_lora_weights failed: {resp.text}" + + logger.info("[LoRA E2E] Verifying generation after unmerge for %s", case.id) + rid_after_unmerge, _ = generate_fn(case.id, client) + assert rid_after_unmerge is not None, "Generation after unmerge failed" + logger.info("[LoRA E2E] Generation after unmerge succeeded") + + # Test 2: merge_lora_weights - API should succeed and generation should work + logger.info("[LoRA E2E] Testing merge_lora_weights for %s", case.id) + resp = requests.post(f"{base_url}/merge_lora_weights") + assert resp.status_code == 200, f"merge_lora_weights failed: {resp.text}" + + logger.info("[LoRA E2E] Verifying generation after re-merge for %s", case.id) + rid_after_merge, _ = generate_fn(case.id, client) + assert rid_after_merge is not None, "Generation after merge failed" + logger.info("[LoRA E2E] Generation after merge succeeded") + + # Test 3: set_lora (re-set the same adapter) - API should succeed and generation should work + logger.info("[LoRA E2E] Testing set_lora for %s", case.id) + resp = requests.post(f"{base_url}/set_lora", json={"lora_nickname": "default"}) + assert resp.status_code == 200, f"set_lora failed: {resp.text}" + + logger.info("[LoRA E2E] Verifying generation after set_lora for %s", case.id) + rid_after_set, _ = generate_fn(case.id, client) + assert rid_after_set is not None, "Generation after set_lora failed" + logger.info("[LoRA E2E] Generation after set_lora succeeded") + + # Test 4: list_loras - API should return the expected list of LoRA adapters + logger.info("[LoRA E2E] Testing list_loras for %s", case.id) + resp = requests.get(f"{base_url}/list_loras") + assert resp.status_code == 200, f"list_loras failed: {resp.text}" + lora_info = resp.json() + logger.info("[LoRA E2E] list_loras returned %s", lora_info) + assert ( + isinstance(lora_info["loaded_adapters"], list) + and len(lora_info["loaded_adapters"]) > 0 + ), "loaded_adapters should be a non-empty list" + assert any( + a.get("nickname") == "default" for a in lora_info["loaded_adapters"] + ), f"nickname 'default' not found in loaded_adapters: {lora_info['loaded_adapters']}" + logger.info("[LoRA E2E] list_loras returned expected LoRA adapters") + + logger.info("[LoRA E2E] All LoRA API E2E tests passed for %s", case.id) + + def _test_lora_dynamic_switch_e2e( + self, + ctx: ServerContext, + case: DiffusionTestCase, + generate_fn: Callable[[str, openai.Client], tuple[str, bytes]], + second_lora_path: str, + ) -> None: + """ + Test dynamic LoRA switching with end-to-end validation. + This test verifies that switching between LoRA adapters works correctly + and generation succeeds after each switch. + """ + base_url = f"http://localhost:{ctx.port}/v1" + client = OpenAI(base_url=base_url, api_key="dummy") + + # Test 1: Generate with initial LoRA + logger.info( + "[LoRA Switch E2E] Testing generation with initial LoRA for %s", case.id + ) + rid_initial, _ = generate_fn(case.id, client) + assert rid_initial is not None, "Generation with initial LoRA failed" + logger.info("[LoRA Switch E2E] Generation with initial LoRA succeeded") + + # Test 2: Switch to second LoRA and generate + logger.info( + "[LoRA Switch E2E] Switching to second LoRA adapter for %s", case.id + ) + resp = requests.post( + f"{base_url}/set_lora", + json={"lora_nickname": "lora2", "lora_path": second_lora_path}, + ) + assert ( + resp.status_code == 200 + ), f"set_lora to second adapter failed: {resp.text}" + + logger.info( + "[LoRA Switch E2E] Verifying generation with second LoRA for %s", case.id + ) + rid_second, _ = generate_fn(case.id, client) + assert rid_second is not None, "Generation with second LoRA failed" + logger.info("[LoRA Switch E2E] Generation with second LoRA succeeded") + + # Test 3: Switch back to original LoRA and generate + logger.info("[LoRA Switch E2E] Switching back to original LoRA for %s", case.id) + resp = requests.post(f"{base_url}/set_lora", json={"lora_nickname": "default"}) + assert resp.status_code == 200, f"set_lora back to default failed: {resp.text}" + + logger.info( + "[LoRA Switch E2E] Verifying generation after switching back for %s", + case.id, + ) + rid_switched_back, _ = generate_fn(case.id, client) + assert rid_switched_back is not None, "Generation after switching back failed" + logger.info("[LoRA Switch E2E] Generation after switching back succeeded") + + logger.info( + "[LoRA Switch E2E] All dynamic switch E2E tests passed for %s", case.id + ) + + def _test_dynamic_lora_loading( + self, + ctx: ServerContext, + case: DiffusionTestCase, + ) -> None: + """ + Test dynamic LoRA loading after server startup. + + This test reproduces the LayerwiseOffload + set_lora issue: + - Server starts WITHOUT lora_path (LayerwiseOffloadManager initializes first) + - Then set_lora is called via API to load LoRA dynamically + - This tests the interaction between layerwise offload and dynamic LoRA loading + """ + base_url = f"http://localhost:{ctx.port}/v1" + dynamic_lora_path = case.server_args.dynamic_lora_path + + # Call set_lora to load LoRA dynamically after server startup + logger.info( + "[Dynamic LoRA] Loading LoRA dynamically via set_lora API for %s", case.id + ) + logger.info("[Dynamic LoRA] LoRA path: %s", dynamic_lora_path) + resp = requests.post( + f"{base_url}/set_lora", + json={"lora_nickname": "default", "lora_path": dynamic_lora_path}, + ) + assert resp.status_code == 200, f"Dynamic set_lora failed: {resp.text}" + logger.info("[Dynamic LoRA] set_lora succeeded for %s", case.id) + + def _test_multi_lora_e2e( + self, + ctx: ServerContext, + case: DiffusionTestCase, + generate_fn: Callable[[str, openai.Client], tuple[str, bytes]], + first_lora_path: str, + second_lora_path: str, + ) -> None: + """ + Test multiple LoRA adapters with different set_lora input scenarios. + Tests: basic multi-LoRA, different strengths, cached adapters, switch back to single. + """ + base_url = f"http://localhost:{ctx.port}/v1" + client = OpenAI(base_url=base_url, api_key="dummy") + + # Test 1: Basic multi-LoRA with list format + resp = requests.post( + f"{base_url}/set_lora", + json={ + "lora_nickname": ["default", "lora2"], + "lora_path": [first_lora_path, second_lora_path], + "target": "all", + "strength": [1.0, 1.0], + }, + ) + assert ( + resp.status_code == 200 + ), f"set_lora with multiple adapters failed: {resp.text}" + rid, _ = generate_fn(case.id, client) + assert rid is not None + + # Test 2: Different strengths + resp = requests.post( + f"{base_url}/set_lora", + json={ + "lora_nickname": ["default", "lora2"], + "lora_path": [first_lora_path, second_lora_path], + "target": "all", + "strength": [0.8, 0.5], + }, + ) + assert ( + resp.status_code == 200 + ), f"set_lora with different strengths failed: {resp.text}" + rid, _ = generate_fn(case.id, client) + assert rid is not None + + # Test 3: Different targets + requests.post(f"{base_url}/set_lora", json={"lora_nickname": "default"}) + resp = requests.post( + f"{base_url}/set_lora", + json={ + "lora_nickname": ["default", "lora2"], + "lora_path": [first_lora_path, second_lora_path], + "target": ["transformer", "transformer_2"], + "strength": [0.8, 0.5], + }, + ) + assert ( + resp.status_code == 200 + ), f"set_lora with cached adapters failed: {resp.text}" + rid, _ = generate_fn(case.id, client) + assert rid is not None + + # Test 4: Switch back to single LoRA + resp = requests.post(f"{base_url}/set_lora", json={"lora_nickname": "default"}) + assert ( + resp.status_code == 200 + ), f"set_lora back to single adapter failed: {resp.text}" + rid, _ = generate_fn(case.id, client) + assert rid is not None + + logger.info("[Multi-LoRA] All multi-LoRA tests passed for %s", case.id) + + def _test_v1_models_endpoint( + self, ctx: ServerContext, case: DiffusionTestCase + ) -> None: + """ + Test /v1/models endpoint returns OpenAI-compatible response. + This endpoint is required for sgl-model-gateway router compatibility. + """ + base_url = f"http://localhost:{ctx.port}" + + # Test GET /v1/models + logger.info("[Models API] Testing GET /v1/models for %s", case.id) + resp = requests.get(f"{base_url}/v1/models") + assert resp.status_code == 200, f"/v1/models failed: {resp.text}" + + data = resp.json() + assert ( + data["object"] == "list" + ), f"Expected object='list', got {data.get('object')}" + assert len(data["data"]) >= 1, "Expected at least one model in response" + + model = data["data"][0] + assert "id" in model, "Model missing 'id' field" + assert ( + model["object"] == "model" + ), f"Expected object='model', got {model.get('object')}" + assert ( + model["id"] == case.server_args.model_path + ), f"Model ID mismatch: expected {case.server_args.model_path}, got {model['id']}" + + # Verify extended diffusion-specific fields + assert "num_gpus" in model, "Model missing 'num_gpus' field" + assert "task_type" in model, "Model missing 'task_type' field" + assert "dit_precision" in model, "Model missing 'dit_precision' field" + assert "vae_precision" in model, "Model missing 'vae_precision' field" + assert ( + model["num_gpus"] == case.server_args.num_gpus + ), f"num_gpus mismatch: expected {case.server_args.num_gpus}, got {model['num_gpus']}" + # Verify task_type is consistent with the modality specified in the test config. + # We can't access pipeline_config from test config, but we can validate against modality. + modality_to_valid_task_types = { + "image": {"T2I", "I2I", "TI2I"}, + "video": {"T2V", "I2V", "TI2V"}, + "3d": {"I2M"}, + } + valid_task_types = modality_to_valid_task_types.get( + case.server_args.modality, set() + ) + assert model["task_type"] in valid_task_types, ( + f"task_type '{model['task_type']}' not valid for modality " + f"'{case.server_args.modality}'. Expected one of: {valid_task_types}" + ) + logger.info( + "[Models API] GET /v1/models returned valid response with extended fields" + ) + + # Test GET /v1/models/{model_path} + model_path = model["id"] + logger.info("[Models API] Testing GET /v1/models/%s", model_path) + resp = requests.get(f"{base_url}/v1/models/{model_path}") + assert resp.status_code == 200, f"/v1/models/{model_path} failed: {resp.text}" + + single_model = resp.json() + assert single_model["id"] == model_path, "Single model ID mismatch" + assert single_model["object"] == "model", "Single model object type mismatch" + + # Verify extended fields on single model endpoint too + assert "num_gpus" in single_model, "Single model missing 'num_gpus' field" + assert "task_type" in single_model, "Single model missing 'task_type' field" + assert single_model["task_type"] in valid_task_types, ( + f"Single model task_type '{single_model['task_type']}' not valid for modality " + f"'{case.server_args.modality}'. Expected one of: {valid_task_types}" + ) + logger.info( + "[Models API] GET /v1/models/{model_path} returned valid response with extended fields" + ) + + # Test GET /v1/models/{non_existent_model} returns 404 + logger.info("[Models API] Testing GET /v1/models/non_existent_model") + resp = requests.get(f"{base_url}/v1/models/non_existent_model") + assert resp.status_code == 404, f"Expected 404, got {resp.status_code}" + error_data = resp.json() + assert "error" in error_data, "404 response missing 'error' field" + assert ( + error_data["error"]["code"] == "model_not_found" + ), f"Incorrect error code: {error_data['error'].get('code')}" + logger.info("[Models API] GET /v1/models/non_existent returns 404 as expected") + + logger.info("[Models API] All /v1/models tests passed for %s", case.id) + + def _test_t2v_rejects_input_reference( + self, ctx: ServerContext, case: DiffusionTestCase + ) -> None: + if case.server_args.modality != "video": + return + + base_url = f"http://localhost:{ctx.port}" + resp = requests.get(f"{base_url}/v1/models") + assert resp.status_code == 200, f"/v1/models failed: {resp.text}" + data = resp.json().get("data", []) + if not data: + pytest.fail("/v1/models returned empty model list") + + task_type = data[0].get("task_type") + if task_type != "T2V": + return + + prompt = case.sampling_params.prompt or "test" + payload = {"prompt": prompt, "input_reference": "dummy"} + if case.sampling_params.output_size: + payload["size"] = case.sampling_params.output_size + + resp = requests.post(f"{base_url}/v1/videos", json=payload) + assert ( + resp.status_code == 400 + ), f"Expected 400 for T2V input_reference, got {resp.status_code}: {resp.text}" + detail = resp.json().get("detail", "") + assert ( + "input_reference is not supported" in detail + ), f"Unexpected error detail for T2V input_reference: {detail}" + + def test_diffusion_generation( + self, + case: DiffusionTestCase, + diffusion_server: ServerContext, + ): + """Single parametrized test that runs for all cases. + + This test performs: + 1. Generation + 2. Performance validation against baselines + 3. Consistency validation against ground truth + + Pytest will execute this test once per case in ONE_GPU_CASES, + with test IDs like: + - test_diffusion_generation[qwen_image_text] + - test_diffusion_generation[qwen_image_edit] + - etc. + """ + # Check if we're in GT generation mode + is_gt_gen_mode = os.environ.get("SGLANG_GEN_GT", "0") == "1" + + # Dynamic LoRA loading test - tests LayerwiseOffload + set_lora interaction + # Server starts WITHOUT lora_path, then set_lora is called after startup + if case.server_args.dynamic_lora_path and not is_gt_gen_mode: + self._test_dynamic_lora_loading(diffusion_server, case) + + generate_fn = get_generate_fn( + model_path=case.server_args.model_path, + modality=case.server_args.modality, + sampling_params=case.sampling_params, + ) + + # Single generation - output is reused for both validations + perf_record, content = self.run_and_collect( + diffusion_server, + case.id, + generate_fn, + ) + + if is_gt_gen_mode: + # GT generation mode: save output and skip all validations/tests + self._save_gt_output(case, content) + return + + # Validation 1: Performance + self._validate_and_record(case, perf_record) + + # Mesh correctness check (Chamfer Distance) for 3D models + if case.server_args.custom_validator == "mesh": + from sglang.multimodal_gen.test.server.test_server_utils import ( + MESH_OUTPUT_PATHS, + validate_mesh_correctness, + ) + + mesh_path = MESH_OUTPUT_PATHS.pop(case.id, None) + if mesh_path: + validate_mesh_correctness(mesh_path) + + # Test /v1/models endpoint for router compatibility + self._test_v1_models_endpoint(diffusion_server, case) + self._test_t2v_rejects_input_reference(diffusion_server, case) + + # LoRA API functionality test with E2E validation (only for LoRA-enabled cases) + if case.server_args.lora_path or case.server_args.dynamic_lora_path: + self._test_lora_api_functionality(diffusion_server, case, generate_fn) + + # Test dynamic LoRA switching (requires a second LoRA adapter) + if case.server_args.second_lora_path: + self._test_lora_dynamic_switch_e2e( + diffusion_server, + case, + generate_fn, + case.server_args.second_lora_path, + ) + + # Test multi-LoRA functionality + self._test_multi_lora_e2e( + diffusion_server, + case, + generate_fn, + case.server_args.lora_path, + case.server_args.second_lora_path, + ) diff --git a/sglang/python/sglang/multimodal_gen/test/server/test_server_utils.py b/sglang/python/sglang/multimodal_gen/test/server/test_server_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f72de0a88bf74d02ab8719bbe241d256bad46d9e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/server/test_server_utils.py @@ -0,0 +1,1306 @@ +""" +Server management and performance validation for diffusion tests. +""" + +from __future__ import annotations + +import base64 +import os +import shlex +import subprocess +import sys +import tempfile +import threading +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Sequence +from urllib.request import urlopen + +import pytest +from openai import Client + +from sglang.multimodal_gen.benchmarks.compare_perf import calculate_upper_bound +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.utils.common import kill_process_tree +from sglang.multimodal_gen.runtime.utils.logging_utils import ( + globally_suppress_loggers, + init_logger, +) +from sglang.multimodal_gen.runtime.utils.perf_logger import RequestPerfRecord +from sglang.multimodal_gen.test.server.testcase_configs import ( + DiffusionSamplingParams, + PerformanceSummary, + ScenarioConfig, + ToleranceConfig, +) +from sglang.multimodal_gen.test.slack_utils import upload_file_to_slack +from sglang.multimodal_gen.test.test_utils import ( + get_expected_image_format, + get_video_frame_count, + is_image_url, + prepare_perf_log, + validate_image, + validate_image_file, + validate_openai_video, + validate_video_file, +) + +logger = init_logger(__name__) + +globally_suppress_loggers() + +# Tracks mesh output file paths from generate_mesh for later correctness validation. +# Keyed by case_id, cleaned up after use. +MESH_OUTPUT_PATHS: dict[str, str] = {} + + +def download_image_from_url(url: str) -> Path: + """Download an image from a URL to a temporary file. + + Args: + url: The URL of the image to download + + Returns: + Path to the downloaded temporary file + """ + logger.info(f"Downloading image from URL: {url}") + + # Determine file extension from URL + ext = ".jpg" # default + if url.lower().endswith((".png", ".jpeg", ".jpg", ".webp", ".gif")): + ext = url[url.rfind(".") :] + + # Create temporary file + temp_file = ( + Path(tempfile.gettempdir()) / f"diffusion_test_image_{int(time.time())}{ext}" + ) + + try: + with urlopen(url, timeout=30) as response: + temp_file.write_bytes(response.read()) + logger.info(f"Downloaded image to: {temp_file}") + return temp_file + except Exception as e: + logger.error(f"Failed to download image from {url}: {e}") + raise + + +def parse_dimensions(size_string: str | None) -> tuple[int | None, int | None]: + """Parse a size string in "widthxheight" format to (width, height) tuple. + + Args: + size_string: Size string in "widthxheight" format (e.g., "1024x1024") or None. + Spaces are automatically stripped. + + Returns: + Tuple of (width, height) as integers if parsing succeeds, (None, None) otherwise. + """ + if not size_string: + return (None, None) + + # Strip spaces from the entire string + size_string = size_string.strip() + if not size_string: + return (None, None) + + # Split by "x" + parts = size_string.split("x") + if len(parts) != 2: + return (None, None) + + # Strip spaces from each part and try to convert to int + try: + width_str = parts[0].strip() + height_str = parts[1].strip() + + if not width_str or not height_str: + return (None, None) + + width = int(width_str) + height = int(height_str) + + # Validate that both are positive + if width <= 0 or height <= 0: + return (None, None) + + return (width, height) + except ValueError: + return (None, None) + + +@dataclass +class ServerContext: + """Context for a running diffusion server.""" + + port: int + process: subprocess.Popen + model: str + stdout_file: Path + perf_log_path: Path + log_dir: Path + _stdout_fh: Any = field(repr=False) + _log_thread: threading.Thread | None = field(default=None, repr=False) + + def cleanup(self) -> None: + """Clean up server resources.""" + try: + kill_process_tree(self.process.pid) + except Exception: + pass + try: + self._stdout_fh.flush() + self._stdout_fh.close() + except Exception: + pass + + # ROCm/AMD: Extra cleanup to ensure GPU memory is released between tests + # This is needed because ROCm memory release can be slower than CUDA + if current_platform.is_hip(): + self._cleanup_rocm_gpu_memory() + # Clean up downloaded models if HF cache is not persistent + # This prevents disk exhaustion in CI when cache is not mounted + self._cleanup_hf_cache_if_not_persistent() + + def _cleanup_hf_cache_if_not_persistent(self) -> None: + """Clean up HF cache if it's not on a persistent volume. + + When running in CI without persistent cache, downloaded models accumulate + and can cause disk/memory exhaustion. This cleans up the model after each + test if the cache is not persistent. + """ + import shutil + + hf_home = os.environ.get("HF_HOME", "") + if not hf_home: + return + + hf_hub_cache = os.path.join(hf_home, "hub") + + # Check if HF cache is on a persistent volume by looking for a marker file + # or checking if the directory existed before this test run + persistent_marker = os.path.join(hf_home, ".persistent_cache") + if os.path.exists(persistent_marker): + logger.info("HF cache is persistent, skipping cleanup") + return + + # Check if the cache directory is empty or was just created + # If it has very few models, it's likely not persistent + if not os.path.exists(hf_hub_cache): + return + + try: + # Get model cache directories + model_dirs = [ + d + for d in os.listdir(hf_hub_cache) + if d.startswith("models--") + and os.path.isdir(os.path.join(hf_hub_cache, d)) + ] + + # If there are cached models but no persistent marker, clean up + # to prevent disk exhaustion in CI + if model_dirs: + logger.info( + "HF cache appears non-persistent (no .persistent_cache marker), " + "cleaning up %d model(s) to prevent disk exhaustion", + len(model_dirs), + ) + for model_dir in model_dirs: + model_path = os.path.join(hf_hub_cache, model_dir) + try: + shutil.rmtree(model_path) + logger.info("Cleaned up model cache: %s", model_dir) + except Exception as e: + logger.warning("Failed to clean up %s: %s", model_dir, e) + except Exception as e: + logger.warning("Error during HF cache cleanup: %s", e) + + def _cleanup_rocm_gpu_memory(self) -> None: + """ROCm-specific cleanup to ensure GPU memory is fully released.""" + import gc + + # Wait for process to fully terminate + try: + self.process.wait(timeout=30) + except Exception: + pass + + # Force garbage collection multiple times + for _ in range(3): + gc.collect() + + # Clear HIP memory on all GPUs + try: + import torch + + for i in range(torch.cuda.device_count()): + with torch.cuda.device(i): + torch.cuda.empty_cache() + torch.cuda.synchronize() + except Exception: + pass + + # Wait for GPU memory to be released (ROCm can be much slower than CUDA) + # The GPU driver needs time to reclaim memory from killed processes + time.sleep(15) + + +class ServerManager: + """Manages diffusion server lifecycle.""" + + def __init__( + self, + model: str, + port: int, + wait_deadline: float = 1200.0, + extra_args: str = "", + env_vars: dict[str, str] | None = None, + ): + self.model = model + self.port = port + self.wait_deadline = wait_deadline + self.extra_args = extra_args + self.env_vars = env_vars or {} + + def _wait_for_rocm_gpu_memory_clear(self, max_wait: float = 60.0) -> None: + """ROCm-specific: Wait for GPU memory to be mostly free before starting. + + ROCm GPU memory release from killed processes can be significantly slower + than CUDA, so we need to wait longer and be more patient. + """ + try: + import torch + + if not torch.cuda.is_available(): + return + + start_time = time.time() + last_total_used = float("inf") + + while time.time() - start_time < max_wait: + # Check GPU memory usage + total_used = 0 + for i in range(torch.cuda.device_count()): + mem_info = torch.cuda.mem_get_info(i) + free, total = mem_info + used = total - free + total_used += used + + # If less than 5GB is used across all GPUs, we're good + if total_used < 5 * 1024 * 1024 * 1024: # 5GB + logger.info( + "[server-test] ROCm GPU memory is clear (used: %.2f GB)", + total_used / (1024**3), + ) + return + + # Log progress + elapsed = int(time.time() - start_time) + if total_used < last_total_used: + logger.info( + "[server-test] ROCm: GPU memory clearing (used: %.2f GB, elapsed: %ds)", + total_used / (1024**3), + elapsed, + ) + else: + logger.info( + "[server-test] ROCm: Waiting for GPU memory (used: %.2f GB, elapsed: %ds)", + total_used / (1024**3), + elapsed, + ) + last_total_used = total_used + time.sleep(3) + + # Final warning with detailed GPU info + logger.warning( + "[server-test] ROCm GPU memory not fully cleared after %.0fs (used: %.2f GB). " + "Proceeding anyway - this may cause OOM.", + max_wait, + total_used / (1024**3), + ) + except Exception as e: + logger.debug("[server-test] Could not check ROCm GPU memory: %s", e) + + def start(self) -> ServerContext: + """Start the diffusion server and wait for readiness.""" + # ROCm/AMD: Wait for GPU memory to be clear before starting + # This prevents OOM when running sequential tests on ROCm + if current_platform.is_hip(): + self._wait_for_rocm_gpu_memory_clear() + + log_dir, perf_log_path = prepare_perf_log() + + safe_model_name = self.model.replace("/", "_") + stdout_path = ( + Path(tempfile.gettempdir()) + / f"sgl_server_{self.port}_{safe_model_name}.log" + ) + stdout_path.unlink(missing_ok=True) + + command = [ + "sglang", + "serve", + "--model-path", + self.model, + "--port", + str(self.port), + "--log-level=debug", + ] + if self.extra_args.strip(): + command.extend(self.extra_args.strip().split()) + + env = os.environ.copy() + env["SGLANG_DIFFUSION_STAGE_LOGGING"] = "1" + env["SGLANG_PERF_LOG_DIR"] = log_dir.as_posix() + + # Apply custom environment variables + env.update(self.env_vars) + + # TODO: unify with run_command + logger.info(f"Running command: {shlex.join(command)}") + + process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + env=env, + ) + + log_thread = None + stdout_fh = stdout_path.open("w", encoding="utf-8", buffering=1) + if process.stdout: + + def _log_pipe(pipe: Any, file: Any) -> None: + """Read from pipe and write to file and stdout.""" + try: + with pipe: + for line in iter(pipe.readline, ""): + sys.stdout.write(line) + sys.stdout.flush() + file.write(line) + file.flush() + except Exception as e: + logger.error("Log pipe thread error: %s", e) + finally: + file.close() + logger.debug("Log pipe thread finished.") + + log_thread = threading.Thread( + target=_log_pipe, args=(process.stdout, stdout_fh) + ) + log_thread.daemon = True + log_thread.start() + + logger.info( + "[server-test] Starting server pid=%s, model=%s, log=%s", + process.pid, + self.model, + stdout_path, + ) + + self._wait_for_ready(process, stdout_path) + + return ServerContext( + port=self.port, + process=process, + model=self.model, + stdout_file=stdout_path, + perf_log_path=perf_log_path, + log_dir=log_dir, + _stdout_fh=stdout_fh, + _log_thread=log_thread, + ) + + def _wait_for_ready(self, process: subprocess.Popen, stdout_path: Path) -> None: + """Wait for server to become ready.""" + start = time.time() + ready_message = "Application startup complete." + log_period = 30 + prev_log_period_count = 0 + + while time.time() - start < self.wait_deadline: + if process.poll() is not None: + tail = self._get_log_tail(stdout_path) + raise RuntimeError( + f"Server exited early (code {process.returncode}).\n{tail}" + ) + + if stdout_path.exists(): + try: + content = stdout_path.read_text(encoding="utf-8", errors="ignore") + if ready_message in content: + logger.info("[server-test] Server ready") + return + except Exception as e: + logger.debug("Could not read log yet: %s", e) + + elapsed = int(time.time() - start) + if (elapsed // log_period) > prev_log_period_count: + prev_log_period_count = elapsed // log_period + logger.info("[server-test] Waiting for server... elapsed=%ss", elapsed) + time.sleep(1) + + tail = self._get_log_tail(stdout_path) + raise TimeoutError(f"Server not ready within {self.wait_deadline}s.\n{tail}") + + @staticmethod + def _get_log_tail(path: Path, lines: int = 200) -> str: + """Get the last N lines from a log file.""" + try: + content = path.read_text(encoding="utf-8", errors="ignore") + return "\n".join(content.splitlines()[-lines:]) + except Exception: + return "" + + +class PerformanceValidator: + """Validates performance metrics against expectations.""" + + is_video_gen: bool = False + + def __init__( + self, + scenario: ScenarioConfig, + tolerances: ToleranceConfig, + step_fractions: Sequence[float], + ): + self.scenario = scenario + self.tolerances = tolerances + self.step_fractions = step_fractions + self.is_baseline_generation_mode = ( + os.environ.get("SGLANG_GEN_BASELINE", "0") == "1" + ) + + def _assert_le( + self, + name: str, + actual: float, + expected: float, + tolerance: float, + min_abs_tolerance_ms: float = 20.0, + ): + """Assert that actual is less than or equal to expected within a tolerance. + + Uses the larger of relative tolerance or absolute tolerance to prevent + flaky failures on very fast operations. + + For AMD GPUs, uses 100% higher tolerance and issues warning instead of assertion. + """ + # Check if running on AMD GPU + is_amd = current_platform.is_hip() + + if is_amd: + # Use 100% higher tolerance for AMD (2x the expected value) + amd_tolerance = 1.0 # 100% + upper_bound = calculate_upper_bound( + expected, amd_tolerance, min_abs_tolerance_ms + ) + if actual > upper_bound: + logger.warning( + f"[AMD PERF WARNING] Validation would fail for '{name}'.\n" + f" Actual: {actual:.4f}ms\n" + f" Expected: {expected:.4f}ms\n" + f" AMD Limit: {upper_bound:.4f}ms " + f"(rel_tol: {amd_tolerance:.1%}, abs_pad: {min_abs_tolerance_ms}ms)\n" + f" Original tolerance was: {tolerance:.1%}" + ) + else: + upper_bound = calculate_upper_bound( + expected, tolerance, min_abs_tolerance_ms + ) + assert actual <= upper_bound, ( + f"Validation failed for '{name}'.\n" + f" Actual: {actual:.4f}ms\n" + f" Expected: {expected:.4f}ms\n" + f" Limit: {upper_bound:.4f}ms " + f"(rel_tol: {tolerance:.1%}, abs_pad: {min_abs_tolerance_ms}ms)" + ) + + def validate( + self, perf_record: RequestPerfRecord, *args, **kwargs + ) -> PerformanceSummary: + """Validate all performance metrics and return summary.""" + summary = self.collect_metrics(perf_record) + if self.is_baseline_generation_mode: + return summary + + self._validate_e2e(summary) + self._validate_denoise_agg(summary) + self._validate_denoise_steps(summary) + self._validate_stages(summary) + + return summary + + def collect_metrics( + self, + perf_record: RequestPerfRecord, + ) -> PerformanceSummary: + return PerformanceSummary.from_req_perf_record(perf_record, self.step_fractions) + + def _validate_e2e(self, summary: PerformanceSummary) -> None: + """Validate end-to-end performance.""" + assert summary.e2e_ms > 0, "E2E duration missing" + self._assert_le( + "E2E Latency", + summary.e2e_ms, + self.scenario.expected_e2e_ms, + self.tolerances.e2e, + ) + + def _validate_denoise_agg(self, summary: PerformanceSummary) -> None: + """Validate aggregate denoising metrics.""" + assert summary.avg_denoise_ms > 0, "Denoising step timings missing" + + self._assert_le( + "Average Denoise Step", + summary.avg_denoise_ms, + self.scenario.expected_avg_denoise_ms, + self.tolerances.denoise_agg, + ) + self._assert_le( + "Median Denoise Step", + summary.median_denoise_ms, + self.scenario.expected_median_denoise_ms, + self.tolerances.denoise_agg, + ) + + def _validate_denoise_steps(self, summary: PerformanceSummary) -> None: + """Validate individual denoising steps.""" + for idx, actual in summary.sampled_steps.items(): + expected = self.scenario.denoise_step_ms.get(idx) + if expected is None: + continue + # FIXME: hardcode, looser for first step + tolerance = 0.4 if idx == 0 else self.tolerances.denoise_step + + self._assert_le( + f"Denoise Step {idx}", + actual, + expected, + tolerance, + ) + + def _validate_stages(self, summary: PerformanceSummary) -> None: + """Validate stage-level metrics.""" + assert summary.stage_metrics, "Stage metrics missing" + + for stage, expected in self.scenario.stages_ms.items(): + if stage == "per_frame_generation" and self.is_video_gen: + continue + actual = summary.stage_metrics.get(stage) + assert actual is not None, f"Stage {stage} timing missing" + tolerance = ( + self.tolerances.denoise_stage + if stage == "DenoisingStage" + else self.tolerances.non_denoise_stage + ) + self._assert_le( + f"Stage '{stage}'", + actual, + expected, + tolerance, + min_abs_tolerance_ms=120.0, # relax absolute tolerance for non-denoising stages + ) + + +class VideoPerformanceValidator(PerformanceValidator): + """Extended validator for video diffusion with frame-level metrics.""" + + is_video_gen = True + + def validate( + self, + perf_record: RequestPerfRecord, + num_frames: int | None = None, + ) -> PerformanceSummary: + """Validate video metrics including frame generation rates.""" + summary = super().validate(perf_record) + + if num_frames and summary.e2e_ms > 0: + summary.total_frames = num_frames + summary.avg_frame_time_ms = summary.e2e_ms / num_frames + summary.frames_per_second = 1000.0 / summary.avg_frame_time_ms + + if not self.is_baseline_generation_mode: + self._validate_frame_rate(summary) + + return summary + + def _validate_frame_rate(self, summary: PerformanceSummary) -> None: + """Validate frame generation performance.""" + expected_frame_time = self.scenario.stages_ms.get("per_frame_generation") + if expected_frame_time and summary.avg_frame_time_ms: + self._assert_le( + "Average Frame Time", + summary.avg_frame_time_ms, + expected_frame_time, + self.tolerances.denoise_stage, + ) + + +class MeshValidator(PerformanceValidator): + """Validator for 3D mesh generation. Inherits perf validation from PerformanceValidator.""" + + pass + + +HUNYUAN3D_REFERENCE_URL = ( + "https://raw.githubusercontent.com/sgl-project/sgl-test-files/" + "main/diffusion-ci/consistency_gt/1-gpu/hunyuan3d_2_0/hunyuan3d.glb" +) + + +def _download_reference_mesh(url: str) -> Path: + """Download a reference mesh from URL, caching in temp dir.""" + import hashlib + + cache_name = f"ref_mesh_{hashlib.md5(url.encode()).hexdigest()}.glb" + cache_path = Path(tempfile.gettempdir()) / cache_name + if cache_path.exists(): + logger.info(f"Using cached reference mesh: {cache_path}") + return cache_path + + logger.info(f"Downloading reference mesh from: {url}") + with urlopen(url, timeout=60) as resp: + cache_path.write_bytes(resp.read()) + logger.info(f"Reference mesh cached at: {cache_path}") + return cache_path + + +def validate_mesh_correctness( + generated_mesh_path: str, + reference_url: str = HUNYUAN3D_REFERENCE_URL, + num_sample_points: int = 4096, + cd_threshold_ratio: float = 0.01, + random_seed: int = 42, +): + """Validate mesh geometric similarity against a reference via Chamfer Distance. + + Downloads the reference mesh from a URL (cached), samples point clouds from + both meshes, and asserts Chamfer Distance is within threshold. + """ + import numpy as np + + try: + import trimesh + except ImportError: + pytest.fail("trimesh is required for mesh validation: pip install trimesh") + + from scipy.spatial import cKDTree + + # Load generated mesh + generated_mesh = trimesh.load(generated_mesh_path) + if isinstance(generated_mesh, trimesh.Scene): + generated_mesh = generated_mesh.dump(concatenate=True) + + # Download and load reference mesh + ref_path = _download_reference_mesh(reference_url) + reference_mesh = trimesh.load(str(ref_path)) + if isinstance(reference_mesh, trimesh.Scene): + reference_mesh = reference_mesh.dump(concatenate=True) + + # Bounding box diagonal for threshold normalization + ref_bbox = reference_mesh.bounding_box.bounds + bbox_diagonal = float(np.linalg.norm(ref_bbox[1] - ref_bbox[0])) + cd_threshold = cd_threshold_ratio * bbox_diagonal + + # Sample point clouds + np.random.seed(random_seed) + gen_points = np.array( + generated_mesh.sample(num_sample_points, return_index=True)[0] + ) + ref_points = np.array( + reference_mesh.sample(num_sample_points, return_index=True)[0] + ) + + # Bidirectional Chamfer Distance + tree1 = cKDTree(gen_points) + tree2 = cKDTree(ref_points) + forward_cd = float(np.mean(tree2.query(gen_points)[0] ** 2)) + backward_cd = float(np.mean(tree1.query(ref_points)[0] ** 2)) + total_cd = forward_cd + backward_cd + + assert total_cd <= cd_threshold, ( + f"Chamfer Distance check failed: total_cd={total_cd:.6f}, " + f"threshold={cd_threshold:.6f} ({cd_threshold_ratio * 100:.2f}% of bbox diagonal {bbox_diagonal:.4f})" + ) + + +# Registry of validators by name +VALIDATOR_REGISTRY = { + "default": PerformanceValidator, + "video": VideoPerformanceValidator, + "mesh": MeshValidator, +} + + +def get_generate_fn( + model_path: str, + modality: str, + sampling_params: DiffusionSamplingParams, +) -> Callable[[str, Client], tuple[str, bytes]]: + """Return appropriate generation function for the case.""" + # Allow override via environment variable (useful for AMD where large resolutions cause slow VAE) + output_size = os.environ.get("SGLANG_TEST_OUTPUT_SIZE", sampling_params.output_size) + n = sampling_params.num_outputs_per_prompt + + def _create_and_download_video( + client, + case_id, + *, + model: str, + size: str, + prompt: str | None = None, + seconds: int | None = None, + input_reference: Any | None = None, + extra_body: dict[Any] | None = None, + expected_frame_count: int | None = None, + ) -> str: + """ + Create a video job via /v1/videos, poll until completion, + then download the binary content and validate it. + + Returns request-id + """ + + create_kwargs: dict[str, Any] = { + "model": model, + "size": size, + } + if prompt is not None: + create_kwargs["prompt"] = prompt + if seconds is not None: + create_kwargs["seconds"] = seconds + if input_reference is not None: + create_kwargs["input_reference"] = input_reference # triggers multipart + if extra_body is not None: + create_kwargs["extra_body"] = extra_body + + job = client.videos.create(**create_kwargs) # type: ignore[attr-defined] + video_id = job.id + + job_completed = False + is_baseline_generation_mode = os.environ.get("SGLANG_GEN_BASELINE", "0") == "1" + # Check if running on AMD GPU - use longer timeout + is_amd = current_platform.is_hip() + if is_baseline_generation_mode: + timeout = 3600.0 + elif is_amd: + timeout = 2400.0 # 40 minutes for AMD + else: + timeout = 1200.0 + deadline = time.time() + timeout + while True: + page = client.videos.list() # type: ignore[attr-defined] + item = next((v for v in page.data if v.id == video_id), None) + + if item and getattr(item, "status", None) == "completed": + job_completed = True + break + + if time.time() > deadline: + break + + time.sleep(1) + + if not job_completed: + if is_baseline_generation_mode: + logger.warning( + f"{case_id}: video job {video_id} timed out during baseline generation. " + "Attempting to collect performance data anyway." + ) + return (video_id, b"") + + if is_amd: + logger.warning( + f"[AMD TIMEOUT WARNING] {case_id}: video job {video_id} did not complete " + f"within {timeout}s timeout. This may indicate performance issues on AMD." + ) + pytest.skip( + f"{case_id}: video job timed out on AMD after {timeout}s - skipping" + ) + + pytest.fail(f"{case_id}: video job {video_id} did not complete in time") + + # download video + resp = client.videos.download_content(video_id=video_id) # type: ignore[attr-defined] + content = resp.read() + validate_openai_video(content) + + expected_filename = f"{video_id}.mp4" + tmp_path = expected_filename + with open(tmp_path, "wb") as f: + f.write(content) + + # Validate output file + expected_width, expected_height = parse_dimensions(size) + validate_video_file( + tmp_path, expected_filename, expected_width, expected_height + ) + + if expected_frame_count is not None: + actual_count = get_video_frame_count(tmp_path) + assert actual_count == expected_frame_count, ( + f"{case_id}: frame count mismatch after interpolation — " + f"expected {expected_frame_count}, got {actual_count}" + ) + + upload_file_to_slack( + case_id=case_id, + model=model_path, + prompt=sampling_params.prompt, + file_path=tmp_path, + origin_file_path=sampling_params.image_path, + ) + os.remove(tmp_path) + + return (video_id, content) + + video_seconds = sampling_params.seconds or 4 + + def generate_image(case_id, client) -> tuple[str, bytes]: + """T2I: Text to Image generation.""" + if not sampling_params.prompt: + pytest.skip(f"{case_id}: no text prompt configured") + + # Request parameters that affect output format + req_output_format = None # Not specified in current request + req_background = None # Not specified in current request + + # Build extra_body for optional features + extra_body = {} + if sampling_params.enable_teacache: + extra_body["enable_teacache"] = True + + response = client.images.with_raw_response.generate( + model=model_path, + prompt=sampling_params.prompt, + n=n, + size=output_size, + response_format="b64_json", + extra_body=extra_body if extra_body else None, + ) + result = response.parse() + validate_image(result.data[0].b64_json) + + rid = result.id + + img_data = base64.b64decode(result.data[0].b64_json) + # Infer expected format from request parameters + expected_ext = get_expected_image_format(req_output_format, req_background) + expected_filename = f"{result.created}.{expected_ext}" + tmp_path = expected_filename + with open(tmp_path, "wb") as f: + f.write(img_data) + + # Validate output file + expected_width, expected_height = parse_dimensions(output_size) + validate_image_file( + tmp_path, + expected_filename, + expected_width, + expected_height, + output_format=req_output_format, + background=req_background, + ) + + upload_file_to_slack( + case_id=case_id, + model=model_path, + prompt=sampling_params.prompt, + file_path=tmp_path, + ) + os.remove(tmp_path) + + return (rid, img_data) + + def generate_image_edit(case_id, client) -> tuple[str, bytes]: + """TI2I: Text + Image -> Image edit.""" + if not sampling_params.prompt or not sampling_params.image_path: + pytest.skip(f"{case_id}: no edit config") + + image_paths = sampling_params.image_path + + if not isinstance(image_paths, list): + image_paths = [image_paths] + + new_image_paths = [] + for image_path in image_paths: + if is_image_url(image_path): + new_image_paths.append(download_image_from_url(str(image_path))) + else: + local_path = Path(image_path) + new_image_paths.append(local_path) + if not local_path.exists(): + pytest.skip(f"{case_id}: file missing: {image_path}") + + image_paths = new_image_paths + + # Request parameters that affect output format + req_output_format = ( + sampling_params.output_format + ) # Not specified in current request + req_background = None # Not specified in current request + + # Build extra_body for optional features + extra_body = {"num_frames": sampling_params.num_frames} + if sampling_params.enable_teacache: + extra_body["enable_teacache"] = True + + images = [open(image_path, "rb") for image_path in image_paths] + try: + response = client.images.with_raw_response.edit( + model=model_path, + image=images, + prompt=sampling_params.prompt, + n=n, + size=output_size, + response_format="b64_json", + output_format=req_output_format, + extra_body=extra_body, + ) + finally: + for img in images: + img.close() + + result = response.parse() + validate_image(result.data[0].b64_json) + + img_data = base64.b64decode(result.data[0].b64_json) + rid = result.id + + # Infer expected format from request parameters + expected_ext = get_expected_image_format(req_output_format, req_background) + expected_filename = f"{rid}.{expected_ext}" + tmp_path = expected_filename + with open(tmp_path, "wb") as f: + f.write(img_data) + + # Validate output file + expected_width, expected_height = parse_dimensions(output_size) + validate_image_file( + tmp_path, + expected_filename, + expected_width, + expected_height, + output_format=req_output_format, + background=req_background, + ) + + upload_file_to_slack( + case_id=case_id, + model=model_path, + prompt=sampling_params.prompt, + file_path=tmp_path, + origin_file_path=sampling_params.image_path, + ) + os.remove(tmp_path) + + return (rid, img_data) + + def generate_image_edit_url(case_id, client) -> tuple[str, bytes]: + """TI2I: Text + Image ? Image edit using direct URL transfer (no pre-download).""" + if not sampling_params.prompt or not sampling_params.image_path: + pytest.skip(f"{case_id}: no edit config") + # Handle both single URL and list of URLs + image_urls = sampling_params.image_path + if not isinstance(image_urls, list): + image_urls = [image_urls] + + # Validate all URLs + for url in image_urls: + if not is_image_url(url): + pytest.skip( + f"{case_id}: image_path must be a URL for URL direct test: {url}" + ) + + # Request parameters that affect output format + req_output_format = ( + sampling_params.output_format + ) # Not specified in current request + req_background = None # Not specified in current request + + response = client.images.with_raw_response.edit( + model=model_path, + prompt=sampling_params.prompt, + image=[], # Only for OpenAI verification + n=n, + size=sampling_params.output_size, + response_format="b64_json", + output_format=req_output_format, + extra_body={"url": image_urls, "num_frames": sampling_params.num_frames}, + ) + + result = response.parse() + rid = result.id + + validate_image(result.data[0].b64_json) + + # Save and upload result for verification + img_data = base64.b64decode(result.data[0].b64_json) + # Infer expected format from request parameters + expected_ext = get_expected_image_format(req_output_format, req_background) + expected_filename = f"{rid}.{expected_ext}" + tmp_path = expected_filename + with open(tmp_path, "wb") as f: + f.write(img_data) + + # Validate output file + expected_width, expected_height = parse_dimensions(sampling_params.output_size) + validate_image_file( + tmp_path, + expected_filename, + expected_width, + expected_height, + output_format=req_output_format, + background=req_background, + ) + + upload_file_to_slack( + case_id=case_id, + model=model_path, + prompt=sampling_params.prompt, + file_path=tmp_path, + origin_file_path=str(sampling_params.image_path), + ) + os.remove(tmp_path) + + return (rid, img_data) + + def generate_video(case_id, client) -> tuple[str, bytes]: + """T2V: Text ? Video.""" + if not sampling_params.prompt: + pytest.skip(f"{case_id}: no text prompt configured") + + # Build extra_body for optional features + extra_body = {} + if sampling_params.enable_teacache: + extra_body["enable_teacache"] = True + if sampling_params.num_frames: + extra_body["num_frames"] = sampling_params.num_frames + if sampling_params.enable_frame_interpolation: + extra_body["enable_frame_interpolation"] = True + extra_body["frame_interpolation_exp"] = ( + sampling_params.frame_interpolation_exp + ) + + # Compute expected output frame count for validation + expected_frame_count = None + if sampling_params.enable_frame_interpolation and sampling_params.num_frames: + n = sampling_params.num_frames + exp = sampling_params.frame_interpolation_exp + expected_frame_count = (n - 1) * (2**exp) + 1 + + return _create_and_download_video( + client, + case_id, + model=model_path, + prompt=sampling_params.prompt, + size=output_size, + seconds=video_seconds, + extra_body=extra_body if extra_body else None, + expected_frame_count=expected_frame_count, + ) + + def generate_image_to_video(case_id, client) -> tuple[str, bytes]: + """I2V: Image -> Video (optional prompt).""" + if not sampling_params.image_path: + pytest.skip(f"{case_id}: no input image configured") + + if is_image_url(sampling_params.image_path): + image_path = download_image_from_url(str(sampling_params.image_path)) + else: + image_path = Path(sampling_params.image_path) + if not image_path.exists(): + pytest.skip(f"{case_id}: file missing: {image_path}") + + # Build extra_body for optional features + extra_body = {} + if sampling_params.enable_teacache: + extra_body["enable_teacache"] = True + + with image_path.open("rb") as fh: + return _create_and_download_video( + client, + case_id, + model=model_path, + prompt=sampling_params.prompt, + size=output_size, + seconds=video_seconds, + input_reference=fh, + extra_body=extra_body if extra_body else None, + ) + + def generate_text_url_image_to_video(case_id, client) -> tuple[str, bytes]: + if not sampling_params.prompt or not sampling_params.image_path: + pytest.skip(f"{case_id}: no edit config") + + # Build extra_body for optional features + extra_body = {"reference_url": sampling_params.image_path} + if sampling_params.enable_teacache: + extra_body["enable_teacache"] = True + + return _create_and_download_video( + client, + case_id, + model=model_path, + prompt=sampling_params.prompt, + size=sampling_params.output_size, + seconds=video_seconds, + extra_body={ + "reference_url": sampling_params.image_path, + "fps": sampling_params.fps, + "num_frames": sampling_params.num_frames, + }, + ) + + def generate_text_image_to_video(case_id, client) -> tuple[str, bytes]: + """TI2V: Text + Image -> Video.""" + if not sampling_params.prompt or not sampling_params.image_path: + pytest.skip(f"{case_id}: no edit config") + + if is_image_url(sampling_params.image_path): + image_path = download_image_from_url(str(sampling_params.image_path)) + else: + image_path = Path(sampling_params.image_path) + if not image_path.exists(): + pytest.skip(f"{case_id}: file missing: {image_path}") + + # Build extra_body for optional features + extra_body = {} + if sampling_params.enable_teacache: + extra_body["enable_teacache"] = True + + with image_path.open("rb") as fh: + return _create_and_download_video( + client, + case_id, + model=model_path, + prompt=sampling_params.prompt, + size=output_size, + seconds=video_seconds, + input_reference=fh, + extra_body={ + "fps": sampling_params.fps, + "num_frames": sampling_params.num_frames, + }, + ) + + def generate_mesh(case_id, client) -> tuple[str, bytes]: + """I2M: Image to Mesh generation using async /v1/meshes API.""" + import requests as http_requests + + if not sampling_params.image_path: + pytest.skip(f"{case_id}: no input image configured for mesh generation") + + image_path = sampling_params.image_path + if isinstance(image_path, str) and is_image_url(image_path): + image_path = download_image_from_url(image_path) + elif isinstance(image_path, Path): + if not image_path.exists(): + pytest.skip(f"{case_id}: image file missing: {image_path}") + else: + image_path = Path(str(image_path)) + if not image_path.exists(): + pytest.skip(f"{case_id}: image file missing: {image_path}") + + base_url = str(client.base_url).rstrip("/") + if base_url.endswith("/v1"): + base_url = base_url[:-3] + + create_url = f"{base_url}/v1/meshes" + + with open(str(image_path), "rb") as img_file: + files = {"image": (Path(str(image_path)).name, img_file, "image/png")} + data = { + "prompt": "generate 3d mesh", + "model": model_path, + "seed": "0", + "guidance_scale": "5.0", + "num_inference_steps": "50", + } + + logger.info(f"[Mesh Gen] Sending request to {create_url}") + + try: + response = http_requests.post( + create_url, files=files, data=data, timeout=60 + ) + except Exception as e: + pytest.fail(f"{case_id}: mesh creation request failed: {e}") + + if response.status_code != 200: + pytest.fail(f"{case_id}: mesh creation failed: {response.text}") + + job = response.json() + mesh_id = job.get("id") + if not mesh_id: + pytest.fail(f"{case_id}: no mesh id in response: {job}") + + poll_url = f"{base_url}/v1/meshes/{mesh_id}" + poll_interval = 5 + max_wait = 1200 + elapsed = 0 + + while elapsed < max_wait: + time.sleep(poll_interval) + elapsed += poll_interval + + try: + poll_resp = http_requests.get(poll_url, timeout=30) + except Exception as e: + logger.warning(f"[Mesh Gen] Poll failed: {e}") + continue + + if poll_resp.status_code != 200: + continue + + status_data = poll_resp.json() + status = status_data.get("status", "") + + if status == "completed": + content_url = f"{base_url}/v1/meshes/{mesh_id}/content" + try: + content_resp = http_requests.get(content_url, timeout=60) + except Exception as e: + pytest.fail(f"{case_id}: mesh download failed: {e}") + + if content_resp.status_code != 200: + pytest.fail(f"{case_id}: mesh download failed: {content_resp.text}") + + temp_path = Path(tempfile.gettempdir()) / f"mesh_test_{mesh_id}.glb" + temp_path.write_bytes(content_resp.content) + MESH_OUTPUT_PATHS[case_id] = str(temp_path) + + logger.info(f"[Mesh Gen] Mesh downloaded to {temp_path}") + return (mesh_id, b"") + elif status == "failed": + error = status_data.get("error", {}) + pytest.fail(f"{case_id}: mesh generation failed: {error}") + + pytest.fail(f"{case_id}: mesh generation timed out after {max_wait}s") + + if modality == "3d": + fn = generate_mesh + elif modality == "video": + if sampling_params.image_path and sampling_params.prompt: + if getattr(sampling_params, "direct_url_test", False): + fn = generate_text_url_image_to_video + else: + fn = generate_text_image_to_video + elif sampling_params.image_path: + fn = generate_image_to_video + else: + fn = generate_video + elif sampling_params.prompt and sampling_params.image_path: + if getattr(sampling_params, "direct_url_test", False): + fn = generate_image_edit_url + else: + fn = generate_image_edit + else: + fn = generate_image + + return fn diff --git a/sglang/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py b/sglang/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py new file mode 100644 index 0000000000000000000000000000000000000000..ac3c8d7ff40d398aa06cfa1fdd94e9dc55ca5d8e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py @@ -0,0 +1,674 @@ +"""Tests for diffusion `update_weights_from_disk`. + +This module verifies the ability to update model weights in place without restarting +the server, which is critical for RL workflows and iterative fine-tuning scenarios. + +Author: + +Menyang Liu, https://github.com/dreamyang-liu +Chenyang Zhao, https://github.com/zhaochenyang20 + +We use two model pairs for testing (base model / instruct model pairs): + +- FLUX.2-klein-base-4B / FLUX.2-klein-4B +- Qwen/Qwen-Image / Qwen/Qwen-Image-2512 + +These model pairs share the same architecture but differ in transformer +weights. The basic testing logic is to refit the instruct model into the +base model and verify the checksum of the transformer weights are the same, +which simulates the real-world RL scenario. However, since these two model +pairs only differ in transformer weights, and we want to verify update a +specific module with update_weights_from_disk API, we need to create a perturbed +instruct model that adds noise to the vae weights. In this sense, the instruct +model differs from the base model in vae and transformer weights, the text +encoder are still the same. + +To strictly verify the correctness of the refit API, we compare the checksum in +SHA-256 on the disk and the server. + +NOTE and TODO: In the refit a specific module test, we randomly select one module +from the transformer and vae to refit the server and keep other modules the same. +As described above, the vae's weights are perturbed. If we select the vae to be the +target module, ideally speaking, we should assert that the refitted vae's checksum +is the same as directly computed from the perturbed vae weights in the disk. However, +since the there is complex weight-name remapping and QKV merge during model loading, +it is not easy to compare the server-disk checksum for vae and text encoder directly. +Therefore, if the target module is vae, we only verify that the refitted vae's checksum +is different from the base model's vae's checksum. + +It should be good issue to solve for the community to adds comparison the server-disk +checksum for vae and text encoder in this test. + +============================================================================= + +Test organization: + +7 test cases in 2 classes; +two model pairs are tested locally, one in CI. + +============================================================================= + +Class 1: TestUpdateWeightsFromDisk (6 tests) — API contract, checksum & rollback +Class 2: TestUpdateWeightsFromDiskWithOffload (1 test) — Offload-aware update + checksum + +----------------------------------------------------------------------------- + +Class 1: TestUpdateWeightsFromDisk + +Validate the update_weights_from_disk API contract, request/response shape, +error handling, checksum verification, and corrupted-weight rollback. + +All tests share one class-scoped server (same process, same in-memory weights). +Tests that require "base model then update" should be explicitly reset to +base model first so behavior is order-independent and updates are real +(base -> perturbed), not no-ops (perturbed -> perturbed). + + • test_update_weights_from_disk_default + + base model -> perturbed model with flush_cache=True. + Verifies after-update transformer checksum == perturbed model's + transformer disk checksum + + + • test_update_weights_specific_modules + + base -> perturbed with flush_cache=False. Randomly selects one module + from _DIFFERING_MODULES (transformer and vae) as target_modules, updates + only that module. Verifies that: + (1) targeted module's in-memory checksum changed; + (2) non-targeted modules' in-memory checksums are unchanged. + + • test_update_weights_nonexistent_model + + model_path set to a non-existent path; must fail (400, success=False). + + Ensure server is healthy after failed update and server's transformer + checksums equal base model's transformer disk checksum. + + • test_update_weights_missing_model_path + + Request body empty (no model_path); must fail (400, success=False). + + Ensure server is healthy after failed update and server's transformer + checksums equal base model's transformer disk checksum. + + • test_update_weights_nonexistent_module + + target_modules=["nonexistent_module"]; must fail (400, success=False). + + Verify server is healthy after failed update and server's checksums + equal base model's transformer disk checksum. + + • test_corrupted_weights_rollback + + All-or-nothing rollback: We first refit the server from base model -> + perturbed model. We manually truncate the vae weights of the base + model to get a corrupted model. We then call the refit to update + the server from the perturbed model -> corrupted model. Verify that: + + 1. The update fails due to truncated vae, server should roll back to the + perturbed model, i.e., server's transformer weights == perturbed model's + transformer weights != base model's transformer weights. + + 2. After the rollback, server's vae weights == perturbed model's vae + weights != base model's vae weights. + + 3. After the rollback, server's text encoder weights == base model's + text encoder weights == perturbed model's text encoder weights. + +----------------------------------------------------------------------------- + +Class 2: TestUpdateWeightsFromDiskWithOffload + + +Ensure weight updates and checksum verification work when layerwise offload is enabled +(--dit-layerwise-offload). With offload, parameters live in CPU buffers and only left +small torch.empty((1,)) as placeholders on GPU; the updater must write into CPU buffers +and update prefetched GPU tensors without shape mismatch. + + • test_update_weights_with_offload_enabled + + Server with --dit-layerwise-offload (base). Load perturbed checkpoint; + must succeed (200, success=True), no "Shape mismatch". server's transformer checksum + matches perturbed model's transformer disk checksum. +""" + +from __future__ import annotations + +import functools +import os +import random +import shutil +import tempfile +import threading +from collections.abc import Callable + +import pytest +import requests +from safetensors.torch import load_file, save_file + +from sglang.multimodal_gen.runtime.loader.utils import ( + _list_safetensors_files, +) +from sglang.multimodal_gen.runtime.loader.weight_utils import ( + compute_weights_checksum, + safetensors_weights_iterator, +) +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.test.server.test_server_utils import ( + ServerManager, +) +from sglang.multimodal_gen.test.test_utils import ( + DEFAULT_FLUX_2_KLEIN_4B_MODEL_NAME_FOR_TEST, + DEFAULT_FLUX_2_KLEIN_BASE_4B_MODEL_NAME_FOR_TEST, + DEFAULT_QWEN_IMAGE_2512_MODEL_NAME_FOR_TEST, + DEFAULT_QWEN_IMAGE_MODEL_NAME_FOR_TEST, + get_dynamic_server_port, + is_in_ci, +) + +logger = init_logger(__name__) + + +_TRANSFORMER_MODULE = "transformer" +_VAE_MODULE = "vae" +_TEXT_ENCODER_MODULE_PREFIX = "text_encoder" + + +# Modules whose weights differ between the base model and the perturbed +# perturbed checkpoint +_DIFFERING_MODULES: list[str] = [_TRANSFORMER_MODULE, _VAE_MODULE] + +_ALL_MODEL_PAIRS: list[tuple[str, str]] = [ + ( + DEFAULT_FLUX_2_KLEIN_BASE_4B_MODEL_NAME_FOR_TEST, + DEFAULT_FLUX_2_KLEIN_4B_MODEL_NAME_FOR_TEST, + ), + ( + DEFAULT_QWEN_IMAGE_MODEL_NAME_FOR_TEST, + DEFAULT_QWEN_IMAGE_2512_MODEL_NAME_FOR_TEST, + ), +] + + +_CI_MODEL_PAIR_ENV = "SGLANG_MMGEN_UPDATE_WEIGHTS_PAIR" + + +def _resolve_active_model_pairs() -> list[tuple[str, str]]: + if not is_in_ci(): + return _ALL_MODEL_PAIRS + + pair_by_id = {pair[0].split("/")[-1]: pair for pair in _ALL_MODEL_PAIRS} + selected_pair_id = os.environ.get(_CI_MODEL_PAIR_ENV) + if selected_pair_id is None: + return [random.choice(_ALL_MODEL_PAIRS)] + + selected_pair = pair_by_id.get(selected_pair_id) + if selected_pair is None: + valid_ids = ", ".join(sorted(pair_by_id)) + raise ValueError( + f"Invalid {_CI_MODEL_PAIR_ENV}={selected_pair_id!r}. " + f"Expected one of: {valid_ids}." + ) + return [selected_pair] + + +_ACTIVE_MODEL_PAIRS = _resolve_active_model_pairs() +_PAIR_IDS = [p[0].split("/")[-1] for p in _ACTIVE_MODEL_PAIRS] + + +@functools.lru_cache(maxsize=None) +def _compute_checksum_from_disk(model_path: str, module_name: str) -> str: + """Compute SHA-256 checksum from safetensors files on disk. + + Uses the same compute_weights_checksum function as the server, + so the checksums are directly comparable. + + Results are cached (keyed on model_path and module_name) because the + same disk checksum is requested multiple times across tests. + """ + local_path = maybe_download_model(model_path) + weights_dir = os.path.join(local_path, module_name) + assert os.path.exists( + weights_dir + ), f"No weights dir for {module_name} in {local_path}" + + safetensors_files = _list_safetensors_files(weights_dir) + assert safetensors_files, f"No safetensors files in {weights_dir}" + + return compute_weights_checksum(safetensors_weights_iterator(safetensors_files)) + + +def _clone_model_with_modified_module( + src_model: str, + dst_model: str, + target_module: str, + transform_safetensor: Callable[[str, str], None], +) -> None: + # Symlink root-level files (model_index.json, etc.). + for fname in os.listdir(src_model): + src_path = os.path.join(src_model, fname) + dst_path = os.path.join(dst_model, fname) + if os.path.isfile(src_path) and not os.path.exists(dst_path): + os.symlink(src_path, dst_path) + + for module_dir in sorted(os.listdir(src_model)): + src_dir = os.path.join(src_model, module_dir) + dst_dir = os.path.join(dst_model, module_dir) + if not os.path.isdir(src_dir): + continue + + if module_dir != target_module: + if not os.path.exists(dst_dir): + os.symlink(src_dir, dst_dir) + continue + + os.makedirs(dst_dir, exist_ok=True) + transformed = False + for fname in sorted(os.listdir(src_dir)): + src_file = os.path.join(src_dir, fname) + dst_file = os.path.join(dst_dir, fname) + if not os.path.isfile(src_file): + continue + + if not fname.endswith(".safetensors") or transformed: + if not os.path.exists(dst_file): + os.symlink(src_file, dst_file) + continue + + transform_safetensor(src_file, dst_file) + transformed = True + + +def _truncate_safetensor(src_file: str, dst_file: str) -> None: + shutil.copy2(src_file, dst_file) + size = os.path.getsize(dst_file) + with open(dst_file, "r+b") as f: + f.truncate(size - 2) + logger.info( + "Created corrupted safetensors: %s (%d -> %d bytes)", + dst_file, + size, + size - 2, + ) + + +def _perturb_safetensor(src_file: str, dst_file: str) -> None: + + tensors = load_file(src_file) + perturbed = { + k: (t + 0.01 if t.is_floating_point() else t) for k, t in tensors.items() + } + save_file(perturbed, dst_file) + logger.info("Created perturbed safetensors: %s", dst_file) + + +class _UpdateWeightsApiMixin: + def _update_weights( + self, + base_url: str, + model_path: str, + flush_cache: bool = True, + target_modules: list[str] | None = None, + timeout: int = 300, + ) -> tuple[dict, int]: + payload = {"model_path": model_path, "flush_cache": flush_cache} + if target_modules is not None: + payload["target_modules"] = target_modules + response = requests.post( + f"{base_url}/update_weights_from_disk", + json=payload, + timeout=timeout, + ) + return response.json(), response.status_code + + def _get_weights_checksum( + self, + base_url: str, + module_names: list[str] | None = None, + timeout: int = 300, + ) -> dict: + payload = {} + if module_names is not None: + payload["module_names"] = module_names + response = requests.post( + f"{base_url}/get_weights_checksum", + json=payload, + timeout=timeout, + ) + assert ( + response.status_code == 200 + ), f"get_weights_checksum failed: {response.status_code} {response.text}" + return response.json() + + def _assert_server_matches_model( + self, + base_url: str, + expected_model: str, + ) -> None: + server_checksums = self._get_weights_checksum( + base_url, module_names=[_TRANSFORMER_MODULE] + ) + expected_cs = _compute_checksum_from_disk(expected_model, _TRANSFORMER_MODULE) + server_cs = server_checksums.get(_TRANSFORMER_MODULE) + assert server_cs == expected_cs, ( + f"Checksum mismatch on '{_TRANSFORMER_MODULE}'\n" + f" expected({expected_model}): {expected_cs}\n" + f" server: {server_cs}" + ) + + +class TestUpdateWeightsFromDisk(_UpdateWeightsApiMixin): + + @pytest.fixture( + scope="class", + params=_ACTIVE_MODEL_PAIRS, + ids=_PAIR_IDS, + ) + def diffusion_server_no_offload(self, request): + default_model, source_model = request.param + port = get_dynamic_server_port() + wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) + + manager = ServerManager( + model=default_model, + port=port, + wait_deadline=wait_deadline, + extra_args="--num-gpus 1", + ) + + # Ensure models are local before spawning threads that need the paths. + local_default = maybe_download_model(default_model) + local_source = maybe_download_model(source_model) + + perturbed_vae_model_dir = tempfile.mkdtemp(prefix="sglang_perturbed_vae_") + corrupted_vae_model_dir = tempfile.mkdtemp(prefix="sglang_corrupted_") + + # Run all disk I/O in background while the server boots. + bg_threads = [ + threading.Thread( + target=_compute_checksum_from_disk, args=(default_model, module) + ) + for module in _DIFFERING_MODULES + ] + [ + threading.Thread( + target=_clone_model_with_modified_module, + args=( + local_source, + perturbed_vae_model_dir, + _VAE_MODULE, + _perturb_safetensor, + ), + ), + threading.Thread( + target=_clone_model_with_modified_module, + args=( + local_default, + corrupted_vae_model_dir, + _VAE_MODULE, + _truncate_safetensor, + ), + ), + ] + for t in bg_threads: + t.start() + + ctx = manager.start() + for t in bg_threads: + t.join() + + # Sanity: all _DIFFERING_MODULES should differ between base and perturbed. + for module in _DIFFERING_MODULES: + assert _compute_checksum_from_disk( + default_model, module + ) != _compute_checksum_from_disk(perturbed_vae_model_dir, module), ( + f"Assumption violated: {module} should differ between " + f"{default_model} and {perturbed_vae_model_dir}" + ) + + try: + yield ctx, default_model, perturbed_vae_model_dir, corrupted_vae_model_dir + finally: + ctx.cleanup() + shutil.rmtree(perturbed_vae_model_dir, ignore_errors=True) + shutil.rmtree(corrupted_vae_model_dir, ignore_errors=True) + + def test_update_weights_from_disk_default(self, diffusion_server_no_offload): + """Default update (target_modules=None, flush_cache=True): all changed modules updated.""" + ctx, default_model, perturbed_model_dir, _ = diffusion_server_no_offload + base_url = f"http://localhost:{ctx.port}" + + self._update_weights(base_url, default_model, flush_cache=True) + + result, status_code = self._update_weights( + base_url, perturbed_model_dir, flush_cache=True + ) + assert status_code == 200 + assert result.get("success", False), f"Update failed: {result.get('message')}" + + self._assert_server_matches_model(base_url, perturbed_model_dir) + + def test_update_weights_specific_modules(self, diffusion_server_no_offload): + ctx, default_model, perturbed_model_dir, _ = diffusion_server_no_offload + base_url = f"http://localhost:{ctx.port}" + + # Reset server to default_model. + self._update_weights(base_url, default_model) + before_checksums = self._get_weights_checksum( + base_url, module_names=_DIFFERING_MODULES + ) + + target_modules = [random.choice(_DIFFERING_MODULES)] + result, status_code = self._update_weights( + base_url, + perturbed_model_dir, + target_modules=target_modules, + flush_cache=False, + ) + assert status_code == 200, f"Update failed: {result}" + assert result.get("success", False), f"Update failed: {result.get('message')}" + + after_checksums = self._get_weights_checksum( + base_url, module_names=_DIFFERING_MODULES + ) + + # Targeted module should have changed. + for name in target_modules: + assert after_checksums.get(name) != before_checksums.get(name), ( + f"Targeted module '{name}' checksum should change after update\n" + f" before: {before_checksums.get(name)}\n" + f" after: {after_checksums.get(name)}" + ) + + # Non-targeted modules should be unchanged. + for name, cs in after_checksums.items(): + if name in target_modules or cs == "not_found": + continue + assert cs == before_checksums.get(name), ( + f"Non-targeted module '{name}' should be unchanged\n" + f" before: {before_checksums.get(name)}\n" + f" after: {cs}" + ) + + def test_update_weights_nonexistent_model(self, diffusion_server_no_offload): + """Nonexistent model path must fail (400). Server healthy, checksums == base disk.""" + ctx, default_model, _, _ = diffusion_server_no_offload + base_url = f"http://localhost:{ctx.port}" + + self._update_weights(base_url, default_model) + + result, status_code = self._update_weights( + base_url, + "/nonexistent/path/to/model", + timeout=60, + ) + logger.info(f"Update result for nonexistent model: {result}") + + assert status_code == 400, f"Expected 400, got {status_code}" + assert not result.get("success", True), "Should fail for nonexistent model" + self._assert_server_matches_model(base_url, default_model) + + def test_update_weights_missing_model_path(self, diffusion_server_no_offload): + """Request without model_path must fail (400). Server healthy, checksums == base disk.""" + ctx, default_model, _, _ = diffusion_server_no_offload + base_url = f"http://localhost:{ctx.port}" + + self._update_weights(base_url, default_model) + + response = requests.post( + f"{base_url}/update_weights_from_disk", + json={}, + timeout=30, + ) + + assert response.status_code == 400, f"Expected 400, got {response.status_code}" + result = response.json() + assert not result.get("success", True), "Should fail when model_path is missing" + self._assert_server_matches_model(base_url, default_model) + + def test_update_weights_nonexistent_module(self, diffusion_server_no_offload): + """Nonexistent module must fail (400). Server healthy, checksums == base disk.""" + ctx, default_model, perturbed_model_dir, _ = diffusion_server_no_offload + base_url = f"http://localhost:{ctx.port}" + + self._update_weights(base_url, default_model) + + result, status_code = self._update_weights( + base_url, + perturbed_model_dir, + target_modules=["nonexistent_module"], + timeout=60, + ) + logger.info(f"Update nonexistent module result: {result}") + + assert status_code == 400, f"Expected 400, got {status_code}" + assert not result.get("success", True), "Should fail for nonexistent module" + assert "not found in pipeline" in result.get("message", "") + self._assert_server_matches_model(base_url, default_model) + + def test_corrupted_weights_rollback(self, diffusion_server_no_offload): + ctx, default_model, perturbed_model_dir, corrupted_vae_model_dir = ( + diffusion_server_no_offload + ) + base_url = f"http://localhost:{ctx.port}" + + # base → perturbed + self._update_weights(base_url, default_model) + base_checksums = self._get_weights_checksum(base_url) + + result, status_code = self._update_weights(base_url, perturbed_model_dir) + assert status_code == 200 and result.get("success") + perturbed_checksums = self._get_weights_checksum(base_url) + + text_encoder_modules = sorted( + name + for name in perturbed_checksums + if _TEXT_ENCODER_MODULE_PREFIX in name + and perturbed_checksums.get(name) != "not_found" + and base_checksums.get(name) != "not_found" + ) + assert ( + text_encoder_modules + ), "Expected at least one text encoder module checksum" + + # perturbed → corrupted (should fail and rollback) + rollback_targets = [_TRANSFORMER_MODULE, _VAE_MODULE] + result, status_code = self._update_weights( + base_url, + corrupted_vae_model_dir, + target_modules=rollback_targets, + ) + assert ( + status_code == 400 + ), f"Expected 400 on corrupted weights, got {status_code}" + assert not result.get("success", True) + message = result.get("message", "") + assert "rolled back" in message.lower() + # The updater reports the first failing module in the error message. + # With ordered target_modules=[transformer, vae], this makes the + # failure point explicit: transformer is processed first, then vae fails. + assert ( + "Failed to update module 'vae'" in message + ), f"Expected vae to be the explicit failure point, got: {message}" + rolled_back_checksums = self._get_weights_checksum(base_url) + + # 1) transformer: server == perturbed != base + transformer_base = base_checksums.get(_TRANSFORMER_MODULE) + transformer_perturbed = perturbed_checksums.get(_TRANSFORMER_MODULE) + transformer_rolled_back = rolled_back_checksums.get(_TRANSFORMER_MODULE) + assert transformer_rolled_back == transformer_perturbed + assert transformer_rolled_back != transformer_base + + # 2) vae: server == perturbed != base + vae_base = base_checksums.get(_VAE_MODULE) + vae_perturbed = perturbed_checksums.get(_VAE_MODULE) + vae_rolled_back = rolled_back_checksums.get(_VAE_MODULE) + assert vae_rolled_back == vae_perturbed + assert vae_rolled_back != vae_base + + # 3) text encoder(s): server == base == perturbed + for name in text_encoder_modules: + assert rolled_back_checksums.get(name) == perturbed_checksums.get( + name + ), f"Text encoder module '{name}' should stay equal to perturbed" + assert rolled_back_checksums.get(name) == base_checksums.get( + name + ), f"Text encoder module '{name}' should stay equal to base" + + +class TestUpdateWeightsFromDiskWithOffload(_UpdateWeightsApiMixin): + """Test update_weights_from_disk with layerwise offload enabled.""" + + @pytest.fixture(scope="class", params=_ACTIVE_MODEL_PAIRS, ids=_PAIR_IDS) + def diffusion_server_with_offload(self, request): + default_model, source_model = request.param + port = get_dynamic_server_port() + wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600")) + + local_source = maybe_download_model(source_model) + perturbed_vae_model_dir = tempfile.mkdtemp(prefix="sglang_perturbed_vae_") + + clone_thread = threading.Thread( + target=_clone_model_with_modified_module, + args=( + local_source, + perturbed_vae_model_dir, + _VAE_MODULE, + _perturb_safetensor, + ), + ) + clone_thread.start() + + manager = ServerManager( + model=default_model, + port=port, + wait_deadline=wait_deadline, + extra_args="--num-gpus 1 --dit-layerwise-offload true", + ) + + ctx = manager.start() + clone_thread.join() + + try: + yield ctx, default_model, perturbed_vae_model_dir + finally: + ctx.cleanup() + shutil.rmtree(perturbed_vae_model_dir, ignore_errors=True) + + def test_update_weights_with_offload_enabled(self, diffusion_server_with_offload): + ctx, _, perturbed_model_dir = diffusion_server_with_offload + base_url = f"http://localhost:{ctx.port}" + + result, status_code = self._update_weights(base_url, perturbed_model_dir) + assert status_code == 200, f"Expected 200, got {status_code}" + assert result.get("success", False), f"Update failed: {result.get('message')}" + + message = result.get("message", "") + assert "Shape mismatch" not in message, f"Shape mismatch detected: {message}" + + self._assert_server_matches_model(base_url, perturbed_model_dir) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/sglang/python/sglang/multimodal_gen/test/server/testcase_configs.py b/sglang/python/sglang/multimodal_gen/test/server/testcase_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..130182f8f6b7e33618b8b15a83ab5d100e53e779 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/server/testcase_configs.py @@ -0,0 +1,865 @@ +""" +Configuration and data structures for diffusion performance tests. + +Usage: + +pytest python/sglang/multimodal_gen/test/server/test_server_a.py +# for a single testcase, look for the name of the testcases in DIFFUSION_CASES +pytest python/sglang/multimodal_gen/test/server/test_server_a.py -k qwen_image_t2i + + +To add a new testcase: +1. add your testcase with case-id: `my_new_test_case_id` to DIFFUSION_CASES +2. run `SGLANG_GEN_BASELINE=1 pytest -s python/sglang/multimodal_gen/test/server/ -k my_new_test_case_id` +3. insert or override the corresponding scenario in `scenarios` section of perf_baselines.json with the output baseline of step-2 + + +""" + +from __future__ import annotations + +import json +import os +import statistics +from dataclasses import dataclass, field +from pathlib import Path +from typing import Sequence + +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.utils.perf_logger import RequestPerfRecord +from sglang.multimodal_gen.test.test_utils import ( + DEFAULT_FLUX_1_DEV_MODEL_NAME_FOR_TEST, + DEFAULT_FLUX_2_DEV_MODEL_NAME_FOR_TEST, + DEFAULT_FLUX_2_KLEIN_4B_MODEL_NAME_FOR_TEST, + DEFAULT_QWEN_IMAGE_EDIT_2509_MODEL_NAME_FOR_TEST, + DEFAULT_QWEN_IMAGE_EDIT_2511_MODEL_NAME_FOR_TEST, + DEFAULT_QWEN_IMAGE_EDIT_MODEL_NAME_FOR_TEST, + DEFAULT_QWEN_IMAGE_LAYERED_MODEL_NAME_FOR_TEST, + DEFAULT_QWEN_IMAGE_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_WAN_2_1_I2V_14B_480P_MODEL_NAME_FOR_TEST, + DEFAULT_WAN_2_1_I2V_14B_720P_MODEL_NAME_FOR_TEST, + DEFAULT_WAN_2_1_T2V_1_3B_MODEL_NAME_FOR_TEST, + DEFAULT_WAN_2_1_T2V_14B_MODEL_NAME_FOR_TEST, + DEFAULT_WAN_2_2_I2V_A14B_MODEL_NAME_FOR_TEST, + DEFAULT_WAN_2_2_T2V_A14B_MODEL_NAME_FOR_TEST, + DEFAULT_WAN_2_2_TI2V_5B_MODEL_NAME_FOR_TEST, +) + + +@dataclass +class ToleranceConfig: + """Tolerance ratios for performance validation.""" + + e2e: float + denoise_stage: float + non_denoise_stage: float + denoise_step: float + denoise_agg: float + + @classmethod + def load_profile(cls, all_tolerances: dict, profile_name: str) -> ToleranceConfig: + """Load a specific tolerance profile from a dictionary of profiles.""" + # Support both flat structure (backward compatibility) and profiled structure + if "e2e" in all_tolerances and not isinstance(all_tolerances["e2e"], dict): + tol_data = all_tolerances + actual_profile = "legacy/flat" + else: + tol_data = all_tolerances.get( + profile_name, all_tolerances.get("pr_test", {}) + ) + actual_profile = ( + profile_name if profile_name in all_tolerances else "pr_test" + ) + + if not tol_data: + raise ValueError( + f"No tolerance profile found for '{profile_name}' and no default 'pr_test' profile exists." + ) + + print(f"--- Performance Tolerance Profile: {actual_profile} ---") + + return cls( + e2e=float(os.getenv("SGLANG_E2E_TOLERANCE", tol_data["e2e"])), + denoise_stage=float( + os.getenv("SGLANG_STAGE_TIME_TOLERANCE", tol_data["denoise_stage"]) + ), + non_denoise_stage=float( + os.getenv( + "SGLANG_NON_DENOISE_STAGE_TIME_TOLERANCE", + tol_data["non_denoise_stage"], + ) + ), + denoise_step=float( + os.getenv("SGLANG_DENOISE_STEP_TOLERANCE", tol_data["denoise_step"]) + ), + denoise_agg=float( + os.getenv("SGLANG_DENOISE_AGG_TOLERANCE", tol_data["denoise_agg"]) + ), + ) + + +@dataclass +class ScenarioConfig: + """Expected performance metrics for a test scenario.""" + + stages_ms: dict[str, float] + denoise_step_ms: dict[int, float] + expected_e2e_ms: float + expected_avg_denoise_ms: float + expected_median_denoise_ms: float + + +@dataclass +class BaselineConfig: + """Full baseline configuration.""" + + scenarios: dict[str, ScenarioConfig] + step_fractions: Sequence[float] + tolerances: ToleranceConfig + improvement_threshold: float + + @classmethod + def load(cls, path: Path) -> BaselineConfig: + """Load baseline configuration from JSON file.""" + with path.open("r", encoding="utf-8") as fh: + data = json.load(fh) + + # Get tolerance profile, defaulting to 'pr_test' + profile_name = "pr_test" + tolerances = ToleranceConfig.load_profile( + data.get("tolerances", {}), profile_name + ) + + scenarios = {} + for name, cfg in data["scenarios"].items(): + scenarios[name] = ScenarioConfig( + stages_ms=cfg["stages_ms"], + denoise_step_ms={int(k): v for k, v in cfg["denoise_step_ms"].items()}, + expected_e2e_ms=float(cfg["expected_e2e_ms"]), + expected_avg_denoise_ms=float(cfg["expected_avg_denoise_ms"]), + expected_median_denoise_ms=float(cfg["expected_median_denoise_ms"]), + ) + + return cls( + scenarios=scenarios, + step_fractions=tuple(data["sampling"]["step_fractions"]), + tolerances=tolerances, + improvement_threshold=data.get("improvement_reporting", {}).get( + "threshold", 0.2 + ), + ) + + def update(self, path: Path): + """Load baseline configuration from JSON file.""" + with path.open("r", encoding="utf-8") as fh: + data = json.load(fh) + + scenarios_new = {} + for name, cfg in data["scenarios"].items(): + scenarios_new[name] = ScenarioConfig( + stages_ms=cfg["stages_ms"], + denoise_step_ms={int(k): v for k, v in cfg["denoise_step_ms"].items()}, + expected_e2e_ms=float(cfg["expected_e2e_ms"]), + expected_avg_denoise_ms=float(cfg["expected_avg_denoise_ms"]), + expected_median_denoise_ms=float(cfg["expected_median_denoise_ms"]), + ) + + self.scenarios.update(scenarios_new) + return self + + +@dataclass +class DiffusionServerArgs: + """Configuration for a single model/scenario test case.""" + + model_path: str # HF repo or local path + modality: str = "image" # "image" or "video" or "3d" + + custom_validator: str | None = None # optional custom validator name + # resources + num_gpus: int = 1 + tp_size: int | None = None + ulysses_degree: int | None = None + ring_degree: int | None = None + cfg_parallel: bool | None = None + # LoRA + lora_path: str | None = ( + None # LoRA adapter path (HF repo or local path, loaded at startup) + ) + dynamic_lora_path: str | None = ( + None # LoRA path for dynamic loading test (loaded via set_lora after startup) + ) + second_lora_path: str | None = ( + None # Second LoRA adapter path for multi-LoRA testing + ) + + dit_layerwise_offload: bool = False + dit_offload_prefetch_size: int | float | None = None + enable_cache_dit: bool = False + text_encoder_cpu_offload: bool = False + + extras: list[str] = field(default_factory=lambda: []) + + def __post_init__(self): + if self.modality == "image": + self.custom_validator = "image" + elif self.modality == "video": + self.custom_validator = "video" + elif self.modality == "3d": + self.custom_validator = "mesh" + + +@dataclass(frozen=True) +class DiffusionSamplingParams: + """Configuration for a single model/scenario test case.""" + + output_size: str = "" + + # inputs and conditioning + prompt: str | None = None # text prompt for generation + image_path: Path | str | None = None # input image/video for editing (Path or URL) + + # duration + seconds: int = 1 # for video: duration in seconds + num_frames: int | None = None # for video: number of frames + fps: int | None = None # for video: frames per second + + # URL direct test flag - if True, don't pre-download URL images + direct_url_test: bool = False + + # output format + output_format: str | None = None # "png", "jpeg", "mp4", etc. + + num_outputs_per_prompt: int = 1 + + # TeaCache acceleration + enable_teacache: bool = False + + # Frame interpolation + enable_frame_interpolation: bool = False + frame_interpolation_exp: int = 1 # 1 = 2×, 2 = 4× + + +@dataclass(frozen=True) +class DiffusionTestCase: + """Configuration for a single model/scenario test case.""" + + id: str # pytest test id and scenario name + server_args: DiffusionServerArgs + sampling_params: DiffusionSamplingParams + run_perf_check: bool = True + + +def sample_step_indices( + step_map: dict[int, float], fractions: Sequence[float] +) -> list[int]: + if not step_map: + return [] + max_idx = max(step_map.keys()) + indices = set() + for fraction in fractions: + idx = min(max_idx, max(0, int(round(fraction * max_idx)))) + if idx in step_map: + indices.add(idx) + return sorted(indices) + + +@dataclass +class PerformanceSummary: + """Summary of performance of a request, built from RequestPerfRecord""" + + e2e_ms: float + avg_denoise_ms: float + median_denoise_ms: float + # { "stage_1": time_1, "stage_2": time_2 } + stage_metrics: dict[str, float] + step_metrics: list[float] + sampled_steps: dict[int, float] + all_denoise_steps: dict[int, float] + frames_per_second: float | None = None + total_frames: int | None = None + avg_frame_time_ms: float | None = None + + @staticmethod + def from_req_perf_record( + record: RequestPerfRecord, step_fractions: Sequence[float] + ): + """Collect all performance metrics into a summary without validation.""" + e2e_ms = record.total_duration_ms + + step_durations = record.steps + avg_denoise = 0.0 + median_denoise = 0.0 + if step_durations: + avg_denoise = sum(step_durations) / len(step_durations) + median_denoise = statistics.median(step_durations) + + per_step = {index: s for index, s in enumerate(step_durations)} + sample_indices = sample_step_indices(per_step, step_fractions) + sampled_steps = {idx: per_step[idx] for idx in sample_indices} + + # convert from list to dict + stage_metrics = {} + for item in record.stages: + if isinstance(item, dict) and "name" in item: + val = item.get("execution_time_ms", 0.0) + stage_metrics[item["name"]] = val + + return PerformanceSummary( + e2e_ms=e2e_ms, + avg_denoise_ms=avg_denoise, + median_denoise_ms=median_denoise, + stage_metrics=stage_metrics, + step_metrics=step_durations, + sampled_steps=sampled_steps, + all_denoise_steps=per_step, + ) + + +T2I_sampling_params = DiffusionSamplingParams( + prompt="Doraemon is eating dorayaki", + output_size="1024x1024", +) + +TI2I_sampling_params = DiffusionSamplingParams( + prompt="Convert 2D style to 3D style", + image_path="https://github.com/lm-sys/lm-sys.github.io/releases/download/test/TI2I_Qwen_Image_Edit_Input.jpg", +) + +MULTI_IMAGE_TI2I_sampling_params = DiffusionSamplingParams( + prompt="The magician bear is on the left, the alchemist bear is on the right, facing each other in the central park square.", + image_path=[ + "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/edit2509/edit2509_1.jpg", + "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/edit2509/edit2509_2.jpg", + ], + direct_url_test=True, +) +MULTI_IMAGE_TI2I_UPLOAD_sampling_params = DiffusionSamplingParams( + prompt="The magician bear is on the left, the alchemist bear is on the right, facing each other in the central park square.", + image_path=[ + "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/edit2509/edit2509_1.jpg", + "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/edit2509/edit2509_2.jpg", + ], +) +MULTI_FRAME_I2I_sampling_params = DiffusionSamplingParams( + prompt="a high quality, cute halloween themed illustration, consistent style and lighting", + image_path=[ + "https://raw.githubusercontent.com/QwenLM/Qwen-Image-Layered/main/assets/test_images/4.png" + ], + num_frames=4, + direct_url_test=True, + output_format="png", +) + +T2V_PROMPT = "A curious raccoon" + +TI2V_sampling_params = DiffusionSamplingParams( + prompt="The man in the picture slowly turns his head, his expression enigmatic and otherworldly. The camera performs a slow, cinematic dolly out, focusing on his face. Moody lighting, neon signs glowing in the background, shallow depth of field.", + image_path="https://is1-ssl.mzstatic.com/image/thumb/Music114/v4/5f/fa/56/5ffa56c2-ea1f-7a17-6bad-192ff9b6476d/825646124206.jpg/600x600bb.jpg", + direct_url_test=True, +) + +TURBOWAN_I2V_sampling_params = DiffusionSamplingParams( + prompt="The man in the picture slowly turns his head, his expression enigmatic and otherworldly. The camera performs a slow, cinematic dolly out, focusing on his face. Moody lighting, neon signs glowing in the background, shallow depth of field.", + image_path="https://is1-ssl.mzstatic.com/image/thumb/Music114/v4/5f/fa/56/5ffa56c2-ea1f-7a17-6bad-192ff9b6476d/825646124206.jpg/600x600bb.jpg", + direct_url_test=True, + output_size="960x960", + num_frames=4, + fps=4, +) + +# All test cases with clean default values +# To test different models, simply add more DiffusionCase entries +ONE_GPU_CASES_A: list[DiffusionTestCase] = [ + # === Text to Image (T2I) === + DiffusionTestCase( + "qwen_image_t2i", + DiffusionServerArgs( + model_path=DEFAULT_QWEN_IMAGE_MODEL_NAME_FOR_TEST, + modality="image", + ), + T2I_sampling_params, + ), + DiffusionTestCase( + "qwen_image_t2i_cache_dit_enabled", + DiffusionServerArgs( + model_path=DEFAULT_QWEN_IMAGE_MODEL_NAME_FOR_TEST, + modality="image", + enable_cache_dit=True, + ), + T2I_sampling_params, + ), + DiffusionTestCase( + "flux_image_t2i", + DiffusionServerArgs( + model_path=DEFAULT_FLUX_1_DEV_MODEL_NAME_FOR_TEST, modality="image" + ), + T2I_sampling_params, + ), + # TODO: modeling of flux different from official flux, so weights can't be loaded + # consider opting for a different quantized hf-repo + # DiffusionTestCase( + # "flux_image_t2i_override_transformer_weights_path_fp8", + # DiffusionServerArgs( + # model_path="black-forest-labs/FLUX.1-dev", modality="image", + # extras=["--transformer-weights-path black-forest-labs/FLUX.1-dev-FP8"] + # ), + # T2I_sampling_params, + # ), + DiffusionTestCase( + "flux_2_image_t2i", + DiffusionServerArgs( + model_path=DEFAULT_FLUX_2_DEV_MODEL_NAME_FOR_TEST, modality="image" + ), + T2I_sampling_params, + ), + DiffusionTestCase( + "flux_2_klein_image_t2i", + DiffusionServerArgs( + model_path=DEFAULT_FLUX_2_KLEIN_4B_MODEL_NAME_FOR_TEST, + modality="image", + ), + T2I_sampling_params, + ), + # TODO: replace with a faster model to test the --dit-layerwise-offload + # TODO: currently, we don't support sending more than one request in test, and setting `num_outputs_per_prompt` to 2 doesn't guarantee the denoising be executed twice, + # so we do one warmup and send one request instead + DiffusionTestCase( + "layerwise_offload", + DiffusionServerArgs( + model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + modality="image", + dit_layerwise_offload=True, + dit_offload_prefetch_size=2, + ), + T2I_sampling_params, + ), + DiffusionTestCase( + "zimage_image_t2i", + DiffusionServerArgs( + model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, modality="image" + ), + T2I_sampling_params, + ), + DiffusionTestCase( + "zimage_image_t2i_fp8", + DiffusionServerArgs( + model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + modality="image", + extras=["--transformer-path MickJ/Z-Image-Turbo-fp8"], + ), + T2I_sampling_params, + ), + # Multi-LoRA test case for Z-Image-Turbo + DiffusionTestCase( + "zimage_image_t2i_multi_lora", + DiffusionServerArgs( + model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + modality="image", + lora_path="reverentelusarca/elusarca-anime-style-lora-z-image-turbo", + second_lora_path="tarn59/pixel_art_style_lora_z_image_turbo", + ), + T2I_sampling_params, + ), + # === Text and Image to Image (TI2I) === + DiffusionTestCase( + "qwen_image_edit_ti2i", + DiffusionServerArgs( + model_path=DEFAULT_QWEN_IMAGE_EDIT_MODEL_NAME_FOR_TEST, modality="image" + ), + TI2I_sampling_params, + ), + DiffusionTestCase( + "qwen_image_edit_2509_ti2i", + DiffusionServerArgs( + model_path=DEFAULT_QWEN_IMAGE_EDIT_2509_MODEL_NAME_FOR_TEST, + modality="image", + ), + MULTI_IMAGE_TI2I_sampling_params, + ), + DiffusionTestCase( + "qwen_image_edit_2511_ti2i", + DiffusionServerArgs( + model_path=DEFAULT_QWEN_IMAGE_EDIT_2511_MODEL_NAME_FOR_TEST, + modality="image", + ), + TI2I_sampling_params, + ), + DiffusionTestCase( + "qwen_image_layered_i2i", + DiffusionServerArgs( + model_path=DEFAULT_QWEN_IMAGE_LAYERED_MODEL_NAME_FOR_TEST, + modality="image", + ), + MULTI_FRAME_I2I_sampling_params, + ), +] + +HUNYUAN3D_SHAPE_sampling_params = DiffusionSamplingParams( + prompt="", + image_path="https://raw.githubusercontent.com/sgl-project/sgl-test-files/main/diffusion-ci/consistency_gt/1-gpu/hunyuan3d_2_0/hunyuan3d.png", +) + +ONE_GPU_CASES_B: list[DiffusionTestCase] = [ + # === Text to Video (T2V) === + DiffusionTestCase( + "wan2_1_t2v_1.3b", + DiffusionServerArgs( + model_path=DEFAULT_WAN_2_1_T2V_1_3B_MODEL_NAME_FOR_TEST, + modality="video", + custom_validator="video", + ), + DiffusionSamplingParams( + prompt=T2V_PROMPT, + ), + ), + DiffusionTestCase( + "wan2_1_t2v_1.3b_text_encoder_cpu_offload", + DiffusionServerArgs( + model_path=DEFAULT_WAN_2_1_T2V_1_3B_MODEL_NAME_FOR_TEST, + modality="video", + custom_validator="video", + text_encoder_cpu_offload=True, + ), + DiffusionSamplingParams( + prompt=T2V_PROMPT, + ), + ), + # TeaCache acceleration test for Wan video model + DiffusionTestCase( + "wan2_1_t2v_1.3b_teacache_enabled", + DiffusionServerArgs( + model_path=DEFAULT_WAN_2_1_T2V_1_3B_MODEL_NAME_FOR_TEST, + modality="video", + custom_validator="video", + ), + DiffusionSamplingParams( + prompt=T2V_PROMPT, + enable_teacache=True, + ), + ), + # Frame interpolation correctness (2× / exp=1) + # Uses the same 1.3B model already in the suite; + DiffusionTestCase( + "wan2_1_t2v_1.3b_frame_interp_2x", + DiffusionServerArgs( + model_path="Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + modality="video", + custom_validator="video", + ), + DiffusionSamplingParams( + prompt=T2V_PROMPT, + enable_frame_interpolation=True, + frame_interpolation_exp=1, + ), + ), + # LoRA test case for single transformer + merge/unmerge API test + # Note: Uses dynamic_lora_path instead of lora_path to test LayerwiseOffload + set_lora interaction + # Server starts WITHOUT LoRA, then set_lora is called after startup (Wan models auto-enable layerwise offload) + DiffusionTestCase( + "wan2_1_t2v_1_3b_lora_1gpu", + DiffusionServerArgs( + model_path=DEFAULT_WAN_2_1_T2V_1_3B_MODEL_NAME_FOR_TEST, + modality="video", + custom_validator="video", + num_gpus=1, + dynamic_lora_path="Cseti/Wan-LoRA-Arcane-Jinx-v1", + ), + DiffusionSamplingParams( + prompt="csetiarcane Nfj1nx with blue hair, a woman walking in a cyberpunk city at night", + ), + ), + # NOTE(mick): flaky + # DiffusionTestCase( + # "hunyuan_video", + # DiffusionServerArgs( + # model_path="hunyuanvideo-community/HunyuanVideo", + # modality="video", + # ), + # DiffusionSamplingParams( + # prompt=T2V_PROMPT, + # ), + # ), + DiffusionTestCase( + "flux_2_ti2i", + DiffusionServerArgs( + model_path=DEFAULT_FLUX_2_DEV_MODEL_NAME_FOR_TEST, modality="image" + ), + TI2I_sampling_params, + ), + DiffusionTestCase( + "flux_2_t2i_customized_vae_path", + DiffusionServerArgs( + model_path=DEFAULT_FLUX_2_DEV_MODEL_NAME_FOR_TEST, + modality="image", + extras=["--vae-path=fal/FLUX.2-Tiny-AutoEncoder"], + ), + T2I_sampling_params, + run_perf_check=False, + ), + DiffusionTestCase( + "fast_hunyuan_video", + DiffusionServerArgs( + model_path="FastVideo/FastHunyuan-diffusers", + modality="video", + custom_validator="video", + ), + DiffusionSamplingParams( + prompt=T2V_PROMPT, + ), + ), + # === Text and Image to Video (TI2V) === + DiffusionTestCase( + "wan2_2_ti2v_5b", + DiffusionServerArgs( + model_path=DEFAULT_WAN_2_2_TI2V_5B_MODEL_NAME_FOR_TEST, + modality="video", + custom_validator="video", + ), + TI2V_sampling_params, + ), + DiffusionTestCase( + "fastwan2_2_ti2v_5b", + DiffusionServerArgs( + model_path="FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers", + modality="video", + custom_validator="video", + ), + TI2V_sampling_params, + ), +] + +# Skip hunyuan3d on AMD: marching_cubes surface extraction produces invalid SDF on ROCm. +if not current_platform.is_hip(): + ONE_GPU_CASES_B.append( + DiffusionTestCase( + "hunyuan3d_shape_gen", + DiffusionServerArgs( + model_path="tencent/Hunyuan3D-2", + modality="3d", + ), + HUNYUAN3D_SHAPE_sampling_params, + ), + ) +# Skip turbowan on AMD: Triton requires 81920 shared memory, but AMD only has 65536. +if not current_platform.is_hip(): + ONE_GPU_CASES_B.append( + DiffusionTestCase( + "turbo_wan2_1_t2v_1.3b", + DiffusionServerArgs( + model_path="IPostYellow/TurboWan2.1-T2V-1.3B-Diffusers", + modality="video", + custom_validator="video", + ), + DiffusionSamplingParams( + prompt=T2V_PROMPT, + ), + ) + ) + +TWO_GPU_CASES_A = [ + DiffusionTestCase( + "wan2_2_i2v_a14b_2gpu", + DiffusionServerArgs( + model_path=DEFAULT_WAN_2_2_I2V_A14B_MODEL_NAME_FOR_TEST, + modality="video", + custom_validator="video", + ), + TI2V_sampling_params, + ), + DiffusionTestCase( + "wan2_2_t2v_a14b_2gpu", + DiffusionServerArgs( + model_path=DEFAULT_WAN_2_2_T2V_A14B_MODEL_NAME_FOR_TEST, + modality="video", + custom_validator="video", + num_gpus=2, + ), + DiffusionSamplingParams( + prompt=T2V_PROMPT, + ), + ), + # LoRA test case for transformer_2 support + DiffusionTestCase( + "wan2_2_t2v_a14b_lora_2gpu", + DiffusionServerArgs( + model_path=DEFAULT_WAN_2_2_T2V_A14B_MODEL_NAME_FOR_TEST, + modality="video", + custom_validator="video", + num_gpus=2, + lora_path="Cseti/wan2.2-14B-Arcane_Jinx-lora-v1", + ), + DiffusionSamplingParams( + prompt="Nfj1nx with blue hair, a woman walking in a cyberpunk city at night", + ), + ), + DiffusionTestCase( + "wan2_1_t2v_14b_2gpu", + DiffusionServerArgs( + model_path=DEFAULT_WAN_2_1_T2V_14B_MODEL_NAME_FOR_TEST, + modality="video", + num_gpus=2, + custom_validator="video", + ), + DiffusionSamplingParams( + prompt=T2V_PROMPT, + output_size="832x480", + ), + ), + DiffusionTestCase( + "wan2_1_t2v_1.3b_cfg_parallel", + DiffusionServerArgs( + model_path=DEFAULT_WAN_2_1_T2V_1_3B_MODEL_NAME_FOR_TEST, + modality="video", + custom_validator="video", + num_gpus=2, + cfg_parallel=True, + ), + DiffusionSamplingParams( + prompt=T2V_PROMPT, + ), + ), + DiffusionTestCase( + "fsdp-inference", + DiffusionServerArgs( + model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + modality="image", + num_gpus=2, + extras=["--use-fsdp-inference"], + ), + T2I_sampling_params, + ), +] + +TWO_GPU_CASES_B = [ + DiffusionTestCase( + "wan2_1_i2v_14b_480P_2gpu", + DiffusionServerArgs( + model_path=DEFAULT_WAN_2_1_I2V_14B_480P_MODEL_NAME_FOR_TEST, + modality="video", + custom_validator="video", + num_gpus=2, + ), + TI2V_sampling_params, + ), + # I2V LoRA test case + DiffusionTestCase( + "wan2_1_i2v_14b_lora_2gpu", + DiffusionServerArgs( + model_path=DEFAULT_WAN_2_1_I2V_14B_720P_MODEL_NAME_FOR_TEST, + modality="video", + custom_validator="video", + num_gpus=2, + lora_path="starsfriday/Wan2.1-Divine-Power-LoRA", + ), + TI2V_sampling_params, + ), + DiffusionTestCase( + "wan2_1_i2v_14b_720P_2gpu", + DiffusionServerArgs( + model_path=DEFAULT_WAN_2_1_I2V_14B_720P_MODEL_NAME_FOR_TEST, + modality="video", + custom_validator="video", + num_gpus=2, + ), + TI2V_sampling_params, + ), + DiffusionTestCase( + "qwen_image_t2i_2_gpus", + DiffusionServerArgs( + model_path=DEFAULT_QWEN_IMAGE_MODEL_NAME_FOR_TEST, + modality="image", + num_gpus=2, + # test ring attn + ulysses_degree=1, + ring_degree=2, + ), + T2I_sampling_params, + ), + DiffusionTestCase( + "zimage_image_t2i_2_gpus", + DiffusionServerArgs( + model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + modality="image", + num_gpus=2, + ulysses_degree=2, + ), + T2I_sampling_params, + ), + DiffusionTestCase( + "flux_image_t2i_2_gpus", + DiffusionServerArgs( + model_path=DEFAULT_FLUX_1_DEV_MODEL_NAME_FOR_TEST, + modality="image", + num_gpus=2, + ), + T2I_sampling_params, + ), + DiffusionTestCase( + "flux_2_image_t2i_2_gpus", + DiffusionServerArgs( + model_path=DEFAULT_FLUX_2_DEV_MODEL_NAME_FOR_TEST, + modality="image", + num_gpus=2, + tp_size=2, + ), + T2I_sampling_params, + ), + DiffusionTestCase( + "flux_2_klein_ti2i_2_gpus", + DiffusionServerArgs( + model_path="black-forest-labs/FLUX.2-klein-4B", + modality="image", + num_gpus=2, + ), + TI2I_sampling_params, + ), +] + +if not current_platform.is_hip(): + # Flux2 multi-image edit with cache-dit, regression test + ONE_GPU_CASES_B.append( + DiffusionTestCase( + "flux_2_ti2i_multi_image_cache_dit", + DiffusionServerArgs( + model_path="black-forest-labs/FLUX.2-dev", + modality="image", + enable_cache_dit=True, + ), + MULTI_IMAGE_TI2I_UPLOAD_sampling_params, + ) + ) + # Skip turbowan because Triton requires 81920 shared memory, but AMD only has 65536. + ONE_GPU_CASES_B.append( + DiffusionTestCase( + "turbo_wan2_1_t2v_1.3b", + DiffusionServerArgs( + model_path="IPostYellow/TurboWan2.1-T2V-1.3B-Diffusers", + modality="video", + custom_validator="video", + ), + DiffusionSamplingParams( + prompt=T2V_PROMPT, + ), + ) + ) + # Skip turbowan because Triton requires 81920 shared memory, but AMD only has 65536. + TWO_GPU_CASES_A.append( + DiffusionTestCase( + "turbo_wan2_2_i2v_a14b_2gpu", + DiffusionServerArgs( + model_path="IPostYellow/TurboWan2.2-I2V-A14B-Diffusers", + modality="video", + custom_validator="video", + num_gpus=2, + tp_size=2, + ), + TURBOWAN_I2V_sampling_params, + ) + ) + +# Load global configuration +BASELINE_CONFIG = BaselineConfig.load( + Path(__file__).with_name("perf_baselines.json") +).update(Path(__file__).parent / "ascend" / "perf_baselines_npu.json") diff --git a/sglang/python/sglang/multimodal_gen/test/slack_utils.py b/sglang/python/sglang/multimodal_gen/test/slack_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..34ba78371d75e3fdf3926df5faedd8a332168c99 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/slack_utils.py @@ -0,0 +1,216 @@ +""" +This file upload the media generated in diffusion-nightly-test to a slack channel of SGLang +""" + +import logging +import os +import tempfile +from datetime import datetime +from typing import List, Union +from urllib.parse import urlparse +from urllib.request import urlopen + +from sglang.multimodal_gen.runtime.utils.perf_logger import get_git_commit_hash + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +import inspect + +try: + import sglang.multimodal_gen.test.server.testcase_configs as configs + from sglang.multimodal_gen.test.server.testcase_configs import DiffusionTestCase + + ALL_CASES = [] + for name, value in inspect.getmembers(configs): + if name.endswith("_CASES") or "_CASES_" in name: + if ( + isinstance(value, list) + and len(value) > 0 + and isinstance(value[0], DiffusionTestCase) + ): + ALL_CASES.extend(value) + elif isinstance(value, list) and len(value) == 0: + # Assume empty list with matching name is a valid case list container + pass + + # Deduplicate cases by ID + seen_ids = set() + unique_cases = [] + for c in ALL_CASES: + if c.id not in seen_ids: + seen_ids.add(c.id) + unique_cases.append(c) + ALL_CASES = unique_cases + +except Exception as e: + logger.warning(f"Failed to import test cases: {e}") + ALL_CASES = [] + + +def _get_status_message(run_id, current_case_id, thread_messages=None): + date_str = datetime.now().strftime("%d/%m") + base_header = f"""🧵 for nightly test of {date_str} +*Git Revision:* {get_git_commit_hash()} +*GitHub Run ID:* {run_id} +*Total Tasks:* {len(ALL_CASES)} +""" + + if not ALL_CASES: + return base_header + + default_emoji_for_case_in_progress = "⏳" + status_map = {c.id: default_emoji_for_case_in_progress for c in ALL_CASES} + + if thread_messages: + for msg in thread_messages: + text = msg.get("text", "") + # Look for case_id in the message (format: *Case ID:* `case_id`) + for c in ALL_CASES: + if f"*Case ID:* `{c.id}`" in text: + status_map[c.id] = "✅" + + if current_case_id: + status_map[current_case_id] = "✅" + + lines = [base_header, "", "*Tasks Status:*"] + + # Calculate padding + max_len = max(len(c.id) for c in ALL_CASES) if ALL_CASES else 10 + max_len = max(max_len, len("Case ID")) + + # Build markdown table inside a code block + table_lines = ["```"] + table_lines.append(f"| {'Case ID'.ljust(max_len)} | Status |") + table_lines.append(f"| {'-' * max_len} | :----: |") + + for c in ALL_CASES: + mark = status_map.get(c.id, default_emoji_for_case_in_progress) + table_lines.append(f"| {c.id.ljust(max_len)} | {mark} |") + + table_lines.append("```") + + lines.extend(table_lines) + + return "\n".join(lines) + + +def upload_file_to_slack( + case_id: str = None, + model: str = None, + prompt: str = None, + file_path: str = None, + origin_file_path: Union[str, List[str]] = None, +) -> bool: + temp_paths = [] + try: + from slack_sdk import WebClient + + run_id = os.getenv("GITHUB_RUN_ID", "local") + + token = os.environ.get("SGLANG_DIFFUSION_SLACK_TOKEN") + if not token: + logger.info(f"Slack upload failed: no token") + return False + + if not file_path or not os.path.exists(file_path): + logger.info(f"Slack upload failed: no file path") + return False + + origin_paths = [] + if isinstance(origin_file_path, str): + if origin_file_path: + origin_paths.append(origin_file_path) + elif isinstance(origin_file_path, list): + origin_paths = [p for p in origin_file_path if p] + + final_origin_paths = [] + for path in origin_paths: + if path.startswith(("http", "https")): + try: + suffix = os.path.splitext(urlparse(path).path)[1] or ".tmp" + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf: + with urlopen(path) as response: + tf.write(response.read()) + temp_paths.append(tf.name) + final_origin_paths.append(tf.name) + except Exception as e: + logger.warning(f"Failed to download {path}: {e}") + else: + final_origin_paths.append(path) + + uploads = [] + for i, path in enumerate(final_origin_paths): + if os.path.exists(path): + title = ( + "Original Image" + if len(final_origin_paths) == 1 + else f"Original Image {i+1}" + ) + uploads.append({"file": path, "title": title}) + + uploads.append({"file": file_path, "title": "Generated Image"}) + + message = ( + f"*Case ID:* `{case_id}`\n" f"*Model:* `{model}`\n" f"*Prompt:* {prompt}" + ) + + client = WebClient(token=token) + channel_id = "C0A02NDF7UY" + thread_ts = None + + parent_msg_text = None + try: + history = client.conversations_history(channel=channel_id, limit=100) + for msg in history.get("messages", []): + if f"*GitHub Run ID:* {run_id}" in msg.get("text", ""): + # Use thread_ts if it exists (msg is a reply), otherwise use ts (msg is a parent) + thread_ts = msg.get("thread_ts") or msg.get("ts") + parent_msg_text = msg.get("text", "") + logger.info(f"Found thread_ts: {thread_ts}") + break + except Exception as e: + logger.warning(f"Failed to search slack history: {e}") + + if not thread_ts: + try: + text = _get_status_message(run_id, case_id) + response = client.chat_postMessage(channel=channel_id, text=text) + thread_ts = response["ts"] + except Exception as e: + logger.warning(f"Failed to create parent thread: {e}") + + # Upload first to ensure it's in history + client.files_upload_v2( + channel=channel_id, + file_uploads=uploads, + initial_comment=message, + thread_ts=thread_ts, + ) + + # Then update status based on thread replies + if thread_ts: + try: + replies = client.conversations_replies( + channel=channel_id, ts=thread_ts, limit=200 + ) + messages = replies.get("messages", []) + new_text = _get_status_message(run_id, case_id, messages) + + # Only update if changed significantly (ignoring timestamp diffs if any) + # But here we just check text content + if new_text != parent_msg_text: + client.chat_update(channel=channel_id, ts=thread_ts, text=new_text) + except Exception as e: + logger.warning(f"Failed to update parent message: {e}") + + logger.info(f"File uploaded successfully: {os.path.basename(file_path)}") + return True + + except Exception as e: + logger.info(f"Slack upload failed: {e}") + return False + finally: + for p in temp_paths: + if os.path.exists(p): + os.remove(p) diff --git a/sglang/python/sglang/multimodal_gen/test/test_files/launch_flux.json b/sglang/python/sglang/multimodal_gen/test/test_files/launch_flux.json new file mode 100644 index 0000000000000000000000000000000000000000..6a9d838209918a03c813151a7c69fa81cb05dcfc --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/test_files/launch_flux.json @@ -0,0 +1,11 @@ +{ + "model_path": "black-forest-labs/FLUX.1-dev", + "prompt": "A beautiful woman in a red dress walking down a street", + "text_encoder_cpu_offload": true, + "pin_cpu_memory": true, + "save_output": true, + "width": 720, + "height": 720, + "output_path": "outputs", + "output_file_name": "FLUX.1-dev, single gpu" +} diff --git a/sglang/python/sglang/multimodal_gen/test/test_files/launch_wan.json b/sglang/python/sglang/multimodal_gen/test/test_files/launch_wan.json new file mode 100644 index 0000000000000000000000000000000000000000..eeb9ddf9dd9ac863be409db7d354fbc0f4c470db --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/test_files/launch_wan.json @@ -0,0 +1,11 @@ +{ + "model_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + "prompt": "A beautiful woman in a red dress walking down a street", + "text_encoder_cpu_offload": true, + "pin_cpu_memory": true, + "save_output": true, + "width": 720, + "height": 720, + "output_path": "outputs", + "output_file_name": "Wan2.1-T2V-1.3B-Diffusers, single gpu" +} diff --git a/sglang/python/sglang/multimodal_gen/test/test_utils.py b/sglang/python/sglang/multimodal_gen/test/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..62e738a81b24b42f2bd5b188133cff824ba7d23e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/test_utils.py @@ -0,0 +1,642 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo +import base64 +import io +import json +import os +import socket +import subprocess +import tempfile +import time +from pathlib import Path +from urllib.parse import urljoin + +import cv2 +import httpx +import numpy as np +from PIL import Image + +from sglang.multimodal_gen.runtime.utils.common import get_bool_env_var +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.runtime.utils.perf_logger import ( + RequestPerfRecord, + get_diffusion_perf_log_dir, +) + +logger = init_logger(__name__) + +# --------------------------------------------------------------------------- +# Common model IDs for diffusion tests +# +# Centralised here so every test file references the same constants instead +# of scattering hard-coded strings. When adding a new model that will be +# reused across tests, define it here. +# --------------------------------------------------------------------------- + +DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "Tongyi-MAI/Z-Image-Turbo" + +# Qwen image generation models +DEFAULT_QWEN_IMAGE_MODEL_NAME_FOR_TEST = "Qwen/Qwen-Image" +DEFAULT_QWEN_IMAGE_2512_MODEL_NAME_FOR_TEST = "Qwen/Qwen-Image-2512" +DEFAULT_QWEN_IMAGE_EDIT_MODEL_NAME_FOR_TEST = "Qwen/Qwen-Image-Edit" +DEFAULT_QWEN_IMAGE_EDIT_2509_MODEL_NAME_FOR_TEST = "Qwen/Qwen-Image-Edit-2509" +DEFAULT_QWEN_IMAGE_EDIT_2511_MODEL_NAME_FOR_TEST = "Qwen/Qwen-Image-Edit-2511" +DEFAULT_QWEN_IMAGE_LAYERED_MODEL_NAME_FOR_TEST = "Qwen/Qwen-Image-Layered" + +# FLUX image generation models +DEFAULT_FLUX_1_DEV_MODEL_NAME_FOR_TEST = "black-forest-labs/FLUX.1-dev" +DEFAULT_FLUX_2_DEV_MODEL_NAME_FOR_TEST = "black-forest-labs/FLUX.2-dev" +DEFAULT_FLUX_2_KLEIN_4B_MODEL_NAME_FOR_TEST = "black-forest-labs/FLUX.2-klein-4B" +DEFAULT_FLUX_2_KLEIN_BASE_4B_MODEL_NAME_FOR_TEST = ( + "black-forest-labs/FLUX.2-klein-base-4B" +) + +# Wan video generation models +DEFAULT_WAN_2_1_T2V_1_3B_MODEL_NAME_FOR_TEST = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" +DEFAULT_WAN_2_1_T2V_14B_MODEL_NAME_FOR_TEST = "Wan-AI/Wan2.1-T2V-14B-Diffusers" +DEFAULT_WAN_2_1_I2V_14B_480P_MODEL_NAME_FOR_TEST = ( + "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" +) +DEFAULT_WAN_2_1_I2V_14B_720P_MODEL_NAME_FOR_TEST = ( + "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers" +) +DEFAULT_WAN_2_2_TI2V_5B_MODEL_NAME_FOR_TEST = "Wan-AI/Wan2.2-TI2V-5B-Diffusers" +DEFAULT_WAN_2_2_T2V_A14B_MODEL_NAME_FOR_TEST = "Wan-AI/Wan2.2-T2V-A14B-Diffusers" +DEFAULT_WAN_2_2_I2V_A14B_MODEL_NAME_FOR_TEST = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" + + +def print_value_formatted(description: str, value: int | float | str): + """Helper function to print a metric value formatted.""" + if isinstance(value, int): + if value >= 1e6: + value_str = f"{value / 1e6:<30.2f}M" + elif value >= 1e3: + value_str = f"{value / 1e3:<30.2f}K" + else: + value_str = f"{value:<30}" + elif isinstance(value, float): + value_str = f"{value:<30.2f}" + else: + value_str = f"{value:<30}" + + print(f"{description:<45} {value_str}") + + +def print_divider(length: int, char: str = "-"): + """Helper function to print a divider line.""" + print(char * length) + + +def is_image_url(image_path: str | Path | None) -> bool: + """Check if image_path is a URL.""" + if image_path is None: + return False + return isinstance(image_path, str) and ( + image_path.startswith("http://") or image_path.startswith("https://") + ) + + +def probe_port(host="127.0.0.1", port=30010, timeout=2.0) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.settimeout(timeout) + try: + s.connect((host, port)) + return True + except OSError: + return False + + +def is_in_ci() -> bool: + return get_bool_env_var("SGLANG_IS_IN_CI") + + +def get_dynamic_server_port() -> int: + cuda_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + if not cuda_devices: + cuda_devices = "0" + try: + first_device_id = int(cuda_devices.split(",")[0].strip()[0]) + except (ValueError, IndexError): + first_device_id = 0 + + if is_in_ci(): + base_port = 10000 + first_device_id * 2000 + else: + base_port = 20000 + first_device_id * 1000 + + return base_port + 1000 + + +def find_free_port(host: str = "127.0.0.1") -> int: + """Bind to port 0 and let the OS assign an available port.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind((host, 0)) + return s.getsockname()[1] + + +def wait_for_server_health( + base_url: str, + path: str = "/health", + timeout: float = 180.0, + interval: float = 1.0, +) -> None: + """Poll ``GET `` until it returns HTTP 200.""" + deadline = time.time() + timeout + last_err: httpx.RequestError | None = None + last_status: int | None = None + while time.time() < deadline: + try: + r = httpx.get(urljoin(base_url, path), timeout=5.0) + last_status = r.status_code + if r.status_code == 200: + return + except httpx.RequestError as e: + last_err = e + time.sleep(interval) + raise TimeoutError( + f"Server at {urljoin(base_url, path)} not healthy after {timeout}s. " + f"{last_status=} {last_err=}" + ) + + +def post_json( + base_url: str, + path: str, + payload: dict, + timeout: float = 300.0, +) -> httpx.Response: + """POST JSON to ```` and return the response.""" + return httpx.post(urljoin(base_url, path), json=payload, timeout=timeout) + + +# --------------------------------------------------------------------------- +# GPU memory helpers (nvidia-smi) +# --------------------------------------------------------------------------- + + +def query_gpu_mem_used_mib(gpu_index: int = 0, required: bool = False) -> int | None: + """Return GPU memory usage in MiB via ``nvidia-smi``, or *None* on failure. + + When *required* is ``True`` the function raises instead of returning ``None``. + """ + try: + out = subprocess.check_output( + [ + "nvidia-smi", + f"--id={gpu_index}", + "--query-gpu=memory.used", + "--format=csv,noheader,nounits", + ], + text=True, + ).strip() + return int(out.splitlines()[0].strip()) + except Exception as e: + logger.warning(f"nvidia-smi memory query failed: {type(e).__name__}: {e}") + assert not required, ( + "nvidia-smi memory query is unavailable; " + "cannot enforce GPU memory assertions." + ) + return None + + +def require_gpu_mem_query(gpu_index: int = 0) -> int: + """Same as :func:`query_gpu_mem_used_mib` but asserts availability. + + Raises ``AssertionError`` when ``nvidia-smi`` is unavailable instead of + returning ``None``, so callers can rely on a valid ``int`` result. + """ + mem = query_gpu_mem_used_mib(gpu_index, required=True) + assert mem is not None + return mem + + +def assert_gpu_mem_changed( + label: str, + before_mib: int, + after_mib: int, + min_delta_mib: int, +) -> None: + """Assert that GPU memory changed by at least *min_delta_mib* MiB.""" + delta = abs(after_mib - before_mib) + logger.debug( + f"[MEM] {label}: before={before_mib} MiB after={after_mib} MiB |delta|={delta} MiB" + ) + assert delta >= min_delta_mib, ( + f"GPU memory change too small for '{label}': " + f"|after-before|={delta} MiB < {min_delta_mib} MiB " + f"(before={before_mib} MiB, after={after_mib} MiB)" + ) + + +def is_mp4(data: bytes) -> bool: + """Check if data represents a valid MP4 file by magic bytes.""" + if len(data) < 8: + return False + return data[4:8] == b"ftyp" + + +def is_jpeg(data: bytes) -> bool: + # JPEG files start with: FF D8 FF + return data.startswith(b"\xff\xd8\xff") + + +def is_png(data): + # PNG files start with: 89 50 4E 47 0D 0A 1A 0A + return data.startswith(b"\x89PNG\r\n\x1a\n") + + +def is_webp(data: bytes) -> bool: + # WebP files start with: RIFF....WEBP + return data[:4] == b"RIFF" and data[8:12] == b"WEBP" + + +def detect_image_format(data: bytes) -> str: + """Detect image format from bytes (magic). Returns 'png'|'jpeg'|'webp'; default 'png'.""" + if len(data) < 12: + return "png" + if is_png(data): + return "png" + if is_jpeg(data): + return "jpeg" + if is_webp(data): + return "webp" + return "png" + + +def get_expected_image_format( + output_format: str | None = None, + background: str | None = None, +) -> str: + """Infer expected image format based on request parameters. + Args: + output_format: The output_format parameter from the request (png/jpeg/webp/jpg) + background: The background parameter from the request (transparent/opaque/auto) + Returns: + Expected file extension: "jpg", "png", or "webp" + """ + fmt = (output_format or "").lower() + if fmt in {"png", "webp", "jpeg", "jpg"}: + return "jpg" if fmt == "jpeg" else fmt + if (background or "auto").lower() == "transparent": + return "png" + return "jpg" # Default + + +def wait_for_port(host="127.0.0.1", port=30010, deadline=300.0, interval=0.5): + end = time.time() + deadline + last_err = None + while time.time() < end: + if probe_port(host, port, timeout=interval): + return True + time.sleep(interval) + raise TimeoutError(f"Port {host}:{port} not ready. Last error: {last_err}") + + +def check_image_size(ut, image, width, height): + # check image size + ut.assertEqual(image.size, (width, height)) + + +def get_perf_log_dir() -> Path: + """Gets the performance log directory from the centralized sglang utility.""" + log_dir_str = get_diffusion_perf_log_dir() + if not log_dir_str: + raise RuntimeError( + "Performance logging is disabled (SGLANG_PERF_LOG_DIR is empty), " + "but a test tried to access the log directory." + ) + return Path(log_dir_str) + + +def _ensure_log_path(log_dir: Path) -> Path: + log_dir.mkdir(parents=True, exist_ok=True) + return log_dir / "performance.log" + + +def clear_perf_log(log_dir: Path) -> Path: + """Delete the perf log file so tests can watch for fresh entries.""" + log_path = _ensure_log_path(log_dir) + if log_path.exists(): + log_path.unlink() + logger.info("[server-test] Monitoring perf log at %s", log_path.as_posix()) + return log_path + + +def prepare_perf_log() -> tuple[Path, Path]: + """Convenience helper to resolve and clear the perf log in one call.""" + log_dir = get_perf_log_dir() + log_path = clear_perf_log(log_dir) + return log_dir, log_path + + +def read_perf_logs(log_path: Path) -> list[RequestPerfRecord]: + if not log_path.exists(): + return [] + records: list[RequestPerfRecord] = [] + with log_path.open("r", encoding="utf-8") as fh: + for line in fh: + line = line.strip() + if not line: + continue + try: + record_dict = json.loads(line) + records.append(RequestPerfRecord(**record_dict)) + except json.JSONDecodeError: + continue + return records + + +def wait_for_req_perf_record( + request_id: str, + log_path: Path, + timeout: float = 30.0, +) -> RequestPerfRecord | None: + """ + the stage metrics of this request should be in the performance_log file with {request-id} + """ + logger.info(f"Waiting for req perf record with request id: {request_id}") + deadline = time.time() + timeout + while time.time() < deadline: + records = read_perf_logs(log_path) + for record in records: + if record.request_id == request_id: + return record + + time.sleep(0.5) + + if os.environ.get("SGLANG_GEN_BASELINE", "0") == "1": + return None + + logger.error(f"record: {records}") + raise AssertionError(f"Timeout waiting for stage metrics for request {request_id} ") + + +def validate_image(b64_json: str) -> None: + """Decode and validate that image is PNG or JPEG.""" + image_bytes = base64.b64decode(b64_json) + assert is_png(image_bytes) or is_jpeg(image_bytes), "Image must be PNG or JPEG" + + +def validate_video(b64_json: str) -> None: + """Decode and validate that video is a valid format.""" + video_bytes = base64.b64decode(b64_json) + is_webm = video_bytes[:4] == b"\x1a\x45\xdf\xa3" + assert is_mp4(video_bytes) or is_webm, "Video must be MP4 or WebM" + + +def validate_openai_video(video_bytes: bytes) -> None: + """Validate that video is MP4 or WebM by magic bytes.""" + is_webm = video_bytes.startswith(b"\x1a\x45\xdf\xa3") + assert is_mp4(video_bytes) or is_webm, "Video must be MP4 or WebM" + + +def validate_image_file( + file_path: str, + expected_filename: str, + expected_width: int | None = None, + expected_height: int | None = None, + output_format: str | None = None, + background: str | None = None, +) -> None: + """Validate image output file: existence, extension, size, filename, format, dimensions.""" + # Infer expected format from request parameters + expected_ext = get_expected_image_format(output_format, background) + + # 1. File existence + assert os.path.exists(file_path), f"Image file does not exist: {file_path}" + + # 2. Extension check + assert file_path.endswith( + f".{expected_ext}" + ), f"Expected .{expected_ext} extension, got: {file_path}" + + # 3. File size > 0 + file_size = os.path.getsize(file_path) + assert file_size > 0, f"Image file is empty: {file_path}" + + # 4. Filename validation + actual_filename = os.path.basename(file_path) + assert ( + actual_filename == expected_filename + ), f"Filename mismatch: expected '{expected_filename}', got '{actual_filename}'" + + # 5. Image format validation (magic bytes check based on expected format) + with open(file_path, "rb") as f: + header = f.read(12) # Read enough bytes for webp detection + if expected_ext == "png": + assert is_png(header), f"File is not a valid PNG: {file_path}" + elif expected_ext == "jpg": + assert is_jpeg(header), f"File is not a valid JPEG: {file_path}" + elif expected_ext == "webp": + assert is_webp(header), f"File is not a valid WebP: {file_path}" + + # 6. Image dimension validation (reuse PIL) + if expected_width is not None and expected_height is not None: + with Image.open(file_path) as img: + width, height = img.size + assert ( + width == expected_width + ), f"Width mismatch: expected {expected_width}, got {width}" + assert ( + height == expected_height + ), f"Height mismatch: expected {expected_height}, got {height}" + + +def _get_video_dimensions_from_metadata( + cap: cv2.VideoCapture, +) -> tuple[int, int] | None: + """Get video dimensions from metadata properties. + + Args: + cap: OpenCV VideoCapture object + + Returns: + Tuple of (width, height) if successful, None if metadata is invalid + """ + width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) + height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) + + if width == 0 or height == 0: + return None + + return int(width), int(height) + + +def _get_video_dimensions_from_frame(cap: cv2.VideoCapture) -> tuple[int, int]: + """Get video dimensions by reading the first frame. + + Args: + cap: OpenCV VideoCapture object + + Returns: + Tuple of (width, height) + + """ + ret, frame = cap.read() + if not ret or frame is None: + raise ValueError("Unable to read video frame to get dimensions") + + # frame.shape is (height, width, channels) + height, width = frame.shape[:2] + return int(width), int(height) + + +def get_video_dimensions(file_path: str) -> tuple[int, int]: + """Get video dimensions (width, height) from a video file. + + Tries to get dimensions from metadata first, falls back to reading first frame. + + Returns: + Tuple of (width, height) + + """ + cap = cv2.VideoCapture(file_path) + try: + # Try to get dimensions from metadata first + dimensions = _get_video_dimensions_from_metadata(cap) + if dimensions is not None: + return dimensions + + # Fall back to reading first frame + return _get_video_dimensions_from_frame(cap) + finally: + cap.release() + + +def get_video_frame_count(file_path: str) -> int: + """Return the number of frames in a video file using OpenCV.""" + cap = cv2.VideoCapture(file_path) + try: + count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + if count > 0: + return count + # Fallback: count frames manually + n = 0 + while cap.read()[0]: + n += 1 + return n + finally: + cap.release() + + +def validate_video_file( + file_path: str, + expected_filename: str, + expected_width: int | None = None, + expected_height: int | None = None, +) -> None: + """Validate video output file: existence, extension, size, filename, format, dimensions.""" + # 1. File existence + assert os.path.exists(file_path), f"Video file does not exist: {file_path}" + + # 2. Extension check + assert file_path.endswith(".mp4"), f"Expected .mp4 extension, got: {file_path}" + + # 3. File size > 0 + file_size = os.path.getsize(file_path) + assert file_size > 0, f"Video file is empty: {file_path}" + + # 4. Filename validation + actual_filename = os.path.basename(file_path) + assert ( + actual_filename == expected_filename + ), f"Filename mismatch: expected '{expected_filename}', got '{actual_filename}'" + + # 5. Video format validation (reuse is_mp4) + with open(file_path, "rb") as f: + header = f.read(32) + assert is_mp4(header), f"File is not a valid MP4: {file_path}" + + # 6. Video dimension validation (using OpenCV) + if expected_width is not None and expected_height is not None: + actual_width, actual_height = get_video_dimensions(file_path) + assert ( + actual_width == expected_width + ), f"Video width mismatch: expected {expected_width}, got {actual_width}" + assert ( + actual_height == expected_height + ), f"Video height mismatch: expected {expected_height}, got {actual_height}" + + +def output_format_to_ext(output_format: str | None) -> str: + """Map output_format to file extension. Used by GT naming and consistency check.""" + if not output_format: + return "png" + of = output_format.lower() + if of == "jpeg": + return "jpg" + if of in ("png", "webp", "jpg"): + return of + return "png" + + +def _consistency_gt_filenames( + case_id: str, num_gpus: int, is_video: bool, output_format: str | None = None +) -> list[str]: + """Return the list of GT image filenames for a case. Reused by GT generation and consistency check.""" + n = num_gpus + if is_video: + return [ + f"{case_id}_{n}gpu_frame_0.png", + f"{case_id}_{n}gpu_frame_mid.png", + f"{case_id}_{n}gpu_frame_last.png", + ] + ext = output_format_to_ext(output_format) + return [f"{case_id}_{n}gpu.{ext}"] + + +def extract_key_frames_from_video( + video_bytes: bytes, + num_frames: int | None = None, +) -> list[np.ndarray]: + """ + Extract key frames (first, middle, last) from video bytes. + + Args: + video_bytes: Raw video bytes (MP4 format) + num_frames: Total number of frames (if known), used for validation + + Returns: + List of numpy arrays [first_frame, middle_frame, last_frame]. + """ + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: + tmp.write(video_bytes) + tmp_path = tmp.name + + try: + cap = cv2.VideoCapture(tmp_path) + if not cap.isOpened(): + raise ValueError("Failed to open video file") + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + if total_frames < 1: + raise ValueError("Video has no frames") + + first_idx = 0 + mid_idx = total_frames // 2 + last_idx = total_frames - 1 + key_indices = [first_idx, mid_idx, last_idx] + + frames = [] + for idx in key_indices: + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + ret, frame = cap.read() + if not ret: + raise ValueError(f"Failed to read frame at index {idx}") + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(frame_rgb) + + cap.release() + logger.info( + f"Extracted {len(frames)} key frames from video " + f"(total: {total_frames}, indices: {key_indices})" + ) + return frames + + finally: + os.unlink(tmp_path) + + +def image_bytes_to_numpy(image_bytes: bytes) -> np.ndarray: + """Convert image bytes to numpy array.""" + img = Image.open(io.BytesIO(image_bytes)).convert("RGB") + return np.array(img) diff --git a/sglang/python/sglang/multimodal_gen/test/unit/test_lora_format_adapter.py b/sglang/python/sglang/multimodal_gen/test/unit/test_lora_format_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..48bf29af3b6f4f1d26f3fb59e2c665c5f888a033 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/unit/test_lora_format_adapter.py @@ -0,0 +1,324 @@ +""" +test_lora_format_adapter.py + +Small regression test for the LoRA format adapter. + +It downloads several public LoRA checkpoints from Hugging Face, runs +format detection and normalization, and prints a compact summary table. +""" + +import logging +import os +import tempfile +from typing import Dict, List + +import torch +from huggingface_hub import hf_hub_download +from safetensors.torch import load_file + +from sglang.multimodal_gen.runtime.pipelines_core.lora_format_adapter import ( + LoRAFormat, + detect_lora_format_from_state_dict, + normalize_lora_state_dict, +) + +logging.basicConfig(level=logging.INFO, force=True) +logger = logging.getLogger("lora_test") + +ROOT_DIR = os.path.join(tempfile.gettempdir(), "sglang_lora_tests") +os.makedirs(ROOT_DIR, exist_ok=True) + + +def download_lora( + repo_id: str, + filename: str, + local_name: str, +) -> str: + """ + Download a LoRA safetensors file into ROOT_DIR and return its local path. + """ + print(f"=== Downloading LoRA from {repo_id} ({filename}) ===") + path = hf_hub_download( + repo_id=repo_id, + filename=filename, + local_dir=ROOT_DIR, + local_dir_use_symlinks=False, + ) + dst = os.path.join(ROOT_DIR, local_name) + if os.path.abspath(path) != os.path.abspath(dst): + try: + import shutil + + shutil.copy2(path, dst) + except Exception: + dst = path + print(f"Saved to: {dst}") + return dst + + +def is_diffusers_style_keys( + sd: Dict[str, torch.Tensor], + debug_name: str = "", +) -> bool: + """ + Relaxed structural check that a state_dict looks like diffusers-style LoRA. + + The check verifies: + 1) No known non-diffusers prefixes. + 2) No non-diffusers suffixes such as alpha / dora_scale / magnitude vectors. + 3) Most top-level roots match common diffusers module namespaces. + """ + if not sd: + print(f"[{debug_name}] diffusers-style check: EMPTY state_dict") + return False + + keys: List[str] = list(sd.keys()) + total = len(keys) + + banned_prefixes = ( + "lora_unet_", + "lora_te_", + "lora_te1_", + "lora_te2_", + "lora_unet_double_blocks_", + "lora_unet_single_blocks_", + ) + bad_prefix_keys = [k for k in keys if k.startswith(banned_prefixes)] + cond1 = len(bad_prefix_keys) == 0 + + banned_suffixes = ( + ".alpha", + ".dora_scale", + ".lora_magnitude_vector", + ) + bad_suffix_keys = [k for k in keys if k.endswith(banned_suffixes)] + cond2 = len(bad_suffix_keys) == 0 + + allowed_roots = { + "unet", + "text_encoder", + "text_encoder_2", + "transformer", + "prior", + "image_encoder", + "vae", + "diffusion_model", + } + root_names = [k.split(".", 1)[0] for k in keys] + root_ok_count = sum(r in allowed_roots for r in root_names) + cond3 = root_ok_count >= 0.6 * total + + ok = cond1 and cond2 and cond3 + + if not ok: + print(f"[{debug_name}] diffusers-style check FAILED (relaxed):") + print(f" total keys = {total}") + print( + f" cond1(no banned prefixes) = {cond1}, bad_prefix_keys={len(bad_prefix_keys)}" + ) + if not cond1 and bad_prefix_keys: + print(" example bad prefix key:", bad_prefix_keys[0]) + print( + f" cond2(no banned suffixes) = {cond2}, bad_suffix_keys={len(bad_suffix_keys)}" + ) + if not cond2 and bad_suffix_keys: + print(" example bad suffix key:", bad_suffix_keys[0]) + print(f" cond3(allowed roots>=60%) = {cond3}, root_ok_count={root_ok_count}") + return ok + + +def run_single_test( + name: str, + repo_id: str, + filename: str, + local_name: str, + expected_before: LoRAFormat, + expected_after: LoRAFormat = LoRAFormat.STANDARD, +): + """ + Run a single end-to-end test for one LoRA checkpoint. + + Steps: + 1) Download. + 2) Detect format on raw keys. + 3) Normalize via lora_format_adapter. + 4) Detect again on the normalized dict. + 5) Optionally check for diffusers-style key structure. + """ + logger.info(f"=== Running test: {name} ===") + local_path = download_lora(repo_id, filename, local_name) + raw_state = load_file(local_path) + + detected_before = detect_lora_format_from_state_dict(raw_state) + norm_state = normalize_lora_state_dict(raw_state, logger=logger) + detected_after = detect_lora_format_from_state_dict(norm_state) + standard_like = is_diffusers_style_keys(norm_state, debug_name=name) + + passed = detected_before == expected_before and detected_after == expected_after + + return { + "name": name, + "expected_before": expected_before.value, + "detected_before": detected_before.value, + "expected_after": expected_after.value, + "detected_after": detected_after.value, + "standard_like_keys": standard_like, + "pass": passed, + "num_keys_raw": len(raw_state), + "num_keys_norm": len(norm_state), + } + + +def _run_all_tests() -> List[Dict]: + results: List[Dict] = [] + + # SDXL LoRA that is already in diffusers/PEFT format. + results.append( + run_single_test( + name="HF standard SDXL LoRA", + repo_id="jbilcke-hf/sdxl-cinematic-1", + filename="pytorch_lora_weights.safetensors", + local_name="sdxl_cinematic1_pytorch_lora_weights.safetensors", + expected_before=LoRAFormat.STANDARD, + expected_after=LoRAFormat.STANDARD, + ) + ) + + # XLabs FLUX LoRA (non-diffusers → diffusers). + results.append( + run_single_test( + name="XLabs FLUX Realism LoRA", + repo_id="XLabs-AI/flux-RealismLora", + filename="lora.safetensors", + local_name="flux_realism_lora.safetensors", + expected_before=LoRAFormat.XLABS_FLUX, + expected_after=LoRAFormat.STANDARD, + ) + ) + + # Kohya-style FLUX LoRA (sd-scripts flux_lora.py → diffusers). + results.append( + run_single_test( + name="Kohya-style Flux LoRA", + repo_id="kohya-ss/misc-models", + filename="flux-hasui-lora-d4-sigmoid-raw-gs1.0.safetensors", + local_name="flux_hasui_lora_d4_sigmoid_raw_gs1_0.safetensors", + expected_before=LoRAFormat.KOHYA_FLUX, + expected_after=LoRAFormat.STANDARD, + ) + ) + + # Classic Kohya/A1111 SD LoRA (non-diffusers SD → diffusers). + results.append( + run_single_test( + name="Kohya-style SD LoRA", + repo_id="kohya-ss/misc-models", + filename="fp-1f-chibi-1024.safetensors", + local_name="fp_1f_chibi_1024.safetensors", + expected_before=LoRAFormat.NON_DIFFUSERS_SD, + expected_after=LoRAFormat.STANDARD, + ) + ) + + # Wan2.1 Fun Reward LoRA (ComfyUI format → diffusers). + results.append( + run_single_test( + name="Wan2.1 Fun Reward LoRA (Comfy)", + repo_id="alibaba-pai/Wan2.1-Fun-Reward-LoRAs", + filename="Wan2.1-Fun-1.3B-InP-MPS.safetensors", + local_name="wan21_fun_1_3b_inp_mps.safetensors", + expected_before=LoRAFormat.NON_DIFFUSERS_SD, + expected_after=LoRAFormat.STANDARD, + ) + ) + + # Qwen-Image EVA LoRA (already diffusers/PEFT-style). + results.append( + run_single_test( + name="Qwen-Image EVA LoRA", + repo_id="starsfriday/Qwen-Image-EVA-LoRA", + filename="qwen_image_eva.safetensors", + local_name="qwen_image_eva.safetensors", + expected_before=LoRAFormat.STANDARD, + expected_after=LoRAFormat.STANDARD, + ) + ) + + # Qwen-Image Lightning LoRA (non-diffusers Qwen → diffusers). + results.append( + run_single_test( + name="Qwen-Image Lightning LoRA", + repo_id="lightx2v/Qwen-Image-Lightning", + filename="Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors", + local_name="qwen_image_lightning_4steps_v1_bf16.safetensors", + expected_before=LoRAFormat.NON_DIFFUSERS_SD, + expected_after=LoRAFormat.STANDARD, + ) + ) + + # Classic Painting Z-Image Turbo LoRA (Z-Image family). + results.append( + run_single_test( + name="Classic Painting Z-Image LoRA", + repo_id="renderartist/Classic-Painting-Z-Image-Turbo-LoRA", + filename="Classic_Painting_Z_Image_Turbo_v1_renderartist_1750.safetensors", + local_name="classic_painting_z_image_turbo_v1_renderartist_1750.safetensors", + expected_before=LoRAFormat.STANDARD, + expected_after=LoRAFormat.STANDARD, + ) + ) + + return results + + +def _print_summary(results: List[Dict]) -> None: + print("\n================ LoRA format adapter test ================") + + header = ( + f"{'Test Name':30} " + f"{'Exp(b)':12} " + f"{'Act(b)':12} " + f"{'Exp(a)':12} " + f"{'Act(a)':12} " + f"{'StdLike':8} " + f"{'#Raw':7} " + f"{'#Norm':7} " + f"{'PASS':5}" + ) + print(header) + print("-" * len(header)) + + for r in results: + print( + f"{r['name'][:30]:30} " + f"{r['expected_before'][:12]:12} " + f"{r['detected_before'][:12]:12} " + f"{r['expected_after'][:12]:12} " + f"{r['detected_after'][:12]:12} " + f"{str(r['standard_like_keys']):8} " + f"{r['num_keys_raw']:7d} " + f"{r['num_keys_norm']:7d} " + f"{str(r['pass']):5}" + ) + + print("=========================================================\n") + + +def main() -> None: + results = _run_all_tests() + _print_summary(results) + + if not all(r["pass"] for r in results): + raise SystemExit(1) + + +class TestLoRAFormatAdapter: + def test_lora_format_adapter_all_formats(self): + results = _run_all_tests() + assert all( + r["pass"] for r in results + ), "At least one LoRA format adapter case failed" + + +if __name__ == "__main__": + main() diff --git a/sglang/python/sglang/multimodal_gen/test/unit/test_sampling_params_validate.py b/sglang/python/sglang/multimodal_gen/test/unit/test_sampling_params_validate.py new file mode 100644 index 0000000000000000000000000000000000000000..0373d1ccc6b08d6c9eb9a167184f937d1c1efdb5 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/unit/test_sampling_params_validate.py @@ -0,0 +1,49 @@ +import math +import unittest + +from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams + + +class TestSamplingParamsValidate(unittest.TestCase): + def test_prompt_path_suffix(self): + with self.assertRaisesRegex(ValueError, r"prompt_path"): + SamplingParams(prompt_path="bad.png") + + def test_num_outputs_per_prompt_must_be_positive(self): + with self.assertRaisesRegex(ValueError, r"num_outputs_per_prompt"): + SamplingParams(num_outputs_per_prompt=0) + + def test_fps_must_be_positive_int(self): + with self.assertRaisesRegex(ValueError, r"\bfps\b"): + SamplingParams(fps=0) + with self.assertRaisesRegex(ValueError, r"\bfps\b"): + SamplingParams(fps=None) # type: ignore[arg-type] + + def test_num_inference_steps_optional_but_if_set_must_be_positive(self): + SamplingParams(num_inference_steps=None) + with self.assertRaisesRegex(ValueError, r"num_inference_steps"): + SamplingParams(num_inference_steps=-1) + + def test_guidance_scale_must_be_finite_non_negative_if_set(self): + SamplingParams(guidance_scale=None) + with self.assertRaisesRegex(ValueError, r"guidance_scale"): + SamplingParams(guidance_scale=math.nan) + with self.assertRaisesRegex(ValueError, r"guidance_scale"): + SamplingParams(guidance_scale=-0.1) + + def test_guidance_rescale_must_be_finite_non_negative(self): + with self.assertRaisesRegex(ValueError, r"guidance_rescale"): + SamplingParams(guidance_rescale=-1.0) + with self.assertRaisesRegex(ValueError, r"guidance_rescale"): + SamplingParams(guidance_rescale=math.inf) + + def test_boundary_ratio_range(self): + SamplingParams(boundary_ratio=None) + with self.assertRaisesRegex(ValueError, r"boundary_ratio"): + SamplingParams(boundary_ratio=1.5) + with self.assertRaisesRegex(ValueError, r"boundary_ratio"): + SamplingParams(boundary_ratio=math.nan) + + +if __name__ == "__main__": + unittest.main() diff --git a/sglang/python/sglang/multimodal_gen/test/unit/test_server_args_unit.py b/sglang/python/sglang/multimodal_gen/test/unit/test_server_args_unit.py new file mode 100644 index 0000000000000000000000000000000000000000..59d7584d6f902bd7bf5c7638574c297cad2850f5 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/unit/test_server_args_unit.py @@ -0,0 +1,50 @@ +import os +import unittest + +from sglang.multimodal_gen.registry import _get_config_info +from sglang.multimodal_gen.runtime.server_args import ServerArgs + + +class TestServerArgsPathExpansion(unittest.TestCase): + def test_tilde_model_path_is_expanded(self): + args = ServerArgs.from_dict({"model_path": "~/fake/local/model"}) + expected = os.path.expanduser("~/fake/local/model") + self.assertEqual(args.model_path, expected) + self.assertFalse(args.model_path.startswith("~")) + + def test_absolute_path_is_unchanged(self): + args = ServerArgs.from_dict({"model_path": "/data/my-model"}) + self.assertEqual(args.model_path, "/data/my-model") + + +class TestModelIdResolution(unittest.TestCase): + def setUp(self): + _get_config_info.cache_clear() + + def test_model_id_overrides_arbitrary_local_path(self): + # a local path whose directory name does not match any HF repo name; + # --model-id tells the engine which config to use + info = _get_config_info("/data/my-custom-qwen", model_id="Qwen-Image") + self.assertIsNotNone(info) + from sglang.multimodal_gen.configs.pipeline_configs.qwen_image import ( + QwenImagePipelineConfig, + ) + + self.assertIs(info.pipeline_config_cls, QwenImagePipelineConfig) + + def test_model_id_works_after_tilde_expansion(self): + # simulate the full flow: user passes ~/..., engine expands and resolves + expanded = os.path.expanduser("~/.cache/huggingface/hub/bbb/snapshots/ccc") + _get_config_info.cache_clear() + info = _get_config_info(expanded, model_id="Qwen-Image") + self.assertIsNotNone(info) + + def test_model_id_unknown_falls_back_without_crash(self): + # unrecognized model_id: should warn and fall back to path-based detection + # with an unresolvable path, expect RuntimeError from the detector step + with self.assertRaises((RuntimeError, Exception)): + _get_config_info("/data/no-such-model", model_id="NonExistentModelXYZ") + + +if __name__ == "__main__": + unittest.main() diff --git a/sglang/python/sglang/multimodal_gen/test/unit/test_storage.py b/sglang/python/sglang/multimodal_gen/test/unit/test_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..18f0c9ef1de6ac5e07c051a8f712b772451ede59 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/test/unit/test_storage.py @@ -0,0 +1,232 @@ +""" +Test suite for S3 CloudStorage integration. + +Tests verify file upload, cleanup, URL generation, and error handling. +""" + +import asyncio +import importlib +import os +from types import SimpleNamespace + +import pytest + +import sglang.multimodal_gen.runtime.entrypoints.openai.storage as storage_mod +from sglang.multimodal_gen.runtime.entrypoints.openai.storage import CloudStorage + + +def _create_temp_file(tmp_path, name="test.png", content=b"\x89PNG\r\n\x1a\nfake"): + """Create a temporary test file.""" + p = tmp_path / name + p.write_bytes(content) + return str(p) + + +# UNIT TESTS + + +def test_upload_file_success(tmp_path): + """Test successful upload with correct URL generation.""" + file_path = _create_temp_file(tmp_path, "image.png") + + storage_mod.cloud_storage.enabled = True + storage_mod.cloud_storage.bucket_name = "my-bucket" + storage_mod.cloud_storage.endpoint_url = "https://s3.example.com" + storage_mod.cloud_storage.region_name = None + + called = {} + + def fake_upload(local_path, bucket, key, ExtraArgs=None): + called["local_path"] = local_path + called["bucket"] = bucket + called["key"] = key + called["extra"] = ExtraArgs + + storage_mod.cloud_storage.client = SimpleNamespace(upload_file=fake_upload) + + url = asyncio.run(storage_mod.cloud_storage.upload_file(file_path, "image.png")) + + assert url == "https://s3.example.com/my-bucket/image.png" + assert called["local_path"] == file_path + assert called["bucket"] == "my-bucket" + assert called["key"] == "image.png" + assert called["extra"]["ContentType"] == "image/png" + + +def test_upload_and_cleanup(tmp_path): + """Test that local file is deleted after successful upload.""" + file_path = _create_temp_file(tmp_path, "cleanup.png") + + storage_mod.cloud_storage.enabled = True + storage_mod.cloud_storage.bucket_name = "my-bucket" + storage_mod.cloud_storage.endpoint_url = "https://s3.example.com" + storage_mod.cloud_storage.client = SimpleNamespace( + upload_file=lambda *args, **kwargs: None + ) + + assert os.path.exists(file_path) + + url = asyncio.run(storage_mod.cloud_storage.upload_and_cleanup(file_path)) + + assert url == "https://s3.example.com/my-bucket/cleanup.png" + assert not os.path.exists(file_path) + + +def test_upload_failure_preserves_file(tmp_path): + """Test that file is preserved when upload fails.""" + file_path = _create_temp_file(tmp_path, "preserve.png") + + storage_mod.cloud_storage.enabled = True + storage_mod.cloud_storage.bucket_name = "my-bucket" + storage_mod.cloud_storage.endpoint_url = "https://s3.example.com" + + def fake_upload_raises(*args, **kwargs): + raise RuntimeError("simulated failure") + + storage_mod.cloud_storage.client = SimpleNamespace(upload_file=fake_upload_raises) + + result = asyncio.run(storage_mod.cloud_storage.upload_and_cleanup(file_path)) + + assert result is None + assert os.path.exists(file_path) + + +def test_disabled_storage_returns_none(tmp_path): + """Test that disabled storage returns None.""" + file_path = _create_temp_file(tmp_path, "test.png") + + prev_enabled = storage_mod.cloud_storage.enabled + storage_mod.cloud_storage.enabled = False + + try: + result = asyncio.run( + storage_mod.cloud_storage.upload_file(file_path, "test.png") + ) + assert result is None + finally: + storage_mod.cloud_storage.enabled = prev_enabled + + +def test_aws_url_with_region(tmp_path): + """Test AWS S3 URL generation with specific region.""" + file_path = _create_temp_file(tmp_path, "aws.png") + + storage_mod.cloud_storage.enabled = True + storage_mod.cloud_storage.bucket_name = "aws-bucket" + storage_mod.cloud_storage.endpoint_url = None + storage_mod.cloud_storage.region_name = "us-west-2" + storage_mod.cloud_storage.client = SimpleNamespace( + upload_file=lambda *args, **kwargs: None + ) + + url = asyncio.run(storage_mod.cloud_storage.upload_file(file_path, "aws.png")) + + assert url == "https://aws-bucket.s3.us-west-2.amazonaws.com/aws.png" + + +def test_aws_url_default_region(tmp_path): + """Test AWS S3 URL defaults to us-east-1 when region not specified.""" + file_path = _create_temp_file(tmp_path, "default.png") + + storage_mod.cloud_storage.enabled = True + storage_mod.cloud_storage.bucket_name = "default-bucket" + storage_mod.cloud_storage.endpoint_url = None + storage_mod.cloud_storage.region_name = None + storage_mod.cloud_storage.client = SimpleNamespace( + upload_file=lambda *args, **kwargs: None + ) + + url = asyncio.run(storage_mod.cloud_storage.upload_file(file_path, "default.png")) + + assert url == "https://default-bucket.s3.us-east-1.amazonaws.com/default.png" + + +def test_custom_endpoint_url(tmp_path): + """Test URL generation with custom endpoint (MinIO/OSS/COS).""" + file_path = _create_temp_file(tmp_path, "custom.png") + + storage_mod.cloud_storage.enabled = True + storage_mod.cloud_storage.bucket_name = "custom-bucket" + storage_mod.cloud_storage.endpoint_url = "https://minio.example.com/" + storage_mod.cloud_storage.region_name = None + storage_mod.cloud_storage.client = SimpleNamespace( + upload_file=lambda *args, **kwargs: None + ) + + url = asyncio.run(storage_mod.cloud_storage.upload_file(file_path, "custom.png")) + + # Verify trailing slash is stripped + assert url == "https://minio.example.com/custom-bucket/custom.png" + + +def test_content_type_detection(tmp_path): + """Test Content-Type header for different file extensions.""" + storage_mod.cloud_storage.enabled = True + storage_mod.cloud_storage.bucket_name = "test-bucket" + storage_mod.cloud_storage.endpoint_url = "https://s3.test" + + test_cases = [ + ("image.png", "image/png"), + ("image.jpg", "image/jpeg"), + ("image.jpeg", "image/jpeg"), + ("image.webp", "image/webp"), + ("video.mp4", "video/mp4"), + ("file.bin", "application/octet-stream"), + ] + + for filename, expected_type in test_cases: + called = {} + + def fake_upload(local_path, bucket, key, ExtraArgs=None): + called["content_type"] = ExtraArgs.get("ContentType") + + storage_mod.cloud_storage.client = SimpleNamespace(upload_file=fake_upload) + + file_path = _create_temp_file(tmp_path, filename) + asyncio.run(storage_mod.cloud_storage.upload_file(file_path, filename)) + + assert called["content_type"] == expected_type + + +# requires moto and boto3 +has_moto = ( + importlib.util.find_spec("moto") is not None + and importlib.util.find_spec("boto3") is not None +) + + +@pytest.mark.skipif(not has_moto, reason="moto/boto3 not installed") +def test_integration_with_moto(tmp_path): + """Integration test using moto to mock real S3 service.""" + import boto3 + from moto import mock_aws + + os.environ["SGLANG_CLOUD_STORAGE_TYPE"] = "s3" + os.environ["SGLANG_S3_BUCKET_NAME"] = "integration-test" + os.environ["SGLANG_S3_REGION_NAME"] = "us-east-1" + + with mock_aws(): + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="integration-test") + + storage = CloudStorage() + assert storage.is_enabled() + + file_path = _create_temp_file(tmp_path, "integration.png", b"test_data") + + url = asyncio.run(storage.upload_and_cleanup(file_path)) + + assert url is not None + assert "integration-test" in url + assert "integration.png" in url + assert not os.path.exists(file_path) + + obj = s3.get_object(Bucket="integration-test", Key="integration.png") + assert obj["Body"].read() == b"test_data" + + for key in [ + "SGLANG_CLOUD_STORAGE_TYPE", + "SGLANG_S3_BUCKET_NAME", + "SGLANG_S3_REGION_NAME", + ]: + os.environ.pop(key, None) diff --git a/sglang/python/sglang/multimodal_gen/third_party/__init__.py b/sglang/python/sglang/multimodal_gen/third_party/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..af2eb7d103a81378d30056cf536353bd24621a5e --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/third_party/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/sglang/python/sglang/multimodal_gen/third_party/pynvml.py b/sglang/python/sglang/multimodal_gen/third_party/pynvml.py new file mode 100644 index 0000000000000000000000000000000000000000..52467b0259319a28b96c20d21f12833b415d3e09 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/third_party/pynvml.py @@ -0,0 +1,7227 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# copied from https://pypi.org/project/nvidia-ml-py +# version 12.570.86 + +##### +# Copyright (c) 2011-2023, NVIDIA Corporation. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the NVIDIA Corporation nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +##### + +import os +import string +import sys +import threading + +## +# Python bindings for the NVML library +## +from ctypes import * +from functools import wraps + +## C Type mappings ## +## Enums +_nvmlEnableState_t = c_uint +NVML_FEATURE_DISABLED = 0 +NVML_FEATURE_ENABLED = 1 + +_nvmlBrandType_t = c_uint +NVML_BRAND_UNKNOWN = 0 +NVML_BRAND_QUADRO = 1 +NVML_BRAND_TESLA = 2 +NVML_BRAND_NVS = 3 +NVML_BRAND_GRID = ( + 4 # Deprecated from API reporting. Keeping definition for backward compatibility. +) +NVML_BRAND_GEFORCE = 5 +NVML_BRAND_TITAN = 6 +NVML_BRAND_NVIDIA_VAPPS = 7 # NVIDIA Virtual Applications +NVML_BRAND_NVIDIA_VPC = 8 # NVIDIA Virtual PC +NVML_BRAND_NVIDIA_VCS = 9 # NVIDIA Virtual Compute Server +NVML_BRAND_NVIDIA_VWS = 10 # NVIDIA RTX Virtual Workstation +NVML_BRAND_NVIDIA_CLOUD_GAMING = 11 # NVIDIA Cloud Gaming +NVML_BRAND_NVIDIA_VGAMING = NVML_BRAND_NVIDIA_CLOUD_GAMING # Deprecated from API reporting. Keeping definition for backward compatibility. +NVML_BRAND_QUADRO_RTX = 12 +NVML_BRAND_NVIDIA_RTX = 13 +NVML_BRAND_NVIDIA = 14 +NVML_BRAND_GEFORCE_RTX = 15 # Unused +NVML_BRAND_TITAN_RTX = 16 # Unused +NVML_BRAND_COUNT = 17 + +_nvmlTemperatureThresholds_t = c_uint +NVML_TEMPERATURE_THRESHOLD_SHUTDOWN = 0 +NVML_TEMPERATURE_THRESHOLD_SLOWDOWN = 1 +NVML_TEMPERATURE_THRESHOLD_MEM_MAX = 2 +NVML_TEMPERATURE_THRESHOLD_GPU_MAX = 3 +NVML_TEMPERATURE_THRESHOLD_ACOUSTIC_MIN = 4 +NVML_TEMPERATURE_THRESHOLD_ACOUSTIC_CURR = 5 +NVML_TEMPERATURE_THRESHOLD_ACOUSTIC_MAX = 6 +NVML_TEMPERATURE_THRESHOLD_GPS_CURR = 7 +NVML_TEMPERATURE_THRESHOLD_COUNT = 8 + +_nvmlTemperatureSensors_t = c_uint +NVML_TEMPERATURE_GPU = 0 +NVML_TEMPERATURE_COUNT = 1 + + +_nvmlComputeMode_t = c_uint +NVML_COMPUTEMODE_DEFAULT = 0 +NVML_COMPUTEMODE_EXCLUSIVE_THREAD = 1 ## Support Removed +NVML_COMPUTEMODE_PROHIBITED = 2 +NVML_COMPUTEMODE_EXCLUSIVE_PROCESS = 3 +NVML_COMPUTEMODE_COUNT = 4 + +_nvmlMemoryLocation_t = c_uint +NVML_MEMORY_LOCATION_L1_CACHE = 0 +NVML_MEMORY_LOCATION_L2_CACHE = 1 +NVML_MEMORY_LOCATION_DEVICE_MEMORY = 2 +NVML_MEMORY_LOCATION_DRAM = 2 +NVML_MEMORY_LOCATION_REGISTER_FILE = 3 +NVML_MEMORY_LOCATION_TEXTURE_MEMORY = 4 +NVML_MEMORY_LOCATION_TEXTURE_SHM = 5 +NVML_MEMORY_LOCATION_CBU = 6 +NVML_MEMORY_LOCATION_SRAM = 7 +NVML_MEMORY_LOCATION_COUNT = 8 + +NVML_NVLINK_MAX_LINKS = 18 + +# For backwards compatibility, maintain the incorrectly-named "LANES" define +NVML_NVLINK_MAX_LANES = NVML_NVLINK_MAX_LINKS + +_nvmlNvLinkErrorCounter_t = c_uint +NVML_NVLINK_ERROR_DL_REPLAY = 0 +NVML_NVLINK_ERROR_DL_RECOVERY = 1 +NVML_NVLINK_ERROR_DL_CRC_FLIT = 2 +NVML_NVLINK_ERROR_DL_CRC_DATA = 3 +NVML_NVLINK_ERROR_DL_ECC_DATA = 4 +NVML_NVLINK_ERROR_COUNT = 5 + +_nvmlNvLinkEccLaneErrorCounter_t = c_uint +NVML_NVLINK_ERROR_DL_ECC_LANE0 = 0 +NVML_NVLINK_ERROR_DL_ECC_LANE1 = 1 +NVML_NVLINK_ERROR_DL_ECC_LANE2 = 2 +NVML_NVLINK_ERROR_DL_ECC_LANE3 = 3 +NVML_NVLINK_ERROR_DL_ECC_COUNT = 5 + +_nvmlNvLinkCapability_t = c_uint +NVML_NVLINK_CAP_P2P_SUPPORTED = 0 +NVML_NVLINK_CAP_SYSMEM_ACCESS = 1 +NVML_NVLINK_CAP_P2P_ATOMICS = 2 +NVML_NVLINK_CAP_SYSMEM_ATOMICS = 3 +NVML_NVLINK_CAP_SLI_BRIDGE = 4 +NVML_NVLINK_CAP_VALID = 5 +NVML_NVLINK_CAP_COUNT = 6 + +_nvmlNvLinkUtilizationCountPktTypes_t = c_uint +NVML_NVLINK_COUNTER_PKTFILTER_NOP = 0x1 +NVML_NVLINK_COUNTER_PKTFILTER_READ = 0x2 +NVML_NVLINK_COUNTER_PKTFILTER_WRITE = 0x4 +NVML_NVLINK_COUNTER_PKTFILTER_RATOM = 0x8 +NVML_NVLINK_COUNTER_PKTFILTER_NRATOM = 0x10 +NVML_NVLINK_COUNTER_PKTFILTER_FLUSH = 0x20 +NVML_NVLINK_COUNTER_PKTFILTER_RESPDATA = 0x40 +NVML_NVLINK_COUNTER_PKTFILTER_RESPNODATA = 0x80 +NVML_NVLINK_COUNTER_PKTFILTER_ALL = 0xFF + +_nvmlNvLinkUtilizationCountUnits_t = c_uint +NVML_NVLINK_COUNTER_UNIT_CYCLES = 0 +NVML_NVLINK_COUNTER_UNIT_PACKETS = 1 +NVML_NVLINK_COUNTER_UNIT_BYTES = 2 +NVML_NVLINK_COUNTER_UNIT_RESERVED = 3 +NVML_NVLINK_COUNTER_UNIT_COUNT = 4 + +_nvmlNvLinkDeviceType_t = c_uint +NVML_NVLINK_DEVICE_TYPE_GPU = 0x00 +NVML_NVLINK_DEVICE_TYPE_IBMNPU = 0x01 +NVML_NVLINK_DEVICE_TYPE_SWITCH = 0x02 +NVML_NVLINK_DEVICE_TYPE_UNKNOWN = 0xFF + +# These are deprecated, instead use _nvmlMemoryErrorType_t +_nvmlEccBitType_t = c_uint +NVML_SINGLE_BIT_ECC = 0 +NVML_DOUBLE_BIT_ECC = 1 +NVML_ECC_ERROR_TYPE_COUNT = 2 + +_nvmlEccCounterType_t = c_uint +NVML_VOLATILE_ECC = 0 +NVML_AGGREGATE_ECC = 1 +NVML_ECC_COUNTER_TYPE_COUNT = 2 + +_nvmlMemoryErrorType_t = c_uint +NVML_MEMORY_ERROR_TYPE_CORRECTED = 0 +NVML_MEMORY_ERROR_TYPE_UNCORRECTED = 1 +NVML_MEMORY_ERROR_TYPE_COUNT = 2 + +_nvmlClockType_t = c_uint +NVML_CLOCK_GRAPHICS = 0 +NVML_CLOCK_SM = 1 +NVML_CLOCK_MEM = 2 +NVML_CLOCK_VIDEO = 3 +NVML_CLOCK_COUNT = 4 + +_nvmlClockId_t = c_uint +NVML_CLOCK_ID_CURRENT = 0 +NVML_CLOCK_ID_APP_CLOCK_TARGET = 1 +NVML_CLOCK_ID_APP_CLOCK_DEFAULT = 2 +NVML_CLOCK_ID_CUSTOMER_BOOST_MAX = 3 +NVML_CLOCK_ID_COUNT = 4 + +_nvmlDriverModel_t = c_uint +NVML_DRIVER_WDDM = 0 +NVML_DRIVER_WDM = 1 +NVML_DRIVER_MCDM = 2 + +NVML_MAX_GPU_PERF_PSTATES = 16 + +_nvmlPstates_t = c_uint +NVML_PSTATE_0 = 0 +NVML_PSTATE_1 = 1 +NVML_PSTATE_2 = 2 +NVML_PSTATE_3 = 3 +NVML_PSTATE_4 = 4 +NVML_PSTATE_5 = 5 +NVML_PSTATE_6 = 6 +NVML_PSTATE_7 = 7 +NVML_PSTATE_8 = 8 +NVML_PSTATE_9 = 9 +NVML_PSTATE_10 = 10 +NVML_PSTATE_11 = 11 +NVML_PSTATE_12 = 12 +NVML_PSTATE_13 = 13 +NVML_PSTATE_14 = 14 +NVML_PSTATE_15 = 15 +NVML_PSTATE_UNKNOWN = 32 + +_nvmlInforomObject_t = c_uint +NVML_INFOROM_OEM = 0 +NVML_INFOROM_ECC = 1 +NVML_INFOROM_POWER = 2 +NVML_INFOROM_DEN = 3 +NVML_INFOROM_COUNT = 4 + +_nvmlReturn_t = c_uint +NVML_SUCCESS = 0 +NVML_ERROR_UNINITIALIZED = 1 +NVML_ERROR_INVALID_ARGUMENT = 2 +NVML_ERROR_NOT_SUPPORTED = 3 +NVML_ERROR_NO_PERMISSION = 4 +NVML_ERROR_ALREADY_INITIALIZED = 5 +NVML_ERROR_NOT_FOUND = 6 +NVML_ERROR_INSUFFICIENT_SIZE = 7 +NVML_ERROR_INSUFFICIENT_POWER = 8 +NVML_ERROR_DRIVER_NOT_LOADED = 9 +NVML_ERROR_TIMEOUT = 10 +NVML_ERROR_IRQ_ISSUE = 11 +NVML_ERROR_LIBRARY_NOT_FOUND = 12 +NVML_ERROR_FUNCTION_NOT_FOUND = 13 +NVML_ERROR_CORRUPTED_INFOROM = 14 +NVML_ERROR_GPU_IS_LOST = 15 +NVML_ERROR_RESET_REQUIRED = 16 +NVML_ERROR_OPERATING_SYSTEM = 17 +NVML_ERROR_LIB_RM_VERSION_MISMATCH = 18 +NVML_ERROR_IN_USE = 19 +NVML_ERROR_MEMORY = 20 +NVML_ERROR_NO_DATA = 21 +NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22 +NVML_ERROR_INSUFFICIENT_RESOURCES = 23 +NVML_ERROR_FREQ_NOT_SUPPORTED = 24 +NVML_ERROR_ARGUMENT_VERSION_MISMATCH = 25 +NVML_ERROR_DEPRECATED = 26 +NVML_ERROR_NOT_READY = 27 +NVML_ERROR_GPU_NOT_FOUND = 28 +NVML_ERROR_INVALID_STATE = 29 +NVML_ERROR_UNKNOWN = 999 + +_nvmlFanState_t = c_uint +NVML_FAN_NORMAL = 0 +NVML_FAN_FAILED = 1 + +_nvmlFanControlPolicy_t = c_uint +NVML_FAN_POLICY_TEMPERATURE_CONTINOUS_SW = 0 +NVML_FAN_POLICY_MANUAL = 1 + +_nvmlLedColor_t = c_uint +NVML_LED_COLOR_GREEN = 0 +NVML_LED_COLOR_AMBER = 1 + +_nvmlGpuOperationMode_t = c_uint +NVML_GOM_ALL_ON = 0 +NVML_GOM_COMPUTE = 1 +NVML_GOM_LOW_DP = 2 + +_nvmlPageRetirementCause_t = c_uint +NVML_PAGE_RETIREMENT_CAUSE_MULTIPLE_SINGLE_BIT_ECC_ERRORS = 0 +NVML_PAGE_RETIREMENT_CAUSE_DOUBLE_BIT_ECC_ERROR = 1 +NVML_PAGE_RETIREMENT_CAUSE_COUNT = 2 + +_nvmlRestrictedAPI_t = c_uint +NVML_RESTRICTED_API_SET_APPLICATION_CLOCKS = 0 +NVML_RESTRICTED_API_SET_AUTO_BOOSTED_CLOCKS = 1 +NVML_RESTRICTED_API_COUNT = 2 + +_nvmlBridgeChipType_t = c_uint +NVML_BRIDGE_CHIP_PLX = 0 +NVML_BRIDGE_CHIP_BRO4 = 1 +NVML_MAX_PHYSICAL_BRIDGE = 128 + +_nvmlValueType_t = c_uint +NVML_VALUE_TYPE_DOUBLE = 0 +NVML_VALUE_TYPE_UNSIGNED_INT = 1 +NVML_VALUE_TYPE_UNSIGNED_LONG = 2 +NVML_VALUE_TYPE_UNSIGNED_LONG_LONG = 3 +NVML_VALUE_TYPE_SIGNED_LONG_LONG = 4 +NVML_VALUE_TYPE_SIGNED_INT = 5 +NVML_VALUE_TYPE_UNSIGNED_SHORT = 6 +NVML_VALUE_TYPE_COUNT = 7 + +_nvmlNvlinkVersion_t = c_uint +NVML_NVLINK_VERSION_INVALID = 0 +NVML_NVLINK_VERSION_1_0 = 1 +NVML_NVLINK_VERSION_2_0 = 2 +NVML_NVLINK_VERSION_2_2 = 3 +NVML_NVLINK_VERSION_3_0 = 4 +NVML_NVLINK_VERSION_3_1 = 5 +NVML_NVLINK_VERSION_4_0 = 6 +NVML_NVLINK_VERSION_5_0 = 7 + +_nvmlPerfPolicyType_t = c_uint +NVML_PERF_POLICY_POWER = 0 +NVML_PERF_POLICY_THERMAL = 1 +NVML_PERF_POLICY_SYNC_BOOST = 2 +NVML_PERF_POLICY_BOARD_LIMIT = 3 +NVML_PERF_POLICY_LOW_UTILIZATION = 4 +NVML_PERF_POLICY_RELIABILITY = 5 +NVML_PERF_POLICY_TOTAL_APP_CLOCKS = 10 +NVML_PERF_POLICY_TOTAL_BASE_CLOCKS = 11 +NVML_PERF_POLICY_COUNT = 12 + +_nvmlEncoderQueryType_t = c_uint +NVML_ENCODER_QUERY_H264 = 0 +NVML_ENCODER_QUERY_HEVC = 1 +NVML_ENCODER_QUERY_AV1 = 2 +NVML_ENCODER_QUERY_UNKNOWN = 255 + +_nvmlFBCSessionType_t = c_uint +NVML_FBC_SESSION_TYPE_UNKNOWN = 0 +NVML_FBC_SESSION_TYPE_TOSYS = 1 +NVML_FBC_SESSION_TYPE_CUDA = 2 +NVML_FBC_SESSION_TYPE_VID = 3 +NVML_FBC_SESSION_TYPE_HWENC = 4 + +_nvmlDetachGpuState_t = c_uint +NVML_DETACH_GPU_KEEP = 0 +NVML_DETACH_GPU_REMOVE = 1 + +_nvmlPcieLinkState_t = c_uint +NVML_PCIE_LINK_KEEP = 0 +NVML_PCIE_LINK_SHUT_DOWN = 1 + +_nvmlSamplingType_t = c_uint +NVML_TOTAL_POWER_SAMPLES = 0 +NVML_GPU_UTILIZATION_SAMPLES = 1 +NVML_MEMORY_UTILIZATION_SAMPLES = 2 +NVML_ENC_UTILIZATION_SAMPLES = 3 +NVML_DEC_UTILIZATION_SAMPLES = 4 +NVML_PROCESSOR_CLK_SAMPLES = 5 +NVML_MEMORY_CLK_SAMPLES = 6 +NVML_MODULE_POWER_SAMPLES = 7 +NVML_JPG_UTILIZATION_SAMPLES = 8 +NVML_OFA_UTILIZATION_SAMPLES = 9 +NVML_SAMPLINGTYPE_COUNT = 10 + +_nvmlPcieUtilCounter_t = c_uint +NVML_PCIE_UTIL_TX_BYTES = 0 +NVML_PCIE_UTIL_RX_BYTES = 1 +NVML_PCIE_UTIL_COUNT = 2 + +_nvmlGpuTopologyLevel_t = c_uint +NVML_TOPOLOGY_INTERNAL = 0 +NVML_TOPOLOGY_SINGLE = 10 +NVML_TOPOLOGY_MULTIPLE = 20 +NVML_TOPOLOGY_HOSTBRIDGE = 30 +NVML_TOPOLOGY_NODE = 40 +NVML_TOPOLOGY_CPU = NVML_TOPOLOGY_NODE +NVML_TOPOLOGY_SYSTEM = 50 + +_nvmlGpuP2PCapsIndex_t = c_uint +NVML_P2P_CAPS_INDEX_READ = (0,) +NVML_P2P_CAPS_INDEX_WRITE = 1 +NVML_P2P_CAPS_INDEX_NVLINK = 2 +NVML_P2P_CAPS_INDEX_ATOMICS = 3 +# +# NVML_P2P_CAPS_INDEX_PROP is deprecated. +# Use NVML_P2P_CAPS_INDEX_PCI instead. +# +NVML_P2P_CAPS_INDEX_PROP = 4 +NVML_P2P_CAPS_INDEX_PCI = 4 +NVML_P2P_CAPS_INDEX_UNKNOWN = 5 + +_nvmlGpuP2PStatus_t = c_uint +NVML_P2P_STATUS_OK = 0 +NVML_P2P_STATUS_CHIPSET_NOT_SUPPORED = 1 +NVML_P2P_STATUS_CHIPSET_NOT_SUPPORTED = NVML_P2P_STATUS_CHIPSET_NOT_SUPPORED +NVML_P2P_STATUS_GPU_NOT_SUPPORTED = 2 +NVML_P2P_STATUS_IOH_TOPOLOGY_NOT_SUPPORTED = 3 +NVML_P2P_STATUS_DISABLED_BY_REGKEY = 4 +NVML_P2P_STATUS_NOT_SUPPORTED = 5 +NVML_P2P_STATUS_UNKNOWN = 6 + +_nvmlDeviceArchitecture_t = c_uint +NVML_DEVICE_ARCH_KEPLER = 2 +NVML_DEVICE_ARCH_MAXWELL = 3 +NVML_DEVICE_ARCH_PASCAL = 4 +NVML_DEVICE_ARCH_VOLTA = 5 +NVML_DEVICE_ARCH_TURING = 6 +NVML_DEVICE_ARCH_AMPERE = 7 +NVML_DEVICE_ARCH_ADA = 8 +NVML_DEVICE_ARCH_HOPPER = 9 +NVML_DEVICE_ARCH_BLACKWELL = 10 +NVML_DEVICE_ARCH_T23X = 11 +NVML_DEVICE_ARCH_UNKNOWN = 0xFFFFFFFF + +# PCI bus Types +_nvmlBusType_t = c_uint +NVML_BUS_TYPE_UNKNOWN = 0 +NVML_BUS_TYPE_PCI = 1 +NVML_BUS_TYPE_PCIE = 2 +NVML_BUS_TYPE_FPCI = 3 +NVML_BUS_TYPE_AGP = 4 + +_nvmlPowerSource_t = c_uint +NVML_POWER_SOURCE_AC = 0x00000000 +NVML_POWER_SOURCE_BATTERY = 0x00000001 +NVML_POWER_SOURCE_UNDERSIZED = 0x00000002 + +_nvmlAdaptiveClockInfoStatus_t = c_uint +NVML_ADAPTIVE_CLOCKING_INFO_STATUS_DISABLED = 0x00000000 +NVML_ADAPTIVE_CLOCKING_INFO_STATUS_ENABLED = 0x00000001 + +_nvmlClockLimitId_t = c_uint +NVML_CLOCK_LIMIT_ID_RANGE_START = 0xFFFFFF00 +NVML_CLOCK_LIMIT_ID_TDP = 0xFFFFFF01 +NVML_CLOCK_LIMIT_ID_UNLIMITED = 0xFFFFFF02 + +_nvmlPcieLinkMaxSpeed_t = c_uint +NVML_PCIE_LINK_MAX_SPEED_INVALID = 0x00000000 +NVML_PCIE_LINK_MAX_SPEED_2500MBPS = 0x00000001 +NVML_PCIE_LINK_MAX_SPEED_5000MBPS = 0x00000002 +NVML_PCIE_LINK_MAX_SPEED_8000MBPS = 0x00000003 +NVML_PCIE_LINK_MAX_SPEED_16000MBPS = 0x00000004 +NVML_PCIE_LINK_MAX_SPEED_32000MBPS = 0x00000005 +NVML_PCIE_LINK_MAX_SPEED_64000MBPS = 0x00000006 + +_nvmlPcieAtomicsCapability_t = c_uint +NVML_PCIE_ATOMICS_CAP_FETCHADD32 = 0x01 +NVML_PCIE_ATOMICS_CAP_FETCHADD64 = 0x02 +NVML_PCIE_ATOMICS_CAP_SWAP32 = 0x04 +NVML_PCIE_ATOMICS_CAP_SWAP64 = 0x08 +NVML_PCIE_ATOMICS_CAP_CAS32 = 0x10 +NVML_PCIE_ATOMICS_CAP_CAS64 = 0x20 +NVML_PCIE_ATOMICS_CAP_CAS128 = 0x40 +NVML_PCIE_ATOMICS_OPS_MAX = 7 + +_nvmlAffinityScope_t = c_uint +NVML_AFFINITY_SCOPE_NODE = 0 +NVML_AFFINITY_SCOPE_SOCKET = 1 + +_nvmlDeviceGpuRecoveryAction_t = c_uint +NVML_GPU_RECOVERY_ACTION_NONE = 0 +NVML_GPU_RECOVERY_ACTION_GPU_RESET = 1 +NVML_GPU_RECOVERY_ACTION_NODE_REBOOT = 2 +NVML_GPU_RECOVERY_ACTION_DRAIN_P2P = 3 +NVML_GPU_RECOVERY_ACTION_DRAIN_AND_RESET = 4 + +# C preprocessor defined values +nvmlFlagDefault = 0 +nvmlFlagForce = 1 +NVML_INIT_FLAG_NO_GPUS = 1 +NVML_INIT_FLAG_NO_ATTACH = 2 + +NVML_MAX_GPC_COUNT = 32 + +# buffer size +NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE = 16 +NVML_DEVICE_UUID_BUFFER_SIZE = 80 +NVML_DEVICE_UUID_V2_BUFFER_SIZE = 96 +NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE = 80 +NVML_SYSTEM_NVML_VERSION_BUFFER_SIZE = 80 +NVML_DEVICE_NAME_BUFFER_SIZE = 64 +NVML_DEVICE_NAME_V2_BUFFER_SIZE = 96 +NVML_DEVICE_SERIAL_BUFFER_SIZE = 30 +NVML_DEVICE_PART_NUMBER_BUFFER_SIZE = 80 +NVML_DEVICE_GPU_PART_NUMBER_BUFFER_SIZE = 80 +NVML_DEVICE_VBIOS_VERSION_BUFFER_SIZE = 32 +NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE = 32 +NVML_DEVICE_PCI_BUS_ID_BUFFER_V2_SIZE = 16 +NVML_GRID_LICENSE_BUFFER_SIZE = 128 +NVML_VGPU_NAME_BUFFER_SIZE = 64 +NVML_GRID_LICENSE_FEATURE_MAX_COUNT = 3 +NVML_VGPU_METADATA_OPAQUE_DATA_SIZE = sizeof(c_uint) + 256 +NVML_VGPU_PGPU_METADATA_OPAQUE_DATA_SIZE = 256 +NVML_DEVICE_GPU_FRU_PART_NUMBER_BUFFER_SIZE = ( + 0x14 # NV2080_GPU_MAX_PRODUCT_PART_NUMBER_LENGTH +) +NVML_PERF_MODES_BUFFER_SIZE = 2048 + +# Format strings +NVML_DEVICE_PCI_BUS_ID_LEGACY_FMT = "%04X:%02X:%02X.0" +NVML_DEVICE_PCI_BUS_ID_FMT = "%08X:%02X:%02X.0" + +NVML_VALUE_NOT_AVAILABLE_ulonglong = c_ulonglong(-1) +NVML_VALUE_NOT_AVAILABLE_uint = c_uint(-1) + +""" + Field Identifiers. + + All Identifiers pertain to a device. Each ID is only used once and is guaranteed never to change. +""" +NVML_FI_DEV_ECC_CURRENT = 1 # Current ECC mode. 1=Active. 0=Inactive +NVML_FI_DEV_ECC_PENDING = 2 # Pending ECC mode. 1=Active. 0=Inactive + +# ECC Count Totals +NVML_FI_DEV_ECC_SBE_VOL_TOTAL = 3 # Total single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_TOTAL = 4 # Total double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_AGG_TOTAL = 5 # Total single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_TOTAL = 6 # Total double bit aggregate (persistent) ECC errors +# Individual ECC locations +NVML_FI_DEV_ECC_SBE_VOL_L1 = 7 # L1 cache single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_L1 = 8 # L1 cache double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_L2 = 9 # L2 cache single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_L2 = 10 # L2 cache double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_DEV = 11 # Device memory single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_DEV = 12 # Device memory double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_REG = 13 # Register file single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_REG = 14 # Register file double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_TEX = 15 # Texture memory single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_TEX = 16 # Texture memory double bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_CBU = 17 # CBU double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_AGG_L1 = 18 # L1 cache single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_L1 = 19 # L1 cache double bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_SBE_AGG_L2 = 20 # L2 cache single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_L2 = 21 # L2 cache double bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_SBE_AGG_DEV = ( + 22 # Device memory single bit aggregate (persistent) ECC errors +) +NVML_FI_DEV_ECC_DBE_AGG_DEV = ( + 23 # Device memory double bit aggregate (persistent) ECC errors +) +NVML_FI_DEV_ECC_SBE_AGG_REG = ( + 24 # Register File single bit aggregate (persistent) ECC errors +) +NVML_FI_DEV_ECC_DBE_AGG_REG = ( + 25 # Register File double bit aggregate (persistent) ECC errors +) +NVML_FI_DEV_ECC_SBE_AGG_TEX = ( + 26 # Texture memory single bit aggregate (persistent) ECC errors +) +NVML_FI_DEV_ECC_DBE_AGG_TEX = ( + 27 # Texture memory double bit aggregate (persistent) ECC errors +) +NVML_FI_DEV_ECC_DBE_AGG_CBU = 28 # CBU double bit aggregate ECC errors + +# Page Retirement +NVML_FI_DEV_RETIRED_SBE = 29 # Number of retired pages because of single bit errors +NVML_FI_DEV_RETIRED_DBE = 30 # Number of retired pages because of double bit errors +NVML_FI_DEV_RETIRED_PENDING = 31 # If any pages are pending retirement. 1=yes. 0=no. + +# NvLink Flit Error Counters +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L0 = ( + 32 # NVLink flow control CRC Error Counter for Lane 0 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L1 = ( + 33 # NVLink flow control CRC Error Counter for Lane 1 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L2 = ( + 34 # NVLink flow control CRC Error Counter for Lane 2 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L3 = ( + 35 # NVLink flow control CRC Error Counter for Lane 3 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L4 = ( + 36 # NVLink flow control CRC Error Counter for Lane 4 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L5 = ( + 37 # NVLink flow control CRC Error Counter for Lane 5 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_TOTAL = ( + 38 # NVLink flow control CRC Error Counter total for all Lanes +) + +# NvLink CRC Data Error Counters +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L0 = ( + 39 # NVLink data CRC Error Counter for Lane 0 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L1 = ( + 40 # NVLink data CRC Error Counter for Lane 1 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L2 = ( + 41 # NVLink data CRC Error Counter for Lane 2 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L3 = ( + 42 # NVLink data CRC Error Counter for Lane 3 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L4 = ( + 43 # NVLink data CRC Error Counter for Lane 4 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L5 = ( + 44 # NVLink data CRC Error Counter for Lane 5 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_TOTAL = ( + 45 # NvLink data CRC Error Counter total for all Lanes +) + +# NvLink Replay Error Counters +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L0 = 46 # NVLink Replay Error Counter for Lane 0 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L1 = 47 # NVLink Replay Error Counter for Lane 1 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L2 = 48 # NVLink Replay Error Counter for Lane 2 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L3 = 49 # NVLink Replay Error Counter for Lane 3 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L4 = 50 # NVLink Replay Error Counter for Lane 4 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L5 = 51 # NVLink Replay Error Counter for Lane 5 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_TOTAL = ( + 52 # NVLink Replay Error Counter total for all Lanes +) + +# NvLink Recovery Error Counters +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L0 = ( + 53 # NVLink Recovery Error Counter for Lane 0 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L1 = ( + 54 # NVLink Recovery Error Counter for Lane 1 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L2 = ( + 55 # NVLink Recovery Error Counter for Lane 2 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L3 = ( + 56 # NVLink Recovery Error Counter for Lane 3 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L4 = ( + 57 # NVLink Recovery Error Counter for Lane 4 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L5 = ( + 58 # NVLink Recovery Error Counter for Lane 5 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_TOTAL = ( + 59 # NVLink Recovery Error Counter total for all Lanes +) + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L0 = ( + 60 # NVLink Bandwidth Counter for Counter Set 0, Lane 0 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L1 = ( + 61 # NVLink Bandwidth Counter for Counter Set 0, Lane 1 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L2 = ( + 62 # NVLink Bandwidth Counter for Counter Set 0, Lane 2 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L3 = ( + 63 # NVLink Bandwidth Counter for Counter Set 0, Lane 3 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L4 = ( + 64 # NVLink Bandwidth Counter for Counter Set 0, Lane 4 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L5 = ( + 65 # NVLink Bandwidth Counter for Counter Set 0, Lane 5 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_TOTAL = ( + 66 # NVLink Bandwidth Counter Total for Counter Set 0, All Lanes +) + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L0 = ( + 67 # NVLink Bandwidth Counter for Counter Set 1, Lane 0 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L1 = ( + 68 # NVLink Bandwidth Counter for Counter Set 1, Lane 1 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L2 = ( + 69 # NVLink Bandwidth Counter for Counter Set 1, Lane 2 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L3 = ( + 70 # NVLink Bandwidth Counter for Counter Set 1, Lane 3 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L4 = ( + 71 # NVLink Bandwidth Counter for Counter Set 1, Lane 4 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L5 = ( + 72 # NVLink Bandwidth Counter for Counter Set 1, Lane 5 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_TOTAL = ( + 73 # NVLink Bandwidth Counter Total for Counter Set 1, All Lanes +) + +# Perf Policy Counters +NVML_FI_DEV_PERF_POLICY_POWER = 74 # Perf Policy Counter for Power Policy +NVML_FI_DEV_PERF_POLICY_THERMAL = 75 # Perf Policy Counter for Thermal Policy +NVML_FI_DEV_PERF_POLICY_SYNC_BOOST = 76 # Perf Policy Counter for Sync boost Policy +NVML_FI_DEV_PERF_POLICY_BOARD_LIMIT = 77 # Perf Policy Counter for Board Limit +NVML_FI_DEV_PERF_POLICY_LOW_UTILIZATION = ( + 78 # Perf Policy Counter for Low GPU Utilization Policy +) +NVML_FI_DEV_PERF_POLICY_RELIABILITY = 79 # Perf Policy Counter for Reliability Policy +NVML_FI_DEV_PERF_POLICY_TOTAL_APP_CLOCKS = ( + 80 # Perf Policy Counter for Total App Clock Policy +) +NVML_FI_DEV_PERF_POLICY_TOTAL_BASE_CLOCKS = ( + 81 # Perf Policy Counter for Total Base Clocks Policy +) + +# Memory temperatures +NVML_FI_DEV_MEMORY_TEMP = 82 # Memory temperature for the device + +# Energy Counter +NVML_FI_DEV_TOTAL_ENERGY_CONSUMPTION = ( + 83 # Total energy consumption for the GPU in mJ since the driver was last reloaded +) + +# NVLink Speed +NVML_FI_DEV_NVLINK_SPEED_MBPS_L0 = 84 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L1 = 85 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L2 = 86 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L3 = 87 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L4 = 88 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L5 = 89 +NVML_FI_DEV_NVLINK_SPEED_MBPS_COMMON = 90 + +# NVLink Link Count +NVML_FI_DEV_NVLINK_LINK_COUNT = 91 + +# Page Retirement pending fields +NVML_FI_DEV_RETIRED_PENDING_SBE = 92 +NVML_FI_DEV_RETIRED_PENDING_DBE = 93 + +# PCIe replay and replay rollover counters +NVML_FI_DEV_PCIE_REPLAY_COUNTER = 94 +NVML_FI_DEV_PCIE_REPLAY_ROLLOVER_COUNTER = 95 + +# NvLink Flit Error Counters +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L6 = ( + 96 # NVLink flow control CRC Error Counter for Lane 6 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L7 = ( + 97 # NVLink flow control CRC Error Counter for Lane 7 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L8 = ( + 98 # NVLink flow control CRC Error Counter for Lane 8 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L9 = ( + 99 # NVLink flow control CRC Error Counter for Lane 9 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L10 = ( + 100 # NVLink flow control CRC Error Counter for Lane 10 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L11 = ( + 101 # NVLink flow control CRC Error Counter for Lane 11 +) + +# NvLink CRC Data Error Counters +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L6 = ( + 102 # NVLink data CRC Error Counter for Lane 6 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L7 = ( + 103 # NVLink data CRC Error Counter for Lane 7 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L8 = ( + 104 # NVLink data CRC Error Counter for Lane 8 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L9 = ( + 105 # NVLink data CRC Error Counter for Lane 9 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L10 = ( + 106 # NVLink data CRC Error Counter for Lane 10 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L11 = ( + 107 # NVLink data CRC Error Counter for Lane 11 +) + +# NvLink Replay Error Counters +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L6 = 108 # NVLink Replay Error Counter for Lane 6 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L7 = 109 # NVLink Replay Error Counter for Lane 7 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L8 = 110 # NVLink Replay Error Counter for Lane 8 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L9 = 111 # NVLink Replay Error Counter for Lane 9 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L10 = ( + 112 # NVLink Replay Error Counter for Lane 10 +) +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L11 = ( + 113 # NVLink Replay Error Counter for Lane 11 +) + +# NvLink Recovery Error Counters +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L6 = ( + 114 # NVLink Recovery Error Counter for Lane 6 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L7 = ( + 115 # NVLink Recovery Error Counter for Lane 7 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L8 = ( + 116 # NVLink Recovery Error Counter for Lane 8 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L9 = ( + 117 # NVLink Recovery Error Counter for Lane 9 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L10 = ( + 118 # NVLink Recovery Error Counter for Lane 10 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L11 = ( + 119 # NVLink Recovery Error Counter for Lane 11 +) + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L6 = ( + 120 # NVLink Bandwidth Counter for Counter Set 0, Lane 6 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L7 = ( + 121 # NVLink Bandwidth Counter for Counter Set 0, Lane 7 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L8 = ( + 122 # NVLink Bandwidth Counter for Counter Set 0, Lane 8 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L9 = ( + 123 # NVLink Bandwidth Counter for Counter Set 0, Lane 9 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L10 = ( + 124 # NVLink Bandwidth Counter for Counter Set 0, Lane 10 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L11 = ( + 125 # NVLink Bandwidth Counter for Counter Set 0, Lane 11 +) + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L6 = ( + 126 # NVLink Bandwidth Counter for Counter Set 1, Lane 6 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L7 = ( + 127 # NVLink Bandwidth Counter for Counter Set 1, Lane 7 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L8 = ( + 128 # NVLink Bandwidth Counter for Counter Set 1, Lane 8 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L9 = ( + 129 # NVLink Bandwidth Counter for Counter Set 1, Lane 9 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L10 = ( + 130 # NVLink Bandwidth Counter for Counter Set 1, Lane 10 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L11 = ( + 131 # NVLink Bandwidth Counter for Counter Set 1, Lane 11 +) + +# NVLink Speed +NVML_FI_DEV_NVLINK_SPEED_MBPS_L6 = 132 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L7 = 133 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L8 = 134 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L9 = 135 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L10 = 136 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L11 = 137 + +# NVLink Throughput Counters +NVML_FI_DEV_NVLINK_THROUGHPUT_DATA_TX = 138 # NVLink TX Data throughput in KiB +NVML_FI_DEV_NVLINK_THROUGHPUT_DATA_RX = 139 # NVLink RX Data throughput in KiB +NVML_FI_DEV_NVLINK_THROUGHPUT_RAW_TX = 140 # NVLink TX Data + protocol overhead in KiB +NVML_FI_DEV_NVLINK_THROUGHPUT_RAW_RX = 141 # NVLink RX Data + protocol overhead in KiB + +# Row Remapper +NVML_FI_DEV_REMAPPED_COR = 142 +NVML_FI_DEV_REMAPPED_UNC = 143 +NVML_FI_DEV_REMAPPED_PENDING = 144 +NVML_FI_DEV_REMAPPED_FAILURE = 145 + +# Remote device NVLink ID +NVML_FI_DEV_NVLINK_REMOTE_NVLINK_ID = 146 + +# Number of NVLinks connected to NVSwitch +NVML_FI_DEV_NVSWITCH_CONNECTED_LINK_COUNT = 147 + +# NvLink ECC Data Error Counters +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L0 = ( + 148 # < NVLink data ECC Error Counter for Link 0 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L1 = ( + 149 # < NVLink data ECC Error Counter for Link 1 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L2 = ( + 150 # < NVLink data ECC Error Counter for Link 2 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L3 = ( + 151 # < NVLink data ECC Error Counter for Link 3 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L4 = ( + 152 # < NVLink data ECC Error Counter for Link 4 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L5 = ( + 153 # < NVLink data ECC Error Counter for Link 5 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L6 = ( + 154 # < NVLink data ECC Error Counter for Link 6 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L7 = ( + 155 # < NVLink data ECC Error Counter for Link 7 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L8 = ( + 156 # < NVLink data ECC Error Counter for Link 8 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L9 = ( + 157 # < NVLink data ECC Error Counter for Link 9 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L10 = ( + 158 # < NVLink data ECC Error Counter for Link 10 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L11 = ( + 159 # < NVLink data ECC Error Counter for Link 11 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_TOTAL = ( + 160 # < NvLink data ECC Error Counter total for all Links +) + +NVML_FI_DEV_NVLINK_ERROR_DL_REPLAY = 161 +NVML_FI_DEV_NVLINK_ERROR_DL_RECOVERY = 162 +NVML_FI_DEV_NVLINK_ERROR_DL_CRC = 163 +NVML_FI_DEV_NVLINK_GET_SPEED = 164 +NVML_FI_DEV_NVLINK_GET_STATE = 165 +NVML_FI_DEV_NVLINK_GET_VERSION = 166 + +NVML_FI_DEV_NVLINK_GET_POWER_STATE = 167 +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD = 168 + +NVML_FI_DEV_PCIE_L0_TO_RECOVERY_COUNTER = 169 + +NVML_FI_DEV_C2C_LINK_COUNT = 170 +NVML_FI_DEV_C2C_LINK_GET_STATUS = 171 +NVML_FI_DEV_C2C_LINK_GET_MAX_BW = 172 + +NVML_FI_DEV_PCIE_COUNT_CORRECTABLE_ERRORS = 173 +NVML_FI_DEV_PCIE_COUNT_NAKS_RECEIVED = 174 +NVML_FI_DEV_PCIE_COUNT_RECEIVER_ERROR = 175 +NVML_FI_DEV_PCIE_COUNT_BAD_TLP = 176 +NVML_FI_DEV_PCIE_COUNT_NAKS_SENT = 177 +NVML_FI_DEV_PCIE_COUNT_BAD_DLLP = 178 +NVML_FI_DEV_PCIE_COUNT_NON_FATAL_ERROR = 179 +NVML_FI_DEV_PCIE_COUNT_FATAL_ERROR = 180 +NVML_FI_DEV_PCIE_COUNT_UNSUPPORTED_REQ = 181 +NVML_FI_DEV_PCIE_COUNT_LCRC_ERROR = 182 +NVML_FI_DEV_PCIE_COUNT_LANE_ERROR = 183 + +NVML_FI_DEV_IS_RESETLESS_MIG_SUPPORTED = 184 + +NVML_FI_DEV_POWER_AVERAGE = 185 +NVML_FI_DEV_POWER_INSTANT = 186 +NVML_FI_DEV_POWER_MIN_LIMIT = 187 +NVML_FI_DEV_POWER_MAX_LIMIT = 188 +NVML_FI_DEV_POWER_DEFAULT_LIMIT = 189 +NVML_FI_DEV_POWER_CURRENT_LIMIT = 190 +NVML_FI_DEV_ENERGY = 191 +NVML_FI_DEV_POWER_REQUESTED_LIMIT = 192 + +NVML_FI_DEV_TEMPERATURE_SHUTDOWN_TLIMIT = 193 +NVML_FI_DEV_TEMPERATURE_SLOWDOWN_TLIMIT = 194 +NVML_FI_DEV_TEMPERATURE_MEM_MAX_TLIMIT = 195 +NVML_FI_DEV_TEMPERATURE_GPU_MAX_TLIMIT = 196 + +NVML_FI_DEV_PCIE_COUNT_TX_BYTES = 197 +NVML_FI_DEV_PCIE_COUNT_RX_BYTES = 198 + +NVML_FI_DEV_IS_MIG_MODE_INDEPENDENT_MIG_QUERY_CAPABLE = 199 + +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_MAX = 200 + +NVML_FI_DEV_NVLINK_COUNT_XMIT_PACKETS = 201 +NVML_FI_DEV_NVLINK_COUNT_XMIT_BYTES = 202 +NVML_FI_DEV_NVLINK_COUNT_RCV_PACKETS = 203 +NVML_FI_DEV_NVLINK_COUNT_RCV_BYTES = 204 +NVML_FI_DEV_NVLINK_COUNT_VL15_DROPPED = 205 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_MALFORMED_PACKET_ERRORS = 206 +NVML_FI_DEV_NVLINK_COUNT_BUFFER_OVERRUN_ERRORS = 207 +NVML_FI_DEV_NVLINK_COUNT_RCV_ERRORS = 208 +NVML_FI_DEV_NVLINK_COUNT_RCV_REMOTE_ERRORS = 209 +NVML_FI_DEV_NVLINK_COUNT_RCV_GENERAL_ERRORS = 210 +NVML_FI_DEV_NVLINK_COUNT_LOCAL_LINK_INTEGRITY_ERRORS = 211 +NVML_FI_DEV_NVLINK_COUNT_XMIT_DISCARDS = 212 + +NVML_FI_DEV_NVLINK_COUNT_LINK_RECOVERY_SUCCESSFUL_EVENTS = 213 +NVML_FI_DEV_NVLINK_COUNT_LINK_RECOVERY_FAILED_EVENTS = 214 +NVML_FI_DEV_NVLINK_COUNT_LINK_RECOVERY_EVENTS = 215 + +NVML_FI_DEV_NVLINK_COUNT_RAW_BER_LANE0 = 216 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_RAW_BER_LANE1 = 217 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_RAW_BER = 218 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_EFFECTIVE_ERRORS = 219 +NVML_FI_DEV_NVLINK_COUNT_EFFECTIVE_BER = 220 +NVML_FI_DEV_NVLINK_COUNT_SYMBOL_ERRORS = 221 +NVML_FI_DEV_NVLINK_COUNT_SYMBOL_BER = 222 + +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_MIN = 223 +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_UNITS = ( + 224 # Values are in the form NVML_NVLINK_LOW_POWER_THRESHOLD_UNIT_* +) +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_SUPPORTED = 225 + +NVML_FI_DEV_RESET_STATUS = ( + 226 # Deprecated use NVML_FI_DEV_GET_GPU_RECOVERY_ACTION instead +) +NVML_FI_DEV_DRAIN_AND_RESET_STATUS = ( + 227 # Deprecated use NVML_FI_DEV_GET_GPU_RECOVERY_ACTION instead +) +NVML_FI_DEV_PCIE_OUTBOUND_ATOMICS_MASK = 228 +NVML_FI_DEV_PCIE_INBOUND_ATOMICS_MASK = 229 +NVML_FI_DEV_GET_GPU_RECOVERY_ACTION = 230 + +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_0 = 235 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_1 = 236 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_2 = 237 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_3 = 238 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_4 = 239 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_5 = 240 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_6 = 241 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_7 = 242 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_8 = 243 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_9 = 244 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_10 = 245 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_11 = 246 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_12 = 247 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_13 = 248 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_14 = 249 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_15 = 250 +NVML_FI_PWR_SMOOTHING_ENABLED = 251 # Enablement (0/DISABLED or 1/ENABLED) +NVML_FI_PWR_SMOOTHING_PRIV_LVL = 252 # Current privilege level +NVML_FI_PWR_SMOOTHING_IMM_RAMP_DOWN_ENABLED = ( + 253 # Immediate ramp down enablement (0/DISABLED or 1/ENABLED) +) +NVML_FI_PWR_SMOOTHING_APPLIED_TMP_CEIL = 254 # Applied TMP ceiling value +NVML_FI_PWR_SMOOTHING_APPLIED_TMP_FLOOR = 255 # Applied TMP floor value +NVML_FI_PWR_SMOOTHING_MAX_PERCENT_TMP_FLOOR_SETTING = 256 # Max % TMP Floor value +NVML_FI_PWR_SMOOTHING_MIN_PERCENT_TMP_FLOOR_SETTING = 257 # Min % TMP Floor value +NVML_FI_PWR_SMOOTHING_HW_CIRCUITRY_PERCENT_LIFETIME_REMAINING = ( + 258 # HW Circuitry % lifetime remaining +) +NVML_FI_PWR_SMOOTHING_MAX_NUM_PRESET_PROFILES = 259 # Max number of preset profiles +NVML_FI_PWR_SMOOTHING_PROFILE_PERCENT_TMP_FLOOR = 260 # % TMP floor for a given profile +NVML_FI_PWR_SMOOTHING_PROFILE_RAMP_UP_RATE = ( + 261 # Ramp up rate in mW/s for a given profile +) +NVML_FI_PWR_SMOOTHING_PROFILE_RAMP_DOWN_RATE = ( + 262 # Ramp down rate in mW/s for a given profile +) +NVML_FI_PWR_SMOOTHING_PROFILE_RAMP_DOWN_HYST_VAL = ( + 263 # Ramp down hysteresis value in ms for a given profile +) +NVML_FI_PWR_SMOOTHING_ACTIVE_PRESET_PROFILE = 264 # Active preset profile number +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_PERCENT_TMP_FLOOR = ( + 265 # % TMP floor for a given profile +) +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_RAMP_UP_RATE = ( + 266 # Ramp up rate in mW/s for a given profile +) +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_RAMP_DOWN_RATE = ( + 267 # Ramp down rate in mW/s for a given profile +) +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_RAMP_DOWN_HYST_VAL = ( + 268 # Ramp down hysteresis value in ms for a given profile +) + +NVML_FI_MAX = 269 # One greater than the largest field ID defined above + +# NVML_FI_DEV_NVLINK_GET_STATE state enums +NVML_NVLINK_STATE_INACTIVE = 0x0 +NVML_NVLINK_STATE_ACTIVE = 0x1 +NVML_NVLINK_STATE_SLEEP = 0x2 + +NVML_NVLINK_LOW_POWER_THRESHOLD_UNIT_100US = ( + 0 # NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_UNITS +) +NVML_NVLINK_LOW_POWER_THRESHOLD_UNIT_50US = ( + 1 # NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_UNITS +) + +## Enums needed for the method nvmlDeviceGetVirtualizationMode and nvmlDeviceSetVirtualizationMode +NVML_GPU_VIRTUALIZATION_MODE_NONE = 0 # Represents Bare Metal GPU +NVML_GPU_VIRTUALIZATION_MODE_PASSTHROUGH = ( + 1 # Device is associated with GPU-Passthorugh +) +NVML_GPU_VIRTUALIZATION_MODE_VGPU = ( + 2 # Device is associated with vGPU inside virtual machine. +) +NVML_GPU_VIRTUALIZATION_MODE_HOST_VGPU = ( + 3 # Device is associated with VGX hypervisor in vGPU mode +) +NVML_GPU_VIRTUALIZATION_MODE_HOST_VSGA = ( + 4 # Device is associated with VGX hypervisor in vSGA mode +) + +## Lib loading ## +nvmlLib = None +libLoadLock = threading.Lock() +_nvmlLib_refcount = 0 # Incremented on each nvmlInit and decremented on nvmlShutdown + +## vGPU Management +_nvmlVgpuTypeId_t = c_uint +_nvmlVgpuInstance_t = c_uint + +_nvmlVgpuVmIdType_t = c_uint +NVML_VGPU_VM_ID_DOMAIN_ID = 0 +NVML_VGPU_VM_ID_UUID = 1 + +_nvmlGridLicenseFeatureCode_t = c_uint +NVML_GRID_LICENSE_FEATURE_CODE_UNKNOWN = 0 +NVML_GRID_LICENSE_FEATURE_CODE_VGPU = 1 +NVML_GRID_LICENSE_FEATURE_CODE_NVIDIA_RTX = 2 +NVML_GRID_LICENSE_FEATURE_CODE_VWORKSTATION = ( + 2 # deprecated, use NVML_GRID_LICENSE_FEATURE_CODE_NVIDIA_RTX. +) +NVML_GRID_LICENSE_FEATURE_CODE_GAMING = 3 +NVML_GRID_LICENSE_FEATURE_CODE_COMPUTE = 4 + +_nvmlGridLicenseExpiryStatus_t = c_uint8 +NVML_GRID_LICENSE_EXPIRY_NOT_AVAILABLE = (0,) # Expiry information not available +NVML_GRID_LICENSE_EXPIRY_INVALID = (1,) # Invalid expiry or error fetching expiry +NVML_GRID_LICENSE_EXPIRY_VALID = (2,) # Valid expiry +NVML_GRID_LICENSE_EXPIRY_NOT_APPLICABLE = (3,) # Expiry not applicable +NVML_GRID_LICENSE_EXPIRY_PERMANENT = (4,) # Permanent expiry + +_nvmlVgpuCapability_t = c_uint +NVML_VGPU_CAP_NVLINK_P2P = 0 # vGPU P2P over NVLink is supported +NVML_VGPU_CAP_GPUDIRECT = 1 # GPUDirect capability is supported +NVML_VGPU_CAP_MULTI_VGPU_EXCLUSIVE = ( + 2 # vGPU profile cannot be mixed with other vGPU profiles in same VM +) +NVML_VGPU_CAP_EXCLUSIVE_TYPE = ( + 3 # vGPU profile cannot run on a GPU alongside other profiles of different type +) +NVML_VGPU_CAP_EXCLUSIVE_SIZE = ( + 4 # vGPU profile cannot run on a GPU alongside other profiles of different size +) +NVML_VGPU_CAP_COUNT = 5 + +_nvmlVgpuDriverCapability_t = c_uint +NVML_VGPU_DRIVER_CAP_HETEROGENEOUS_MULTI_VGPU = ( + 0 # Supports mixing of different vGPU profiles within one guest VM +) +NVML_VGPU_DRIVER_CAP_WARM_UPDATE = 1 # Supports FSR and warm update of vGPU host driver without terminating the running guest VM +NVML_VGPU_DRIVER_CAP_COUNT = 2 + +_nvmlDeviceVgpuCapability_t = c_uint +NVML_DEVICE_VGPU_CAP_FRACTIONAL_MULTI_VGPU = 0 # Query whether the fractional vGPU profiles on this GPU can be used in multi-vGPU configurations +NVML_DEVICE_VGPU_CAP_HETEROGENEOUS_TIMESLICE_PROFILES = 1 # Query whether the GPU supports concurrent execution of timesliced vGPU profiles of differing types +NVML_DEVICE_VGPU_CAP_HETEROGENEOUS_TIMESLICE_SIZES = 2 # Query whether the GPU supports concurrent execution of timesliced vGPU profiles of differing framebuffer sizes +NVML_DEVICE_VGPU_CAP_READ_DEVICE_BUFFER_BW = 3 # Query the GPU's read_device_buffer expected bandwidth capacity in megabytes per second +NVML_DEVICE_VGPU_CAP_WRITE_DEVICE_BUFFER_BW = 4 # Query the GPU's write_device_buffer expected bandwidth capacity in megabytes per second +NVML_DEVICE_VGPU_CAP_DEVICE_STREAMING = ( + 5 # Query whether the vGPU profiles on the GPU supports migration data streaming +) +NVML_DEVICE_VGPU_CAP_MINI_QUARTER_GPU = ( + 6 # Set/Get support of mini-quarter vGPU profiles +) +NVML_DEVICE_VGPU_CAP_COMPUTE_MEDIA_ENGINE_GPU = ( + 7 # Set/Get support for compute media engine vGPU profiles +) +NVML_DEVICE_VGPU_CAP_WARM_UPDATE = ( + 8 # Query whether the GPU supports FSR and warm update +) +NVML_DEVICE_VGPU_CAP_HOMOGENEOUS_PLACEMENTS = 9 # Query whether the GPU supports reporting of placements of timesliced vGPU profiles with identical framebuffer sizes +NVML_DEVICE_VGPU_CAP_COUNT = 10 + +_nvmlVgpuGuestInfoState_t = c_uint +NVML_VGPU_INSTANCE_GUEST_INFO_STATE_UNINITIALIZED = 0 +NVML_VGPU_INSTANCE_GUEST_INFO_STATE_INITIALIZED = 1 + +_nvmlVgpuVmCompatibility_t = c_uint +NVML_VGPU_VM_COMPATIBILITY_NONE = 0x0 +NVML_VGPU_VM_COMPATIBILITY_COLD = 0x1 +NVML_VGPU_VM_COMPATIBILITY_HIBERNATE = 0x2 +NVML_VGPU_VM_COMPATIBILITY_SLEEP = 0x4 +NVML_VGPU_VM_COMPATIBILITY_LIVE = 0x8 + +_nvmlVgpuPgpuCompatibilityLimitCode_t = c_uint +NVML_VGPU_COMPATIBILITY_LIMIT_NONE = 0x0 +NVML_VGPU_COMPATIBILITY_LIMIT_HOST_DRIVER = 0x1 +NVML_VGPU_COMPATIBILITY_LIMIT_GUEST_DRIVER = 0x2 +NVML_VGPU_COMPATIBILITY_LIMIT_GPU = 0x4 +NVML_VGPU_COMPATIBILITY_LIMIT_OTHER = 0x80000000 + +_nvmlHostVgpuMode_t = c_uint +NVML_HOST_VGPU_MODE_NON_SRIOV = 0 +NVML_HOST_VGPU_MODE_SRIOV = 1 + +_nvmlConfComputeGpusReadyState_t = c_uint +NVML_CC_ACCEPTING_CLIENT_REQUESTS_FALSE = 0 +NVML_CC_ACCEPTING_CLIENT_REQUESTS_TRUE = 1 + +_nvmlConfComputeGpuCaps_t = c_uint +NVML_CC_SYSTEM_GPUS_CC_NOT_CAPABLE = 0 +NVML_CC_SYSTEM_GPUS_CC_CAPABLE = 1 + +_nvmlConfComputeCpuCaps_t = c_uint +NVML_CC_SYSTEM_CPU_CAPS_NONE = 0 +NVML_CC_SYSTEM_CPU_CAPS_AMD_SEV = 1 +NVML_CC_SYSTEM_CPU_CAPS_INTEL_TDX = 2 +NVML_CC_SYSTEM_CPU_CAPS_AMD_SEV_SNP = 3 +NVML_CC_SYSTEM_CPU_CAPS_AMD_SNP_VTOM = 4 + +_nvmlConfComputeDevToolsMode_t = c_uint +NVML_CC_SYSTEM_DEVTOOLS_MODE_OFF = 0 +NVML_CC_SYSTEM_DEVTOOLS_MODE_ON = 1 + +NVML_CC_SYSTEM_MULTIGPU_NONE = 0 +NVML_CC_SYSTEM_MULTIGPU_PROTECTED_PCIE = 1 + +NVML_CC_SYSTEM_ENVIRONMENT_UNAVAILABLE = 0 +NVML_CC_SYSTEM_ENVIRONMENT_SIM = 1 +NVML_CC_SYSTEM_ENVIRONMENT_PROD = 2 + +_nvmlConfComputeCcFeature_t = c_uint +NVML_CC_SYSTEM_FEATURE_DISABLED = 0 +NVML_CC_SYSTEM_FEATURE_ENABLED = 1 + +_nvmlConfComputeCcKeyRotationThreshAttackerAdv_t = c_uint +NVML_CC_KEY_ROTATION_THRESH_ATTACKER_ADVANTAGE_MIN = 50 +NVML_CC_KEY_ROTATION_THRESH_ATTACKER_ADVANTAGE_MAX = 65 + +# GSP firmware +NVML_GSP_FIRMWARE_VERSION_BUF_SIZE = 0x40 + + +class NVMLLibraryMismatchError(Exception): + pass + + +## Error Checking ## +class NVMLError(Exception): + _valClassMapping = dict() + # List of currently known error codes + _errcode_to_string = { + NVML_ERROR_UNINITIALIZED: "Uninitialized", + NVML_ERROR_INVALID_ARGUMENT: "Invalid Argument", + NVML_ERROR_NOT_SUPPORTED: "Not Supported", + NVML_ERROR_NO_PERMISSION: "Insufficient Permissions", + NVML_ERROR_ALREADY_INITIALIZED: "Already Initialized", + NVML_ERROR_NOT_FOUND: "Not Found", + NVML_ERROR_INSUFFICIENT_SIZE: "Insufficient Size", + NVML_ERROR_INSUFFICIENT_POWER: "Insufficient External Power", + NVML_ERROR_DRIVER_NOT_LOADED: "Driver Not Loaded", + NVML_ERROR_TIMEOUT: "Timeout", + NVML_ERROR_IRQ_ISSUE: "Interrupt Request Issue", + NVML_ERROR_LIBRARY_NOT_FOUND: "NVML Shared Library Not Found", + NVML_ERROR_FUNCTION_NOT_FOUND: "Function Not Found", + NVML_ERROR_CORRUPTED_INFOROM: "Corrupted infoROM", + NVML_ERROR_GPU_IS_LOST: "GPU is lost", + NVML_ERROR_RESET_REQUIRED: "GPU requires restart", + NVML_ERROR_OPERATING_SYSTEM: "The operating system has blocked the request.", + NVML_ERROR_LIB_RM_VERSION_MISMATCH: "RM has detected an NVML/RM version mismatch.", + NVML_ERROR_MEMORY: "Insufficient Memory", + NVML_ERROR_UNKNOWN: "Unknown Error", + } + + def __new__(typ, value): + """ + Maps value to a proper subclass of NVMLError. + See _extractNVMLErrorsAsClasses function for more details + """ + if typ == NVMLError: + typ = NVMLError._valClassMapping.get(value, typ) + obj = Exception.__new__(typ) + obj.value = value + return obj + + def __str__(self): + try: + if self.value not in NVMLError._errcode_to_string: + NVMLError._errcode_to_string[self.value] = str( + nvmlErrorString(self.value) + ) + return NVMLError._errcode_to_string[self.value] + except NVMLError: + return "NVML Error with code %d" % self.value + + def __eq__(self, other): + return self.value == other.value + + +def nvmlExceptionClass(nvmlErrorCode): + if nvmlErrorCode not in NVMLError._valClassMapping: + raise ValueError("nvmlErrorCode %s is not valid" % nvmlErrorCode) + return NVMLError._valClassMapping[nvmlErrorCode] + + +def _extractNVMLErrorsAsClasses(): + """ + Generates a hierarchy of classes on top of NVMLError class. + + Each NVML Error gets a new NVMLError subclass. This way try,except blocks can filter appropriate + exceptions more easily. + + NVMLError is a parent class. Each NVML_ERROR_* gets it's own subclass. + e.g. NVML_ERROR_ALREADY_INITIALIZED will be turned into NVMLError_AlreadyInitialized + """ + this_module = sys.modules[__name__] + nvmlErrorsNames = [x for x in dir(this_module) if x.startswith("NVML_ERROR_")] + for err_name in nvmlErrorsNames: + # e.g. Turn NVML_ERROR_ALREADY_INITIALIZED into NVMLError_AlreadyInitialized + class_name = "NVMLError_" + string.capwords( + err_name.replace("NVML_ERROR_", ""), "_" + ).replace("_", "") + err_val = getattr(this_module, err_name) + + def gen_new(val): + def new(typ): + obj = NVMLError.__new__(typ, val) + return obj + + return new + + new_error_class = type(class_name, (NVMLError,), {"__new__": gen_new(err_val)}) + new_error_class.__module__ = __name__ + setattr(this_module, class_name, new_error_class) + NVMLError._valClassMapping[err_val] = new_error_class + + +_extractNVMLErrorsAsClasses() + + +def _nvmlCheckReturn(ret): + if ret != NVML_SUCCESS: + raise NVMLError(ret) + return ret + + +## Function access ## +_nvmlGetFunctionPointer_cache = ( + dict() +) # function pointers are cached to prevent unnecessary libLoadLock locking + + +def _nvmlGetFunctionPointer(name): + global nvmlLib + + if name in _nvmlGetFunctionPointer_cache: + return _nvmlGetFunctionPointer_cache[name] + + libLoadLock.acquire() + try: + # ensure library was loaded + if nvmlLib is None: + raise NVMLError(NVML_ERROR_UNINITIALIZED) + try: + _nvmlGetFunctionPointer_cache[name] = getattr(nvmlLib, name) + return _nvmlGetFunctionPointer_cache[name] + except AttributeError: + raise NVMLError(NVML_ERROR_FUNCTION_NOT_FOUND) + finally: + # lock is always freed + libLoadLock.release() + + +## Alternative object +# Allows the object to be printed +# Allows mismatched types to be assigned +# - like None when the Structure variant requires c_uint +class nvmlFriendlyObject(object): + def __init__(self, dictionary): + for x in dictionary: + setattr(self, x, dictionary[x]) + + def __str__(self): + return self.__dict__.__str__() + + +def nvmlStructToFriendlyObject(struct): + d = {} + for x in struct._fields_: + key = x[0] + value = getattr(struct, key) + # only need to convert from bytes if bytes, no need to check python version. + d[key] = value.decode() if isinstance(value, bytes) else value + obj = nvmlFriendlyObject(d) + return obj + + +# pack the object so it can be passed to the NVML library +def nvmlFriendlyObjectToStruct(obj, model): + for x in model._fields_: + key = x[0] + value = obj.__dict__[key] + # any c_char_p in python3 needs to be bytes, default encoding works fine. + if sys.version_info >= (3,): + setattr(model, key, value.encode()) + else: + setattr(model, key, value) + return model + + +## Unit structures +class struct_c_nvmlUnit_t(Structure): + pass # opaque handle + + +c_nvmlUnit_t = POINTER(struct_c_nvmlUnit_t) + + +class _PrintableStructure(Structure): + """ + Abstract class that produces nicer __str__ output than ctypes.Structure. + e.g. instead of: + >>> print str(obj) + + this class will print + class_name(field_name: formatted_value, field_name: formatted_value) + + _fmt_ dictionary of -> + e.g. class that has _field_ 'hex_value', c_uint could be formatted with + _fmt_ = {"hex_value" : "%08X"} + to produce nicer output. + Default formatting string for all fields can be set with key "" like: + _fmt_ = {"" : "%d MHz"} # e.g all values are numbers in MHz. + If not set it's assumed to be just "%s" + + Exact format of returned str from this class is subject to change in the future. + """ + + _fmt_ = {} + + def __str__(self): + result = [] + for x in self._fields_: + key = x[0] + value = getattr(self, key) + fmt = "%s" + if key in self._fmt_: + fmt = self._fmt_[key] + elif "" in self._fmt_: + fmt = self._fmt_[""] + result.append(("%s: " + fmt) % (key, value)) + return self.__class__.__name__ + "(" + ", ".join(result) + ")" + + def __getattribute__(self, name): + res = super(_PrintableStructure, self).__getattribute__(name) + # need to convert bytes to unicode for python3 don't need to for python2 + # Python 2 strings are of both str and bytes + # Python 3 strings are not of type bytes + # ctypes should convert everything to the correct values otherwise + if isinstance(res, bytes): + if isinstance(res, str): + return res + return res.decode() + return res + + def __setattr__(self, name, value): + if isinstance(value, str): + # encoding a python2 string returns the same value, since python2 strings are bytes already + # bytes passed in python3 will be ignored. + value = value.encode() + super(_PrintableStructure, self).__setattr__(name, value) + + +class c_nvmlUnitInfo_t(_PrintableStructure): + _fields_ = [ + ("name", c_char * 96), + ("id", c_char * 96), + ("serial", c_char * 96), + ("firmwareVersion", c_char * 96), + ] + + +class c_nvmlC2cModeInfo_v1_t(_PrintableStructure): + _fields_ = [("isC2cEnabled", c_uint)] + + +nvmlC2cModeInfo_v1 = 0x1000008 + + +class c_nvmlLedState_t(_PrintableStructure): + _fields_ = [ + ("cause", c_char * 256), + ("color", _nvmlLedColor_t), + ] + + +class c_nvmlPSUInfo_t(_PrintableStructure): + _fields_ = [ + ("state", c_char * 256), + ("current", c_uint), + ("voltage", c_uint), + ("power", c_uint), + ] + + +class c_nvmlUnitFanInfo_t(_PrintableStructure): + _fields_ = [ + ("speed", c_uint), + ("state", _nvmlFanState_t), + ] + + +class c_nvmlUnitFanSpeeds_t(_PrintableStructure): + _fields_ = [("fans", c_nvmlUnitFanInfo_t * 24), ("count", c_uint)] + + +## Device structures +class struct_c_nvmlDevice_t(Structure): + pass # opaque handle + + +c_nvmlDevice_t = POINTER(struct_c_nvmlDevice_t) + + +class nvmlPciInfoExt_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("domain", c_uint), + ("bus", c_uint), + ("device", c_uint), + ("pciDeviceId", c_uint), + ("pciSubSystemId", c_uint), + ("baseClass", c_uint), + ("subClass", c_uint), + ("busId", c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE), + ] + _fmt_ = { + "version": "0x%04X", + "domain": "0x%04X", + "bus": "0x%02X", + "device": "0x%02X", + "pciDeviceId": "0x%08X", + "pciSubSystemId": "0x%08X", + "baseClass": "0x%01X", + "subClass": "0x%01X", + } + + +nvmlPciInfoExt_v1 = 0x1000040 + + +# Legacy pciInfo used for _v1 and _v2 +class nvmlPciInfo_v2_t(_PrintableStructure): + _fields_ = [ + ("busId", c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_V2_SIZE), + ("domain", c_uint), + ("bus", c_uint), + ("device", c_uint), + ("pciDeviceId", c_uint), + # Added in 2.285 + ("pciSubSystemId", c_uint), + ("reserved0", c_uint), + ("reserved1", c_uint), + ("reserved2", c_uint), + ("reserved3", c_uint), + ] + _fmt_ = { + "domain": "0x%04X", + "bus": "0x%02X", + "device": "0x%02X", + "pciDeviceId": "0x%08X", + "pciSubSystemId": "0x%08X", + } + + +class nvmlPciInfo_t(_PrintableStructure): + _fields_ = [ + # Moved to the new busId location below + ("busIdLegacy", c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_V2_SIZE), + ("domain", c_uint), + ("bus", c_uint), + ("device", c_uint), + ("pciDeviceId", c_uint), + # Added in 2.285 + ("pciSubSystemId", c_uint), + # New busId replaced the long deprecated and reserved fields with a + # field of the same size in 9.0 + ("busId", c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE), + ] + _fmt_ = { + "domain": "0x%08X", + "bus": "0x%02X", + "device": "0x%02X", + "pciDeviceId": "0x%08X", + "pciSubSystemId": "0x%08X", + } + + +class c_nvmlSystemDriverBranchInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("branch", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ] + + +SystemDriverBranchInfo_v1 = 0x1000054 + + +class c_nvmlExcludedDeviceInfo_t(_PrintableStructure): + _fields_ = [("pci", nvmlPciInfo_t), ("uuid", c_char * NVML_DEVICE_UUID_BUFFER_SIZE)] + + +class nvmlNvLinkUtilizationControl_t(_PrintableStructure): + _fields_ = [ + ("units", _nvmlNvLinkUtilizationCountUnits_t), + ("pktfilter", _nvmlNvLinkUtilizationCountPktTypes_t), + ] + + +class c_nvmlMemory_t(_PrintableStructure): + _fields_ = [ + ("total", c_ulonglong), + ("free", c_ulonglong), + ("used", c_ulonglong), + ] + _fmt_ = {"": "%d B"} + + +class c_nvmlMemory_v2_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("total", c_ulonglong), + ("reserved", c_ulonglong), + ("free", c_ulonglong), + ("used", c_ulonglong), + ] + _fmt_ = {"": "%d B"} + + +nvmlMemory_v2 = 0x02000028 + + +class c_nvmlBAR1Memory_t(_PrintableStructure): + _fields_ = [ + ("bar1Total", c_ulonglong), + ("bar1Free", c_ulonglong), + ("bar1Used", c_ulonglong), + ] + _fmt_ = {"": "%d B"} + + +class nvmlClkMonFaultInfo_t(Structure): + _fields_ = [("clkApiDomain", c_uint), ("clkDomainFaultMask", c_uint)] + + +MAX_CLK_DOMAINS = 32 + + +class nvmlClkMonStatus_t(Structure): + _fields_ = [ + ("bGlobalStatus", c_uint), + ("clkMonListSize", c_uint), + ("clkMonList", nvmlClkMonFaultInfo_t * MAX_CLK_DOMAINS), + ] + + +# On Windows with the WDDM driver, usedGpuMemory is reported as None +# Code that processes this structure should check for None, I.E. +# +# if (info.usedGpuMemory is None): +# # TODO handle the error +# pass +# else: +# print("Using %d MiB of memory" % (info.usedGpuMemory / 1024 / 1024)) +# endif +# +# See NVML documentation for more information +class c_nvmlProcessInfo_v2_t(_PrintableStructure): + _fields_ = [ + ("pid", c_uint), + ("usedGpuMemory", c_ulonglong), + ("gpuInstanceId", c_uint), + ("computeInstanceId", c_uint), + ] + _fmt_ = {"usedGpuMemory": "%d B"} + + +c_nvmlProcessInfo_v3_t = c_nvmlProcessInfo_v2_t + +c_nvmlProcessInfo_t = c_nvmlProcessInfo_v3_t + +_nvmlProcessMode_t = c_uint +NVML_PROCESS_MODE_COMPUTE = 0 +NVML_PROCESS_MODE_GRAPHICS = 1 +NVML_PROCESS_MODE_MPS = 2 + + +class c_nvmlProcessDetail_v1_t(Structure): + _fields_ = [ + ("pid", c_uint), + ("usedGpuMemory", c_ulonglong), + ("gpuInstanceId", c_uint), + ("computeInstanceId", c_uint), + ("usedGpuCcProtectedMemory", c_ulonglong), + ] + + +class c_nvmlProcessDetailList_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("mode", _nvmlProcessMode_t), + ("numProcArrayEntries", c_uint), + ("procArray", POINTER(c_nvmlProcessDetail_v1_t)), + ] + _fmt_ = {"numProcArrayEntries": "%d B"} + + +c_nvmlProcessDetailList_t = c_nvmlProcessDetailList_v1_t + +nvmlProcessDetailList_v1 = 0x1000018 + + +class c_nvmlBridgeChipInfo_t(_PrintableStructure): + _fields_ = [ + ("type", _nvmlBridgeChipType_t), + ("fwVersion", c_uint), + ] + + +class c_nvmlBridgeChipHierarchy_t(_PrintableStructure): + _fields_ = [ + ("bridgeCount", c_uint), + ("bridgeChipInfo", c_nvmlBridgeChipInfo_t * 128), + ] + + +class c_nvmlEccErrorCounts_t(_PrintableStructure): + _fields_ = [ + ("l1Cache", c_ulonglong), + ("l2Cache", c_ulonglong), + ("deviceMemory", c_ulonglong), + ("registerFile", c_ulonglong), + ] + + +class c_nvmlUtilization_t(_PrintableStructure): + _fields_ = [ + ("gpu", c_uint), + ("memory", c_uint), + ] + _fmt_ = {"": "%d %%"} + + +# Added in 2.285 +class c_nvmlHwbcEntry_t(_PrintableStructure): + _fields_ = [ + ("hwbcId", c_uint), + ("firmwareVersion", c_char * 32), + ] + + +class c_nvmlValue_t(Union): + _fields_ = [ + ("dVal", c_double), + ("uiVal", c_uint), + ("ulVal", c_ulong), + ("ullVal", c_ulonglong), + ("sllVal", c_longlong), + ("siVal", c_int), + ("usVal", c_ushort), + ] + + +class c_nvmlSample_t(_PrintableStructure): + _fields_ = [ + ("timeStamp", c_ulonglong), + ("sampleValue", c_nvmlValue_t), + ] + + +class c_nvmlViolationTime_t(_PrintableStructure): + _fields_ = [ + ("referenceTime", c_ulonglong), + ("violationTime", c_ulonglong), + ] + + +class c_nvmlFieldValue_t(_PrintableStructure): + _fields_ = [ + ("fieldId", c_uint32), + ("scopeId", c_uint32), + ("timestamp", c_int64), + ("latencyUsec", c_int64), + ("valueType", _nvmlValueType_t), + ("nvmlReturn", _nvmlReturn_t), + ("value", c_nvmlValue_t), + ] + + +NVML_NVLINK_TOTAL_SUPPORTED_BW_MODES = 23 + +nvmlNvlinkSupportedBwModes_v1 = 0x100001C + + +class c_nvmlNvlinkSupportedBwModes_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("bwModes", c_uint8 * NVML_NVLINK_TOTAL_SUPPORTED_BW_MODES), + ("totalBwModes", c_uint8), + ] + + def __init__(self): + super(c_nvmlNvlinkSupportedBwModes_v1_t, self).__init__( + version=nvmlNvlinkSupportedBwModes_v1 + ) + + +nvmlNvlinkGetBwMode_v1 = 0x100000C + + +class c_nvmlNvlinkGetBwMode_v1_t(_PrintableStructure): + _fields_ = [("version", c_uint), ("bIsBest", c_uint), ("bwMode", c_uint8)] + + def __init__(self): + super(c_nvmlNvlinkGetBwMode_v1_t, self).__init__(version=nvmlNvlinkGetBwMode_v1) + + +nvmlNvlinkSetBwMode_v1 = 0x100000C + + +class c_nvmlNvlinkSetBwMode_v1_t(_PrintableStructure): + _fields_ = [("version", c_uint), ("bSetBest", c_uint), ("bwMode", c_uint8)] + + def __init__(self): + super(c_nvmlNvlinkSetBwMode_v1_t, self).__init__(version=nvmlNvlinkSetBwMode_v1) + + +class c_nvmlVgpuHeterogeneousMode_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("mode", c_uint), + ] + + +VgpuHeterogeneousMode_v1 = 0x1000008 + + +class c_nvmlVgpuPlacementId_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("placementId", c_uint), + ] + + +VgpuPlacementId_v1 = 0x1000008 + + +class c_nvmlVgpuPlacementList_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("count", c_uint), + ("placementSize", c_uint), + ("placementIds", POINTER(c_uint)), + ] + + +VgpuPlacementList_v1 = 0x1000018 + +NVML_VGPU_PGPU_HETEROGENEOUS_MODE = 0 +NVML_VGPU_PGPU_HOMOGENEOUS_MODE = 1 + + +class c_nvmlVgpuPlacementList_v2_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("placementSize", c_uint), + ("count", c_uint), + ("placementIds", POINTER(c_uint)), + ("mode", c_uint), + ] + + +VgpuPlacementList_v2 = 0x2000020 + + +class c_nvmlVgpuTypeBar1Info_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("bar1Size", c_ulonglong), + ] + + +VgpuTypeBar1Info_v1 = 0x1000010 + + +class c_nvmlVgpuInstanceUtilizationSample_t(_PrintableStructure): + _fields_ = [ + ("vgpuInstance", _nvmlVgpuInstance_t), + ("timeStamp", c_ulonglong), + ("smUtil", c_nvmlValue_t), + ("memUtil", c_nvmlValue_t), + ("encUtil", c_nvmlValue_t), + ("decUtil", c_nvmlValue_t), + ] + + +class c_nvmlVgpuInstanceUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("timeStamp", c_ulonglong), + ("vgpuInstance", _nvmlVgpuInstance_t), + ("smUtil", c_nvmlValue_t), + ("memUtil", c_nvmlValue_t), + ("encUtil", c_nvmlValue_t), + ("decUtil", c_nvmlValue_t), + ("jpgUtil", c_nvmlValue_t), + ("ofaUtil", c_nvmlValue_t), + ] + + +class c_nvmlVgpuInstancesUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("sampleValType", _nvmlValueType_t), + ("vgpuInstanceCount", c_uint), + ("lastSeenTimeStamp", c_ulonglong), + ("vgpuUtilArray", POINTER(c_nvmlVgpuInstanceUtilizationInfo_v1_t)), + ] + + +VgpuInstancesUtilizationInfo_v1 = 0x01000020 + + +class c_nvmlVgpuProcessUtilizationSample_t(_PrintableStructure): + _fields_ = [ + ("vgpuInstance", _nvmlVgpuInstance_t), + ("pid", c_uint), + ("processName", c_char * NVML_VGPU_NAME_BUFFER_SIZE), + ("timeStamp", c_ulonglong), + ("smUtil", c_uint), + ("memUtil", c_uint), + ("encUtil", c_uint), + ("decUtil", c_uint), + ] + + +class c_nvmlVgpuProcessUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("processName", c_char * NVML_VGPU_NAME_BUFFER_SIZE), + ("timeStamp", c_ulonglong), + ("vgpuInstance", _nvmlVgpuInstance_t), + ("pid", c_uint), + ("smUtil", c_uint), + ("memUtil", c_uint), + ("encUtil", c_uint), + ("decUtil", c_uint), + ("jpgUtil", c_uint), + ("ofaUtil", c_uint), + ] + + +class c_nvmlVgpuProcessesUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("vgpuProcessCount", c_uint), + ("lastSeenTimeStamp", c_ulonglong), + ("vgpuProcUtilArray", POINTER(c_nvmlVgpuProcessUtilizationInfo_v1_t)), + ] + + +VgpuProcessesUtilizationInfo_v1 = 0x01000018 + + +class nvmlVgpuRuntimeState_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("size", c_ulonglong), + ] + + +VgpuRuntimeState_v1 = 0x1000010 + + +class c_nvmlVgpuLicenseExpiry_t(_PrintableStructure): + _fields_ = [ + ("year", c_uint32), + ("month", c_uint16), + ("day", c_uint16), + ("hour", c_uint16), + ("min", c_uint16), + ("sec", c_uint16), + ("status", c_uint8), + ] + + +NVML_GRID_LICENSE_STATE_UNKNOWN = 0 +NVML_GRID_LICENSE_STATE_UNINITIALIZED = 1 +NVML_GRID_LICENSE_STATE_UNLICENSED_UNRESTRICTED = 2 +NVML_GRID_LICENSE_STATE_UNLICENSED_RESTRICTED = 3 +NVML_GRID_LICENSE_STATE_UNLICENSED = 4 +NVML_GRID_LICENSE_STATE_LICENSED = 5 + + +class c_nvmlVgpuLicenseInfo_t(_PrintableStructure): + _fields_ = [ + ("isLicensed", c_uint8), + ("licenseExpiry", c_nvmlVgpuLicenseExpiry_t), + ("currentState", c_uint), + ] + + +class c_nvmlEncoderSession_t(_PrintableStructure): + _fields_ = [ + ("sessionId", c_uint), + ("pid", c_uint), + ("vgpuInstance", _nvmlVgpuInstance_t), + ("codecType", c_uint), + ("hResolution", c_uint), + ("vResolution", c_uint), + ("averageFps", c_uint), + ("encodeLatency", c_uint), + ] + + +class c_nvmlProcessUtilizationSample_t(_PrintableStructure): + _fields_ = [ + ("pid", c_uint), + ("timeStamp", c_ulonglong), + ("smUtil", c_uint), + ("memUtil", c_uint), + ("encUtil", c_uint), + ("decUtil", c_uint), + ] + + +class c_nvmlProcessUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("timeStamp", c_ulonglong), + ("pid", c_uint), + ("smUtil", c_uint), + ("memUtil", c_uint), + ("encUtil", c_uint), + ("decUtil", c_uint), + ("jpgUtil", c_uint), + ("ofaUtil", c_uint), + ] + + +class c_nvmlProcessesUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("processSamplesCount", c_uint), + ("lastSeenTimeStamp", c_ulonglong), + ("procUtilArray", POINTER(c_nvmlProcessUtilizationInfo_v1_t)), + ] + + +ProcessesUtilizationInfo_v1 = 0x01000018 + + +class c_nvmlGridLicenseExpiry_t(_PrintableStructure): + _fields_ = [ + ("year", c_uint32), + ("month", c_uint16), + ("day", c_uint16), + ("hour", c_uint16), + ("min", c_uint16), + ("sec", c_uint16), + ("status", c_uint8), + ] + + +class c_nvmlGridLicensableFeature_v4_t(_PrintableStructure): + _fields_ = [ + ("featureCode", _nvmlGridLicenseFeatureCode_t), + ("featureState", c_uint), + ("licenseInfo", c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ("productName", c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ("featureEnabled", c_uint), + ("licenseExpiry", c_nvmlGridLicenseExpiry_t), + ] + + +class c_nvmlGridLicensableFeatures_v4_t(_PrintableStructure): + _fields_ = [ + ("isGridLicenseSupported", c_int), + ("licensableFeaturesCount", c_uint), + ( + "gridLicensableFeatures", + c_nvmlGridLicensableFeature_v4_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT, + ), + ] + + +class c_nvmlGridLicensableFeature_v3_t(_PrintableStructure): + _fields_ = [ + ("featureCode", _nvmlGridLicenseFeatureCode_t), + ("featureState", c_uint), + ("licenseInfo", c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ("productName", c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ("featureEnabled", c_uint), + ] + + +class c_nvmlGridLicensableFeatures_v3_t(_PrintableStructure): + _fields_ = [ + ("isGridLicenseSupported", c_int), + ("licensableFeaturesCount", c_uint), + ( + "gridLicensableFeatures", + c_nvmlGridLicensableFeature_v3_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT, + ), + ] + + +class c_nvmlGridLicensableFeature_v2_t(_PrintableStructure): + _fields_ = [ + ("featureCode", _nvmlGridLicenseFeatureCode_t), + ("featureState", c_uint), + ("licenseInfo", c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ("productName", c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ] + + +class c_nvmlGridLicensableFeatures_v2_t(_PrintableStructure): + _fields_ = [ + ("isGridLicenseSupported", c_int), + ("licensableFeaturesCount", c_uint), + ( + "gridLicensableFeatures", + c_nvmlGridLicensableFeature_v2_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT, + ), + ] + + +class c_nvmlGridLicensableFeature_t(_PrintableStructure): + _fields_ = [ + ("featureCode", _nvmlGridLicenseFeatureCode_t), + ("featureState", c_uint), + ("licenseInfo", c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ] + + +class c_nvmlGridLicensableFeatures_t(_PrintableStructure): + _fields_ = [ + ("isGridLicenseSupported", c_int), + ("licensableFeaturesCount", c_uint), + ( + "gridLicensableFeatures", + c_nvmlGridLicensableFeature_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT, + ), + ] + + +class c_nvmlMarginTemperature_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("marginTemperature", c_int), + ] + + +nvmlMarginTemperature_v1 = 0x1000008 + + +## Event structures +class struct_c_nvmlEventSet_t(Structure): + pass # opaque handle + + +c_nvmlEventSet_t = POINTER(struct_c_nvmlEventSet_t) + +nvmlEventTypeSingleBitEccError = 0x0000000000000001 +nvmlEventTypeDoubleBitEccError = 0x0000000000000002 +nvmlEventTypePState = 0x0000000000000004 +nvmlEventTypeXidCriticalError = 0x0000000000000008 +nvmlEventTypeClock = 0x0000000000000010 +nvmlEventTypePowerSourceChange = 0x0000000000000080 +nvmlEventMigConfigChange = 0x0000000000000100 +nvmlEventTypeSingleBitEccErrorStorm = 0x0000000000000200 +nvmlEventTypeDramRetirementEvent = 0x0000000000000400 +nvmlEventTypeDramRetirementFailure = 0x0000000000000800 +nvmlEventTypeNonFatalPoisonError = 0x0000000000001000 +nvmlEventTypeFatalPoisonError = 0x0000000000002000 +nvmlEventTypeGpuUnavailableError = 0x0000000000004000 +nvmlEventTypeGpuRecoveryAction = 0x0000000000008000 +nvmlEventTypeNone = 0x0000000000000000 +nvmlEventTypeAll = ( + nvmlEventTypeNone + | nvmlEventTypeSingleBitEccError + | nvmlEventTypeDoubleBitEccError + | nvmlEventTypePState + | nvmlEventTypeClock + | nvmlEventTypePowerSourceChange + | nvmlEventTypeXidCriticalError + | nvmlEventMigConfigChange + | nvmlEventTypeSingleBitEccErrorStorm + | nvmlEventTypeDramRetirementEvent + | nvmlEventTypeDramRetirementFailure + | nvmlEventTypeNonFatalPoisonError + | nvmlEventTypeFatalPoisonError + | nvmlEventTypeGpuUnavailableError + | nvmlEventTypeGpuRecoveryAction +) + +## Clock Event Reasons defines +nvmlClocksEventReasonGpuIdle = 0x0000000000000001 +nvmlClocksEventReasonApplicationsClocksSetting = 0x0000000000000002 +nvmlClocksEventReasonUserDefinedClocks = nvmlClocksEventReasonApplicationsClocksSetting # deprecated, use nvmlClocksEventReasonApplicationsClocksSetting +nvmlClocksEventReasonSwPowerCap = 0x0000000000000004 +nvmlClocksEventReasonHwSlowdown = 0x0000000000000008 +nvmlClocksEventReasonSyncBoost = 0x0000000000000010 +nvmlClocksEventReasonSwThermalSlowdown = 0x0000000000000020 +nvmlClocksEventReasonHwThermalSlowdown = 0x0000000000000040 +nvmlClocksEventReasonHwPowerBrakeSlowdown = 0x0000000000000080 +nvmlClocksEventReasonDisplayClockSetting = 0x0000000000000100 +nvmlClocksEventReasonNone = 0x0000000000000000 +nvmlClocksEventReasonAll = ( + nvmlClocksEventReasonNone + | nvmlClocksEventReasonGpuIdle + | nvmlClocksEventReasonApplicationsClocksSetting + | nvmlClocksEventReasonSwPowerCap + | nvmlClocksEventReasonHwSlowdown + | nvmlClocksEventReasonSyncBoost + | nvmlClocksEventReasonSwThermalSlowdown + | nvmlClocksEventReasonHwThermalSlowdown + | nvmlClocksEventReasonHwPowerBrakeSlowdown + | nvmlClocksEventReasonDisplayClockSetting +) + +## Following have been deprecated +nvmlClocksThrottleReasonGpuIdle = 0x0000000000000001 +nvmlClocksThrottleReasonApplicationsClocksSetting = 0x0000000000000002 +nvmlClocksThrottleReasonUserDefinedClocks = nvmlClocksThrottleReasonApplicationsClocksSetting # deprecated, use nvmlClocksThrottleReasonApplicationsClocksSetting +nvmlClocksThrottleReasonSwPowerCap = 0x0000000000000004 +nvmlClocksThrottleReasonHwSlowdown = 0x0000000000000008 +nvmlClocksThrottleReasonSyncBoost = 0x0000000000000010 +nvmlClocksThrottleReasonSwThermalSlowdown = 0x0000000000000020 +nvmlClocksThrottleReasonHwThermalSlowdown = 0x0000000000000040 +nvmlClocksThrottleReasonHwPowerBrakeSlowdown = 0x0000000000000080 +nvmlClocksThrottleReasonDisplayClockSetting = 0x0000000000000100 +nvmlClocksThrottleReasonNone = 0x0000000000000000 +nvmlClocksThrottleReasonAll = ( + nvmlClocksThrottleReasonNone + | nvmlClocksThrottleReasonGpuIdle + | nvmlClocksThrottleReasonApplicationsClocksSetting + | nvmlClocksThrottleReasonSwPowerCap + | nvmlClocksThrottleReasonHwSlowdown + | nvmlClocksThrottleReasonSyncBoost + | nvmlClocksThrottleReasonSwThermalSlowdown + | nvmlClocksThrottleReasonHwThermalSlowdown + | nvmlClocksThrottleReasonHwPowerBrakeSlowdown + | nvmlClocksThrottleReasonDisplayClockSetting +) + + +class c_nvmlEventData_t(_PrintableStructure): + _fields_ = [ + ("device", c_nvmlDevice_t), + ("eventType", c_ulonglong), + ("eventData", c_ulonglong), + ("gpuInstanceId", c_uint), + ("computeInstanceId", c_uint), + ] + _fmt_ = {"eventType": "0x%08X"} + + +class c_nvmlAccountingStats_t(_PrintableStructure): + _fields_ = [ + ("gpuUtilization", c_uint), + ("memoryUtilization", c_uint), + ("maxMemoryUsage", c_ulonglong), + ("time", c_ulonglong), + ("startTime", c_ulonglong), + ("isRunning", c_uint), + ("reserved", c_uint * 5), + ] + + +class c_nvmlVgpuVersion_t(Structure): + _fields_ = [("minVersion", c_uint), ("maxVersion", c_uint)] + + +class c_nvmlVgpuMetadata_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("revision", c_uint), + ("guestInfoState", _nvmlVgpuGuestInfoState_t), + ("guestDriverVersion", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ("hostDriverVersion", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ("reserved", c_uint * 6), + ("vgpuVirtualizationCaps", c_uint), + ("guestVgpuVersion", c_uint), + ("opaqueDataSize", c_uint), + ("opaqueData", c_char * NVML_VGPU_METADATA_OPAQUE_DATA_SIZE), + ] + + +class c_nvmlVgpuPgpuMetadata_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("revision", c_uint), + ("hostDriverVersion", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ("pgpuVirtualizationCaps", c_uint), + ("reserved", c_uint * 5), + ("hostSupportedVgpuRange", c_nvmlVgpuVersion_t), + ("opaqueDataSize", c_uint), + ("opaqueData", c_char * NVML_VGPU_PGPU_METADATA_OPAQUE_DATA_SIZE), + ] + + +class c_nvmlVgpuPgpuCompatibility_t(Structure): + _fields_ = [ + ("vgpuVmCompatibility", _nvmlVgpuVmCompatibility_t), + ("compatibilityLimitCode", _nvmlVgpuPgpuCompatibilityLimitCode_t), + ] + + +## vGPU scheduler policy defines +NVML_VGPU_SCHEDULER_POLICY_UNKNOWN = 0 +NVML_VGPU_SCHEDULER_POLICY_BEST_EFFORT = 1 +NVML_VGPU_SCHEDULER_POLICY_EQUAL_SHARE = 2 +NVML_VGPU_SCHEDULER_POLICY_FIXED_SHARE = 3 + +## Supported vGPU scheduler policy count +NVML_SUPPORTED_VGPU_SCHEDULER_POLICY_COUNT = 3 + +NVML_SCHEDULER_SW_MAX_LOG_ENTRIES = 200 + +NVML_VGPU_SCHEDULER_ARR_DEFAULT = 0 +NVML_VGPU_SCHEDULER_ARR_DISABLE = 1 +NVML_VGPU_SCHEDULER_ARR_ENABLE = 2 + + +class c_nvmlVgpuSchedDataWithARR_t(_PrintableStructure): + _fields_ = [ + ("avgFactor", c_uint), + ("timeslice", c_uint), + ] + + +class c_nvmlVgpuSchedData_t(_PrintableStructure): + _fields_ = [ + ("timeslice", c_uint), + ] + + +class c_nvmlVgpuSchedulerParams_t(Union): + _fields_ = [ + ("vgpuSchedDataWithARR", c_nvmlVgpuSchedDataWithARR_t), + ("vgpuSchedData", c_nvmlVgpuSchedData_t), + ] + + +class c_nvmlVgpuSchedulerLogEntry_t(_PrintableStructure): + _fields_ = [ + ("timestamp", c_ulonglong), + ("timeRunTotal", c_ulonglong), + ("timeRun", c_ulonglong), + ("swRunlistId", c_uint), + ("targetTimeSlice", c_ulonglong), + ("cumulativePreemptionTime", c_ulonglong), + ] + + +class c_nvmlVgpuSchedulerLog_t(_PrintableStructure): + _fields_ = [ + ("engineId", c_uint), + ("schedulerPolicy", c_uint), + ("arrMode", c_uint), + ("schedulerParams", c_nvmlVgpuSchedulerParams_t), + ("entriesCount", c_uint), + ( + "logEntries", + c_nvmlVgpuSchedulerLogEntry_t * NVML_SCHEDULER_SW_MAX_LOG_ENTRIES, + ), + ] + + +class c_nvmlVgpuSchedulerGetState_t(_PrintableStructure): + _fields_ = [ + ("schedulerPolicy", c_uint), + ("arrMode", c_uint), + ("schedulerParams", c_nvmlVgpuSchedulerParams_t), + ] + + +class c_nvmlVgpuSchedSetDataWithARR_t(_PrintableStructure): + _fields_ = [ + ("avgFactor", c_uint), + ("frequency", c_uint), + ] + + +class c_nvmlVgpuSchedSetData_t(_PrintableStructure): + _fields_ = [ + ("timeslice", c_uint), + ] + + +class c_nvmlVgpuSchedulerSetParams_t(Union): + _fields_ = [ + ("vgpuSchedDataWithARR", c_nvmlVgpuSchedSetDataWithARR_t), + ("vgpuSchedData", c_nvmlVgpuSchedSetData_t), + ] + + +class c_nvmlVgpuSchedulerSetState_t(_PrintableStructure): + _fields_ = [ + ("schedulerPolicy", c_uint), + ("enableARRMode", c_uint), + ("schedulerParams", c_nvmlVgpuSchedulerSetParams_t), + ] + + +class c_nvmlVgpuSchedulerCapabilities_t(_PrintableStructure): + _fields_ = [ + ("supportedSchedulers", c_uint * NVML_SUPPORTED_VGPU_SCHEDULER_POLICY_COUNT), + ("maxTimeslice", c_uint), + ("minTimeslice", c_uint), + ("isArrModeSupported", c_uint), + ("maxFrequencyForARR", c_uint), + ("minFrequencyForARR", c_uint), + ("maxAvgFactorForARR", c_uint), + ("minAvgFactorForARR", c_uint), + ] + + +class c_nvmlFBCStats_t(Structure): + _fields_ = [ + ("sessionsCount", c_uint), + ("averageFPS", c_uint), + ("averageLatency", c_uint), + ] + + +class c_nvmlFBCSession_t(_PrintableStructure): + _fields_ = [ + ("sessionId", c_uint), + ("pid", c_uint), + ("vgpuInstance", _nvmlVgpuInstance_t), + ("displayOrdinal", c_uint), + ("sessionType", c_uint), + ("sessionFlags", c_uint), + ("hMaxResolution", c_uint), + ("vMaxResolution", c_uint), + ("hResolution", c_uint), + ("vResolution", c_uint), + ("averageFPS", c_uint), + ("averageLatency", c_uint), + ] + + +NVML_DEVICE_MIG_DISABLE = 0x0 +NVML_DEVICE_MIG_ENABLE = 0x1 + +NVML_GPU_INSTANCE_PROFILE_1_SLICE = 0x0 +NVML_GPU_INSTANCE_PROFILE_2_SLICE = 0x1 +NVML_GPU_INSTANCE_PROFILE_3_SLICE = 0x2 +NVML_GPU_INSTANCE_PROFILE_4_SLICE = 0x3 +NVML_GPU_INSTANCE_PROFILE_7_SLICE = 0x4 +NVML_GPU_INSTANCE_PROFILE_8_SLICE = 0x5 +NVML_GPU_INSTANCE_PROFILE_6_SLICE = 0x6 +NVML_GPU_INSTANCE_PROFILE_1_SLICE_REV1 = 0x7 +NVML_GPU_INSTANCE_PROFILE_2_SLICE_REV1 = 0x8 +NVML_GPU_INSTANCE_PROFILE_1_SLICE_REV2 = 0x9 +NVML_GPU_INSTANCE_PROFILE_1_SLICE_GFX = 0xA +NVML_GPU_INSTANCE_PROFILE_2_SLICE_GFX = 0xB +NVML_GPU_INSTANCE_PROFILE_4_SLICE_GFX = 0xC +NVML_GPU_INSTANCE_PROFILE_COUNT = 0xD + + +class c_nvmlGpuInstancePlacement_t(Structure): + _fields_ = [("start", c_uint), ("size", c_uint)] + + +class c_nvmlGpuInstanceProfileInfo_t(Structure): + _fields_ = [ + ("id", c_uint), + ("isP2pSupported", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("copyEngineCount", c_uint), + ("decoderCount", c_uint), + ("encoderCount", c_uint), + ("jpegCount", c_uint), + ("ofaCount", c_uint), + ("memorySizeMB", c_ulonglong), + ] + + +nvmlGpuInstanceProfileInfo_v2 = 0x02000098 + + +class c_nvmlGpuInstanceProfileInfo_v2_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("id", c_uint), + ("isP2pSupported", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("copyEngineCount", c_uint), + ("decoderCount", c_uint), + ("encoderCount", c_uint), + ("jpegCount", c_uint), + ("ofaCount", c_uint), + ("memorySizeMB", c_ulonglong), + ("name", c_char * NVML_DEVICE_NAME_V2_BUFFER_SIZE), + ] + + def __init__(self): + super(c_nvmlGpuInstanceProfileInfo_v2_t, self).__init__( + version=nvmlGpuInstanceProfileInfo_v2 + ) + + +class c_nvmlGpuInstanceInfo_t(Structure): + _fields_ = [ + ("device", c_nvmlDevice_t), + ("id", c_uint), + ("profileId", c_uint), + ("placement", c_nvmlGpuInstancePlacement_t), + ] + + +class struct_c_nvmlGpuInstance_t(Structure): + pass # opaque handle + + +c_nvmlGpuInstance_t = POINTER(struct_c_nvmlGpuInstance_t) + +NVML_COMPUTE_INSTANCE_PROFILE_1_SLICE = 0x0 +NVML_COMPUTE_INSTANCE_PROFILE_2_SLICE = 0x1 +NVML_COMPUTE_INSTANCE_PROFILE_3_SLICE = 0x2 +NVML_COMPUTE_INSTANCE_PROFILE_4_SLICE = 0x3 +NVML_COMPUTE_INSTANCE_PROFILE_7_SLICE = 0x4 +NVML_COMPUTE_INSTANCE_PROFILE_8_SLICE = 0x5 +NVML_COMPUTE_INSTANCE_PROFILE_6_SLICE = 0x6 +NVML_COMPUTE_INSTANCE_PROFILE_1_SLICE_REV1 = 0x7 +NVML_COMPUTE_INSTANCE_PROFILE_COUNT = 0x8 + +NVML_COMPUTE_INSTANCE_ENGINE_PROFILE_SHARED = 0x0 +NVML_COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT = 0x1 + + +class c_nvmlComputeInstancePlacement_t(Structure): + _fields_ = [("start", c_uint), ("size", c_uint)] + + +class c_nvmlComputeInstanceProfileInfo_t(Structure): + _fields_ = [ + ("id", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("sharedCopyEngineCount", c_uint), + ("sharedDecoderCount", c_uint), + ("sharedEncoderCount", c_uint), + ("sharedJpegCount", c_uint), + ("sharedOfaCount", c_uint), + ] + + +nvmlComputeInstanceProfileInfo_v2 = 0x02000088 + + +class c_nvmlComputeInstanceProfileInfo_v2_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("id", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("sharedCopyEngineCount", c_uint), + ("sharedDecoderCount", c_uint), + ("sharedEncoderCount", c_uint), + ("sharedJpegCount", c_uint), + ("sharedOfaCount", c_uint), + ("name", c_char * NVML_DEVICE_NAME_V2_BUFFER_SIZE), + ] + + def __init__(self): + super(c_nvmlComputeInstanceProfileInfo_v2_t, self).__init__( + version=nvmlComputeInstanceProfileInfo_v2 + ) + + +class c_nvmlComputeInstanceInfo_t(Structure): + _fields_ = [ + ("device", c_nvmlDevice_t), + ("gpuInstance", c_nvmlGpuInstance_t), + ("id", c_uint), + ("profileId", c_uint), + ("placement", c_nvmlComputeInstancePlacement_t), + ] + + +NVML_MAX_GPU_UTILIZATIONS = 8 +NVML_GPU_UTILIZATION_DOMAIN_GPU = 0 +NVML_GPU_UTILIZATION_DOMAIN_FB = 1 +NVML_GPU_UTILIZATION_DOMAIN_VID = 2 +NVML_GPU_UTILIZATION_DOMAIN_BUS = 3 + + +class c_nvmlGpuDynamicPstatesUtilization_t(Structure): + _fields_ = [ + ("bIsPresent", c_uint, 1), + ("percentage", c_uint), + ("incThreshold", c_uint), + ("decThreshold", c_uint), + ] + + +class c_nvmlGpuDynamicPstatesInfo_t(Structure): + _fields_ = [ + ("flags", c_uint), + ( + "utilization", + c_nvmlGpuDynamicPstatesUtilization_t * NVML_MAX_GPU_UTILIZATIONS, + ), + ] + + +NVML_MAX_THERMAL_SENSORS_PER_GPU = 3 + +NVML_THERMAL_TARGET_NONE = 0 +NVML_THERMAL_TARGET_GPU = 1 +NVML_THERMAL_TARGET_MEMORY = 2 +NVML_THERMAL_TARGET_POWER_SUPPLY = 4 +NVML_THERMAL_TARGET_BOARD = 8 +NVML_THERMAL_TARGET_VCD_BOARD = 9 +NVML_THERMAL_TARGET_VCD_INLET = 10 +NVML_THERMAL_TARGET_VCD_OUTLET = 11 +NVML_THERMAL_TARGET_ALL = 15 +NVML_THERMAL_TARGET_UNKNOWN = -1 + +NVML_THERMAL_CONTROLLER_NONE = 0 +NVML_THERMAL_CONTROLLER_GPU_INTERNAL = 1 +NVML_THERMAL_CONTROLLER_ADM1032 = 2 +NVML_THERMAL_CONTROLLER_ADT7461 = 3 +NVML_THERMAL_CONTROLLER_MAX6649 = 4 +NVML_THERMAL_CONTROLLER_MAX1617 = 5 +NVML_THERMAL_CONTROLLER_LM99 = 6 +NVML_THERMAL_CONTROLLER_LM89 = 7 +NVML_THERMAL_CONTROLLER_LM64 = 8 +NVML_THERMAL_CONTROLLER_G781 = 9 +NVML_THERMAL_CONTROLLER_ADT7473 = 10 +NVML_THERMAL_CONTROLLER_SBMAX6649 = 11 +NVML_THERMAL_CONTROLLER_VBIOSEVT = 12 +NVML_THERMAL_CONTROLLER_OS = 13 +NVML_THERMAL_CONTROLLER_NVSYSCON_CANOAS = 14 +NVML_THERMAL_CONTROLLER_NVSYSCON_E551 = 15 +NVML_THERMAL_CONTROLLER_MAX6649R = 16 +NVML_THERMAL_CONTROLLER_ADT7473S = 17 +NVML_THERMAL_CONTROLLER_UNKNOWN = -1 + + +class c_nvmlGpuThermalSensor_t(Structure): + _fields_ = [ + ("controller", c_int), + ("defaultMinTemp", c_int), + ("defaultMaxTemp", c_int), + ("currentTemp", c_int), + ("target", c_int), + ] + + +class c_nvmlGpuThermalSettings_t(Structure): + _fields_ = [ + ("count", c_uint), + ("sensor", c_nvmlGpuThermalSensor_t * NVML_MAX_THERMAL_SENSORS_PER_GPU), + ] + + +_nvmlCoolerControl_t = c_uint +NVML_THERMAL_COOLER_SIGNAL_NONE = 0 +NVML_THERMAL_COOLER_SIGNAL_TOGGLE = 1 +NVML_THERMAL_COOLER_SIGNAL_VARIABLE = 2 +NVML_THERMAL_COOLER_SIGNAL_COUNT = 3 + +_nvmlCoolerTarget_t = c_uint +NVML_THERMAL_COOLER_TARGET_NONE = 1 << 0 +NVML_THERMAL_COOLER_TARGET_GPU = 1 << 1 +NVML_THERMAL_COOLER_TARGET_MEMORY = 1 << 2 +NVML_THERMAL_COOLER_TARGET_POWER_SUPPLY = 1 << 3 +NVML_THERMAL_COOLER_TARGET_GPU_RELATED = ( + NVML_THERMAL_COOLER_TARGET_GPU + | NVML_THERMAL_COOLER_TARGET_MEMORY + | NVML_THERMAL_COOLER_TARGET_POWER_SUPPLY +) + + +class c_nvmlCoolerInfo_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("index", c_uint), + ("coolerControlType", _nvmlCoolerControl_t), + ("coolerTarget", _nvmlCoolerTarget_t), + ] + + +nvmlCoolerInfo_v1 = 0x1000010 + + +def nvmlDeviceGetCoolerInfo(handle): + c_coolerInfo = c_nvmlCoolerInfo_t() + c_coolerInfo.version = nvmlCoolerInfo_v1 + c_coolerInfo.index = 0 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCoolerInfo") + ret = fn(handle, byref(c_coolerInfo)) + _nvmlCheckReturn(ret) + return [c_coolerInfo.coolerControlType, c_coolerInfo.coolerTarget] + + +class struct_c_nvmlComputeInstance_t(Structure): + pass # opaque handle + + +c_nvmlComputeInstance_t = POINTER(struct_c_nvmlComputeInstance_t) + + +class c_nvmlDeviceAttributes(Structure): + _fields_ = [ + ("multiprocessorCount", c_uint), + ("sharedCopyEngineCount", c_uint), + ("sharedDecoderCount", c_uint), + ("sharedEncoderCount", c_uint), + ("sharedJpegCount", c_uint), + ("sharedOfaCount", c_uint), + ("gpuInstanceSliceCount", c_uint), + ("computeInstanceSliceCount", c_uint), + ("memorySizeMB", c_ulonglong), + ] + + +class c_nvmlRowRemapperHistogramValues(Structure): + _fields_ = [ + ("max", c_uint), + ("high", c_uint), + ("partial", c_uint), + ("low", c_uint), + ("none", c_uint), + ] + + +NVML_GPU_CERT_CHAIN_SIZE = 0x1000 +NVML_GPU_ATTESTATION_CERT_CHAIN_SIZE = 0x1400 +NVML_CC_GPU_CEC_NONCE_SIZE = 0x20 +NVML_CC_GPU_ATTESTATION_REPORT_SIZE = 0x2000 +NVML_CC_GPU_CEC_ATTESTATION_REPORT_SIZE = 0x1000 +NVML_CC_CEC_ATTESTATION_REPORT_NOT_PRESENT = 0 +NVML_CC_CEC_ATTESTATION_REPORT_PRESENT = 1 + + +class c_nvmlConfComputeSystemState_t(Structure): + _fields_ = [ + ("environment", c_uint), + ("ccFeature", c_uint), + ("devToolsMode", c_uint), + ] + + +nvmlSystemConfComputeSettings_v1 = 0x1000014 + + +class c_nvmlSystemConfComputeSettings_v1_t(Structure): + _fields_ = [ + ("version", c_uint), + ("environment", c_uint), + ("ccFeature", c_uint), + ("devToolsMode", c_uint), + ("multiGpuMode", c_uint), + ] + + def __init__(self): + super(c_nvmlSystemConfComputeSettings_v1_t, self).__init__( + version=nvmlSystemConfComputeSettings_v1 + ) + + +class c_nvmlConfComputeSystemCaps_t(Structure): + _fields_ = [ + ("cpuCaps", c_uint), + ("gpusCaps", c_uint), + ] + + +class c_nvmlConfComputeMemSizeInfo_t(Structure): + _fields_ = [ + ("protectedMemSizeKib", c_ulonglong), + ("unprotectedMemSizeKib", c_ulonglong), + ] + + +class c_nvmlConfComputeGpuCertificate_t(Structure): + _fields_ = [ + ("certChainSize", c_uint), + ("attestationCertChainSize", c_uint), + ("certChain", c_uint8 * NVML_GPU_CERT_CHAIN_SIZE), + ("attestationCertChain", c_uint8 * NVML_GPU_ATTESTATION_CERT_CHAIN_SIZE), + ] + + +class c_nvmlConfComputeGpuAttestationReport_t(Structure): + _fields_ = [ + ("isCecAttestationReportPresent", c_uint), + ("attestationReportSize", c_uint), + ("cecAttestationReportSize", c_uint), + ("nonce", c_uint8 * NVML_CC_GPU_CEC_NONCE_SIZE), + ("attestationReport", c_uint8 * NVML_CC_GPU_ATTESTATION_REPORT_SIZE), + ("cecAttestationReport", c_uint8 * NVML_CC_GPU_CEC_ATTESTATION_REPORT_SIZE), + ] + + +class c_nvmlConfComputeSetKeyRotationThresholdInfo_t(Structure): + _fields_ = [ + ("version", c_uint), + ("maxAttackerAdvantage", c_ulong), + ] + + +ConfComputeSetKeyRotationThresholdInfo_v1 = 0x1000010 + + +class c_nvmlConfComputeGetKeyRotationThresholdInfo_t(Structure): + _fields_ = [ + ("version", c_uint), + ("attackerAdvantage", c_ulong), + ] + + +ConfComputeGetKeyRotationThresholdInfo_v1 = 0x1000010 + + +## string/bytes conversion for ease of use +def convertStrBytes(func): + """ + In python 3, strings are unicode instead of bytes, and need to be converted for ctypes + Args from caller: (1, 'string', <__main__.c_nvmlDevice_t at 0xFFFFFFFF>) + Args passed to function: (1, b'string', <__main__.c_nvmlDevice_t at 0xFFFFFFFF)> + ---- + Returned from function: b'returned string' + Returned to caller: 'returned string' + """ + + @wraps(func) + def wrapper(*args, **kwargs): + # encoding a str returns bytes in python 2 and 3 + args = [arg.encode() if isinstance(arg, str) else arg for arg in args] + res = func(*args, **kwargs) + # In python 2, str and bytes are the same + # In python 3, str is unicode and should be decoded. + # Ctypes handles most conversions, this only effects c_char and char arrays. + if isinstance(res, bytes): + if isinstance(res, str): + return res + return res.decode() + return res + + if sys.version_info >= (3,): + return wrapper + return func + + +def throwOnVersionMismatch(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except NVMLError_FunctionNotFound: + raise NVMLLibraryMismatchError( + "Unversioned function called and the " + "pyNVML version does not match the NVML lib version. " + "Either use matching pyNVML and NVML lib versions or " + "use a versioned function such as " + func.__name__ + "_v2" + ) + + return wrapper + + +## C function wrappers ## +def nvmlInitWithFlags(flags): + _LoadNvmlLibrary() + + # + # Initialize the library + # + fn = _nvmlGetFunctionPointer("nvmlInitWithFlags") + ret = fn(flags) + _nvmlCheckReturn(ret) + + # Atomically update refcount + global _nvmlLib_refcount + libLoadLock.acquire() + _nvmlLib_refcount += 1 + libLoadLock.release() + return None + + +def nvmlInit(): + nvmlInitWithFlags(0) + return None + + +def _LoadNvmlLibrary(): + """ + Load the library if it isn't loaded already + """ + global nvmlLib + + if nvmlLib is None: + # lock to ensure only one caller loads the library + libLoadLock.acquire() + + try: + # ensure the library still isn't loaded + if nvmlLib is None: + try: + if sys.platform[:3] == "win": + # cdecl calling convention + try: + # Check for nvml.dll in System32 first for DCH drivers + nvmlLib = CDLL( + os.path.join( + os.getenv("WINDIR", "C:/Windows"), + "System32/nvml.dll", + ) + ) + except OSError as ose: + # If nvml.dll is not found in System32, it should be in ProgramFiles + # load nvml.dll from %ProgramFiles%/NVIDIA Corporation/NVSMI/nvml.dll + nvmlLib = CDLL( + os.path.join( + os.getenv("ProgramFiles", "C:/Program Files"), + "NVIDIA Corporation/NVSMI/nvml.dll", + ) + ) + else: + # assume linux + nvmlLib = CDLL("libnvidia-ml.so.1") + except OSError as ose: + _nvmlCheckReturn(NVML_ERROR_LIBRARY_NOT_FOUND) + if nvmlLib is None: + _nvmlCheckReturn(NVML_ERROR_LIBRARY_NOT_FOUND) + finally: + # lock is always freed + libLoadLock.release() + + +def nvmlShutdown(): + # + # Leave the library loaded, but shutdown the interface + # + fn = _nvmlGetFunctionPointer("nvmlShutdown") + ret = fn() + _nvmlCheckReturn(ret) + + # Atomically update refcount + global _nvmlLib_refcount + libLoadLock.acquire() + if 0 < _nvmlLib_refcount: + _nvmlLib_refcount -= 1 + libLoadLock.release() + return None + + +# Added in 2.285 +@convertStrBytes +def nvmlErrorString(result): + fn = _nvmlGetFunctionPointer("nvmlErrorString") + fn.restype = c_char_p # otherwise return is an int + ret = fn(result) + return ret + + +# Added in 2.285 +@convertStrBytes +def nvmlSystemGetNVMLVersion(): + c_version = create_string_buffer(NVML_SYSTEM_NVML_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlSystemGetNVMLVersion") + ret = fn(c_version, c_uint(NVML_SYSTEM_NVML_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + + +def nvmlSystemGetCudaDriverVersion(): + c_cuda_version = c_int() + fn = _nvmlGetFunctionPointer("nvmlSystemGetCudaDriverVersion") + ret = fn(byref(c_cuda_version)) + _nvmlCheckReturn(ret) + return c_cuda_version.value + + +def nvmlSystemGetCudaDriverVersion_v2(): + c_cuda_version = c_int() + fn = _nvmlGetFunctionPointer("nvmlSystemGetCudaDriverVersion_v2") + ret = fn(byref(c_cuda_version)) + _nvmlCheckReturn(ret) + return c_cuda_version.value + + +# Added in 2.285 +@convertStrBytes +def nvmlSystemGetProcessName(pid): + c_name = create_string_buffer(1024) + fn = _nvmlGetFunctionPointer("nvmlSystemGetProcessName") + ret = fn(c_uint(pid), c_name, c_uint(1024)) + _nvmlCheckReturn(ret) + return c_name.value + + +@convertStrBytes +def nvmlSystemGetDriverVersion(): + c_version = create_string_buffer(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlSystemGetDriverVersion") + ret = fn(c_version, c_uint(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + + +# Added in 2.285 +def nvmlSystemGetHicVersion(): + c_count = c_uint(0) + hics = None + fn = _nvmlGetFunctionPointer("nvmlSystemGetHicVersion") + + # get the count + ret = fn(byref(c_count), None) + + # this should only fail with insufficient size + if (ret != NVML_SUCCESS) and (ret != NVML_ERROR_INSUFFICIENT_SIZE): + raise NVMLError(ret) + + # If there are no hics + if c_count.value == 0: + return [] + + hic_array = c_nvmlHwbcEntry_t * c_count.value + hics = hic_array() + ret = fn(byref(c_count), hics) + _nvmlCheckReturn(ret) + return hics + + +def nvmlSystemGetDriverBranch(): + c_branchInfo = c_nvmlSystemDriverBranchInfo_v1_t(0) + c_branchInfo.version = SystemDriverBranchInfo_v1 + fn = _nvmlGetFunctionPointer("nvmlSystemGetDriverBranch") + ret = fn(byref(c_branchInfo), c_uint(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_branchInfo + + +## Unit get functions +def nvmlUnitGetCount(): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlUnitGetCount") + ret = fn(byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + + +def nvmlUnitGetHandleByIndex(index): + c_index = c_uint(index) + unit = c_nvmlUnit_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetHandleByIndex") + ret = fn(c_index, byref(unit)) + _nvmlCheckReturn(ret) + return unit + + +def nvmlUnitGetUnitInfo(unit): + c_info = c_nvmlUnitInfo_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetUnitInfo") + ret = fn(unit, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + + +def nvmlUnitGetLedState(unit): + c_state = c_nvmlLedState_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetLedState") + ret = fn(unit, byref(c_state)) + _nvmlCheckReturn(ret) + return c_state + + +def nvmlUnitGetPsuInfo(unit): + c_info = c_nvmlPSUInfo_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetPsuInfo") + ret = fn(unit, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + + +def nvmlUnitGetTemperature(unit, type): + c_temp = c_uint() + fn = _nvmlGetFunctionPointer("nvmlUnitGetTemperature") + ret = fn(unit, c_uint(type), byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.value + + +def nvmlUnitGetFanSpeedInfo(unit): + c_speeds = c_nvmlUnitFanSpeeds_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetFanSpeedInfo") + ret = fn(unit, byref(c_speeds)) + _nvmlCheckReturn(ret) + return c_speeds + + +# added to API +def nvmlUnitGetDeviceCount(unit): + c_count = c_uint(0) + # query the unit to determine device count + fn = _nvmlGetFunctionPointer("nvmlUnitGetDevices") + ret = fn(unit, byref(c_count), None) + if ret == NVML_ERROR_INSUFFICIENT_SIZE: + ret = NVML_SUCCESS + _nvmlCheckReturn(ret) + return c_count.value + + +def nvmlUnitGetDevices(unit): + c_count = c_uint(nvmlUnitGetDeviceCount(unit)) + device_array = c_nvmlDevice_t * c_count.value + c_devices = device_array() + fn = _nvmlGetFunctionPointer("nvmlUnitGetDevices") + ret = fn(unit, byref(c_count), c_devices) + _nvmlCheckReturn(ret) + return c_devices + + +## Device get functions +def nvmlDeviceGetCount(): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCount_v2") + ret = fn(byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + + +def nvmlDeviceGetHandleByIndex(index): + c_index = c_uint(index) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleByIndex_v2") + ret = fn(c_index, byref(device)) + _nvmlCheckReturn(ret) + return device + + +@convertStrBytes +def nvmlDeviceGetHandleBySerial(serial): + c_serial = c_char_p(serial) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleBySerial") + ret = fn(c_serial, byref(device)) + _nvmlCheckReturn(ret) + return device + + +@convertStrBytes +def nvmlDeviceGetHandleByUUID(uuid): + c_uuid = c_char_p(uuid) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleByUUID") + ret = fn(c_uuid, byref(device)) + _nvmlCheckReturn(ret) + return device + + +@convertStrBytes +def nvmlDeviceGetHandleByPciBusId(pciBusId): + c_busId = c_char_p(pciBusId) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleByPciBusId_v2") + ret = fn(c_busId, byref(device)) + _nvmlCheckReturn(ret) + return device + + +@convertStrBytes +def nvmlDeviceGetName(handle): + c_name = create_string_buffer(NVML_DEVICE_NAME_V2_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetName") + ret = fn(handle, c_name, c_uint(NVML_DEVICE_NAME_V2_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_name.value + + +class c_nvmlDevicePerfModes_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("str", c_char * NVML_PERF_MODES_BUFFER_SIZE), + ] + + +nvmlDevicePerfModes_v1 = 0x1000804 + + +@convertStrBytes +def nvmlDeviceGetPerformanceModes(handle): + perfModes = c_nvmlDevicePerfModes_v1_t() + perfModes.version = nvmlDevicePerfModes_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPerformanceModes") + ret = fn(handle, byref(perfModes)) + _nvmlCheckReturn(ret) + return perfModes.str + + +class c_nvmlDeviceCurrentClockFreqs_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("str", c_char * NVML_PERF_MODES_BUFFER_SIZE), + ] + + +nvmlDeviceCurrentClockFreqs_v1 = 0x1000804 + + +@convertStrBytes +def nvmlDeviceGetCurrentClockFreqs(handle): + currentClockFreqs = c_nvmlDeviceCurrentClockFreqs_v1_t() + currentClockFreqs.version = nvmlDeviceCurrentClockFreqs_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrentClockFreqs") + ret = fn(handle, byref(currentClockFreqs)) + _nvmlCheckReturn(ret) + return currentClockFreqs.str + + +def nvmlDeviceGetBoardId(handle): + c_id = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBoardId") + ret = fn(handle, byref(c_id)) + _nvmlCheckReturn(ret) + return c_id.value + + +def nvmlDeviceGetMultiGpuBoard(handle): + c_multiGpu = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMultiGpuBoard") + ret = fn(handle, byref(c_multiGpu)) + _nvmlCheckReturn(ret) + return c_multiGpu.value + + +def nvmlDeviceGetBrand(handle): + c_type = _nvmlBrandType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBrand") + ret = fn(handle, byref(c_type)) + _nvmlCheckReturn(ret) + return c_type.value + + +def nvmlDeviceGetC2cModeInfoV1(handle): + c_info = c_nvmlC2cModeInfo_v1_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetC2cModeInfoV") + ret = fn(handle, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + + +def nvmlDeviceGetC2cModeInfoV(handle): + return nvmlDeviceGetC2cModeInfoV1(handle) + + +@convertStrBytes +def nvmlDeviceGetBoardPartNumber(handle): + c_part_number = create_string_buffer(NVML_DEVICE_PART_NUMBER_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBoardPartNumber") + ret = fn(handle, c_part_number, c_uint(NVML_DEVICE_PART_NUMBER_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_part_number.value + + +@convertStrBytes +def nvmlDeviceGetSerial(handle): + c_serial = create_string_buffer(NVML_DEVICE_SERIAL_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSerial") + ret = fn(handle, c_serial, c_uint(NVML_DEVICE_SERIAL_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_serial.value + + +def nvmlDeviceGetModuleId(handle, moduleId=c_uint()): + isReference = type(moduleId) is not c_uint + moduleIdRef = moduleId if isReference else byref(moduleId) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetModuleId") + ret = fn(handle, moduleIdRef) + if isReference: + return ret + else: + _nvmlCheckReturn(ret) + return moduleId.value + + +def nvmlDeviceGetMemoryAffinity(handle, nodeSetSize, scope): + affinity_array = c_ulonglong * nodeSetSize + c_affinity = affinity_array() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryAffinity") + ret = fn(handle, nodeSetSize, byref(c_affinity), _nvmlAffinityScope_t(scope)) + _nvmlCheckReturn(ret) + return c_affinity + + +def nvmlDeviceGetCpuAffinityWithinScope(handle, cpuSetSize, scope): + affinity_array = c_ulonglong * cpuSetSize + c_affinity = affinity_array() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCpuAffinityWithinScope") + ret = fn(handle, cpuSetSize, byref(c_affinity), _nvmlAffinityScope_t(scope)) + _nvmlCheckReturn(ret) + return c_affinity + + +def nvmlDeviceGetCpuAffinity(handle, cpuSetSize): + affinity_array = c_ulonglong * cpuSetSize + c_affinity = affinity_array() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCpuAffinity") + ret = fn(handle, cpuSetSize, byref(c_affinity)) + _nvmlCheckReturn(ret) + return c_affinity + + +def nvmlDeviceSetCpuAffinity(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetCpuAffinity") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceClearCpuAffinity(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceClearCpuAffinity") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetNumaNodeId(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNumaNodeId") + node = c_int() + ret = fn(handle, byref(node)) + _nvmlCheckReturn(ret) + return node.value + + +def nvmlDeviceGetMinorNumber(handle): + c_minor_number = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMinorNumber") + ret = fn(handle, byref(c_minor_number)) + _nvmlCheckReturn(ret) + return c_minor_number.value + + +@convertStrBytes +def nvmlDeviceGetUUID(handle): + c_uuid = create_string_buffer(NVML_DEVICE_UUID_V2_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetUUID") + ret = fn(handle, c_uuid, c_uint(NVML_DEVICE_UUID_V2_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_uuid.value + + +@convertStrBytes +def nvmlDeviceGetInforomVersion(handle, infoRomObject): + c_version = create_string_buffer(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetInforomVersion") + ret = fn( + handle, + _nvmlInforomObject_t(infoRomObject), + c_version, + c_uint(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE), + ) + _nvmlCheckReturn(ret) + return c_version.value + + +# Added in 4.304 +@convertStrBytes +def nvmlDeviceGetInforomImageVersion(handle): + c_version = create_string_buffer(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetInforomImageVersion") + ret = fn(handle, c_version, c_uint(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + + +# Added in 4.304 +def nvmlDeviceGetInforomConfigurationChecksum(handle): + c_checksum = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetInforomConfigurationChecksum") + ret = fn(handle, byref(c_checksum)) + _nvmlCheckReturn(ret) + return c_checksum.value + + +# Added in 4.304 +def nvmlDeviceValidateInforom(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceValidateInforom") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetLastBBXFlushTime(handle): + c_timestamp = c_ulonglong() + c_durationUs = c_ulong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetLastBBXFlushTime") + ret = fn(handle, byref(c_timestamp), byref(c_durationUs)) + _nvmlCheckReturn(ret) + return [c_timestamp.value, c_durationUs.value] + + +def nvmlDeviceGetDisplayMode(handle): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDisplayMode") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + + +def nvmlDeviceGetDisplayActive(handle): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDisplayActive") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + + +def nvmlDeviceGetPersistenceMode(handle): + c_state = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPersistenceMode") + ret = fn(handle, byref(c_state)) + _nvmlCheckReturn(ret) + return c_state.value + + +def nvmlDeviceGetPciInfoExt(handle, c_info): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPciInfoExt") + ret = fn(handle, c_info) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetPciInfo_v3(handle): + c_info = nvmlPciInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPciInfo_v3") + ret = fn(handle, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + + +def nvmlDeviceGetPciInfo(handle): + return nvmlDeviceGetPciInfo_v3(handle) + + +def nvmlDeviceGetClockInfo(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClockInfo") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + + +# Added in 2.285 +def nvmlDeviceGetMaxClockInfo(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxClockInfo") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + + +# Added in 4.304 +def nvmlDeviceGetApplicationsClock(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetApplicationsClock") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + + +def nvmlDeviceGetMaxCustomerBoostClock(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxCustomerBoostClock") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + + +def nvmlDeviceGetClock(handle, type, id): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClock") + ret = fn(handle, _nvmlClockType_t(type), _nvmlClockId_t(id), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + + +# Added in 5.319 +def nvmlDeviceGetDefaultApplicationsClock(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDefaultApplicationsClock") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + + +# Added in 4.304 +def nvmlDeviceGetSupportedMemoryClocks(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedMemoryClocks") + ret = fn(handle, byref(c_count), None) + + if ret == NVML_SUCCESS: + # special case, no clocks + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + clocks_array = c_uint * c_count.value + c_clocks = clocks_array() + + # make the call again + ret = fn(handle, byref(c_count), c_clocks) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + procs.append(c_clocks[i]) + + return procs + else: + # error case + raise NVMLError(ret) + + +# Added in 4.304 +def nvmlDeviceGetSupportedGraphicsClocks(handle, memoryClockMHz): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedGraphicsClocks") + ret = fn(handle, c_uint(memoryClockMHz), byref(c_count), None) + + if ret == NVML_SUCCESS: + # special case, no clocks + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + clocks_array = c_uint * c_count.value + c_clocks = clocks_array() + + # make the call again + ret = fn(handle, c_uint(memoryClockMHz), byref(c_count), c_clocks) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + procs.append(c_clocks[i]) + + return procs + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetFanSpeed(handle): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanSpeed") + ret = fn(handle, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + + +def nvmlDeviceGetFanSpeed_v2(handle, fan): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanSpeed_v2") + ret = fn(handle, fan, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + + +class c_nvmlFanSpeedInfo_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("fan", c_uint), + ("speed", c_uint), + ] + + +nvmlFanSpeedInfo_v1 = 0x100000C + + +def nvmlDeviceGetFanSpeedRPM(handle): + c_fanSpeed = c_nvmlFanSpeedInfo_t() + c_fanSpeed.fan = 0 + c_fanSpeed.version = nvmlFanSpeedInfo_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanSpeedRPM") + ret = fn(handle, byref(c_fanSpeed)) + _nvmlCheckReturn(ret) + return c_fanSpeed.speed + + +def nvmlDeviceGetTargetFanSpeed(handle, fan): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTargetFanSpeed") + ret = fn(handle, fan, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + + +def nvmlDeviceGetNumFans(device): + c_numFans = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNumFans") + ret = fn(device, byref(c_numFans)) + _nvmlCheckReturn(ret) + return c_numFans.value + + +def nvmlDeviceSetDefaultFanSpeed_v2(handle, index): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDefaultFanSpeed_v2") + ret = fn(handle, index) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetMinMaxFanSpeed(handle, minSpeed=c_uint(), maxSpeed=c_uint()): + isReference = (type(minSpeed) is not c_uint) or (type(maxSpeed) is not c_uint) + minSpeedRef = minSpeed if isReference else byref(minSpeed) + maxSpeedRef = maxSpeed if isReference else byref(maxSpeed) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMinMaxFanSpeed") + ret = fn(handle, minSpeedRef, maxSpeedRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else [minSpeed.value, maxSpeed.value] + + +def nvmlDeviceGetFanControlPolicy_v2(handle, fan, fanControlPolicy=c_uint()): + isReference = type(fanControlPolicy) is not c_uint + fanControlPolicyRef = fanControlPolicy if isReference else byref(fanControlPolicy) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanControlPolicy_v2") + ret = fn(handle, fan, fanControlPolicyRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else fanControlPolicy.value + + +def nvmlDeviceSetFanControlPolicy(handle, fan, fanControlPolicy): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetFanControlPolicy") + ret = fn(handle, fan, _nvmlFanControlPolicy_t(fanControlPolicy)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +class c_nvmlTemperature_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("sensorType", _nvmlTemperatureSensors_t), + ("temperature", c_int), + ] + + +nvmlTemperature_v1 = 0x100000C + + +def nvmlDeviceGetTemperatureV1(handle, sensor): + c_temp = c_nvmlTemperature_v1_t() + c_temp.version = nvmlTemperature_v1 + c_temp.sensorType = _nvmlTemperatureSensors_t(sensor) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTemperatureV") + ret = fn(handle, byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.temperature + + +def nvmlDeviceGetTemperatureV(handle, sensor, version=nvmlTemperature_v1): + if version == nvmlTemperature_v1: + return nvmlDeviceGetTemperatureV1(handle, sensor) + else: + raise NVMLError(NVML_ERROR_ARGUMENT_VERSION_MISMATCH) + + +# DEPRECATED use nvmlDeviceGetTemperatureV instead +def nvmlDeviceGetTemperature(handle, sensor): + c_temp = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTemperature") + ret = fn(handle, _nvmlTemperatureSensors_t(sensor), byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.value + + +def nvmlDeviceGetTemperatureThreshold(handle, threshold): + c_temp = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTemperatureThreshold") + ret = fn(handle, _nvmlTemperatureThresholds_t(threshold), byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.value + + +def nvmlDeviceSetTemperatureThreshold(handle, threshold, temp): + c_temp = c_uint() + c_temp.value = temp + fn = _nvmlGetFunctionPointer("nvmlDeviceSetTemperatureThreshold") + ret = fn(handle, _nvmlTemperatureThresholds_t(threshold), byref(c_temp)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetMarginTemperature(handle): + c_marginTempInfo = c_nvmlMarginTemperature_v1_t() + c_marginTempInfo.version = nvmlMarginTemperature_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMarginTemperature") + ret = fn(handle, byref(c_marginTempInfo)) + _nvmlCheckReturn(ret) + return c_marginTempInfo.marginTemperature + + +# DEPRECATED use nvmlDeviceGetPerformanceState +def nvmlDeviceGetPowerState(handle): + c_pstate = _nvmlPstates_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerState") + ret = fn(handle, byref(c_pstate)) + _nvmlCheckReturn(ret) + return c_pstate.value + + +def nvmlDeviceGetPerformanceState(handle): + c_pstate = _nvmlPstates_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPerformanceState") + ret = fn(handle, byref(c_pstate)) + _nvmlCheckReturn(ret) + return c_pstate.value + + +def nvmlDeviceGetPowerManagementMode(handle): + c_pcapMode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementMode") + ret = fn(handle, byref(c_pcapMode)) + _nvmlCheckReturn(ret) + return c_pcapMode.value + + +def nvmlDeviceGetPowerManagementLimit(handle): + c_limit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementLimit") + ret = fn(handle, byref(c_limit)) + _nvmlCheckReturn(ret) + return c_limit.value + + +# Added in 4.304 +def nvmlDeviceGetPowerManagementLimitConstraints(handle): + c_minLimit = c_uint() + c_maxLimit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementLimitConstraints") + ret = fn(handle, byref(c_minLimit), byref(c_maxLimit)) + _nvmlCheckReturn(ret) + return [c_minLimit.value, c_maxLimit.value] + + +# Added in 4.304 +def nvmlDeviceGetPowerManagementDefaultLimit(handle): + c_limit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementDefaultLimit") + ret = fn(handle, byref(c_limit)) + _nvmlCheckReturn(ret) + return c_limit.value + + +# Added in 331 +def nvmlDeviceGetEnforcedPowerLimit(handle): + c_limit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEnforcedPowerLimit") + ret = fn(handle, byref(c_limit)) + _nvmlCheckReturn(ret) + return c_limit.value + + +def nvmlDeviceGetPowerUsage(handle): + c_watts = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerUsage") + ret = fn(handle, byref(c_watts)) + _nvmlCheckReturn(ret) + return c_watts.value + + +def nvmlDeviceGetTotalEnergyConsumption(handle): + c_millijoules = c_uint64() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTotalEnergyConsumption") + ret = fn(handle, byref(c_millijoules)) + _nvmlCheckReturn(ret) + return c_millijoules.value + + +# Added in 4.304 +def nvmlDeviceGetGpuOperationMode(handle): + c_currState = _nvmlGpuOperationMode_t() + c_pendingState = _nvmlGpuOperationMode_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuOperationMode") + ret = fn(handle, byref(c_currState), byref(c_pendingState)) + _nvmlCheckReturn(ret) + return [c_currState.value, c_pendingState.value] + + +# Added in 4.304 +def nvmlDeviceGetCurrentGpuOperationMode(handle): + return nvmlDeviceGetGpuOperationMode(handle)[0] + + +# Added in 4.304 +def nvmlDeviceGetPendingGpuOperationMode(handle): + return nvmlDeviceGetGpuOperationMode(handle)[1] + + +def nvmlDeviceGetMemoryInfo(handle, version=None): + if not version: + c_memory = c_nvmlMemory_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryInfo") + else: + c_memory = c_nvmlMemory_v2_t() + c_memory.version = version + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryInfo_v2") + ret = fn(handle, byref(c_memory)) + _nvmlCheckReturn(ret) + return c_memory + + +def nvmlDeviceGetBAR1MemoryInfo(handle): + c_bar1_memory = c_nvmlBAR1Memory_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBAR1MemoryInfo") + ret = fn(handle, byref(c_bar1_memory)) + _nvmlCheckReturn(ret) + return c_bar1_memory + + +def nvmlDeviceGetComputeMode(handle): + c_mode = _nvmlComputeMode_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeMode") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + + +def nvmlDeviceGetCudaComputeCapability(handle): + c_major = c_int() + c_minor = c_int() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCudaComputeCapability") + ret = fn(handle, byref(c_major), byref(c_minor)) + _nvmlCheckReturn(ret) + return (c_major.value, c_minor.value) + + +def nvmlDeviceGetEccMode(handle): + c_currState = _nvmlEnableState_t() + c_pendingState = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEccMode") + ret = fn(handle, byref(c_currState), byref(c_pendingState)) + _nvmlCheckReturn(ret) + return [c_currState.value, c_pendingState.value] + + +# added to API +def nvmlDeviceGetCurrentEccMode(handle): + return nvmlDeviceGetEccMode(handle)[0] + + +# added to API +def nvmlDeviceGetPendingEccMode(handle): + return nvmlDeviceGetEccMode(handle)[1] + + +def nvmlDeviceGetDefaultEccMode(handle): + c_defaultState = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDefaultEccMode") + ret = fn(handle, byref(c_defaultState)) + _nvmlCheckReturn(ret) + return [c_defaultState.value] + + +def nvmlDeviceGetTotalEccErrors(handle, errorType, counterType): + c_count = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTotalEccErrors") + ret = fn( + handle, + _nvmlMemoryErrorType_t(errorType), + _nvmlEccCounterType_t(counterType), + byref(c_count), + ) + _nvmlCheckReturn(ret) + return c_count.value + + +# This is deprecated, instead use nvmlDeviceGetMemoryErrorCounter +def nvmlDeviceGetDetailedEccErrors(handle, errorType, counterType): + c_counts = c_nvmlEccErrorCounts_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDetailedEccErrors") + ret = fn( + handle, + _nvmlMemoryErrorType_t(errorType), + _nvmlEccCounterType_t(counterType), + byref(c_counts), + ) + _nvmlCheckReturn(ret) + return c_counts + + +# Added in 4.304 +def nvmlDeviceGetMemoryErrorCounter(handle, errorType, counterType, locationType): + c_count = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryErrorCounter") + ret = fn( + handle, + _nvmlMemoryErrorType_t(errorType), + _nvmlEccCounterType_t(counterType), + _nvmlMemoryLocation_t(locationType), + byref(c_count), + ) + _nvmlCheckReturn(ret) + return c_count.value + + +def nvmlDeviceGetUtilizationRates(handle): + c_util = c_nvmlUtilization_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetUtilizationRates") + ret = fn(handle, byref(c_util)) + _nvmlCheckReturn(ret) + return c_util + + +def nvmlDeviceGetEncoderUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + + +def nvmlDeviceGetDecoderUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDecoderUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + + +def nvmlDeviceGetJpgUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetJpgUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + + +def nvmlDeviceGetOfaUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetOfaUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + + +def nvmlDeviceGetPcieReplayCounter(handle): + c_replay = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieReplayCounter") + ret = fn(handle, byref(c_replay)) + _nvmlCheckReturn(ret) + return c_replay.value + + +def nvmlDeviceGetDriverModel(handle): + c_currModel = _nvmlDriverModel_t() + c_pendingModel = _nvmlDriverModel_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDriverModel") + ret = fn(handle, byref(c_currModel), byref(c_pendingModel)) + _nvmlCheckReturn(ret) + return [c_currModel.value, c_pendingModel.value] + + +# added to API +def nvmlDeviceGetCurrentDriverModel(handle): + return nvmlDeviceGetDriverModel(handle)[0] + + +# added to API +def nvmlDeviceGetPendingDriverModel(handle): + return nvmlDeviceGetDriverModel(handle)[1] + + +# Added in 2.285 +@convertStrBytes +def nvmlDeviceGetVbiosVersion(handle): + c_version = create_string_buffer(NVML_DEVICE_VBIOS_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVbiosVersion") + ret = fn(handle, c_version, c_uint(NVML_DEVICE_VBIOS_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + + +# Added in 2.285 +def nvmlDeviceGetComputeRunningProcesses_v2(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeRunningProcesses_v2") + ret = fn(handle, byref(c_count), None) + if ret == NVML_SUCCESS: + # special case, no running processes + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + # oversize the array in case more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v2_t * c_count.value + c_procs = proc_array() + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value: + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + return procs + else: + # error case + raise NVMLError(ret) + + +# Added in 2.285 +def nvmlDeviceGetComputeRunningProcesses_v3(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeRunningProcesses_v3") + ret = fn(handle, byref(c_count), None) + + if ret == NVML_SUCCESS: + # special case, no running processes + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + # oversize the array in case more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v3_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value: + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + + +@throwOnVersionMismatch +def nvmlDeviceGetComputeRunningProcesses(handle): + return nvmlDeviceGetComputeRunningProcesses_v3(handle) + + +def nvmlDeviceGetGraphicsRunningProcesses_v2(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGraphicsRunningProcesses_v2") + ret = fn(handle, byref(c_count), None) + if ret == NVML_SUCCESS: + # special case, no running processes + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + # oversize the array in case more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v2_t * c_count.value + c_procs = proc_array() + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value: + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + return procs + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetGraphicsRunningProcesses_v3(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGraphicsRunningProcesses_v3") + ret = fn(handle, byref(c_count), None) + + if ret == NVML_SUCCESS: + # special case, no running processes + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + # oversize the array in case more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v3_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value: + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + + +@throwOnVersionMismatch +def nvmlDeviceGetGraphicsRunningProcesses(handle): + return nvmlDeviceGetGraphicsRunningProcesses_v3(handle) + + +@throwOnVersionMismatch +def nvmlDeviceGetMPSComputeRunningProcesses(handle): + return nvmlDeviceGetMPSComputeRunningProcesses_v3(handle) + + +def nvmlDeviceGetMPSComputeRunningProcesses_v2(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMPSComputeRunningProcesses_v2") + ret = fn(handle, byref(c_count), None) + + if ret == NVML_SUCCESS: + # special case, no running processes + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + # oversize the array in case more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v2_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value: + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetMPSComputeRunningProcesses_v3(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMPSComputeRunningProcesses_v3") + ret = fn(handle, byref(c_count), None) + + if ret == NVML_SUCCESS: + # special case, no running processes + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + # oversize the array in case more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v3_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value: + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetRunningProcessDetailList(handle, version, mode): + c_processDetailList = c_nvmlProcessDetailList_t() + c_processDetailList.version = version + c_processDetailList.mode = mode + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRunningProcessDetailList") + + # first call to get the size + ret = fn(handle, byref(c_processDetailList)) + if ret == NVML_SUCCESS: + # special case, no running processes + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + c_procs = c_nvmlProcessDetail_v1_t * c_processDetailList.numProcArrayEntries + c_processDetailList.procArray = cast( + (c_procs)(), POINTER(c_nvmlProcessDetail_v1_t) + ) + + # make the call again + ret = fn(handle, byref(c_processDetailList)) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_processDetailList.numProcArrayEntries): + # use an alternative struct for this object + obj = c_processDetailList.procArray[i] + if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value: + obj.usedGpuMemory = None + if obj.usedGpuCcProtectedMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value: + obj.usedGpuCcProtectedMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetAutoBoostedClocksEnabled(handle): + c_isEnabled = _nvmlEnableState_t() + c_defaultIsEnabled = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAutoBoostedClocksEnabled") + ret = fn(handle, byref(c_isEnabled), byref(c_defaultIsEnabled)) + _nvmlCheckReturn(ret) + return [c_isEnabled.value, c_defaultIsEnabled.value] + # Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks + + +## Set functions +def nvmlUnitSetLedState(unit, color): + fn = _nvmlGetFunctionPointer("nvmlUnitSetLedState") + ret = fn(unit, _nvmlLedColor_t(color)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceSetPersistenceMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetPersistenceMode") + ret = fn(handle, _nvmlEnableState_t(mode)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceSetComputeMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetComputeMode") + ret = fn(handle, _nvmlComputeMode_t(mode)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceSetEccMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetEccMode") + ret = fn(handle, _nvmlEnableState_t(mode)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceClearEccErrorCounts(handle, counterType): + fn = _nvmlGetFunctionPointer("nvmlDeviceClearEccErrorCounts") + ret = fn(handle, _nvmlEccCounterType_t(counterType)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceSetDriverModel(handle, model): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDriverModel") + ret = fn(handle, _nvmlDriverModel_t(model)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceSetAutoBoostedClocksEnabled(handle, enabled): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetAutoBoostedClocksEnabled") + ret = fn(handle, _nvmlEnableState_t(enabled)) + _nvmlCheckReturn(ret) + return None + # Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks + + +def nvmlDeviceSetDefaultAutoBoostedClocksEnabled(handle, enabled, flags): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDefaultAutoBoostedClocksEnabled") + ret = fn(handle, _nvmlEnableState_t(enabled), c_uint(flags)) + _nvmlCheckReturn(ret) + return None + # Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks + + +def nvmlDeviceSetGpuLockedClocks(handle, minGpuClockMHz, maxGpuClockMHz): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetGpuLockedClocks") + ret = fn(handle, c_uint(minGpuClockMHz), c_uint(maxGpuClockMHz)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceResetGpuLockedClocks(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetGpuLockedClocks") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceSetMemoryLockedClocks(handle, minMemClockMHz, maxMemClockMHz): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetMemoryLockedClocks") + ret = fn(handle, c_uint(minMemClockMHz), c_uint(maxMemClockMHz)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceResetMemoryLockedClocks(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetMemoryLockedClocks") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetClkMonStatus(handle, c_clkMonInfo=nvmlClkMonStatus_t()): + isReference = type(c_clkMonInfo) is not nvmlClkMonStatus_t + c_clkMonInfoRef = c_clkMonInfo if isReference else byref(c_clkMonInfo) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClkMonStatus") + ret = fn(handle, c_clkMonInfoRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else c_clkMonInfo + + +# Added in 4.304 +def nvmlDeviceSetApplicationsClocks(handle, maxMemClockMHz, maxGraphicsClockMHz): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetApplicationsClocks") + ret = fn(handle, c_uint(maxMemClockMHz), c_uint(maxGraphicsClockMHz)) + _nvmlCheckReturn(ret) + return None + + +# Added in 4.304 +def nvmlDeviceResetApplicationsClocks(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetApplicationsClocks") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + + +# Added in 4.304 +def nvmlDeviceSetPowerManagementLimit(handle, limit): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetPowerManagementLimit") + ret = fn(handle, c_uint(limit)) + _nvmlCheckReturn(ret) + return None + + +# Added in 4.304 +def nvmlDeviceSetGpuOperationMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetGpuOperationMode") + ret = fn(handle, _nvmlGpuOperationMode_t(mode)) + _nvmlCheckReturn(ret) + return None + + +# Added in 2.285 +def nvmlEventSetCreate(): + fn = _nvmlGetFunctionPointer("nvmlEventSetCreate") + eventSet = c_nvmlEventSet_t() + ret = fn(byref(eventSet)) + _nvmlCheckReturn(ret) + return eventSet + + +# Added in 2.285 +def nvmlDeviceRegisterEvents(handle, eventTypes, eventSet): + fn = _nvmlGetFunctionPointer("nvmlDeviceRegisterEvents") + ret = fn(handle, c_ulonglong(eventTypes), eventSet) + _nvmlCheckReturn(ret) + return None + + +# Added in 2.285 +def nvmlDeviceGetSupportedEventTypes(handle): + c_eventTypes = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedEventTypes") + ret = fn(handle, byref(c_eventTypes)) + _nvmlCheckReturn(ret) + return c_eventTypes.value + + +# raises NVML_ERROR_TIMEOUT exception on timeout +def nvmlEventSetWait_v2(eventSet, timeoutms): + fn = _nvmlGetFunctionPointer("nvmlEventSetWait_v2") + data = c_nvmlEventData_t() + ret = fn(eventSet, byref(data), c_uint(timeoutms)) + _nvmlCheckReturn(ret) + return data + + +def nvmlEventSetWait(eventSet, timeoutms): + return nvmlEventSetWait_v2(eventSet, timeoutms) + + +# Added in 2.285 +def nvmlEventSetFree(eventSet): + fn = _nvmlGetFunctionPointer("nvmlEventSetFree") + ret = fn(eventSet) + _nvmlCheckReturn(ret) + return None + + +# Added in 3.295 +def nvmlDeviceOnSameBoard(handle1, handle2): + fn = _nvmlGetFunctionPointer("nvmlDeviceOnSameBoard") + onSameBoard = c_int() + ret = fn(handle1, handle2, byref(onSameBoard)) + _nvmlCheckReturn(ret) + return onSameBoard.value != 0 + + +# Added in 3.295 +def nvmlDeviceGetCurrPcieLinkGeneration(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrPcieLinkGeneration") + gen = c_uint() + ret = fn(handle, byref(gen)) + _nvmlCheckReturn(ret) + return gen.value + + +# Added in 3.295 +def nvmlDeviceGetMaxPcieLinkGeneration(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxPcieLinkGeneration") + gen = c_uint() + ret = fn(handle, byref(gen)) + _nvmlCheckReturn(ret) + return gen.value + + +# Added in 3.295 +def nvmlDeviceGetCurrPcieLinkWidth(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrPcieLinkWidth") + width = c_uint() + ret = fn(handle, byref(width)) + _nvmlCheckReturn(ret) + return width.value + + +# Added in 3.295 +def nvmlDeviceGetMaxPcieLinkWidth(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxPcieLinkWidth") + width = c_uint() + ret = fn(handle, byref(width)) + _nvmlCheckReturn(ret) + return width.value + + +def nvmlDeviceGetGpuMaxPcieLinkGeneration(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuMaxPcieLinkGeneration") + gen = c_uint() + ret = fn(handle, byref(gen)) + _nvmlCheckReturn(ret) + return gen.value + + +# Added in 4.304 +def nvmlDeviceGetSupportedClocksThrottleReasons(handle): + c_reasons = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedClocksThrottleReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + + +def nvmlDeviceGetSupportedClocksEventReasons(handle): + c_reasons = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedClocksEventReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + + +# Added in 4.304 +def nvmlDeviceGetCurrentClocksThrottleReasons(handle): + c_reasons = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrentClocksThrottleReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + + +def nvmlDeviceGetCurrentClocksEventReasons(handle): + c_reasons = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrentClocksEventReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + + +# Added in 5.319 +def nvmlDeviceGetIndex(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetIndex") + c_index = c_uint() + ret = fn(handle, byref(c_index)) + _nvmlCheckReturn(ret) + return c_index.value + + +# Added in 5.319 +def nvmlDeviceGetAccountingMode(handle): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingMode") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + + +def nvmlDeviceSetAccountingMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetAccountingMode") + ret = fn(handle, _nvmlEnableState_t(mode)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceClearAccountingPids(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceClearAccountingPids") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetAccountingStats(handle, pid): + stats = c_nvmlAccountingStats_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingStats") + ret = fn(handle, c_uint(pid), byref(stats)) + _nvmlCheckReturn(ret) + if stats.maxMemoryUsage == NVML_VALUE_NOT_AVAILABLE_ulonglong.value: + # special case for WDDM on Windows, see comment above + stats.maxMemoryUsage = None + return stats + + +def nvmlDeviceGetAccountingPids(handle): + count = c_uint(nvmlDeviceGetAccountingBufferSize(handle)) + pids = (c_uint * count.value)() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingPids") + ret = fn(handle, byref(count), pids) + _nvmlCheckReturn(ret) + return list(map(int, pids[0 : count.value])) + + +def nvmlDeviceGetAccountingBufferSize(handle): + bufferSize = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingBufferSize") + ret = fn(handle, byref(bufferSize)) + _nvmlCheckReturn(ret) + return int(bufferSize.value) + + +def nvmlDeviceGetRetiredPages(device, sourceFilter): + c_source = _nvmlPageRetirementCause_t(sourceFilter) + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRetiredPages") + + # First call will get the size + ret = fn(device, c_source, byref(c_count), None) + + # this should only fail with insufficient size + if (ret != NVML_SUCCESS) and (ret != NVML_ERROR_INSUFFICIENT_SIZE): + raise NVMLError(ret) + + # call again with a buffer + # oversize the array for the rare cases where additional pages + # are retired between NVML calls + c_count.value = c_count.value * 2 + 5 + page_array = c_ulonglong * c_count.value + c_pages = page_array() + ret = fn(device, c_source, byref(c_count), c_pages) + _nvmlCheckReturn(ret) + return list(map(int, c_pages[0 : c_count.value])) + + +def nvmlDeviceGetRetiredPages_v2(device, sourceFilter): + c_source = _nvmlPageRetirementCause_t(sourceFilter) + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRetiredPages_v2") + + # First call will get the size + ret = fn(device, c_source, byref(c_count), None) + + # this should only fail with insufficient size + if (ret != NVML_SUCCESS) and (ret != NVML_ERROR_INSUFFICIENT_SIZE): + raise NVMLError(ret) + + # call again with a buffer + # oversize the array for the rare cases where additional pages + # are retired between NVML calls + c_count.value = c_count.value * 2 + 5 + page_array = c_ulonglong * c_count.value + c_pages = page_array() + times_array = c_ulonglong * c_count.value + c_times = times_array() + ret = fn(device, c_source, byref(c_count), c_pages, c_times) + _nvmlCheckReturn(ret) + return [ + {"address": int(c_pages[i]), "timestamp": int(c_times[i])} + for i in range(c_count.value) + ] + + +def nvmlDeviceGetRetiredPagesPendingStatus(device): + c_pending = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRetiredPagesPendingStatus") + ret = fn(device, byref(c_pending)) + _nvmlCheckReturn(ret) + return int(c_pending.value) + + +def nvmlDeviceGetAPIRestriction(device, apiType): + c_permission = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAPIRestriction") + ret = fn(device, _nvmlRestrictedAPI_t(apiType), byref(c_permission)) + _nvmlCheckReturn(ret) + return int(c_permission.value) + + +def nvmlDeviceSetAPIRestriction(handle, apiType, isRestricted): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetAPIRestriction") + ret = fn(handle, _nvmlRestrictedAPI_t(apiType), _nvmlEnableState_t(isRestricted)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetBridgeChipInfo(handle): + bridgeHierarchy = c_nvmlBridgeChipHierarchy_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBridgeChipInfo") + ret = fn(handle, byref(bridgeHierarchy)) + _nvmlCheckReturn(ret) + return bridgeHierarchy + + +def nvmlDeviceGetSamples(device, sampling_type, timeStamp): + c_sampling_type = _nvmlSamplingType_t(sampling_type) + c_time_stamp = c_ulonglong(timeStamp) + c_sample_count = c_uint(0) + c_sample_value_type = _nvmlValueType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSamples") + + ## First Call gets the size + ret = fn( + device, + c_sampling_type, + c_time_stamp, + byref(c_sample_value_type), + byref(c_sample_count), + None, + ) + + # Stop if this fails + if ret != NVML_SUCCESS: + raise NVMLError(ret) + + sampleArray = c_sample_count.value * c_nvmlSample_t + c_samples = sampleArray() + ret = fn( + device, + c_sampling_type, + c_time_stamp, + byref(c_sample_value_type), + byref(c_sample_count), + c_samples, + ) + _nvmlCheckReturn(ret) + return (c_sample_value_type.value, c_samples[0 : c_sample_count.value]) + + +def nvmlDeviceGetViolationStatus(device, perfPolicyType): + c_perfPolicy_type = _nvmlPerfPolicyType_t(perfPolicyType) + c_violTime = c_nvmlViolationTime_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetViolationStatus") + + ## Invoke the method to get violation time + ret = fn(device, c_perfPolicy_type, byref(c_violTime)) + _nvmlCheckReturn(ret) + return c_violTime + + +def nvmlDeviceGetPcieThroughput(device, counter): + c_util = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieThroughput") + ret = fn(device, _nvmlPcieUtilCounter_t(counter), byref(c_util)) + _nvmlCheckReturn(ret) + return c_util.value + + +def nvmlSystemGetTopologyGpuSet(cpuNumber): + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlSystemGetTopologyGpuSet") + + # First call will get the size + ret = fn(cpuNumber, byref(c_count), None) + + if ret != NVML_SUCCESS: + raise NVMLError(ret) + # call again with a buffer + device_array = c_nvmlDevice_t * c_count.value + c_devices = device_array() + ret = fn(cpuNumber, byref(c_count), c_devices) + _nvmlCheckReturn(ret) + return list(c_devices[0 : c_count.value]) + + +def nvmlDeviceGetTopologyNearestGpus(device, level): + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTopologyNearestGpus") + + # First call will get the size + ret = fn(device, level, byref(c_count), None) + + if ret != NVML_SUCCESS: + raise NVMLError(ret) + + # call again with a buffer + device_array = c_nvmlDevice_t * c_count.value + c_devices = device_array() + ret = fn(device, level, byref(c_count), c_devices) + _nvmlCheckReturn(ret) + return list(c_devices[0 : c_count.value]) + + +def nvmlDeviceGetTopologyCommonAncestor(device1, device2): + c_level = _nvmlGpuTopologyLevel_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTopologyCommonAncestor") + ret = fn(device1, device2, byref(c_level)) + _nvmlCheckReturn(ret) + return c_level.value + + +def nvmlDeviceGetNvLinkUtilizationCounter(device, link, counter): + c_rxcounter = c_ulonglong() + c_txcounter = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkUtilizationCounter") + ret = fn(device, link, counter, byref(c_rxcounter), byref(c_txcounter)) + _nvmlCheckReturn(ret) + return (c_rxcounter.value, c_txcounter.value) + + +def nvmlDeviceFreezeNvLinkUtilizationCounter(device, link, counter, freeze): + fn = _nvmlGetFunctionPointer("nvmlDeviceFreezeNvLinkUtilizationCounter") + ret = fn(device, link, counter, freeze) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceResetNvLinkUtilizationCounter(device, link, counter): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetNvLinkUtilizationCounter") + ret = fn(device, link, counter) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceSetNvLinkUtilizationControl(device, link, counter, control, reset): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetNvLinkUtilizationControl") + ret = fn(device, link, counter, byref(control), reset) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetNvLinkUtilizationControl(device, link, counter): + c_control = nvmlNvLinkUtilizationControl_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkUtilizationControl") + ret = fn(device, link, counter, byref(c_control)) + _nvmlCheckReturn(ret) + return c_control + + +def nvmlDeviceGetNvLinkCapability(device, link, capability): + c_capResult = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkCapability") + ret = fn(device, link, capability, byref(c_capResult)) + _nvmlCheckReturn(ret) + return c_capResult.value + + +def nvmlDeviceGetNvLinkErrorCounter(device, link, counter): + c_result = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkErrorCounter") + ret = fn(device, link, counter, byref(c_result)) + _nvmlCheckReturn(ret) + return c_result.value + + +def nvmlDeviceResetNvLinkErrorCounters(device, link): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetNvLinkErrorCounters") + ret = fn(device, link) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetNvLinkRemotePciInfo(device, link): + c_pci = nvmlPciInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkRemotePciInfo_v2") + ret = fn(device, link, byref(c_pci)) + _nvmlCheckReturn(ret) + return c_pci + + +def nvmlDeviceGetNvLinkRemoteDeviceType(handle, link): + c_type = _nvmlNvLinkDeviceType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkRemoteDeviceType") + ret = fn(handle, link, byref(c_type)) + _nvmlCheckReturn(ret) + return c_type.value + + +def nvmlDeviceGetNvLinkState(device, link): + c_isActive = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkState") + ret = fn(device, link, byref(c_isActive)) + _nvmlCheckReturn(ret) + return c_isActive.value + + +def nvmlDeviceGetNvLinkVersion(device, link): + c_version = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkVersion") + ret = fn(device, link, byref(c_version)) + _nvmlCheckReturn(ret) + return c_version.value + + +def nvmlDeviceModifyDrainState(pciInfo, newState): + fn = _nvmlGetFunctionPointer("nvmlDeviceModifyDrainState") + ret = fn(pointer(pciInfo), newState) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceQueryDrainState(pciInfo): + c_newState = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceQueryDrainState") + ret = fn(pointer(pciInfo), byref(c_newState)) + _nvmlCheckReturn(ret) + return c_newState.value + + +def nvmlDeviceRemoveGpu(pciInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceRemoveGpu") + ret = fn(pointer(pciInfo)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceDiscoverGpus(pciInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceDiscoverGpus") + ret = fn(pointer(pciInfo)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetFieldValues(handle, fieldIds): + values_arr = c_nvmlFieldValue_t * len(fieldIds) + values = values_arr() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFieldValues") + + for i, fieldId in enumerate(fieldIds): + try: + values[i].fieldId, values[i].scopeId = fieldId + except TypeError: + values[i].fieldId = fieldId + + ret = fn(handle, c_int32(len(fieldIds)), byref(values)) + _nvmlCheckReturn(ret) + return values + + +def nvmlDeviceClearFieldValues(handle, fieldIds): + values_arr = c_nvmlFieldValue_t * len(fieldIds) + values = values_arr() + fn = _nvmlGetFunctionPointer("nvmlDeviceClearFieldValues") + + for i, fieldId in enumerate(fieldIds): + try: + values[i].fieldId, values[i].scopeId = fieldId + except TypeError: + values[i].fieldId = fieldId + + ret = fn(handle, c_int32(len(fieldIds)), byref(values)) + _nvmlCheckReturn(ret) + return values + + +def nvmlDeviceGetVirtualizationMode(handle): + c_virtualization_mode = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVirtualizationMode") + ret = fn(handle, byref(c_virtualization_mode)) + _nvmlCheckReturn(ret) + return c_virtualization_mode.value + + +def nvmlDeviceSetVirtualizationMode(handle, virtualization_mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVirtualizationMode") + return fn(handle, virtualization_mode) + + +def nvmlDeviceGetVgpuHeterogeneousMode(handle): + c_vgpuHeterogeneousMode = c_nvmlVgpuHeterogeneousMode_v1_t(0) + c_vgpuHeterogeneousMode.version = VgpuHeterogeneousMode_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuHeterogeneousMode") + ret = fn(handle, byref(c_vgpuHeterogeneousMode)) + _nvmlCheckReturn(ret) + return c_vgpuHeterogeneousMode.mode + + +def nvmlDeviceSetVgpuHeterogeneousMode(handle, heterogeneous_mode): + c_vgpuHeterogeneousMode = c_nvmlVgpuHeterogeneousMode_v1_t(0) + c_vgpuHeterogeneousMode.version = VgpuHeterogeneousMode_v1 + c_vgpuHeterogeneousMode.mode = heterogeneous_mode + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVgpuHeterogeneousMode") + ret = fn(handle, byref(c_vgpuHeterogeneousMode)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlVgpuInstanceGetPlacementId(vgpuInstance): + c_placement = c_nvmlVgpuPlacementId_v1_t(0) + c_placement.version = VgpuPlacementId_v1 + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetPlacementId") + ret = fn(vgpuInstance, byref(c_placement)) + _nvmlCheckReturn(ret) + return c_placement.placementId + + +def nvmlDeviceGetVgpuTypeSupportedPlacements(handle, vgpuTypeId, mode=0, version=1): + c_max_instances = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstances") + ret = fn(handle, vgpuTypeId, byref(c_max_instances)) + _nvmlCheckReturn(ret) + + if version == 2: + c_vgpu_placements = c_nvmlVgpuPlacementList_v2_t() + c_vgpu_placements.version = VgpuPlacementList_v2 + c_vgpu_placements.count = c_max_instances.value + c_vgpu_placements.mode = mode + elif version == 1: + c_vgpu_placements = c_nvmlVgpuPlacementList_v1_t() + c_vgpu_placements.version = VgpuPlacementList_v1 + else: + raise NVMLError(NVML_ERROR_ARGUMENT_VERSION_MISMATCH) + + c_placements = c_uint * c_max_instances.value + c_vgpu_placements.placementIds = c_placements() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuTypeSupportedPlacements") + ret = fn(handle, vgpuTypeId, byref(c_vgpu_placements)) + _nvmlCheckReturn(ret) + return c_vgpu_placements + + +def nvmlDeviceGetVgpuTypeCreatablePlacements(handle, vgpuTypeId, version=1): + c_max_instances = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstances") + ret = fn(handle, vgpuTypeId, byref(c_max_instances)) + _nvmlCheckReturn(ret) + + if version == 2: + c_vgpu_placements = c_nvmlVgpuPlacementList_v2_t() + c_vgpu_placements.version = VgpuPlacementList_v2 + c_vgpu_placements.count = c_max_instances.value + elif version == 1: + c_vgpu_placements = c_nvmlVgpuPlacementList_v1_t() + c_vgpu_placements.version = VgpuPlacementList_v1 + + c_placements = c_uint * c_max_instances.value + c_vgpu_placements.placementIds = c_placements() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuTypeCreatablePlacements") + ret = fn(handle, vgpuTypeId, byref(c_vgpu_placements)) + _nvmlCheckReturn(ret) + return c_vgpu_placements + + +def nvmlGetVgpuDriverCapabilities(capability): + c_capResult = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGetVgpuDriverCapabilities") + ret = fn(_nvmlVgpuDriverCapability_t(capability), byref(c_capResult)) + _nvmlCheckReturn(ret) + return c_capResult.value + + +def nvmlDeviceGetVgpuCapabilities(handle, capability): + c_capResult = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuCapabilities") + ret = fn(handle, _nvmlDeviceVgpuCapability_t(capability), byref(c_capResult)) + _nvmlCheckReturn(ret) + return c_capResult.value + + +def nvmlDeviceSetVgpuCapabilities(handle, capability, state): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVgpuCapabilities") + ret = fn(handle, _nvmlDeviceVgpuCapability_t(capability), state) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetSupportedVgpus(handle): + # first call to get the size + c_vgpu_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedVgpus") + ret = fn(handle, byref(c_vgpu_count), None) + + if ret == NVML_SUCCESS: + # special case, no supported vGPUs + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + vgpu_type_ids_array = _nvmlVgpuTypeId_t * c_vgpu_count.value + c_vgpu_type_ids = vgpu_type_ids_array() + + # make the call again + ret = fn(handle, byref(c_vgpu_count), c_vgpu_type_ids) + _nvmlCheckReturn(ret) + vgpus = [] + for i in range(c_vgpu_count.value): + vgpus.append(c_vgpu_type_ids[i]) + return vgpus + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetCreatableVgpus(handle): + # first call to get the size + c_vgpu_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCreatableVgpus") + ret = fn(handle, byref(c_vgpu_count), None) + + if ret == NVML_SUCCESS: + # special case, no supported vGPUs + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + vgpu_type_ids_array = _nvmlVgpuTypeId_t * c_vgpu_count.value + c_vgpu_type_ids = vgpu_type_ids_array() + + # make the call again + ret = fn(handle, byref(c_vgpu_count), c_vgpu_type_ids) + _nvmlCheckReturn(ret) + vgpus = [] + for i in range(c_vgpu_count.value): + vgpus.append(c_vgpu_type_ids[i]) + return vgpus + else: + # error case + raise NVMLError(ret) + + +def nvmlVgpuTypeGetGpuInstanceProfileId(vgpuTypeId): + c_profile_id = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetGpuInstanceProfileId") + ret = fn(vgpuTypeId, byref(c_profile_id)) + _nvmlCheckReturn(ret) + return c_profile_id.value + + +@convertStrBytes +def nvmlVgpuTypeGetClass(vgpuTypeId): + c_class = create_string_buffer(NVML_DEVICE_NAME_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_NAME_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetClass") + ret = fn(vgpuTypeId, c_class, byref(c_buffer_size)) + _nvmlCheckReturn(ret) + return c_class.value + + +@convertStrBytes +def nvmlVgpuTypeGetName(vgpuTypeId): + c_name = create_string_buffer(NVML_DEVICE_NAME_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_NAME_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetName") + ret = fn(vgpuTypeId, c_name, byref(c_buffer_size)) + _nvmlCheckReturn(ret) + return c_name.value + + +def nvmlVgpuTypeGetDeviceID(vgpuTypeId): + c_device_id = c_ulonglong(0) + c_subsystem_id = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetDeviceID") + ret = fn(vgpuTypeId, byref(c_device_id), byref(c_subsystem_id)) + _nvmlCheckReturn(ret) + return (c_device_id.value, c_subsystem_id.value) + + +def nvmlVgpuTypeGetFramebufferSize(vgpuTypeId): + c_fb_size = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetFramebufferSize") + ret = fn(vgpuTypeId, byref(c_fb_size)) + _nvmlCheckReturn(ret) + return c_fb_size.value + + +def nvmlVgpuTypeGetNumDisplayHeads(vgpuTypeId): + c_num_heads = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetNumDisplayHeads") + ret = fn(vgpuTypeId, byref(c_num_heads)) + _nvmlCheckReturn(ret) + return c_num_heads.value + + +def nvmlVgpuTypeGetResolution(vgpuTypeId): + c_xdim = c_uint(0) + c_ydim = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetResolution") + ret = fn(vgpuTypeId, 0, byref(c_xdim), byref(c_ydim)) + _nvmlCheckReturn(ret) + return (c_xdim.value, c_ydim.value) + + +@convertStrBytes +def nvmlVgpuTypeGetLicense(vgpuTypeId): + c_license = create_string_buffer(NVML_GRID_LICENSE_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_GRID_LICENSE_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetLicense") + ret = fn(vgpuTypeId, c_license, c_buffer_size) + _nvmlCheckReturn(ret) + return c_license.value + + +def nvmlVgpuTypeGetFrameRateLimit(vgpuTypeId): + c_frl_config = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetFrameRateLimit") + ret = fn(vgpuTypeId, byref(c_frl_config)) + _nvmlCheckReturn(ret) + return c_frl_config.value + + +def nvmlVgpuTypeGetGspHeapSize(vgpuTypeId): + c_gsp_heap = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetGspHeapSize") + ret = fn(vgpuTypeId, byref(c_gsp_heap)) + _nvmlCheckReturn(ret) + return c_gsp_heap.value + + +def nvmlVgpuTypeGetFbReservation(vgpuTypeId): + c_fb_reservation = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetFbReservation") + ret = fn(vgpuTypeId, byref(c_fb_reservation)) + _nvmlCheckReturn(ret) + return c_fb_reservation.value + + +def nvmlVgpuInstanceGetRuntimeStateSize(vgpuInstance): + c_runtime_state = nvmlVgpuRuntimeState_v1_t() + c_runtime_state.version = VgpuRuntimeState_v1 + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetRuntimeStateSize") + ret = fn(vgpuInstance, byref(c_runtime_state)) + _nvmlCheckReturn(ret) + return c_runtime_state + + +def nvmlVgpuTypeGetMaxInstances(handle, vgpuTypeId): + c_max_instances = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstances") + ret = fn(handle, vgpuTypeId, byref(c_max_instances)) + _nvmlCheckReturn(ret) + return c_max_instances.value + + +def nvmlVgpuTypeGetMaxInstancesPerVm(vgpuTypeId): + c_max_instances_per_vm = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstancesPerVm") + ret = fn(vgpuTypeId, byref(c_max_instances_per_vm)) + _nvmlCheckReturn(ret) + return c_max_instances_per_vm.value + + +def nvmlVgpuTypeGetBAR1Info(vgpuTypeId): + c_bar1Info = c_nvmlVgpuTypeBar1Info_v1_t(0) + c_bar1Info.version = VgpuTypeBar1Info_v1 + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetBAR1Info") + ret = fn(vgpuTypeId, byref(c_bar1Info)) + _nvmlCheckReturn(ret) + return c_bar1Info + + +def nvmlDeviceGetActiveVgpus(handle): + # first call to get the size + c_vgpu_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetActiveVgpus") + ret = fn(handle, byref(c_vgpu_count), None) + + if ret == NVML_SUCCESS: + # special case, no active vGPUs + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + vgpu_instance_array = _nvmlVgpuInstance_t * c_vgpu_count.value + c_vgpu_instances = vgpu_instance_array() + + # make the call again + ret = fn(handle, byref(c_vgpu_count), c_vgpu_instances) + _nvmlCheckReturn(ret) + vgpus = [] + for i in range(c_vgpu_count.value): + vgpus.append(c_vgpu_instances[i]) + return vgpus + else: + # error case + raise NVMLError(ret) + + +@convertStrBytes +def nvmlVgpuInstanceGetVmID(vgpuInstance): + c_vm_id = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_GRID_LICENSE_BUFFER_SIZE) + c_vm_id_type = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetVmID") + ret = fn(vgpuInstance, byref(c_vm_id), c_buffer_size, byref(c_vm_id_type)) + _nvmlCheckReturn(ret) + return (c_vm_id.value, c_vm_id_type.value) + + +@convertStrBytes +def nvmlVgpuInstanceGetUUID(vgpuInstance): + c_uuid = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_UUID_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetUUID") + ret = fn(vgpuInstance, byref(c_uuid), c_buffer_size) + _nvmlCheckReturn(ret) + return c_uuid.value + + +@convertStrBytes +def nvmlVgpuInstanceGetMdevUUID(vgpuInstance): + c_uuid = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_UUID_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetMdevUUID") + ret = fn(vgpuInstance, byref(c_uuid), c_buffer_size) + _nvmlCheckReturn(ret) + return c_uuid.value + + +@convertStrBytes +def nvmlVgpuInstanceGetVmDriverVersion(vgpuInstance): + c_driver_version = create_string_buffer(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetVmDriverVersion") + ret = fn(vgpuInstance, byref(c_driver_version), c_buffer_size) + _nvmlCheckReturn(ret) + return c_driver_version.value + + +def nvmlVgpuInstanceGetLicenseStatus(vgpuInstance): + c_license_status = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetLicenseStatus") + ret = fn(vgpuInstance, byref(c_license_status)) + _nvmlCheckReturn(ret) + return c_license_status.value + + +def nvmlVgpuInstanceGetLicenseInfo_v2(vgpuInstance): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetLicenseInfo_v2") + c_license_info = c_nvmlVgpuLicenseInfo_t() + ret = fn(vgpuInstance, byref(c_license_info)) + _nvmlCheckReturn(ret) + return c_license_info + + +def nvmlVgpuInstanceGetLicenseInfo(vgpuInstance): + return nvmlVgpuInstanceGetLicenseInfo_v2(vgpuInstance) + + +def nvmlVgpuInstanceGetFrameRateLimit(vgpuInstance): + c_frl = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFrameRateLimit") + ret = fn(vgpuInstance, byref(c_frl)) + _nvmlCheckReturn(ret) + return c_frl.value + + +def nvmlVgpuInstanceGetEccMode(vgpuInstance): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEccMode") + ret = fn(vgpuInstance, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + + +def nvmlVgpuInstanceGetType(vgpuInstance): + c_vgpu_type = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetType") + ret = fn(vgpuInstance, byref(c_vgpu_type)) + _nvmlCheckReturn(ret) + return c_vgpu_type.value + + +def nvmlVgpuInstanceGetEncoderCapacity(vgpuInstance): + c_encoder_capacity = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEncoderCapacity") + ret = fn(vgpuInstance, byref(c_encoder_capacity)) + _nvmlCheckReturn(ret) + return c_encoder_capacity.value + + +def nvmlVgpuInstanceSetEncoderCapacity(vgpuInstance, encoder_capacity): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceSetEncoderCapacity") + return fn(vgpuInstance, encoder_capacity) + + +def nvmlVgpuInstanceGetFbUsage(vgpuInstance): + c_fb_usage = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFbUsage") + ret = fn(vgpuInstance, byref(c_fb_usage)) + _nvmlCheckReturn(ret) + return c_fb_usage.value + + +def nvmlVgpuTypeGetCapabilities(vgpuTypeId, capability): + c_cap_result = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetCapabilities") + ret = fn(vgpuTypeId, _nvmlVgpuCapability_t(capability), byref(c_cap_result)) + _nvmlCheckReturn(ret) + return c_cap_result.value + + +def nvmlVgpuInstanceGetGpuInstanceId(vgpuInstance): + c_id = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetGpuInstanceId") + ret = fn(vgpuInstance, byref(c_id)) + _nvmlCheckReturn(ret) + return c_id.value + + +@convertStrBytes +def nvmlVgpuInstanceGetGpuPciId(vgpuInstance): + c_vgpuPciId = create_string_buffer(NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetGpuPciId") + ret = fn( + vgpuInstance, c_vgpuPciId, byref(c_uint(NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE)) + ) + _nvmlCheckReturn(ret) + return c_vgpuPciId.value + + +def nvmlDeviceGetVgpuUtilization(handle, timeStamp): + # first call to get the size + c_vgpu_count = c_uint(0) + c_time_stamp = c_ulonglong(timeStamp) + c_sample_value_type = _nvmlValueType_t() + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuUtilization") + ret = fn( + handle, c_time_stamp, byref(c_sample_value_type), byref(c_vgpu_count), None + ) + + if ret == NVML_SUCCESS: + # special case, no active vGPUs + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + sampleArray = c_vgpu_count.value * c_nvmlVgpuInstanceUtilizationSample_t + c_samples = sampleArray() + + # make the call again + ret = fn( + handle, + c_time_stamp, + byref(c_sample_value_type), + byref(c_vgpu_count), + c_samples, + ) + _nvmlCheckReturn(ret) + + return c_samples[0 : c_vgpu_count.value] + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetVgpuInstancesUtilizationInfo(handle, timeStamp): + # first call to get the size + c_time_stamp = c_ulonglong(timeStamp) + c_vgpuUtilInfo = c_nvmlVgpuInstancesUtilizationInfo_v1_t(0) + c_vgpuUtilInfo.version = VgpuInstancesUtilizationInfo_v1 + c_vgpuUtilInfo.sampleValType = _nvmlValueType_t() + c_vgpuUtilInfo.vgpuInstanceCount = c_uint(0) + c_vgpuUtilInfo.lastSeenTimeStamp = c_time_stamp + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuInstancesUtilizationInfo") + ret = fn(handle, byref(c_vgpuUtilInfo)) + + if ret == NVML_SUCCESS: + # special case, no active vGPUs + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + sampleArray = ( + c_vgpuUtilInfo.vgpuInstanceCount * c_nvmlVgpuInstanceUtilizationInfo_v1_t + ) + c_samples = sampleArray() + c_vgpuUtilInfo.vgpuUtilArray = c_samples + + # make the call again + ret = fn(handle, byref(c_vgpuUtilInfo)) + _nvmlCheckReturn(ret) + + return c_samples[0 : c_vgpuUtilInfo.vgpuInstanceCount] + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetP2PStatus(device1, device2, p2pIndex): + c_p2pstatus = _nvmlGpuP2PStatus_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetP2PStatus") + ret = fn(device1, device2, p2pIndex, byref(c_p2pstatus)) + _nvmlCheckReturn(ret) + return c_p2pstatus.value + + +def nvmlDeviceGetGridLicensableFeatures_v4(handle): + c_get_grid_licensable_features = c_nvmlGridLicensableFeatures_v4_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGridLicensableFeatures_v4") + ret = fn(handle, byref(c_get_grid_licensable_features)) + _nvmlCheckReturn(ret) + + return c_get_grid_licensable_features + + +def nvmlDeviceGetGridLicensableFeatures(handle): + return nvmlDeviceGetGridLicensableFeatures_v4(handle) + + +def nvmlDeviceGetGspFirmwareVersion(handle, version=None): + isUserDefined = version is not None + if not isUserDefined: + version = (c_char * NVML_GSP_FIRMWARE_VERSION_BUF_SIZE)() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGspFirmwareVersion") + ret = fn(handle, version) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isUserDefined else version.value + + +def nvmlDeviceGetGspFirmwareMode(handle, isEnabled=c_uint(), defaultMode=c_uint()): + isReference = type(isEnabled) is not c_uint + isEnabledRef = isEnabled if isReference else byref(isEnabled) + defaultModeRef = defaultMode if isReference else byref(defaultMode) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGspFirmwareMode") + ret = fn(handle, isEnabledRef, defaultModeRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else [isEnabled.value, defaultMode.value] + + +def nvmlDeviceGetEncoderCapacity(handle, encoderQueryType): + c_encoder_capacity = c_ulonglong(0) + c_encoderQuery_type = _nvmlEncoderQueryType_t(encoderQueryType) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderCapacity") + ret = fn(handle, c_encoderQuery_type, byref(c_encoder_capacity)) + _nvmlCheckReturn(ret) + return c_encoder_capacity.value + + +def nvmlDeviceGetVgpuProcessUtilization(handle, timeStamp): + # first call to get the size + c_vgpu_count = c_uint(0) + c_time_stamp = c_ulonglong(timeStamp) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuProcessUtilization") + ret = fn(handle, c_time_stamp, byref(c_vgpu_count), None) + + if ret == NVML_SUCCESS: + # special case, no active vGPUs + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + sampleArray = c_vgpu_count.value * c_nvmlVgpuProcessUtilizationSample_t + c_samples = sampleArray() + + # make the call again + ret = fn(handle, c_time_stamp, byref(c_vgpu_count), c_samples) + _nvmlCheckReturn(ret) + + return c_samples[0 : c_vgpu_count.value] + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetVgpuProcessesUtilizationInfo(handle, timeStamp): + # first call to get the size + c_time_stamp = c_ulonglong(timeStamp) + c_vgpuProcUtilInfo = c_nvmlVgpuProcessesUtilizationInfo_v1_t(0) + c_vgpuProcUtilInfo.version = VgpuProcessesUtilizationInfo_v1 + c_vgpuProcUtilInfo.vgpuProcessCount = c_uint(0) + c_vgpuProcUtilInfo.lastSeenTimeStamp = c_time_stamp + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuProcessesUtilizationInfo") + ret = fn(handle, byref(c_vgpuProcUtilInfo)) + + if ret == NVML_SUCCESS: + # special case, no active vGPUs + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + sampleArray = ( + c_vgpuProcUtilInfo.vgpuProcessCount * c_nvmlVgpuProcessUtilizationInfo_v1_t + ) + c_samples = sampleArray() + c_vgpuProcUtilInfo.vgpuProcUtilArray = c_samples + + # make the call again + ret = fn(handle, byref(c_vgpuProcUtilInfo)) + _nvmlCheckReturn(ret) + + return c_samples[0 : c_vgpuProcUtilInfo.vgpuProcessCount] + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetEncoderStats(handle): + c_encoderCount = c_ulonglong(0) + c_encodeFps = c_ulonglong(0) + c_encoderLatency = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderStats") + ret = fn(handle, byref(c_encoderCount), byref(c_encodeFps), byref(c_encoderLatency)) + _nvmlCheckReturn(ret) + return (c_encoderCount.value, c_encodeFps.value, c_encoderLatency.value) + + +def nvmlDeviceGetEncoderSessions(handle): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderSessions") + ret = fn(handle, byref(c_session_count), None) + + if ret == NVML_SUCCESS: + if c_session_count.value != 0: + # typical case + session_array = c_nvmlEncoderSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(handle, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetFBCStats(handle): + c_fbcStats = c_nvmlFBCStats_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFBCStats") + ret = fn(handle, byref(c_fbcStats)) + _nvmlCheckReturn(ret) + return c_fbcStats + + +def nvmlDeviceGetFBCSessions(handle): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFBCSessions") + ret = fn(handle, byref(c_session_count), None) + + if ret == NVML_SUCCESS: + if c_session_count.value != 0: + # typical case + session_array = c_nvmlFBCSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(handle, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + + +def nvmlVgpuInstanceGetEncoderStats(vgpuInstance): + c_encoderCount = c_ulonglong(0) + c_encodeFps = c_ulonglong(0) + c_encoderLatency = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEncoderStats") + ret = fn( + vgpuInstance, byref(c_encoderCount), byref(c_encodeFps), byref(c_encoderLatency) + ) + _nvmlCheckReturn(ret) + return (c_encoderCount.value, c_encodeFps.value, c_encoderLatency.value) + + +def nvmlVgpuInstanceGetEncoderSessions(vgpuInstance): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEncoderSessions") + ret = fn(vgpuInstance, byref(c_session_count), None) + + if ret == NVML_SUCCESS: + if c_session_count.value != 0: + # typical case + session_array = c_nvmlEncoderSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(vgpuInstance, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + + +def nvmlVgpuInstanceGetFBCStats(vgpuInstance): + c_fbcStats = c_nvmlFBCStats_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFBCStats") + ret = fn(vgpuInstance, byref(c_fbcStats)) + _nvmlCheckReturn(ret) + return c_fbcStats + + +def nvmlVgpuInstanceGetFBCSessions(vgpuInstance): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFBCSessions") + ret = fn(vgpuInstance, byref(c_session_count), None) + + if ret == NVML_SUCCESS: + if c_session_count.value != 0: + # typical case + session_array = c_nvmlFBCSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(vgpuInstance, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetProcessUtilization(handle, timeStamp): + # first call to get the size + c_count = c_uint(0) + c_time_stamp = c_ulonglong(timeStamp) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetProcessUtilization") + ret = fn(handle, None, byref(c_count), c_time_stamp) + + if ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + sampleArray = c_count.value * c_nvmlProcessUtilizationSample_t + c_samples = sampleArray() + + # make the call again + ret = fn(handle, c_samples, byref(c_count), c_time_stamp) + _nvmlCheckReturn(ret) + + return c_samples[0 : c_count.value] + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetProcessesUtilizationInfo(handle, timeStamp): + # first call to get the size + c_time_stamp = c_ulonglong(timeStamp) + c_processesUtilInfo = c_nvmlProcessesUtilizationInfo_v1_t(0) + c_processesUtilInfo.version = ProcessesUtilizationInfo_v1 + c_processesUtilInfo.processSamplesCount = c_uint(0) + c_processesUtilInfo.lastSeenTimeStamp = c_time_stamp + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetProcessesUtilizationInfo") + ret = fn(handle, byref(c_processesUtilInfo)) + + if ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + sampleArray = ( + c_processesUtilInfo.processSamplesCount * c_nvmlProcessUtilizationInfo_v1_t + ) + c_samples = sampleArray() + c_processesUtilInfo.procUtilArray = c_samples + + # make the call again + ret = fn(handle, byref(c_processesUtilInfo)) + _nvmlCheckReturn(ret) + + return c_samples[0 : c_processesUtilInfo.processSamplesCount] + else: + # error case + raise NVMLError(ret) + + +def nvmlVgpuInstanceGetMetadata(vgpuInstance): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetMetadata") + c_vgpuMetadata = c_nvmlVgpuMetadata_t() + c_bufferSize = c_uint(0) + # Make the first NVML API call to get the c_bufferSize value. + # We have already allocated required buffer above. + ret = fn(vgpuInstance, byref(c_vgpuMetadata), byref(c_bufferSize)) + if ret == NVML_ERROR_INSUFFICIENT_SIZE: + ret = fn(vgpuInstance, byref(c_vgpuMetadata), byref(c_bufferSize)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return c_vgpuMetadata + + +def nvmlDeviceGetVgpuMetadata(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuMetadata") + c_vgpuPgpuMetadata = c_nvmlVgpuPgpuMetadata_t() + c_bufferSize = c_uint(0) + # Make the first NVML API call to get the c_bufferSize value. + # We have already allocated required buffer above. + ret = fn(handle, byref(c_vgpuPgpuMetadata), byref(c_bufferSize)) + if ret == NVML_ERROR_INSUFFICIENT_SIZE: + ret = fn(handle, byref(c_vgpuPgpuMetadata), byref(c_bufferSize)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return c_vgpuPgpuMetadata + + +def nvmlGetVgpuCompatibility(vgpuMetadata, pgpuMetadata): + fn = _nvmlGetFunctionPointer("nvmlGetVgpuCompatibility") + c_vgpuPgpuCompatibility = c_nvmlVgpuPgpuCompatibility_t() + ret = fn(byref(vgpuMetadata), byref(pgpuMetadata), byref(c_vgpuPgpuCompatibility)) + _nvmlCheckReturn(ret) + return c_vgpuPgpuCompatibility + + +@convertStrBytes +def nvmlDeviceGetPgpuMetadataString(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPgpuMetadataString") + c_pgpuMetadata = create_string_buffer(NVML_VGPU_PGPU_METADATA_OPAQUE_DATA_SIZE) + c_bufferSize = c_uint(0) + # Make the first NVML API call to get the c_bufferSize value. + # We have already allocated required buffer above. + ret = fn(handle, byref(c_pgpuMetadata), byref(c_bufferSize)) + if ret == NVML_ERROR_INSUFFICIENT_SIZE: + ret = fn(handle, byref(c_pgpuMetadata), byref(c_bufferSize)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return (c_pgpuMetadata.value, c_bufferSize.value) + + +def nvmlDeviceGetVgpuSchedulerLog(handle): + c_vgpu_sched_log = c_nvmlVgpuSchedulerLog_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuSchedulerLog") + ret = fn(handle, byref(c_vgpu_sched_log)) + _nvmlCheckReturn(ret) + return c_vgpu_sched_log + + +def nvmlDeviceGetVgpuSchedulerState(handle): + c_vgpu_sched_state = c_nvmlVgpuSchedulerGetState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuSchedulerState") + ret = fn(handle, byref(c_vgpu_sched_state)) + _nvmlCheckReturn(ret) + return c_vgpu_sched_state + + +def nvmlDeviceGetVgpuSchedulerCapabilities(handle): + c_vgpu_sched_caps = c_nvmlVgpuSchedulerCapabilities_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuSchedulerCapabilities") + ret = fn(handle, byref(c_vgpu_sched_caps)) + _nvmlCheckReturn(ret) + return c_vgpu_sched_caps + + +def nvmlDeviceSetVgpuSchedulerState(handle, sched_state): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVgpuSchedulerState") + ret = fn(handle, byref(sched_state)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlSetVgpuVersion(vgpuVersion): + fn = _nvmlGetFunctionPointer("nvmlSetVgpuVersion") + ret = fn(byref(vgpuVersion)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlGetVgpuVersion(supported=None, current=None): + isUserDefined = (supported is not None) or (current is not None) + if not isUserDefined: + supported = c_nvmlVgpuVersion_t() + current = c_nvmlVgpuVersion_t() + fn = _nvmlGetFunctionPointer("nvmlGetVgpuVersion") + ret = fn(byref(supported), byref(current)) + _nvmlCheckReturn(ret) + return ( + NVML_SUCCESS + if isUserDefined + else [ + (supported.minVersion, supported.maxVersion), + (current.minVersion, current.maxVersion), + ] + ) + + +def nvmlVgpuInstanceGetAccountingMode(vgpuInstance): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetAccountingMode") + ret = fn(vgpuInstance, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + + +def nvmlVgpuInstanceGetAccountingPids(vgpuInstance): + c_pidCount = c_uint() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetAccountingPids") + ret = fn(vgpuInstance, byref(c_pidCount), None) + if ret == NVML_ERROR_INSUFFICIENT_SIZE: + sampleArray = c_pidCount.value * c_uint + c_pidArray = sampleArray() + ret = fn(vgpuInstance, byref(c_pidCount), byref(c_pidArray)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return (c_pidCount, c_pidArray) + + +def nvmlVgpuInstanceGetAccountingStats(vgpuInstance, pid): + c_accountingStats = c_nvmlAccountingStats_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetAccountingStats") + ret = fn(vgpuInstance, pid, byref(c_accountingStats)) + _nvmlCheckReturn(ret) + return c_accountingStats + + +def nvmlVgpuInstanceClearAccountingPids(vgpuInstance): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceClearAccountingPids") + ret = fn(vgpuInstance) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlGetExcludedDeviceCount(): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGetExcludedDeviceCount") + ret = fn(byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + + +def nvmlGetExcludedDeviceInfoByIndex(index): + c_index = c_uint(index) + info = c_nvmlExcludedDeviceInfo_t() + fn = _nvmlGetFunctionPointer("nvmlGetExcludedDeviceInfoByIndex") + ret = fn(c_index, byref(info)) + _nvmlCheckReturn(ret) + return info + + +def nvmlDeviceGetHostVgpuMode(handle): + c_host_vgpu_mode = _nvmlHostVgpuMode_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHostVgpuMode") + ret = fn(handle, byref(c_host_vgpu_mode)) + _nvmlCheckReturn(ret) + return c_host_vgpu_mode.value + + +def nvmlDeviceSetMigMode(device, mode): + c_activationStatus = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceSetMigMode") + ret = fn(device, mode, byref(c_activationStatus)) + _nvmlCheckReturn(ret) + return c_activationStatus.value + + +def nvmlDeviceGetMigMode(device): + c_currentMode = c_uint() + c_pendingMode = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMigMode") + ret = fn(device, byref(c_currentMode), byref(c_pendingMode)) + _nvmlCheckReturn(ret) + return [c_currentMode.value, c_pendingMode.value] + + +def nvmlDeviceGetGpuInstanceProfileInfo(device, profile, version=2): + if version == 2: + c_info = c_nvmlGpuInstanceProfileInfo_v2_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceProfileInfoV") + elif version == 1: + c_info = c_nvmlGpuInstanceProfileInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceProfileInfo") + else: + raise NVMLError(NVML_ERROR_FUNCTION_NOT_FOUND) + ret = fn(device, profile, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + + +# Define function alias for the API exposed by NVML +nvmlDeviceGetGpuInstanceProfileInfoV = nvmlDeviceGetGpuInstanceProfileInfo + + +def nvmlDeviceGetGpuInstanceRemainingCapacity(device, profileId): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceRemainingCapacity") + ret = fn(device, profileId, byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + + +def nvmlDeviceGetGpuInstancePossiblePlacements( + device, profileId, placementsRef, countRef +): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstancePossiblePlacements_v2") + ret = fn(device, profileId, placementsRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceCreateGpuInstance(device, profileId): + c_instance = c_nvmlGpuInstance_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceCreateGpuInstance") + ret = fn(device, profileId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + + +def nvmlDeviceCreateGpuInstanceWithPlacement(device, profileId, placement): + c_instance = c_nvmlGpuInstance_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceCreateGpuInstanceWithPlacement") + ret = fn(device, profileId, placement, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + + +def nvmlGpuInstanceDestroy(gpuInstance): + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceDestroy") + ret = fn(gpuInstance) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetGpuInstances(device, profileId, gpuInstancesRef, countRef): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstances") + ret = fn(device, profileId, gpuInstancesRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetGpuInstanceById(device, gpuInstanceId): + c_instance = c_nvmlGpuInstance_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceById") + ret = fn(device, gpuInstanceId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + + +def nvmlGpuInstanceGetInfo(gpuInstance): + c_info = c_nvmlGpuInstanceInfo_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetInfo") + ret = fn(gpuInstance, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + + +def nvmlGpuInstanceGetComputeInstanceProfileInfo( + device, profile, engProfile, version=2 +): + if version == 2: + c_info = c_nvmlComputeInstanceProfileInfo_v2_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceProfileInfoV") + elif version == 1: + c_info = c_nvmlComputeInstanceProfileInfo_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceProfileInfo") + else: + raise NVMLError(NVML_ERROR_FUNCTION_NOT_FOUND) + ret = fn(device, profile, engProfile, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + + +# Define function alias for the API exposed by NVML +nvmlGpuInstanceGetComputeInstanceProfileInfoV = ( + nvmlGpuInstanceGetComputeInstanceProfileInfo +) + + +def nvmlGpuInstanceGetComputeInstanceRemainingCapacity(gpuInstance, profileId): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceRemainingCapacity") + ret = fn(gpuInstance, profileId, byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + + +def nvmlGpuInstanceGetComputeInstancePossiblePlacements( + gpuInstance, profileId, placementsRef, countRef +): + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstancePossiblePlacements") + ret = fn(gpuInstance, profileId, placementsRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlGpuInstanceCreateComputeInstance(gpuInstance, profileId): + c_instance = c_nvmlComputeInstance_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceCreateComputeInstance") + ret = fn(gpuInstance, profileId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + + +def nvmlGpuInstanceCreateComputeInstanceWithPlacement( + gpuInstance, profileId, placement +): + c_instance = c_nvmlComputeInstance_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceCreateComputeInstanceWithPlacement") + ret = fn(gpuInstance, profileId, placement, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + + +def nvmlComputeInstanceDestroy(computeInstance): + fn = _nvmlGetFunctionPointer("nvmlComputeInstanceDestroy") + ret = fn(computeInstance) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlGpuInstanceGetComputeInstances( + gpuInstance, profileId, computeInstancesRef, countRef +): + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstances") + ret = fn(gpuInstance, profileId, computeInstancesRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlGpuInstanceGetComputeInstanceById(gpuInstance, computeInstanceId): + c_instance = c_nvmlComputeInstance_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceById") + ret = fn(gpuInstance, computeInstanceId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + + +def nvmlComputeInstanceGetInfo_v2(computeInstance): + c_info = c_nvmlComputeInstanceInfo_t() + fn = _nvmlGetFunctionPointer("nvmlComputeInstanceGetInfo_v2") + ret = fn(computeInstance, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + + +def nvmlComputeInstanceGetInfo(computeInstance): + return nvmlComputeInstanceGetInfo_v2(computeInstance) + + +def nvmlDeviceIsMigDeviceHandle(device): + c_isMigDevice = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceIsMigDeviceHandle") + ret = fn(device, byref(c_isMigDevice)) + _nvmlCheckReturn(ret) + return c_isMigDevice + + +def nvmlDeviceGetGpuInstanceId(device): + c_gpuInstanceId = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceId") + ret = fn(device, byref(c_gpuInstanceId)) + _nvmlCheckReturn(ret) + return c_gpuInstanceId.value + + +def nvmlDeviceGetComputeInstanceId(device): + c_computeInstanceId = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeInstanceId") + ret = fn(device, byref(c_computeInstanceId)) + _nvmlCheckReturn(ret) + return c_computeInstanceId.value + + +def nvmlDeviceGetMaxMigDeviceCount(device): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxMigDeviceCount") + ret = fn(device, byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + + +def nvmlDeviceGetMigDeviceHandleByIndex(device, index): + c_index = c_uint(index) + migDevice = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMigDeviceHandleByIndex") + ret = fn(device, c_index, byref(migDevice)) + _nvmlCheckReturn(ret) + return migDevice + + +def nvmlDeviceGetDeviceHandleFromMigDeviceHandle(migDevice): + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDeviceHandleFromMigDeviceHandle") + ret = fn(migDevice, byref(device)) + _nvmlCheckReturn(ret) + return device + + +def nvmlDeviceGetAttributes_v2(device): + c_attrs = c_nvmlDeviceAttributes() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAttributes_v2") + ret = fn(device, byref(c_attrs)) + _nvmlCheckReturn(ret) + return c_attrs + + +def nvmlDeviceGetAttributes(device): + return nvmlDeviceGetAttributes_v2(device) + + +def nvmlDeviceGetRemappedRows(device): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRemappedRows") + c_corr = c_uint() + c_unc = c_uint() + c_bpending = c_uint() + c_bfailure = c_uint() + ret = fn(device, byref(c_corr), byref(c_unc), byref(c_bpending), byref(c_bfailure)) + _nvmlCheckReturn(ret) + return (c_corr.value, c_unc.value, c_bpending.value, c_bfailure.value) + + +def nvmlDeviceGetRowRemapperHistogram(device): + c_vals = c_nvmlRowRemapperHistogramValues() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRowRemapperHistogram") + ret = fn(device, byref(c_vals)) + _nvmlCheckReturn(ret) + return c_vals + + +def nvmlDeviceGetArchitecture(device): + arch = _nvmlDeviceArchitecture_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetArchitecture") + ret = fn(device, byref(arch)) + _nvmlCheckReturn(ret) + return arch.value + + +def nvmlDeviceGetBusType(device): + c_busType = _nvmlBusType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBusType") + ret = fn(device, byref(c_busType)) + _nvmlCheckReturn(ret) + return c_busType.value + + +def nvmlDeviceGetIrqNum(device): + c_irqNum = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetIrqNum") + ret = fn(device, byref(c_irqNum)) + _nvmlCheckReturn(ret) + return c_irqNum.value + + +def nvmlDeviceGetNumGpuCores(device): + c_numCores = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNumGpuCores") + ret = fn(device, byref(c_numCores)) + _nvmlCheckReturn(ret) + return c_numCores.value + + +def nvmlDeviceGetPowerSource(device): + c_powerSource = _nvmlPowerSource_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerSource") + ret = fn(device, byref(c_powerSource)) + _nvmlCheckReturn(ret) + return c_powerSource.value + + +def nvmlDeviceGetMemoryBusWidth(device): + c_memBusWidth = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryBusWidth") + ret = fn(device, byref(c_memBusWidth)) + _nvmlCheckReturn(ret) + return c_memBusWidth.value + + +def nvmlDeviceGetPcieLinkMaxSpeed(device): + c_speed = _nvmlPcieLinkMaxSpeed_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieLinkMaxSpeed") + ret = fn(device, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + + +def nvmlDeviceGetAdaptiveClockInfoStatus(device): + c_adaptiveClockInfoStatus = _nvmlAdaptiveClockInfoStatus_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAdaptiveClockInfoStatus") + ret = fn(device, byref(c_adaptiveClockInfoStatus)) + _nvmlCheckReturn(ret) + return c_adaptiveClockInfoStatus.value + + +def nvmlDeviceGetPcieSpeed(device): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieSpeed") + ret = fn(device, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + + +def nvmlDeviceGetDynamicPstatesInfo( + device, c_dynamicpstatesinfo=c_nvmlGpuDynamicPstatesInfo_t() +): + isReference = type(c_dynamicpstatesinfo) is not c_nvmlGpuDynamicPstatesInfo_t + dynamicpstatesinfoRef = ( + c_dynamicpstatesinfo if isReference else byref(c_dynamicpstatesinfo) + ) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDynamicPstatesInfo") + ret = fn(device, dynamicpstatesinfoRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else c_dynamicpstatesinfo + + +def nvmlDeviceSetFanSpeed_v2(handle, index, speed): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetFanSpeed_v2") + ret = fn(handle, index, speed) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetThermalSettings( + device, sensorindex, c_thermalsettings=c_nvmlGpuThermalSettings_t() +): + isReference = type(c_thermalsettings) is not c_nvmlGpuThermalSettings_t + thermalsettingsRef = c_thermalsettings if isReference else byref(c_thermalsettings) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetThermalSettings") + ret = fn(device, sensorindex, thermalsettingsRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else c_thermalsettings.sensor[:] + + +def nvmlDeviceGetMinMaxClockOfPState( + device, clockType, pstate, minClockMHz=c_uint(), maxClockMHz=c_uint() +): + isReference = (type(minClockMHz) is not c_uint) or (type(maxClockMHz) is not c_uint) + minClockMHzRef = minClockMHz if isReference else byref(minClockMHz) + maxClockMHzRef = maxClockMHz if isReference else byref(maxClockMHz) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMinMaxClockOfPState") + ret = fn( + device, + _nvmlClockType_t(clockType), + _nvmlClockType_t(pstate), + minClockMHzRef, + maxClockMHzRef, + ) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else (minClockMHz.value, maxClockMHz.value) + + +class c_nvmlClockOffset_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("type", _nvmlClockType_t), + ("pstate", _nvmlPstates_t), + ("clockOffsetMHz", c_int), + ("minClockOffsetMHz", c_int), + ("maxClockOffsetMHz", c_int), + ] + + +nvmlClockOffset_v1 = 0x1000018 + + +def nvmlDeviceGetClockOffsets(device, info): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClockOffsets") + ret = fn(device, info) + return NVML_SUCCESS + + +def nvmlDeviceSetClockOffsets(device, info): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetClockOffsets") + ret = fn(device, info) + return NVML_SUCCESS + + +def nvmlDeviceGetSupportedPerformanceStates(device): + pstates = [] + c_count = c_uint(NVML_MAX_GPU_PERF_PSTATES) + c_size = sizeof(c_uint) * c_count.value + + # NOTE: use 'c_uint' to represent the size of the nvmlPstate_t enumeration. + pstates_array = _nvmlPstates_t * c_count.value + c_pstates = pstates_array() + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedPerformanceStates") + ret = fn(device, c_pstates, c_size) + _nvmlCheckReturn(ret) + + for value in c_pstates: + if value != NVML_PSTATE_UNKNOWN: + pstates.append(value) + + return pstates + + +def nvmlDeviceGetGpcClkVfOffset(device): + offset = c_int32() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpcClkVfOffset") + ret = fn(device, byref(offset)) + _nvmlCheckReturn(ret) + return offset.value + + +def nvmlDeviceSetGpcClkVfOffset(device, offset): + c_offset = c_int32(offset) + fn = _nvmlGetFunctionPointer("nvmlDeviceSetGpcClkVfOffset") + ret = fn(device, c_offset) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetGpcClkMinMaxVfOffset(device, minOffset=c_int(), maxOffset=c_int()): + isReference = (type(minOffset) is not c_int) or (type(maxOffset) is not c_int) + minOffsetRef = minOffset if isReference else byref(minOffset) + maxOffsetRef = maxOffset if isReference else byref(maxOffset) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpcClkMinMaxVfOffset") + ret = fn(device, minOffsetRef, maxOffsetRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else (minOffset.value, maxOffset.value) + + +def nvmlDeviceGetMemClkVfOffset(device): + offset = c_int32() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemClkVfOffset") + ret = fn(device, byref(offset)) + _nvmlCheckReturn(ret) + return offset.value + + +def nvmlDeviceSetMemClkVfOffset(device, offset): + c_offset = c_int32(offset) + fn = _nvmlGetFunctionPointer("nvmlDeviceSetMemClkVfOffset") + ret = fn(device, c_offset) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetMemClkMinMaxVfOffset(device, minOffset=c_int(), maxOffset=c_int()): + isReference = (type(minOffset) is not c_int) or (type(maxOffset) is not c_int) + minOffsetRef = minOffset if isReference else byref(minOffset) + maxOffsetRef = maxOffset if isReference else byref(maxOffset) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemClkMinMaxVfOffset") + ret = fn(device, minOffsetRef, maxOffsetRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else (minOffset.value, maxOffset.value) + + +def nvmlSystemSetConfComputeGpusReadyState(state): + c_state = c_uint(state) + fn = _nvmlGetFunctionPointer("nvmlSystemSetConfComputeGpusReadyState") + ret = fn(c_state) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlSystemGetConfComputeGpusReadyState(): + c_state = c_uint() + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeGpusReadyState") + ret = fn(byref(c_state)) + _nvmlCheckReturn(ret) + return c_state.value + + +def nvmlSystemGetConfComputeCapabilities(): + c_ccSysCaps = c_nvmlConfComputeSystemCaps_t() + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeCapabilities") + ret = fn(byref(c_ccSysCaps)) + _nvmlCheckReturn(ret) + return c_ccSysCaps + + +def nvmlSystemGetConfComputeState(): + c_state = c_nvmlConfComputeSystemState_t() + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeState") + ret = fn(byref(c_state)) + _nvmlCheckReturn(ret) + return c_state + + +def nvmlSystemGetConfComputeSettings(settings): + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeSettings") + return fn(settings) + + +def nvmlDeviceSetConfComputeUnprotectedMemSize(device, c_ccMemSize): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetConfComputeUnprotectedMemSize") + ret = fn(device, c_ccMemSize) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetConfComputeMemSizeInfo(device): + c_ccMemSize = c_nvmlConfComputeMemSizeInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeMemSizeInfo") + ret = fn(device, byref(c_ccMemSize)) + _nvmlCheckReturn(ret) + return c_ccMemSize + + +def nvmlDeviceGetConfComputeProtectedMemoryUsage(device): + c_memory = c_nvmlMemory_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeProtectedMemoryUsage") + ret = fn(device, byref(c_memory)) + _nvmlCheckReturn(ret) + return c_memory + + +def nvmlDeviceGetConfComputeGpuCertificate(device): + c_cert = c_nvmlConfComputeGpuCertificate_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeGpuCertificate") + ret = fn(device, byref(c_cert)) + _nvmlCheckReturn(ret) + return c_cert + + +def nvmlDeviceGetConfComputeGpuAttestationReport(device, c_nonce): + c_attestReport = c_nvmlConfComputeGpuAttestationReport_t() + c_nonce_arr = (c_uint8 * len(c_nonce))(*(c_nonce)) + setattr(c_attestReport, "nonce", c_nonce_arr) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeGpuAttestationReport") + ret = fn(device, byref(c_attestReport)) + _nvmlCheckReturn(ret) + return c_attestReport + + +def nvmlSystemSetConfComputeKeyRotationThresholdInfo(max_atk_adv): + c_keyRotationThrInfo = c_nvmlConfComputeSetKeyRotationThresholdInfo_t(0) + c_keyRotationThrInfo.version = ConfComputeSetKeyRotationThresholdInfo_v1 + c_keyRotationThrInfo.maxAttackerAdvantage = max_atk_adv + fn = _nvmlGetFunctionPointer("nvmlSystemSetConfComputeKeyRotationThresholdInfo") + ret = fn(byref(c_keyRotationThrInfo)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlSystemGetConfComputeKeyRotationThresholdInfo(): + c_keyRotationThrInfo = c_nvmlConfComputeGetKeyRotationThresholdInfo_t(0) + c_keyRotationThrInfo.version = ConfComputeGetKeyRotationThresholdInfo_v1 + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeKeyRotationThresholdInfo") + ret = fn(byref(c_keyRotationThrInfo)) + _nvmlCheckReturn(ret) + return c_keyRotationThrInfo + + +## GPM ## +######### + +## Enums/defines + +#### GPM Metric Identifiers +NVML_GPM_METRIC_GRAPHICS_UTIL = ( + 1 # Percentage of time any compute/graphics app was active on the GPU. 0.0 - 100.0 +) +NVML_GPM_METRIC_SM_UTIL = 2 # Percentage of SMs that were busy. 0.0 - 100.0 +NVML_GPM_METRIC_SM_OCCUPANCY = ( + 3 # Percentage of warps that were active vs theoretical maximum. 0.0 - 100.0 +) +NVML_GPM_METRIC_INTEGER_UTIL = ( + 4 # Percentage of time the GPU's SMs were doing integer operations. 0.0 - 100.0 +) +NVML_GPM_METRIC_ANY_TENSOR_UTIL = ( + 5 # Percentage of time the GPU's SMs were doing ANY tensor operations. 0.0 - 100.0 +) +NVML_GPM_METRIC_DFMA_TENSOR_UTIL = ( + 6 # Percentage of time the GPU's SMs were doing DFMA tensor operations. 0.0 - 100.0 +) +NVML_GPM_METRIC_HMMA_TENSOR_UTIL = ( + 7 # Percentage of time the GPU's SMs were doing HMMA tensor operations. 0.0 - 100.0 +) +NVML_GPM_METRIC_IMMA_TENSOR_UTIL = ( + 9 # Percentage of time the GPU's SMs were doing IMMA tensor operations. 0.0 - 100.0 +) +NVML_GPM_METRIC_DRAM_BW_UTIL = ( + 10 # Percentage of DRAM bw used vs theoretical maximum. 0.0 - 100.0 +) +NVML_GPM_METRIC_FP64_UTIL = ( + 11 # Percentage of time the GPU's SMs were doing non-tensor FP64 math. 0.0 - 100.0 +) +NVML_GPM_METRIC_FP32_UTIL = ( + 12 # Percentage of time the GPU's SMs were doing non-tensor FP32 math. 0.0 - 100.0 +) +NVML_GPM_METRIC_FP16_UTIL = ( + 13 # Percentage of time the GPU's SMs were doing non-tensor FP16 math. 0.0 - 100.0 +) +NVML_GPM_METRIC_PCIE_TX_PER_SEC = 20 # PCIe traffic from this GPU in MiB/sec +NVML_GPM_METRIC_PCIE_RX_PER_SEC = 21 # PCIe traffic to this GPU in MiB/sec +NVML_GPM_METRIC_NVDEC_0_UTIL = 30 # Percent utilization of NVDEC 0. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_1_UTIL = 31 # Percent utilization of NVDEC 1. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_2_UTIL = 32 # Percent utilization of NVDEC 2. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_3_UTIL = 33 # Percent utilization of NVDEC 3. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_4_UTIL = 34 # Percent utilization of NVDEC 4. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_5_UTIL = 35 # Percent utilization of NVDEC 5. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_6_UTIL = 36 # Percent utilization of NVDEC 6. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_7_UTIL = 37 # Percent utilization of NVDEC 7. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_0_UTIL = 40 # Percent utilization of NVJPG 0. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_1_UTIL = 41 # Percent utilization of NVJPG 1. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_2_UTIL = 42 # Percent utilization of NVJPG 2. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_3_UTIL = 43 # Percent utilization of NVJPG 3. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_4_UTIL = 44 # Percent utilization of NVJPG 4. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_5_UTIL = 45 # Percent utilization of NVJPG 5. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_6_UTIL = 46 # Percent utilization of NVJPG 6. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_7_UTIL = 47 # Percent utilization of NVJPG 7. 0.0 - 100.0 +NVML_GPM_METRIC_NVOFA_0_UTIL = 50 # Percent utilization of NVOFA 0. 0.0 - 100.0 +NVML_GPM_METRIC_NVOFA_1_UTIL = 51 # Percent utilization of NVOFA 1. 0.0 - 100.0 +NVML_GPM_METRIC_NVLINK_TOTAL_RX_PER_SEC = ( + 60 # NvLink read bandwidth for all links in MiB/sec +) +NVML_GPM_METRIC_NVLINK_TOTAL_TX_PER_SEC = ( + 61 # NvLink write bandwidth for all links in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L0_RX_PER_SEC = 62 # NvLink read bandwidth for link 0 in MiB/sec +NVML_GPM_METRIC_NVLINK_L0_TX_PER_SEC = ( + 63 # NvLink write bandwidth for link 0 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L1_RX_PER_SEC = 64 # NvLink read bandwidth for link 1 in MiB/sec +NVML_GPM_METRIC_NVLINK_L1_TX_PER_SEC = ( + 65 # NvLink write bandwidth for link 1 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L2_RX_PER_SEC = 66 # NvLink read bandwidth for link 2 in MiB/sec +NVML_GPM_METRIC_NVLINK_L2_TX_PER_SEC = ( + 67 # NvLink write bandwidth for link 2 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L3_RX_PER_SEC = 68 # NvLink read bandwidth for link 3 in MiB/sec +NVML_GPM_METRIC_NVLINK_L3_TX_PER_SEC = ( + 69 # NvLink write bandwidth for link 3 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L4_RX_PER_SEC = 70 # NvLink read bandwidth for link 4 in MiB/sec +NVML_GPM_METRIC_NVLINK_L4_TX_PER_SEC = ( + 71 # NvLink write bandwidth for link 4 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L5_RX_PER_SEC = 72 # NvLink read bandwidth for link 5 in MiB/sec +NVML_GPM_METRIC_NVLINK_L5_TX_PER_SEC = ( + 73 # NvLink write bandwidth for link 5 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L6_RX_PER_SEC = 74 # NvLink read bandwidth for link 6 in MiB/sec +NVML_GPM_METRIC_NVLINK_L6_TX_PER_SEC = ( + 75 # NvLink write bandwidth for link 6 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L7_RX_PER_SEC = 76 # NvLink read bandwidth for link 7 in MiB/sec +NVML_GPM_METRIC_NVLINK_L7_TX_PER_SEC = ( + 77 # NvLink write bandwidth for link 7 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L8_RX_PER_SEC = 78 # NvLink read bandwidth for link 8 in MiB/sec +NVML_GPM_METRIC_NVLINK_L8_TX_PER_SEC = ( + 79 # NvLink write bandwidth for link 8 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L9_RX_PER_SEC = 80 # NvLink read bandwidth for link 9 in MiB/sec +NVML_GPM_METRIC_NVLINK_L9_TX_PER_SEC = ( + 81 # NvLink write bandwidth for link 9 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L10_RX_PER_SEC = ( + 82 # NvLink read bandwidth for link 10 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L10_TX_PER_SEC = ( + 83 # NvLink write bandwidth for link 10 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L11_RX_PER_SEC = ( + 84 # NvLink read bandwidth for link 11 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L11_TX_PER_SEC = ( + 85 # NvLink write bandwidth for link 11 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L12_RX_PER_SEC = ( + 86 # NvLink read bandwidth for link 12 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L12_TX_PER_SEC = ( + 87 # NvLink write bandwidth for link 12 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L13_RX_PER_SEC = ( + 88 # NvLink read bandwidth for link 13 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L13_TX_PER_SEC = ( + 89 # NvLink write bandwidth for link 13 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L14_RX_PER_SEC = ( + 90 # NvLink read bandwidth for link 14 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L14_TX_PER_SEC = ( + 91 # NvLink write bandwidth for link 14 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L15_RX_PER_SEC = ( + 92 # NvLink read bandwidth for link 15 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L15_TX_PER_SEC = ( + 93 # NvLink write bandwidth for link 15 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L16_RX_PER_SEC = ( + 94 # NvLink read bandwidth for link 16 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L16_TX_PER_SEC = ( + 95 # NvLink write bandwidth for link 16 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L17_RX_PER_SEC = ( + 96 # NvLink read bandwidth for link 17 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L17_TX_PER_SEC = ( + 97 # NvLink write bandwidth for link 17 in MiB/sec +) +NVML_GPM_METRIC_MAX = 98 + +## Structs + + +class c_nvmlUnitInfo_t(_PrintableStructure): + _fields_ = [ + ("name", c_char * 96), + ("id", c_char * 96), + ("serial", c_char * 96), + ("firmwareVersion", c_char * 96), + ] + + +class struct_c_nvmlGpmSample_t(Structure): + pass # opaque handle + + +c_nvmlGpmSample_t = POINTER(struct_c_nvmlGpmSample_t) + + +class c_metricInfo_t(Structure): + _fields_ = [ + ("shortName", c_char_p), + ("longName", c_char_p), + ("unit", c_char_p), + ] + + +class c_nvmlGpmMetric_t(_PrintableStructure): + _fields_ = [ + ("metricId", c_uint), + ("nvmlReturn", _nvmlReturn_t), + ("value", c_double), + ("metricInfo", c_metricInfo_t), + ] + + +class c_nvmlGpmMetricsGet_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("numMetrics", c_uint), + ("sample1", c_nvmlGpmSample_t), + ("sample2", c_nvmlGpmSample_t), + ("metrics", c_nvmlGpmMetric_t * NVML_GPM_METRIC_MAX), + ] + + +NVML_GPM_METRICS_GET_VERSION = 1 + + +class c_nvmlGpmSupport_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("isSupportedDevice", c_uint), + ] + + +NVML_GPM_SUPPORT_VERSION = 1 + +## Functions + + +def nvmlGpmMetricsGet(metricsGet): + fn = _nvmlGetFunctionPointer("nvmlGpmMetricsGet") + ret = fn(byref(metricsGet)) + _nvmlCheckReturn(ret) + return metricsGet + + +def nvmlGpmSampleFree(gpmSample): + fn = _nvmlGetFunctionPointer("nvmlGpmSampleFree") + ret = fn(gpmSample) + _nvmlCheckReturn(ret) + return + + +def nvmlGpmSampleAlloc(): + gpmSample = c_nvmlGpmSample_t() + fn = _nvmlGetFunctionPointer("nvmlGpmSampleAlloc") + ret = fn(byref(gpmSample)) + _nvmlCheckReturn(ret) + return gpmSample + + +def nvmlGpmSampleGet(device, gpmSample): + fn = _nvmlGetFunctionPointer("nvmlGpmSampleGet") + ret = fn(device, gpmSample) + _nvmlCheckReturn(ret) + return gpmSample + + +def nvmlGpmMigSampleGet(device, gpuInstanceId, gpmSample): + fn = _nvmlGetFunctionPointer("nvmlGpmMigSampleGet") + ret = fn(device, gpuInstanceId, gpmSample) + _nvmlCheckReturn(ret) + return gpmSample + + +def nvmlGpmQueryDeviceSupport(device): + gpmSupport = c_nvmlGpmSupport_t() + gpmSupport.version = NVML_GPM_SUPPORT_VERSION + fn = _nvmlGetFunctionPointer("nvmlGpmQueryDeviceSupport") + ret = fn(device, byref(gpmSupport)) + _nvmlCheckReturn(ret) + return gpmSupport + + +def nvmlGpmSetStreamingEnabled(device, state): + c_state = c_uint(state) + fn = _nvmlGetFunctionPointer("nvmlGpmSetStreamingEnabled") + ret = fn(device, c_state) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlGpmQueryIfStreamingEnabled(device): + c_state = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGpmQueryIfStreamingEnabled") + ret = fn(device, byref(c_state)) + _nvmlCheckReturn(ret) + return c_state.value + + +# Low Power Structure and Function + +NVML_NVLINK_POWER_STATE_HIGH_SPEED = 0x0 +NVML_NVLINK_POWER_STATE_LOW = 0x1 + +NVML_NVLINK_LOW_POWER_THRESHOLD_MIN = 0x1 +NVML_NVLINK_LOW_POWER_THRESHOLD_MAX = 0x1FFF +NVML_NVLINK_LOW_POWER_THRESHOLD_RESET = 0xFFFFFFFF +NVML_NVLINK_LOW_POWER_THRESHOLD_DEFAULT = NVML_NVLINK_LOW_POWER_THRESHOLD_RESET + + +class c_nvmlNvLinkPowerThres_t(Structure): + _fields_ = [ + ("lowPwrThreshold", c_uint), + ] + + +def nvmlDeviceSetNvLinkDeviceLowPowerThreshold(device, l1threshold): + c_info = c_nvmlNvLinkPowerThres_t() + c_info.lowPwrThreshold = l1threshold + fn = _nvmlGetFunctionPointer("nvmlDeviceSetNvLinkDeviceLowPowerThreshold") + ret = fn(device, byref(c_info)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +NVML_GPU_FABRIC_UUID_LEN = 16 + +_nvmlGpuFabricState_t = c_uint +NVML_GPU_FABRIC_STATE_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_STATE_NOT_STARTED = 1 +NVML_GPU_FABRIC_STATE_IN_PROGRESS = 2 +NVML_GPU_FABRIC_STATE_COMPLETED = 3 + + +class c_nvmlGpuFabricInfo_t(_PrintableStructure): + _fields_ = [ + ("clusterUuid", c_char * NVML_DEVICE_UUID_BUFFER_SIZE), + ("status", _nvmlReturn_t), + ("cliqueId", c_uint32), + ("state", _nvmlGpuFabricState_t), + ] + + +NVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_DEGRADED_BW = 0 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_DEGRADED_BW = 0x11 + +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_RECOVERY_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_RECOVERY_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_RECOVERY_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_ROUTE_RECOVERY = 2 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_ROUTE_RECOVERY = 0x11 + +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_UNHEALTHY_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_UNHEALTHY_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_UNHEALTHY_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_ROUTE_UNHEALTHY = 4 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_ROUTE_UNHEALTHY = 0x11 + +NVML_GPU_FABRIC_HEALTH_MASK_ACCESS_TIMEOUT_RECOVERY_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_ACCESS_TIMEOUT_RECOVERY_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_ACCESS_TIMEOUT_RECOVERY_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_ACCESS_TIMEOUT_RECOVERY = 6 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_ACCESS_TIMEOUT_RECOVERY = 0x11 + +nvmlGpuFabricInfo_v2 = 0x02000024 + + +class c_nvmlGpuFabricInfoV_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("clusterUuid", c_char * NVML_GPU_FABRIC_UUID_LEN), + ("status", _nvmlReturn_t), + ("cliqueId", c_uint32), + ("state", _nvmlGpuFabricState_t), + ("healthMask", c_uint32), + ] + + def __init__(self): + super(c_nvmlGpuFabricInfoV_t, self).__init__(version=nvmlGpuFabricInfo_v2) + + +def nvmlDeviceGetGpuFabricInfo(device, gpuFabricInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuFabricInfo") + ret = fn(device, gpuFabricInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetGpuFabricInfoV(device, gpuFabricInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuFabricInfoV") + ret = fn(device, gpuFabricInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +###################### +## Enums/defines +#### NVML GPU NVLINK BW MODE +NVML_GPU_NVLINK_BW_MODE_FULL = 0x0 +NVML_GPU_NVLINK_BW_MODE_OFF = 0x1 +NVML_GPU_NVLINK_BW_MODE_MIN = 0x2 +NVML_GPU_NVLINK_BW_MODE_HALF = 0x3 +NVML_GPU_NVLINK_BW_MODE_3QUARTER = 0x4 +NVML_GPU_NVLINK_BW_MODE_COUNT = 0x5 + + +def nvmlSystemSetNvlinkBwMode(mode): + fn = _nvmlGetFunctionPointer("nvmlSystemSetNvlinkBwMode") + ret = fn(mode) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlSystemGetNvlinkBwMode(): + mode = c_uint() + fn = _nvmlGetFunctionPointer("nvmlSystemGetNvlinkBwMode") + ret = fn(byref(mode)) + _nvmlCheckReturn(ret) + return mode.value + + +_nvmlPowerScopeType_t = c_uint +NVML_POWER_SCOPE_GPU = 0 +NVML_POWER_SCOPE_MODULE = 1 +NVML_POWER_SCOPE_MEMORY = 2 + + +class c_nvmlPowerValue_v2_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("powerScope", _nvmlPowerScopeType_t), + ("powerValueMw", c_uint), + ] + _fmt_ = {"": "%d B"} + + +nvmlPowerValue_v2 = 0x0200000C + + +def nvmlDeviceSetPowerManagementLimit_v2( + device, powerScope, powerLimit, version=nvmlPowerValue_v2 +): + c_powerScope = _nvmlPowerScopeType_t(powerScope) + c_powerValue = c_nvmlPowerValue_v2_t() + c_powerValue.version = c_uint(version) + c_powerValue.powerScope = c_powerScope + c_powerValue.powerValueMw = c_uint(powerLimit) + fn = _nvmlGetFunctionPointer("nvmlDeviceSetPowerManagementLimit_v2") + ret = fn(device, byref(c_powerValue)) + return NVML_SUCCESS + + +class c_nvmlEccSramErrorStatus_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("aggregateUncParity", c_ulonglong), + ("aggregateUncSecDed", c_ulonglong), + ("aggregateCor", c_ulonglong), + ("volatileUncParity", c_ulonglong), + ("volatileUncSecDed", c_ulonglong), + ("volatileCor", c_ulonglong), + ("aggregateUncBucketL2", c_ulonglong), + ("aggregateUncBucketSm", c_ulonglong), + ("aggregateUncBucketPcie", c_ulonglong), + ("aggregateUncBucketMcu", c_ulonglong), + ("aggregateUncBucketOther", c_ulonglong), + ("bThresholdExceeded", c_uint), + ] + + def __init__(self): + super(c_nvmlEccSramErrorStatus_v1_t, self).__init__( + version=nvmlEccSramErrorStatus_v1 + ) + + +nvmlEccSramErrorStatus_v1 = 0x1000068 + + +def nvmlDeviceGetSramEccErrorStatus(device, status): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSramEccErrorStatus") + ret = fn(device, status) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +NVML_DEV_CAP_EGM = 1 << 0 +nvmlDeviceCapabilities_v1 = 0x1000008 + + +class c_nvmlDeviceCapabilities_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("capMask", c_uint), + ] + + def __init__(self): + super(c_nvmlDeviceCapabilities_v1_t, self).__init__( + version=nvmlDeviceCapabilities_v1 + ) + + +def nvmlDeviceGetCapabilities(device, caps): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCapabilities") + return fn(device, caps) + + +class c_nvmlPlatformInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("ibGuid", c_char * 16), + ("rackGuid", c_char * 16), + ("chassisPhysicalSlotNumber", c_char), + ("computeSlotIndex", c_char), + ("nodeIndex", c_char), + ("peerType", c_char), + ("moduleId", c_char), + ] + + def __init__(self): + super(c_nvmlPlatformInfo_v1_t, self).__init__(version=nvmlPlatformInfo_v1) + + +nvmlPlatformInfo_v1 = 0x100002C + + +def nvmlDeviceGetPlatformInfo(device, platformInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPlatformInfo") + ret = fn(device, platformInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +class c_nvmlMask255_t(_PrintableStructure): + _fields_ = [ + ("mask", c_uint * 8), + ] + + +NVML_WORKLOAD_POWER_MAX_PROFILES = 255 +NVML_POWER_PROFILE_MAX_P = 0 +NVML_POWER_PROFILE_MAX_Q = 1 +NVML_POWER_PROFILE_COMPUTE = 2 +NVML_POWER_PROFILE_MEMORY_BOUND = 3 +NVML_POWER_PROFILE_NETWORK = 4 +NVML_POWER_PROFILE_BALANCED = 5 +NVML_POWER_PROFILE_LLM_INFERENCE = 6 +NVML_POWER_PROFILE_LLM_TRAINING = 7 +NVML_POWER_PROFILE_RBM = 8 +NVML_POWER_PROFILE_DCPCIE = 9 +NVML_POWER_PROFILE_HMMA_SPARSE = 10 +NVML_POWER_PROFILE_HMMA_DENSE = 11 +NVML_POWER_PROFILE_SYNC_BALANCED = 12 +NVML_POWER_PROFILE_HPC = 13 +NVML_POWER_PROFILE_MIG = 14 +NVML_POWER_PROFILE_MAX = 15 + +nvmlWorkloadPowerProfileInfo_v1 = 0x100002C + + +class c_nvmlWorkloadPowerProfileInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("profileId", c_uint), + ("priority", c_uint), + ("conflictingmask", c_nvmlMask255_t), + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileInfo_v1_t, self).__init__( + version=nvmlWorkloadPowerProfileInfo_v1 + ) + + +nvmlWorkloadPowerProfileProfilesInfo_v1 = 0x1002BF8 + + +class c_nvmlWorkloadPowerProfileProfilesInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("perfProfilesMask", c_nvmlMask255_t), + ( + "perfProfile", + c_nvmlWorkloadPowerProfileInfo_v1_t * NVML_WORKLOAD_POWER_MAX_PROFILES, + ), + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileProfilesInfo_v1_t, self).__init__( + version=nvmlWorkloadPowerProfileProfilesInfo_v1 + ) + + +nvmlWorkloadPowerProfileCurrentProfiles_v1 = 0x1000064 + + +class c_nvmlWorkloadPowerProfileCurrentProfiles_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("perfProfilesMask", c_nvmlMask255_t), + ("requestedProfilesMask", c_nvmlMask255_t), + ("enforcedProfilesMask", c_nvmlMask255_t), + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileCurrentProfiles_v1_t, self).__init__( + version=nvmlWorkloadPowerProfileCurrentProfiles_v1 + ) + + +nvmlWorkloadPowerProfileRequestedProfiles_v1 = 0x1000024 + + +class c_nvmlWorkloadPowerProfileRequestedProfiles_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("requestedProfilesMask", c_nvmlMask255_t), + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileRequestedProfiles_v1_t, self).__init__( + version=nvmlWorkloadPowerProfileRequestedProfiles_v1 + ) + + +def nvmlDeviceWorkloadPowerProfileGetProfilesInfo(device, profilesInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileGetProfilesInfo") + ret = fn(device, profilesInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceWorkloadPowerProfileGetCurrentProfiles(device, currentProfiles): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileGetCurrentProfiles") + ret = fn(device, currentProfiles) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceWorkloadPowerProfileSetRequestedProfiles(device, requestedProfiles): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileSetRequestedProfiles") + ret = fn(device, requestedProfiles) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceWorkloadPowerProfileClearRequestedProfiles(device, requestedProfiles): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileClearRequestedProfiles") + ret = fn(device, requestedProfiles) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetNvlinkSupportedBwModes(device, supportedBwModes): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvlinkSupportedBwModes") + ret = fn(device, supportedBwModes) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetNvlinkBwMode(device, getBwMode): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvlinkBwMode") + ret = fn(device, getBwMode) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceSetNvlinkBwMode(device, setBwMode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetNvlinkBwMode") + ret = fn(device, setBwMode) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +nvmlDramEncryptionInfo_v1 = 0x01000008 + + +class c_nvmlDramEncryptionInfo_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("encryptionState", _nvmlEnableState_t), + ] + + def __init__(self): + super(c_nvmlDramEncryptionInfo_t, self).__init__( + version=nvmlDramEncryptionInfo_v1 + ) + + +def nvmlDeviceGetDramEncryptionMode(handle): + c_currState = c_nvmlDramEncryptionInfo_t() + c_pendingState = c_nvmlDramEncryptionInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDramEncryptionMode") + ret = fn(handle, byref(c_currState), byref(c_pendingState)) + _nvmlCheckReturn(ret) + return [c_currState.encryptionState, c_pendingState.encryptionState] + + +# added to API +def nvmlDeviceGetCurrentDramEncryptionMode(handle): + return nvmlDeviceGetDramEncryptionMode(handle)[0] + + +# added to API +def nvmlDeviceGetPendingDramEncryptionMode(handle): + return nvmlDeviceGetDramEncryptionMode(handle)[1] + + +def nvmlDeviceSetDramEncryptionMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDramEncryptionMode") + c_dramEncryptionMode = c_nvmlDramEncryptionInfo_t() + c_dramEncryptionMode.encryptionState = mode + ret = fn(handle, byref(c_dramEncryptionMode)) + _nvmlCheckReturn(ret) + return None + + +# Power Smoothing defines +NVML_POWER_SMOOTHING_MAX_NUM_PROFILES = 5 +NVML_POWER_SMOOTHING_ADMIN_OVERRIDE_NOT_SET = 0xFFFFFFFF +NVML_POWER_SMOOTHING_PROFILE_PARAM_PERCENT_TMP_FLOOR = 0 +NVML_POWER_SMOOTHING_PROFILE_PARAM_RAMP_UP_RATE = 1 +NVML_POWER_SMOOTHING_PROFILE_PARAM_RAMP_DOWN_RATE = 2 +NVML_POWER_SMOOTHING_PROFILE_PARAM_RAMP_DOWN_HYSTERESIS = 3 + +nvmlPowerSmoothingState_v1 = 0x1000008 + + +class c_nvmlPowerSmoothingState_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("state", c_uint), + ] + + def __init__(self): + super(c_nvmlPowerSmoothingState_v1_t, self).__init__( + version=nvmlPowerSmoothingState_v1 + ) + + +nvmlPowerSmoothingProfile_v1 = 0x1000018 + + +class c_nvmlPowerSmoothingProfile_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("profileId", c_uint), + ("paramId", c_uint), + ("value", c_double), + ] + + def __init__(self): + super(c_nvmlPowerSmoothingProfile_v1_t, self).__init__( + version=nvmlPowerSmoothingProfile_v1 + ) + + +def nvmlDevicePowerSmoothingActivatePresetProfile(device, profile): + fn = _nvmlGetFunctionPointer("nvmlDevicePowerSmoothingActivatePresetProfile") + ret = fn(device, profile) + _nvmlCheckReturn(ret) + + +def nvmlDevicePowerSmoothingUpdatePresetProfileParam(device, profile): + fn = _nvmlGetFunctionPointer("nvmlDevicePowerSmoothingUpdatePresetProfileParam") + ret = fn(device, profile) + _nvmlCheckReturn(ret) + + +def nvmlDevicePowerSmoothingSetState(device, state): + fn = _nvmlGetFunctionPointer("nvmlDevicePowerSmoothingSetState") + ret = fn(device, state) + _nvmlCheckReturn(ret) diff --git a/sglang/python/sglang/multimodal_gen/tools/convert_hf_to_fp8.py b/sglang/python/sglang/multimodal_gen/tools/convert_hf_to_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..7c3e0051cea6233efd0f3437dc6babc58341b007 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/tools/convert_hf_to_fp8.py @@ -0,0 +1,319 @@ +# copied and adapted from Slime +""" +Convert HuggingFace safetensors model to FP8 format for efficient inference. + +Example usage: + # convert FLUX.1-dev transformer to FP8 + python -m sglang.multimodal_gen.tools.convert_hf_to_fp8 \ + --model-dir /path/to/FLUX.1-dev/transformer \ + --save-dir /path/to/FLUX.1-dev/transformer-FP8 \ + --strategy block \ + --block-size 128 128 + +Options: + --model-dir MODEL_DIR + path to the directory of the HF safetensors model (e.g., transformer subfolder) + --save-dir SAVE_DIR + path to the directory to save the converted FP8 model + --strategy {block,channel,tensor} + quantization strategy (default: block) + --block-size [BLOCK_SIZE ...] + block size for block quantization, e.g., --block-size 128 128 + --max-workers MAX_WORKERS + number of worker threads for parallel processing (default: 1) +""" + +import argparse +import gc +import json +import os +import shutil +import threading +from concurrent.futures import ThreadPoolExecutor + +import safetensors +import safetensors.torch +import torch +import torch.nn.functional as F +from tqdm import tqdm + +FP8_INFO = torch.finfo(torch.float8_e4m3fn) +FP8_MAX, FP8_MIN = FP8_INFO.max, FP8_INFO.min + + +def ceildiv(a, b): + return -(-a // b) + + +def block_fp8(weight, block_size): + + # per block quant + block_n, block_k = block_size[0], block_size[1] + + shape_0, shape_1 = weight.shape + + n_tiles = ceildiv(shape_0, block_n) + k_tiles = ceildiv(shape_1, block_k) + + q_weight = F.pad( + weight, + (0, k_tiles * block_k - shape_1, 0, n_tiles * block_n - shape_0), + mode="constant", + value=0.0, + ) + + qweight = q_weight.reshape(n_tiles, block_n, k_tiles, block_k) + block_max = torch.max(torch.abs(qweight), dim=1, keepdim=True)[0] + block_max = torch.max(block_max, dim=3, keepdim=True)[0] + + scale = block_max.to(torch.float32) / FP8_MAX + qweight = ( + (qweight / scale) + .clamp(min=FP8_MIN, max=FP8_MAX) + .reshape((n_tiles * block_n, k_tiles * block_k)) + .to(torch.float8_e4m3fn) + ) + qweight = qweight[:shape_0, :shape_1].clone().detach() + scale = scale.squeeze() + + return qweight, scale + + +def channel_fp8(weight): + channel_max = torch.max(weight.abs(), dim=-1, keepdim=True)[0] + scale = channel_max.clamp(min=1e-12).to(torch.float32) / FP8_MAX + qweight = (weight / scale).clamp(min=FP8_MIN, max=FP8_MAX) + qweight = qweight.to(torch.float8_e4m3fn) + return qweight, scale + + +def tensor_fp8(weight): + scale = weight.abs().max().clamp(min=1e-12).to(torch.float32) / FP8_MAX + qweight = (weight / scale).clamp(min=FP8_MIN, max=FP8_MAX) + qweight = qweight.to(torch.float8_e4m3fn) + scale = scale.view(1) + return qweight, scale + + +def quant_fp8(weight, strategy, block_size=None): + if strategy == "tensor": + return tensor_fp8(weight) + elif strategy == "channel": + return channel_fp8(weight) + else: + return block_fp8(weight, block_size) + + +class ConversionResult: + def __init__(self): + self.lock = threading.Lock() + self.weight_map = {} + self.param_count = 0 + self.modules_to_not_convert = [] + + def add_result(self, filename, q_weights, module_names): + with self.lock: + for k, v in q_weights.items(): + self.weight_map[k] = filename + self.param_count += v.numel() + self.modules_to_not_convert.extend(module_names) + + +def process_file( + input_path, output_path, filename, strategy, block_size, result_collector +): + if not filename.endswith(".safetensors"): + return + + print(f"Processing {filename}, memory usage: {torch.cuda.memory_allocated()}") + weights = {} + q_weights = {} + + with safetensors.safe_open( + os.path.join(input_path, filename), framework="pt", device="cuda" + ) as f: + for k in f.keys(): + weights[k] = f.get_tensor(k) + + modules_to_not_convert = [] + for key in weights.keys(): + if ( + "weight" in key + and "layernorm" not in key + and "embed" not in key + and "router" not in key + and "mlp.gate." not in key + and "norm" not in key + and "lm_head" not in key + and "eh_proj" not in key + and "net" not in key + and "txt_mod" not in key + and "img_mod" not in key + and "img_in" not in key + and "txt_in" not in key + and "time_in" not in key + and "vector_in" not in key + and "adaLN_modulation" not in key + and "all_final_layer" not in key + and "feed_forward" not in key + and "proj_out.weight" != key + ): + qw, s = quant_fp8(weights[key], strategy, block_size) + q_weights[key] = qw + if block_size: + scale_name = key.replace(".weight", ".weight_scale_inv") + else: + scale_name = key.replace(".weight", ".weight_scale") + q_weights[scale_name] = s + else: + modules_to_not_convert.append(key.replace(".weight", "")) + q_weights[key] = weights[key] + + safetensors.torch.save_file( + q_weights, os.path.join(output_path, filename), metadata={"format": "pt"} + ) + + result_collector.add_result(filename, q_weights, modules_to_not_convert) + + +def convert_fp8(input_path, output_path, strategy, block_size=None, max_workers=4): + input_path = os.path.abspath(input_path) + os.makedirs(output_path, exist_ok=True) + + for filename in os.listdir(input_path): + if not filename.endswith(".safetensors") and not os.path.isdir( + os.path.join(input_path, filename) + ): + shutil.copyfile( + os.path.join(input_path, filename), os.path.join(output_path, filename) + ) + + safetensors_files = [ + f for f in os.listdir(input_path) if f.endswith(".safetensors") + ] + + result_collector = ConversionResult() + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for filename in safetensors_files: + future = executor.submit( + process_file, + input_path, + output_path, + filename, + strategy, + block_size, + result_collector, + ) + futures.append(future) + + for future in tqdm(futures, desc="Processing files"): + future.result() + + if strategy == "block" or strategy == "tensor": + quantization_config = { + "activation_scheme": "dynamic", + "fmt": "e4m3", + "quant_method": "fp8", + } + if block_size: + quantization_config["weight_block_size"] = block_size + if len(result_collector.modules_to_not_convert) > 0: + quantization_config["modules_to_not_convert"] = list( + set(result_collector.modules_to_not_convert) + ) + else: + quant_group = { + "group_0": { + "input_activations": { + "actorder": None, + "block_structure": None, + "dynamic": True, + "group_size": None, + "num_bits": 8, + "observer": None, + "observer_kwargs": {}, + "strategy": "token", + "symmetric": True, + "type": "float", + }, + "output_activations": None, + "targets": ["Linear"], + "weights": { + "actorder": None, + "block_structure": None, + "dynamic": False, + "group_size": None, + "num_bits": 8, + "observer": "minmax", + "observer_kwargs": {}, + "strategy": strategy, + "symmetric": True, + "type": "float", + }, + }, + } + quantization_config = { + "config_groups": quant_group, + "format": "float-quantized", + "ignore": list(set(result_collector.modules_to_not_convert)), + "quant_method": "compressed-tensors", + "quantization_status": "compressed", + } + + config_path = os.path.join(input_path, "config.json") + if os.path.exists(config_path): + cfg = json.load(open(config_path)) + cfg["quantization_config"] = quantization_config + json.dump(cfg, open(os.path.join(output_path, "config.json"), "w"), indent=2) + + index_dict = { + "weight_map": result_collector.weight_map, + "metadata": {"total_size": result_collector.param_count}, + } + json.dump( + index_dict, + open(os.path.join(output_path, "model.safetensors.index.json"), "w"), + indent=2, + ) + + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-dir", + type=str, + help="Path to the directory of the HF safetensors model.", + ) + parser.add_argument( + "--save-dir", + type=str, + help="Path to the directory to save the converted model.", + ) + parser.add_argument( + "--strategy", type=str, default="block", choices=["block", "channel", "tensor"] + ) + parser.add_argument( + "--block-size", type=int, nargs="*", default=None, help="eg. --block-size 32 32" + ) + parser.add_argument( + "--max-workers", + type=int, + default=8, + help="Number of worker threads for parallel processing", + ) + args = parser.parse_args() + + if not os.path.exists(args.save_dir): + print(f"Creating directory {args.save_dir}") + os.makedirs(args.save_dir) + elif not os.path.isdir(args.save_dir): + raise ValueError("The save_dir should be a directory.") + + convert_fp8( + args.model_dir, args.save_dir, args.strategy, args.block_size, args.max_workers + ) diff --git a/sglang/python/sglang/multimodal_gen/utils.py b/sglang/python/sglang/multimodal_gen/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..359fd35f3ffad170daf654ab2eb9fdc4e103c010 --- /dev/null +++ b/sglang/python/sglang/multimodal_gen/utils.py @@ -0,0 +1,799 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/utils.py + +import argparse +import ctypes +import importlib +import importlib.util +import inspect +import math +import os +import signal +import sys +import threading +import traceback +from collections.abc import Callable +from dataclasses import dataclass, fields, is_dataclass +from functools import lru_cache, partial, wraps +from typing import Any, TypeVar, cast + +import cloudpickle +import torch +import yaml +from torch.distributed.fsdp import MixedPrecisionPolicy + +import sglang.multimodal_gen.envs as envs +from sglang.multimodal_gen.runtime.utils.logging_utils import ( + SortedHelpFormatter, + init_logger, +) + +logger = init_logger(__name__) + +T = TypeVar("T") + + +def expand_path_fields(obj) -> None: + """In-place expanduser on all dataclass fields whose name ends with '_path' or '_paths'.""" + eu = os.path.expanduser + for f in fields(obj): + v = getattr(obj, f.name) + if f.name.endswith("_path") and isinstance(v, str): + setattr(obj, f.name, eu(v)) + elif f.name.endswith("_path") and isinstance(v, list): + setattr(obj, f.name, [eu(x) if isinstance(x, str) else x for x in v]) + elif f.name.endswith("_paths") and isinstance(v, dict): + setattr( + obj, + f.name, + {k: eu(p) if isinstance(p, str) else p for k, p in v.items()}, + ) + + +# TODO(will): used to convert server_args.precision to torch.dtype. Find a +# cleaner way to do this. +PRECISION_TO_TYPE = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + +STR_BACKEND_ENV_VAR: str = "SGLANG_DIFFUSION_ATTENTION_BACKEND" +STR_ATTN_CONFIG_ENV_VAR: str = "SGLANG_DIFFUSION_ATTENTION_CONFIG" + + +def find_nccl_library() -> str: + """ + We either use the library file specified by the `VLLM_NCCL_SO_PATH` + environment variable, or we find the library file brought by PyTorch. + After importing `torch`, `libnccl.so.2`, `librccl.so.1` or `libmccl.so.2` + can be found by `ctypes` automatically. + """ + so_file = envs.SGLANG_DIFFUSION_NCCL_SO_PATH + + # manually load the nccl library + if so_file: + logger.info( + "Found nccl from environment variable SGLANG_DIFFUSION_NCCL_SO_PATH=%s", + so_file, + ) + else: + if torch.version.cuda is not None: + so_file = "libnccl.so.2" + elif torch.version.hip is not None: + so_file = "librccl.so.1" + elif hasattr(torch.version, "musa") and torch.version.musa is not None: + so_file = "libmccl.so.2" + else: + raise ValueError("NCCL only supports CUDA, ROCm and MUSA backends.") + logger.info("Found nccl from library %s", so_file) + return str(so_file) + + +prev_set_stream = torch.cuda.set_stream + +_current_stream = None + + +def _patched_set_stream(stream: torch.cuda.Stream | None) -> None: + global _current_stream + _current_stream = stream + if stream is not None: + prev_set_stream(stream) + + +torch.cuda.set_stream = _patched_set_stream + + +def current_stream() -> torch.cuda.Stream | None: + """ + replace `torch.cuda.current_stream()` with `sglang.multimodal_gen.utils.current_stream()`. + it turns out that `torch.cuda.current_stream()` is quite expensive, + as it will construct a new stream object at each call. + here we patch `torch.cuda.set_stream` to keep track of the current stream + directly, so that we can avoid calling `torch.cuda.current_stream()`. + + the underlying hypothesis is that we do not call `torch._C._cuda_setStream` + from C/C++ code. + """ + from sglang.multimodal_gen.runtime.platforms import current_platform + + # For non-CUDA platforms, return None + if not current_platform.is_cuda_alike(): + return None + + global _current_stream + if _current_stream is None: + # when this function is called before any stream is set, + # we return the default stream. + # On ROCm using the default 0 stream in combination with RCCL + # is hurting performance. Therefore creating a dedicated stream + # per process + _current_stream = ( + torch.cuda.Stream() + if current_platform.is_rocm() + else torch.cuda.current_stream() + ) + return _current_stream + + +class StoreBoolean(argparse.Action): + + def __init__(self, option_strings, dest, default=False, required=False, help=None): + super().__init__( + option_strings=option_strings, + dest=dest, + nargs="?", + const=True, + default=default, + required=required, + help=help, + ) + + def __call__(self, parser, namespace, values, option_string=None): + if values is None: + setattr(namespace, self.dest, True) + elif isinstance(values, str): + if values.lower() == "true": + setattr(namespace, self.dest, True) + elif values.lower() == "false": + setattr(namespace, self.dest, False) + else: + raise ValueError( + f"Invalid boolean value: {values}. " "Expected 'true' or 'false'." + ) + else: + setattr(namespace, self.dest, bool(values)) + + +class FlexibleArgumentParser(argparse.ArgumentParser): + """ArgumentParser that allows both underscore and dash in names.""" + + def __init__(self, *args, **kwargs) -> None: + # Set the default 'formatter_class' to SortedHelpFormatter + if "formatter_class" not in kwargs: + kwargs["formatter_class"] = SortedHelpFormatter + super().__init__(*args, **kwargs) + + def parse_args( # type: ignore[override] + self, args=None, namespace=None + ) -> argparse.Namespace: + if args is None: + args = sys.argv[1:] + + if any(arg.startswith("--config") for arg in args): + args = self._pull_args_from_config(args) + + # Convert underscores to dashes and vice versa in argument names + processed_args = [] + for arg in args: + if arg.startswith("--"): + if "=" in arg: + key, value = arg.split("=", 1) + key = "--" + key[len("--") :].replace("_", "-") + processed_args.append(f"{key}={value}") + else: + processed_args.append("--" + arg[len("--") :].replace("_", "-")) + elif arg.startswith("-O") and arg != "-O" and len(arg) == 2: + # allow -O flag to be used without space, e.g. -O3 + processed_args.append("-O") + processed_args.append(arg[2:]) + else: + processed_args.append(arg) + + namespace = super().parse_args(processed_args, namespace) + + # Track which arguments were explicitly provided + namespace._provided = set() + + i = 0 + while i < len(args): + arg = args[i] + if arg.startswith("--"): + # Handle --key=value format + if "=" in arg: + key = arg.split("=")[0][2:].replace("-", "_") + namespace._provided.add(key) + i += 1 + # Handle --key value format + else: + key = arg[2:].replace("-", "_") + namespace._provided.add(key) + # Skip the value if there is one + if i + 1 < len(args) and not args[i + 1].startswith("-"): + i += 2 + else: + i += 1 + else: + i += 1 + + return namespace # type: ignore[no-any-return] + + def _pull_args_from_config(self, args: list[str]) -> list[str]: + """Method to pull arguments specified in the config file + into the command-line args variable. + + The arguments in config file will be inserted between + the argument list. + + example: + ```yaml + port: 12323 + tensor-parallel-size: 4 + ``` + ```python + $: vllm {serve,chat,complete} "facebook/opt-12B" \ + --config config.yaml -tp 2 + $: args = [ + "serve,chat,complete", + "facebook/opt-12B", + '--config', 'config.yaml', + '-tp', '2' + ] + $: args = [ + "serve,chat,complete", + "facebook/opt-12B", + '--port', '12323', + '--tp-size', '4', + '-tp', '2' + ] + ``` + + Please note how the config args are inserted after the sub command. + this way the order of priorities is maintained when these are args + parsed by super(). + """ + index = -1 + config_arg = None + for i, arg in enumerate(args): + if arg.startswith("--config"): + if index != -1: + raise ValueError("More than one config file specified!") + index = i + config_arg = arg + + if config_arg is None: + return args + args_before_config = args[:index] + if "=" in config_arg: + file_path = config_arg.split("=", 1)[1] + args_after_config = args[index + 1 :] + else: + if index == len(args) - 1: + raise ValueError( + "No config file specified! " + "Please check your command-line arguments." + ) + file_path = args[index + 1] + args_after_config = args[index + 2 :] + + config_args = self._load_config_file(file_path) + + # 0th index is for {serve,chat,complete} + # followed by model_tag (only for serve) + # followed by config args + # followed by rest of cli args. + # maintaining this order will enforce the precedence + # of cli > config > defaults + if args[0] == "serve": + if index == 1: + raise ValueError( + "No model_tag specified! Please check your command-line" + " arguments." + ) + command = args_before_config[0] + model_tag = args_before_config[1] + other_args_before = args_before_config[2:] + args = ( + [command, model_tag] + + config_args + + other_args_before + + args_after_config + ) + else: + command = args_before_config[0] + other_args_before = args_before_config[1:] + args = [command] + config_args + other_args_before + args_after_config + + return args + + def _load_config_file(self, file_path: str) -> list[str]: + """Loads a yaml file and returns the key value pairs as a + flattened list with argparse like pattern + ```yaml + port: 12323 + tensor-parallel-size: 4 + vae_config: + load_encoder: false + load_decoder: true + ``` + returns: + processed_args: list[str] = [ + '--port': '12323', + '--tp-size': '4', + '--vae-config.load-encoder': 'false', + '--vae-config.load-decoder': 'true' + ] + """ + + extension: str = file_path.split(".")[-1] + if extension not in ("yaml", "yml", "json"): + raise ValueError( + "Config file must be of a yaml/yml/json type.\ + %s supplied", + extension, + ) + + processed_args: list[str] = [] + + config: dict[str, Any] = {} + try: + with open(file_path) as config_file: + config = yaml.safe_load(config_file) + except Exception as ex: + logger.error( + "Unable to read the config file at %s. \ + Make sure path is correct", + file_path, + ) + raise ex + + store_boolean_arguments = [ + action.dest for action in self._actions if isinstance(action, StoreBoolean) + ] + + def process_dict(prefix: str, d: dict[str, Any]): + for key, value in d.items(): + full_key = f"{prefix}.{key}" if prefix else key + + if isinstance(value, bool) and full_key not in store_boolean_arguments: + if value: + processed_args.append("--" + full_key) + else: + processed_args.append("--" + full_key) + processed_args.append("false") + elif isinstance(value, list): + processed_args.append("--" + full_key) + for item in value: + processed_args.append(str(item)) + elif isinstance(value, dict): + process_dict(full_key, value) + else: + processed_args.append("--" + full_key) + processed_args.append(str(value)) + + process_dict("", config) + + return processed_args + + +def warn_for_unimplemented_methods(cls: type[T]) -> type[T]: + """ + A replacement for `abc.ABC`. + When we use `abc.ABC`, subclasses will fail to instantiate + if they do not implement all abstract methods. + Here, we only require `raise NotImplementedError` in the + base class, and log a warning if the method is not implemented + in the subclass. + """ + + original_init = cls.__init__ + + def find_unimplemented_methods(self: object): + unimplemented_methods = [] + for attr_name in dir(self): + # bypass inner method + if attr_name.startswith("_"): + continue + + try: + attr = getattr(self, attr_name) + # get the func of callable method + if callable(attr): + attr_func = attr.__func__ + except AttributeError: + continue + src = inspect.getsource(attr_func) + if "NotImplementedError" in src: + unimplemented_methods.append(attr_name) + if unimplemented_methods: + method_names = ",".join(unimplemented_methods) + msg = f"Methods {method_names} not implemented in {self}" + logger.warning(msg) + + @wraps(original_init) + def wrapped_init(self, *args, **kwargs) -> None: + original_init(self, *args, **kwargs) + find_unimplemented_methods(self) + + type.__setattr__(cls, "__init__", wrapped_init) + return cls + + +def align_to(value: int, alignment: int) -> int: + """align height, width according to alignment + + Args: + value (int): height or width + alignment (int): target alignment factor + + Returns: + int: the aligned value + """ + return int(math.ceil(value / alignment) * alignment) + + +def resolve_obj_by_qualname(qualname: str) -> Any: + """ + Resolve an object by its fully qualified name. + """ + module_name, obj_name = qualname.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, obj_name) + + +# From vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/utils.py +def import_pynvml(): + """ + Historical comments: + + libnvml.so is the library behind nvidia-smi, and + pynvml is a Python wrapper around it. We use it to get GPU + status without initializing CUDA context in the current process. + Historically, there are two packages that provide pynvml: + - `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official + wrapper. It is a dependency of sglang-diffusion, and is installed when users + install sglang-diffusion. It provides a Python module named `pynvml`. + - `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper. + Prior to version 12.0, it also provides a Python module `pynvml`, + and therefore conflicts with the official one which is a standalone Python file. + This causes errors when both of them are installed. + Starting from version 12.0, it migrates to a new module + named `pynvml_utils` to avoid the conflict. + It is so confusing that many packages in the community use the + unofficial one by mistake, and we have to handle this case. + For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial + one, and it will cause errors, see the issue + https://github.com/vllm-project/vllm/issues/12847 for example. + After all the troubles, we decide to copy the official `pynvml` + module to our codebase, and use it directly. + """ + import sglang.multimodal_gen.third_party.pynvml as pynvml + + return pynvml + + +def update_environment_variables(envs: dict[str, str]): + for k, v in envs.items(): + if k in os.environ and os.environ[k] != v: + logger.warning( + "Overwriting environment variable %s " "from '%s' to '%s'", + k, + os.environ[k], + v, + ) + os.environ[k] = v + + +def run_method( + obj: Any, method: str | bytes | Callable, args: tuple[Any], kwargs: dict[str, Any] +) -> Any: + """ + Run a method of an object with the given arguments and keyword arguments. + If the method is string, it will be converted to a method using getattr. + If the method is serialized bytes and will be deserialized using + cloudpickle. + If the method is a callable, it will be called directly. + """ + if isinstance(method, bytes): + func = partial(cloudpickle.loads(method), obj) + elif isinstance(method, str): + try: + func = getattr(obj, method) + except AttributeError: + raise NotImplementedError( + f"Method {method!r} is not" " implemented." + ) from None + else: + func = partial(method, obj) # type: ignore + return func(*args, **kwargs) + + +def shallow_asdict(obj) -> dict[str, Any]: + if not is_dataclass(obj): + raise TypeError("Expected dataclass instance") + return {f.name: getattr(obj, f.name) for f in fields(obj)} + + +# TODO: validate that this is fine +def kill_itself_when_parent_died() -> None: + # if sys.platform == "linux": + # sigkill this process when parent worker manager dies + PR_SET_PDEATHSIG = 1 + import platform + + if platform.system() == "Linux": + libc = ctypes.CDLL("libc.so.6") + libc.prctl(PR_SET_PDEATHSIG, signal.SIGKILL) + # elif platform.system() == "Darwin": + # libc = ctypes.CDLL("libc.dylib") + # logger.warning("kill_itself_when_parent_died is only supported in linux.") + else: + logger.warning("kill_itself_when_parent_died is only supported in linux.") + + +def get_exception_traceback() -> str: + etype, value, tb = sys.exc_info() + err_str = "".join(traceback.format_exception(etype, value, tb)) + return err_str + + +class TypeBasedDispatcher: + + def __init__(self, mapping: list[tuple[type, Callable]]): + self._mapping = mapping + + def __call__(self, obj: Any): + for ty, fn in self._mapping: + if isinstance(obj, ty): + return fn(obj) + raise ValueError(f"Invalid object: {obj}") + + +@dataclass +class MixedPrecisionState: + param_dtype: torch.dtype | None = None + reduce_dtype: torch.dtype | None = None + output_dtype: torch.dtype | None = None + compute_dtype: torch.dtype | None = None + mp_policy: MixedPrecisionPolicy | None = None + + +# Thread-local storage for mixed precision state +_mixed_precision_state = threading.local() + + +def get_mixed_precision_state() -> MixedPrecisionState: + """Get the current mixed precision state.""" + if not hasattr(_mixed_precision_state, "state"): + raise ValueError("Mixed precision state not set") + return cast(MixedPrecisionState, _mixed_precision_state.state) + + +def set_mixed_precision_policy( + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + output_dtype: torch.dtype | None = None, + mp_policy: MixedPrecisionPolicy | None = None, +): + """Set mixed precision policy globally. + + Args: + param_dtype: Parameter dtype used for training + reduce_dtype: Reduction dtype used for gradients + output_dtype: Optional output dtype + """ + state = MixedPrecisionState( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + output_dtype=output_dtype, + mp_policy=mp_policy, + ) + _mixed_precision_state.state = state + + +def get_compute_dtype() -> torch.dtype: + """Get the current compute dtype from mixed precision policy.""" + if not hasattr(_mixed_precision_state, "state"): + return torch.get_default_dtype() + else: + state = get_mixed_precision_state() + return state.param_dtype + + +def dict_to_3d_list( + mask_strategy: dict[str, Any] | None = None, + t_max: int | None = None, + l_max: int | None = None, + h_max: int | None = None, +) -> list[list[list[torch.Tensor | None]]]: + """ + Convert a dictionary of mask indices to a 3D list of tensors. + Args: + mask_strategy: keys are "t_l_h", values are torch.Tensor masks. + t_max, l_max, h_max: if provided (all three), force the output shape to (t_max, l_max, h_max). + If all three are None, infer shape from the data. + """ + # Case 1: no data, but fixed shape requested + if mask_strategy is None: + assert ( + t_max is not None and l_max is not None and h_max is not None + ), "If mask_strategy is None, you must provide t_max, l_max, and h_max" + return [ + [[None for _ in range(h_max)] for _ in range(l_max)] for _ in range(t_max) + ] + + # Parse all keys into integer tuples + indices = [tuple(map(int, key.split("_"))) for key in mask_strategy] + + # Decide on dimensions + if t_max is None and l_max is None and h_max is None: + # fully dynamic: infer from data + max_timesteps_idx = max(t for t, _, _ in indices) + 1 + max_layer_idx = max(l for _, l, _ in indices) + 1 # noqa: E741 + max_head_idx = max(h for _, _, h in indices) + 1 + else: + # require all three to be provided + assert t_max is not None and l_max is not None and h_max is not None, ( + "Either supply none of (t_max, l_max, h_max) to infer dimensions, " + "or supply all three to fix the shape." + ) + max_timesteps_idx = t_max + max_layer_idx = l_max + max_head_idx = h_max + + # Preallocate + result = [ + [[None for _ in range(max_head_idx)] for _ in range(max_layer_idx)] + for _ in range(max_timesteps_idx) + ] + + # Fill in, skipping any out-of-bounds entries + for key, value in mask_strategy.items(): + t, l, h = map(int, key.split("_")) # noqa: E741 + if ( + 0 <= t < max_timesteps_idx + and 0 <= l < max_layer_idx + and 0 <= h < max_head_idx + ): + result[t][l][h] = value + # else: silently ignore any key that doesn't fit + + return result + + +def set_random_seed(seed: int) -> None: + from sglang.multimodal_gen.runtime.platforms import current_platform + + current_platform.seed_everything(seed) + + +@lru_cache(maxsize=1) +def is_vsa_available() -> bool: + return importlib.util.find_spec("vsa") is not None + + +@lru_cache(maxsize=1) +def is_vmoba_available() -> bool: + if importlib.util.find_spec("kernel.csrc.attn.vmoba_attn.vmoba") is None: + return False + try: + import flash_attn + + return flash_attn.__version__ >= "2.7.4" + except Exception: + return False + + +# adapted from: https://github.com/Wan-Video/Wan2.2/blob/main/wan/utils/utils.py +def masks_like( + tensors, zero=False, generator=None, p=0.2 +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """ + Generate binary masks for Text-to-Image-to-Video (TI2V) tasks. + + Creates masks to control which frames should be preserved vs replaced. + Primarily used to fix the first frame to the input image while generating other frames. + + Args: + tensors: List of tensors with shape [C, T, H, W] + zero: If True, set first frame (dim 1, index 0) to zero. Default: False + generator: Optional random generator for stochastic masking + p: Probability of applying special noise when generator is provided. Default: 0.2 + + Returns: + Tuple of two lists of tensors: + - When zero=False: Both lists contain all-ones tensors + - When zero=True (no generator): First frame set to 0, others to 1 + - When zero=True (with generator): First frame set to small random values with probability p + + Example: + >>> latent = torch.randn(48, 69, 96, 160) # [C, T, H, W] + >>> _, mask = masks_like([latent], zero=True) + >>> # mask[0][:, 0] == 0 (first frame) + >>> # mask[0][:, 1:] == 1 (other frames) + >>> blended = (1.0 - mask[0]) * image + mask[0] * latent + >>> # Result: first frame = image, other frames = latent + """ + assert isinstance(tensors, list) + out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensors] + + out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensors] + + if zero: + if generator is not None: + for u, v in zip(out1, out2, strict=False): + random_num = torch.rand( + 1, generator=generator, device=generator.device + ).item() + if random_num < p: + u[:, 0] = ( + torch.normal( + mean=-3.5, + std=0.5, + size=(1,), + device=u.device, + generator=generator, + ) + .expand_as(u[:, 0]) + .exp() + ) + v[:, 0] = torch.zeros_like(v[:, 0]) + else: + u[:, 0] = u[:, 0] + v[:, 0] = v[:, 0] + + else: + for u, v in zip(out1, out2, strict=False): + u[:, 0] = torch.zeros_like(u[:, 0]) + v[:, 0] = torch.zeros_like(v[:, 0]) + + return out1, out2 + + +# adapted from: https://github.com/Wan-Video/Wan2.2/blob/main/wan/utils/utils.py +def best_output_size(w, h, dw, dh, expected_area): + # float output size + ratio = w / h + ow = (expected_area * ratio) ** 0.5 + oh = expected_area / ow + + # process width first + ow1 = int(ow // dw * dw) + oh1 = int(expected_area / ow1 // dh * dh) + assert ow1 % dw == 0 and oh1 % dh == 0 and ow1 * oh1 <= expected_area + ratio1 = ow1 / oh1 + + # process height first + oh2 = int(oh // dh * dh) + ow2 = int(expected_area / oh2 // dw * dw) + assert oh2 % dh == 0 and ow2 % dw == 0 and ow2 * oh2 <= expected_area + ratio2 = ow2 / oh2 + + # compare ratios + if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2, ratio2 / ratio): + return ow1, oh1 + else: + return ow2, oh2 + + +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height, None diff --git a/sglang/python/sglang/profiler.py b/sglang/python/sglang/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc7a100e24b0885ef26296f9ecaba2501536aad --- /dev/null +++ b/sglang/python/sglang/profiler.py @@ -0,0 +1,158 @@ +""" +Run live profiling. + +Usage: +python3 -m sglang.profiler +""" + +import argparse +import json +import os +import time +from argparse import ArgumentParser +from pathlib import Path +from typing import List, Optional + +import requests + +PROFILER_DIR = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp") + + +def run_profile( + url: Optional[str], + num_steps: int, + activities: List[str], + output_dir: Optional[str] = None, + profile_by_stage: bool = False, + merge_profiles: bool = False, + profile_prefix: Optional[str] = None, + start_step: Optional[int] = None, +) -> str: + if output_dir is None: + output_dir = PROFILER_DIR + + output_dir = Path(os.path.abspath(os.path.normpath(output_dir))) / str(time.time()) + output_dir.mkdir(exist_ok=True, parents=True) + + print(f"Dump profiling traces to {output_dir}") + print( + f"Waiting for {num_steps} steps and the trace to be flushed.... ({profile_by_stage=})" + ) + + # Dump server args. + file_path = Path(output_dir) / "server_args.json" + if not file_path.exists(): + response = requests.get(url + "/get_server_info") + response.raise_for_status() + server_args_data = response.json() + with open(file_path, "w") as file: + file.write(json.dumps(server_args_data)) + + # Start profiler. The API replies when all steps are processed + # and files are generated. + json_data = { + "output_dir": str(output_dir), + "num_steps": str(num_steps), + "activities": activities, + "profile_by_stage": profile_by_stage, + "merge_profiles": merge_profiles, + "profile_prefix": profile_prefix, + } + if start_step is not None: + json_data["start_step"] = str(start_step) + + response = requests.post(url=url + "/start_profile", json=json_data) + response.raise_for_status() + + trace_link = str(output_dir) + return trace_link + + +if __name__ == "__main__": + parser = ArgumentParser(description="Benchmark the online serving throughput.") + parser.add_argument( + "--url", + type=str, + default="http://localhost:30000", + help="Server or API base url if not using http host and port.", + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Profile directory to dump profile traces.", + ) + parser.add_argument( + "--num-steps", + type=int, + default=5, + help="The number of forward steps to profile.", + ) + parser.add_argument( + "--profile-by-stage", + action=argparse.BooleanOptionalAction, + type=bool, + default=False, + help="Whether to profile prefill and decode separately", + ) + parser.add_argument( + "--profile-prefix", + type=str, + help="The prefix of this profiler file.", + ) + parser.add_argument( + "--cpu", + action=argparse.BooleanOptionalAction, + type=bool, + default=True, + help="Whether to profile CPU activity", + ) + parser.add_argument( + "--gpu", + action=argparse.BooleanOptionalAction, + type=bool, + default=True, + help="Whether to profile GPU activity", + ) + parser.add_argument( + "--mem", + action=argparse.BooleanOptionalAction, + type=bool, + default=False, + help="Whether to profile memory usage (https://pytorch.org/memory_viz)", + ) + parser.add_argument( + "--rpd", + action=argparse.BooleanOptionalAction, + type=bool, + default=False, + help="Whether to use ROCM rpd profiler (https://github.com/ROCm/rocmProfileData)", + ) + parser.add_argument( + "--merge-profiles", + action=argparse.BooleanOptionalAction, + type=bool, + default=False, + help="Whether to merge profiles from all ranks into a single trace file", + ) + + args = parser.parse_args() + activities = [] + if args.cpu: + activities.append("CPU") + if args.gpu: + activities.append("GPU") + if args.mem: + activities.append("MEM") + if args.rpd: + activities.append("RPD") + + run_profile( + url=args.url, + num_steps=args.num_steps, + activities=activities, + output_dir=args.output_dir, + profile_by_stage=args.profile_by_stage, + profile_prefix=args.profile_prefix, + merge_profiles=args.merge_profiles, + ) diff --git a/sglang/python/sglang/srt/batch_invariant_ops/__init__.py b/sglang/python/sglang/srt/batch_invariant_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3caa60f0b4a7397ab22332f6c9de569f882283bc --- /dev/null +++ b/sglang/python/sglang/srt/batch_invariant_ops/__init__.py @@ -0,0 +1,29 @@ +# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/__init__.py + +from .batch_invariant_ops import ( + AttentionBlockSize, + disable_batch_invariant_mode, + enable_batch_invariant_mode, + get_batch_invariant_attention_block_size, + is_batch_invariant_mode_enabled, + log_softmax, + matmul_persistent, + mean_dim, + rms_norm_batch_invariant, + set_batch_invariant_mode, +) + +__version__ = "0.1.0" + +__all__ = [ + "set_batch_invariant_mode", + "is_batch_invariant_mode_enabled", + "disable_batch_invariant_mode", + "enable_batch_invariant_mode", + "matmul_persistent", + "log_softmax", + "mean_dim", + "get_batch_invariant_attention_block_size", + "AttentionBlockSize", + "rms_norm_batch_invariant", +] diff --git a/sglang/python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py b/sglang/python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..80bb66e7ea67a0bcf4b6dc764c5e0e5aa1a4bf1a --- /dev/null +++ b/sglang/python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py @@ -0,0 +1,994 @@ +# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/batch_invariant_ops.py + +import contextlib +from collections import namedtuple +from collections.abc import Callable +from typing import Any, Dict + +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM +from sglang.srt.utils.common import calc_diff, get_bool_env_var + +if ENABLE_JIT_DEEPGEMM: + import deep_gemm + +_ENABLE_MM_DEEPGEMM = get_bool_env_var( + "SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_DEEPGEMM", "1" +) +# If true, allows to fallback to batch variant gemm when the shape cannot be run in DeepGEMM +_ENABLE_MM_FALLBACK_VARIANT = get_bool_env_var( + "SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_FALLBACK_VARIANT", "0" +) +_ENABLE_MM_COMPARISON_TEST = get_bool_env_var( + "SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_COMPARISON_TEST" +) + +if not _ENABLE_MM_DEEPGEMM: + print("Disable DeepGEMM in batch invariant ops. Performance may be suboptimal.") + +__all__ = [ + "set_batch_invariant_mode", + "is_batch_invariant_mode_enabled", + "disable_batch_invariant_mode", + "enable_batch_invariant_mode", +] + + +def _matmul_launch_metadata( + grid: Callable[..., Any], kernel: Any, args: Dict[str, Any] +) -> Dict[str, Any]: + ret = {} + m, n, k = args["M"], args["N"], args["K"] + ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]" + if "tiles_per_update" in args: + ret["name"] = ( + f"{kernel.name} [M={m}, N={n}, K={k}, tiles_per_update={args['tiles_per_update']:02}]" + ) + if "c_ptr" in args: + bytes_per_elem = args["c_ptr"].element_size() + else: + bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 + ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k + ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n) + return ret + + +@triton.jit +def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + return pid_m, pid_n + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_persistent( + a_ptr, + b_ptr, + c_ptr, # + bias_ptr, + M, + N, + K, # + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + NUM_SMS: tl.constexpr, # + A_LARGE: tl.constexpr, + B_LARGE: tl.constexpr, + C_LARGE: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m, pid_n = _compute_pid( + tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS + ) + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + if A_LARGE: + offs_am = offs_am.to(tl.int64) + if B_LARGE: + offs_bn = offs_bn.to(tl.int64) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + if A_LARGE or B_LARGE: + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + else: + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + ) + b_ptrs = b_ptr + ( + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + ) + + a = tl.load( + a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0 + ) + b = tl.load( + b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, b, accumulator) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if C_LARGE: + offs_cm = offs_cm.to(tl.int64) + offs_cn = offs_cn.to(tl.int64) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if HAS_BIAS: + bias_ptrs = bias_ptr + offs_cn + bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32) + accumulator += bias + if c_ptr.dtype.element_ty == tl.float8e4nv: + c = accumulator.to(tl.float8e4nv) + elif c_ptr.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif c_ptr.dtype.element_ty == tl.float32: + c = accumulator.to(tl.float32) + else: + c = accumulator.to(tl.float16) + tl.store(c_ptrs, c, mask=c_mask) + + +def _matmul_persistent_triton( + a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None +): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + assert ( + bias is None or bias.dim() == 1 + ), "Currently assuming bias is 1D, let Horace know if you run into this" + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + M, K = a.shape + K, N = b.shape + dtype = a.dtype + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=dtype) + + # 1D launch kernel where each block gets its own program. + def grid(META): + return ( + min( + NUM_SMS, + triton.cdiv(M, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ), + ) + + configs = { + torch.bfloat16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + torch.float16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + torch.float32: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + } + # print(a.device, b.device, c.device) + matmul_kernel_persistent[grid]( + a, + b, + c, # + bias, + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + NUM_SMS=NUM_SMS, # + A_LARGE=a.numel() > 2**31, + B_LARGE=b.numel() > 2**31, + C_LARGE=c.numel() > 2**31, + HAS_BIAS=bias is not None, + **configs[dtype], + ) + return c + + +def _matmul_persistent_deepgemm( + a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None +): + M, K = a.shape + K, N = b.shape + dtype = a.dtype + out = torch.empty((M, N), device=a.device, dtype=dtype) + + try: + deep_gemm.bf16_gemm_nn(a, b, out) + except RuntimeError as e: + raise RuntimeError( + f"DeepGEMM failed for matrix shapes M={M}, N={N}, K={K}. " + f"This typically occurs when dimensions are too small for DeepGEMM's TMA descriptors. " + f"Consider increasing MIN_DEEPGEMM_DIM in matmul_persistent() or disabling DeepGEMM " + f"for small matrices. Original error: {e}" + ) from e + + # TODO can this be put in DeepGEMM's `c`? + if bias is not None: + out += bias + + return out + + +def matmul_persistent( + a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None +): + K, N = b.shape + + # DeepGEMM has minimum dimension requirements for TMA descriptors + MIN_DEEPGEMM_DIM = 16 + + if ( + _ENABLE_MM_DEEPGEMM + and ENABLE_JIT_DEEPGEMM + and (a.dtype == torch.bfloat16) + and (b.dtype == torch.bfloat16) + and a.is_contiguous() + and b.transpose(0, 1).is_contiguous() + and N >= MIN_DEEPGEMM_DIM + ): + if _ENABLE_MM_COMPARISON_TEST: + out_triton = _matmul_persistent_triton(a=a, b=b, bias=bias) + out_deepgemm = _matmul_persistent_deepgemm(a=a, b=b, bias=bias) + diff = calc_diff(out_triton, out_deepgemm) + assert diff < 0.0001, f"{diff=} {out_triton=} {out_deepgemm=}" + # can be enabled for debugging + # print( + # f"{diff=} " + # f"{(out_triton - out_deepgemm).abs().mean()=} " + # f"{(out_triton - out_deepgemm).abs().sum()=} " + # f"{torch.sum(out_triton != out_deepgemm)=} " + # ) + # print(f"{a=} {b=} {bias=} {out_triton=} {out_deepgemm=}") + return out_deepgemm + + return _matmul_persistent_deepgemm(a=a, b=b, bias=bias) + + if _ENABLE_MM_FALLBACK_VARIANT: + out = torch.einsum("ik,kj->ij", a, b) + if bias is not None: + out += bias + return out + + return _matmul_persistent_triton(a=a, b=b, bias=bias) + + +@triton.jit +def _log_softmax_kernel( + input_ptr, + output_ptr, + input_row_stride: tl.constexpr, + output_row_stride: tl.constexpr, + n_cols: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + Compute log_softmax along the last dimension of a 2D tensor. + Each block handles one row of the input tensor. + """ + # Get the row index for this block + row_idx = tl.program_id(0).to(tl.int64) + + # Compute base pointers for input and output rows + row_start_ptr = input_ptr + row_idx * input_row_stride + output_row_start_ptr = output_ptr + row_idx * output_row_stride + + # Step 1: Find maximum value in the row for numerical stability + # Load first block to infer dtype and initialize max_val with correct type + col_idx_init = tl.arange(0, BLOCK_SIZE) + mask_init = col_idx_init < n_cols + vals_init = tl.load( + row_start_ptr + col_idx_init, mask=mask_init, other=-float("inf") + ) + max_val = tl.max(vals_init) + + # Continue with remaining blocks + for col_offset in range(BLOCK_SIZE, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + # Load values + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=-float("inf")) + + # Update maximum + max_val = tl.max(tl.maximum(vals, max_val)) + + # Step 2: Compute sum of exp(x - max_val) + # Initialize sum_exp with correct dtype by using tl.sum on a zero vector + sum_exp = tl.sum(tl.zeros([1], dtype=max_val.dtype)) + + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + # Load values + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) + + # Compute exp(x - max_val) and accumulate + exp_vals = tl.exp(vals - max_val) + sum_exp += tl.sum(tl.where(mask, exp_vals, 0.0)) + + # Compute log(sum_exp) + log_sum_exp = tl.log(sum_exp) + + # Step 3: Compute final log_softmax values: x - max_val - log_sum_exp + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + # Load values + vals = tl.load(row_start_ptr + col_idx, mask=mask) + + # Compute log_softmax + output = vals - max_val - log_sum_exp + + # Store results + tl.store(output_row_start_ptr + col_idx, output, mask=mask) + + +def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Compute log_softmax using Triton kernel. + + Args: + input: Input tensor + dim: Dimension along which to compute log_softmax (only -1 or last dim supported) + >> Stashed changes + Returns: + Tensor with log_softmax applied along the specified dimension + """ + if dim != -1 and dim != input.ndim - 1: + raise ValueError( + "This implementation only supports log_softmax along the last dimension" + ) + + # Flatten all dimensions except the last one + original_shape = input.shape + input_2d = input.reshape(-1, input.shape[-1]) + input_2d = input_2d.contiguous() + + n_rows, n_cols = input_2d.shape + + # Allocate output tensor + output = torch.empty_like(input_2d) + + # Choose block size based on the number of columns + BLOCK_SIZE = 1024 + + # Launch kernel with one block per row + grid = (n_rows,) + _log_softmax_kernel[grid]( + input_2d, + output, + input_2d.stride(0), + output.stride(0), + n_cols, + BLOCK_SIZE=BLOCK_SIZE, + ) + # Reshape output back to original shape + return output.reshape(original_shape) + + +@triton.jit +def mean_kernel( + input_ptr, + output_ptr, + input_stride0, + input_stride1, + input_stride2, + output_stride0, + output_stride1, + M, # size before reduction dim + N, # size of reduction dim + K, # size after reduction dim + BLOCK_SIZE: tl.constexpr, +): + """ + Kernel for computing mean along a single dimension. + Input is viewed as (M, N, K) where N is the dimension being reduced. + """ + # Program ID gives us which output element we're computing + pid = tl.program_id(0) + + # Compute output indices + m_idx = pid // K + k_idx = pid % K + + # Bounds check + if m_idx >= M or k_idx >= K: + return + + # Accumulate sum across reduction dimension + acc = 0.0 + for n_start in range(0, N, BLOCK_SIZE): + n_offsets = n_start + tl.arange(0, BLOCK_SIZE) + mask = n_offsets < N + + # Calculate input indices + input_idx = ( + m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2 + ) + + # Load and accumulate + vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0) + acc += tl.sum(vals) + + # Compute mean and store + mean_val = acc / N + output_idx = m_idx * output_stride0 + k_idx * output_stride1 + tl.store(output_ptr + output_idx, mean_val) + + +def mean_dim( + input: torch.Tensor, + dim: int, + keepdim: bool = False, + dtype: torch.dtype | None = None, +) -> torch.Tensor: + """ + Triton implementation of torch.mean with single dimension reduction. + + Args: + input: Input tensor + dim: Single dimension along which to compute mean + keepdim: Whether to keep the reduced dimension + dtype: Output dtype. If None, uses input dtype (or float32 for integer inputs) + + Returns: + Tensor with mean values along specified dimension + """ + # Validate inputs + assert input.is_cuda, "Input must be a CUDA tensor" + assert ( + -input.ndim <= dim < input.ndim + ), f"Invalid dimension {dim} for tensor with {input.ndim} dimensions" + + # Handle negative dim + if dim < 0: + dim = dim + input.ndim + + # Handle dtype + if dtype is None: + if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: + dtype = torch.float32 + else: + dtype = input.dtype + + # Convert input to appropriate dtype if needed + if input.dtype != dtype: + input = input.to(dtype) + + # Get input shape and strides + shape = list(input.shape) + + # Calculate dimensions for kernel + M = 1 + for i in range(dim): + M *= shape[i] + + N = shape[dim] + + K = 1 + for i in range(dim + 1, len(shape)): + K *= shape[i] + + # Reshape input to 3D view (M, N, K) + input_3d = input.reshape(M, N, K) + + # Create output shape + if keepdim: + output_shape = shape.copy() + output_shape[dim] = 1 + else: + output_shape = shape[:dim] + shape[dim + 1 :] + + # Create output tensor + output = torch.empty(output_shape, dtype=dtype, device=input.device) + + # Reshape output for kernel + if keepdim: + output_2d = output.reshape(M, 1, K).squeeze(1) + else: + output_2d = output.reshape(M, K) + + # Launch kernel + grid = (M * K,) + BLOCK_SIZE = 1024 + + mean_kernel[grid]( + input_3d, + output_2d, + input_3d.stride(0), + input_3d.stride(1), + input_3d.stride(2), + output_2d.stride(0), + output_2d.stride(1) if output_2d.ndim > 1 else 0, + M, + N, + K, + BLOCK_SIZE, + ) + + return output + + +def mm_batch_invariant(a, b): + return matmul_persistent(a, b) + + +def addmm_batch_invariant(bias, a, b): + return matmul_persistent(a, b, bias=bias) + + +def _log_softmax_batch_invariant(input, dim, _half_to_float): + assert not _half_to_float, "not implemented" + return log_softmax(input, dim=dim) + + +def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None): + assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}" + if len(dim) == 1: + return mean_dim(input, dim[0], keepdim=keepdim) + else: + assert input.dtype in { + torch.float16, + torch.bfloat16, + torch.float32, + }, "only float types supported for now" + n_elems = 1 + for d in dim: + n_elems *= input.shape[d] + return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems + + +@triton.jit +def bmm_kernel_persistent( + a_ptr, + b_ptr, + c_ptr, # + B, + M, + N, + K, # + stride_ab, + stride_am, + stride_ak, + stride_bb, + stride_bk, + stride_bn, + stride_cb, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + NUM_SMS: tl.constexpr, # + A_LARGE: tl.constexpr, + B_LARGE: tl.constexpr, + C_LARGE: tl.constexpr, +): + """ + Batched matrix multiplication kernel that processes batches in parallel. + Each tile processes a (BLOCK_SIZE_M, BLOCK_SIZE_N) output block for a specific batch. + """ + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles_per_batch = num_pid_m * num_pid_n + num_tiles_total = B * num_tiles_per_batch + + offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + # Process tiles in a deterministic order: batch-major ordering + for tile_id in tl.range(start_pid, num_tiles_total, NUM_SMS, flatten=True): + # Decompose tile_id into batch and within-batch tile + batch_idx = tile_id // num_tiles_per_batch + tile_in_batch = tile_id % num_tiles_per_batch + + pid_m, pid_n = _compute_pid( + tile_in_batch, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS + ) + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + if A_LARGE: + offs_am = offs_am.to(tl.int64) + if B_LARGE: + offs_bn = offs_bn.to(tl.int64) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + # Add batch offset + if A_LARGE or B_LARGE: + batch_idx_typed = batch_idx.to(tl.int64) + else: + batch_idx_typed = batch_idx + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + if A_LARGE or B_LARGE: + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + else: + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + ( + batch_idx_typed * stride_ab + + offs_am[:, None] * stride_am + + offs_k[None, :] * stride_ak + ) + b_ptrs = b_ptr + ( + batch_idx_typed * stride_bb + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) + + a = tl.load( + a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0 + ) + b = tl.load( + b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, b, accumulator) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if C_LARGE: + offs_cm = offs_cm.to(tl.int64) + offs_cn = offs_cn.to(tl.int64) + c_ptrs = ( + c_ptr + + batch_idx_typed * stride_cb + + stride_cm * offs_cm[:, None] + + stride_cn * offs_cn[None, :] + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + if c_ptr.dtype.element_ty == tl.float8e4nv: + c = accumulator.to(tl.float8e4nv) + elif c_ptr.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif c_ptr.dtype.element_ty == tl.float32: + c = accumulator.to(tl.float32) + else: + c = accumulator.to(tl.float16) + tl.store(c_ptrs, c, mask=c_mask) + + +def bmm_batch_invariant(a, b, *, out=None): + # Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N) + # Process batches in parallel with our persistent kernel + if a.ndim == 3 and b.ndim == 3: + # Check constraints + assert a.shape[0] == b.shape[0], "Batch sizes must match" + assert a.shape[2] == b.shape[1], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + + B = a.shape[0] + M = a.shape[1] + K = a.shape[2] + N = b.shape[2] + dtype = a.dtype + + # Allocate output + if out is None: + c = torch.empty((B, M, N), device=a.device, dtype=dtype) + else: + c = out + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + # Use fixed kernel configuration for determinism + configs = { + torch.bfloat16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + torch.float16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + torch.float32: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + } + + config = configs.get(dtype) + if config is None: + raise ValueError( + f"Unsupported dtype {dtype} for bmm_batch_invariant. " + f"Supported dtypes are: {list(configs.keys())}" + ) + + # Grid: limit by NUM_SMS for persistent kernel approach + num_tiles_per_batch = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv( + N, config["BLOCK_SIZE_N"] + ) + num_tiles_total = B * num_tiles_per_batch + grid = (min(NUM_SMS, num_tiles_total),) + + bmm_kernel_persistent[grid]( + a, + b, + c, # + B, + M, + N, + K, # + a.stride(0), + a.stride(1), + a.stride(2), # + b.stride(0), + b.stride(1), + b.stride(2), # + c.stride(0), + c.stride(1), + c.stride(2), # + NUM_SMS=NUM_SMS, # + A_LARGE=a.numel() > 2**31, + B_LARGE=b.numel() > 2**31, + C_LARGE=c.numel() > 2**31, + **config, + ) + + return c + else: + raise ValueError( + f"bmm_batch_invariant expects 3D tensors, " + f"got shapes {a.shape} and {b.shape}" + ) + + +@triton.jit +def _rms_norm_kernel( + input_ptr, + weight_ptr, + output_ptr, + input_row_stride: tl.constexpr, + output_row_stride: tl.constexpr, + n_cols: tl.constexpr, + eps, + BLOCK_SIZE: tl.constexpr, +): + """ + Compute RMS normalization along the last dimension of a 2D tensor. + RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight + Each block handles one row of the input tensor. + """ + row_idx = tl.program_id(0).to(tl.int64) + row_start_ptr = input_ptr + row_idx * input_row_stride + output_row_start_ptr = output_ptr + row_idx * output_row_stride + + # Step 1: Compute sum of squares in float32 to avoid overflow + sum_sq = tl.zeros([1], dtype=tl.float32) + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) + # Convert to float32 for accumulation to prevent overflow + vals_f32 = vals.to(tl.float32) + sq_vals = vals_f32 * vals_f32 + sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0)) + + # Step 2: Compute RMS (root mean square) in float32 + mean_sq = sum_sq / n_cols + rms = tl.sqrt(mean_sq + eps) + inv_rms = 1.0 / rms + + # Step 3: Normalize and apply weight + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) + weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0) + # Compute in float32 then convert back to input dtype + vals_f32 = vals.to(tl.float32) + weight_f32 = weight.to(tl.float32) + output_f32 = vals_f32 * inv_rms * weight_f32 + output = output_f32.to(vals.dtype) + tl.store(output_row_start_ptr + col_idx, output, mask=mask) + + +def rms_norm( + input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + """ + Compute RMS normalization using Triton kernel. + + RMS Norm normalizes the input by the root mean square and scales by weight: + output = input / sqrt(mean(input^2) + eps) * weight + + Args: + input: Input tensor of shape (..., hidden_size) + weight: Weight tensor of shape (hidden_size,) + eps: Small constant for numerical stability + + Returns: + Tensor with RMS normalization applied along the last dimension + """ + assert weight.dim() == 1, "Weight must be 1-dimensional" + assert input.shape[-1] == weight.shape[0], ( + f"Input last dimension ({input.shape[-1]}) must match " + f"weight dimension ({weight.shape[0]})" + ) + + # Flatten all dimensions except the last one + original_shape = input.shape + input_2d = input.reshape(-1, input.shape[-1]) + input_2d = input_2d.contiguous() + weight = weight.contiguous() + + n_rows, n_cols = input_2d.shape + + output = torch.empty_like(input_2d) + BLOCK_SIZE = 1024 + grid = (n_rows,) + _rms_norm_kernel[grid]( + input_2d, + weight, + output, + input_2d.stride(0), + output.stride(0), + n_cols, + eps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return output.reshape(original_shape) + + +def rms_norm_batch_invariant( + input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + """ + Batch-invariant wrapper for RMS normalization. + + This function provides a deterministic, batch-invariant implementation + of RMS normalization for use with the batch_invariant mode. + + Adapted from @https://github.com/vllm-project/vllm/blob/66a168a197ba214a5b70a74fa2e713c9eeb3251a/vllm/model_executor/layers/batch_invariant.py#L649 + + Args: + input: Input tensor of shape (..., hidden_size) + weight: Weight tensor of shape (hidden_size,) + eps: Small constant for numerical stability + + Returns: + RMS normalized tensor + """ + return rms_norm(input, weight, eps=eps) + + +_batch_invariant_MODE = False +_batch_invariant_LIB = None +_original_torch_bmm = None + + +def is_batch_invariant_mode_enabled(): + return _batch_invariant_MODE + + +def enable_batch_invariant_mode( + enable_bmm: bool = True, +): + global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm + if _batch_invariant_MODE: + return + + _batch_invariant_MODE = True + _batch_invariant_LIB = torch.library.Library("aten", "IMPL") + _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA") + _batch_invariant_LIB.impl( + "aten::_log_softmax", _log_softmax_batch_invariant, "CUDA" + ) + _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA") + + if enable_bmm: + _batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA") + + # Also monkeypatch torch.bmm directly as a fallback + _original_torch_bmm = torch.bmm + torch.bmm = bmm_batch_invariant + + +def disable_batch_invariant_mode(): + global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm + if _batch_invariant_LIB is not None: + _batch_invariant_LIB._destroy() + if _original_torch_bmm is not None: + torch.bmm = _original_torch_bmm + _original_torch_bmm = None + _batch_invariant_MODE = False + _batch_invariant_LIB = None + + +@contextlib.contextmanager +def set_batch_invariant_mode(enabled: bool = True): + global _batch_invariant_MODE, _batch_invariant_LIB + old_data = (_batch_invariant_MODE, _batch_invariant_LIB) + if enabled: + enable_batch_invariant_mode() + else: + disable_batch_invariant_mode() + yield + if _batch_invariant_LIB is not None: + _batch_invariant_LIB._destroy() + _batch_invariant_MODE, _batch_invariant_LIB = old_data + + +AttentionBlockSize = namedtuple("AttentionBlockSize", ["block_m", "block_n"]) + + +def get_batch_invariant_attention_block_size() -> AttentionBlockSize: + return AttentionBlockSize(block_m=16, block_n=16) diff --git a/sglang/python/sglang/srt/constants.py b/sglang/python/sglang/srt/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..c9da6b6bb1d58d6729c5af0650a37751ebbe0d58 --- /dev/null +++ b/sglang/python/sglang/srt/constants.py @@ -0,0 +1,10 @@ +# GPU Memory Types +GPU_MEMORY_TYPE_KV_CACHE = "kv_cache" +GPU_MEMORY_TYPE_WEIGHTS = "weights" +GPU_MEMORY_TYPE_CUDA_GRAPH = "cuda_graph" + +GPU_MEMORY_ALL_TYPES = [ + GPU_MEMORY_TYPE_KV_CACHE, + GPU_MEMORY_TYPE_WEIGHTS, + GPU_MEMORY_TYPE_CUDA_GRAPH, +] diff --git a/sglang/python/sglang/srt/constrained/base_grammar_backend.py b/sglang/python/sglang/srt/constrained/base_grammar_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..0b6c784167efbed4eaef987b9646e740a8918923 --- /dev/null +++ b/sglang/python/sglang/srt/constrained/base_grammar_backend.py @@ -0,0 +1,274 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The baseclass of a backend for grammar-guided constrained decoding.""" + +import logging +import time +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from threading import Event +from typing import Dict, List, Optional, Tuple + +import torch + +from sglang.srt.server_args import ServerArgs + +logger = logging.getLogger(__name__) + + +@dataclass +class GrammarStats: + compilation_time: Optional[float] = None + schema_count: Optional[int] = None + ebnf_size: Optional[int] = None + is_cache_hit: bool = False + is_grammar_aborted: bool = False + tree_traversal_time: List[float] = field(default_factory=list) + dispatch_type: Optional[str] = None + num_timeout: int = 0 + + +class BaseGrammarObject: + + def __init__(self): + self._finished = False + self.grammar_stats = None + self.current_token = None + + def maybe_init_reasoning(self, reasoning: bool): + pass + + def accept_token(self, token: int) -> None: + """ + Accept a token in the grammar. + """ + raise NotImplementedError() + + def rollback(self, k: int): + raise NotImplementedError() + + def is_terminated(self): + return False + + def allocate_vocab_mask( + self, vocab_size: int, batch_size: int, device + ) -> torch.Tensor: + raise NotImplementedError() + + def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: + raise NotImplementedError() + + @staticmethod + def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: + raise NotImplementedError() + + @staticmethod + def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: + raise NotImplementedError() + + def copy(self) -> "BaseGrammarObject": + return self + + @property + def finished(self): + return self._finished + + @finished.setter + def finished(self, finished): + self._finished = finished + + def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]: + """ + Try to jump forward in the grammar. + + Returns: + A jump forward helper which may be used in `jump_forward_str_state`. + None if the jump forward is not possible. + """ + raise NotImplementedError() + + def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: + """ + Jump forward for the grammar. + + Returns: + A tuple of the jump forward string and the next state of the grammar + (which can be used in `jump_and_retokenize` if needed). + """ + raise NotImplementedError() + + def jump_and_retokenize( + self, old_output_ids: List[int], new_output_ids: List[int], next_state: int + ) -> None: + """ + Jump forward occurs, and update the grammar state if needed. + """ + raise NotImplementedError() + + +INVALID_GRAMMAR_OBJ = BaseGrammarObject() + + +@dataclass +class CacheEntry: + value: BaseGrammarObject + event: Event + + +class BaseGrammarBackend: + def __init__(self): + self.executor = ThreadPoolExecutor() + self.cache: Dict[Tuple[str, str], CacheEntry] = {} + + def _not_supported(self, key_type: str, key_string: str) -> None: + logger.warning(f"Skip unsupported {key_type=}, {key_string=}") + + def dispatch_fallback( + self, key_type: str, key_string: str + ) -> Optional[BaseGrammarObject]: + """ + This function should not be reached in any case. + """ + raise ValueError(f"Invalid key_type: {key_type}={key_string}") + + def dispatch_json(self, key_string: str) -> Optional[BaseGrammarObject]: + return self._not_supported("json", key_string) + + def dispatch_regex(self, key_string: str) -> Optional[BaseGrammarObject]: + return self._not_supported("regex", key_string) + + def dispatch_ebnf(self, key_string: str) -> Optional[BaseGrammarObject]: + return self._not_supported("ebnf", key_string) + + def dispatch_structural_tag(self, key_string: str) -> Optional[BaseGrammarObject]: + return self._not_supported("structural_tag", key_string) + + def _init_value_dispatch( + self, key: Tuple[str, str], require_reasoning: bool + ) -> Optional[BaseGrammarObject]: + s = time.perf_counter() + key_type, key_string = key + if key_type == "json": + grammar = self.dispatch_json(key_string) + elif key_type == "regex": + grammar = self.dispatch_regex(key_string) + elif key_type == "ebnf": + grammar = self.dispatch_ebnf(key_string) + elif key_type == "structural_tag": + grammar = self.dispatch_structural_tag(key_string) + elif key_type == "structural_pattern": + grammar = self.dispatch_structural_pattern(key_string) + elif key_type == "structural_pattern_v2": + grammar = self.dispatch_structural_pattern_v2(key_string) + else: + grammar = self.dispatch_fallback(key_type, key_string) + + if grammar is not None and grammar.grammar_stats is not None: + grammar.grammar_stats.compilation_time = time.perf_counter() - s + return grammar + + def get_cached_or_future_value( + self, key: Tuple[str, str], require_reasoning: bool + ) -> Optional[BaseGrammarObject]: + value = self.cache.get(key) + if value: + copied_value = value.copy() + copied_value.maybe_init_reasoning(require_reasoning) + return copied_value, True + value = self.executor.submit(self._init_value_dispatch, key, require_reasoning) + return value, False + + def set_cache(self, key: Tuple[str, str], value: BaseGrammarObject): + self.cache[key] = value + + def reset(self): + self.cache.clear() + + +GRAMMAR_BACKEND_REGISTRY = {} + + +def register_grammar_backend(name, init_func): + GRAMMAR_BACKEND_REGISTRY[name] = init_func + + +def create_grammar_backend( + server_args: ServerArgs, + tokenizer, + vocab_size: int, + eos_token_ids: Optional[set] = None, +) -> Optional[BaseGrammarBackend]: + name = server_args.grammar_backend + + # Custom grammar backend has the highest priority + if name in GRAMMAR_BACKEND_REGISTRY: + return GRAMMAR_BACKEND_REGISTRY[name]( + server_args, tokenizer, vocab_size, eos_token_ids + ) + + # Default grammar backends + if name == "outlines": + from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend + + grammar_backend = OutlinesGrammarBackend( + tokenizer, + whitespace_pattern=server_args.constrained_json_whitespace_pattern, + ) + elif name == "xgrammar": + from sglang.srt.constrained.xgrammar_backend import ( + TokenizerNotSupportedError, + XGrammarGrammarBackend, + ) + + # Convert Set[int] to List[int] if needed + eos_list = list(eos_token_ids) if eos_token_ids else None + + try: + grammar_backend = XGrammarGrammarBackend( + tokenizer, + vocab_size=vocab_size, + model_eos_token_ids=eos_list, + any_whitespace=not server_args.constrained_json_disable_any_whitespace, + ) + except TokenizerNotSupportedError as e: + logger.warning( + f"Grammar backend disabled because tokenizer is not supported by XGrammar: {e}. " + "Falling back to grammar_backend='none'. " + "Structured outputs (JSON schema, regex, EBNF) will not be available." + ) + server_args.grammar_backend = "none" + return None + elif name == "llguidance": + from sglang.srt.constrained.llguidance_backend import GuidanceBackend + + grammar_backend = GuidanceBackend( + tokenizer=tokenizer, + any_whitespace=not server_args.constrained_json_disable_any_whitespace, + whitespace_pattern=server_args.constrained_json_whitespace_pattern, + ) + elif name == "none": + return None + else: + raise ValueError(f"Invalid grammar backend: {name}") + + if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"): + from sglang.srt.constrained.reasoner_grammar_backend import ( + ReasonerGrammarBackend, + ) + + grammar_backend = ReasonerGrammarBackend( + grammar_backend, tokenizer.think_end_id + ) + + return grammar_backend diff --git a/sglang/python/sglang/srt/constrained/grammar_manager.py b/sglang/python/sglang/srt/constrained/grammar_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..c2487c8d1089a465a1f05a1d3e93514a78aa6370 --- /dev/null +++ b/sglang/python/sglang/srt/constrained/grammar_manager.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +import logging +import time +from concurrent import futures +from typing import TYPE_CHECKING, List + +import torch + +from sglang.srt.constrained.base_grammar_backend import ( + INVALID_GRAMMAR_OBJ, + create_grammar_backend, +) +from sglang.srt.environ import envs + +if TYPE_CHECKING: + from sglang.srt.managers.io_struct import AbortReq + from sglang.srt.managers.schedule_batch import Req + from sglang.srt.managers.scheduler import Scheduler + +logger = logging.getLogger(__name__) + + +class GrammarManager: + def __init__(self, scheduler: Scheduler): + self.scheduler = scheduler + self.server_args = scheduler.server_args + self.grammar_queue: List[Req] = [] + if not self.server_args.skip_tokenizer_init: + self.grammar_backend = create_grammar_backend( + self.server_args, + scheduler.tokenizer, + scheduler.model_config.vocab_size, + scheduler.model_config.hf_eos_token_id, + ) + else: + self.grammar_backend = None + + self.grammar_sync_group = scheduler.dp_tp_cpu_group + self.grammar_sync_size = scheduler.dp_tp_group.world_size + self.grammar_sync_entry = scheduler.dp_tp_group.first_rank + self.is_grammar_sync_entry = scheduler.dp_tp_group.is_first_rank + + self.SGLANG_GRAMMAR_POLL_INTERVAL = envs.SGLANG_GRAMMAR_POLL_INTERVAL.get() + self.SGLANG_GRAMMAR_MAX_POLL_ITERATIONS = ( + envs.SGLANG_GRAMMAR_MAX_POLL_ITERATIONS.get() + ) + + def __len__(self): + return len(self.grammar_queue) + + def clear(self): + if self.grammar_backend: + self.grammar_backend.reset() + + def has_waiting_grammars(self) -> bool: + return len(self.grammar_queue) > 0 + + def abort_requests(self, recv_req: AbortReq): + for req in self.grammar_queue: + if recv_req.abort_all or req.rid.startswith(recv_req.rid): + logger.debug(f"Abort grammar queue request. {req.rid=}") + if req.grammar: + req.grammar.cancel() + req.set_finish_with_abort("Aborted by AbortReq.") + + def process_req_with_grammar(self, req: Req) -> bool: + # Init grammar cache for this request + add_to_grammar_queue = False + if ( + req.sampling_params.json_schema is not None + or req.sampling_params.regex is not None + or req.sampling_params.ebnf is not None + or req.sampling_params.structural_tag is not None + ): + if self.grammar_backend is None: + error_msg = "Grammar-based generation (json_schema, regex, ebnf, structural_tag) is not supported when the server is launched with --grammar-backend none" + req.set_finish_with_abort(error_msg) + else: + if req.sampling_params.json_schema is not None: + key = ("json", req.sampling_params.json_schema) + elif req.sampling_params.regex is not None: + key = ("regex", req.sampling_params.regex) + elif req.sampling_params.ebnf is not None: + key = ("ebnf", req.sampling_params.ebnf) + elif req.sampling_params.structural_tag: + key = ("structural_tag", req.sampling_params.structural_tag) + + value, cache_hit = self.grammar_backend.get_cached_or_future_value( + key, req.require_reasoning + ) + req.grammar = value + + if not cache_hit: + req.grammar_key = key + add_to_grammar_queue = True + else: + if value is INVALID_GRAMMAR_OBJ: # We hit a cached invalid grammar. + error_msg = f"Invalid grammar request with cache hit: {key=}" + req.set_finish_with_abort(error_msg) + + if add_to_grammar_queue: + self.grammar_queue.append(req) + + return add_to_grammar_queue + + def get_ready_grammar_requests(self) -> List[Req]: + """ + Move requests whose grammar objects are ready from grammar_queue to waiting_queue. + + Rank i returns two sets ready_reqs_i, failed_reqs_i + ready_reqs_all = all_gather(ready_reqs_i) + failed_reqs_all = all_gather(failed_reqs_i) + + ready_reqs = intersect(ready_reqs_all) + failed_reqs = union(failed_reqs_all) + """ + ready_req_idxs: set[int] = set() + failed_req_idxs: set[int] = set() + + # Poll for ready requests + start_time = time.perf_counter() + while time.perf_counter() - start_time < self.SGLANG_GRAMMAR_POLL_INTERVAL: + for i, req in enumerate(self.grammar_queue): + if i in ready_req_idxs: + continue + + if req.finished() or req.grammar is None: # It is aborted by AbortReq + ready_req_idxs.add(i) + continue + + assert isinstance(req.grammar, futures.Future), f"{req=}" + if req.grammar.done(): + ready_req_idxs.add(i) + + # Sleep a bit to avoid busy waiting + time.sleep(self.SGLANG_GRAMMAR_POLL_INTERVAL / 10) + + # Check failed requests + for i, req in enumerate(self.grammar_queue): + if i not in ready_req_idxs: + self.grammar_queue[i].grammar_wait_ct += 1 + if ( + self.grammar_queue[i].grammar_wait_ct + >= self.SGLANG_GRAMMAR_MAX_POLL_ITERATIONS + ): + # Timeout after max poll iterations + # The actual waiting time is SGLANG_GRAMMAR_MAX_POLL_ITERATIONS * max(SGLANG_GRAMMAR_POLL_INTERVAL, GPU_forward_batch_latency) + failed_req_idxs.add(i) + + # Sync ready and failed requests across all ranks + if self.grammar_sync_size == 1: + synced_ready_req_idxs = ready_req_idxs + synced_failed_req_idxs = failed_req_idxs + else: + all_gather_output = [None] * self.grammar_sync_size + torch.distributed.all_gather_object( + all_gather_output, + (ready_req_idxs, failed_req_idxs), + group=self.grammar_sync_group, + ) + synced_ready_req_idxs = set.intersection(*[x[0] for x in all_gather_output]) + synced_failed_req_idxs = set.union(*[x[1] for x in all_gather_output]) + + # Return ready requests + return_reqs: List[Req] = [] + for i in synced_ready_req_idxs: + req = self.grammar_queue[i] + return_reqs.append(req) + if req.finished() or req.grammar is None: # It is aborted by AbortReq + continue + + req.grammar = req.grammar.result() + self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy()) + if req.grammar is INVALID_GRAMMAR_OBJ: + error_msg = f"Invalid grammar request: {req.grammar_key=}" + req.set_finish_with_abort(error_msg) + + # Return failed requests + for i in synced_failed_req_idxs: + req = self.grammar_queue[i] + return_reqs.append(req) + + req.grammar.cancel() + self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ) + error_msg = f"Grammar preprocessing timed out: {req.grammar_key=}" + req.set_finish_with_abort(error_msg) + + # Remove finished requests from grammar_queue + self.grammar_queue = [ + req + for i, req in enumerate(self.grammar_queue) + if i not in synced_ready_req_idxs and i not in synced_failed_req_idxs + ] + return return_reqs diff --git a/sglang/python/sglang/srt/constrained/llguidance_backend.py b/sglang/python/sglang/srt/constrained/llguidance_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..d600a07f3724edaf2888f706f96df8dbfc1f91dc --- /dev/null +++ b/sglang/python/sglang/srt/constrained/llguidance_backend.py @@ -0,0 +1,178 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Constrained decoding with llguidance backend.""" + +import json +import logging +import os +from typing import List, Optional, Tuple + +import torch +from llguidance import LLMatcher, LLTokenizer, StructTag, grammar_from +from llguidance.hf import from_tokenizer +from llguidance.torch import ( + allocate_token_bitmask, + apply_token_bitmask_inplace, + fill_next_token_bitmask, +) + +from sglang.srt.constrained.base_grammar_backend import ( + INVALID_GRAMMAR_OBJ, + BaseGrammarBackend, + BaseGrammarObject, +) +from sglang.srt.constrained.utils import is_legacy_structural_tag + +logger = logging.getLogger(__name__) + + +class GuidanceGrammar(BaseGrammarObject): + + def __init__(self, llguidance_tokenizer: LLTokenizer, serialized_grammar: str): + super().__init__() + self.llguidance_tokenizer = llguidance_tokenizer + self.serialized_grammar = serialized_grammar + + self.ll_matcher = LLMatcher( + self.llguidance_tokenizer, + self.serialized_grammar, + log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), + ) + self.bitmask = None + + def accept_token(self, token: int): + if not self.ll_matcher.consume_token(token): + logger.warning(f"matcher error: {self.ll_matcher.get_error()}") + self.finished = True + + def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: + if self.ll_matcher.is_stopped(): + self.finished = True + + fill_next_token_bitmask(self.ll_matcher, vocab_mask, idx) + + def allocate_vocab_mask( + self, vocab_size: int, batch_size: int, device + ) -> torch.Tensor: + if self.bitmask is None or self.bitmask.shape[0] < batch_size: + # only create bitmask when batch gets larger + self.bitmask = allocate_token_bitmask( + batch_size, self.llguidance_tokenizer.vocab_size + ) + bitmask = self.bitmask + else: + bitmask = self.bitmask[:batch_size] + + return bitmask + + @staticmethod + def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: + return vocab_mask.to(device, non_blocking=True) + + @staticmethod + def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: + apply_token_bitmask_inplace(logits, vocab_mask) + + def copy(self): + return GuidanceGrammar( + llguidance_tokenizer=self.llguidance_tokenizer, + serialized_grammar=self.serialized_grammar, + ) + + def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]: + ff_tokens = self.ll_matcher.compute_ff_tokens() + if ff_tokens: + return ff_tokens, "" + else: + return None + + def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: + return "", -1 + + def jump_and_retokenize( + self, old_output_ids: List[int], new_output_ids: List[int], next_state: int + ): + pass + + +class GuidanceBackend(BaseGrammarBackend): + + def __init__( + self, + tokenizer, + any_whitespace: bool = True, + whitespace_pattern: Optional[str] = None, + n_vocab: Optional[int] = None, + ): + super().__init__() + + self.tokenizer = tokenizer + self.any_whitespace = any_whitespace + self.whitespace_pattern = whitespace_pattern + self.llguidance_tokenizer = from_tokenizer(self.tokenizer, n_vocab) + + def _from_serialized(self, serialized_grammar) -> Optional[GuidanceGrammar]: + try: + return GuidanceGrammar( + llguidance_tokenizer=self.llguidance_tokenizer, + serialized_grammar=serialized_grammar, + ) + except Exception as e: + logger.error(f"Hit invalid grammar: {serialized_grammar=}, {e=}") + return INVALID_GRAMMAR_OBJ + + def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]: + try: + serialized_grammar = LLMatcher.grammar_from_json_schema( + key_string, + defaults={ + "whitespace_flexible": self.any_whitespace, + "whitespace_pattern": self.whitespace_pattern, + }, + ) + except Exception as e: + logger.error(f"Hit invalid json_schema: {key_string=}, {e=}") + return INVALID_GRAMMAR_OBJ + return self._from_serialized(serialized_grammar) + + def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]: + serialized_grammar = grammar_from("regex", key_string) + return self._from_serialized(serialized_grammar) + + def dispatch_ebnf(self, key_string: str) -> Optional[GuidanceGrammar]: + try: + serialized_grammar = grammar_from("ebnf", key_string) + return self._from_serialized(serialized_grammar) + except ValueError as e: + logger.error(f"Hit invalid ebnf: {key_string=}, {e=}") + return INVALID_GRAMMAR_OBJ + + def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]: + try: + structural_tag = json.loads(key_string) + assert is_legacy_structural_tag(structural_tag) + tags = [ + StructTag( + begin=structure["begin"], + grammar=structure["schema"], + end=structure["end"], + trigger=structural_tag["triggers"][0], # TODO? + ) + for structure in structural_tag["structures"] + ] + g = StructTag.to_grammar(tags) + return self._from_serialized(g) + except Exception as e: + logger.error(f"Hit invalid structural_tag: {key_string=}, {e=}") + return INVALID_GRAMMAR_OBJ diff --git a/sglang/python/sglang/srt/constrained/outlines_backend.py b/sglang/python/sglang/srt/constrained/outlines_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..28831ab862cf281f10b987da098d4472be6032c0 --- /dev/null +++ b/sglang/python/sglang/srt/constrained/outlines_backend.py @@ -0,0 +1,190 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Constrained decoding with outlines backend.""" + +import json +import logging +from typing import Dict, List, Optional, Tuple, Union + +import interegular +import torch +from outlines.fsm.guide import RegexGuide +from outlines.models.transformers import TransformerTokenizer +from pydantic import BaseModel + +from sglang.srt.constrained.base_grammar_backend import ( + INVALID_GRAMMAR_OBJ, + BaseGrammarBackend, + BaseGrammarObject, +) +from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap + +try: + from outlines.fsm.json_schema import build_regex_from_schema +except ImportError: + from outlines_core.fsm.json_schema import build_regex_from_schema + + +logger = logging.getLogger(__name__) + + +class OutlinesGrammar(BaseGrammarObject): + def __init__( + self, + guide: RegexGuide, + jump_forward_map: Union[OutlinesJumpForwardMap, None], + ) -> None: + super().__init__() + self.guide = guide + self.jump_forward_map = jump_forward_map + self.state = 0 + + def accept_token(self, token: int): + self.state = self.guide.get_next_state(self.state, token) + + def allocate_vocab_mask( + self, vocab_size: int, batch_size: int, device + ) -> torch.Tensor: + return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device) + + @staticmethod + def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: + return vocab_mask + + def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: + tokens = torch.tensor( + self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64 + ).to(vocab_mask.device, non_blocking=True) + vocab_mask = vocab_mask[idx] + vocab_mask.fill_(1) + vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool)) + + @staticmethod + def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor): + logits.masked_fill_(vocab_mask, float("-inf")) + + def copy(self): + return OutlinesGrammar(self.guide, self.jump_forward_map) + + def try_jump_forward(self, tokenizer) -> Optional[Tuple]: + if not self.jump_forward_map: + return None + + jump_forward_bytes = self.jump_forward_map.jump_forward_byte(self.state) + if jump_forward_bytes is None or len(jump_forward_bytes) <= 1: + return None + + # preprocess the jump forward string + suffix_bytes = [] + continuation_range = range(0x80, 0xC0) + cur_state = self.state + while ( + len(jump_forward_bytes) and jump_forward_bytes[0][0] in continuation_range + ): + # continuation bytes + byte_edge = jump_forward_bytes.pop(0) + suffix_bytes.append(byte_edge[0]) + cur_state = byte_edge[1] + + suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes] + suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens) + return suffix_ids, cur_state + + def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: + _, cur_state = helper + return self.jump_forward_map.jump_forward_symbol(cur_state) + + def jump_and_retokenize( + self, old_output_ids: List[int], new_output_ids: List[int], next_state: int + ): + self.state = next_state + + +class OutlinesGrammarBackend(BaseGrammarBackend): + def __init__( + self, + tokenizer, + whitespace_pattern: str | None, + ): + super().__init__() + + try: + self.outlines_tokenizer = TransformerTokenizer(tokenizer) + except AttributeError: + # FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0) + origin_pad_token_id = tokenizer.pad_token_id + + def fset(self, value): + self._value = value + + type(tokenizer).pad_token_id = property( + fget=type(tokenizer).pad_token_id.fget, fset=fset + ) + self.outlines_tokenizer = TransformerTokenizer(tokenizer) + self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id + self.outlines_tokenizer.pad_token_id = origin_pad_token_id + self.outlines_tokenizer.pad_token = ( + self.outlines_tokenizer.tokenizer.pad_token + ) + self.outlines_tokenizer.vocabulary = ( + self.outlines_tokenizer.tokenizer.get_vocab() + ) + self.whitespace_pattern = whitespace_pattern + + def _compile_regex(self, regex: str) -> Optional[OutlinesGrammar]: + try: + if hasattr(RegexGuide, "from_regex"): + # outlines >= 0.1.1 + guide = RegexGuide.from_regex(regex, self.outlines_tokenizer) + else: + # outlines <= 0.0.46 + guide = RegexGuide(regex, self.outlines_tokenizer) + except interegular.patterns.InvalidSyntax as e: + logger.error(f"Hit invalid regex schema: {regex=}, {e=}") + return INVALID_GRAMMAR_OBJ + + jump_forward_map = None + return OutlinesGrammar(guide, jump_forward_map) + + def dispatch_ebnf(self, key_string: str): + return super().dispatch_ebnf(key_string) + + def dispatch_structural_tag(self, key_string: str): + return super().dispatch_structural_tag(key_string) + + def dispatch_json(self, key_string: str): + try: + regex = build_regex_from_object( + key_string, + whitespace_pattern=self.whitespace_pattern, + ) + except (NotImplementedError, json.decoder.JSONDecodeError, ValueError) as e: + logger.error(f"Hit invalid json_schema: {key_string=}, {e=}") + return INVALID_GRAMMAR_OBJ + return self._compile_regex(regex) + + def dispatch_regex(self, key_string: str): + return self._compile_regex(key_string) + + +def build_regex_from_object( + object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None +): + if isinstance(object, type(BaseModel)): + schema = json.dumps(object.model_json_schema()) + elif isinstance(object, Dict): + schema = json.dumps(object) + else: + schema = object + return build_regex_from_schema(schema, whitespace_pattern) diff --git a/sglang/python/sglang/srt/constrained/outlines_jump_forward.py b/sglang/python/sglang/srt/constrained/outlines_jump_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..8e19742c66f470baa51aef6c94d3d5ce8f045800 --- /dev/null +++ b/sglang/python/sglang/srt/constrained/outlines_jump_forward.py @@ -0,0 +1,200 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Faster constrained decoding with jump forward decoding / compressed finite state machine. +Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/ +""" + +import dataclasses +import logging +from collections import defaultdict +from typing import Optional + +import interegular +from interegular import InvalidSyntax +from outlines.caching import cache + +from sglang.srt.utils import get_bool_env_var + +try: + # outlines >= 0.1.0 + from outlines_core.fsm.outlines_core_rs import FSMInfo + from outlines_core.fsm.regex import make_byte_level_fsm, make_deterministic_fsm +except ImportError: + # outlines <= 0.0.46 + from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm + +IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" + +# Env var was set in sglang.srt.server_args.ServerArgs.__post_init__ +DISABLE_DISK_CACHE = get_bool_env_var("SGLANG_DISABLE_OUTLINES_DISK_CACHE", "true") + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class JumpEdge: + symbol: str = None + symbol_next_state: int = None + byte: int = None + byte_next_state: int = None + + +def disk_cache(expire: Optional[float] = None, typed=False, ignore=()): + if not DISABLE_DISK_CACHE: + return cache(expire, typed, ignore) + else: + return lambda fn: None + + +@disk_cache() +def init_state_to_jump_forward(regex_string): + try: + regex_pattern = interegular.parse_pattern(regex_string) + except InvalidSyntax as e: + logger.warning(f"skip invalid regex: {regex_string}, {e=}") + return + + byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True) + regex_fsm, _ = make_deterministic_fsm(byte_fsm) + + fsm_info: FSMInfo = regex_fsm.fsm_info + + symbol_to_id = fsm_info.alphabet_symbol_mapping + id_to_symbol = {} + for symbol, id_ in symbol_to_id.items(): + id_to_symbol.setdefault(id_, []).append(symbol) + + transitions = fsm_info.transitions + + outgoings_ct = defaultdict(int) + # NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally + for s in fsm_info.finals: + outgoings_ct[s] = 1 + + state_to_jump_forward = {} + for (state, id_), next_state in transitions.items(): + if id_ == fsm_info.alphabet_anything_value: + # Arbitrarily symbol cannot be recognized as jump forward + continue + + symbols = id_to_symbol[id_] + for c in symbols: + if len(c) > 1: + # Skip byte level transitions like c = "5E" + continue + + outgoings_ct[state] += 1 + if outgoings_ct[state] > 1: + if state in state_to_jump_forward: + del state_to_jump_forward[state] + break + + state_to_jump_forward[state] = JumpEdge( + symbol=c, + symbol_next_state=next_state, + ) + + # Process the byte level jump forward + outgoings_ct = defaultdict(int) + for s in fsm_info.finals: + outgoings_ct[s] = 1 + + for (state, id_), next_state in transitions.items(): + if id_ == fsm_info.alphabet_anything_value: + continue + symbols = id_to_symbol[id_] + for c in symbols: + byte_ = None + if len(c) == 1 and ord(c) < 0x80: + # ASCII character + byte_ = ord(c) + elif len(c) > 1: + # FIXME: This logic is due to the leading \x00 + # https://github.com/outlines-dev/outlines/pull/930 + byte_ = int(symbols[0][1:], 16) + + if byte_ is not None: + outgoings_ct[state] += 1 + if outgoings_ct[state] > 1: + if state in state_to_jump_forward: + del state_to_jump_forward[state] + break + e = state_to_jump_forward.get(state, JumpEdge()) + e.byte = byte_ + e.byte_next_state = next_state + state_to_jump_forward[state] = e + + return state_to_jump_forward + + +class OutlinesJumpForwardMap: + def __init__(self, regex_string): + self.state_to_jump_forward = init_state_to_jump_forward(regex_string) + + def jump_forward_symbol(self, state): + jump_forward_str = "" + next_state = state + while state in self.state_to_jump_forward: + e = self.state_to_jump_forward[state] + if e.symbol is None: + break + jump_forward_str += e.symbol + next_state = e.symbol_next_state + state = next_state + + return jump_forward_str, next_state + + def jump_forward_byte(self, state): + if state not in self.state_to_jump_forward: + return None + + jump_forward_bytes = [] + next_state = None + while state in self.state_to_jump_forward: + e = self.state_to_jump_forward[state] + assert e.byte is not None and e.byte_next_state is not None + jump_forward_bytes.append((e.byte, e.byte_next_state)) + next_state = e.byte_next_state + state = next_state + + return jump_forward_bytes + + def is_jump_forward_symbol_state(self, state): + return ( + state in self.state_to_jump_forward + and self.state_to_jump_forward[state].symbol is not None + ) + + +def test_main(regex_string): + jump_forward_map = OutlinesJumpForwardMap(regex_string) + for state, e in jump_forward_map.state_to_jump_forward.items(): + if e.symbol is not None: + jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state) + print(f"{state} -> {next_state}", jump_forward_str) + bytes_ = jump_forward_map.jump_forward_byte(state) + print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_]) + + +if __name__ == "__main__": + import outlines + + outlines.caching.clear_cache() + test_main(r"The google's DNS sever address is " + IP_REGEX) + test_main(r"霍格沃茨特快列车|霍比特人比尔博") + # 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ... + # 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ... + + test_main(r"[-+]?[0-9]+[ ]*") diff --git a/sglang/python/sglang/srt/constrained/reasoner_grammar_backend.py b/sglang/python/sglang/srt/constrained/reasoner_grammar_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..e2ae8405e315a0ecc961f15f607f84d9ae92f99e --- /dev/null +++ b/sglang/python/sglang/srt/constrained/reasoner_grammar_backend.py @@ -0,0 +1,124 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The baseclass of a backend for reasoner grammar-guided constrained decoding.""" + +from typing import List, Optional, Tuple + +import torch + +from .base_grammar_backend import ( + INVALID_GRAMMAR_OBJ, + BaseGrammarBackend, + BaseGrammarObject, +) + + +class ReasonerGrammarObject(BaseGrammarObject): + def __init__(self, grammar: BaseGrammarObject, think_end_id: int): + super().__init__() + self.grammar = grammar + self.think_end_id = think_end_id + # -1 means thinking has not ended yet + # 0 means just ended thinking in the last token + # + means number of tokens after thinking ended + self.tokens_after_think_end = -1 + + def maybe_init_reasoning(self, reasoning: bool): + self.tokens_after_think_end = -1 if reasoning else 0 + + def transfer_state(self, token: int) -> int: + if self.tokens_after_think_end == -1 and token == self.think_end_id: + self.tokens_after_think_end = 0 + elif self.tokens_after_think_end >= 0: + self.tokens_after_think_end += 1 + + def rollback_state(self): + if self.tokens_after_think_end == 0: + self.tokens_after_think_end = -1 + elif self.tokens_after_think_end > 0: + self.tokens_after_think_end -= 1 + + def accept_token(self, token: int): + if self.tokens_after_think_end >= 0: + self.grammar.accept_token(token) + self.transfer_state(token) + + def is_terminated(self): + return self.grammar.is_terminated() + + def rollback(self, k): + steps_after_think = min(k, self.tokens_after_think_end) + if steps_after_think > 0: + self.grammar.rollback(steps_after_think) + + for _ in range(k): + self.rollback_state() + + def allocate_vocab_mask( + self, vocab_size: int, batch_size: int, device + ) -> torch.Tensor: + return self.grammar.allocate_vocab_mask(vocab_size, batch_size, device) + + def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: + if self.tokens_after_think_end >= 0: + self.grammar.fill_vocab_mask(vocab_mask, idx) + + def move_vocab_mask(self, vocab_mask: torch.Tensor, device) -> torch.Tensor: + return self.grammar.move_vocab_mask(vocab_mask, device) + + @property + def apply_vocab_mask(self): + return self.grammar.apply_vocab_mask + + def copy(self) -> BaseGrammarObject: + return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id) + + @property + def finished(self): + return self.grammar.finished + + @finished.setter + def finished(self, finished): + self.grammar.finished = finished + + def try_jump_forward(self, tokenizer): + return self.grammar.try_jump_forward(tokenizer) + + def jump_forward_str_state(self, helper): + return self.grammar.jump_forward_str_state(helper) + + def jump_and_retokenize( + self, old_output_ids: List[int], new_output_ids: List[int], next_state: int + ): + return self.grammar.jump_and_retokenize( + old_output_ids, new_output_ids, next_state + ) + + +class ReasonerGrammarBackend(BaseGrammarBackend): + def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id): + super().__init__() + self.grammar_backend = grammar_backend + self.think_end_id = think_end_id + + def _init_value_dispatch( + self, key: Tuple[str, str], reasoning: bool + ) -> Optional[BaseGrammarObject]: + ret = self.grammar_backend._init_value_dispatch(key, reasoning) + # avoid wrapping invalid grammar, so that the scheduler can detect it + if ret is None or ret is INVALID_GRAMMAR_OBJ: + return ret + obj = ReasonerGrammarObject(ret, self.think_end_id) + obj.maybe_init_reasoning(reasoning) + return obj diff --git a/sglang/python/sglang/srt/constrained/utils.py b/sglang/python/sglang/srt/constrained/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..40cdcc43411306c139f6251e674a334ae6bb8950 --- /dev/null +++ b/sglang/python/sglang/srt/constrained/utils.py @@ -0,0 +1,12 @@ +from typing import Dict + + +def is_legacy_structural_tag(obj: Dict) -> bool: + # test whether an object is a legacy structural tag + # see `StructuralTagResponseFormat` at `sglang.srt.entrypoints.openai.protocol` + if obj.get("structures", None) is not None: + assert obj.get("triggers", None) is not None + return True + else: + assert obj.get("format", None) is not None + return False diff --git a/sglang/python/sglang/srt/constrained/xgrammar_backend.py b/sglang/python/sglang/srt/constrained/xgrammar_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..3bc30652102934f4ae62dedd7b99a4ae804e3d57 --- /dev/null +++ b/sglang/python/sglang/srt/constrained/xgrammar_backend.py @@ -0,0 +1,356 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Constrained decoding with xgrammar backend.""" + +import dataclasses +import json +import logging +from typing import Dict, List, Optional, Tuple, Union + +import torch +from xgrammar import ( + CompiledGrammar, + GrammarCompiler, + GrammarMatcher, + StructuralTagItem, + TokenizerInfo, + allocate_token_bitmask, +) + +from sglang.srt.constrained.base_grammar_backend import ( + INVALID_GRAMMAR_OBJ, + BaseGrammarBackend, + BaseGrammarObject, + GrammarStats, +) +from sglang.srt.constrained.utils import is_legacy_structural_tag +from sglang.srt.utils import is_hip + +_is_hip = is_hip() +if _is_hip: + from sgl_kernel import apply_token_bitmask_inplace_cuda +else: + from sglang.srt.constrained.triton_ops.bitmask_ops import ( + apply_token_bitmask_inplace_triton, + ) + + +logger = logging.getLogger(__name__) +MAX_ROLLBACK_TOKENS = 200 + + +class XGrammarGrammar(BaseGrammarObject): + + def __init__( + self, + matcher: GrammarMatcher, + vocab_size: int, + ctx: CompiledGrammar, + override_stop_tokens: Optional[Union[List[int], int]], + key_string: Optional[str] = None, # TODO (sk): for debugging, remove later + grammar_stats: Optional[GrammarStats] = GrammarStats(), + ) -> None: + super().__init__() + self.matcher = matcher + self.vocab_size = vocab_size + self.ctx = ctx + self.override_stop_tokens = override_stop_tokens + self.accepted_tokens = [] + self.key_string = key_string + self.grammar_stats = grammar_stats + + def accept_token(self, token: int): + if not self.is_terminated(): + self.current_token = token + accepted = self.matcher.accept_token(token) + if not accepted: + # log for debugging + raise ValueError( + f"Tokens not accepted: {token}\n" + f"Accepted tokens: {self.accepted_tokens}\n" + f"Key string: {self.key_string}" + ) + else: + self.accepted_tokens.append(token) + + def rollback(self, k: int): + self.matcher.rollback(k) + self.accepted_tokens = self.accepted_tokens[:-k] + + def is_terminated(self): + return self.matcher.is_terminated() + + def allocate_vocab_mask( + self, vocab_size: int, batch_size: int, device + ) -> torch.Tensor: + return allocate_token_bitmask(batch_size, vocab_size) + + def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: + self.matcher.fill_next_token_bitmask(vocab_mask, idx) + + @staticmethod + def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: + return vocab_mask.to(device, non_blocking=True) + + def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: + if ( + logits.device.type == "cuda" + or logits.device.type == "npu" + or logits.device.type == "xpu" + ): + if _is_hip: + apply_token_bitmask_inplace_cuda(logits, vocab_mask) + else: + apply_token_bitmask_inplace_triton(logits, vocab_mask) + elif logits.device.type == "cpu" and self.apply_vocab_mask_cpu: + self.apply_vocab_mask_cpu(logits, vocab_mask) + else: + raise RuntimeError(f"Unsupported device: {logits.device.type}") + + def copy(self): + matcher = GrammarMatcher( + self.ctx, + max_rollback_tokens=MAX_ROLLBACK_TOKENS, + override_stop_tokens=self.override_stop_tokens, + ) + return XGrammarGrammar( + matcher, + self.vocab_size, + self.ctx, + self.override_stop_tokens, + self.key_string, + dataclasses.replace( + self.grammar_stats, is_cache_hit=True, tree_traversal_time=[] + ), + ) + + def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]: + s = self.matcher.find_jump_forward_string() + if s: + return [], s + return None + + def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: + _, data = helper + return data, -1 + + def jump_and_retokenize( + self, old_output_ids: List[int], new_output_ids: List[int], next_state: int + ): + k = 0 + for i, old_id in enumerate(old_output_ids): + if old_id == new_output_ids[i]: + k = i + 1 + else: + break + + # rollback to the last token that is the same + if k < len(old_output_ids): + self.matcher.rollback(len(old_output_ids) - k) + + for i in range(k, len(new_output_ids)): + assert self.matcher.accept_token(new_output_ids[i]) + + def __repr__(self): + return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=}, {self.current_token=})" + + +class TokenizerNotSupportedError(Exception): + """Raised when tokenizer is not supported by XGrammar backend.""" + + pass + + +class XGrammarGrammarBackend(BaseGrammarBackend): + def __init__( + self, + tokenizer, + vocab_size: int, + model_eos_token_ids: Optional[List[int]] = None, + any_whitespace: bool = True, + ): + super().__init__() + + if hasattr(tokenizer, "init_xgrammar"): + # For special tokenizer + tokenizer_info, override_stop_tokens = tokenizer.init_xgrammar() + + if tokenizer_info is None: + # Not supported tokenizer + raise TokenizerNotSupportedError( + f"Tokenizer type {type(tokenizer).__name__} is not supported by XGrammar" + ) + else: + # Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens + # This ensures consistency between what the model considers EOS and what XGrammar uses + try: + tokenizer_info = TokenizerInfo.from_huggingface( + tokenizer, vocab_size=vocab_size, stop_token_ids=model_eos_token_ids + ) + override_stop_tokens = None + except Exception as e: + raise TokenizerNotSupportedError( + f"Failed to create XGrammar TokenizerInfo from tokenizer: {e}" + ) + + self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info) + self.vocab_size = vocab_size + self.override_stop_tokens = override_stop_tokens + self.any_whitespace = any_whitespace + + @staticmethod + def _sanitize_structural_format(structural_format): + """Recursively replace missing json_schema fields with an empty schema.""" + if not isinstance(structural_format, dict): + return + + fmt_type = structural_format.get("type") + if fmt_type in {"json_schema", "qwen_xml_parameter"}: + if structural_format.get("json_schema") is None: + structural_format["json_schema"] = {} + + if fmt_type == "tag": + XGrammarGrammarBackend._sanitize_structural_format( + structural_format.get("content") + ) + elif fmt_type in {"sequence", "or"}: + for element in structural_format.get("elements", []): + XGrammarGrammarBackend._sanitize_structural_format(element) + elif fmt_type in {"triggered_tags", "tags_with_separator"}: + for tag in structural_format.get("tags", []): + XGrammarGrammarBackend._sanitize_structural_format(tag) + + @staticmethod + def _sanitize_structural_tag_structures(structural_tag: Dict) -> None: + for structure in structural_tag.get("structures", []): + if structure.get("schema") is None: + structure["schema"] = {} + + def _from_context( + self, ctx: CompiledGrammar, key_string: str, grammar_stats: GrammarStats + ) -> XGrammarGrammar: + matcher = GrammarMatcher( + ctx, + max_rollback_tokens=MAX_ROLLBACK_TOKENS, + override_stop_tokens=self.override_stop_tokens, + ) + return XGrammarGrammar( + matcher, + self.vocab_size, + ctx, + self.override_stop_tokens, + key_string, + grammar_stats, + ) + + def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]: + try: + if key_string == "$$ANY$$": + # Note: This builtin JSON grammar includes *all* valid JSON (including, for example, arrays at the root) + ctx = self.grammar_compiler.compile_builtin_json_grammar() + else: + ctx = self.grammar_compiler.compile_json_schema( + schema=key_string, any_whitespace=self.any_whitespace + ) + + except (RuntimeError, json.decoder.JSONDecodeError, UnicodeDecodeError) as e: + logger.error(f"Hit invalid json_schema: {key_string=}, {e=}") + return INVALID_GRAMMAR_OBJ + return self._from_context(ctx, key_string, GrammarStats(dispatch_type="json")) + + def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]: + try: + ctx = self.grammar_compiler.compile_grammar(key_string) + except RuntimeError as e: + logger.error(f"Hit invalid ebnf: {key_string=}, {e=}") + return INVALID_GRAMMAR_OBJ + return self._from_context(ctx, key_string, GrammarStats(dispatch_type="ebnf")) + + def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]: + try: + ctx = self.grammar_compiler.compile_regex(key_string) + except RuntimeError as e: + logger.error(f"Hit invalid regex: {key_string=}, {e=}") + return INVALID_GRAMMAR_OBJ + return self._from_context(ctx, key_string, GrammarStats(dispatch_type="regex")) + + def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]: + try: + # TODO(dark): it's REALLY stupid to construct object from string and decode it again + structural_tag = json.loads(key_string) + if is_legacy_structural_tag(structural_tag): + self._sanitize_structural_tag_structures(structural_tag) + tags = [ + StructuralTagItem( + begin=structure["begin"], + schema=json.dumps(structure["schema"]), + end=structure["end"], + ) + for structure in structural_tag["structures"] + ] + ctx = self.grammar_compiler.compile_structural_tag( + tags, structural_tag["triggers"] + ) + else: + format_dict = structural_tag.get("format") + if isinstance(format_dict, dict): + self._sanitize_structural_format(format_dict) + structural_tag["format"] = format_dict + key_string = json.dumps(structural_tag) + ctx = self.grammar_compiler.compile_structural_tag(key_string) + except (RuntimeError, json.decoder.JSONDecodeError) as e: + logger.error(f"Hit invalid structural_tag: {key_string=}, {e=}") + return INVALID_GRAMMAR_OBJ + return self._from_context( + ctx, key_string, GrammarStats(dispatch_type="structural_tag") + ) + + def reset(self): + self.grammar_compiler.clear_cache() + + +def demo_test(): + from transformers import AutoConfig, AutoTokenizer + + from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST + + tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL_NAME_FOR_TEST) + hf_config = AutoConfig.from_pretrained(DEFAULT_MODEL_NAME_FOR_TEST) + + # Should use vocab size from model config + vocab_size = hf_config.vocab_size + eos_token_id = tokenizer.eos_token_id + + backend = XGrammarGrammarBackend( + tokenizer, vocab_size=vocab_size, model_eos_token_ids=[eos_token_id] + ) + regex = r"hello (world|there)" + grammar = backend.dispatch_regex(regex) + tokens = [ + tokenizer.encode(t, add_special_tokens=False)[0] for t in ["hello", " world"] + ] + + # Test termination + grammar.accept_token(tokens[0]) # accept "hello" + grammar.accept_token(tokens[1]) # accept " world" + grammar.accept_token(eos_token_id) # accept EOS + assert grammar.is_terminated() + + # Test rollback the terminated state + grammar.rollback(1) + assert not grammar.is_terminated() + + +if __name__ == "__main__": + demo_test() diff --git a/sglang/python/sglang/srt/debug_utils/__init__.py b/sglang/python/sglang/srt/debug_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sglang/python/sglang/srt/debug_utils/cuda_coredump.py b/sglang/python/sglang/srt/debug_utils/cuda_coredump.py new file mode 100644 index 0000000000000000000000000000000000000000..1507467ddba0929771ed45eafaf6d10fb931a0fe --- /dev/null +++ b/sglang/python/sglang/srt/debug_utils/cuda_coredump.py @@ -0,0 +1,95 @@ +"""CUDA coredump helpers. + +When SGLANG_CUDA_COREDUMP=1, this module injects CUDA coredump environment +variables into the current process so that GPU exceptions (e.g. illegal +memory access) produce lightweight coredump files for post-mortem analysis +with cuda-gdb. + +The injection happens at module import time via _inject_env() on a +best-effort basis. If any CUDA_* variable is already present in the +environment (e.g. set by the user in the shell), injection is skipped for +that variable and a warning is printed. For strict guarantees, set the +CUDA_* env vars in the shell before launching Python. +""" + +import glob +import os +import warnings + +from sglang.srt.environ import envs + +_CUDA_COREDUMP_FLAGS = ( + "skip_nonrelocated_elf_images,skip_global_memory," + "skip_shared_memory,skip_local_memory,skip_constbank_memory" +) + + +def is_enabled() -> bool: + return envs.SGLANG_CUDA_COREDUMP.get() + + +def get_dump_dir() -> str: + return envs.SGLANG_CUDA_COREDUMP_DIR.get() + + +def _inject_env(): + """Inject CUDA coredump environment variables into the current process. + If a CUDA_* variable is already present, skip it and log a warning.""" + dump_dir = get_dump_dir() + os.makedirs(dump_dir, exist_ok=True) + + env_vars = { + "CUDA_ENABLE_COREDUMP_ON_EXCEPTION": "1", + "CUDA_COREDUMP_SHOW_PROGRESS": "1", + "CUDA_COREDUMP_GENERATION_FLAGS": _CUDA_COREDUMP_FLAGS, + "CUDA_COREDUMP_FILE": f"{dump_dir}/cuda_coredump_%h.%p.%t", + } + for key, value in env_vars.items(): + if key in os.environ: + warnings.warn( + f"CUDA coredump env var {key} is already set to " + f"'{os.environ[key]}', skipping injection of '{value}'.", + stacklevel=2, + ) + else: + os.environ[key] = value + + +def cleanup_dump_dir(): + """Remove stale coredump files from the dump directory.""" + dump_dir = get_dump_dir() + for f in glob.glob(os.path.join(dump_dir, "cuda_coredump_*")): + os.remove(f) + + +def report(): + """Log any CUDA coredump files found after a test failure.""" + dump_dir = get_dump_dir() + coredump_files = glob.glob(os.path.join(dump_dir, "cuda_coredump_*")) + if not coredump_files: + return + + print(f"\n{'='*60}") + print(f"CUDA coredump(s) detected ({len(coredump_files)} file(s)):") + for f in coredump_files: + size_mb = os.path.getsize(f) / (1024 * 1024) + print(f" {f} ({size_mb:.1f} MB)") + print("Use cuda-gdb to analyze: cuda-gdb -c ") + + run_id = os.environ.get("GITHUB_RUN_ID") + if run_id: + repo = os.environ.get("GITHUB_REPOSITORY", "sgl-project/sglang") + print(f"Download from CI: gh run download {run_id} --repo {repo}") + + print(f"{'='*60}\n") + + +# Auto-inject CUDA coredump env vars at import time. +# The sentinel env var is inherited by child processes, so injection only +# happens once in the top-level process. +_SENTINEL = "_SGLANG_CUDA_COREDUMP_INJECTED" + +if is_enabled() and _SENTINEL not in os.environ: + os.environ[_SENTINEL] = "1" + print(f"Injecting CUDA coredump env vars (pid={os.getpid()})") + _inject_env() diff --git a/sglang/python/sglang/srt/debug_utils/dump_comparator.py b/sglang/python/sglang/srt/debug_utils/dump_comparator.py new file mode 100644 index 0000000000000000000000000000000000000000..6f5a3397d643de9d3933b1792733f1a16996c7c4 --- /dev/null +++ b/sglang/python/sglang/srt/debug_utils/dump_comparator.py @@ -0,0 +1,296 @@ +"""Simplified dump comparator — a self-contained single-file script for comparing +two dump directories tensor-by-tensor. + +For advanced features (unshard, token alignment, per-dimension annotations), see the +full ``comparator/`` package: ``python -m sglang.srt.debug_utils.comparator``. +""" + +import argparse +import functools +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, List, Optional + +import torch + +from sglang.srt.debug_utils.dumper import get_truncated_value + + +def main(args): + import polars as pl + + from sglang.srt.debug_utils.dump_loader import find_row, read_meta + + df_target = read_meta(args.target_path) + df_target = df_target.filter( + (pl.col("step") >= args.start_step) & (pl.col("step") <= args.end_step) + ) + if args.filter: + df_target = df_target.filter(pl.col("filename").str.contains(args.filter)) + assert all(c in df_target.columns for c in ["rank", "step", "dump_index", "name"]) + + df_baseline = read_meta(args.baseline_path) + print("df_target", df_target) + print("df_baseline", df_baseline) + + tensor_dim_descs: List[TensorDimDesc] = _get_tensor_dim_descs() + + for row in df_target.iter_rows(named=True): + path_target = Path(args.target_path) / row["filename"] + + tensor_dim_desc: Optional[TensorDimDesc] = None + if tensor_dim_descs: + matched: list[TensorDimDesc] = [ + desc + for desc in tensor_dim_descs + if re.search(desc.pattern, row["filename"]) is not None + ] + if matched: + tensor_dim_desc = matched[0] + + row_baseline = find_row( + df_baseline, + conditions=dict( + step=row["step"], + **{ + k: v + for k, v in row.items() + if k not in ["step", "dump_index", "filename"] + }, + ), + ) + + if row_baseline is None: + print(f"Skip: target={str(path_target)} since no baseline") + x_target = _load_object(path_target) + if x_target is not None: + print(f"x_target(sample)={get_truncated_value(x_target)}") + continue + + path_baseline = Path(args.baseline_path) / row_baseline["filename"] + print( + f"Check:\n" + f"target={str(path_target)} (duplicate_index={row['duplicate_index']})\n" + f"baseline={str(path_baseline)} (duplicate_index={row_baseline['duplicate_index']})" + ) + check_tensor_pair( + path_baseline=path_baseline, + path_target=path_target, + diff_threshold=args.diff_threshold, + name=row["name"], + tensor_dim_desc=tensor_dim_desc, + ) + print() + + +def check_tensor_pair( + path_baseline, + path_target, + diff_threshold: float = 1e-3, + name="", + tensor_dim_desc: Optional["TensorDimDesc"] = None, +): + x_baseline = _load_object(path_baseline) + x_target = _load_object(path_target) + + if x_baseline is None or x_target is None: + print( + f"Skip comparison because of None: x_baseline={x_baseline}, x_target={x_target}" + ) + return + + print( + f"Raw " + f"[shape] {x_baseline.shape} vs {x_target.shape}\t" + f"[{'' if x_baseline.dtype == x_target.dtype else '🟠'}dtype] {x_baseline.dtype} vs {x_target.dtype}" + ) + + if tensor_dim_desc is not None: + import einops + + x_baseline = einops.rearrange( + x_baseline, + tensor_dim_desc.baseline_desc + " -> " + tensor_dim_desc.target_desc, + ) + if tensor_dim_desc.baseline_cropper is not None: + print("Apply baseline_cropper") + x_baseline = tensor_dim_desc.baseline_cropper(x_baseline) + + x_baseline, x_target = _comparison_preprocessor(x_baseline, x_target, name=name) + x_baseline = _try_unify_shape(x_baseline, target_shape=x_target.shape) + + print( + f"After preprocessor " + f"[shape] {x_baseline.shape} vs {x_target.shape}\t" + f"[dtype] {x_baseline.dtype} vs {x_target.dtype}" + ) + + x_baseline_original_dtype = x_baseline.dtype + x_target_original_dtype = x_target.dtype + + x_target = x_target.float() + x_baseline = x_baseline.float() + + for name, fn in [ + ("mean", torch.mean), + ("std", torch.std), + ("min", torch.min), + ("max", torch.max), + *( + [ + ("p1", functools.partial(torch.quantile, q=0.01)), + ("p5", functools.partial(torch.quantile, q=0.05)), + ("p95", functools.partial(torch.quantile, q=0.95)), + ("p99", functools.partial(torch.quantile, q=0.99)), + ] + if x_baseline.numel() < 10_000_000 + else [] + ), + ]: + value_baseline = fn(x_baseline).item() + value_target = fn(x_target).item() + print( + f"[{name}] {value_baseline :.4f} vs {value_target:.4f} (diff: {value_target - value_baseline:.4f})" + ) + + if x_baseline.shape != x_target.shape: + print(f"⚠️ Shape mismatch") + return + + diff_info = _compute_and_print_diff( + x_baseline=x_baseline, + x_target=x_target, + diff_threshold=diff_threshold, + ) + needs_print = diff_info["max_abs_diff"] > 1e-3 + + if (x_baseline_original_dtype != x_target_original_dtype) and ( + ( + downcast_dtype := _compute_smaller_dtype( + x_baseline_original_dtype, x_target_original_dtype + ) + ) + is not None + ): + _compute_and_print_diff( + x_baseline=x_baseline.to(downcast_dtype), + x_target=x_target.to(downcast_dtype), + diff_threshold=diff_threshold, + prefix_text=f"When downcast to {downcast_dtype}: ", + ) + + if needs_print: + print(f"x_baseline(sample)={get_truncated_value(x_baseline)}") + print(f"x_target(sample)={get_truncated_value(x_target)}") + + +def _compute_and_print_diff( + x_baseline, x_target, diff_threshold: float, prefix_text="" +): + raw_abs_diff = (x_target - x_baseline).abs() + + max_abs_diff = raw_abs_diff.max().item() + mean_abs_diff = raw_abs_diff.mean().item() + rel_diff = _calc_rel_diff(x_target, x_baseline) + + rel_diff_marker: str = "❌" if rel_diff > diff_threshold else "✅" + print( + prefix_text + + f"{rel_diff_marker} rel_diff={rel_diff}\t" + + f"max_abs_diff={max_abs_diff}\t" + + f"mean_abs_diff={mean_abs_diff}" + ) + + max_diff_coord = _argmax_coord(raw_abs_diff) + print( + f"max_abs_diff happens at coord={max_diff_coord} with " + f"baseline={x_baseline[max_diff_coord].item()} " + f"target={x_target[max_diff_coord].item()}" + ) + + return dict(max_abs_diff=max_abs_diff) + + +def _argmax_coord(x: torch.Tensor) -> tuple: + flat_idx = x.argmax() + return tuple(idx.item() for idx in torch.unravel_index(flat_idx, x.shape)) + + +def _compute_smaller_dtype(dtype_a, dtype_b): + info_dict = { + (torch.float32, torch.bfloat16): torch.bfloat16, + # ... add more ... + } + return info_dict.get((dtype_a, dtype_b)) or info_dict.get((dtype_b, dtype_a)) + + +def _try_unify_shape(x: torch.Tensor, target_shape): + x_shape = x.shape + num_dim_to_remove = len(x_shape) - len(target_shape) + if (x_shape[num_dim_to_remove:] == target_shape) and all( + val == 1 for val in x_shape[:num_dim_to_remove] + ): + out = functools.reduce(lambda a, _: a.squeeze(0), range(num_dim_to_remove), x) + print(f"Unify shape: {x_shape} -> {out.shape} (to match {target_shape})") + return out + + return x + + +# Copied from DeepGEMM +def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def _load_object(path): + try: + x = torch.load(path, weights_only=False) + except Exception as e: + print(f"Skip load {path} since error {e}") + return None + + if isinstance(x, dict) and "value" in x: + x = x["value"] + + if not isinstance(x, torch.Tensor): + print(f"Skip load {path} since {type(x)=} is not a Tensor ({x=})") + return None + return x.cuda() + + +def _comparison_preprocessor(x_baseline, x_target, name): + """Customization endpoint. Can insert arbitrary adhoc postprocessing logic here.""" + return x_baseline, x_target + + +@dataclass +class TensorDimDesc: + pattern: str + baseline_desc: str + target_desc: str + baseline_cropper: Optional[Callable[[torch.Tensor], torch.Tensor]] = None + + +def _get_tensor_dim_descs() -> List[TensorDimDesc]: + """Customization endpoint. Return a list of TensorDimDesc to rearrange baseline + dimensions to match target layout via einops before comparison.""" + return [] + + +if __name__ == "__main__": + # python -m sglang.srt.debug_utils.dump_comparator --baseline-path ... --target-path ... + parser = argparse.ArgumentParser() + parser.add_argument("--baseline-path", type=str) + parser.add_argument("--target-path", type=str) + parser.add_argument("--start-step", type=int, default=0) + parser.add_argument("--end-step", type=int, default=1000000) + parser.add_argument("--diff-threshold", type=float, default=1e-3) + parser.add_argument( + "--filter", type=str, default=None, help="Regex to filter filenames" + ) + args = parser.parse_args() + main(args) diff --git a/sglang/python/sglang/srt/debug_utils/dump_loader.py b/sglang/python/sglang/srt/debug_utils/dump_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..f35a455c2c990d3713e962c36e6e4365c50651af --- /dev/null +++ b/sglang/python/sglang/srt/debug_utils/dump_loader.py @@ -0,0 +1,183 @@ +import functools +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Tuple + +import polars as pl +import torch + +LOAD_FAILED: object = object() + + +def parse_meta_from_filename(path: Path) -> Dict[str, Any]: + stem = Path(path).stem + result: Dict[str, Any] = {} + for kv in stem.split("___"): + if "=" in kv: + k, v = kv.split("=", 1) + result[k] = v + for field_name, converter in _TYPED_FIELDS: + if field_name in result: + result[field_name] = converter(result[field_name]) + return result + + +@dataclass +class ValueWithMeta: + value: Any + meta: Dict[str, Any] + + @staticmethod + def load(path: Path) -> "ValueWithMeta": + path = Path(path) + meta_from_filename = parse_meta_from_filename(path) + + try: + raw = torch.load(path, weights_only=False, map_location="cpu") + except Exception as e: + print(f"Skip load {path} since error {e}") + return ValueWithMeta( + value=LOAD_FAILED, meta={**meta_from_filename, "filename": path.name} + ) + + value, meta_from_embedded = _unwrap_dict_format(raw) + return ValueWithMeta( + value=value, + meta={**meta_from_filename, **meta_from_embedded, "filename": path.name}, + ) + + +def _unwrap_dict_format(obj: Any) -> Tuple[Any, Dict[str, Any]]: + if isinstance(obj, dict) and "value" in obj: + meta = obj.get("meta", {}) + assert isinstance(meta, dict), f"Expected meta to be dict, got {type(meta)}" + return obj["value"], meta + return obj, {} + + +class DumpLoader: + def __init__(self): + directory = os.environ.get("SGLANG_DUMP_LOADER_DIR") + + self._enable = directory is not None + if self._enable: + self._directory = Path(directory) + self._df = read_meta(directory) + + @property + def enable(self): + return self._enable + + def load(self, name, **kwargs): + assert self._enable, "Please call DumpLoader.load only when it is enabled" + + from sglang.srt.debug_utils.dumper import dumper + + step = dumper._state.step + conditions = dict(name=name, step=step, **kwargs) + row = find_row(self._df, conditions=conditions) + assert ( + row is not None + ), f"DumpLoader cannot find row given query {name=} {kwargs=} {self._directory=}" + + path = self._directory / row["filename"] + output = torch.load(path, weights_only=False) + if isinstance(output, dict) and "value" in output: + output = output["value"] + + print( + f"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})" + ) + return output + + +def read_meta(directory): + directory = Path(directory) + assert directory.is_dir(), f"{directory=} should be a directory" + + rows = [] + for p in directory.glob("*.pt"): + try: + full_kwargs = parse_meta_from_filename(p) + rows.append( + { + "filename": str(p.name), + **full_kwargs, + } + ) + except Exception as e: + print(f"[DumpLoader] skip loading {p} due to error {e}") + + df = pl.DataFrame(rows) + df = df.with_columns( + pl.col("step").cast(int), + pl.col("rank").cast(int), + pl.col("dump_index").cast(int), + ) + df = _add_duplicate_index(df) + df = df.sort("rank", "dump_index") + return df + + +def _add_duplicate_index(df: pl.DataFrame) -> pl.DataFrame: + group_cols = [c for c in df.columns if c not in ["filename", "dump_index"]] + df = df.sort(group_cols + ["dump_index"]) + df = df.with_columns( + pl.cum_count("dump_index").over(group_cols).sub(1).alias("duplicate_index") + ) + return df + + +def filter_rows(df: pl.DataFrame, conditions: Dict[str, Any]) -> list[dict]: + filter_exprs = [ + ( + pl.col(col) == _cast_to_polars_dtype(conditions[col], df.schema[col]) + if conditions[col] is not None + else pl.col(col).is_null() + ) + for col in conditions + if col in df.columns + ] + if not filter_exprs: + return [] + return df.filter(functools.reduce(lambda a, b: a & b, filter_exprs)).to_dicts() + + +def find_row(df: pl.DataFrame, conditions: Dict[str, Any]): + rows = filter_rows(df, conditions) + if len(rows) > 1: + print(f"find_row find ambiguous results: {rows=}") + return None + return rows[0] if rows else None + + +def _cast_to_polars_dtype(value, target_dtype): + if target_dtype in (pl.Int64, pl.Int32, pl.UInt64, pl.UInt32): + return int(value) + elif target_dtype in (pl.Float64, pl.Float32): + return float(value) + elif target_dtype == pl.Boolean: + return bool(value) + elif target_dtype == pl.String: + return str(value) + else: + return value + + +def read_tokenizer_path(directory: Path) -> Optional[str]: + """Read tokenizer_path from any .pt file's embedded metadata in a dump directory.""" + for p in directory.glob("*.pt"): + item: ValueWithMeta = ValueWithMeta.load(p) + tokenizer_path: Optional[str] = item.meta.get("tokenizer_path") + if tokenizer_path is not None: + return str(tokenizer_path) + return None + + +_TYPED_FIELDS: list[tuple[str, Callable[[str], Any]]] = [ + ("rank", int), +] + + +dump_loader = DumpLoader() diff --git a/sglang/python/sglang/srt/debug_utils/dumper.py b/sglang/python/sglang/srt/debug_utils/dumper.py new file mode 100644 index 0000000000000000000000000000000000000000..c2715f19e169b04281efea012d791cbe981919f9 --- /dev/null +++ b/sglang/python/sglang/srt/debug_utils/dumper.py @@ -0,0 +1,1485 @@ +import enum +import functools +import json +import os +import random +import re +import socket +import threading +import time +from abc import ABC, abstractmethod +from collections.abc import Callable +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import asdict, dataclass, field, fields, replace +from functools import cached_property +from http.server import BaseHTTPRequestHandler, HTTPServer +from pathlib import Path +from typing import Any, List, Literal, Optional, Union, get_args, get_type_hints + +import torch +import torch.distributed as dist + +# -------------------------------------- config base ------------------------------------------ + + +@dataclass(frozen=True) +class _BaseConfig(ABC): + def __post_init__(self) -> None: + self._verify_types() + + def _verify_types(self) -> None: + hints = get_type_hints(type(self)) + cls_name = type(self).__name__ + for f in fields(self): + value = getattr(self, f.name) + if value is None: + continue + expected = self._unwrap_type(hints[f.name]) + if not isinstance(value, expected): + raise TypeError( + f"{cls_name}.{f.name}: expected {expected.__name__}, " + f"got {type(value).__name__}" + ) + + @classmethod + @abstractmethod + def _env_prefix(cls) -> str: ... + + @classmethod + def _env_name(cls, field_name: str) -> str: + return f"{cls._env_prefix()}{field_name.upper()}" + + @classmethod + def from_env(cls) -> "_BaseConfig": + return cls( + **{ + f.name: cls._parse_env_field(cls._env_name(f.name), f.default) + for f in fields(cls) + } + ) + + def with_defaults(self, **kwargs) -> "_BaseConfig": + cls = type(self) + actual = { + key: value + for key, value in kwargs.items() + if os.getenv(cls._env_name(key)) is None + } + return replace(self, **actual) if actual else self + + @staticmethod + def _unwrap_type(hint) -> type: + args = get_args(hint) + if args: + return next(a for a in args if a is not type(None)) + return hint + + @classmethod + def _parse_env_field(cls, env_name: str, default): + return cls._parse_env_value(os.getenv(env_name), default) + + @staticmethod + def _parse_env_value(raw, default): + if raw is None or not raw.strip(): + return default + if isinstance(default, bool): + return raw.lower() in ("true", "1") + if isinstance(default, int): + return int(raw) + return raw + + @classmethod + def from_kv_pairs(cls, pairs: Optional[List[str]]) -> "_BaseConfig": + return cls(**cls._kv_pairs_to_dict(pairs)) + + @classmethod + def _kv_pairs_to_dict(cls, pairs: Optional[List[str]]) -> dict: + if not pairs: + return {} + + missing = object() + defaults = {f.name: f.default for f in fields(cls)} + result: dict = {} + + for pair in pairs: + key, sep, value = pair.partition("=") + if not sep: + raise ValueError(f"Invalid config pair (missing '='): {pair!r}") + default = defaults.get(key, missing) + if default is missing: + raise ValueError( + f"Unknown config key {key!r}. Valid keys: {sorted(defaults)}" + ) + try: + result[key] = cls._parse_env_value(value, default) + except (ValueError, TypeError) as exc: + field_type = type(default).__name__ + raise TypeError(f"{key}: expected {field_type}, got {value!r}") from exc + + return result + + +_DEFAULT_EXP_NAME_PREFIX = "dump_" + + +@dataclass(frozen=True) +class DumperConfig(_BaseConfig): + enable: bool = False + filter: Optional[str] = None + dir: str = "/tmp/dumper" + enable_output_file: bool = True + enable_output_console: bool = True + enable_value: bool = True + enable_grad: bool = False + enable_model_value: bool = False + enable_model_grad: bool = False + exp_name: Optional[str] = None + cleanup_previous: bool = False + collective_timeout: int = 60 + server_port: str = "-1" + non_intrusive_mode: str = "core" + source_patcher_config: Optional[str] = None + + @classmethod + def _env_prefix(cls) -> str: + # NOTE: should not be `SGLANG_DUMPER_`, otherwise it is weird when dumping Megatron in Miles + return "DUMPER_" + + @property + def server_port_parsed(self) -> Optional[Union[int, Literal["reuse"]]]: + raw = self.server_port + if raw == "reuse": + return "reuse" + port = int(raw) + if port <= 0: + return None + return port + + +# -------------------------------------- dumper core ------------------------------------------ + + +@dataclass +class _DumperState: + dump_index: int = 0 + step: int = 0 + global_ctx: dict = field(default_factory=dict) + captured_output_data: Optional[dict] = None + cleanup_previous_handled: bool = False + + +class _Dumper: + """Utility to dump tensors, which can be useful when comparison checking models. + + Example usage: + dumper.dump("layer_start__hidden_states", hidden_states, layer_id=self.layer_id) + dumper.step() + + Import from non-SGLang system: + ``` + import sys + sys.path.append("/YOUR_PATH/sglang/python/sglang/srt/debug_utils") + from dumper import dumper + ``` + + Then run the program: + `DUMPER_ENABLE=1 python ...` + + Auto-cleanup old dumps before first write: + `DUMPER_CLEANUP_PREVIOUS=1 python ...` + + Alternatively, disable at startup and configure via HTTP: + 1. `python ...` + 2. sglang mode: `curl -X POST http://localhost:30000/dumper/configure -d '{"enable": true}'` + standalone: `curl -X POST http://localhost:40000/dumper/configure -d '{"enable": true}'` + 3. `curl -X POST http://localhost:30000/dumper/configure -d '{"enable": true, "filter": "layer_id=[0-3]"}'` + 4. `curl -X POST http://localhost:30000/dumper/reset` + + Related: `sglang.srt.debug_utils.dump_comparator` for dump comparison + """ + + def __init__(self, *, config: DumperConfig): + self._config = config + self._state = _DumperState() + self._non_intrusives: list["_NonIntrusiveDumper"] = [] + + # ------------------------------- public :: core --------------------------------- + + @property + def may_enable(self) -> bool: + return self._config.enable or self._config.server_port_parsed is not None + + def step(self): + """This should be called on all ranks at the end of each iteration.""" + + self._http_manager # noqa: B018 + + if not self._config.enable: + return + + # Users may want to `dump` only on some ranks, thus determine name here + self._ensure_exp_name() + + self._state.step += 1 + print(f"[Dumper] [{time.time()}] step={self._state.step}") + + def dump( + self, + name: str, + value, + save: bool = True, + dims: Optional[str] = None, + dims_grad: Optional[str] = None, + **kwargs, + ) -> None: + value_meta: dict = {} + grad_meta: dict = {} + if dims is not None: + value_meta["dims"] = dims + grad_meta["dims"] = dims + if dims_grad is not None: + value_meta["dims_grad"] = dims_grad + grad_meta["dims"] = dims_grad + + self._dump_inner( + name=name, + value=value, + extra_kwargs=kwargs, + save=save, + enable_value=self._config.enable_value, + enable_curr_grad=False, + enable_future_grad=self._config.enable_grad, + value_tag="Dumper.Value", + grad_tag="Dumper.Grad", + value_meta_only_fields=value_meta, + grad_meta_only_fields=grad_meta, + ) + + def dump_model( + self, + model: "torch.nn.Module", + name_prefix: str = "param", + save: bool = True, + **kwargs, + ) -> None: + for param_name, param in model.named_parameters(): + self._dump_inner( + name=f"{name_prefix}__{param_name}", + value=param, + extra_kwargs=kwargs, + save=save, + enable_value=self._config.enable_model_value, + enable_curr_grad=self._config.enable_model_grad, + enable_future_grad=False, + value_tag="Dumper.ParamValue", + grad_tag="Dumper.ParamGrad", + ) + + def dump_dict(self, name_prefix, data, save: bool = True, **kwargs): + data = _obj_to_dict(data) + for name, value in data.items(): + self.dump(f"{name_prefix}_{name}", value, save=save, **kwargs) + + def set_ctx(self, **kwargs): + """ + Example: + + dumper.configure_default(filter='layer_id=[0-3]') + dumper.set_ctx(layer_id=self.layer_id) + ... + dumper.set_ctx(layer_id=None) + """ + self._state.global_ctx = { + k: v for k, v in (self._state.global_ctx | kwargs).items() if v is not None + } + + def ctx( + self, + _extractor: Optional[Callable[..., dict]] = None, + **static_ctx: Any, + ) -> Callable: + """Decorator that sets context before calling the wrapped function and clears it after. + + Two forms: + @dumper.ctx(lambda self: dict(layer_id=self.layer_id)) + def forward(self, x): ... + + @dumper.ctx(phase="decode") + def decode_step(self, x): ... + """ + if _extractor is not None and static_ctx: + raise ValueError("cannot mix lambda extractor with static kwargs") + if _extractor is None and not static_ctx: + raise ValueError("must provide either a lambda or static kwargs") + + def decorator(fn: Callable) -> Callable: + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + ctx_dict: dict = _extractor(args[0]) if _extractor else static_ctx + self.set_ctx(**ctx_dict) + try: + return fn(*args, **kwargs) + finally: + self.set_ctx(**{k: None for k in ctx_dict}) + + return wrapper + + return decorator + + def apply_source_patches(self) -> None: + """Apply source patches from DUMPER_SOURCE_PATCHER_CONFIG if set. + + Automatically injects ``from sglang.srt.debug_utils.dumper import dumper`` + into every replacement block so users don't need to write it in YAML. + """ + config_path = self._config.source_patcher_config + if not config_path: + return + + from sglang.srt.debug_utils.source_patcher import apply_patches_from_config + + yaml_content: str = Path(config_path).read_text() + print(f"[source_patcher] loading config from {config_path}") + apply_patches_from_config( + yaml_content, + extra_imports=["from sglang.srt.debug_utils.dumper import dumper"], + ) + + def register_non_intrusive_dumper( + self, + model: "torch.nn.Module", + ) -> Optional["_NonIntrusiveDumper"]: + self._http_manager # noqa: B018 + mode = self._config.non_intrusive_mode + if mode == "off": + return None + non_intrusive = _NonIntrusiveDumper(dumper=self, model=model, mode=mode) + self._non_intrusives.append(non_intrusive) + return non_intrusive + + # ------------------------------- public :: secondary --------------------------------- + + def configure(self, **kwargs) -> None: + self._config = replace(self._config, **kwargs) + + def configure_default(self, **kwargs) -> None: + self._config = self._config.with_defaults(**kwargs) + + def reset(self) -> None: + for non_intrusive in self._non_intrusives: + non_intrusive.remove() + self._non_intrusives.clear() + self._state = _DumperState() + + @contextmanager + def capture_output(self): + assert self._state.captured_output_data is None + self._state.captured_output_data = {} + try: + yield self._state.captured_output_data + finally: + self._state.captured_output_data = None + + def get_state(self) -> dict: + return { + "config": asdict(self._config), + "dump_index": self._state.dump_index, + "step": self._state.step, + } + + @cached_property + def _http_manager(self) -> Optional["_DumperHttpManager"]: + if self._config.server_port_parsed is None: + return None + return _DumperHttpManager(self) + + # ------------------------- private :: related to dump ----------------------------- + + def _dump_inner( + self, + *, + name: str, + value, + extra_kwargs: dict, + save: bool, + enable_value: bool, + enable_curr_grad: bool, + enable_future_grad: bool, + value_tag: str, + grad_tag: str, + value_meta_only_fields: Optional[dict] = None, + grad_meta_only_fields: Optional[dict] = None, + ) -> None: + self._http_manager # noqa: B018 + + if not self._config.enable: + return + + recompute_status = _detect_recompute_status() + tags = dict( + name=name, + recompute_status=recompute_status.value, + **extra_kwargs, + **self._state.global_ctx, + ) + + if (f := self._config.filter) is not None and not _evaluate_filter(f, tags): + return + + if not (enable_value or enable_curr_grad or enable_future_grad): + return + + recompute_meta = recompute_status.to_pseudo_parallel_meta() + value = _materialize_value(value) + + if enable_value: + self._dump_single( + tag=value_tag, + tags=tags, + value=value, + save=save, + meta_only_fields={**(value_meta_only_fields or {}), **recompute_meta}, + ) + + if ( + enable_curr_grad + and isinstance(value, torch.Tensor) + and (g := value.grad) is not None + ): + self._dump_single( + tag=grad_tag, + tags={**tags, "name": f"grad__{name}"}, + value=g, + save=save, + meta_only_fields={**(grad_meta_only_fields or {}), **recompute_meta}, + ) + + if enable_future_grad: + self._register_dump_grad_hook( + name=name, + tensor=value, + extra_kwargs=extra_kwargs, + save=save, + meta_only_fields=grad_meta_only_fields or {}, + ) + + def _register_dump_grad_hook( + self, + *, + name: str, + tensor, + extra_kwargs: dict, + save: bool, + meta_only_fields: Optional[dict] = None, + ) -> None: + if not isinstance(tensor, torch.Tensor): + return + if not tensor.requires_grad: + return + + captured_step = self._state.step + captured_tags = dict( + name=f"grad__{name}", + **deepcopy(extra_kwargs), + ) + captured_meta_only = meta_only_fields or {} + + def grad_hook(grad: torch.Tensor) -> None: + self._dump_single( + tag="Dumper.Grad", + tags=captured_tags, + value=grad, + save=save, + step=captured_step, + meta_only_fields=captured_meta_only, + ) + + tensor.register_hook(grad_hook) + + def _dump_single( + self, + *, + tag: str, + tags: dict, + value, + save: bool, + step: Optional[int] = None, + meta_only_fields: Optional[dict] = None, + ) -> None: + self._ensure_exp_name() + self._state.dump_index += 1 + + rank = _get_rank() + full_kwargs = dict( + step=(step if step is not None else self._state.step), + rank=rank, + dump_index=self._state.dump_index, + **tags, + ) + full_filename = _format_tags(full_kwargs) + ".pt" + path = Path(self._config.dir) / self._config.exp_name / full_filename + + if self._config.enable_output_console: + print( + f"[{tag}] [{rank}, {time.time()}] {path} " + f"type={type(value)} " + f"shape={value.shape if isinstance(value, torch.Tensor) else None} " + f"dtype={value.dtype if isinstance(value, torch.Tensor) else None} " + f"device={value.device if isinstance(value, torch.Tensor) else None} " + f"id={id(value)} " + f"sample_value={get_truncated_value(value)}" + ) + + capturing = self._state.captured_output_data is not None + if save and (self._config.enable_output_file or capturing): + output_data = { + "value": value, + "meta": dict( + **full_kwargs, + **self._static_meta, + **(meta_only_fields or {}), + ), + } + + if capturing: + output_data["value"] = _deepcopy_or_clone(output_data["value"]) + self._state.captured_output_data[tags["name"]] = output_data + else: + if ( + not self._state.cleanup_previous_handled + and self._config.cleanup_previous + ): + self._state.cleanup_previous_handled = True + _cleanup_old_dumps( + Path(self._config.dir), exp_name=self._config.exp_name + ) + + path.parent.mkdir(parents=True, exist_ok=True) + _torch_save(output_data, str(path)) + + # ------------------------------- private :: misc --------------------------------- + + @cached_property + def _static_meta(self) -> dict: + return _compute_static_meta() + + def _ensure_exp_name(self): + if self._config.exp_name is None: + name = _get_default_exp_name( + timeout_seconds=self._config.collective_timeout + ) + self.configure(exp_name=name) + print(f"[Dumper] Choose exp_name={name}") + + +# -------------------------------------- hook dumper ------------------------------------------ + + +class _NonIntrusiveDumper: + _NAME_PREFIX = "non_intrusive__" + _LAYER_NAME_RE = re.compile(r"(?:.+\.)?layers\.(\d+)$") + + def __init__( + self, + dumper: _Dumper, + model: "torch.nn.Module", + mode: str, + ): + self._dumper = dumper + self._mode = mode + self._handles: list = [] + self._core_fields: frozenset[str] = frozenset().union( + *(p.core_fields() for p in _plugins) + ) + + for module_name, module in model.named_modules(): + if ctx := self._detect_module_ctx(module_name, module): + self._register_ctx_hooks(module, ctx=ctx) + + is_root = module_name == "" + pre_hook = self._make_forward_pre_hook( + module_name=module_name, is_root=is_root + ) + hook = self._make_forward_hook(module_name=module_name, is_root=is_root) + self._handles += _register_forward_hook_or_replace_fn( + module, + pre_hook=pre_hook, + hook=hook, + mode="replace_fn" if is_root else "hook", + ) + + def remove(self) -> None: + for handle in self._handles: + handle.remove() + self._handles.clear() + + @classmethod + def _detect_module_ctx( + cls, module_name: str, module: "torch.nn.Module" + ) -> Optional[dict]: + match = cls._LAYER_NAME_RE.fullmatch(module_name) + if match: + for plugin in _plugins: + layer_id = plugin.detect_layer_id(module) + if layer_id is not None: + return {"layer_id": layer_id} + return {"layer_id": int(match.group(1))} + return None + + def _register_ctx_hooks(self, module: "torch.nn.Module", *, ctx: dict) -> None: + clear_ctx = {k: None for k in ctx} + self._handles.append( + module.register_forward_pre_hook( + lambda _mod, _input, _ctx=ctx: self._dumper.set_ctx(**_ctx) + ) + ) + self._handles.append( + module.register_forward_hook( + lambda _mod, _input, _output, _clear=clear_ctx: self._dumper.set_ctx( + **_clear + ) + ) + ) + + def _make_forward_pre_hook(self, *, module_name: str, is_root: bool): + def _hook(_module, args, kwargs): + for i, item in enumerate(args): + self._dump_value( + module_name, item, sub_name=f"inputs.{i}", is_root=is_root + ) + for name, value in kwargs.items(): + self._dump_value( + module_name, + value, + sub_name=f"inputs.{name}", + is_root=is_root, + ) + + return _hook + + def _make_forward_hook(self, *, module_name: str, is_root: bool): + def _hook(_module, input, output): + if output is not None: + self._dump_value(module_name, output, sub_name="output", is_root=False) + + return _hook + + def _dump_value( + self, module_name: str, value: Any, sub_name: str, *, is_root: bool + ) -> None: + for key, item in self._convert_value( + value, skip_forward_batch=(not is_root) + ).items(): + effective_key = key or sub_name.rsplit(".", 1)[-1] + if effective_key in self._core_fields: + self._dumper.dump(effective_key, item) + elif self._mode == "all": + parts = [p for p in (module_name, sub_name, key) if p] + self._dumper.dump(self._NAME_PREFIX + ".".join(parts), item) + + @staticmethod + def _convert_value(value, *, skip_forward_batch: bool = False) -> dict[str, Any]: + if isinstance(value, torch.Tensor): + return {"": value} + + if isinstance(value, (tuple, list)): + tensors = [t for t in value if isinstance(t, torch.Tensor)] + if len(tensors) == 1: + return {"": tensors[0]} + return {str(i): t for i, t in enumerate(tensors)} + + for plugin in _plugins: + result = plugin.convert_value(value, skip_forward_batch=skip_forward_batch) + if result is not None: + return result + + return {} + + +def _register_forward_hook_or_replace_fn( + module: "torch.nn.Module", + *, + pre_hook, + hook, + mode: str, +) -> list: + """Attach pre/post forward hooks to *module*. + + mode="hook" — standard ``register_forward_pre_hook`` / ``register_forward_hook`` + (fires only via ``__call__``). + mode="replace_fn" — monkey-patch ``module.forward`` so hooks fire even when + callers invoke ``.forward()`` directly (as sglang does for the + root model). + + Returns a list of handle objects with a ``.remove()`` method that undoes + the registration. + """ + if mode == "hook": + return [ + module.register_forward_pre_hook(pre_hook, with_kwargs=True), + module.register_forward_hook(hook), + ] + elif mode == "replace_fn": + original_forward = module.forward + + @functools.wraps(original_forward) + def _wrapped(*args, **kwargs): + pre_hook(module, args, kwargs) + output = original_forward(*args, **kwargs) + hook(module, args, output) + return output + + module.forward = _wrapped + + class _Handle: + def remove(self) -> None: + assert module.forward is _wrapped + module.forward = original_forward + + return [_Handle()] + else: + raise ValueError(f"Unknown mode {mode!r}") + + +# -------------------------------------- util fn ------------------------------------------ + + +def _torch_save(value, path: str): + value = _clone_if_view(value) + try: + try: + return torch.save(value, path) + except RuntimeError as e: + if "not pickleable" in str(e): + stripped = _strip_parameter(value) + if stripped is not value: + print(f"[Dumper] Observe error={e} and try pickling .data") + return _torch_save(stripped, path) + raise + except Exception as e: + print(f"[Dumper] Observe error={e} when saving data, skip the tensor") + + +def _map_tensor(value, fn: Callable[[torch.Tensor], torch.Tensor]): + if isinstance(value, dict): + return {k: _map_tensor(v, fn) for k, v in value.items()} + if isinstance(value, torch.Tensor): + return fn(value) + return value + + +def _clone_if_view(value): + def _fn(t: torch.Tensor) -> torch.Tensor: + if t.untyped_storage().nbytes() > t.nelement() * t.element_size(): + return t.clone() + return t + + return _map_tensor(value, _fn) + + +def _strip_parameter(value): + def _fn(t: torch.Tensor) -> torch.Tensor: + if isinstance(t, torch.nn.Parameter): + return t.data + return t + + return _map_tensor(value, _fn) + + +def _collective_with_timeout(fn, operation_name: str, timeout_seconds: int = 60): + completed = threading.Event() + + def watchdog(): + if not completed.wait(timeout=timeout_seconds): + print( + f"\n[Dumper] WARNING: '{operation_name}' has not completed after " + f"{timeout_seconds}s. This usually means not all ranks are " + f"participating in this collective operation.\n", + flush=True, + ) + + thread = threading.Thread(target=watchdog, daemon=True) + thread.start() + try: + return fn() + finally: + completed.set() + + +def _get_default_exp_name(timeout_seconds: int = 60): + rank = _get_rank() + now = time.time() + ms = int((now % 1) * 1000) + rand_suffix = random.randint(0, 999) + object_list = [ + ( + ( + f"{_DEFAULT_EXP_NAME_PREFIX}" + f"{time.strftime('%Y%m%d_%H%M%S', time.gmtime(now))}" + f"_{ms:03d}{rand_suffix:03d}" + ) + if rank == 0 + else None + ) + ] + + if dist.is_initialized(): + _collective_with_timeout( + lambda: dist.broadcast_object_list(object_list, device="cuda"), + operation_name="broadcast_object_list in _get_default_exp_name", + timeout_seconds=timeout_seconds, + ) + + return object_list[0] + + +def _cleanup_old_dumps(base_dir: Path, exp_name: Optional[str] = None) -> None: + import shutil + + if _get_rank() == 0: + targets = {entry for entry in base_dir.glob(f"{_DEFAULT_EXP_NAME_PREFIX}*")} + if exp_name: + targets.add(base_dir / exp_name) + targets = {d for d in targets if d.is_dir()} + + for entry in targets: + shutil.rmtree(entry) + print(f"[Dumper] Cleaned up {entry}") + + if dist.is_initialized(): + _collective_with_timeout( + dist.barrier, + operation_name="barrier in _cleanup_old_dumps", + ) + + +def _get_rank(): + if dist.is_initialized(): + return dist.get_rank() + else: + return 0 + + +def _get_world_size(): + if dist.is_initialized(): + return dist.get_world_size() + else: + return 1 + + +def _obj_to_dict(obj): + if isinstance(obj, dict): + return obj + ret = {} + for k in dir(obj): + if k.startswith("__") and k.endswith("__"): + continue + try: + v = getattr(obj, k) + if not callable(v): + ret[k] = v + except Exception: + # Skip attributes that raise an exception on access + continue + return ret + + +def _materialize_value(value): + if callable(value): + value = value() + return value + + +def _format_tags(kwargs: dict) -> str: + return "___".join(f"{k}={v}" for k, v in kwargs.items()) + + +class _DefaultNoneDict(dict): + """dict subclass that returns None for missing keys, for filter expression eval.""" + + def __missing__(self, key: str): + return None + + +_FILTER_BUILTINS: dict[str, Any] = {"search": re.search, "match": re.match} + + +def _evaluate_filter(filter_expr: str, tags: dict[str, Any]) -> bool: + """Evaluate a Python filter expression against the tags dict. + + Unknown tag keys resolve to None, so `layer_id is None` works when layer_id is absent. + `re.search` and `re.match` are available as `search()` and `match()`. + """ + namespace = _DefaultNoneDict(tags) + namespace.update(_FILTER_BUILTINS) + return bool(eval(filter_expr, {"__builtins__": {}}, namespace)) + + +def _deepcopy_or_clone(x): + if isinstance(x, torch.Tensor): + return x.clone() + return deepcopy(x) + + +# -------------------------------------- static meta ------------------------------------------ + + +def _compute_static_meta(): + result = { + "world_rank": _get_rank(), + "world_size": _get_world_size(), + } + + for plugin in _plugins: + if info := plugin.collect_parallel_info(): + result[f"{plugin.name}_parallel_info"] = info + + for plugin in _plugins: + tokenizer_path: Optional[str] = plugin.get_tokenizer_path() + if tokenizer_path is not None: + result["tokenizer_path"] = tokenizer_path + break + + return result + + +# -------------------------------------- http manager ------------------------------------------ + + +class _DumperHttpManager: + def __init__(self, dumper: "_Dumper"): + self._dumper = dumper + http_port = self._dumper._config.server_port_parsed + + rpc_broadcast = _create_zmq_rpc_broadcast( + self, + timeout_seconds=self._dumper._config.collective_timeout, + ) + + if _get_rank() == 0: + assert rpc_broadcast is not None + self._rpc_broadcast = rpc_broadcast + + if http_port == "reuse": + print( + "[Dumper] Standalone HTTP server disabled, reusing existing ports" + ) + else: + _start_http_server(prefix="/dumper/", target=self, http_port=http_port) + print(f"[Dumper] HTTP server started on port {http_port}") + + # ------------------------------- public --------------------------------- + + def handle_request(self, *, method: str, body: dict[str, Any]) -> list[dict]: + return self._rpc_broadcast._handle_request_inner(method=method, body=body) + + # ------------------------------- private --------------------------------- + + def _handle_request_inner(self, *, method: str, body: dict[str, Any]) -> dict: + if method == "get_state": + return self._dumper.get_state() + elif method == "configure": + self._dumper.configure(**body) + return {} + elif method == "reset": + self._dumper.reset() + return {} + else: + raise ValueError(f"Unknown dumper control method: {method!r}") + + +# -------------------------------------- http control server ------------------------------------------ + + +def _start_http_server(*, prefix: str, target: object, http_port: int): + handler_class = _make_http_handler(prefix=prefix, target=target) + server = HTTPServer(("0.0.0.0", http_port), handler_class) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + + +def _make_http_handler(*, prefix: str, target): + class _HTTPHandler(BaseHTTPRequestHandler): + def do_POST(self): + if not self.path.startswith(prefix): + self.send_error(404) + return + method = self.path[len(prefix) :] + try: + req_body = self._get_request_body() + print(f"[Dumper#{_get_rank()}] HTTP {self.path} {req_body=}") + result = target.handle_request(method=method, body=req_body) + resp_body = json.dumps(result).encode() + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(resp_body))) + self.end_headers() + self.wfile.write(resp_body) + except Exception as e: + self.send_error(400, str(e)) + + def _get_request_body(self) -> dict: + content_length = int(self.headers.get("Content-Length", 0)) + if content_length == 0: + return {} + return json.loads(self.rfile.read(content_length)) + + return _HTTPHandler + + +# -------------------------------------- zmq rpc ------------------------------------------ + + +def _create_zmq_rpc_broadcast( + handler, timeout_seconds: int = 60 +) -> Optional["_ZmqRpcBroadcast"]: + """A general-purpose minimal RPC to support broadcasting executions to multi processes""" + import zmq + + rank = _get_rank() + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + ctx = zmq.Context() + sock = ctx.socket(zmq.REP) + sock.bind("tcp://*:0") + bound_port = int(sock.getsockopt_string(zmq.LAST_ENDPOINT).rsplit(":", 1)[1]) + local_addr = f"tcp://{_get_local_ip_by_remote()}:{bound_port}" + + def serve_loop(): + while True: + try: + req = sock.recv_pyobj() + result = getattr(handler, req["method"])(*req["args"], **req["kwargs"]) + resp = {"result": result, "error": None} + except Exception as e: + print(f"[Dumper.ZmqRpc] error inside handler: {e}") + resp = {"result": None, "error": str(e)} + sock.send_pyobj(resp) + + thread = threading.Thread(target=serve_loop, daemon=True) + thread.start() + print(f"[Dumper.ZmqRpc] rank={rank} server started at {local_addr}") + + if dist.is_initialized(): + all_addresses = [None] * world_size + _collective_with_timeout( + lambda: dist.all_gather_object(all_addresses, local_addr), + operation_name="all_gather_object in _create_zmq_rpc_broadcast", + timeout_seconds=timeout_seconds, + ) + else: + all_addresses = [local_addr] + print(f"[Dumper.ZmqRpc] rank={rank} all_addresses={all_addresses}") + + if rank == 0: + handles = [] + for i, addr in enumerate(all_addresses): + req_socket = ctx.socket(zmq.REQ) + req_socket.connect(addr) + handles.append(_ZmqRpcHandle(req_socket, debug_name=f"rank-{i}")) + return _ZmqRpcBroadcast(handles) + else: + return None + + +class _ZmqRpcHandle: + """Proxy object to call remote handler methods via ZMQ.""" + + def __init__(self, socket, debug_name: str): + self._socket = socket + self._debug_name = debug_name + + def __getattr__(self, method_name: str): + def call(*args, **kwargs): + self._socket.send_pyobj( + { + "method": method_name, + "args": args, + "kwargs": kwargs, + } + ) + response = self._socket.recv_pyobj() + if response["error"]: + raise RuntimeError( + f"RPC error on {self._debug_name}: {response['error']}" + ) + return response["result"] + + return call + + +class _RpcBroadcastBase: + """Base for broadcasting method calls to dumper instance(s).""" + + def __getattr__(self, method_name: str): + raise NotImplementedError + + def __init__(self, handles: List[_ZmqRpcHandle]): + self._handles = handles + + +class _ZmqRpcBroadcast(_RpcBroadcastBase): + """Broadcasts method calls to all ZMQ RPC handles. + + Returns a list of results, one per rank (ordered by rank). + """ + + def __init__(self, handles: List[_ZmqRpcHandle]): + self._handles = handles + + def __getattr__(self, method_name: str): + def call(*args, **kwargs): + return [ + getattr(handle, method_name)(*args, **kwargs) + for handle in self._handles + ] + + return call + + +# --------------------------------- copied code (avoid dependency) -------------------------------------- + + +def _get_local_ip_by_remote() -> Optional[str]: + # try ipv4 + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + try: + hostname = socket.gethostname() + ip = socket.gethostbyname(hostname) + if ip and ip != "127.0.0.1" and ip != "0.0.0.0": + return ip + except Exception: + pass + + # try ipv6 + try: + s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) + # Google's public DNS server, see + # https://developers.google.com/speed/public-dns/docs/using#addresses + s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + print("Can not get local ip by remote") + return None + + +# -------------------------------------- framework plugins ------------------------------------------ + + +class _RecomputeStatus(enum.Enum): + DISABLED = "disabled" + ORIGINAL = "original" # inside checkpoint, original forward + RECOMPUTE = "recompute" # inside checkpoint, recompute forward + + def to_pseudo_parallel_meta(self) -> dict[str, Any]: + if self == _RecomputeStatus.DISABLED: + return {} + return { + "recompute_pseudo_rank": 1 if self == _RecomputeStatus.RECOMPUTE else 0, + "recompute_pseudo_size": 2, + } + + +class _FrameworkPlugin(ABC): + @property + @abstractmethod + def name(self) -> str: ... + + @abstractmethod + def collect_parallel_info(self) -> dict: ... + + @abstractmethod + def convert_value( + self, value: Any, *, skip_forward_batch: bool + ) -> Optional[dict[str, Any]]: + """Return converted dict, or None if this plugin doesn't handle the value.""" + ... + + @abstractmethod + def detect_layer_id(self, module: "torch.nn.Module") -> Optional[int]: + """Return 0-indexed layer_id, or None if not detectable.""" + ... + + def core_fields(self) -> frozenset[str]: + return frozenset() + + def get_tokenizer_path(self) -> Optional[str]: + return None + + def detect_recompute_status(self) -> _RecomputeStatus: + return _RecomputeStatus.DISABLED + + +class _SGLangPlugin(_FrameworkPlugin): + _available = True + try: + from sglang.srt import distributed as _dist + from sglang.srt.layers import dp_attention as _dp_attn + from sglang.srt.layers.logits_processor import LogitsProcessorOutput + from sglang.srt.model_executor.forward_batch_info import ( + ForwardBatch, + PPProxyTensors, + ) + except ImportError: + _available = False + + @property + def name(self) -> str: + return "sglang" + + def collect_parallel_info(self) -> dict: + if not self._available: + return {} + + info = {} + + try: + info["tp_rank"] = self._dist.get_tensor_model_parallel_rank() + info["tp_size"] = self._dist.get_tensor_model_parallel_world_size() + info["pp_rank"] = self._dist.get_pipeline_model_parallel_rank() + info["pp_size"] = self._dist.get_pipeline_model_parallel_world_size() + info["moe_ep_rank"] = self._dist.get_moe_expert_parallel_rank() + info["moe_ep_size"] = self._dist.get_moe_expert_parallel_world_size() + info["moe_tp_rank"] = self._dist.get_moe_tensor_parallel_rank() + info["moe_tp_size"] = self._dist.get_moe_tensor_parallel_world_size() + info["moe_dp_rank"] = self._dist.get_moe_data_parallel_rank() + info["moe_dp_size"] = self._dist.get_moe_data_parallel_world_size() + except (AttributeError, AssertionError): + info["distributed_error"] = True + + try: + info["enable_dp_attention"] = self._dp_attn.is_dp_attention_enabled() + info["attn_tp_rank"] = self._dp_attn.get_attention_tp_rank() + info["attn_tp_size"] = self._dp_attn.get_attention_tp_size() + info["attn_dp_rank"] = self._dp_attn.get_attention_dp_rank() + info["attn_dp_size"] = self._dp_attn.get_attention_dp_size() + info["local_attn_dp_rank"] = self._dp_attn.get_local_attention_dp_rank() + info["local_attn_dp_size"] = self._dp_attn.get_local_attention_dp_size() + info["attn_cp_rank"] = self._dp_attn.get_attention_cp_rank() + info["attn_cp_size"] = self._dp_attn.get_attention_cp_size() + except (AttributeError, AssertionError): + info["dp_attention_error"] = True + + return info + + def convert_value( + self, value: Any, *, skip_forward_batch: bool + ) -> Optional[dict[str, Any]]: + if not self._available: + return None + + if isinstance(value, self.LogitsProcessorOutput): + return {"next_token_logits": value.next_token_logits} + if isinstance(value, self.ForwardBatch): + if skip_forward_batch: + return {} + result = { + "input_ids": value.input_ids, + "seq_lens": value.seq_lens, + "positions": value.positions, + "req_pool_indices": value.req_pool_indices, + } + if value.rids is not None: + result["rids"] = value.rids + return result + if isinstance(value, self.PPProxyTensors): + return {k: v for k, v in value.tensors.items()} + + return None + + def detect_layer_id(self, module: "torch.nn.Module") -> Optional[int]: + if hasattr(module, "layer_id"): + return module.layer_id + return None + + def core_fields(self) -> frozenset[str]: + return frozenset( + {"input_ids", "positions", "seq_lens", "req_pool_indices", "rids"} + ) + + def get_tokenizer_path(self) -> Optional[str]: + if not self._available: + return None + + try: + from sglang.srt.server_args import get_global_server_args + + args = get_global_server_args() + if args is None: + return None + + return args.tokenizer_path + except Exception: + return None + + +class _MegatronPlugin(_FrameworkPlugin): + _available = True + try: + from megatron.core import parallel_state as _mpu + from megatron.core.packed_seq_params import PackedSeqParams + except ImportError: + _available = False + + @property + def name(self) -> str: + return "megatron" + + def collect_parallel_info(self) -> dict: + if not self._available: + return {} + + info = {} + try: + info["tp_rank"] = self._mpu.get_tensor_model_parallel_rank() + info["tp_size"] = self._mpu.get_tensor_model_parallel_world_size() + info["pp_rank"] = self._mpu.get_pipeline_model_parallel_rank() + info["pp_size"] = self._mpu.get_pipeline_model_parallel_world_size() + info["dp_rank"] = self._mpu.get_data_parallel_rank() + info["dp_size"] = self._mpu.get_data_parallel_world_size() + info["cp_rank"] = self._mpu.get_context_parallel_rank() + info["cp_size"] = self._mpu.get_context_parallel_world_size() + info["vpp_rank"] = self._mpu.get_virtual_pipeline_model_parallel_rank() + info["vpp_size"] = ( + self._mpu.get_virtual_pipeline_model_parallel_world_size() + ) + info["ep_rank"] = self._mpu.get_expert_model_parallel_rank() + info["ep_size"] = self._mpu.get_expert_model_parallel_world_size() + info["etp_rank"] = self._mpu.get_expert_tensor_parallel_rank() + info["etp_size"] = self._mpu.get_expert_tensor_parallel_world_size() + info["edp_rank"] = self._mpu.get_expert_data_parallel_rank() + info["edp_size"] = self._mpu.get_expert_data_parallel_world_size() + info["tcp_rank"] = self._mpu.get_tensor_and_context_parallel_rank() + info["tcp_size"] = self._mpu.get_tensor_and_context_parallel_world_size() + info["etmp_rank"] = self._mpu.get_expert_tensor_and_model_parallel_rank() + info["etmp_size"] = ( + self._mpu.get_expert_tensor_and_model_parallel_world_size() + ) + info["tp_src_rank"] = self._mpu.get_tensor_model_parallel_src_rank() + info["mp_src_rank"] = self._mpu.get_model_parallel_src_rank() + info["dp_src_rank"] = self._mpu.get_data_parallel_src_rank() + except (AttributeError, AssertionError): + info["megatron_error"] = True + + # Megatron sequence parallel reuses the TP group (no dedicated parallel state API). + # When sequence_parallel=True, inject sp_rank/sp_size for the comparator unsharder. + try: + from megatron.training.global_vars import get_args + + args = get_args() + if getattr(args, "sequence_parallel", False) and "tp_rank" in info: + info["sp_rank"] = info["tp_rank"] + info["sp_size"] = info["tp_size"] + except (ImportError, AssertionError, AttributeError): + pass + + return info + + def convert_value( + self, value: Any, *, skip_forward_batch: bool + ) -> Optional[dict[str, Any]]: + if not self._available: + return None + if isinstance(value, self.PackedSeqParams): + return { + "cu_seqlens_q": value.cu_seqlens_q, + "cu_seqlens_kv": value.cu_seqlens_kv, + "qkv_format": value.qkv_format, + } + return None + + def detect_layer_id(self, module: "torch.nn.Module") -> Optional[int]: + if hasattr(module, "layer_number"): + return module.layer_number - 1 + return None + + def core_fields(self) -> frozenset[str]: + return frozenset( + {"input_ids", "position_ids", "cu_seqlens_q", "cu_seqlens_kv", "qkv_format"} + ) + + def detect_recompute_status(self) -> _RecomputeStatus: + if not self._available: + return _RecomputeStatus.DISABLED + try: + from megatron.core.tensor_parallel.random import is_checkpointing + + if not is_checkpointing(): + return _RecomputeStatus.DISABLED + if torch.is_grad_enabled(): + return _RecomputeStatus.RECOMPUTE + return _RecomputeStatus.ORIGINAL + except (ImportError, AttributeError): + return _RecomputeStatus.DISABLED + + +_plugins: list[_FrameworkPlugin] = [_SGLangPlugin(), _MegatronPlugin()] + + +def _detect_recompute_status() -> _RecomputeStatus: + for plugin in _plugins: + info = plugin.detect_recompute_status() + if info != _RecomputeStatus.DISABLED: + return info + return _RecomputeStatus.DISABLED + + +# -------------------------------------- singleton ------------------------------------------ + + +dumper = _Dumper(config=DumperConfig.from_env()) + + +# -------------------------------------- other utility functions ------------------------------------------ + + +def get_truncated_value(value): + if value is None: + return None + + if isinstance(value, tuple): + return [get_truncated_value(x) for x in value] + + if not isinstance(value, torch.Tensor): + return value + + if value.numel() < 200: + return value + + slices = [slice(0, 5) if dim_size > 50 else slice(None) for dim_size in value.shape] + return value[tuple(slices)] + + +def get_tensor_info(x): + """ + from sglang.srt.debug_utils.dumper import get_tensor_info + """ + if not isinstance(x, torch.Tensor): + return f"type={type(x)} value={x}" + min = x.float().min() if x.numel() > 0 else None + max = x.float().max() if x.numel() > 0 else None + mean = x.float().mean() if x.numel() > 0 else None + torch.set_printoptions(precision=10) + x_sample_head = str(x.flatten()[:5]) + x_sample_tail = str(x.flatten()[-5:]) + torch.set_printoptions(precision=4) + return ( + f"type={type(x)} " + f"shape={x.shape} " + f"dtype={x.dtype} " + f"device={x.device} " + f"stride={x.stride()} " + f"req_grad={x.requires_grad} " + f"min={min} " + f"max={max} " + f"mean={mean} " + f"x_sample_head={x_sample_head} " + f"x_sample_tail={x_sample_tail}" + ) diff --git a/sglang/python/sglang/srt/debug_utils/log_parser.py b/sglang/python/sglang/srt/debug_utils/log_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..2fac5126940ae36f7305e4ec6a3c872e7a6b9245 --- /dev/null +++ b/sglang/python/sglang/srt/debug_utils/log_parser.py @@ -0,0 +1,46 @@ +_PATTERN_DECODE = ( + r"(\(\w+ pid=(?P\d+)(?:,\s*ip=(?P[\d\.]+))?\))?\s*" + r"\[(?P