Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- sglang/.claude/skills/add-jit-kernel/SKILL.md +553 -0
- sglang/.claude/skills/add-sgl-kernel/SKILL.md +358 -0
- sglang/.claude/skills/sglang-bisect-ci-regression/SKILL.md +219 -0
- sglang/.claude/skills/write-sglang-test/SKILL.md +248 -0
- sglang/benchmark/json_jump_forward/README.md +88 -0
- sglang/benchmark/json_jump_forward/bench_other.py +288 -0
- sglang/benchmark/json_jump_forward/bench_sglang.py +143 -0
- sglang/benchmark/json_jump_forward/build_dataset.py +58 -0
- sglang/benchmark/json_jump_forward/dataset.txt +50 -0
- sglang/benchmark/multi_turn_chat/bench_other.py +93 -0
- sglang/benchmark/multi_turn_chat/data_gen.py +29 -0
- sglang/benchmark/tree_of_thought_deep/README.md +51 -0
- sglang/benchmark/tree_of_thought_deep/bench_other.py +222 -0
- sglang/benchmark/tree_of_thought_deep/bench_sglang.py +171 -0
- sglang/docker/configs/.zshrc +27 -0
- sglang/docker/configs/opt/.gitconfig +30 -0
- sglang/docker/configs/opt/.tmux.conf +27 -0
- sglang/docker/configs/opt/.vimrc +45 -0
- sglang/docker/configs/yank +12 -0
- sglang/python/sglang.egg-info/PKG-INFO +120 -0
- sglang/python/sglang.egg-info/SOURCES.txt +0 -0
- sglang/python/sglang.egg-info/dependency_links.txt +1 -0
- sglang/python/sglang.egg-info/entry_points.txt +2 -0
- sglang/python/sglang.egg-info/requires.txt +121 -0
- sglang/python/sglang.egg-info/top_level.txt +1 -0
- sglang/python/sglang/README.md +18 -0
- sglang/python/sglang/__init__.py +83 -0
- sglang/python/sglang/__pycache__/__init__.cpython-311.pyc +0 -0
- sglang/python/sglang/__pycache__/_version.cpython-311.pyc +0 -0
- sglang/python/sglang/__pycache__/bench_serving.cpython-311.pyc +0 -0
- sglang/python/sglang/__pycache__/check_env.cpython-311.pyc +0 -0
- sglang/python/sglang/__pycache__/global_config.cpython-311.pyc +0 -0
- sglang/python/sglang/__pycache__/launch_server.cpython-311.pyc +0 -0
- sglang/python/sglang/__pycache__/utils.cpython-311.pyc +0 -0
- sglang/python/sglang/__pycache__/version.cpython-311.pyc +0 -0
- sglang/python/sglang/_version.py +34 -0
- sglang/python/sglang/bench_offline_throughput.py +543 -0
- sglang/python/sglang/bench_one_batch.py +837 -0
- sglang/python/sglang/bench_one_batch_server.py +49 -0
- sglang/python/sglang/bench_serving.py +2238 -0
- sglang/python/sglang/benchmark/__init__.py +0 -0
- sglang/python/sglang/benchmark/__pycache__/__init__.cpython-311.pyc +0 -0
- sglang/python/sglang/benchmark/__pycache__/utils.cpython-311.pyc +0 -0
- sglang/python/sglang/benchmark/datasets/__init__.py +47 -0
- sglang/python/sglang/benchmark/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
- sglang/python/sglang/benchmark/datasets/__pycache__/common.cpython-311.pyc +0 -0
- sglang/python/sglang/benchmark/datasets/__pycache__/custom.cpython-311.pyc +0 -0
- sglang/python/sglang/benchmark/datasets/__pycache__/generated_shared_prefix.cpython-311.pyc +0 -0
- sglang/python/sglang/benchmark/datasets/__pycache__/image.cpython-311.pyc +0 -0
- sglang/python/sglang/benchmark/datasets/__pycache__/mmmu.cpython-311.pyc +0 -0
sglang/.claude/skills/add-jit-kernel/SKILL.md
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: add-jit-kernel
|
| 3 |
+
description: Step-by-step tutorial for adding a new lightweight JIT CUDA kernel to sglang's jit_kernel module
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Tutorial: Adding a New JIT Kernel to SGLang
|
| 7 |
+
|
| 8 |
+
This tutorial walks through adding a simple element-wise scale operation as a JIT kernel. We'll implement `scale(x, factor) = x * factor` to demonstrate the complete workflow.
|
| 9 |
+
|
| 10 |
+
## Goal
|
| 11 |
+
|
| 12 |
+
Add a new operation that scales each element of a tensor by a scalar factor:
|
| 13 |
+
|
| 14 |
+
- Input: tensor `x` (CUDA) and scalar `factor` (float, passed as C++ template argument)
|
| 15 |
+
- Output: `x * factor` (element-wise), allocated internally
|
| 16 |
+
- Supported dtypes: **FP16 (`torch.float16`), BF16 (`torch.bfloat16`), FP32 (`torch.float32`)**
|
| 17 |
+
|
| 18 |
+
## When to use JIT vs AOT (`sgl-kernel`)
|
| 19 |
+
|
| 20 |
+
- **JIT (`jit_kernel`)**: lightweight, few dependencies, rapid iteration, compiled on first use
|
| 21 |
+
- **AOT (`sgl-kernel`)**: depends on CUTLASS / FlashInfer / DeepGEMM, needs pre-built wheel
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
## Common Abstractions in `python/sglang/jit_kernel/include/sgl_kernel/`
|
| 26 |
+
|
| 27 |
+
**Always prefer these abstractions over raw CUDA primitives.** They provide safety, readability, and consistency with the rest of the codebase.
|
| 28 |
+
|
| 29 |
+
### `utils.h` — Host-side utilities
|
| 30 |
+
|
| 31 |
+
```cpp
|
| 32 |
+
#include <sgl_kernel/utils.h>
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
- **`host::RuntimeCheck(cond, args...)`** — Assert a condition at runtime; throws `PanicError` with file/line info on failure. Prefer this over bare `assert`.
|
| 36 |
+
- **`host::Panic(args...)`** — Unconditionally throw a `PanicError` with a descriptive message.
|
| 37 |
+
- **`host::div_ceil(a, b)`** — Integer ceiling division `(a + b - 1) / b`.
|
| 38 |
+
- **`host::irange(n)`** / **`host::irange(start, end)`** — Range views for cleaner loops.
|
| 39 |
+
- **`host::pointer::offset(ptr, offsets...)`** — Byte-safe pointer arithmetic on `void*`. Use this instead of raw casts.
|
| 40 |
+
|
| 41 |
+
### `utils.cuh` — Device-side utilities + `LaunchKernel`
|
| 42 |
+
|
| 43 |
+
```cpp
|
| 44 |
+
#include <sgl_kernel/utils.cuh>
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
- **Type aliases**: `fp16_t`, `bf16_t`, `fp32_t`, `fp8_e4m3_t`, `fp8_e5m2_t` and their packed variants `fp16x2_t`, `bf16x2_t`, `fp32x2_t`, etc.
|
| 48 |
+
- **`SGL_DEVICE`** — Expands to `__forceinline__ __device__`. Use on all device functions.
|
| 49 |
+
- **`device::kWarpThreads`** — Constant `32`.
|
| 50 |
+
- **`device::load_as<T>(ptr, offset)`** / **`device::store_as<T>(ptr, val, offset)`** — Type-safe loads/stores from `void*`.
|
| 51 |
+
- **`device::pointer::offset(ptr, offsets...)`** — Pointer arithmetic on device.
|
| 52 |
+
- **`host::LaunchKernel(grid, block, device_or_stream [, smem])`** — RAII kernel launcher that:
|
| 53 |
+
- Resolves the CUDA stream from a `DLDevice` via TVM-FFI automatically.
|
| 54 |
+
- Checks the CUDA error with file/line info after launch via `operator()(kernel, args...)`.
|
| 55 |
+
- Supports `.enable_pdl(bool)` for PDL (Programmatic Dependent Launch, SM90+).
|
| 56 |
+
- **`host::RuntimeDeviceCheck(cudaError_t)`** — Check a CUDA error; throw on failure.
|
| 57 |
+
|
| 58 |
+
### `tensor.h` — Tensor validation (`TensorMatcher`, Symbolic types)
|
| 59 |
+
|
| 60 |
+
```cpp
|
| 61 |
+
#include <sgl_kernel/tensor.h>
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
This is the **primary validation API** for all kernel launchers. Use it to validate every `tvm::ffi::TensorView` argument.
|
| 65 |
+
|
| 66 |
+
- **`host::SymbolicSize{"name"}`** — A named symbolic dimension. Call `.set_value(n)` to pin it, `.unwrap()` to extract after verification.
|
| 67 |
+
- **`host::SymbolicDType`** — Symbolic dtype. Use `.set_options<Ts...>()` to restrict allowed types.
|
| 68 |
+
- **`host::SymbolicDevice`** — Symbolic device. Use `.set_options<kDLCUDA>()` to restrict to CUDA.
|
| 69 |
+
- **`host::TensorMatcher({dims...})`** — Fluent builder for tensor validation:
|
| 70 |
+
- `.with_dtype<T>()` — require a specific C++ type (e.g. `fp16_t`)
|
| 71 |
+
- `.with_dtype<T1, T2, ...>()` — allow a set of types
|
| 72 |
+
- `.with_device<kDLCUDA>(device_sym)` — require CUDA, bind device to symbol
|
| 73 |
+
- `.with_strides({strides...})` — validate strides (omit to require contiguous)
|
| 74 |
+
- `.verify(tensor_view)` — execute the check; throws `PanicError` with full context on failure; **chainable** (`verify(a).verify(b)` to check multiple tensors with the same shape)
|
| 75 |
+
|
| 76 |
+
**Typical pattern:**
|
| 77 |
+
```cpp
|
| 78 |
+
auto N = SymbolicSize{"num_elements"};
|
| 79 |
+
auto device = SymbolicDevice{};
|
| 80 |
+
device.set_options<kDLCUDA>();
|
| 81 |
+
TensorMatcher({N}) //
|
| 82 |
+
.with_dtype<fp16_t>()
|
| 83 |
+
.with_device(device)
|
| 84 |
+
.verify(dst)
|
| 85 |
+
.verify(src); // same shape, dtype, device as dst
|
| 86 |
+
const size_t n = N.unwrap();
|
| 87 |
+
const DLDevice dev = device.unwrap();
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
### `type.cuh` — `dtype_trait<T>` and `packed_t<T>`
|
| 91 |
+
|
| 92 |
+
```cpp
|
| 93 |
+
#include <sgl_kernel/type.cuh>
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
- **`dtype_trait<T>`** — Static trait struct for each scalar type. Provides:
|
| 97 |
+
- `dtype_trait<T>::from(value)` — convert from another type (e.g. `fp32_t` → `fp16_t`)
|
| 98 |
+
- `dtype_trait<T>::abs/sqrt/rsqrt/max/min(x)` — type-dispatched math (for `fp32_t`)
|
| 99 |
+
- **`packed_t<T>`** — Two-element packed alias: `packed_t<fp16_t>` = `fp16x2_t`, `packed_t<bf16_t>` = `bf16x2_t`, `packed_t<fp32_t>` = `fp32x2_t`. Use for vectorized loads/stores.
|
| 100 |
+
- **`device::cast<To, From>(value)`** — Type-safe cast using `dtype_trait`, e.g. `cast<fp32x2_t, fp16x2_t>(v)`.
|
| 101 |
+
|
| 102 |
+
### `vec.cuh` — Vectorized memory access (`AlignedVector`)
|
| 103 |
+
|
| 104 |
+
```cpp
|
| 105 |
+
#include <sgl_kernel/vec.cuh>
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
- **`device::AlignedVector<T, N>`** — Aligned storage for N elements of type T. N must be a power of two, `sizeof(T)*N <= 32`. Enables 128-bit vector loads/stores for bandwidth efficiency.
|
| 109 |
+
- `.load(ptr, offset)` — vectorized load from `ptr[offset]`
|
| 110 |
+
- `.store(ptr, offset)` — vectorized store to `ptr[offset]`
|
| 111 |
+
- `.fill(value)` — fill all lanes
|
| 112 |
+
- `operator[](i)` — element access
|
| 113 |
+
|
| 114 |
+
### `tile.cuh` — `tile::Memory` (strided memory access pattern)
|
| 115 |
+
|
| 116 |
+
```cpp
|
| 117 |
+
#include <sgl_kernel/tile.cuh>
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
- **`device::tile::Memory<T>::cta(blockDim.x)`** — Creates a tile accessor where each thread handles `tid = threadIdx.x` with stride `blockDim.x`. Common for loops over a 1D array.
|
| 121 |
+
- **`.load(ptr, offset)`** — loads `ptr[tid + offset * blockDim.x]`
|
| 122 |
+
- **`.store(ptr, val, offset)`** — stores to `ptr[tid + offset * blockDim.x]`
|
| 123 |
+
- **`.in_bound(n, offset)`** — boundary check
|
| 124 |
+
|
| 125 |
+
### `math.cuh` — Device math (`device::math::`)
|
| 126 |
+
|
| 127 |
+
```cpp
|
| 128 |
+
#include <sgl_kernel/math.cuh>
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
- `device::math::max/min/abs/sqrt/rsqrt<T>(a, b)` — type-dispatched math via `dtype_trait`
|
| 132 |
+
- `device::math::exp/sin/cos(float)` — fast float math wrappers
|
| 133 |
+
|
| 134 |
+
### `warp.cuh` — Warp-level primitives
|
| 135 |
+
|
| 136 |
+
```cpp
|
| 137 |
+
#include <sgl_kernel/warp.cuh>
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
- `device::warp::reduce_sum<T>(value)` — warp-level sum reduction via `__shfl_xor_sync`
|
| 141 |
+
- `device::warp::reduce_max<T>(value)` — warp-level max reduction
|
| 142 |
+
|
| 143 |
+
### `cta.cuh` — CTA-level primitives
|
| 144 |
+
|
| 145 |
+
```cpp
|
| 146 |
+
#include <sgl_kernel/cta.cuh>
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
- `device::cta::reduce_max<T>(value, smem, min_value)` — CTA-wide max using shared memory + warp reduction. Caller is responsible for a `__syncthreads()` after if the result in `smem[0]` is needed.
|
| 150 |
+
|
| 151 |
+
### `atomic.cuh` — Atomic operations
|
| 152 |
+
|
| 153 |
+
```cpp
|
| 154 |
+
#include <sgl_kernel/atomic.cuh>
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
- `device::atomic::max(float* addr, float value)` — float atomic max (handles negative values correctly via bit tricks).
|
| 158 |
+
|
| 159 |
+
### `runtime.cuh` — Occupancy and device info
|
| 160 |
+
|
| 161 |
+
```cpp
|
| 162 |
+
#include <sgl_kernel/runtime.cuh>
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
- `host::runtime::get_blocks_per_sm(kernel, block_dim)` — max active blocks per SM (occupancy)
|
| 166 |
+
- `host::runtime::get_sm_count(device_id)` — number of SMs on the device
|
| 167 |
+
- `host::runtime::get_cc_major(device_id)` — compute capability major version
|
| 168 |
+
|
| 169 |
+
**Persistent kernel pattern** (cap blocks to SM count × occupancy):
|
| 170 |
+
```cpp
|
| 171 |
+
static const uint32_t max_occ = runtime::get_blocks_per_sm(kernel, kBlockSize);
|
| 172 |
+
static const uint32_t num_sm = runtime::get_sm_count(device.unwrap().device_id);
|
| 173 |
+
const auto num_blocks = std::min(num_sm * max_occ, div_ceil(n, kBlockSize));
|
| 174 |
+
LaunchKernel(num_blocks, kBlockSize, device.unwrap())(kernel, params);
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
---
|
| 178 |
+
|
| 179 |
+
## Step 0 (optional): Generate a `.clangd` config for better IDE support
|
| 180 |
+
|
| 181 |
+
```bash
|
| 182 |
+
python -m sglang.jit_kernel
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
---
|
| 186 |
+
|
| 187 |
+
## Step 1: Implement the CUDA kernel in `jit_kernel/csrc/`
|
| 188 |
+
|
| 189 |
+
Create `python/sglang/jit_kernel/csrc/elementwise/scale.cuh`.
|
| 190 |
+
|
| 191 |
+
The implementation fully uses the project abstractions described above:
|
| 192 |
+
|
| 193 |
+
```cpp
|
| 194 |
+
#include <sgl_kernel/tensor.h> // TensorMatcher, SymbolicSize, SymbolicDevice
|
| 195 |
+
#include <sgl_kernel/type.cuh> // dtype_trait, fp16_t, bf16_t, fp32_t
|
| 196 |
+
#include <sgl_kernel/utils.h> // RuntimeCheck, div_ceil
|
| 197 |
+
#include <sgl_kernel/utils.cuh> // LaunchKernel, SGL_DEVICE
|
| 198 |
+
#include <sgl_kernel/vec.cuh> // AlignedVector
|
| 199 |
+
|
| 200 |
+
#include <dlpack/dlpack.h>
|
| 201 |
+
#include <tvm/ffi/container/tensor.h>
|
| 202 |
+
|
| 203 |
+
namespace {
|
| 204 |
+
|
| 205 |
+
// ----------------------------------------------------------------
|
| 206 |
+
// Kernel: element-wise scale using vectorized 128-bit loads/stores
|
| 207 |
+
// T = fp16_t | bf16_t | fp32_t
|
| 208 |
+
// kVecN = number of elements per vector load (e.g. 8 for fp16)
|
| 209 |
+
// kFactor = scale factor encoded as kFactorNumer / kFactorDenom
|
| 210 |
+
// ----------------------------------------------------------------
|
| 211 |
+
template <typename T, int kVecN, int32_t kFactorNumer, int32_t kFactorDenom>
|
| 212 |
+
__global__ void scale_kernel(T* __restrict__ dst,
|
| 213 |
+
const T* __restrict__ src,
|
| 214 |
+
uint32_t n_vecs,
|
| 215 |
+
uint32_t n_remainder,
|
| 216 |
+
uint32_t n_total) {
|
| 217 |
+
constexpr float kFactor = static_cast<float>(kFactorNumer)
|
| 218 |
+
/ static_cast<float>(kFactorDenom);
|
| 219 |
+
|
| 220 |
+
using vec_t = device::AlignedVector<T, kVecN>;
|
| 221 |
+
|
| 222 |
+
// --- vectorised body ---
|
| 223 |
+
const uint32_t vec_stride = blockDim.x * gridDim.x;
|
| 224 |
+
for (uint32_t vi = blockIdx.x * blockDim.x + threadIdx.x;
|
| 225 |
+
vi < n_vecs;
|
| 226 |
+
vi += vec_stride) {
|
| 227 |
+
vec_t v;
|
| 228 |
+
v.load(src, vi);
|
| 229 |
+
#pragma unroll
|
| 230 |
+
for (int i = 0; i < kVecN; ++i) {
|
| 231 |
+
v[i] = static_cast<T>(static_cast<float>(v[i]) * kFactor);
|
| 232 |
+
}
|
| 233 |
+
v.store(dst, vi);
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
// --- scalar tail ---
|
| 237 |
+
const uint32_t base = n_vecs * kVecN;
|
| 238 |
+
const uint32_t scalar_stride = blockDim.x * gridDim.x;
|
| 239 |
+
for (uint32_t i = blockIdx.x * blockDim.x + threadIdx.x;
|
| 240 |
+
i < n_remainder;
|
| 241 |
+
i += scalar_stride) {
|
| 242 |
+
dst[base + i] = static_cast<T>(static_cast<float>(src[base + i]) * kFactor);
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
// ----------------------------------------------------------------
|
| 247 |
+
// Launcher: validates tensors, selects vector width, launches kernel
|
| 248 |
+
// ----------------------------------------------------------------
|
| 249 |
+
template <typename T, int32_t kFactorNumer, int32_t kFactorDenom>
|
| 250 |
+
void scale(tvm::ffi::TensorView dst, tvm::ffi::TensorView src) {
|
| 251 |
+
using namespace host;
|
| 252 |
+
|
| 253 |
+
// 1. Validate input tensors with TensorMatcher
|
| 254 |
+
SymbolicSize N = {"num_elements"};
|
| 255 |
+
SymbolicDevice device_;
|
| 256 |
+
device_.set_options<kDLCUDA>();
|
| 257 |
+
|
| 258 |
+
TensorMatcher({N}) //
|
| 259 |
+
.with_dtype<T>()
|
| 260 |
+
.with_device(device_)
|
| 261 |
+
.verify(dst)
|
| 262 |
+
.verify(src); // same shape / dtype / device as dst
|
| 263 |
+
|
| 264 |
+
const uint32_t n = static_cast<uint32_t>(N.unwrap());
|
| 265 |
+
const DLDevice device = device_.unwrap();
|
| 266 |
+
|
| 267 |
+
RuntimeCheck(n > 0, "scale: num_elements must be > 0, got ", n);
|
| 268 |
+
|
| 269 |
+
// 2. Choose vector width for 128-bit loads (16 bytes)
|
| 270 |
+
// fp16/bf16: 8 elements × 2 bytes = 16 bytes
|
| 271 |
+
// fp32: 4 elements × 4 bytes = 16 bytes
|
| 272 |
+
constexpr int kVecN = 16 / sizeof(T);
|
| 273 |
+
const uint32_t n_vecs = n / kVecN;
|
| 274 |
+
const uint32_t n_remainder = n % kVecN;
|
| 275 |
+
|
| 276 |
+
// 3. Launch
|
| 277 |
+
constexpr uint32_t kBlockSize = 256;
|
| 278 |
+
const uint32_t grid = div_ceil(std::max(n_vecs, n_remainder), kBlockSize);
|
| 279 |
+
|
| 280 |
+
LaunchKernel(grid, kBlockSize, device)(
|
| 281 |
+
scale_kernel<T, kVecN, kFactorNumer, kFactorDenom>,
|
| 282 |
+
static_cast<T*>(dst.data_ptr()),
|
| 283 |
+
static_cast<const T*>(src.data_ptr()),
|
| 284 |
+
n_vecs,
|
| 285 |
+
n_remainder,
|
| 286 |
+
n);
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
} // namespace
|
| 290 |
+
```
|
| 291 |
+
|
| 292 |
+
**Key points:**
|
| 293 |
+
|
| 294 |
+
- Include headers from `sgl_kernel/` — **not** raw CUDA headers for anything already covered
|
| 295 |
+
- Use `TensorMatcher` for all tensor validation; never manually check shape/dtype/device
|
| 296 |
+
- Use `AlignedVector` for vectorised 128-bit loads/stores — significant bandwidth win
|
| 297 |
+
- Use `LaunchKernel` — it resolves the stream and checks errors automatically
|
| 298 |
+
- Use `RuntimeCheck` for runtime assertions with useful error messages
|
| 299 |
+
- `fp16_t` / `bf16_t` / `fp32_t` are the project's type aliases (from `utils.cuh`)
|
| 300 |
+
- `device::cast<To, From>` or `dtype_trait<T>::from(val)` for cross-type conversions
|
| 301 |
+
- `device::math::` functions for device math instead of bare `__` intrinsics
|
| 302 |
+
|
| 303 |
+
---
|
| 304 |
+
|
| 305 |
+
## Step 2: Add the Python wrapper in `jit_kernel/`
|
| 306 |
+
|
| 307 |
+
Create `python/sglang/jit_kernel/scale.py`:
|
| 308 |
+
|
| 309 |
+
```python
|
| 310 |
+
from __future__ import annotations
|
| 311 |
+
|
| 312 |
+
from typing import TYPE_CHECKING
|
| 313 |
+
|
| 314 |
+
import torch
|
| 315 |
+
|
| 316 |
+
from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args
|
| 317 |
+
|
| 318 |
+
if TYPE_CHECKING:
|
| 319 |
+
from tvm_ffi.module import Module
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
@cache_once
|
| 323 |
+
def _jit_scale_module(dtype: torch.dtype, factor_numer: int, factor_denom: int) -> Module:
|
| 324 |
+
"""Compile and cache the JIT scale module for a given dtype and factor."""
|
| 325 |
+
args = make_cpp_args(dtype, factor_numer, factor_denom)
|
| 326 |
+
return load_jit(
|
| 327 |
+
"scale",
|
| 328 |
+
*args,
|
| 329 |
+
cuda_files=["elementwise/scale.cuh"],
|
| 330 |
+
cuda_wrappers=[("scale", f"scale<{args}>")],
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def scale(src: torch.Tensor, factor: float, out: torch.Tensor | None = None) -> torch.Tensor:
|
| 335 |
+
"""
|
| 336 |
+
Element-wise scale: dst = src * factor.
|
| 337 |
+
|
| 338 |
+
Supported dtypes: torch.float16, torch.bfloat16, torch.float32.
|
| 339 |
+
|
| 340 |
+
Parameters
|
| 341 |
+
----------
|
| 342 |
+
src : CUDA tensor (FP16 / BF16 / FP32)
|
| 343 |
+
factor : scale factor
|
| 344 |
+
out : optional pre-allocated output tensor (same shape/dtype as src)
|
| 345 |
+
|
| 346 |
+
Returns
|
| 347 |
+
-------
|
| 348 |
+
Scaled tensor (dst = src * factor).
|
| 349 |
+
"""
|
| 350 |
+
assert src.is_cuda, "src must be a CUDA tensor"
|
| 351 |
+
assert src.dtype in (torch.float16, torch.bfloat16, torch.float32), (
|
| 352 |
+
f"Unsupported dtype {src.dtype}. Supported: float16, bfloat16, float32"
|
| 353 |
+
)
|
| 354 |
+
if out is None:
|
| 355 |
+
out = torch.empty_like(src)
|
| 356 |
+
else:
|
| 357 |
+
assert out.shape == src.shape, "out shape must match src"
|
| 358 |
+
assert out.dtype == src.dtype, "out dtype must match src"
|
| 359 |
+
|
| 360 |
+
# Encode factor as integer ratio; denom=1000 gives 3 decimal places of precision
|
| 361 |
+
factor_denom = 1000
|
| 362 |
+
factor_numer = round(factor * factor_denom)
|
| 363 |
+
|
| 364 |
+
module = _jit_scale_module(src.dtype, factor_numer, factor_denom)
|
| 365 |
+
module.scale(out, src)
|
| 366 |
+
return out
|
| 367 |
+
```
|
| 368 |
+
|
| 369 |
+
**Key points:**
|
| 370 |
+
|
| 371 |
+
- Use `cache_once` — **not** `functools.lru_cache` (incompatible with `torch.compile`)
|
| 372 |
+
- `load_jit` first arg(s) form the unique build marker; same marker = same cached binary
|
| 373 |
+
- `cuda_wrappers`: `(export_name, kernel_symbol)` — `export_name` is called from Python
|
| 374 |
+
- `make_cpp_args(dtype, ...)` converts `torch.dtype` to C++ type alias:
|
| 375 |
+
|
| 376 |
+
| `torch.dtype` | C++ type |
|
| 377 |
+
|--------------------|------------|
|
| 378 |
+
| `torch.float16` | `fp16_t` |
|
| 379 |
+
| `torch.bfloat16` | `bf16_t` |
|
| 380 |
+
| `torch.float32` | `fp32_t` |
|
| 381 |
+
|
| 382 |
+
---
|
| 383 |
+
|
| 384 |
+
## Step 3 (optional): Tune JIT build flags
|
| 385 |
+
|
| 386 |
+
```python
|
| 387 |
+
return load_jit(
|
| 388 |
+
"scale",
|
| 389 |
+
*args,
|
| 390 |
+
cuda_files=["elementwise/scale.cuh"],
|
| 391 |
+
cuda_wrappers=[("scale", f"scale<{args}>")],
|
| 392 |
+
extra_cuda_cflags=["-O3", "--use_fast_math"],
|
| 393 |
+
)
|
| 394 |
+
```
|
| 395 |
+
|
| 396 |
+
If your kernel requires SM90+, raise a clear Python error before calling `load_jit`:
|
| 397 |
+
|
| 398 |
+
```python
|
| 399 |
+
if torch.cuda.get_device_capability()[0] < 9:
|
| 400 |
+
raise RuntimeError("This kernel requires SM90 (Hopper) or later")
|
| 401 |
+
```
|
| 402 |
+
|
| 403 |
+
---
|
| 404 |
+
|
| 405 |
+
## Step 4: Write tests (required)
|
| 406 |
+
|
| 407 |
+
Create `python/sglang/jit_kernel/tests/test_scale.py`:
|
| 408 |
+
|
| 409 |
+
```python
|
| 410 |
+
import pytest
|
| 411 |
+
import torch
|
| 412 |
+
from sglang.jit_kernel.scale import scale
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
| 416 |
+
@pytest.mark.parametrize("size", [1, 127, 128, 1024, 4097]) # cover tail remainder
|
| 417 |
+
@pytest.mark.parametrize("factor", [0.5, 1.0, 2.0, 3.0])
|
| 418 |
+
def test_scale_correctness(dtype, size, factor):
|
| 419 |
+
src = torch.randn(size, dtype=dtype, device="cuda")
|
| 420 |
+
out = scale(src, factor)
|
| 421 |
+
expected = src * factor
|
| 422 |
+
|
| 423 |
+
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-2, 1e-2)
|
| 424 |
+
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
| 428 |
+
def test_scale_out_param(dtype):
|
| 429 |
+
src = torch.randn(1024, dtype=dtype, device="cuda")
|
| 430 |
+
out = torch.empty_like(src)
|
| 431 |
+
result = scale(src, 2.0, out=out)
|
| 432 |
+
assert result is out
|
| 433 |
+
torch.testing.assert_close(out, src * 2.0, rtol=1e-2, atol=1e-2)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def test_scale_cpu_error():
|
| 437 |
+
src = torch.randn(128, dtype=torch.float16) # CPU tensor
|
| 438 |
+
with pytest.raises(AssertionError, match="CUDA"):
|
| 439 |
+
scale(src, 2.0)
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def test_scale_unsupported_dtype():
|
| 443 |
+
src = torch.randint(0, 10, (128,), dtype=torch.int32, device="cuda")
|
| 444 |
+
with pytest.raises(AssertionError, match="Unsupported dtype"):
|
| 445 |
+
scale(src, 2.0)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
if __name__ == "__main__":
|
| 449 |
+
pytest.main([__file__, "-v", "-s"])
|
| 450 |
+
```
|
| 451 |
+
|
| 452 |
+
---
|
| 453 |
+
|
| 454 |
+
## Step 5: Add a benchmark (required)
|
| 455 |
+
|
| 456 |
+
Create `python/sglang/jit_kernel/benchmark/bench_scale.py`:
|
| 457 |
+
|
| 458 |
+
```python
|
| 459 |
+
import itertools
|
| 460 |
+
|
| 461 |
+
import torch
|
| 462 |
+
import triton
|
| 463 |
+
import triton.testing
|
| 464 |
+
|
| 465 |
+
from sglang.jit_kernel.benchmark.utils import (
|
| 466 |
+
DEFAULT_DEVICE,
|
| 467 |
+
DEFAULT_DTYPE,
|
| 468 |
+
get_benchmark_range,
|
| 469 |
+
run_benchmark,
|
| 470 |
+
)
|
| 471 |
+
from sglang.jit_kernel.scale import scale as jit_scale
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
SIZE_LIST = get_benchmark_range(
|
| 475 |
+
full_range=[2**n for n in range(10, 20)], # 1K … 512K elements
|
| 476 |
+
ci_range=[4096, 65536],
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
configs = list(itertools.product(SIZE_LIST))
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
@triton.testing.perf_report(
|
| 483 |
+
triton.testing.Benchmark(
|
| 484 |
+
x_names=["size"],
|
| 485 |
+
x_vals=configs,
|
| 486 |
+
line_arg="provider",
|
| 487 |
+
line_vals=["jit", "torch"],
|
| 488 |
+
line_names=["SGL JIT Kernel", "PyTorch"],
|
| 489 |
+
styles=[("blue", "-"), ("red", "--")],
|
| 490 |
+
ylabel="us",
|
| 491 |
+
plot_name="scale-performance",
|
| 492 |
+
args={},
|
| 493 |
+
)
|
| 494 |
+
)
|
| 495 |
+
def benchmark(size: int, provider: str):
|
| 496 |
+
src = torch.randn(size, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE)
|
| 497 |
+
factor = 2.0
|
| 498 |
+
|
| 499 |
+
if provider == "jit":
|
| 500 |
+
fn = lambda: jit_scale(src, factor)
|
| 501 |
+
else:
|
| 502 |
+
fn = lambda: src * factor
|
| 503 |
+
|
| 504 |
+
return run_benchmark(fn)
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
if __name__ == "__main__":
|
| 508 |
+
benchmark.run(print_data=True)
|
| 509 |
+
```
|
| 510 |
+
|
| 511 |
+
Run:
|
| 512 |
+
|
| 513 |
+
```bash
|
| 514 |
+
python python/sglang/jit_kernel/benchmark/bench_scale.py
|
| 515 |
+
```
|
| 516 |
+
|
| 517 |
+
---
|
| 518 |
+
|
| 519 |
+
## Troubleshooting
|
| 520 |
+
|
| 521 |
+
- **JIT compilation fails**: ensure the `.cuh` file is under `python/sglang/jit_kernel/csrc/`; reduce template argument combinations
|
| 522 |
+
- **CUDA crash / illegal memory access**: `CUDA_LAUNCH_BLOCKING=1`; `compute-sanitizer --tool memcheck python ...`
|
| 523 |
+
- **Unstable benchmark results**: `run_benchmark` uses CUDA-graph-based timing by default
|
| 524 |
+
|
| 525 |
+
---
|
| 526 |
+
|
| 527 |
+
## References
|
| 528 |
+
|
| 529 |
+
- `docs/developer_guide/development_jit_kernel_guide.md`
|
| 530 |
+
- `python/sglang/jit_kernel/utils.py` — `cache_once`, `load_jit`, `make_cpp_args`
|
| 531 |
+
- `python/sglang/jit_kernel/include/sgl_kernel/tensor.h` — `TensorMatcher`, `SymbolicSize/DType/Device`
|
| 532 |
+
- `python/sglang/jit_kernel/include/sgl_kernel/utils.cuh` — type aliases, `LaunchKernel`, `SGL_DEVICE`
|
| 533 |
+
- `python/sglang/jit_kernel/include/sgl_kernel/vec.cuh` — `AlignedVector`
|
| 534 |
+
- `python/sglang/jit_kernel/include/sgl_kernel/tile.cuh` — `tile::Memory`
|
| 535 |
+
- `python/sglang/jit_kernel/include/sgl_kernel/type.cuh` — `dtype_trait`, `packed_t`, `device::cast`
|
| 536 |
+
- `python/sglang/jit_kernel/include/sgl_kernel/math.cuh` — `device::math::`
|
| 537 |
+
- `python/sglang/jit_kernel/include/sgl_kernel/warp.cuh` — `warp::reduce_sum/max`
|
| 538 |
+
- `python/sglang/jit_kernel/include/sgl_kernel/cta.cuh` — `cta::reduce_max`
|
| 539 |
+
- `python/sglang/jit_kernel/include/sgl_kernel/atomic.cuh` — `atomic::max`
|
| 540 |
+
- `python/sglang/jit_kernel/include/sgl_kernel/runtime.cuh` — occupancy / SM count helpers
|
| 541 |
+
- `python/sglang/jit_kernel/csrc/add_constant.cuh` — minimal runnable reference
|
| 542 |
+
- `python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh` — real example using `TensorMatcher` + `LaunchKernel` + `tile::Memory`
|
| 543 |
+
- `python/sglang/jit_kernel/csrc/elementwise/qknorm.cuh` — real example using `runtime::get_blocks_per_sm` + persistent kernel pattern
|
| 544 |
+
- `python/sglang/jit_kernel/benchmark/utils.py` — benchmark helpers
|
| 545 |
+
|
| 546 |
+
## Summary of Files Created
|
| 547 |
+
|
| 548 |
+
```
|
| 549 |
+
python/sglang/jit_kernel/csrc/elementwise/scale.cuh # NEW: CUDA kernel
|
| 550 |
+
python/sglang/jit_kernel/scale.py # NEW: Python wrapper
|
| 551 |
+
python/sglang/jit_kernel/tests/test_scale.py # NEW: Tests
|
| 552 |
+
python/sglang/jit_kernel/benchmark/bench_scale.py # NEW: Benchmark
|
| 553 |
+
```
|
sglang/.claude/skills/add-sgl-kernel/SKILL.md
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: add-sgl-kernel
|
| 3 |
+
description: Step-by-step tutorial for adding a heavyweight AOT CUDA/C++ kernel to sgl-kernel (including tests & benchmarks)
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Tutorial: Adding a New Kernel to `sgl-kernel` (AOT / Heavyweight)
|
| 7 |
+
|
| 8 |
+
This tutorial walks through adding a simple element-wise scale operation as an AOT kernel. We'll implement `scale(x, factor) = x * factor` to demonstrate the complete workflow.
|
| 9 |
+
|
| 10 |
+
## Goal
|
| 11 |
+
|
| 12 |
+
Add a new operation that scales each element of a tensor by a scalar factor:
|
| 13 |
+
|
| 14 |
+
- Input: tensor `x` (CUDA) and scalar `factor` (float)
|
| 15 |
+
- Output: `x * factor` (element-wise, in-place or into pre-allocated `out`)
|
| 16 |
+
- Supported dtypes: **FP16 (`torch.float16`), BF16 (`torch.bfloat16`), FP32 (`torch.float32`)**
|
| 17 |
+
- Dispatched via `DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16` macro (defined in `sgl-kernel/include/utils.h`)
|
| 18 |
+
|
| 19 |
+
## Two rules of thumb (must follow)
|
| 20 |
+
|
| 21 |
+
1. **Heavyweight kernels go to `sgl-kernel`.** If it depends on CUTLASS / FlashInfer / DeepGEMM (or similarly heavy stacks), implement it in `sgl-kernel/`.
|
| 22 |
+
2. **Lightweight kernels go to `python/sglang/jit_kernel`.** If it is small, has few dependencies, and benefits from rapid iteration, implement it as a JIT kernel instead.
|
| 23 |
+
|
| 24 |
+
In addition, every new kernel must ship with:
|
| 25 |
+
|
| 26 |
+
- **Tests** (pytest)
|
| 27 |
+
- **A benchmark script** (triton.testing)
|
| 28 |
+
|
| 29 |
+
---
|
| 30 |
+
|
| 31 |
+
## Repository integration map
|
| 32 |
+
|
| 33 |
+
You will typically touch these files/areas:
|
| 34 |
+
|
| 35 |
+
- Implementation: `sgl-kernel/csrc/elementwise/scale.cu` (pick the right subdirectory)
|
| 36 |
+
- Public declarations: `sgl-kernel/include/sgl_kernel_ops.h`
|
| 37 |
+
- Torch extension registration: `sgl-kernel/csrc/common_extension.cc`
|
| 38 |
+
- Build: `sgl-kernel/CMakeLists.txt` (`set(SOURCES ...)`)
|
| 39 |
+
- Python API: `sgl-kernel/python/sgl_kernel/` and `sgl-kernel/python/sgl_kernel/__init__.py`
|
| 40 |
+
- Tests: `sgl-kernel/tests/test_scale.py`
|
| 41 |
+
- Benchmarks: `sgl-kernel/benchmark/bench_scale.py`
|
| 42 |
+
|
| 43 |
+
---
|
| 44 |
+
|
| 45 |
+
## Step 1: Implement the kernel in `csrc/`
|
| 46 |
+
|
| 47 |
+
Pick the right subdirectory:
|
| 48 |
+
|
| 49 |
+
- `csrc/elementwise/` — for element-wise ops (our example)
|
| 50 |
+
- `csrc/gemm/`, `csrc/attention/`, `csrc/moe/` — for other categories
|
| 51 |
+
|
| 52 |
+
Create `sgl-kernel/csrc/elementwise/scale.cu`:
|
| 53 |
+
|
| 54 |
+
```cpp
|
| 55 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 56 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 57 |
+
#include <torch/all.h>
|
| 58 |
+
|
| 59 |
+
#include "utils.h" // DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
|
| 60 |
+
|
| 61 |
+
// scale_kernel: out[i] = input[i] * factor
|
| 62 |
+
// Supports float, half (__half), __nv_bfloat16 via template T
|
| 63 |
+
template <typename T>
|
| 64 |
+
__global__ void scale_kernel(T* __restrict__ out,
|
| 65 |
+
const T* __restrict__ input,
|
| 66 |
+
float factor,
|
| 67 |
+
int64_t n) {
|
| 68 |
+
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
|
| 69 |
+
if (idx < n) {
|
| 70 |
+
out[idx] = static_cast<T>(static_cast<float>(input[idx]) * factor);
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
void scale(at::Tensor& out, const at::Tensor& input, double factor) {
|
| 75 |
+
TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
|
| 76 |
+
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
|
| 77 |
+
TORCH_CHECK(out.is_cuda(), "out must be a CUDA tensor");
|
| 78 |
+
TORCH_CHECK(out.is_contiguous(), "out must be contiguous");
|
| 79 |
+
TORCH_CHECK(out.sizes() == input.sizes(), "out and input must have the same shape");
|
| 80 |
+
TORCH_CHECK(out.scalar_type() == input.scalar_type(),
|
| 81 |
+
"out and input must have the same dtype");
|
| 82 |
+
|
| 83 |
+
const int64_t n = input.numel();
|
| 84 |
+
const int threads = 256;
|
| 85 |
+
const int blocks = (n + threads - 1) / threads;
|
| 86 |
+
|
| 87 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 88 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
| 89 |
+
|
| 90 |
+
// Dispatches over float, float16, bfloat16
|
| 91 |
+
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
|
| 92 |
+
scale_kernel<c_type><<<blocks, threads, 0, stream>>>(
|
| 93 |
+
static_cast<c_type*>(out.data_ptr()),
|
| 94 |
+
static_cast<const c_type*>(input.data_ptr()),
|
| 95 |
+
static_cast<float>(factor),
|
| 96 |
+
n);
|
| 97 |
+
cudaError_t status = cudaGetLastError();
|
| 98 |
+
TORCH_CHECK(status == cudaSuccess,
|
| 99 |
+
"scale_kernel launch failed: ", cudaGetErrorString(status));
|
| 100 |
+
return true;
|
| 101 |
+
});
|
| 102 |
+
}
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
**Key points:**
|
| 106 |
+
|
| 107 |
+
- Use `at::Tensor` (PyTorch tensors), `TORCH_CHECK` for validation, `at::cuda::getCurrentCUDAStream()` for stream
|
| 108 |
+
- `DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16` covers `float`, `half` (FP16), `__nv_bfloat16` (BF16)
|
| 109 |
+
- Add device error checking after every kernel launch
|
| 110 |
+
- If a kernel only works on certain architectures, enforce that with `TORCH_CHECK` and skip logic in tests
|
| 111 |
+
|
| 112 |
+
---
|
| 113 |
+
|
| 114 |
+
## Step 2: Add a C++ declaration in `include/sgl_kernel_ops.h`
|
| 115 |
+
|
| 116 |
+
Edit `sgl-kernel/include/sgl_kernel_ops.h`, add to the elementwise section:
|
| 117 |
+
|
| 118 |
+
```cpp
|
| 119 |
+
void scale(at::Tensor& out, const at::Tensor& input, double factor);
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
---
|
| 123 |
+
|
| 124 |
+
## Step 3: Register the op in `csrc/common_extension.cc`
|
| 125 |
+
|
| 126 |
+
Edit `sgl-kernel/csrc/common_extension.cc`, inside `TORCH_LIBRARY_FRAGMENT(sgl_kernel, m)`:
|
| 127 |
+
|
| 128 |
+
```cpp
|
| 129 |
+
// From csrc/elementwise
|
| 130 |
+
m.def("scale(Tensor! out, Tensor input, float factor) -> ()");
|
| 131 |
+
m.impl("scale", torch::kCUDA, &scale);
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
**Key points:**
|
| 135 |
+
|
| 136 |
+
- `Tensor!` means in-place / mutable output argument
|
| 137 |
+
- The schema is important for `torch.compile` and for consistent call signatures
|
| 138 |
+
- If your underlying C++ API uses `float` but PyTorch bindings expect `double`, the implicit cast is fine for scalars; use shims if needed for other types
|
| 139 |
+
|
| 140 |
+
---
|
| 141 |
+
|
| 142 |
+
## Step 4: Add the new source file to `CMakeLists.txt`
|
| 143 |
+
|
| 144 |
+
Edit `sgl-kernel/CMakeLists.txt`, add to `set(SOURCES ...)`:
|
| 145 |
+
|
| 146 |
+
```cmake
|
| 147 |
+
csrc/elementwise/scale.cu
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
**Key points:**
|
| 151 |
+
|
| 152 |
+
- Keep the list **alphabetically sorted** (the file explicitly requires this)
|
| 153 |
+
- If the kernel has arch constraints, reflect that in tests/benchmarks via skip logic
|
| 154 |
+
|
| 155 |
+
---
|
| 156 |
+
|
| 157 |
+
## Step 5: Expose a Python API under `sgl-kernel/python/sgl_kernel/`
|
| 158 |
+
|
| 159 |
+
In `sgl-kernel/python/sgl_kernel/__init__.py`, add:
|
| 160 |
+
|
| 161 |
+
```python
|
| 162 |
+
from torch.ops import sgl_kernel as _ops
|
| 163 |
+
|
| 164 |
+
def scale(out: torch.Tensor, input: torch.Tensor, factor: float) -> None:
|
| 165 |
+
"""
|
| 166 |
+
Element-wise scale: out = input * factor (in-place into out).
|
| 167 |
+
|
| 168 |
+
Supported dtypes: torch.float16, torch.bfloat16, torch.float32.
|
| 169 |
+
|
| 170 |
+
Parameters
|
| 171 |
+
----------
|
| 172 |
+
out : pre-allocated CUDA output tensor (same shape/dtype as input)
|
| 173 |
+
input : CUDA input tensor
|
| 174 |
+
factor : scale factor (float)
|
| 175 |
+
"""
|
| 176 |
+
_ops.scale(out, input, factor)
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
Or export it from the existing module organisation — follow the pattern already used by similar ops in `__init__.py`.
|
| 180 |
+
|
| 181 |
+
---
|
| 182 |
+
|
| 183 |
+
## Step 6: Write tests (required)
|
| 184 |
+
|
| 185 |
+
Create `sgl-kernel/tests/test_scale.py`:
|
| 186 |
+
|
| 187 |
+
```python
|
| 188 |
+
import pytest
|
| 189 |
+
import torch
|
| 190 |
+
import sgl_kernel
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
| 194 |
+
@pytest.mark.parametrize("size", [128, 1024, 4096, 65536])
|
| 195 |
+
@pytest.mark.parametrize("factor", [0.5, 1.0, 2.0])
|
| 196 |
+
def test_scale_correctness(dtype, size, factor):
|
| 197 |
+
input = torch.randn(size, dtype=dtype, device="cuda")
|
| 198 |
+
out = torch.empty_like(input)
|
| 199 |
+
|
| 200 |
+
sgl_kernel.scale(out, input, factor)
|
| 201 |
+
|
| 202 |
+
expected = input * factor
|
| 203 |
+
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-2, 1e-2)
|
| 204 |
+
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def test_scale_shape_mismatch():
|
| 208 |
+
input = torch.randn(128, dtype=torch.float16, device="cuda")
|
| 209 |
+
out = torch.empty(256, dtype=torch.float16, device="cuda")
|
| 210 |
+
with pytest.raises(RuntimeError, match="same shape"):
|
| 211 |
+
sgl_kernel.scale(out, input, 2.0)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def test_scale_cpu_input():
|
| 215 |
+
input = torch.randn(128, dtype=torch.float16) # CPU
|
| 216 |
+
out = torch.empty_like(input)
|
| 217 |
+
with pytest.raises(RuntimeError, match="CUDA"):
|
| 218 |
+
sgl_kernel.scale(out, input, 2.0)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
if __name__ == "__main__":
|
| 222 |
+
pytest.main([__file__, "-q"])
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
Run:
|
| 226 |
+
|
| 227 |
+
```bash
|
| 228 |
+
pytest sgl-kernel/tests/test_scale.py -q
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
---
|
| 232 |
+
|
| 233 |
+
## Step 7: Add a benchmark (required)
|
| 234 |
+
|
| 235 |
+
Create `sgl-kernel/benchmark/bench_scale.py`:
|
| 236 |
+
|
| 237 |
+
```python
|
| 238 |
+
import itertools
|
| 239 |
+
import os
|
| 240 |
+
|
| 241 |
+
import torch
|
| 242 |
+
import triton
|
| 243 |
+
import triton.testing
|
| 244 |
+
|
| 245 |
+
import sgl_kernel
|
| 246 |
+
|
| 247 |
+
IS_CI = (
|
| 248 |
+
os.getenv("CI", "false").lower() == "true"
|
| 249 |
+
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
dtypes = [torch.float16] if IS_CI else [torch.float16, torch.bfloat16, torch.float32]
|
| 253 |
+
sizes = [4096] if IS_CI else [2**n for n in range(10, 20)] # 1K … 512K
|
| 254 |
+
factors = [2.0]
|
| 255 |
+
|
| 256 |
+
configs = list(itertools.product(dtypes, sizes))
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def torch_scale(input: torch.Tensor, factor: float) -> torch.Tensor:
|
| 260 |
+
return input * factor
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
@triton.testing.perf_report(
|
| 264 |
+
triton.testing.Benchmark(
|
| 265 |
+
x_names=["dtype", "size"],
|
| 266 |
+
x_vals=configs,
|
| 267 |
+
line_arg="provider",
|
| 268 |
+
line_vals=["sglang", "torch"],
|
| 269 |
+
line_names=["SGL Kernel", "PyTorch"],
|
| 270 |
+
styles=[("green", "-"), ("red", "--")],
|
| 271 |
+
ylabel="µs (median)",
|
| 272 |
+
plot_name="scale-performance",
|
| 273 |
+
args={},
|
| 274 |
+
)
|
| 275 |
+
)
|
| 276 |
+
def benchmark(dtype, size, provider):
|
| 277 |
+
input = torch.randn(size, dtype=dtype, device="cuda")
|
| 278 |
+
out = torch.empty_like(input)
|
| 279 |
+
factor = 2.0
|
| 280 |
+
|
| 281 |
+
if provider == "sglang":
|
| 282 |
+
fn = lambda: sgl_kernel.scale(out, input, factor)
|
| 283 |
+
else:
|
| 284 |
+
fn = lambda: torch_scale(input, factor)
|
| 285 |
+
|
| 286 |
+
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
| 287 |
+
fn, quantiles=[0.5, 0.2, 0.8]
|
| 288 |
+
)
|
| 289 |
+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
if __name__ == "__main__":
|
| 293 |
+
benchmark.run(print_data=True)
|
| 294 |
+
```
|
| 295 |
+
|
| 296 |
+
Run:
|
| 297 |
+
|
| 298 |
+
```bash
|
| 299 |
+
python sgl-kernel/benchmark/bench_scale.py
|
| 300 |
+
```
|
| 301 |
+
|
| 302 |
+
---
|
| 303 |
+
|
| 304 |
+
## Step 8: Build and validate
|
| 305 |
+
|
| 306 |
+
Build:
|
| 307 |
+
|
| 308 |
+
```bash
|
| 309 |
+
cd sgl-kernel
|
| 310 |
+
make build -j16
|
| 311 |
+
```
|
| 312 |
+
|
| 313 |
+
If you need to limit host resource usage:
|
| 314 |
+
|
| 315 |
+
```bash
|
| 316 |
+
cd sgl-kernel
|
| 317 |
+
make build -j1 MAX_JOBS=2 CMAKE_ARGS="-DSGL_KERNEL_COMPILE_THREADS=1"
|
| 318 |
+
```
|
| 319 |
+
|
| 320 |
+
Validate:
|
| 321 |
+
|
| 322 |
+
```bash
|
| 323 |
+
pytest sgl-kernel/tests/test_scale.py -q
|
| 324 |
+
python sgl-kernel/benchmark/bench_scale.py
|
| 325 |
+
```
|
| 326 |
+
|
| 327 |
+
---
|
| 328 |
+
|
| 329 |
+
## Troubleshooting
|
| 330 |
+
|
| 331 |
+
- **Async CUDA errors**: `CUDA_LAUNCH_BLOCKING=1`
|
| 332 |
+
- **Memory errors**: `compute-sanitizer --tool memcheck python ...`
|
| 333 |
+
- **Build is too slow / OOM**: reduce `MAX_JOBS` and `SGL_KERNEL_COMPILE_THREADS`
|
| 334 |
+
- **Binary bloat**: use `sgl-kernel/analyze_whl_kernel_sizes.py`
|
| 335 |
+
- **CMake sources list**: if your `.cu` file is missing from `SOURCES`, the symbol will be undefined at link time
|
| 336 |
+
|
| 337 |
+
---
|
| 338 |
+
|
| 339 |
+
## References
|
| 340 |
+
|
| 341 |
+
- `sgl-kernel/README.md`
|
| 342 |
+
- `sgl-kernel/include/sgl_kernel_ops.h`
|
| 343 |
+
- `sgl-kernel/csrc/common_extension.cc`
|
| 344 |
+
- `sgl-kernel/CMakeLists.txt`
|
| 345 |
+
- `sgl-kernel/include/utils.h` — `DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16` macro and friends
|
| 346 |
+
- `sgl-kernel/csrc/elementwise/activation.cu` — reference for the FP16/BF16/FP32 dispatch pattern
|
| 347 |
+
|
| 348 |
+
## Summary of Files Created/Modified
|
| 349 |
+
|
| 350 |
+
```
|
| 351 |
+
sgl-kernel/csrc/elementwise/scale.cu # NEW: CUDA kernel + launcher
|
| 352 |
+
sgl-kernel/include/sgl_kernel_ops.h # MODIFIED: C++ declaration
|
| 353 |
+
sgl-kernel/csrc/common_extension.cc # MODIFIED: schema + dispatch registration
|
| 354 |
+
sgl-kernel/CMakeLists.txt # MODIFIED: add source file (alphabetical)
|
| 355 |
+
sgl-kernel/python/sgl_kernel/__init__.py # MODIFIED: export Python API
|
| 356 |
+
sgl-kernel/tests/test_scale.py # NEW: tests
|
| 357 |
+
sgl-kernel/benchmark/bench_scale.py # NEW: benchmark
|
| 358 |
+
```
|
sglang/.claude/skills/sglang-bisect-ci-regression/SKILL.md
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SGLang Bisect CI Regression
|
| 2 |
+
|
| 3 |
+
Investigate a consistently failing CI test to find the root cause - whether it's a code regression from a specific PR, a hardware/runner-specific issue, or an environment change. Optionally reproduce the failure on a remote GPU server.
|
| 4 |
+
|
| 5 |
+
## Slash Command
|
| 6 |
+
|
| 7 |
+
`/sglang-bisect-ci-regression <test_name_or_ci_url> [ssh_target] [docker_container]`
|
| 8 |
+
|
| 9 |
+
## When to Use This Skill
|
| 10 |
+
|
| 11 |
+
- A CI test is failing consistently on main (scheduled runs)
|
| 12 |
+
- You need to find which PR introduced a regression
|
| 13 |
+
- You suspect a runner-specific or GPU-specific issue
|
| 14 |
+
- You want to reproduce a CI failure on a remote server
|
| 15 |
+
|
| 16 |
+
## Arguments
|
| 17 |
+
|
| 18 |
+
- **First argument (required)**: Test file name (e.g. `test_lora_tp.py`) or a GitHub Actions job URL
|
| 19 |
+
- **Second argument (optional)**: SSH target for remote reproduction (e.g. `user@host`)
|
| 20 |
+
- **Third argument (optional)**: Docker container name on the SSH target (e.g. `sglang_dev`)
|
| 21 |
+
|
| 22 |
+
If SSH target and docker container are not provided, the skill will only perform the CI log analysis and bisection, without remote reproduction. **Ask the user** for these if reproduction is needed and they weren't provided.
|
| 23 |
+
|
| 24 |
+
## Background: Scheduled CI Runs
|
| 25 |
+
|
| 26 |
+
SGLang uses the `pr-test.yml` workflow with **scheduled runs** (cron-triggered) to periodically test the `main` branch. These runs are the primary data source for detecting regressions:
|
| 27 |
+
|
| 28 |
+
- **Workflow**: `pr-test.yml` with `event: schedule`
|
| 29 |
+
- **Branch**: `main`
|
| 30 |
+
- **Dashboard**: https://github.com/sgl-project/sglang/actions/workflows/pr-test.yml?query=event%3Aschedule
|
| 31 |
+
- **Frequency**: Runs multiple times daily, each pinned to the HEAD of `main` at trigger time
|
| 32 |
+
- **Purpose**: Catches regressions that slip through PR-level CI (e.g., interaction bugs between merged PRs, hardware-specific issues)
|
| 33 |
+
|
| 34 |
+
Always use these scheduled runs (not PR-triggered runs) when bisecting regressions on `main`. The `--event schedule` filter in `gh run list` ensures you only see these periodic main-branch runs.
|
| 35 |
+
|
| 36 |
+
## Workflow
|
| 37 |
+
|
| 38 |
+
### Phase 1: Extract the Failure Signature
|
| 39 |
+
|
| 40 |
+
1. **Get the failing test details from CI logs.** If given a URL, fetch logs directly. If given a test name, find recent scheduled runs of `pr-test.yml` on `main` that failed:
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
# List recent scheduled runs targeting main (the primary source of truth for regressions)
|
| 44 |
+
# These are cron-triggered runs visible at:
|
| 45 |
+
# https://github.com/sgl-project/sglang/actions/workflows/pr-test.yml?query=event%3Aschedule
|
| 46 |
+
gh run list --repo sgl-project/sglang --workflow="pr-test.yml" --event schedule --branch main --limit 20 --json databaseId,conclusion,createdAt,headSha
|
| 47 |
+
|
| 48 |
+
# Find the job containing the test
|
| 49 |
+
gh run view {RUN_ID} --repo sgl-project/sglang --json jobs --jq '.jobs[] | select(.conclusion == "failure") | {name, conclusion, databaseId}'
|
| 50 |
+
|
| 51 |
+
# Get the failure details
|
| 52 |
+
gh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -E -B 5 -A 30 "AssertionError|FAIL|Error|{TEST_NAME}"
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
2. **Record the failure signature:**
|
| 56 |
+
- Exact error message and assertion
|
| 57 |
+
- Affected test method name
|
| 58 |
+
- Model/config involved
|
| 59 |
+
- Numeric values (e.g., tolerance diffs, scores)
|
| 60 |
+
- Whether the failure is deterministic (same values across runs)
|
| 61 |
+
|
| 62 |
+
### Phase 2: Temporal Bisection
|
| 63 |
+
|
| 64 |
+
3. **Find the boundary between passing and failing runs.** Walk through the scheduled run history (from the `pr-test.yml` schedule runs on `main`) to identify:
|
| 65 |
+
- Last known PASSING run (sha + date)
|
| 66 |
+
- First known FAILING run (sha + date)
|
| 67 |
+
|
| 68 |
+
```bash
|
| 69 |
+
# For each scheduled run, check the specific partition/job status
|
| 70 |
+
gh run view {RUN_ID} --repo sgl-project/sglang --json jobs --jq '.jobs[] | select(.name == "{JOB_NAME}") | {conclusion, databaseId}'
|
| 71 |
+
|
| 72 |
+
# Verify a specific test passed or failed in a run
|
| 73 |
+
gh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -E "{TEST_NAME}|PASSED|FAILED|logprobs mismatch" | head -10
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
4. **List commits between the boundary:**
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
git log --oneline {LAST_PASS_SHA}..{FIRST_FAIL_SHA}
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
5. **Filter for relevant commits** that touch files related to the failing test (model layers, kernels, test utilities, etc.):
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
git log --oneline {LAST_PASS_SHA}..{FIRST_FAIL_SHA} -- {relevant_paths}
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
### Phase 3: Runner/Hardware Analysis
|
| 89 |
+
|
| 90 |
+
6. **Check if the failure is runner-specific.** Extract the runner identity from each failing and passing run:
|
| 91 |
+
|
| 92 |
+
```bash
|
| 93 |
+
# Get runner name and machine
|
| 94 |
+
gh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -E "Runner name|Machine name" | head -5
|
| 95 |
+
|
| 96 |
+
# Get GPU/driver info
|
| 97 |
+
gh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -i -E "NVIDIA-SMI|Driver Version|CUDA Version" | head -5
|
| 98 |
+
|
| 99 |
+
# Get package versions
|
| 100 |
+
gh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -E "sgl.kernel.*==|flashinfer.*==" | head -5
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
7. **Correlate runners with pass/fail outcomes.** Build a table:
|
| 104 |
+
|
| 105 |
+
| Run ID | Date | Runner | GPU Type | Driver | Result |
|
| 106 |
+
|--------|------|--------|----------|--------|--------|
|
| 107 |
+
|
| 108 |
+
If all failures map to a specific runner type/GPU and all passes map to another, the issue is **hardware-specific**, not a code regression.
|
| 109 |
+
|
| 110 |
+
### Phase 4: Code Analysis
|
| 111 |
+
|
| 112 |
+
8. **If a code regression is suspected** (failures not runner-specific), examine the candidate commits:
|
| 113 |
+
- Read the changed files
|
| 114 |
+
- Understand how the changes could affect the failing test
|
| 115 |
+
- Look for prefill-vs-decode differences, TP-specific paths, kernel changes
|
| 116 |
+
|
| 117 |
+
9. **If a hardware issue is suspected**, analyze:
|
| 118 |
+
- Kernel compatibility (CUDA compute capability)
|
| 119 |
+
- Driver version differences
|
| 120 |
+
- All-reduce / NCCL behavior differences
|
| 121 |
+
- CUDA graph capture differences across GPU architectures
|
| 122 |
+
|
| 123 |
+
### Phase 5: Remote Reproduction (Optional)
|
| 124 |
+
|
| 125 |
+
Only if SSH target and docker container were provided.
|
| 126 |
+
|
| 127 |
+
10. **Verify the remote environment:**
|
| 128 |
+
|
| 129 |
+
```bash
|
| 130 |
+
ssh {SSH_TARGET} "docker exec {CONTAINER} nvidia-smi --query-gpu=name,driver_version --format=csv"
|
| 131 |
+
ssh {SSH_TARGET} "docker exec {CONTAINER} pip show sgl-kernel sglang flashinfer-python 2>&1 | grep -E 'Name:|Version:'"
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
11. **Ensure latest code is installed.** If the container is stale, update:
|
| 135 |
+
|
| 136 |
+
```bash
|
| 137 |
+
# Try fetching latest main
|
| 138 |
+
ssh {SSH_TARGET} "docker exec {CONTAINER} bash -c 'cd /path/to/sglang && git fetch origin main && git checkout origin/main'"
|
| 139 |
+
# Or download and install from tarball if git auth fails
|
| 140 |
+
ssh {SSH_TARGET} "docker exec {CONTAINER} bash -c 'cd /tmp && curl -L https://github.com/sgl-project/sglang/archive/refs/heads/main.tar.gz | tar xz && cd sglang-main && pip install -e \"python[all]\"'"
|
| 141 |
+
# Reinstall (after git fetch)
|
| 142 |
+
ssh {SSH_TARGET} "docker exec {CONTAINER} bash -c 'cd /path/to/sglang && pip install -e \"python[all]\"'"
|
| 143 |
+
# Install test dependencies if needed
|
| 144 |
+
ssh {SSH_TARGET} "docker exec {CONTAINER} pip install peft rouge-score"
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
12. **Create a minimal reproduction script** that:
|
| 148 |
+
- Uses `if __name__ == '__main__'` with `mp.set_start_method("spawn")`
|
| 149 |
+
- Runs the specific failing test configuration
|
| 150 |
+
- Prints key metrics (diffs, scores, outputs)
|
| 151 |
+
- Exits with code 1 on failure
|
| 152 |
+
|
| 153 |
+
13. **Copy and run the reproduction script:**
|
| 154 |
+
|
| 155 |
+
```bash
|
| 156 |
+
scp /tmp/repro_script.py {SSH_TARGET}:/tmp/
|
| 157 |
+
ssh {SSH_TARGET} "docker cp /tmp/repro_script.py {CONTAINER}:/tmp/"
|
| 158 |
+
ssh {SSH_TARGET} "docker exec -e CUDA_VISIBLE_DEVICES=0,1 {CONTAINER} python3 /tmp/repro_script.py"
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
14. **Run control experiments** to isolate the variable:
|
| 162 |
+
- If suspecting TP issue: run with TP=1 as control
|
| 163 |
+
- If suspecting GPU issue: compare same code on different GPU
|
| 164 |
+
- If suspecting a specific commit: test before/after that commit
|
| 165 |
+
|
| 166 |
+
### Phase 6: Report
|
| 167 |
+
|
| 168 |
+
15. **Produce a structured report:**
|
| 169 |
+
|
| 170 |
+
```markdown
|
| 171 |
+
## CI Regression Bisection Report
|
| 172 |
+
|
| 173 |
+
### Failure Signature
|
| 174 |
+
- **Test**: {test_file}::{test_method}
|
| 175 |
+
- **Error**: {exact error message}
|
| 176 |
+
- **Key metrics**: {numeric values}
|
| 177 |
+
- **Deterministic**: Yes/No
|
| 178 |
+
|
| 179 |
+
### Root Cause Classification
|
| 180 |
+
One of:
|
| 181 |
+
- **Code Regression**: PR #{number} introduced the bug
|
| 182 |
+
- **Hardware-Specific**: Fails on {GPU_TYPE}, passes on others
|
| 183 |
+
- **Environment Change**: New runner/driver/package version
|
| 184 |
+
- **Pre-existing Flakiness**: Intermittent, not a new regression
|
| 185 |
+
|
| 186 |
+
### Evidence
|
| 187 |
+
| Condition | Result |
|
| 188 |
+
|-----------|--------|
|
| 189 |
+
| {condition1} | PASS/FAIL |
|
| 190 |
+
| {condition2} | PASS/FAIL |
|
| 191 |
+
|
| 192 |
+
### Timeline
|
| 193 |
+
- {date}: Last known pass ({sha}, {runner})
|
| 194 |
+
- {date}: First known fail ({sha}, {runner})
|
| 195 |
+
- {date}: Confirmed reproduction on {server}
|
| 196 |
+
|
| 197 |
+
### Recommended Fix
|
| 198 |
+
- **Short-term**: {workaround}
|
| 199 |
+
- **Long-term**: {proper fix}
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
## Key Patterns to Recognize
|
| 203 |
+
|
| 204 |
+
| Pattern | Diagnosis |
|
| 205 |
+
|---------|-----------|
|
| 206 |
+
| Same SHA passes on runner A, fails on runner B | Hardware/runner-specific |
|
| 207 |
+
| All runners fail after commit X | Code regression from commit X |
|
| 208 |
+
| Intermittent - same runner sometimes passes/fails | Flaky test or race condition |
|
| 209 |
+
| Prefill OK but decode fails | TP/all-reduce issue in decode path |
|
| 210 |
+
| Works with TP=1, fails with TP>1 | Tensor parallelism bug |
|
| 211 |
+
| Exact same numeric diff every time | Deterministic bug, not flakiness |
|
| 212 |
+
|
| 213 |
+
## Important Notes
|
| 214 |
+
|
| 215 |
+
- **Always check runner identity** before concluding it's a code regression. Many "consistent" failures are actually runner-specific.
|
| 216 |
+
- **Test partition assignments change over time** as tests are added/removed. A test may move between partitions, landing on different runner types.
|
| 217 |
+
- **H200 runners** use `/root/actions-runner/` path and machine names like `gpu-h200-worker-*`. Non-H200 runners use `/public_sglang_ci/runner-*` paths.
|
| 218 |
+
- When running remote reproduction, use `run_in_background` for long-running tests and check output with `TaskOutput`.
|
| 219 |
+
- Container environments may be stale - always verify package versions match CI before drawing conclusions.
|
sglang/.claude/skills/write-sglang-test/SKILL.md
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: write-sglang-test
|
| 3 |
+
description: Guide for writing SGLang CI/UT tests following project conventions. Covers CustomTestCase, CI registration, server fixtures, model selection, and test placement. Use when creating new tests, adding CI test cases, writing unit tests, or when the user asks to add tests for SGLang features.
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Writing SGLang CI / UT Tests
|
| 7 |
+
|
| 8 |
+
## Core Rules
|
| 9 |
+
|
| 10 |
+
1. **Always use `CustomTestCase`** — never raw `unittest.TestCase`
|
| 11 |
+
2. **Place tests in `test/registered/<category>/`** — only use `test/manual/` for debugging / non-CI tests
|
| 12 |
+
3. **Reuse server fixtures** — inherit from `DefaultServerBase` or write `setUpClass`/`tearDownClass` with `popen_launch_server`
|
| 13 |
+
4. **Smallest model for model-agnostic functionality** — use `DEFAULT_SMALL_MODEL_NAME_FOR_TEST` (Llama-3.2-1B-Instruct) for basic features that don't depend on model size
|
| 14 |
+
5. **8B for general performance** — use `DEFAULT_MODEL_NAME_FOR_TEST` (Llama-3.1-8B-Instruct, single-node) for performance tests that don't involve spec / DP / parallelism
|
| 15 |
+
6. **Bigger features → discuss case by case** — spec, DP attention, tensor/pipeline parallelism etc. may need multi-GPU suites and specific models
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
## Test File Template
|
| 20 |
+
|
| 21 |
+
### Functional correctness test (small model)
|
| 22 |
+
|
| 23 |
+
```python
|
| 24 |
+
import unittest
|
| 25 |
+
|
| 26 |
+
import requests
|
| 27 |
+
|
| 28 |
+
from sglang.srt.utils import kill_process_tree
|
| 29 |
+
from sglang.test.ci.ci_register import register_cuda_ci
|
| 30 |
+
from sglang.test.test_utils import (
|
| 31 |
+
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
| 32 |
+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
| 33 |
+
DEFAULT_URL_FOR_TEST,
|
| 34 |
+
CustomTestCase,
|
| 35 |
+
popen_launch_server,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
register_cuda_ci(est_time=60, suite="stage-b-test-small-1-gpu")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TestMyFeature(CustomTestCase):
|
| 42 |
+
@classmethod
|
| 43 |
+
def setUpClass(cls):
|
| 44 |
+
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
| 45 |
+
cls.base_url = DEFAULT_URL_FOR_TEST
|
| 46 |
+
cls.process = popen_launch_server(
|
| 47 |
+
cls.model,
|
| 48 |
+
cls.base_url,
|
| 49 |
+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
| 50 |
+
other_args=["--arg1", "value1"], # feature-specific args
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
@classmethod
|
| 54 |
+
def tearDownClass(cls):
|
| 55 |
+
kill_process_tree(cls.process.pid)
|
| 56 |
+
|
| 57 |
+
def test_basic_functionality(self):
|
| 58 |
+
response = requests.post(
|
| 59 |
+
self.base_url + "/generate",
|
| 60 |
+
json={"text": "Hello", "sampling_params": {"max_new_tokens": 32}},
|
| 61 |
+
)
|
| 62 |
+
self.assertEqual(response.status_code, 200)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
unittest.main(verbosity=3)
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### General performance test (8B model, single node, no spec/DP/parallelism)
|
| 70 |
+
|
| 71 |
+
```python
|
| 72 |
+
import time
|
| 73 |
+
import unittest
|
| 74 |
+
|
| 75 |
+
import requests
|
| 76 |
+
|
| 77 |
+
from sglang.srt.utils import kill_process_tree
|
| 78 |
+
from sglang.test.ci.ci_register import register_cuda_ci
|
| 79 |
+
from sglang.test.test_utils import (
|
| 80 |
+
DEFAULT_MODEL_NAME_FOR_TEST,
|
| 81 |
+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
| 82 |
+
DEFAULT_URL_FOR_TEST,
|
| 83 |
+
CustomTestCase,
|
| 84 |
+
popen_launch_server,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
register_cuda_ci(est_time=300, suite="stage-b-test-large-1-gpu")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class TestMyFeaturePerf(CustomTestCase):
|
| 91 |
+
@classmethod
|
| 92 |
+
def setUpClass(cls):
|
| 93 |
+
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
| 94 |
+
cls.base_url = DEFAULT_URL_FOR_TEST
|
| 95 |
+
cls.process = popen_launch_server(
|
| 96 |
+
cls.model,
|
| 97 |
+
cls.base_url,
|
| 98 |
+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
@classmethod
|
| 102 |
+
def tearDownClass(cls):
|
| 103 |
+
kill_process_tree(cls.process.pid)
|
| 104 |
+
|
| 105 |
+
def test_latency(self):
|
| 106 |
+
start = time.perf_counter()
|
| 107 |
+
response = requests.post(
|
| 108 |
+
self.base_url + "/generate",
|
| 109 |
+
json={"text": "Hello", "sampling_params": {"max_new_tokens": 128}},
|
| 110 |
+
)
|
| 111 |
+
elapsed = time.perf_counter() - start
|
| 112 |
+
self.assertEqual(response.status_code, 200)
|
| 113 |
+
self.assertLess(elapsed, 5.0, "Latency exceeded threshold")
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
unittest.main(verbosity=3)
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
---
|
| 121 |
+
|
| 122 |
+
## Server Fixture Reuse
|
| 123 |
+
|
| 124 |
+
For tests that only need a standard server, inherit from `DefaultServerBase` and override class attributes:
|
| 125 |
+
|
| 126 |
+
```python
|
| 127 |
+
from sglang.test.server_fixtures.default_fixture import DefaultServerBase
|
| 128 |
+
|
| 129 |
+
class TestMyFeature(DefaultServerBase):
|
| 130 |
+
model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
| 131 |
+
other_args = ["--enable-my-feature"]
|
| 132 |
+
|
| 133 |
+
def test_something(self):
|
| 134 |
+
...
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
Available fixtures in `python/sglang/test/server_fixtures/`:
|
| 138 |
+
|
| 139 |
+
| Fixture | Use case |
|
| 140 |
+
|---------|----------|
|
| 141 |
+
| `DefaultServerBase` | Standard single-server tests |
|
| 142 |
+
| `EagleServerBase` | EAGLE speculative decoding |
|
| 143 |
+
| `PDDisaggregationServerBase` | Disaggregated prefill/decode |
|
| 144 |
+
| `MMMUServerBase` | Multimodal VLM tests |
|
| 145 |
+
|
| 146 |
+
---
|
| 147 |
+
|
| 148 |
+
## CI Registration
|
| 149 |
+
|
| 150 |
+
Every test file in `test/registered/` **must** call a registration function at module level:
|
| 151 |
+
|
| 152 |
+
```python
|
| 153 |
+
from sglang.test.ci.ci_register import register_cuda_ci, register_amd_ci
|
| 154 |
+
|
| 155 |
+
register_cuda_ci(est_time=60, suite="stage-b-test-small-1-gpu")
|
| 156 |
+
register_amd_ci(est_time=60, suite="stage-b-test-small-1-gpu-amd") # optional
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
Parameters:
|
| 160 |
+
- `est_time`: estimated runtime in seconds (used for CI partitioning)
|
| 161 |
+
- `suite`: which CI suite to run in (see below)
|
| 162 |
+
- `nightly=True`: for nightly-only tests (default `False` = per-commit)
|
| 163 |
+
- `disabled="reason"`: temporarily disable with explanation
|
| 164 |
+
|
| 165 |
+
### Suite selection guide
|
| 166 |
+
|
| 167 |
+
**Default cases (1 GPU):**
|
| 168 |
+
|
| 169 |
+
| Scenario | Model | Suite |
|
| 170 |
+
|----------|-------|-------|
|
| 171 |
+
| Model-agnostic basic functionality | 1B (smallest) | `stage-b-test-small-1-gpu` |
|
| 172 |
+
| General performance (no spec/DP/parallelism) | 8B | `stage-b-test-large-1-gpu` |
|
| 173 |
+
|
| 174 |
+
**Bigger features (case by case):**
|
| 175 |
+
|
| 176 |
+
| Scenario | Suite |
|
| 177 |
+
|----------|-------|
|
| 178 |
+
| 2 GPU (e.g. TP=2) | `stage-b-test-large-2-gpu` |
|
| 179 |
+
| 4 GPU (H100) | `stage-c-test-4-gpu-h100` |
|
| 180 |
+
| 8 GPU (H200) | `stage-c-test-8-gpu-h200` |
|
| 181 |
+
| Nightly, 1 GPU | `nightly-1-gpu` |
|
| 182 |
+
| Nightly, 8 GPU | `nightly-8-gpu` |
|
| 183 |
+
|
| 184 |
+
For spec, DP attention, parallelism, disaggregation, etc., discuss with the team to determine the appropriate suite and GPU configuration.
|
| 185 |
+
|
| 186 |
+
---
|
| 187 |
+
|
| 188 |
+
## Model Constants
|
| 189 |
+
|
| 190 |
+
All defined in `python/sglang/test/test_utils.py`:
|
| 191 |
+
|
| 192 |
+
| Constant | Model | When to use |
|
| 193 |
+
|----------|-------|-------------|
|
| 194 |
+
| `DEFAULT_SMALL_MODEL_NAME_FOR_TEST` | Llama-3.2-1B-Instruct | Model-agnostic basic functionality |
|
| 195 |
+
| `DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE` | Llama-3.2-1B | Base (non-instruct) model tests |
|
| 196 |
+
| `DEFAULT_MODEL_NAME_FOR_TEST` | Llama-3.1-8B-Instruct | General performance (single node) |
|
| 197 |
+
| `DEFAULT_MOE_MODEL_NAME_FOR_TEST` | Mixtral-8x7B-Instruct | MoE-specific tests |
|
| 198 |
+
| `DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST` | — | Embedding tests |
|
| 199 |
+
| `DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST` | — | Vision-language tests |
|
| 200 |
+
|
| 201 |
+
---
|
| 202 |
+
|
| 203 |
+
## Test Placement
|
| 204 |
+
|
| 205 |
+
```
|
| 206 |
+
test/
|
| 207 |
+
├── registered/ # CI tests (auto-discovered by run_suite.py)
|
| 208 |
+
│ ├── sampling/ # test_penalty.py, test_sampling_params.py ...
|
| 209 |
+
│ ├── sessions/ # test_session_control.py ...
|
| 210 |
+
│ ├── openai_server/ # basic/, features/, validation/ ...
|
| 211 |
+
│ ├── spec/ # eagle/, utils/ ...
|
| 212 |
+
│ ├── models/ # model-specific accuracy tests
|
| 213 |
+
│ ├── perf/ # performance benchmarks
|
| 214 |
+
│ └── <category>/ # create new category if needed
|
| 215 |
+
├── manual/ # Non-CI: debugging, one-off, manual verification
|
| 216 |
+
└── run_suite.py # CI runner (scans registered/ only)
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
**Decision rule**: if the test should run in CI → `registered/`. If it's for local debugging or requires special hardware not in CI → `manual/`.
|
| 220 |
+
|
| 221 |
+
---
|
| 222 |
+
|
| 223 |
+
## Key Utilities
|
| 224 |
+
|
| 225 |
+
```python
|
| 226 |
+
from sglang.test.test_utils import (
|
| 227 |
+
CustomTestCase, # base class with retry logic
|
| 228 |
+
popen_launch_server, # launch server subprocess
|
| 229 |
+
DEFAULT_URL_FOR_TEST, # auto-configured base URL
|
| 230 |
+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, # 600s default
|
| 231 |
+
run_bench_serving, # benchmark helper (launch + bench)
|
| 232 |
+
)
|
| 233 |
+
from sglang.srt.utils import kill_process_tree # cleanup server
|
| 234 |
+
```
|
| 235 |
+
|
| 236 |
+
---
|
| 237 |
+
|
| 238 |
+
## Checklist
|
| 239 |
+
|
| 240 |
+
Before submitting a test:
|
| 241 |
+
|
| 242 |
+
- [ ] Inherits from `CustomTestCase` (not `unittest.TestCase`)
|
| 243 |
+
- [ ] Has `register_*_ci(...)` call at module level
|
| 244 |
+
- [ ] Placed in `test/registered/<category>/`
|
| 245 |
+
- [ ] Model selection: smallest for model-agnostic features, 8B for general perf, case-by-case for other complex features
|
| 246 |
+
- [ ] `setUpClass` launches server, `tearDownClass` kills it
|
| 247 |
+
- [ ] Has `if __name__ == "__main__": unittest.main(verbosity=3)`
|
| 248 |
+
- [ ] `est_time` is reasonable (measure locally)
|
sglang/benchmark/json_jump_forward/README.md
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Run benchmark
|
| 2 |
+
|
| 3 |
+
### Dependencies
|
| 4 |
+
|
| 5 |
+
```
|
| 6 |
+
llama_cpp_python 0.2.38
|
| 7 |
+
guidance 0.1.10
|
| 8 |
+
vllm 0.2.7
|
| 9 |
+
outlines 0.0.25
|
| 10 |
+
```
|
| 11 |
+
|
| 12 |
+
### Build dataset
|
| 13 |
+
|
| 14 |
+
When benchmarking long document information retrieval, run the following command to build the dataset:
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
pip install wikipedia
|
| 18 |
+
python3 build_dataset.py
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
### Benchmark sglang
|
| 22 |
+
|
| 23 |
+
Run Llama-7B
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
Benchmark Character Generation
|
| 30 |
+
|
| 31 |
+
```bash
|
| 32 |
+
python3 bench_sglang.py --mode character
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
Benchmark City Information Retrieval
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
python3 bench_sglang.py --mode city
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
### Benchmark Outlines + vLLM
|
| 43 |
+
|
| 44 |
+
Run Llama-7B
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
Benchmark Character Generation
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
python3 bench_other.py --mode character --backend outlines
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
Benchmark City Information Retrieval
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
python3 bench_other.py --mode city --backend outlines
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
### Benchmark guidance
|
| 63 |
+
|
| 64 |
+
Run Llama-7B and benchmark character generation
|
| 65 |
+
|
| 66 |
+
```bash
|
| 67 |
+
python3 bench_other.py --mode character --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
Run Llama-7B and benchmark city information retrieval
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
python3 bench_other.py --mode city --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
### Benchmark lmql
|
| 77 |
+
|
| 78 |
+
Run Llama-7B and benchmark character generation
|
| 79 |
+
|
| 80 |
+
```
|
| 81 |
+
python3 bench_other.py --mode character --backend lmql --parallel 1
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
Run Llama-7B and benchmark city information retrieval
|
| 85 |
+
|
| 86 |
+
```
|
| 87 |
+
python3 bench_other.py --mode city --backend lmql --parallel 1
|
| 88 |
+
```
|
sglang/benchmark/json_jump_forward/bench_other.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
import guidance
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
| 11 |
+
from sglang.utils import dump_state_text, read_jsonl
|
| 12 |
+
|
| 13 |
+
# there are some FSM bugs with json regex converted from pydantic model
|
| 14 |
+
# here use a string regex instead
|
| 15 |
+
# regex_string = build_regex_from_object(HarryPoterRole)
|
| 16 |
+
character_regex = (
|
| 17 |
+
r"""\{\n"""
|
| 18 |
+
+ r""" "name": "[\w\d\s]{1,16}",\n"""
|
| 19 |
+
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
|
| 20 |
+
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
|
| 21 |
+
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
|
| 22 |
+
+ r""" "wand": \{\n"""
|
| 23 |
+
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
|
| 24 |
+
+ r""" "core": "[\w\d\s]{1,16}",\n"""
|
| 25 |
+
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
|
| 26 |
+
+ r""" \},\n"""
|
| 27 |
+
+ r""" "alive": "(Alive|Deceased)",\n"""
|
| 28 |
+
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
|
| 29 |
+
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
|
| 30 |
+
+ r"""\}"""
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
city_regex = (
|
| 34 |
+
r"""\{\n"""
|
| 35 |
+
+ r""" "name": "[\w\d\s]{1,16}",\n"""
|
| 36 |
+
+ r""" "country": "[\w\d\s]{1,16}",\n"""
|
| 37 |
+
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
|
| 38 |
+
+ r""" "population": [-+]?[0-9]{1,9},\n"""
|
| 39 |
+
+ r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
|
| 40 |
+
+ r"""\}"""
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# fmt: off
|
| 44 |
+
def character_gen(name, generate):
|
| 45 |
+
s = name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
|
| 46 |
+
s += generate(s, max_tokens=256, regex=character_regex)
|
| 47 |
+
return s
|
| 48 |
+
# fmt: on
|
| 49 |
+
|
| 50 |
+
# fmt: off
|
| 51 |
+
def city_gen(document, generate):
|
| 52 |
+
s = "Please extract the information of a city from the following wikipedia page.\n"
|
| 53 |
+
s += "Page begin.\n" + document + "Page end.\n"
|
| 54 |
+
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
| 55 |
+
s += generate(s, max_tokens=256, regex=city_regex)
|
| 56 |
+
return s
|
| 57 |
+
# fmt: on
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@guidance
|
| 61 |
+
def character_maker(lm, name):
|
| 62 |
+
regex_str_no_quote = r"[\w\d\s]+"
|
| 63 |
+
regex_float = r"[0-9]+\.[0-9]+"
|
| 64 |
+
lm += f"""\
|
| 65 |
+
{name} is a character in Harry Potter. Please fill in the following information about this character.
|
| 66 |
+
{{
|
| 67 |
+
"name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}",
|
| 68 |
+
"house": "{guidance.select(options=['Gryffindor', 'Slytherin', 'Ravenclaw', 'Hufflepuff'], name='house')}",
|
| 69 |
+
"blood status": "{guidance.select(options=['Pure-blood', 'Half-blood', 'Muggle-born'], name='blood status')}",
|
| 70 |
+
"occupation": "{guidance.select(options=['student', 'teacher', 'auror', 'ministry of magic', 'death eater', 'order of the phoenix'], name='occupation')}",
|
| 71 |
+
"wand": {{
|
| 72 |
+
"wood": "{guidance.gen("wood", max_tokens=16, regex=regex_str_no_quote)}",
|
| 73 |
+
"core": "{guidance.gen('core', max_tokens=16, regex=regex_str_no_quote)}",
|
| 74 |
+
"length": {guidance.gen('length', max_tokens=10, regex=regex_float)}
|
| 75 |
+
}},
|
| 76 |
+
"alive": "{guidance.select(options=['Alive', 'Deceased'], name='alive')}",
|
| 77 |
+
"patronus": "{guidance.gen('patronus', max_tokens=16, regex=regex_str_no_quote)}",
|
| 78 |
+
"bogart": "{guidance.gen('bogart', max_tokens=16, regex=regex_str_no_quote)}"
|
| 79 |
+
}}
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
return lm
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
async def call_generate_lmql(
|
| 86 |
+
prompt, temperature, max_tokens, regex, max_len=4096, model=None, **kwargs
|
| 87 |
+
):
|
| 88 |
+
assert model is not None
|
| 89 |
+
import lmql
|
| 90 |
+
|
| 91 |
+
@lmql.query(model=model)
|
| 92 |
+
async def program(question, max_tokens, regex):
|
| 93 |
+
'''lmql
|
| 94 |
+
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and REGEX(ANSWER, regex)
|
| 95 |
+
return ANSWER
|
| 96 |
+
'''
|
| 97 |
+
|
| 98 |
+
return await program(
|
| 99 |
+
question=prompt,
|
| 100 |
+
temperature=temperature,
|
| 101 |
+
max_tokens=max_tokens,
|
| 102 |
+
max_len=max_len,
|
| 103 |
+
regex=regex,
|
| 104 |
+
**kwargs,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@guidance
|
| 109 |
+
def city_maker(lm, document):
|
| 110 |
+
regex_str_no_quote = r"[\w\d\s]+"
|
| 111 |
+
regex_float = r"[0-9]+\.[0-9]+"
|
| 112 |
+
lm += f"""\
|
| 113 |
+
Please extract the information of a city from the following wikipedia page.
|
| 114 |
+
Page begin.
|
| 115 |
+
{document}
|
| 116 |
+
Page end.
|
| 117 |
+
Here is the name, country, and symbol of the city in JSON format.
|
| 118 |
+
{{
|
| 119 |
+
"name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}",
|
| 120 |
+
"country": "{guidance.gen("country", max_tokens=16, regex=regex_str_no_quote)}",
|
| 121 |
+
"latitude": {guidance.gen("latitude", max_tokens=10, regex=regex_float)},
|
| 122 |
+
"population": {guidance.gen("population", max_tokens=10, regex=r"[0-9]+")},
|
| 123 |
+
"top 3 landmarks": [
|
| 124 |
+
"{guidance.gen("landmark1", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark2", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark3", max_tokens=16, regex=regex_str_no_quote)}"
|
| 125 |
+
]
|
| 126 |
+
}}
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
return lm
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def bench_character(args):
|
| 133 |
+
arguments = []
|
| 134 |
+
with open(args.data_path, "r") as f:
|
| 135 |
+
for line in f:
|
| 136 |
+
arguments.append({"name": line.strip()})
|
| 137 |
+
arguments = arguments[: args.num_jsons]
|
| 138 |
+
|
| 139 |
+
states = [None] * len(arguments)
|
| 140 |
+
|
| 141 |
+
# Select backend
|
| 142 |
+
if args.backend == "outlines":
|
| 143 |
+
call_generate = partial(get_call_generate(args), temperature=0)
|
| 144 |
+
|
| 145 |
+
def get_one_answer(i):
|
| 146 |
+
states[i] = character_gen(**arguments[i], generate=call_generate)
|
| 147 |
+
|
| 148 |
+
elif args.backend == "guidance":
|
| 149 |
+
model = guidance.models.LlamaCpp(
|
| 150 |
+
args.model_path,
|
| 151 |
+
n_gpu_layers=-1,
|
| 152 |
+
n_ctx=args.n_ctx,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def get_one_answer(i):
|
| 156 |
+
lm = model + character_maker(**arguments[i])
|
| 157 |
+
states[i] = lm
|
| 158 |
+
|
| 159 |
+
elif args.backend == "lmql":
|
| 160 |
+
import asyncio
|
| 161 |
+
|
| 162 |
+
import lmql
|
| 163 |
+
|
| 164 |
+
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
|
| 165 |
+
call_generate = partial(
|
| 166 |
+
call_generate_lmql,
|
| 167 |
+
model=model,
|
| 168 |
+
max_tokens=256,
|
| 169 |
+
regex=character_regex,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
async def get_one_answer_async(i):
|
| 173 |
+
states[i] = await call_generate(prompt=arguments[i]["name"], temperature=0)
|
| 174 |
+
|
| 175 |
+
else:
|
| 176 |
+
raise ValueError(f"Invalid backend: {args.backend}")
|
| 177 |
+
|
| 178 |
+
tic = time.perf_counter()
|
| 179 |
+
|
| 180 |
+
if args.backend != "lmql":
|
| 181 |
+
if args.parallel == 1:
|
| 182 |
+
for i in tqdm(range(len(arguments))):
|
| 183 |
+
get_one_answer(i)
|
| 184 |
+
else:
|
| 185 |
+
with ThreadPoolExecutor(args.parallel) as executor:
|
| 186 |
+
rets = list(
|
| 187 |
+
tqdm(
|
| 188 |
+
executor.map(get_one_answer, list(range(len(arguments)))),
|
| 189 |
+
total=len(arguments),
|
| 190 |
+
)
|
| 191 |
+
)
|
| 192 |
+
for _ in rets:
|
| 193 |
+
pass
|
| 194 |
+
else:
|
| 195 |
+
batches = []
|
| 196 |
+
for i in range(0, len(arguments), args.parallel):
|
| 197 |
+
batches.append(list(range(i, min(i + args.parallel, len(arguments)))))
|
| 198 |
+
loop = asyncio.get_event_loop()
|
| 199 |
+
|
| 200 |
+
for bt in tqdm(batches):
|
| 201 |
+
loop.run_until_complete(
|
| 202 |
+
asyncio.gather(*[get_one_answer_async(i) for i in bt])
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
latency = time.perf_counter() - tic
|
| 206 |
+
|
| 207 |
+
return states, latency
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def bench_city_doc(args):
|
| 211 |
+
arguments = []
|
| 212 |
+
for line in read_jsonl(args.data_path):
|
| 213 |
+
arguments.append({"document": line["document"]})
|
| 214 |
+
arguments = arguments[: args.num_jsons]
|
| 215 |
+
|
| 216 |
+
states = [None] * len(arguments)
|
| 217 |
+
|
| 218 |
+
# Select backend
|
| 219 |
+
if args.backend == "outlines":
|
| 220 |
+
call_generate = partial(get_call_generate(args), temperature=0)
|
| 221 |
+
|
| 222 |
+
def get_one_answer(i):
|
| 223 |
+
states[i] = city_gen(**arguments[i], generate=call_generate)
|
| 224 |
+
|
| 225 |
+
elif args.backend == "guidance":
|
| 226 |
+
model = guidance.models.LlamaCpp(
|
| 227 |
+
args.model_path,
|
| 228 |
+
n_gpu_layers=-1,
|
| 229 |
+
n_ctx=args.n_ctx,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
def get_one_answer(i):
|
| 233 |
+
lm = model + city_maker(**arguments[i])
|
| 234 |
+
states[i] = lm
|
| 235 |
+
|
| 236 |
+
else:
|
| 237 |
+
raise ValueError(f"Invalid backend: {args.backend}")
|
| 238 |
+
|
| 239 |
+
tic = time.perf_counter()
|
| 240 |
+
if args.parallel == 1:
|
| 241 |
+
for i in tqdm(range(len(arguments))):
|
| 242 |
+
get_one_answer(i)
|
| 243 |
+
else:
|
| 244 |
+
with ThreadPoolExecutor(args.parallel) as executor:
|
| 245 |
+
rets = executor.map(get_one_answer, list(range(len(arguments))))
|
| 246 |
+
for _ in rets:
|
| 247 |
+
pass
|
| 248 |
+
|
| 249 |
+
latency = time.perf_counter() - tic
|
| 250 |
+
|
| 251 |
+
return states, latency
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def main(args):
|
| 255 |
+
if args.mode == "character":
|
| 256 |
+
args.data_path = "dataset.txt"
|
| 257 |
+
states, latency = bench_character(args)
|
| 258 |
+
elif args.mode == "city":
|
| 259 |
+
args.data_path = "questions.jsonl"
|
| 260 |
+
states, latency = bench_city_doc(args)
|
| 261 |
+
|
| 262 |
+
# Compute accuracy
|
| 263 |
+
print(f"Latency: {latency:.3f}")
|
| 264 |
+
|
| 265 |
+
# Write results
|
| 266 |
+
dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)
|
| 267 |
+
|
| 268 |
+
with open(args.result_file, "a") as fout:
|
| 269 |
+
value = {
|
| 270 |
+
"task": "json_jump_forward",
|
| 271 |
+
"backend": args.backend,
|
| 272 |
+
"latency": round(latency, 3),
|
| 273 |
+
"num_jsons": args.num_jsons,
|
| 274 |
+
"mode": args.mode,
|
| 275 |
+
"parallel": args.parallel,
|
| 276 |
+
}
|
| 277 |
+
fout.write(json.dumps(value) + "\n")
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
parser = argparse.ArgumentParser()
|
| 282 |
+
parser.add_argument("--data-path", type=str)
|
| 283 |
+
parser.add_argument("--num-jsons", type=int, default=50)
|
| 284 |
+
parser.add_argument(
|
| 285 |
+
"--mode", type=str, default="character", choices=["character", "city"]
|
| 286 |
+
)
|
| 287 |
+
args = add_common_other_args_and_parse(parser)
|
| 288 |
+
main(args)
|
sglang/benchmark/json_jump_forward/bench_sglang.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import sglang as sgl
|
| 6 |
+
from sglang.test.test_utils import (
|
| 7 |
+
add_common_sglang_args_and_parse,
|
| 8 |
+
select_sglang_backend,
|
| 9 |
+
)
|
| 10 |
+
from sglang.utils import dump_state_text, read_jsonl
|
| 11 |
+
|
| 12 |
+
# there are some FSM bugs with json regex converted from pydantic model
|
| 13 |
+
# here use a string regex instead
|
| 14 |
+
# regex_string = build_regex_from_object(HarryPoterRole)
|
| 15 |
+
character_regex = (
|
| 16 |
+
r"""\{\n"""
|
| 17 |
+
+ r""" "name": "[\w\d\s]{1,16}",\n"""
|
| 18 |
+
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
|
| 19 |
+
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
|
| 20 |
+
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
|
| 21 |
+
+ r""" "wand": \{\n"""
|
| 22 |
+
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
|
| 23 |
+
+ r""" "core": "[\w\d\s]{1,16}",\n"""
|
| 24 |
+
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
|
| 25 |
+
+ r""" \},\n"""
|
| 26 |
+
+ r""" "alive": "(Alive|Deceased)",\n"""
|
| 27 |
+
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
|
| 28 |
+
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
|
| 29 |
+
+ r"""\}"""
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
city_regex = (
|
| 33 |
+
r"""\{\n"""
|
| 34 |
+
+ r""" "name": "[\w\d\s]{1,16}",\n"""
|
| 35 |
+
+ r""" "country": "[\w\d\s]{1,16}",\n"""
|
| 36 |
+
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
|
| 37 |
+
+ r""" "population": [-+]?[0-9]{1,9},\n"""
|
| 38 |
+
+ r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
|
| 39 |
+
+ r"""\}"""
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# fmt: off
|
| 43 |
+
@sgl.function
|
| 44 |
+
def character_gen(s, name):
|
| 45 |
+
s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
|
| 46 |
+
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
|
| 47 |
+
# fmt: on
|
| 48 |
+
|
| 49 |
+
# fmt: off
|
| 50 |
+
@sgl.function
|
| 51 |
+
def city_gen(s, document):
|
| 52 |
+
s += "Please extract the information of a city from the following wikipedia page.\n"
|
| 53 |
+
s += "Page begin.\n" + document + "Page end.\n"
|
| 54 |
+
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
| 55 |
+
s += sgl.gen("json_output",max_tokens=256, regex=city_regex)
|
| 56 |
+
# fmt: on
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def bench_city_doc(args):
|
| 60 |
+
arguments = []
|
| 61 |
+
for line in read_jsonl(args.data_path):
|
| 62 |
+
arguments.append({"document": line["document"]})
|
| 63 |
+
arguments = arguments[: args.num_jsons]
|
| 64 |
+
|
| 65 |
+
# Select backend
|
| 66 |
+
backend = select_sglang_backend(args)
|
| 67 |
+
sgl.set_default_backend(backend)
|
| 68 |
+
|
| 69 |
+
# Run requests
|
| 70 |
+
tic = time.perf_counter()
|
| 71 |
+
states = city_gen.run_batch(
|
| 72 |
+
arguments,
|
| 73 |
+
temperature=0,
|
| 74 |
+
num_threads=args.parallel,
|
| 75 |
+
progress_bar=True,
|
| 76 |
+
)
|
| 77 |
+
latency = time.perf_counter() - tic
|
| 78 |
+
|
| 79 |
+
return states, latency
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def bench_character(args):
|
| 83 |
+
arguments = []
|
| 84 |
+
with open(args.data_path, "r") as f:
|
| 85 |
+
for line in f:
|
| 86 |
+
arguments.append({"name": line.strip()})
|
| 87 |
+
arguments = arguments[: args.num_jsons]
|
| 88 |
+
|
| 89 |
+
# Select backend
|
| 90 |
+
backend = select_sglang_backend(args)
|
| 91 |
+
sgl.set_default_backend(backend)
|
| 92 |
+
|
| 93 |
+
# Run requests
|
| 94 |
+
tic = time.perf_counter()
|
| 95 |
+
states = character_gen.run_batch(
|
| 96 |
+
arguments,
|
| 97 |
+
temperature=0,
|
| 98 |
+
num_threads=args.parallel,
|
| 99 |
+
progress_bar=True,
|
| 100 |
+
)
|
| 101 |
+
latency = time.perf_counter() - tic
|
| 102 |
+
|
| 103 |
+
return states, latency
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def main(args):
|
| 107 |
+
if args.mode == "character":
|
| 108 |
+
args.data_path = "dataset.txt"
|
| 109 |
+
states, latency = bench_character(args)
|
| 110 |
+
elif args.mode == "city":
|
| 111 |
+
args.data_path = "questions.jsonl"
|
| 112 |
+
states, latency = bench_city_doc(args)
|
| 113 |
+
|
| 114 |
+
# Compute accuracy
|
| 115 |
+
print(f"Latency: {latency:.3f}")
|
| 116 |
+
|
| 117 |
+
# Write results
|
| 118 |
+
dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)
|
| 119 |
+
with open(f"{args.backend}_{args.mode}.json", "w") as fout:
|
| 120 |
+
for state in states:
|
| 121 |
+
fout.write(state["json_output"] + "\n")
|
| 122 |
+
|
| 123 |
+
with open(args.result_file, "a") as fout:
|
| 124 |
+
value = {
|
| 125 |
+
"task": "json_jump_forward",
|
| 126 |
+
"backend": args.backend,
|
| 127 |
+
"latency": round(latency, 3),
|
| 128 |
+
"num_jsons": args.num_jsons,
|
| 129 |
+
"mode": args.mode,
|
| 130 |
+
"parallel": args.parallel,
|
| 131 |
+
}
|
| 132 |
+
fout.write(json.dumps(value) + "\n")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
parser = argparse.ArgumentParser()
|
| 137 |
+
parser.add_argument("--data-path", type=str)
|
| 138 |
+
parser.add_argument("--num-jsons", type=int, default=50)
|
| 139 |
+
parser.add_argument(
|
| 140 |
+
"--mode", type=str, default="character", choices=["character", "city"]
|
| 141 |
+
)
|
| 142 |
+
args = add_common_sglang_args_and_parse(parser)
|
| 143 |
+
main(args)
|
sglang/benchmark/json_jump_forward/build_dataset.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
import transformers
|
| 4 |
+
import wikipedia
|
| 5 |
+
|
| 6 |
+
model_path = "meta-llama/Llama-2-7b-chat-hf"
|
| 7 |
+
t = transformers.AutoTokenizer.from_pretrained(model_path)
|
| 8 |
+
city_names = [
|
| 9 |
+
"los angles",
|
| 10 |
+
"london",
|
| 11 |
+
"tokyo",
|
| 12 |
+
"beijing",
|
| 13 |
+
"singapore",
|
| 14 |
+
"paris",
|
| 15 |
+
"dubai",
|
| 16 |
+
"sydney",
|
| 17 |
+
"moscow",
|
| 18 |
+
"rome",
|
| 19 |
+
"toronto",
|
| 20 |
+
"rio de janeiro",
|
| 21 |
+
"istanbul",
|
| 22 |
+
"berlin",
|
| 23 |
+
"auckland",
|
| 24 |
+
"buenos aires",
|
| 25 |
+
"mexico city",
|
| 26 |
+
"mumbai",
|
| 27 |
+
"seoul",
|
| 28 |
+
"bangkok",
|
| 29 |
+
"cairo",
|
| 30 |
+
"athens",
|
| 31 |
+
"jerusalem",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_content(city_name):
|
| 36 |
+
content = str(wikipedia.page(city_name).content)
|
| 37 |
+
content = content.replace("\n\n", "\n")
|
| 38 |
+
|
| 39 |
+
tokens = t.encode(content)
|
| 40 |
+
|
| 41 |
+
expected_tokens = 3000
|
| 42 |
+
truncate_len = int((expected_tokens / len(tokens)) * len(content))
|
| 43 |
+
truncate_content = content[:truncate_len]
|
| 44 |
+
truncate_tokens = t.encode(truncate_content)
|
| 45 |
+
|
| 46 |
+
# Count token
|
| 47 |
+
print(
|
| 48 |
+
f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
return truncate_content
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if __name__ == "__main__":
|
| 55 |
+
with open("questions.jsonl", "w") as fout:
|
| 56 |
+
for city_name in city_names:
|
| 57 |
+
truncate_content = get_content(city_name)
|
| 58 |
+
fout.write(json.dumps({"document": truncate_content}) + "\n")
|
sglang/benchmark/json_jump_forward/dataset.txt
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Harry Potter
|
| 2 |
+
Hermione Granger
|
| 3 |
+
Ron Weasley
|
| 4 |
+
Albus Dumbledore
|
| 5 |
+
Severus Snape
|
| 6 |
+
Rubeus Hagrid
|
| 7 |
+
Draco Malfoy
|
| 8 |
+
Ginny Weasley
|
| 9 |
+
Fred Weasley
|
| 10 |
+
George Weasley
|
| 11 |
+
Percy Weasley
|
| 12 |
+
Sirius Black
|
| 13 |
+
Remus Lupin
|
| 14 |
+
Neville Longbottom
|
| 15 |
+
Luna Lovegood
|
| 16 |
+
Cedric Diggory
|
| 17 |
+
Cho Chang
|
| 18 |
+
Lord Voldemort
|
| 19 |
+
Minerva McGonagall
|
| 20 |
+
Filius Flitwick
|
| 21 |
+
Dolores Umbridge
|
| 22 |
+
Bellatrix Lestrange
|
| 23 |
+
Lucius Malfoy
|
| 24 |
+
Molly Weasley
|
| 25 |
+
Arthur Weasley
|
| 26 |
+
Nymphadora Tonks
|
| 27 |
+
Dobby
|
| 28 |
+
Moaning Myrtle
|
| 29 |
+
Peter Pettigrew
|
| 30 |
+
Alastor 'Mad-Eye' Moody
|
| 31 |
+
Horace Slughorn
|
| 32 |
+
Vernon Dursley
|
| 33 |
+
Petunia Dursley
|
| 34 |
+
Dudley Dursley
|
| 35 |
+
Argus Filch
|
| 36 |
+
Sybill Trelawney
|
| 37 |
+
Gilderoy Lockhart
|
| 38 |
+
Fleur Delacour
|
| 39 |
+
Viktor Krum
|
| 40 |
+
Bill Weasley
|
| 41 |
+
Oliver Wood
|
| 42 |
+
Cornelius Fudge
|
| 43 |
+
Barty Crouch Sr.
|
| 44 |
+
Barty Crouch Jr.
|
| 45 |
+
Kingsley Shacklebolt
|
| 46 |
+
Quirinus Quirrell
|
| 47 |
+
Nearly Headless Nick
|
| 48 |
+
Aunt Marge
|
| 49 |
+
Griphook
|
| 50 |
+
Ludo Bagman
|
sglang/benchmark/multi_turn_chat/bench_other.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import time
|
| 3 |
+
from argparse import ArgumentParser
|
| 4 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
from data_gen import gen_arguments
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from vllm.transformers_utils.tokenizer import get_tokenizer
|
| 10 |
+
|
| 11 |
+
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
| 12 |
+
from sglang.utils import dump_state_text
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def multi_turns(generate, qas):
|
| 16 |
+
s = ""
|
| 17 |
+
for qa in qas:
|
| 18 |
+
s += qa["prompt"]
|
| 19 |
+
s += generate(s, max_tokens=qa["new_tokens"])
|
| 20 |
+
|
| 21 |
+
return s
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def main(args):
|
| 25 |
+
print(args)
|
| 26 |
+
|
| 27 |
+
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
| 28 |
+
|
| 29 |
+
multi_qas = gen_arguments(args, tokenizer)
|
| 30 |
+
|
| 31 |
+
states = [None] * args.num_qa
|
| 32 |
+
|
| 33 |
+
call_generate = partial(get_call_generate(args), temperature=0)
|
| 34 |
+
|
| 35 |
+
def get_one_answer(i):
|
| 36 |
+
states[i] = multi_turns(generate=call_generate, **multi_qas[i])
|
| 37 |
+
|
| 38 |
+
tic = time.perf_counter()
|
| 39 |
+
if args.parallel == 1:
|
| 40 |
+
for i in tqdm(range(len(multi_qas))):
|
| 41 |
+
get_one_answer(i)
|
| 42 |
+
else:
|
| 43 |
+
with ThreadPoolExecutor(args.parallel) as executor:
|
| 44 |
+
rets = list(
|
| 45 |
+
tqdm(
|
| 46 |
+
executor.map(get_one_answer, list(range(len(multi_qas)))),
|
| 47 |
+
total=len(multi_qas),
|
| 48 |
+
)
|
| 49 |
+
)
|
| 50 |
+
for _ in rets:
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
latency = time.perf_counter() - tic
|
| 54 |
+
|
| 55 |
+
# Compute accuracy
|
| 56 |
+
print(f"Latency: {latency:.3f}")
|
| 57 |
+
|
| 58 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
| 59 |
+
|
| 60 |
+
with open(args.result_file, "a") as fout:
|
| 61 |
+
value = {
|
| 62 |
+
"task": "multi_turn_chat",
|
| 63 |
+
"backend": args.backend,
|
| 64 |
+
"num_gpus": 1,
|
| 65 |
+
"latency": round(latency, 3),
|
| 66 |
+
"num_requests": args.num_qa,
|
| 67 |
+
"num_turns": args.turns,
|
| 68 |
+
"other": {
|
| 69 |
+
"parallel": args.parallel,
|
| 70 |
+
"output_mode": "long" if args.long else "short",
|
| 71 |
+
},
|
| 72 |
+
}
|
| 73 |
+
fout.write(json.dumps(value) + "\n")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
parser = ArgumentParser()
|
| 78 |
+
parser.add_argument("--turns", type=int, default=4)
|
| 79 |
+
parser.add_argument("--num-qa", type=int, default=20)
|
| 80 |
+
parser.add_argument("--min-len-q", type=int, default=256)
|
| 81 |
+
parser.add_argument("--max-len-q", type=int, default=512)
|
| 82 |
+
parser.add_argument("--min-len-a", type=int, default=4)
|
| 83 |
+
parser.add_argument("--max-len-a", type=int, default=8)
|
| 84 |
+
parser.add_argument("--tokenizer", type=str, required=True)
|
| 85 |
+
parser.add_argument("--trust-remote-code", action="store_true")
|
| 86 |
+
parser.add_argument("--long", action="store_true")
|
| 87 |
+
args = add_common_other_args_and_parse(parser)
|
| 88 |
+
|
| 89 |
+
if args.long:
|
| 90 |
+
args.min_len_a = 256
|
| 91 |
+
args.max_len_a = 512
|
| 92 |
+
args.num_qa = 20
|
| 93 |
+
main(args)
|
sglang/benchmark/multi_turn_chat/data_gen.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import string
|
| 3 |
+
|
| 4 |
+
random.seed(42)
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def gen_prompt(tokenizer, token_num):
|
| 8 |
+
cha_set = string.ascii_letters + string.digits
|
| 9 |
+
ret = "".join(random.choices(cha_set, k=token_num))
|
| 10 |
+
while len(tokenizer(ret).input_ids) < token_num:
|
| 11 |
+
ret += random.choice(cha_set)
|
| 12 |
+
return ret
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def gen_arguments(args, tokenizer):
|
| 16 |
+
multi_qas = [{"qas": []} for _ in range(args.num_qa)]
|
| 17 |
+
for i in range(args.num_qa):
|
| 18 |
+
qas = multi_qas[i]["qas"]
|
| 19 |
+
for _ in range(args.turns):
|
| 20 |
+
prompt_len = random.randint(args.min_len_q, args.max_len_q)
|
| 21 |
+
new_tokens = random.randint(args.min_len_a, args.max_len_a)
|
| 22 |
+
qas.append(
|
| 23 |
+
{
|
| 24 |
+
"prompt": gen_prompt(tokenizer, prompt_len),
|
| 25 |
+
"new_tokens": new_tokens,
|
| 26 |
+
}
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
return multi_qas
|
sglang/benchmark/tree_of_thought_deep/README.md
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Download data
|
| 2 |
+
```
|
| 3 |
+
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
| 4 |
+
```
|
| 5 |
+
|
| 6 |
+
## Run benchmark
|
| 7 |
+
|
| 8 |
+
NOTE: This is an implementation for throughput/latency benchmark purposes. The prompts are not tuned to achieve good accuracy on the GSM-8K tasks.
|
| 9 |
+
|
| 10 |
+
### Benchmark sglang
|
| 11 |
+
```
|
| 12 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
```
|
| 16 |
+
python3 bench_sglang.py --num-questions 32
|
| 17 |
+
python3 bench_sglang.py --num-questions 16 --parallel 1
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
### Benchmark vllm
|
| 22 |
+
```
|
| 23 |
+
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
```
|
| 27 |
+
python3 bench_other.py --num-questions 32 --backend vllm
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
### Benchmark lightllm
|
| 32 |
+
```
|
| 33 |
+
# A10G
|
| 34 |
+
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
```
|
| 38 |
+
python3 bench_other.py --num-questions 32 --backend lightllm
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
### Benchmark guidance
|
| 43 |
+
```
|
| 44 |
+
python3 bench_other.py --num-questions 8 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### Benchmark lmql
|
| 48 |
+
|
| 49 |
+
```
|
| 50 |
+
python3 bench_other.py --num-questions 8 --backend lmql --parallel 1
|
| 51 |
+
```
|
sglang/benchmark/tree_of_thought_deep/bench_other.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import ast
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
import time
|
| 6 |
+
from collections import Counter
|
| 7 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
| 13 |
+
from sglang.utils import dump_state_text, read_jsonl
|
| 14 |
+
|
| 15 |
+
INVALID = -9999999
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_answer_value(answer_str):
|
| 19 |
+
answer_str = answer_str.replace(",", "")
|
| 20 |
+
numbers = re.findall(r"\d+", answer_str)
|
| 21 |
+
if len(numbers) < 1:
|
| 22 |
+
return INVALID
|
| 23 |
+
try:
|
| 24 |
+
return ast.literal_eval(numbers[-1])
|
| 25 |
+
except SyntaxError:
|
| 26 |
+
return INVALID
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def most_frequent_number(numbers):
|
| 30 |
+
if not numbers:
|
| 31 |
+
return None
|
| 32 |
+
|
| 33 |
+
frequency = Counter(numbers)
|
| 34 |
+
most_frequent = max(frequency, key=frequency.get)
|
| 35 |
+
return most_frequent
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
USER_PREFIX = "[INST] "
|
| 39 |
+
USER_SUFFIX = " [/INST]"
|
| 40 |
+
ASSISTANT_PREFIX = ""
|
| 41 |
+
ASSISTANT_SUFFIX = " </s><s>"
|
| 42 |
+
|
| 43 |
+
# Use a low temp to make the results more deterministic and the comparison more fair.
|
| 44 |
+
temp = 0.001
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def propose_plan(s, question, num_branches, call_generate):
|
| 48 |
+
s += (
|
| 49 |
+
USER_PREFIX
|
| 50 |
+
+ """Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """
|
| 51 |
+
+ question
|
| 52 |
+
+ USER_SUFFIX
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
s += ASSISTANT_PREFIX
|
| 56 |
+
comps = call_generate(
|
| 57 |
+
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
|
| 58 |
+
)
|
| 59 |
+
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def execute_plan(s, num_branches, call_generate):
|
| 63 |
+
s += (
|
| 64 |
+
USER_PREFIX
|
| 65 |
+
+ """The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short."""
|
| 66 |
+
+ USER_SUFFIX
|
| 67 |
+
)
|
| 68 |
+
s += ASSISTANT_PREFIX
|
| 69 |
+
comps = call_generate(
|
| 70 |
+
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
|
| 71 |
+
)
|
| 72 |
+
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def reflect_solution(s, num_branches, call_generate):
|
| 76 |
+
s += (
|
| 77 |
+
USER_PREFIX
|
| 78 |
+
+ """Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness."""
|
| 79 |
+
+ USER_SUFFIX
|
| 80 |
+
)
|
| 81 |
+
s += ASSISTANT_PREFIX
|
| 82 |
+
comps = call_generate(
|
| 83 |
+
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
|
| 84 |
+
)
|
| 85 |
+
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_final_answer(s, num_branches, call_generate):
|
| 89 |
+
s += (
|
| 90 |
+
USER_PREFIX
|
| 91 |
+
+ """Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration."""
|
| 92 |
+
+ USER_SUFFIX
|
| 93 |
+
)
|
| 94 |
+
s += ASSISTANT_PREFIX
|
| 95 |
+
comps = call_generate(
|
| 96 |
+
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
|
| 97 |
+
)
|
| 98 |
+
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def tree_search(question, num_branches, call_generate):
|
| 102 |
+
plan_forks = propose_plan("", question, num_branches, call_generate)
|
| 103 |
+
|
| 104 |
+
sol_states = []
|
| 105 |
+
for plan in plan_forks:
|
| 106 |
+
forks = execute_plan(plan, num_branches, call_generate)
|
| 107 |
+
sol_states.extend(forks)
|
| 108 |
+
|
| 109 |
+
ref_states = []
|
| 110 |
+
for sol in sol_states:
|
| 111 |
+
forks = reflect_solution(sol, num_branches, call_generate)
|
| 112 |
+
ref_states.extend(forks)
|
| 113 |
+
|
| 114 |
+
solutions = []
|
| 115 |
+
for sol in ref_states:
|
| 116 |
+
ans = get_final_answer(sol, num_branches, call_generate)
|
| 117 |
+
solutions.append(ans)
|
| 118 |
+
|
| 119 |
+
return solutions
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def main(args):
|
| 123 |
+
lines = read_jsonl(args.data_path)
|
| 124 |
+
|
| 125 |
+
# Construct prompts
|
| 126 |
+
num_branches = 2
|
| 127 |
+
questions = []
|
| 128 |
+
labels = []
|
| 129 |
+
for i in range(len(lines[: args.num_questions])):
|
| 130 |
+
questions.append(lines[i]["question"])
|
| 131 |
+
labels.append(get_answer_value(lines[i]["answer"]))
|
| 132 |
+
assert all(l != INVALID for l in labels)
|
| 133 |
+
arguments = [{"question": q, "num_branches": num_branches} for q in questions]
|
| 134 |
+
|
| 135 |
+
# Select backend
|
| 136 |
+
call_generate = get_call_generate(args)
|
| 137 |
+
|
| 138 |
+
# Run requests
|
| 139 |
+
states = [None] * len(questions)
|
| 140 |
+
|
| 141 |
+
tic = time.perf_counter()
|
| 142 |
+
if args.backend != "lmql":
|
| 143 |
+
|
| 144 |
+
def get_one_answer(i):
|
| 145 |
+
states[i] = tree_search(**arguments[i], call_generate=call_generate)
|
| 146 |
+
|
| 147 |
+
if args.parallel == 1:
|
| 148 |
+
for i in tqdm(range(len(questions))):
|
| 149 |
+
get_one_answer(i)
|
| 150 |
+
else:
|
| 151 |
+
with ThreadPoolExecutor(args.parallel) as executor:
|
| 152 |
+
list(
|
| 153 |
+
tqdm(
|
| 154 |
+
executor.map(get_one_answer, list(range(len(questions)))),
|
| 155 |
+
total=len(questions),
|
| 156 |
+
)
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
else:
|
| 160 |
+
import asyncio
|
| 161 |
+
|
| 162 |
+
from lmql_funcs import tree_search_async
|
| 163 |
+
|
| 164 |
+
async def get_one_answer_async(i):
|
| 165 |
+
states[i] = await tree_search_async(
|
| 166 |
+
**arguments[i], call_generate=call_generate
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
batches = [
|
| 170 |
+
[] for _ in range((len(questions) + args.parallel - 1) // args.parallel)
|
| 171 |
+
]
|
| 172 |
+
for i in range(len(questions)):
|
| 173 |
+
batches[i // args.parallel].append(i)
|
| 174 |
+
|
| 175 |
+
loop = asyncio.get_event_loop()
|
| 176 |
+
for bt in tqdm(batches):
|
| 177 |
+
tasks = [get_one_answer_async(k) for k in bt]
|
| 178 |
+
loop.run_until_complete(asyncio.gather(*tasks))
|
| 179 |
+
|
| 180 |
+
latency = time.perf_counter() - tic
|
| 181 |
+
|
| 182 |
+
answers_text = []
|
| 183 |
+
for s in states:
|
| 184 |
+
answers_text.append([x for xs in s for x in xs])
|
| 185 |
+
|
| 186 |
+
preds = []
|
| 187 |
+
for i in range(len(states)):
|
| 188 |
+
answers = [get_answer_value(v) for v in answers_text[i]]
|
| 189 |
+
preds.append(most_frequent_number(answers))
|
| 190 |
+
|
| 191 |
+
# Compute accuracy
|
| 192 |
+
acc = np.mean(np.array(preds) == np.array(labels))
|
| 193 |
+
invalid = np.mean(np.array(preds) == INVALID)
|
| 194 |
+
print(f"Latency: {latency:.3f}")
|
| 195 |
+
print(f"Invalid: {invalid:.3f}")
|
| 196 |
+
print(f"Accuracy: {acc:.3f}")
|
| 197 |
+
|
| 198 |
+
# Write results
|
| 199 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", answers_text)
|
| 200 |
+
|
| 201 |
+
with open(args.result_file, "a") as fout:
|
| 202 |
+
value = {
|
| 203 |
+
"task": "tree_of_thought_gsm8k",
|
| 204 |
+
"backend": args.backend,
|
| 205 |
+
"num_gpus": 1,
|
| 206 |
+
"latency": round(latency, 3),
|
| 207 |
+
"accuracy": round(acc, 3),
|
| 208 |
+
"num_requests": args.num_questions,
|
| 209 |
+
"other": {
|
| 210 |
+
"num_questions": args.num_questions,
|
| 211 |
+
"parallel": args.parallel,
|
| 212 |
+
},
|
| 213 |
+
}
|
| 214 |
+
fout.write(json.dumps(value) + "\n")
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
if __name__ == "__main__":
|
| 218 |
+
parser = argparse.ArgumentParser()
|
| 219 |
+
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
| 220 |
+
parser.add_argument("--num-questions", type=int, default=200)
|
| 221 |
+
args = add_common_other_args_and_parse(parser)
|
| 222 |
+
main(args)
|
sglang/benchmark/tree_of_thought_deep/bench_sglang.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import ast
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
import time
|
| 6 |
+
from collections import Counter
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
import sglang as sgl
|
| 11 |
+
from sglang.test.test_utils import (
|
| 12 |
+
add_common_sglang_args_and_parse,
|
| 13 |
+
select_sglang_backend,
|
| 14 |
+
)
|
| 15 |
+
from sglang.utils import dump_state_text, read_jsonl
|
| 16 |
+
|
| 17 |
+
INVALID = -9999999
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_answer_value(answer_str):
|
| 21 |
+
answer_str = answer_str.replace(",", "")
|
| 22 |
+
numbers = re.findall(r"\d+", answer_str)
|
| 23 |
+
if len(numbers) < 1:
|
| 24 |
+
return INVALID
|
| 25 |
+
try:
|
| 26 |
+
return ast.literal_eval(numbers[-1])
|
| 27 |
+
except SyntaxError:
|
| 28 |
+
return INVALID
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def most_frequent_number(numbers):
|
| 32 |
+
if not numbers:
|
| 33 |
+
return None
|
| 34 |
+
|
| 35 |
+
frequency = Counter(numbers)
|
| 36 |
+
most_frequent = max(frequency, key=frequency.get)
|
| 37 |
+
return most_frequent
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Use a low temp to make the results more deterministic and the comparison more fair.
|
| 41 |
+
temp = 0.001
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def propose_plan(s, question, num_branches):
|
| 45 |
+
s += sgl.user(
|
| 46 |
+
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """
|
| 47 |
+
+ question
|
| 48 |
+
)
|
| 49 |
+
forks = s.fork(num_branches)
|
| 50 |
+
forks += sgl.assistant(sgl.gen("plan", max_tokens=256, temperature=temp))
|
| 51 |
+
return forks
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def execute_plan(s, num_branches):
|
| 55 |
+
s += sgl.user(
|
| 56 |
+
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short."""
|
| 57 |
+
)
|
| 58 |
+
forks = s.fork(num_branches)
|
| 59 |
+
forks += sgl.assistant(sgl.gen("answer", max_tokens=256, temperature=temp))
|
| 60 |
+
return forks
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def reflect_solution(s, num_branches):
|
| 64 |
+
s += sgl.user(
|
| 65 |
+
"""Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness."""
|
| 66 |
+
)
|
| 67 |
+
forks = s.fork(num_branches)
|
| 68 |
+
forks += sgl.assistant(sgl.gen("score", max_tokens=256, temperature=temp))
|
| 69 |
+
return forks
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_final_answer(s, num_branches):
|
| 73 |
+
s += sgl.user(
|
| 74 |
+
"""Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration."""
|
| 75 |
+
)
|
| 76 |
+
forks = s.fork(num_branches)
|
| 77 |
+
forks += sgl.assistant(sgl.gen("final_answer", max_tokens=256, temperature=temp))
|
| 78 |
+
return forks
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@sgl.function
|
| 82 |
+
def tree_search(s, question, num_branches):
|
| 83 |
+
plan_forks = propose_plan(s, question, num_branches)
|
| 84 |
+
|
| 85 |
+
sol_states = []
|
| 86 |
+
for plan in plan_forks:
|
| 87 |
+
forks = execute_plan(plan, num_branches)
|
| 88 |
+
sol_states.extend(forks)
|
| 89 |
+
|
| 90 |
+
ref_states = []
|
| 91 |
+
for sol in sol_states:
|
| 92 |
+
forks = reflect_solution(sol, num_branches)
|
| 93 |
+
ref_states.extend(forks)
|
| 94 |
+
|
| 95 |
+
solutions = []
|
| 96 |
+
for sol in ref_states:
|
| 97 |
+
forks = get_final_answer(sol, num_branches)
|
| 98 |
+
solutions.append(forks)
|
| 99 |
+
solutions = [[s.text() for s in forks] for forks in solutions]
|
| 100 |
+
|
| 101 |
+
return solutions
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def main(args):
|
| 105 |
+
lines = read_jsonl(args.data_path)
|
| 106 |
+
lines = list(lines)
|
| 107 |
+
|
| 108 |
+
# Construct prompts
|
| 109 |
+
num_branches = 2
|
| 110 |
+
questions = []
|
| 111 |
+
labels = []
|
| 112 |
+
for i in range(len(lines[: args.num_questions])):
|
| 113 |
+
questions.append(lines[i]["question"])
|
| 114 |
+
labels.append(get_answer_value(lines[i]["answer"]))
|
| 115 |
+
assert all(l != INVALID for l in labels)
|
| 116 |
+
arguments = [{"question": q, "num_branches": num_branches} for q in questions]
|
| 117 |
+
|
| 118 |
+
# Select backend
|
| 119 |
+
backend = select_sglang_backend(args)
|
| 120 |
+
|
| 121 |
+
# Run requests
|
| 122 |
+
tic = time.perf_counter()
|
| 123 |
+
states = tree_search.run_batch(
|
| 124 |
+
arguments,
|
| 125 |
+
temperature=0,
|
| 126 |
+
backend=backend,
|
| 127 |
+
num_threads=args.parallel,
|
| 128 |
+
progress_bar=True,
|
| 129 |
+
)
|
| 130 |
+
latency = time.perf_counter() - tic
|
| 131 |
+
answers_text = []
|
| 132 |
+
for s in states:
|
| 133 |
+
answers_text.append([x for xs in s.ret_value for x in xs])
|
| 134 |
+
|
| 135 |
+
preds = []
|
| 136 |
+
for i in range(len(states)):
|
| 137 |
+
answers = [get_answer_value(v) for v in answers_text[i]]
|
| 138 |
+
preds.append(most_frequent_number(answers))
|
| 139 |
+
|
| 140 |
+
# Compute accuracy
|
| 141 |
+
acc = np.mean(np.array(preds) == np.array(labels))
|
| 142 |
+
invalid = np.mean(np.array(preds) == INVALID)
|
| 143 |
+
print(f"Latency: {latency:.3f}")
|
| 144 |
+
print(f"Invalid: {invalid:.3f}")
|
| 145 |
+
print(f"Accuracy: {acc:.3f}")
|
| 146 |
+
|
| 147 |
+
# Write results
|
| 148 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", answers_text)
|
| 149 |
+
|
| 150 |
+
with open(args.result_file, "a") as fout:
|
| 151 |
+
value = {
|
| 152 |
+
"task": "tree_of_thought_gsm8k",
|
| 153 |
+
"backend": args.backend,
|
| 154 |
+
"num_gpus": 1,
|
| 155 |
+
"latency": round(latency, 3),
|
| 156 |
+
"accuracy": round(acc, 3),
|
| 157 |
+
"num_requests": args.num_questions,
|
| 158 |
+
"other": {
|
| 159 |
+
"num_questions": args.num_questions,
|
| 160 |
+
"parallel": args.parallel,
|
| 161 |
+
},
|
| 162 |
+
}
|
| 163 |
+
fout.write(json.dumps(value) + "\n")
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
if __name__ == "__main__":
|
| 167 |
+
parser = argparse.ArgumentParser()
|
| 168 |
+
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
| 169 |
+
parser.add_argument("--num-questions", type=int, default=200)
|
| 170 |
+
args = add_common_sglang_args_and_parse(parser)
|
| 171 |
+
main(args)
|
sglang/docker/configs/.zshrc
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export ZSH="/root/.oh-my-zsh"
|
| 2 |
+
|
| 3 |
+
# Theme
|
| 4 |
+
ZSH_THEME="robbyrussell"
|
| 5 |
+
|
| 6 |
+
# Plugins
|
| 7 |
+
plugins=(
|
| 8 |
+
git
|
| 9 |
+
z
|
| 10 |
+
zsh-autosuggestions
|
| 11 |
+
zsh-syntax-highlighting
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
source $ZSH/oh-my-zsh.sh
|
| 15 |
+
|
| 16 |
+
# Aliases
|
| 17 |
+
alias ll='ls -alF'
|
| 18 |
+
alias la='ls -A'
|
| 19 |
+
alias l='ls -CF'
|
| 20 |
+
alias vi='vim'
|
| 21 |
+
|
| 22 |
+
# Enhanced history
|
| 23 |
+
HISTSIZE=10000
|
| 24 |
+
SAVEHIST=10000
|
| 25 |
+
setopt HIST_IGNORE_ALL_DUPS
|
| 26 |
+
setopt HIST_FIND_NO_DUPS
|
| 27 |
+
setopt INC_APPEND_HISTORY
|
sglang/docker/configs/opt/.gitconfig
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[core]
|
| 2 |
+
editor = vim
|
| 3 |
+
whitespace = fix,-indent-with-non-tab,trailing-space,cr-at-eol
|
| 4 |
+
pager = diff-so-fancy | less --tabs=4 -RFX
|
| 5 |
+
|
| 6 |
+
[color]
|
| 7 |
+
ui = true
|
| 8 |
+
|
| 9 |
+
[color "diff-highlight"]
|
| 10 |
+
oldNormal = red bold
|
| 11 |
+
oldHighlight = red bold 52
|
| 12 |
+
newNormal = green bold
|
| 13 |
+
newHighlight = green bold 22
|
| 14 |
+
|
| 15 |
+
[color "diff"]
|
| 16 |
+
meta = 11
|
| 17 |
+
frag = magenta bold
|
| 18 |
+
commit = yellow bold
|
| 19 |
+
old = red bold
|
| 20 |
+
new = green bold
|
| 21 |
+
whitespace = red reverse
|
| 22 |
+
|
| 23 |
+
[alias]
|
| 24 |
+
lg = log --color --graph --pretty=format:'%Cred%h%Creset - %s %Cgreen(%cr) %C(bold blue)<%an>%Creset%C(auto)%d%Creset' --abbrev-commit --
|
| 25 |
+
|
| 26 |
+
[http]
|
| 27 |
+
sslVerify = false
|
| 28 |
+
|
| 29 |
+
[pull]
|
| 30 |
+
rebase = true
|
sglang/docker/configs/opt/.tmux.conf
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pane border styling
|
| 2 |
+
set -g pane-border-style fg='#742727',bg=black
|
| 3 |
+
set -g pane-active-border-style fg=red,bg=black
|
| 4 |
+
|
| 5 |
+
# Status bar styling
|
| 6 |
+
set -g status-style bg='#0C8A92',fg=black
|
| 7 |
+
|
| 8 |
+
# Change prefix key to backtick
|
| 9 |
+
set-option -g prefix `
|
| 10 |
+
unbind C-b
|
| 11 |
+
bind-key ` send-prefix
|
| 12 |
+
|
| 13 |
+
# Split panes using - and = with current path
|
| 14 |
+
unbind '"'
|
| 15 |
+
bind - splitw -v -c '#{pane_current_path}'
|
| 16 |
+
unbind '%'
|
| 17 |
+
bind = splitw -h -c '#{pane_current_path}'
|
| 18 |
+
|
| 19 |
+
# Vi mode settings
|
| 20 |
+
bind-key -T copy-mode-vi Y send-keys -X copy-pipe 'yank > #{pane_tty}'
|
| 21 |
+
set-window-option -g mode-keys vi
|
| 22 |
+
|
| 23 |
+
# Other settings
|
| 24 |
+
set-option -g escape-time 0
|
| 25 |
+
set-option -g base-index 1
|
| 26 |
+
set-window-option -g mouse on
|
| 27 |
+
set -g history-limit 100000
|
sglang/docker/configs/opt/.vimrc
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
function! Yank(text) abort
|
| 2 |
+
let escape = system('yank', a:text)
|
| 3 |
+
if v:shell_error
|
| 4 |
+
echoerr escape
|
| 5 |
+
else
|
| 6 |
+
call writefile([escape], '/dev/tty', 'b')
|
| 7 |
+
endif
|
| 8 |
+
endfunction
|
| 9 |
+
|
| 10 |
+
noremap <silent> <Leader>y y:<C-U>call Yank(@0)<CR>
|
| 11 |
+
|
| 12 |
+
" automatically run yank(1) whenever yanking in Vim
|
| 13 |
+
function! CopyYank() abort
|
| 14 |
+
call Yank(join(v:event.regcontents, "\n"))
|
| 15 |
+
endfunction
|
| 16 |
+
|
| 17 |
+
autocmd TextYankPost * call CopyYank()
|
| 18 |
+
|
| 19 |
+
" Basic settings
|
| 20 |
+
set number
|
| 21 |
+
syntax on
|
| 22 |
+
set mouse=a
|
| 23 |
+
filetype indent on
|
| 24 |
+
|
| 25 |
+
" Indentation
|
| 26 |
+
set autoindent nosmartindent
|
| 27 |
+
set smarttab
|
| 28 |
+
set expandtab
|
| 29 |
+
set shiftwidth=4
|
| 30 |
+
set softtabstop=4
|
| 31 |
+
|
| 32 |
+
" Visual guides
|
| 33 |
+
set colorcolumn=120
|
| 34 |
+
highlight ColorColumn ctermbg=5
|
| 35 |
+
|
| 36 |
+
" Status line
|
| 37 |
+
set laststatus=2
|
| 38 |
+
set statusline=%<%f\ %h%m%r%=%{\"[\".(&fenc==\"\"?&enc:&fenc).((exists(\"+bomb\")\ &&\ &bomb)?\",B\":\"\").\"]\ \"}%k\ %-14.(%l,%c%V%)\ %P
|
| 39 |
+
|
| 40 |
+
" Backspace behavior
|
| 41 |
+
set backspace=2
|
| 42 |
+
|
| 43 |
+
" Encoding
|
| 44 |
+
set encoding=utf-8
|
| 45 |
+
set fileencoding=utf-8
|
sglang/docker/configs/yank
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
put() {
|
| 3 |
+
esc=$1
|
| 4 |
+
test -n "$TMUX" -o -z "${TERM##screen*}" && esc="\033Ptmux;\033$esc\033\\"
|
| 5 |
+
printf "$esc"
|
| 6 |
+
}
|
| 7 |
+
put "\033]52;c;!\a"
|
| 8 |
+
buf=$( cat "$@" )
|
| 9 |
+
len=$( printf %s "$buf" | wc -c ) max=74994
|
| 10 |
+
test $len -gt $max && echo "$0: input is $(( len - max )) bytes too long" >&2
|
| 11 |
+
put "\033]52;c;$( printf %s "$buf" | head -c $max | base64 | tr -d '\r\n' )\a"
|
| 12 |
+
test -n "$TMUX" && tmux set-buffer "$buf" ||:
|
sglang/python/sglang.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: sglang
|
| 3 |
+
Version: 0.5.9
|
| 4 |
+
Summary: SGLang is a fast serving framework for large language models and vision language models.
|
| 5 |
+
Project-URL: Homepage, https://github.com/sgl-project/sglang
|
| 6 |
+
Project-URL: Bug Tracker, https://github.com/sgl-project/sglang/issues
|
| 7 |
+
Classifier: Programming Language :: Python :: 3
|
| 8 |
+
Classifier: License :: OSI Approved :: Apache Software License
|
| 9 |
+
Requires-Python: >=3.10
|
| 10 |
+
Description-Content-Type: text/markdown
|
| 11 |
+
Requires-Dist: IPython
|
| 12 |
+
Requires-Dist: aiohttp
|
| 13 |
+
Requires-Dist: apache-tvm-ffi<0.2,>=0.1.5
|
| 14 |
+
Requires-Dist: anthropic>=0.20.0
|
| 15 |
+
Requires-Dist: blobfile==3.0.0
|
| 16 |
+
Requires-Dist: build
|
| 17 |
+
Requires-Dist: compressed-tensors
|
| 18 |
+
Requires-Dist: cuda-python==12.9
|
| 19 |
+
Requires-Dist: decord2
|
| 20 |
+
Requires-Dist: datasets
|
| 21 |
+
Requires-Dist: einops
|
| 22 |
+
Requires-Dist: fastapi
|
| 23 |
+
Requires-Dist: flashinfer_python==0.6.4
|
| 24 |
+
Requires-Dist: flashinfer_cubin==0.6.4
|
| 25 |
+
Requires-Dist: gguf
|
| 26 |
+
Requires-Dist: hf_transfer
|
| 27 |
+
Requires-Dist: huggingface_hub
|
| 28 |
+
Requires-Dist: interegular
|
| 29 |
+
Requires-Dist: llguidance<0.8.0,>=0.7.11
|
| 30 |
+
Requires-Dist: modelscope
|
| 31 |
+
Requires-Dist: msgspec
|
| 32 |
+
Requires-Dist: ninja
|
| 33 |
+
Requires-Dist: numpy
|
| 34 |
+
Requires-Dist: nvidia-cutlass-dsl>=4.3.4
|
| 35 |
+
Requires-Dist: nvidia-ml-py
|
| 36 |
+
Requires-Dist: openai-harmony==0.0.4
|
| 37 |
+
Requires-Dist: openai==2.6.1
|
| 38 |
+
Requires-Dist: orjson
|
| 39 |
+
Requires-Dist: outlines==0.1.11
|
| 40 |
+
Requires-Dist: packaging
|
| 41 |
+
Requires-Dist: partial_json_parser
|
| 42 |
+
Requires-Dist: pillow
|
| 43 |
+
Requires-Dist: prometheus-client>=0.20.0
|
| 44 |
+
Requires-Dist: psutil
|
| 45 |
+
Requires-Dist: py-spy
|
| 46 |
+
Requires-Dist: pybase64
|
| 47 |
+
Requires-Dist: pydantic
|
| 48 |
+
Requires-Dist: python-multipart
|
| 49 |
+
Requires-Dist: pyzmq>=25.1.2
|
| 50 |
+
Requires-Dist: quack-kernels==0.2.4
|
| 51 |
+
Requires-Dist: requests
|
| 52 |
+
Requires-Dist: scipy
|
| 53 |
+
Requires-Dist: sentencepiece
|
| 54 |
+
Requires-Dist: setproctitle
|
| 55 |
+
Requires-Dist: sgl-fa4==4.0.3
|
| 56 |
+
Requires-Dist: sgl-kernel==0.3.21
|
| 57 |
+
Requires-Dist: soundfile==0.13.1
|
| 58 |
+
Requires-Dist: tiktoken
|
| 59 |
+
Requires-Dist: timm==1.0.16
|
| 60 |
+
Requires-Dist: torch_memory_saver==0.0.9
|
| 61 |
+
Requires-Dist: torch==2.9.1
|
| 62 |
+
Requires-Dist: torchao==0.9.0
|
| 63 |
+
Requires-Dist: torchaudio==2.9.1
|
| 64 |
+
Requires-Dist: torchcodec==0.8.0; sys_platform != "linux" or (sys_platform == "linux" and platform_machine != "aarch64" and platform_machine != "arm64" and platform_machine != "armv7l")
|
| 65 |
+
Requires-Dist: torchvision
|
| 66 |
+
Requires-Dist: tqdm
|
| 67 |
+
Requires-Dist: transformers==4.57.1
|
| 68 |
+
Requires-Dist: uvicorn
|
| 69 |
+
Requires-Dist: uvloop
|
| 70 |
+
Requires-Dist: watchfiles
|
| 71 |
+
Requires-Dist: xgrammar==0.1.27
|
| 72 |
+
Requires-Dist: smg-grpc-proto>=0.4.1
|
| 73 |
+
Requires-Dist: grpcio>=1.78.0
|
| 74 |
+
Requires-Dist: grpcio-reflection>=1.78.0
|
| 75 |
+
Requires-Dist: grpcio-health-checking>=1.78.0
|
| 76 |
+
Provides-Extra: checkpoint-engine
|
| 77 |
+
Requires-Dist: checkpoint-engine==0.1.2; extra == "checkpoint-engine"
|
| 78 |
+
Provides-Extra: diffusion
|
| 79 |
+
Requires-Dist: PyYAML==6.0.1; extra == "diffusion"
|
| 80 |
+
Requires-Dist: cloudpickle==3.1.2; extra == "diffusion"
|
| 81 |
+
Requires-Dist: diffusers==0.36.0; extra == "diffusion"
|
| 82 |
+
Requires-Dist: imageio==2.36.0; extra == "diffusion"
|
| 83 |
+
Requires-Dist: imageio-ffmpeg==0.5.1; extra == "diffusion"
|
| 84 |
+
Requires-Dist: moviepy>=2.0.0; extra == "diffusion"
|
| 85 |
+
Requires-Dist: opencv-python-headless==4.10.0.84; extra == "diffusion"
|
| 86 |
+
Requires-Dist: remote-pdb==2.1.0; extra == "diffusion"
|
| 87 |
+
Requires-Dist: st_attn==0.0.7; (platform_machine != "aarch64" and platform_machine != "arm64") and extra == "diffusion"
|
| 88 |
+
Requires-Dist: vsa==0.0.4; (platform_machine != "aarch64" and platform_machine != "arm64") and extra == "diffusion"
|
| 89 |
+
Requires-Dist: runai_model_streamer>=0.15.5; extra == "diffusion"
|
| 90 |
+
Requires-Dist: cache-dit==1.2.3; extra == "diffusion"
|
| 91 |
+
Requires-Dist: addict==2.4.0; extra == "diffusion"
|
| 92 |
+
Requires-Dist: av==16.1.0; extra == "diffusion"
|
| 93 |
+
Requires-Dist: scikit-image==0.25.2; extra == "diffusion"
|
| 94 |
+
Requires-Dist: trimesh>=4.0.0; extra == "diffusion"
|
| 95 |
+
Requires-Dist: xatlas; extra == "diffusion"
|
| 96 |
+
Provides-Extra: ray
|
| 97 |
+
Requires-Dist: ray[default]>=2.54.0; extra == "ray"
|
| 98 |
+
Provides-Extra: tracing
|
| 99 |
+
Requires-Dist: opentelemetry-api; extra == "tracing"
|
| 100 |
+
Requires-Dist: opentelemetry-exporter-otlp; extra == "tracing"
|
| 101 |
+
Requires-Dist: opentelemetry-exporter-otlp-proto-grpc; extra == "tracing"
|
| 102 |
+
Requires-Dist: opentelemetry-sdk; extra == "tracing"
|
| 103 |
+
Provides-Extra: test
|
| 104 |
+
Requires-Dist: accelerate; extra == "test"
|
| 105 |
+
Requires-Dist: bitsandbytes; extra == "test"
|
| 106 |
+
Requires-Dist: expecttest; extra == "test"
|
| 107 |
+
Requires-Dist: jsonlines; extra == "test"
|
| 108 |
+
Requires-Dist: lm-eval[api]>=0.4.9.2; extra == "test"
|
| 109 |
+
Requires-Dist: matplotlib; extra == "test"
|
| 110 |
+
Requires-Dist: pandas; extra == "test"
|
| 111 |
+
Requires-Dist: parameterized; extra == "test"
|
| 112 |
+
Requires-Dist: peft; extra == "test"
|
| 113 |
+
Requires-Dist: pytest; extra == "test"
|
| 114 |
+
Requires-Dist: sentence_transformers; extra == "test"
|
| 115 |
+
Requires-Dist: tabulate; extra == "test"
|
| 116 |
+
Provides-Extra: dev
|
| 117 |
+
Requires-Dist: sglang[test]; extra == "dev"
|
| 118 |
+
Provides-Extra: all
|
| 119 |
+
Requires-Dist: sglang[diffusion]; extra == "all"
|
| 120 |
+
Requires-Dist: sglang[tracing]; extra == "all"
|
sglang/python/sglang.egg-info/SOURCES.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sglang/python/sglang.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
sglang/python/sglang.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
sglang = sglang.cli.main:main
|
sglang/python/sglang.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
IPython
|
| 2 |
+
aiohttp
|
| 3 |
+
apache-tvm-ffi<0.2,>=0.1.5
|
| 4 |
+
anthropic>=0.20.0
|
| 5 |
+
blobfile==3.0.0
|
| 6 |
+
build
|
| 7 |
+
compressed-tensors
|
| 8 |
+
cuda-python==12.9
|
| 9 |
+
decord2
|
| 10 |
+
datasets
|
| 11 |
+
einops
|
| 12 |
+
fastapi
|
| 13 |
+
flashinfer_python==0.6.4
|
| 14 |
+
flashinfer_cubin==0.6.4
|
| 15 |
+
gguf
|
| 16 |
+
hf_transfer
|
| 17 |
+
huggingface_hub
|
| 18 |
+
interegular
|
| 19 |
+
llguidance<0.8.0,>=0.7.11
|
| 20 |
+
modelscope
|
| 21 |
+
msgspec
|
| 22 |
+
ninja
|
| 23 |
+
numpy
|
| 24 |
+
nvidia-cutlass-dsl>=4.3.4
|
| 25 |
+
nvidia-ml-py
|
| 26 |
+
openai-harmony==0.0.4
|
| 27 |
+
openai==2.6.1
|
| 28 |
+
orjson
|
| 29 |
+
outlines==0.1.11
|
| 30 |
+
packaging
|
| 31 |
+
partial_json_parser
|
| 32 |
+
pillow
|
| 33 |
+
prometheus-client>=0.20.0
|
| 34 |
+
psutil
|
| 35 |
+
py-spy
|
| 36 |
+
pybase64
|
| 37 |
+
pydantic
|
| 38 |
+
python-multipart
|
| 39 |
+
pyzmq>=25.1.2
|
| 40 |
+
quack-kernels==0.2.4
|
| 41 |
+
requests
|
| 42 |
+
scipy
|
| 43 |
+
sentencepiece
|
| 44 |
+
setproctitle
|
| 45 |
+
sgl-fa4==4.0.3
|
| 46 |
+
sgl-kernel==0.3.21
|
| 47 |
+
soundfile==0.13.1
|
| 48 |
+
tiktoken
|
| 49 |
+
timm==1.0.16
|
| 50 |
+
torch_memory_saver==0.0.9
|
| 51 |
+
torch==2.9.1
|
| 52 |
+
torchao==0.9.0
|
| 53 |
+
torchaudio==2.9.1
|
| 54 |
+
torchvision
|
| 55 |
+
tqdm
|
| 56 |
+
transformers==4.57.1
|
| 57 |
+
uvicorn
|
| 58 |
+
uvloop
|
| 59 |
+
watchfiles
|
| 60 |
+
xgrammar==0.1.27
|
| 61 |
+
smg-grpc-proto>=0.4.1
|
| 62 |
+
grpcio>=1.78.0
|
| 63 |
+
grpcio-reflection>=1.78.0
|
| 64 |
+
grpcio-health-checking>=1.78.0
|
| 65 |
+
|
| 66 |
+
[:sys_platform != "linux" or (sys_platform == "linux" and platform_machine != "aarch64" and platform_machine != "arm64" and platform_machine != "armv7l")]
|
| 67 |
+
torchcodec==0.8.0
|
| 68 |
+
|
| 69 |
+
[all]
|
| 70 |
+
sglang[diffusion]
|
| 71 |
+
sglang[tracing]
|
| 72 |
+
|
| 73 |
+
[checkpoint-engine]
|
| 74 |
+
checkpoint-engine==0.1.2
|
| 75 |
+
|
| 76 |
+
[dev]
|
| 77 |
+
sglang[test]
|
| 78 |
+
|
| 79 |
+
[diffusion]
|
| 80 |
+
PyYAML==6.0.1
|
| 81 |
+
cloudpickle==3.1.2
|
| 82 |
+
diffusers==0.36.0
|
| 83 |
+
imageio==2.36.0
|
| 84 |
+
imageio-ffmpeg==0.5.1
|
| 85 |
+
moviepy>=2.0.0
|
| 86 |
+
opencv-python-headless==4.10.0.84
|
| 87 |
+
remote-pdb==2.1.0
|
| 88 |
+
runai_model_streamer>=0.15.5
|
| 89 |
+
cache-dit==1.2.3
|
| 90 |
+
addict==2.4.0
|
| 91 |
+
av==16.1.0
|
| 92 |
+
scikit-image==0.25.2
|
| 93 |
+
trimesh>=4.0.0
|
| 94 |
+
xatlas
|
| 95 |
+
|
| 96 |
+
[diffusion:platform_machine != "aarch64" and platform_machine != "arm64"]
|
| 97 |
+
st_attn==0.0.7
|
| 98 |
+
vsa==0.0.4
|
| 99 |
+
|
| 100 |
+
[ray]
|
| 101 |
+
ray[default]>=2.54.0
|
| 102 |
+
|
| 103 |
+
[test]
|
| 104 |
+
accelerate
|
| 105 |
+
bitsandbytes
|
| 106 |
+
expecttest
|
| 107 |
+
jsonlines
|
| 108 |
+
lm-eval[api]>=0.4.9.2
|
| 109 |
+
matplotlib
|
| 110 |
+
pandas
|
| 111 |
+
parameterized
|
| 112 |
+
peft
|
| 113 |
+
pytest
|
| 114 |
+
sentence_transformers
|
| 115 |
+
tabulate
|
| 116 |
+
|
| 117 |
+
[tracing]
|
| 118 |
+
opentelemetry-api
|
| 119 |
+
opentelemetry-exporter-otlp
|
| 120 |
+
opentelemetry-exporter-otlp-proto-grpc
|
| 121 |
+
opentelemetry-sdk
|
sglang/python/sglang.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
sglang
|
sglang/python/sglang/README.md
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code Structure
|
| 2 |
+
|
| 3 |
+
- `eval`: The evaluation utilities.
|
| 4 |
+
- `lang`: The frontend language.
|
| 5 |
+
- `multimodal_gen`: Inference framework for accelerated image/video generation.
|
| 6 |
+
- `srt`: The backend engine for running local models. (SRT = SGLang Runtime).
|
| 7 |
+
- `test`: The test utilities.
|
| 8 |
+
- `api.py`: The public APIs.
|
| 9 |
+
- `bench_offline_throughput.py`: Benchmark the performance in the offline mode.
|
| 10 |
+
- `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server.
|
| 11 |
+
- `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server.
|
| 12 |
+
- `bench_serving.py`: Benchmark online serving with dynamic requests.
|
| 13 |
+
- `check_env.py`: Check the environment variables and dependencies.
|
| 14 |
+
- `global_config.py`: The global configs and constants.
|
| 15 |
+
- `launch_server.py`: The entry point for launching a local server.
|
| 16 |
+
- `profiler.py`: The profiling entry point to send profile requests.
|
| 17 |
+
- `utils.py`: Common utilities.
|
| 18 |
+
- `version.py`: Version info.
|
sglang/python/sglang/__init__.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SGLang public APIs
|
| 2 |
+
|
| 3 |
+
# Frontend Language APIs
|
| 4 |
+
from sglang.global_config import global_config
|
| 5 |
+
from sglang.lang.api import (
|
| 6 |
+
Engine,
|
| 7 |
+
Runtime,
|
| 8 |
+
assistant,
|
| 9 |
+
assistant_begin,
|
| 10 |
+
assistant_end,
|
| 11 |
+
flush_cache,
|
| 12 |
+
function,
|
| 13 |
+
gen,
|
| 14 |
+
gen_int,
|
| 15 |
+
gen_string,
|
| 16 |
+
get_server_info,
|
| 17 |
+
image,
|
| 18 |
+
select,
|
| 19 |
+
separate_reasoning,
|
| 20 |
+
set_default_backend,
|
| 21 |
+
system,
|
| 22 |
+
system_begin,
|
| 23 |
+
system_end,
|
| 24 |
+
user,
|
| 25 |
+
user_begin,
|
| 26 |
+
user_end,
|
| 27 |
+
video,
|
| 28 |
+
)
|
| 29 |
+
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
| 30 |
+
from sglang.lang.choices import (
|
| 31 |
+
greedy_token_selection,
|
| 32 |
+
token_length_normalized,
|
| 33 |
+
unconditional_likelihood_normalized,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Lazy import some libraries
|
| 37 |
+
from sglang.utils import LazyImport
|
| 38 |
+
from sglang.version import __version__
|
| 39 |
+
|
| 40 |
+
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
|
| 41 |
+
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
|
| 42 |
+
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
|
| 43 |
+
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
|
| 44 |
+
|
| 45 |
+
# Runtime Engine APIs
|
| 46 |
+
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
|
| 47 |
+
Engine = LazyImport("sglang.srt.entrypoints.engine", "Engine")
|
| 48 |
+
|
| 49 |
+
__all__ = [
|
| 50 |
+
"Engine",
|
| 51 |
+
"Runtime",
|
| 52 |
+
"assistant",
|
| 53 |
+
"assistant_begin",
|
| 54 |
+
"assistant_end",
|
| 55 |
+
"flush_cache",
|
| 56 |
+
"function",
|
| 57 |
+
"gen",
|
| 58 |
+
"gen_int",
|
| 59 |
+
"gen_string",
|
| 60 |
+
"get_server_info",
|
| 61 |
+
"image",
|
| 62 |
+
"select",
|
| 63 |
+
"separate_reasoning",
|
| 64 |
+
"set_default_backend",
|
| 65 |
+
"system",
|
| 66 |
+
"system_begin",
|
| 67 |
+
"system_end",
|
| 68 |
+
"user",
|
| 69 |
+
"user_begin",
|
| 70 |
+
"user_end",
|
| 71 |
+
"video",
|
| 72 |
+
"RuntimeEndpoint",
|
| 73 |
+
"greedy_token_selection",
|
| 74 |
+
"token_length_normalized",
|
| 75 |
+
"unconditional_likelihood_normalized",
|
| 76 |
+
"ServerArgs",
|
| 77 |
+
"Anthropic",
|
| 78 |
+
"LiteLLM",
|
| 79 |
+
"OpenAI",
|
| 80 |
+
"VertexAI",
|
| 81 |
+
"global_config",
|
| 82 |
+
"__version__",
|
| 83 |
+
]
|
sglang/python/sglang/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.08 kB). View file
|
|
|
sglang/python/sglang/__pycache__/_version.cpython-311.pyc
ADDED
|
Binary file (872 Bytes). View file
|
|
|
sglang/python/sglang/__pycache__/bench_serving.cpython-311.pyc
ADDED
|
Binary file (95.3 kB). View file
|
|
|
sglang/python/sglang/__pycache__/check_env.cpython-311.pyc
ADDED
|
Binary file (24.9 kB). View file
|
|
|
sglang/python/sglang/__pycache__/global_config.cpython-311.pyc
ADDED
|
Binary file (969 Bytes). View file
|
|
|
sglang/python/sglang/__pycache__/launch_server.cpython-311.pyc
ADDED
|
Binary file (2.62 kB). View file
|
|
|
sglang/python/sglang/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (34.6 kB). View file
|
|
|
sglang/python/sglang/__pycache__/version.cpython-311.pyc
ADDED
|
Binary file (1.31 kB). View file
|
|
|
sglang/python/sglang/_version.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# file generated by setuptools-scm
|
| 2 |
+
# don't change, don't track in version control
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"__version__",
|
| 6 |
+
"__version_tuple__",
|
| 7 |
+
"version",
|
| 8 |
+
"version_tuple",
|
| 9 |
+
"__commit_id__",
|
| 10 |
+
"commit_id",
|
| 11 |
+
]
|
| 12 |
+
|
| 13 |
+
TYPE_CHECKING = False
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from typing import Tuple
|
| 16 |
+
from typing import Union
|
| 17 |
+
|
| 18 |
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
| 19 |
+
COMMIT_ID = Union[str, None]
|
| 20 |
+
else:
|
| 21 |
+
VERSION_TUPLE = object
|
| 22 |
+
COMMIT_ID = object
|
| 23 |
+
|
| 24 |
+
version: str
|
| 25 |
+
__version__: str
|
| 26 |
+
__version_tuple__: VERSION_TUPLE
|
| 27 |
+
version_tuple: VERSION_TUPLE
|
| 28 |
+
commit_id: COMMIT_ID
|
| 29 |
+
__commit_id__: COMMIT_ID
|
| 30 |
+
|
| 31 |
+
__version__ = version = '0.5.9'
|
| 32 |
+
__version_tuple__ = version_tuple = (0, 5, 9)
|
| 33 |
+
|
| 34 |
+
__commit_id__ = commit_id = 'gbbe9c7eeb'
|
sglang/python/sglang/bench_offline_throughput.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Benchmark the throughput in the offline mode.
|
| 3 |
+
It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py).
|
| 4 |
+
|
| 5 |
+
# Usage
|
| 6 |
+
## Sharegpt dataset with default args
|
| 7 |
+
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10
|
| 8 |
+
|
| 9 |
+
## Random dataset with default args
|
| 10 |
+
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import asyncio
|
| 15 |
+
import dataclasses
|
| 16 |
+
import inspect
|
| 17 |
+
import json
|
| 18 |
+
import logging
|
| 19 |
+
import os
|
| 20 |
+
import random
|
| 21 |
+
import time
|
| 22 |
+
from typing import Dict, List, Optional
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
|
| 26 |
+
from sglang.benchmark.datasets import DatasetRow, get_dataset
|
| 27 |
+
from sglang.benchmark.datasets.random import sample_random_requests
|
| 28 |
+
from sglang.benchmark.utils import get_tokenizer, set_ulimit
|
| 29 |
+
from sglang.lang.backend.runtime_endpoint import Runtime
|
| 30 |
+
from sglang.srt.entrypoints.engine import Engine
|
| 31 |
+
from sglang.srt.server_args import ServerArgs
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclasses.dataclass
|
| 35 |
+
class BenchArgs:
|
| 36 |
+
backend: str = "engine"
|
| 37 |
+
result_filename: str = ""
|
| 38 |
+
dataset_name: str = "sharegpt"
|
| 39 |
+
dataset_path: str = ""
|
| 40 |
+
num_prompts: int = 1000
|
| 41 |
+
sharegpt_output_len: Optional[int] = None
|
| 42 |
+
sharegpt_context_len: Optional[int] = None
|
| 43 |
+
random_input_len: int = 1024
|
| 44 |
+
random_output_len: int = 1024
|
| 45 |
+
random_range_ratio: float = 0.0
|
| 46 |
+
gsp_num_groups: int = 64
|
| 47 |
+
gsp_prompts_per_group: int = 16
|
| 48 |
+
gsp_system_prompt_len: int = 2048
|
| 49 |
+
gsp_question_len: int = 128
|
| 50 |
+
gsp_output_len: int = 256
|
| 51 |
+
seed: int = 1
|
| 52 |
+
disable_ignore_eos: bool = False
|
| 53 |
+
extra_request_body: Optional[str] = None
|
| 54 |
+
apply_chat_template: bool = False
|
| 55 |
+
profile: bool = False
|
| 56 |
+
skip_warmup: bool = False
|
| 57 |
+
do_not_exit: bool = False
|
| 58 |
+
prompt_suffix: str = ""
|
| 59 |
+
return_logprob: bool = False
|
| 60 |
+
logprob_start_len: int = -1
|
| 61 |
+
|
| 62 |
+
@staticmethod
|
| 63 |
+
def add_cli_args(parser: argparse.ArgumentParser):
|
| 64 |
+
parser.add_argument("--backend", type=str, default=BenchArgs.backend)
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--result-filename", type=str, default=BenchArgs.result_filename
|
| 67 |
+
)
|
| 68 |
+
parser.add_argument(
|
| 69 |
+
"--dataset-name",
|
| 70 |
+
type=str,
|
| 71 |
+
default="sharegpt",
|
| 72 |
+
choices=["sharegpt", "random", "generated-shared-prefix"],
|
| 73 |
+
help="Name of the dataset to benchmark on.",
|
| 74 |
+
)
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--dataset-path", type=str, default="", help="Path to the dataset."
|
| 77 |
+
)
|
| 78 |
+
parser.add_argument(
|
| 79 |
+
"--num-prompts",
|
| 80 |
+
type=int,
|
| 81 |
+
default=BenchArgs.num_prompts,
|
| 82 |
+
help="Number of prompts to process. Default is 1000.",
|
| 83 |
+
)
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--sharegpt-output-len",
|
| 86 |
+
type=int,
|
| 87 |
+
default=BenchArgs.sharegpt_output_len,
|
| 88 |
+
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--sharegpt-context-len",
|
| 92 |
+
type=int,
|
| 93 |
+
default=BenchArgs.sharegpt_context_len,
|
| 94 |
+
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
|
| 95 |
+
)
|
| 96 |
+
parser.add_argument(
|
| 97 |
+
"--random-input-len",
|
| 98 |
+
type=int,
|
| 99 |
+
default=BenchArgs.random_input_len,
|
| 100 |
+
help="Number of input tokens per request, used only for random dataset.",
|
| 101 |
+
)
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--random-output-len",
|
| 104 |
+
type=int,
|
| 105 |
+
default=BenchArgs.random_output_len,
|
| 106 |
+
help="Number of output tokens per request, used only for random dataset.",
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--random-range-ratio",
|
| 110 |
+
type=float,
|
| 111 |
+
default=BenchArgs.random_range_ratio,
|
| 112 |
+
help="Range of sampled ratio of input/output length, "
|
| 113 |
+
"used only for random dataset.",
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--gsp-num-groups",
|
| 117 |
+
type=int,
|
| 118 |
+
default=BenchArgs.gsp_num_groups,
|
| 119 |
+
help="Number of groups with shared prefix, used"
|
| 120 |
+
"only for generate-shared-prefix",
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--gsp-prompts-per-group",
|
| 124 |
+
type=int,
|
| 125 |
+
default=BenchArgs.gsp_prompts_per_group,
|
| 126 |
+
help="Number of prompts per group of shared prefix, used"
|
| 127 |
+
"only for generate-shared-prefix",
|
| 128 |
+
)
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--gsp-system-prompt-len",
|
| 131 |
+
type=int,
|
| 132 |
+
default=BenchArgs.gsp_system_prompt_len,
|
| 133 |
+
help="System prompt length, used" "only for generate-shared-prefix",
|
| 134 |
+
)
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--gsp-question-len",
|
| 137 |
+
type=int,
|
| 138 |
+
default=BenchArgs.gsp_question_len,
|
| 139 |
+
help="Question length, used" "only for generate-shared-prefix",
|
| 140 |
+
)
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--gsp-output-len",
|
| 143 |
+
type=int,
|
| 144 |
+
default=BenchArgs.gsp_output_len,
|
| 145 |
+
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
| 146 |
+
)
|
| 147 |
+
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
| 148 |
+
parser.add_argument(
|
| 149 |
+
"--disable-ignore-eos",
|
| 150 |
+
action="store_true",
|
| 151 |
+
help="Disable ignore EOS token",
|
| 152 |
+
)
|
| 153 |
+
parser.add_argument(
|
| 154 |
+
"--extra-request-body",
|
| 155 |
+
metavar='{"key1": "value1", "key2": "value2"}',
|
| 156 |
+
type=str,
|
| 157 |
+
default=BenchArgs.extra_request_body,
|
| 158 |
+
help="Append given JSON object to the request payload. You can use this to specify"
|
| 159 |
+
"additional generate params like sampling params.",
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--apply-chat-template",
|
| 163 |
+
action="store_true",
|
| 164 |
+
help="Apply chat template",
|
| 165 |
+
)
|
| 166 |
+
parser.add_argument(
|
| 167 |
+
"--profile",
|
| 168 |
+
action="store_true",
|
| 169 |
+
help="Use Torch Profiler. The endpoint must be launched with "
|
| 170 |
+
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
| 171 |
+
)
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--skip-warmup",
|
| 174 |
+
action="store_true",
|
| 175 |
+
help="Skip the warmup batches.",
|
| 176 |
+
)
|
| 177 |
+
parser.add_argument(
|
| 178 |
+
"--do-not-exit",
|
| 179 |
+
action="store_true",
|
| 180 |
+
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
|
| 181 |
+
)
|
| 182 |
+
parser.add_argument(
|
| 183 |
+
"--prompt-suffix",
|
| 184 |
+
type=str,
|
| 185 |
+
default="",
|
| 186 |
+
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
|
| 187 |
+
)
|
| 188 |
+
parser.add_argument(
|
| 189 |
+
"--return-logprob",
|
| 190 |
+
action="store_true",
|
| 191 |
+
help="Enable returning log probabilities.",
|
| 192 |
+
)
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--logprob-start-len",
|
| 195 |
+
type=int,
|
| 196 |
+
default=-1,
|
| 197 |
+
help="Start length for logprob. -1 means only return logprobs for output tokens (default). 0 means return logprobs for all tokens including input.",
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
@classmethod
|
| 201 |
+
def from_cli_args(cls, args: argparse.Namespace):
|
| 202 |
+
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
| 203 |
+
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def throughput_test_once(
|
| 207 |
+
backend_name: str,
|
| 208 |
+
backend,
|
| 209 |
+
reqs: List[DatasetRow],
|
| 210 |
+
ignore_eos: bool,
|
| 211 |
+
extra_request_body: Dict,
|
| 212 |
+
profile: bool,
|
| 213 |
+
return_logprob: bool = False,
|
| 214 |
+
logprob_start_len: int = -1,
|
| 215 |
+
):
|
| 216 |
+
measurement_results = {
|
| 217 |
+
"backend": backend_name,
|
| 218 |
+
"successful_requests": len(reqs),
|
| 219 |
+
"total_latency": -1,
|
| 220 |
+
"total_input_tokens": sum(r.prompt_len for r in reqs),
|
| 221 |
+
"total_output_tokens": -1,
|
| 222 |
+
"request_throughput": -1,
|
| 223 |
+
"input_throughput": -1,
|
| 224 |
+
"output_throughput": -1,
|
| 225 |
+
"total_throughput": -1,
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
prompt = [r.prompt for r in reqs]
|
| 229 |
+
sampling_params = [
|
| 230 |
+
{
|
| 231 |
+
"temperature": 0,
|
| 232 |
+
"max_new_tokens": r.output_len,
|
| 233 |
+
"ignore_eos": ignore_eos,
|
| 234 |
+
**extra_request_body,
|
| 235 |
+
}
|
| 236 |
+
for r in reqs
|
| 237 |
+
]
|
| 238 |
+
|
| 239 |
+
if profile:
|
| 240 |
+
assert (
|
| 241 |
+
"SGLANG_TORCH_PROFILER_DIR" in os.environ
|
| 242 |
+
), "Please set SGLANG_TORCH_PROFILER_DIR."
|
| 243 |
+
os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True)
|
| 244 |
+
backend.start_profile()
|
| 245 |
+
|
| 246 |
+
st = time.perf_counter()
|
| 247 |
+
gen_out = backend.generate(
|
| 248 |
+
prompt=prompt,
|
| 249 |
+
sampling_params=sampling_params,
|
| 250 |
+
return_logprob=return_logprob,
|
| 251 |
+
logprob_start_len=logprob_start_len,
|
| 252 |
+
)
|
| 253 |
+
latency = time.perf_counter() - st
|
| 254 |
+
|
| 255 |
+
if profile:
|
| 256 |
+
dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
|
| 257 |
+
known_files = set(os.listdir(dir))
|
| 258 |
+
backend.stop_profile()
|
| 259 |
+
monitor_trace_file(known_files, dir)
|
| 260 |
+
|
| 261 |
+
if backend_name == "runtime":
|
| 262 |
+
gen_out = json.loads(gen_out)
|
| 263 |
+
|
| 264 |
+
server_info = backend.get_server_info()
|
| 265 |
+
|
| 266 |
+
measurement_results["total_latency"] = latency
|
| 267 |
+
measurement_results["total_output_tokens"] = sum(
|
| 268 |
+
o["meta_info"]["completion_tokens"] for o in gen_out
|
| 269 |
+
)
|
| 270 |
+
measurement_results["request_throughput"] = (
|
| 271 |
+
measurement_results["successful_requests"] / latency
|
| 272 |
+
)
|
| 273 |
+
measurement_results["input_throughput"] = (
|
| 274 |
+
measurement_results["total_input_tokens"] / latency
|
| 275 |
+
)
|
| 276 |
+
measurement_results["output_throughput"] = (
|
| 277 |
+
measurement_results["total_output_tokens"] / latency
|
| 278 |
+
)
|
| 279 |
+
measurement_results["total_throughput"] = (
|
| 280 |
+
measurement_results["total_input_tokens"]
|
| 281 |
+
+ measurement_results["total_output_tokens"]
|
| 282 |
+
) / latency
|
| 283 |
+
|
| 284 |
+
if inspect.isawaitable(server_info):
|
| 285 |
+
server_info = asyncio.run(server_info)
|
| 286 |
+
|
| 287 |
+
measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
|
| 288 |
+
"last_gen_throughput"
|
| 289 |
+
]
|
| 290 |
+
|
| 291 |
+
return measurement_results
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def monitor_trace_file(known_files, directory, interval=1):
|
| 295 |
+
print(f"Monitoring {directory} for new trace files...")
|
| 296 |
+
|
| 297 |
+
while True:
|
| 298 |
+
flag = False
|
| 299 |
+
time.sleep(interval)
|
| 300 |
+
current_files = set(os.listdir(directory))
|
| 301 |
+
|
| 302 |
+
new_files = current_files - known_files
|
| 303 |
+
for new_file in new_files:
|
| 304 |
+
new_file_path = os.path.join(directory, new_file)
|
| 305 |
+
print(f"New file detected: {new_file}")
|
| 306 |
+
|
| 307 |
+
previous_size = 0
|
| 308 |
+
while True:
|
| 309 |
+
try:
|
| 310 |
+
current_size = os.path.getsize(new_file_path)
|
| 311 |
+
except FileNotFoundError:
|
| 312 |
+
print(f"File {new_file} is no longer accessible.")
|
| 313 |
+
break
|
| 314 |
+
|
| 315 |
+
if current_size > previous_size:
|
| 316 |
+
previous_size = current_size
|
| 317 |
+
else:
|
| 318 |
+
flag = True
|
| 319 |
+
break
|
| 320 |
+
|
| 321 |
+
time.sleep(interval)
|
| 322 |
+
if flag:
|
| 323 |
+
break
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def _create_ray_engine_backend(server_args: ServerArgs):
|
| 327 |
+
"""Create a RayEngine inside a Ray actor on a placement group.
|
| 328 |
+
|
| 329 |
+
RayEngine requires a placement group, so we launch it inside a Ray actor
|
| 330 |
+
and return a lightweight proxy that forwards calls via ray.get().
|
| 331 |
+
"""
|
| 332 |
+
import ray
|
| 333 |
+
from ray.runtime_env import RuntimeEnv
|
| 334 |
+
from ray.util.placement_group import placement_group
|
| 335 |
+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
| 336 |
+
|
| 337 |
+
env_vars = {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}
|
| 338 |
+
if os.environ.get("HF_TOKEN"):
|
| 339 |
+
env_vars["HF_TOKEN"] = os.environ["HF_TOKEN"]
|
| 340 |
+
if not ray.is_initialized():
|
| 341 |
+
ray.init(runtime_env=RuntimeEnv(env_vars=env_vars))
|
| 342 |
+
|
| 343 |
+
total_gpus = server_args.tp_size * server_args.pp_size
|
| 344 |
+
pg = placement_group([{"CPU": 1, "GPU": total_gpus}], strategy="STRICT_PACK")
|
| 345 |
+
ray.get(pg.ready())
|
| 346 |
+
|
| 347 |
+
@ray.remote
|
| 348 |
+
class _EngineActor:
|
| 349 |
+
def __init__(self, **kwargs):
|
| 350 |
+
from sglang.srt.ray.engine import RayEngine
|
| 351 |
+
|
| 352 |
+
self.engine = RayEngine(**kwargs)
|
| 353 |
+
|
| 354 |
+
def call(self, method, **kwargs):
|
| 355 |
+
return getattr(self.engine, method)(**kwargs)
|
| 356 |
+
|
| 357 |
+
actor = _EngineActor.options(
|
| 358 |
+
num_cpus=1,
|
| 359 |
+
num_gpus=0,
|
| 360 |
+
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
| 361 |
+
placement_group=pg,
|
| 362 |
+
placement_group_bundle_index=0,
|
| 363 |
+
),
|
| 364 |
+
).remote(**dataclasses.asdict(server_args))
|
| 365 |
+
|
| 366 |
+
class _Proxy:
|
| 367 |
+
"""Forwards method calls to the remote RayEngine actor."""
|
| 368 |
+
|
| 369 |
+
def generate(self, **kwargs):
|
| 370 |
+
return ray.get(actor.call.remote("generate", **kwargs))
|
| 371 |
+
|
| 372 |
+
def get_server_info(self, **kwargs):
|
| 373 |
+
return ray.get(actor.call.remote("get_server_info", **kwargs))
|
| 374 |
+
|
| 375 |
+
def start_profile(self, **kwargs):
|
| 376 |
+
return ray.get(actor.call.remote("start_profile", **kwargs))
|
| 377 |
+
|
| 378 |
+
def stop_profile(self, **kwargs):
|
| 379 |
+
return ray.get(actor.call.remote("stop_profile", **kwargs))
|
| 380 |
+
|
| 381 |
+
def shutdown(self):
|
| 382 |
+
try:
|
| 383 |
+
ray.get(actor.call.remote("shutdown"), timeout=60)
|
| 384 |
+
except Exception:
|
| 385 |
+
pass
|
| 386 |
+
try:
|
| 387 |
+
ray.util.remove_placement_group(pg)
|
| 388 |
+
except Exception:
|
| 389 |
+
pass
|
| 390 |
+
|
| 391 |
+
return _Proxy()
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def throughput_test(
|
| 395 |
+
server_args: ServerArgs,
|
| 396 |
+
bench_args: BenchArgs,
|
| 397 |
+
):
|
| 398 |
+
if bench_args.backend == "engine":
|
| 399 |
+
if server_args.use_ray:
|
| 400 |
+
backend = _create_ray_engine_backend(server_args)
|
| 401 |
+
else:
|
| 402 |
+
backend = Engine(**dataclasses.asdict(server_args))
|
| 403 |
+
if not backend:
|
| 404 |
+
raise ValueError("Please provide valid engine arguments")
|
| 405 |
+
elif bench_args.backend == "runtime":
|
| 406 |
+
backend = Runtime(**dataclasses.asdict(server_args))
|
| 407 |
+
else:
|
| 408 |
+
raise ValueError('Please set backend to either "engine" or "runtime"')
|
| 409 |
+
|
| 410 |
+
tokenizer_id = server_args.tokenizer_path or server_args.model_path
|
| 411 |
+
tokenizer = get_tokenizer(tokenizer_id)
|
| 412 |
+
|
| 413 |
+
# Set global environments
|
| 414 |
+
set_ulimit()
|
| 415 |
+
random.seed(bench_args.seed)
|
| 416 |
+
np.random.seed(bench_args.seed)
|
| 417 |
+
|
| 418 |
+
# Parse args
|
| 419 |
+
extra_request_body = {}
|
| 420 |
+
if bench_args.extra_request_body:
|
| 421 |
+
extra_request_body = json.loads(args.extra_request_body)
|
| 422 |
+
|
| 423 |
+
# Read dataset
|
| 424 |
+
input_requests = get_dataset(bench_args, tokenizer)
|
| 425 |
+
|
| 426 |
+
warmup_requests = sample_random_requests(
|
| 427 |
+
input_len=256,
|
| 428 |
+
output_len=16,
|
| 429 |
+
num_prompts=min(bench_args.num_prompts, 16),
|
| 430 |
+
range_ratio=1.0,
|
| 431 |
+
tokenizer=tokenizer,
|
| 432 |
+
dataset_path=bench_args.dataset_path,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
# Warm up
|
| 436 |
+
if not bench_args.skip_warmup:
|
| 437 |
+
logging.info("\nWarmup...")
|
| 438 |
+
throughput_test_once(
|
| 439 |
+
backend_name=bench_args.backend,
|
| 440 |
+
backend=backend,
|
| 441 |
+
reqs=warmup_requests,
|
| 442 |
+
ignore_eos=not bench_args.disable_ignore_eos,
|
| 443 |
+
extra_request_body=extra_request_body,
|
| 444 |
+
profile=False,
|
| 445 |
+
return_logprob=bench_args.return_logprob,
|
| 446 |
+
logprob_start_len=bench_args.logprob_start_len,
|
| 447 |
+
)
|
| 448 |
+
time.sleep(0.5)
|
| 449 |
+
|
| 450 |
+
logging.info("\nBenchmark...")
|
| 451 |
+
result = throughput_test_once(
|
| 452 |
+
backend_name=bench_args.backend,
|
| 453 |
+
backend=backend,
|
| 454 |
+
reqs=input_requests,
|
| 455 |
+
ignore_eos=not bench_args.disable_ignore_eos,
|
| 456 |
+
extra_request_body=extra_request_body,
|
| 457 |
+
profile=bench_args.profile,
|
| 458 |
+
return_logprob=bench_args.return_logprob,
|
| 459 |
+
logprob_start_len=bench_args.logprob_start_len,
|
| 460 |
+
)
|
| 461 |
+
backend.shutdown()
|
| 462 |
+
|
| 463 |
+
if bench_args.result_filename:
|
| 464 |
+
with open(bench_args.result_filename, "a") as fout:
|
| 465 |
+
fout.write(json.dumps(result) + "\n")
|
| 466 |
+
|
| 467 |
+
print(
|
| 468 |
+
"\n{s:{c}^{n}}".format(s=" Offline Throughput Benchmark Result ", n=50, c="=")
|
| 469 |
+
)
|
| 470 |
+
print("{:<40} {:<10}".format("Backend:", result["backend"]))
|
| 471 |
+
print("{:<40} {:<10}".format("Successful requests:", result["successful_requests"]))
|
| 472 |
+
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", result["total_latency"]))
|
| 473 |
+
print("{:<40} {:<10}".format("Total input tokens:", result["total_input_tokens"]))
|
| 474 |
+
print(
|
| 475 |
+
"{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
|
| 476 |
+
)
|
| 477 |
+
print(
|
| 478 |
+
"{:<40} {:<10.2f}".format(
|
| 479 |
+
"Last generation throughput (tok/s):", result["last_gen_throughput"]
|
| 480 |
+
)
|
| 481 |
+
)
|
| 482 |
+
print(
|
| 483 |
+
"{:<40} {:<10.2f}".format(
|
| 484 |
+
"Request throughput (req/s):", result["request_throughput"]
|
| 485 |
+
)
|
| 486 |
+
)
|
| 487 |
+
print(
|
| 488 |
+
"{:<40} {:<10.2f}".format(
|
| 489 |
+
"Input token throughput (tok/s):", result["input_throughput"]
|
| 490 |
+
)
|
| 491 |
+
)
|
| 492 |
+
print(
|
| 493 |
+
"{:<40} {:<10.2f}".format(
|
| 494 |
+
"Output token throughput (tok/s):", result["output_throughput"]
|
| 495 |
+
)
|
| 496 |
+
)
|
| 497 |
+
print(
|
| 498 |
+
"{:<40} {:<10.2f}".format(
|
| 499 |
+
"Total token throughput (tok/s):", result["total_throughput"]
|
| 500 |
+
)
|
| 501 |
+
)
|
| 502 |
+
print("=" * 50)
|
| 503 |
+
|
| 504 |
+
return result
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
if __name__ == "__main__":
|
| 508 |
+
parser = argparse.ArgumentParser()
|
| 509 |
+
ServerArgs.add_cli_args(parser)
|
| 510 |
+
BenchArgs.add_cli_args(parser)
|
| 511 |
+
args = parser.parse_args()
|
| 512 |
+
|
| 513 |
+
# handling ModelScope model downloads
|
| 514 |
+
if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() in ("true", "1"):
|
| 515 |
+
if os.path.exists(args.model_path):
|
| 516 |
+
print(f"Using local model path: {args.model_path}")
|
| 517 |
+
else:
|
| 518 |
+
try:
|
| 519 |
+
from modelscope import snapshot_download
|
| 520 |
+
|
| 521 |
+
print(f"Using ModelScope to download model: {args.model_path}")
|
| 522 |
+
|
| 523 |
+
# download the model and replace args.model_path
|
| 524 |
+
args.model_path = snapshot_download(
|
| 525 |
+
args.model_path,
|
| 526 |
+
)
|
| 527 |
+
print(f"Model downloaded to: {args.model_path}")
|
| 528 |
+
except Exception as e:
|
| 529 |
+
print(f"ModelScope download failed: {str(e)}")
|
| 530 |
+
raise e
|
| 531 |
+
|
| 532 |
+
server_args = ServerArgs.from_cli_args(args)
|
| 533 |
+
bench_args = BenchArgs.from_cli_args(args)
|
| 534 |
+
|
| 535 |
+
logging.basicConfig(
|
| 536 |
+
level=getattr(logging, server_args.log_level.upper()),
|
| 537 |
+
format="%(message)s",
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
throughput_test(server_args, bench_args)
|
| 541 |
+
|
| 542 |
+
while bench_args.do_not_exit:
|
| 543 |
+
pass
|
sglang/python/sglang/bench_one_batch.py
ADDED
|
@@ -0,0 +1,837 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Benchmark the latency of running a single static batch without a server.
|
| 3 |
+
|
| 4 |
+
This script does not launch a server and uses the low-level APIs.
|
| 5 |
+
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
|
| 6 |
+
|
| 7 |
+
# Usage (latency test)
|
| 8 |
+
## with dummy weights:
|
| 9 |
+
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
| 10 |
+
## sweep through multiple data points and store (append) the results in a jsonl file:
|
| 11 |
+
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
|
| 12 |
+
## run with profiling:
|
| 13 |
+
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile
|
| 14 |
+
## run with profiling to custom directory:
|
| 15 |
+
export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log
|
| 16 |
+
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile
|
| 17 |
+
## run with CUDA profiler (nsys):
|
| 18 |
+
nsys profile --force-overwrite=true -o bench_one_batch python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile --profile-activities CUDA_PROFILER
|
| 19 |
+
# Usage (correctness test):
|
| 20 |
+
python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
| 21 |
+
|
| 22 |
+
## Reference output (of the correctness test above, can be gpu dependent):
|
| 23 |
+
input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]
|
| 24 |
+
|
| 25 |
+
prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
|
| 26 |
+
[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
|
| 27 |
+
[ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]],
|
| 28 |
+
device='cuda:0')
|
| 29 |
+
|
| 30 |
+
prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141],
|
| 31 |
+
[-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781],
|
| 32 |
+
[-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]],
|
| 33 |
+
device='cuda:0')
|
| 34 |
+
|
| 35 |
+
========== Prompt 0 ==========
|
| 36 |
+
<s> The capital of France is Paris.
|
| 37 |
+
The capital of the United States is Washington, D.C.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
========== Prompt 1 ==========
|
| 41 |
+
<s> The capital of the United Kindom is London.
|
| 42 |
+
The capital of the United Kingdom is London.
|
| 43 |
+
The capital of the
|
| 44 |
+
|
| 45 |
+
========== Prompt 2 ==========
|
| 46 |
+
<s> Today is a sunny day and I like to go for a walk in the park.
|
| 47 |
+
I'm going to the park
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
import argparse
|
| 51 |
+
import copy
|
| 52 |
+
import dataclasses
|
| 53 |
+
import itertools
|
| 54 |
+
import json
|
| 55 |
+
import logging
|
| 56 |
+
import multiprocessing
|
| 57 |
+
import os
|
| 58 |
+
import time
|
| 59 |
+
from types import SimpleNamespace
|
| 60 |
+
from typing import Optional, Tuple
|
| 61 |
+
|
| 62 |
+
import numpy as np
|
| 63 |
+
import torch
|
| 64 |
+
import torch.distributed as dist
|
| 65 |
+
|
| 66 |
+
from sglang.srt.configs.model_config import ModelConfig
|
| 67 |
+
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
|
| 68 |
+
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
| 69 |
+
from sglang.srt.layers.moe import initialize_moe_config
|
| 70 |
+
from sglang.srt.layers.quantization.fp4_utils import initialize_fp4_gemm_config
|
| 71 |
+
from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config
|
| 72 |
+
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
| 73 |
+
from sglang.srt.managers.scheduler_dp_attn_mixin import prepare_mlp_sync_batch_raw
|
| 74 |
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
| 75 |
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
| 76 |
+
from sglang.srt.sampling.sampling_params import SamplingParams
|
| 77 |
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
| 78 |
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
| 79 |
+
from sglang.srt.utils import (
|
| 80 |
+
configure_logger,
|
| 81 |
+
get_bool_env_var,
|
| 82 |
+
kill_process_tree,
|
| 83 |
+
maybe_reindex_device_id,
|
| 84 |
+
require_mlp_sync,
|
| 85 |
+
require_mlp_tp_gather,
|
| 86 |
+
set_gpu_proc_affinity,
|
| 87 |
+
suppress_other_loggers,
|
| 88 |
+
)
|
| 89 |
+
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def start_profile(profile_activities, profile_record_shapes=False, rank_print=print):
|
| 93 |
+
"""
|
| 94 |
+
Abstracted function to start profiling based on profile_activities.
|
| 95 |
+
Returns profiler object (or None).
|
| 96 |
+
"""
|
| 97 |
+
if "CUDA_PROFILER" in profile_activities:
|
| 98 |
+
try:
|
| 99 |
+
torch.cuda.cudart().cudaProfilerStart()
|
| 100 |
+
rank_print("CUDA Profiler started (nsys will begin capturing)")
|
| 101 |
+
except Exception as e:
|
| 102 |
+
rank_print(f"Failed to start CUDA profiler: {e}")
|
| 103 |
+
return None
|
| 104 |
+
else:
|
| 105 |
+
activities = []
|
| 106 |
+
if "CPU" in profile_activities:
|
| 107 |
+
activities.append(torch.profiler.ProfilerActivity.CPU)
|
| 108 |
+
if "GPU" in profile_activities:
|
| 109 |
+
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
| 110 |
+
if "XPU" in profile_activities:
|
| 111 |
+
activities.append(torch.profiler.ProfilerActivity.XPU)
|
| 112 |
+
if activities:
|
| 113 |
+
profiler = torch.profiler.profile(
|
| 114 |
+
activities=activities,
|
| 115 |
+
with_stack=True,
|
| 116 |
+
record_shapes=profile_record_shapes,
|
| 117 |
+
)
|
| 118 |
+
profiler.start()
|
| 119 |
+
return profiler
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def stop_profile(
|
| 124 |
+
profiler,
|
| 125 |
+
profile_activities,
|
| 126 |
+
rank_print=print,
|
| 127 |
+
save_trace=False,
|
| 128 |
+
trace_filename=None,
|
| 129 |
+
stage=None,
|
| 130 |
+
):
|
| 131 |
+
"""
|
| 132 |
+
Abstracted function to stop profiling based on profile_activities.
|
| 133 |
+
Optionally saves trace results and prints completion messages.
|
| 134 |
+
"""
|
| 135 |
+
if "CUDA_PROFILER" in profile_activities:
|
| 136 |
+
try:
|
| 137 |
+
torch.cuda.cudart().cudaProfilerStop()
|
| 138 |
+
rank_print("CUDA Profiler stopped (nsys should dump traces)")
|
| 139 |
+
except Exception as e:
|
| 140 |
+
rank_print(f"Failed to stop CUDA profiler: {e}")
|
| 141 |
+
elif profiler is not None:
|
| 142 |
+
profiler.stop()
|
| 143 |
+
|
| 144 |
+
if save_trace:
|
| 145 |
+
if profiler is not None:
|
| 146 |
+
if trace_filename:
|
| 147 |
+
_save_profile_trace_results(profiler, trace_filename)
|
| 148 |
+
stage_desc = f"for {stage}" if stage else ""
|
| 149 |
+
rank_print(
|
| 150 |
+
f"torch profiler chrome trace {stage_desc} saved to {trace_filename}"
|
| 151 |
+
)
|
| 152 |
+
if "CUDA_PROFILER" in profile_activities:
|
| 153 |
+
rank_print(f"CUDA profiler trace for {stage} completed")
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@dataclasses.dataclass
|
| 157 |
+
class BenchArgs:
|
| 158 |
+
run_name: str = "default"
|
| 159 |
+
batch_size: Tuple[int] = (1,)
|
| 160 |
+
input_len: Tuple[int] = (1024,)
|
| 161 |
+
output_len: Tuple[int] = (16,)
|
| 162 |
+
prompt_filename: str = ""
|
| 163 |
+
result_filename: str = "result.jsonl"
|
| 164 |
+
correctness_test: bool = False
|
| 165 |
+
# This is only used for correctness test
|
| 166 |
+
cut_len: int = 4
|
| 167 |
+
log_decode_step: int = 0
|
| 168 |
+
profile: bool = False
|
| 169 |
+
profile_record_shapes: bool = False
|
| 170 |
+
profile_activities: Tuple[str] = ("CPU", "GPU")
|
| 171 |
+
profile_stage: str = "all"
|
| 172 |
+
profile_filename_prefix: str = "profile"
|
| 173 |
+
profile_start_step: Optional[int] = None
|
| 174 |
+
profile_steps: Optional[int] = None
|
| 175 |
+
|
| 176 |
+
@staticmethod
|
| 177 |
+
def add_cli_args(parser: argparse.ArgumentParser):
|
| 178 |
+
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
|
| 181 |
+
)
|
| 182 |
+
parser.add_argument(
|
| 183 |
+
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
| 187 |
+
)
|
| 188 |
+
parser.add_argument(
|
| 189 |
+
"--prompt-filename", type=str, default=BenchArgs.prompt_filename
|
| 190 |
+
)
|
| 191 |
+
parser.add_argument(
|
| 192 |
+
"--result-filename", type=str, default=BenchArgs.result_filename
|
| 193 |
+
)
|
| 194 |
+
parser.add_argument("--correctness-test", action="store_true")
|
| 195 |
+
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
| 196 |
+
parser.add_argument(
|
| 197 |
+
"--log-decode-step",
|
| 198 |
+
type=int,
|
| 199 |
+
default=BenchArgs.log_decode_step,
|
| 200 |
+
help="Log decode latency by step, default is set to zero to disable.",
|
| 201 |
+
)
|
| 202 |
+
parser.add_argument("--profile", action="store_true", help="Enable profiling.")
|
| 203 |
+
parser.add_argument(
|
| 204 |
+
"--profile-record-shapes",
|
| 205 |
+
action="store_true",
|
| 206 |
+
help="Record tensor shapes in profiling results.",
|
| 207 |
+
)
|
| 208 |
+
parser.add_argument(
|
| 209 |
+
"--profile-activities",
|
| 210 |
+
type=str,
|
| 211 |
+
nargs="+",
|
| 212 |
+
default=["CPU", "GPU"],
|
| 213 |
+
choices=["CPU", "GPU", "CUDA_PROFILER", "XPU"],
|
| 214 |
+
help="Profiler activities: CPU, GPU, XPU, CUDA_PROFILER. If CPU/GPU/XPU, use torch profiler. If CUDA_PROFILER, use CUDA profiler.",
|
| 215 |
+
)
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--profile-stage",
|
| 218 |
+
type=str,
|
| 219 |
+
default=BenchArgs.profile_stage,
|
| 220 |
+
choices=["all", "prefill", "decode"],
|
| 221 |
+
help="Which stage to profile: all, prefill, or decode only.",
|
| 222 |
+
)
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--profile-filename-prefix",
|
| 225 |
+
type=str,
|
| 226 |
+
default=BenchArgs.profile_filename_prefix,
|
| 227 |
+
help="Prefix of the profiling file names. The full profiling result file(s) be "
|
| 228 |
+
'"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"',
|
| 229 |
+
)
|
| 230 |
+
parser.add_argument(
|
| 231 |
+
"--profile-start-step",
|
| 232 |
+
type=int,
|
| 233 |
+
default=None,
|
| 234 |
+
help="Decode step at which to start profiling (0-indexed). If not specified, defaults to output_len // 2.",
|
| 235 |
+
)
|
| 236 |
+
parser.add_argument(
|
| 237 |
+
"--profile-steps",
|
| 238 |
+
type=int,
|
| 239 |
+
default=None,
|
| 240 |
+
help="Number of decode steps to profile starting from profile-start-step. If not specified, profiles only one step.",
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
@classmethod
|
| 244 |
+
def from_cli_args(cls, args: argparse.Namespace):
|
| 245 |
+
# use the default value's type to cast the args into correct types.
|
| 246 |
+
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
| 247 |
+
result = {}
|
| 248 |
+
for attr, attr_type in attrs:
|
| 249 |
+
value = getattr(args, attr)
|
| 250 |
+
# Handle None values - don't try to cast them
|
| 251 |
+
if value is None or attr_type == type(None):
|
| 252 |
+
result[attr] = value
|
| 253 |
+
else:
|
| 254 |
+
result[attr] = attr_type(value)
|
| 255 |
+
return cls(**result)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def load_model(server_args, port_args, gpu_id, tp_rank):
|
| 259 |
+
suppress_other_loggers()
|
| 260 |
+
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
| 261 |
+
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
|
| 262 |
+
|
| 263 |
+
model_config = ModelConfig.from_server_args(server_args)
|
| 264 |
+
model_runner = ModelRunner(
|
| 265 |
+
model_config=model_config,
|
| 266 |
+
mem_fraction_static=server_args.mem_fraction_static,
|
| 267 |
+
gpu_id=gpu_id,
|
| 268 |
+
tp_rank=tp_rank,
|
| 269 |
+
tp_size=server_args.tp_size,
|
| 270 |
+
moe_ep_rank=moe_ep_rank,
|
| 271 |
+
moe_ep_size=server_args.ep_size,
|
| 272 |
+
pp_rank=0,
|
| 273 |
+
pp_size=1,
|
| 274 |
+
nccl_port=port_args.nccl_port,
|
| 275 |
+
server_args=server_args,
|
| 276 |
+
)
|
| 277 |
+
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
| 278 |
+
tokenizer = get_tokenizer(
|
| 279 |
+
server_args.tokenizer_path,
|
| 280 |
+
tokenizer_mode=server_args.tokenizer_mode,
|
| 281 |
+
trust_remote_code=server_args.trust_remote_code,
|
| 282 |
+
)
|
| 283 |
+
if server_args.tp_size > 1:
|
| 284 |
+
dist.barrier()
|
| 285 |
+
return model_runner, tokenizer
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
|
| 289 |
+
prompts = (
|
| 290 |
+
custom_prompts
|
| 291 |
+
if custom_prompts
|
| 292 |
+
else [
|
| 293 |
+
"The capital of France is",
|
| 294 |
+
"The capital of the United Kindom is",
|
| 295 |
+
"Today is a sunny day and I like",
|
| 296 |
+
]
|
| 297 |
+
)
|
| 298 |
+
input_ids = [tokenizer.encode(p) for p in prompts]
|
| 299 |
+
sampling_params = SamplingParams(
|
| 300 |
+
temperature=0,
|
| 301 |
+
max_new_tokens=BenchArgs.output_len,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
reqs = []
|
| 305 |
+
for i in range(len(prompts)):
|
| 306 |
+
assert len(input_ids[i]) > bench_args.cut_len
|
| 307 |
+
|
| 308 |
+
tmp_input_ids = input_ids[i][: bench_args.cut_len]
|
| 309 |
+
req = Req(
|
| 310 |
+
rid=i,
|
| 311 |
+
origin_input_text=prompts[i],
|
| 312 |
+
origin_input_ids=tmp_input_ids,
|
| 313 |
+
sampling_params=sampling_params,
|
| 314 |
+
)
|
| 315 |
+
req.fill_ids = req.origin_input_ids
|
| 316 |
+
req.logprob_start_len = -1
|
| 317 |
+
req.set_extend_input_len(len(req.fill_ids) - len(req.prefix_indices))
|
| 318 |
+
reqs.append(req)
|
| 319 |
+
|
| 320 |
+
return input_ids, reqs
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def prepare_extend_inputs_for_correctness_test(
|
| 324 |
+
bench_args, input_ids, reqs, model_runner
|
| 325 |
+
):
|
| 326 |
+
for i in range(len(reqs)):
|
| 327 |
+
req: Req = reqs[i]
|
| 328 |
+
req.fill_ids += input_ids[i][bench_args.cut_len :]
|
| 329 |
+
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
| 330 |
+
i, : bench_args.cut_len
|
| 331 |
+
].to(req.prefix_indices.dtype)
|
| 332 |
+
req.logprob_start_len = -1
|
| 333 |
+
req.set_extend_input_len(len(req.fill_ids) - len(req.prefix_indices))
|
| 334 |
+
return reqs
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def prepare_synthetic_inputs_for_latency_test(
|
| 338 |
+
batch_size, input_len, custom_inputs=None
|
| 339 |
+
):
|
| 340 |
+
input_ids = (
|
| 341 |
+
custom_inputs
|
| 342 |
+
if custom_inputs
|
| 343 |
+
else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
|
| 344 |
+
)
|
| 345 |
+
sampling_params = SamplingParams(
|
| 346 |
+
temperature=0,
|
| 347 |
+
max_new_tokens=BenchArgs.output_len,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
reqs = []
|
| 351 |
+
for i in range(len(input_ids)):
|
| 352 |
+
req = Req(
|
| 353 |
+
rid=i,
|
| 354 |
+
origin_input_text="",
|
| 355 |
+
origin_input_ids=list(input_ids[i]),
|
| 356 |
+
sampling_params=sampling_params,
|
| 357 |
+
)
|
| 358 |
+
req.fill_ids = req.origin_input_ids
|
| 359 |
+
req.logprob_start_len = -1
|
| 360 |
+
req.set_extend_input_len(len(req.fill_ids) - len(req.prefix_indices))
|
| 361 |
+
reqs.append(req)
|
| 362 |
+
|
| 363 |
+
return reqs
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
class TreeCacheNamespace(SimpleNamespace):
|
| 367 |
+
def supports_swa(self) -> bool:
|
| 368 |
+
return False
|
| 369 |
+
|
| 370 |
+
def supports_mamba(self) -> bool:
|
| 371 |
+
return False
|
| 372 |
+
|
| 373 |
+
def is_chunk_cache(self) -> bool:
|
| 374 |
+
return False
|
| 375 |
+
|
| 376 |
+
def is_tree_cache(self) -> bool:
|
| 377 |
+
return not self.is_chunk_cache()
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
@torch.no_grad
|
| 381 |
+
def extend(reqs, model_runner):
|
| 382 |
+
# Create dummy tree_cache for benchmarks (no prefix caching, just allocation)
|
| 383 |
+
dummy_tree_cache = TreeCacheNamespace(
|
| 384 |
+
page_size=model_runner.server_args.page_size,
|
| 385 |
+
device=model_runner.device,
|
| 386 |
+
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
batch = ScheduleBatch.init_new(
|
| 390 |
+
reqs=reqs,
|
| 391 |
+
req_to_token_pool=model_runner.req_to_token_pool,
|
| 392 |
+
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
|
| 393 |
+
tree_cache=dummy_tree_cache,
|
| 394 |
+
model_config=model_runner.model_config,
|
| 395 |
+
enable_overlap=False,
|
| 396 |
+
spec_algorithm=SpeculativeAlgorithm.NONE,
|
| 397 |
+
)
|
| 398 |
+
batch.prepare_for_extend()
|
| 399 |
+
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
| 400 |
+
model_worker_batch = batch.get_model_worker_batch()
|
| 401 |
+
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
| 402 |
+
logits_output = model_runner.forward(forward_batch).logits_output
|
| 403 |
+
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
| 404 |
+
return next_token_ids, logits_output.next_token_logits, batch
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
@torch.no_grad
|
| 408 |
+
def decode(input_token_ids, batch, model_runner):
|
| 409 |
+
batch.output_ids = input_token_ids
|
| 410 |
+
batch.prepare_for_decode()
|
| 411 |
+
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
| 412 |
+
model_worker_batch = batch.get_model_worker_batch()
|
| 413 |
+
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
| 414 |
+
logits_output = model_runner.forward(forward_batch).logits_output
|
| 415 |
+
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
| 416 |
+
return next_token_ids, logits_output.next_token_logits
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
|
| 420 |
+
if require_mlp_sync(model_runner.server_args):
|
| 421 |
+
prepare_mlp_sync_batch_raw(
|
| 422 |
+
batch,
|
| 423 |
+
dp_size=model_runner.server_args.dp_size,
|
| 424 |
+
attn_tp_size=1,
|
| 425 |
+
tp_group=model_runner.tp_group,
|
| 426 |
+
get_idle_batch=None,
|
| 427 |
+
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
| 428 |
+
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
|
| 429 |
+
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
|
| 430 |
+
offload_tags=set(),
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def _read_prompts_from_file(prompt_file, rank_print):
|
| 435 |
+
"""Read custom prompts from the file specified by `--prompt-filename`."""
|
| 436 |
+
if not prompt_file:
|
| 437 |
+
return []
|
| 438 |
+
if not os.path.exists(prompt_file):
|
| 439 |
+
rank_print(
|
| 440 |
+
f"Custom prompt file {prompt_file} not found. Using default inputs..."
|
| 441 |
+
)
|
| 442 |
+
return []
|
| 443 |
+
with open(prompt_file, "r") as pf:
|
| 444 |
+
return pf.readlines()
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def _get_torch_profiler_output_dir():
|
| 448 |
+
return os.environ.get("SGLANG_TORCH_PROFILER_DIR", "/tmp")
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def _create_torch_profiler_filename(
|
| 452 |
+
profile_filename_prefix, batch_size, input_len, output_len, stage
|
| 453 |
+
):
|
| 454 |
+
output_dir = _get_torch_profiler_output_dir()
|
| 455 |
+
filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_{stage}.trace.json.gz"
|
| 456 |
+
return os.path.join(output_dir, filename)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def _save_profile_trace_results(profiler, filename):
|
| 460 |
+
parent_dir = os.path.dirname(os.path.abspath(filename))
|
| 461 |
+
os.makedirs(parent_dir, exist_ok=True)
|
| 462 |
+
profiler.export_chrome_trace(filename)
|
| 463 |
+
print(
|
| 464 |
+
profiler.key_averages(group_by_input_shape=True).table(
|
| 465 |
+
sort_by="self_cpu_time_total"
|
| 466 |
+
)
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def correctness_test(
|
| 471 |
+
server_args,
|
| 472 |
+
port_args,
|
| 473 |
+
bench_args,
|
| 474 |
+
gpu_id,
|
| 475 |
+
tp_rank,
|
| 476 |
+
):
|
| 477 |
+
# Configure the logger
|
| 478 |
+
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
| 479 |
+
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
| 480 |
+
|
| 481 |
+
# Load the model
|
| 482 |
+
model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank)
|
| 483 |
+
|
| 484 |
+
# Prepare inputs
|
| 485 |
+
custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
|
| 486 |
+
input_ids, reqs = prepare_inputs_for_correctness_test(
|
| 487 |
+
bench_args, tokenizer, custom_prompts
|
| 488 |
+
)
|
| 489 |
+
rank_print(f"\n{input_ids=}\n")
|
| 490 |
+
|
| 491 |
+
if bench_args.cut_len > 0:
|
| 492 |
+
# Prefill
|
| 493 |
+
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
| 494 |
+
rank_print(f"prefill logits (first half): {next_token_logits} \n")
|
| 495 |
+
|
| 496 |
+
# Prepare extend inputs
|
| 497 |
+
reqs = prepare_extend_inputs_for_correctness_test(
|
| 498 |
+
bench_args, input_ids, reqs, model_runner
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
# Extend (prefill w/ KV cache)
|
| 502 |
+
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
| 503 |
+
rank_print(f"prefill logits (final): {next_token_logits} \n")
|
| 504 |
+
|
| 505 |
+
# Decode
|
| 506 |
+
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
| 507 |
+
for _ in range(bench_args.output_len[0] - 1):
|
| 508 |
+
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
| 509 |
+
next_token_ids_list = next_token_ids.tolist()
|
| 510 |
+
for i in range(len(reqs)):
|
| 511 |
+
output_ids[i].append(next_token_ids_list[i])
|
| 512 |
+
|
| 513 |
+
# Print output texts
|
| 514 |
+
for i in range(len(reqs)):
|
| 515 |
+
rank_print(f"========== Prompt {i} ==========")
|
| 516 |
+
rank_print(tokenizer.decode(output_ids[i]), "\n")
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def synchronize(device):
|
| 520 |
+
torch.get_device_module(device).synchronize()
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def latency_test_run_once(
|
| 524 |
+
run_name,
|
| 525 |
+
model_runner,
|
| 526 |
+
rank_print,
|
| 527 |
+
reqs,
|
| 528 |
+
batch_size,
|
| 529 |
+
input_len,
|
| 530 |
+
output_len,
|
| 531 |
+
device,
|
| 532 |
+
log_decode_step,
|
| 533 |
+
profile,
|
| 534 |
+
profile_record_shapes,
|
| 535 |
+
profile_activities,
|
| 536 |
+
profile_filename_prefix,
|
| 537 |
+
profile_stage,
|
| 538 |
+
tp_rank,
|
| 539 |
+
profile_start_step=None,
|
| 540 |
+
profile_steps=None,
|
| 541 |
+
):
|
| 542 |
+
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
| 543 |
+
if batch_size > max_batch_size:
|
| 544 |
+
rank_print(
|
| 545 |
+
f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
|
| 546 |
+
)
|
| 547 |
+
return
|
| 548 |
+
|
| 549 |
+
model_runner.req_to_token_pool.clear()
|
| 550 |
+
model_runner.token_to_kv_pool_allocator.clear()
|
| 551 |
+
|
| 552 |
+
measurement_results = {
|
| 553 |
+
"run_name": run_name,
|
| 554 |
+
"batch_size": batch_size,
|
| 555 |
+
"input_len": input_len,
|
| 556 |
+
"output_len": output_len,
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
tot_latency = 0
|
| 560 |
+
|
| 561 |
+
profiler = None
|
| 562 |
+
enable_profile_prefill = profile and profile_stage in ["all", "prefill"]
|
| 563 |
+
if enable_profile_prefill:
|
| 564 |
+
profiler = start_profile(
|
| 565 |
+
profile_activities,
|
| 566 |
+
profile_record_shapes=profile_record_shapes,
|
| 567 |
+
rank_print=rank_print,
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
synchronize(device)
|
| 571 |
+
tic = time.perf_counter()
|
| 572 |
+
next_token_ids, _, batch = extend(reqs, model_runner)
|
| 573 |
+
synchronize(device)
|
| 574 |
+
prefill_latency = time.perf_counter() - tic
|
| 575 |
+
|
| 576 |
+
if enable_profile_prefill:
|
| 577 |
+
trace_filename = _create_torch_profiler_filename(
|
| 578 |
+
profile_filename_prefix, batch_size, input_len, output_len, "prefill"
|
| 579 |
+
)
|
| 580 |
+
stop_profile(
|
| 581 |
+
profiler,
|
| 582 |
+
profile_activities,
|
| 583 |
+
rank_print=rank_print,
|
| 584 |
+
save_trace=True,
|
| 585 |
+
trace_filename=trace_filename,
|
| 586 |
+
stage="prefill",
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
tot_latency += prefill_latency
|
| 590 |
+
throughput = input_len * batch_size / prefill_latency
|
| 591 |
+
rank_print(
|
| 592 |
+
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
| 593 |
+
)
|
| 594 |
+
measurement_results["prefill_latency"] = prefill_latency
|
| 595 |
+
measurement_results["prefill_throughput"] = throughput
|
| 596 |
+
|
| 597 |
+
decode_latencies = []
|
| 598 |
+
# Determine profiling start step and end step
|
| 599 |
+
profile_start = (
|
| 600 |
+
profile_start_step if profile_start_step is not None else (output_len // 2)
|
| 601 |
+
)
|
| 602 |
+
profile_end = profile_start + (profile_steps if profile_steps is not None else 1)
|
| 603 |
+
enable_profile_decode = profile and profile_stage in ["all", "decode"]
|
| 604 |
+
profiler = None
|
| 605 |
+
for i in range(output_len - 1):
|
| 606 |
+
synchronize(device)
|
| 607 |
+
# Start profiler at the specified step
|
| 608 |
+
if enable_profile_decode and i == profile_start:
|
| 609 |
+
profiler = start_profile(
|
| 610 |
+
profile_activities,
|
| 611 |
+
profile_record_shapes=profile_record_shapes,
|
| 612 |
+
rank_print=rank_print,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
tic = time.perf_counter()
|
| 616 |
+
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
| 617 |
+
synchronize(device)
|
| 618 |
+
latency = time.perf_counter() - tic
|
| 619 |
+
|
| 620 |
+
# Stop profiler after the specified number of steps
|
| 621 |
+
if enable_profile_decode and profiler is not None and i >= profile_end - 1:
|
| 622 |
+
trace_filename = _create_torch_profiler_filename(
|
| 623 |
+
profile_filename_prefix, batch_size, input_len, output_len, "decode"
|
| 624 |
+
)
|
| 625 |
+
stop_profile(
|
| 626 |
+
profiler,
|
| 627 |
+
profile_activities,
|
| 628 |
+
rank_print=rank_print,
|
| 629 |
+
save_trace=True,
|
| 630 |
+
trace_filename=trace_filename,
|
| 631 |
+
stage="decode",
|
| 632 |
+
)
|
| 633 |
+
profiler = None
|
| 634 |
+
|
| 635 |
+
tot_latency += latency
|
| 636 |
+
throughput = batch_size / latency
|
| 637 |
+
decode_latencies.append(latency)
|
| 638 |
+
if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0):
|
| 639 |
+
rank_print(
|
| 640 |
+
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
# Record decode timing from 2nd output
|
| 644 |
+
if output_len > 1:
|
| 645 |
+
med_decode_latency = np.median(decode_latencies)
|
| 646 |
+
med_decode_throughput = batch_size / med_decode_latency
|
| 647 |
+
rank_print(
|
| 648 |
+
f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
|
| 649 |
+
)
|
| 650 |
+
measurement_results["median_decode_latency"] = med_decode_latency
|
| 651 |
+
measurement_results["median_decode_throughput"] = med_decode_throughput
|
| 652 |
+
|
| 653 |
+
throughput = (input_len + output_len) * batch_size / tot_latency
|
| 654 |
+
rank_print(
|
| 655 |
+
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
|
| 656 |
+
)
|
| 657 |
+
measurement_results["total_latency"] = tot_latency
|
| 658 |
+
measurement_results["overall_throughput"] = throughput
|
| 659 |
+
return measurement_results
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
def latency_test(
|
| 663 |
+
server_args,
|
| 664 |
+
port_args,
|
| 665 |
+
bench_args,
|
| 666 |
+
gpu_id,
|
| 667 |
+
tp_rank,
|
| 668 |
+
):
|
| 669 |
+
initialize_moe_config(server_args)
|
| 670 |
+
initialize_fp8_gemm_config(server_args)
|
| 671 |
+
initialize_fp4_gemm_config(server_args)
|
| 672 |
+
|
| 673 |
+
# Set CPU affinity
|
| 674 |
+
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
| 675 |
+
set_gpu_proc_affinity(
|
| 676 |
+
server_args.pp_size, server_args.tp_size, server_args.nnodes, tp_rank
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
# Configure the logger
|
| 680 |
+
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
| 681 |
+
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
| 682 |
+
|
| 683 |
+
# Load the model
|
| 684 |
+
model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank)
|
| 685 |
+
|
| 686 |
+
# Prepare inputs for warm up
|
| 687 |
+
reqs = prepare_synthetic_inputs_for_latency_test(
|
| 688 |
+
bench_args.batch_size[0], bench_args.input_len[0]
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
# Warm up
|
| 692 |
+
rank_print("Warmup ...")
|
| 693 |
+
latency_test_run_once(
|
| 694 |
+
bench_args.run_name,
|
| 695 |
+
model_runner,
|
| 696 |
+
rank_print,
|
| 697 |
+
reqs,
|
| 698 |
+
bench_args.batch_size[0],
|
| 699 |
+
bench_args.input_len[0],
|
| 700 |
+
min(32, bench_args.output_len[0]), # shorter decoding to speed up the warmup
|
| 701 |
+
server_args.device,
|
| 702 |
+
log_decode_step=0,
|
| 703 |
+
profile=False,
|
| 704 |
+
profile_record_shapes=False,
|
| 705 |
+
profile_activities=("CPU", "GPU"),
|
| 706 |
+
profile_filename_prefix="",
|
| 707 |
+
profile_stage="all",
|
| 708 |
+
tp_rank=tp_rank,
|
| 709 |
+
profile_start_step=None,
|
| 710 |
+
profile_steps=None,
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
rank_print("Benchmark ...")
|
| 714 |
+
|
| 715 |
+
custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
|
| 716 |
+
custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs]
|
| 717 |
+
custom_input_len = len(custom_inputs)
|
| 718 |
+
|
| 719 |
+
# Run the sweep
|
| 720 |
+
result_list = []
|
| 721 |
+
for bs, il, ol in itertools.product(
|
| 722 |
+
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
| 723 |
+
):
|
| 724 |
+
bs_aligned_inputs = []
|
| 725 |
+
if custom_inputs:
|
| 726 |
+
if custom_input_len == bs:
|
| 727 |
+
bs_aligned_inputs = custom_inputs
|
| 728 |
+
elif custom_input_len > bs:
|
| 729 |
+
rank_print(
|
| 730 |
+
f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). "
|
| 731 |
+
f"Using the first {bs} prompts."
|
| 732 |
+
)
|
| 733 |
+
bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs])
|
| 734 |
+
else:
|
| 735 |
+
rank_print(
|
| 736 |
+
f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). "
|
| 737 |
+
f"Pad to the desired batch_size with the last prompt."
|
| 738 |
+
)
|
| 739 |
+
bs_aligned_inputs = copy.deepcopy(custom_inputs)
|
| 740 |
+
bs_aligned_inputs.extend(
|
| 741 |
+
[bs_aligned_inputs[-1]] * (bs - custom_input_len)
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs)
|
| 745 |
+
ret = latency_test_run_once(
|
| 746 |
+
bench_args.run_name,
|
| 747 |
+
model_runner,
|
| 748 |
+
rank_print,
|
| 749 |
+
reqs,
|
| 750 |
+
bs,
|
| 751 |
+
il,
|
| 752 |
+
ol,
|
| 753 |
+
server_args.device,
|
| 754 |
+
bench_args.log_decode_step,
|
| 755 |
+
bench_args.profile if tp_rank == 0 else None,
|
| 756 |
+
bench_args.profile_record_shapes if tp_rank == 0 else None,
|
| 757 |
+
bench_args.profile_activities,
|
| 758 |
+
bench_args.profile_filename_prefix,
|
| 759 |
+
bench_args.profile_stage,
|
| 760 |
+
tp_rank,
|
| 761 |
+
bench_args.profile_start_step,
|
| 762 |
+
bench_args.profile_steps,
|
| 763 |
+
)
|
| 764 |
+
if ret is not None:
|
| 765 |
+
result_list.append(ret)
|
| 766 |
+
|
| 767 |
+
# Write results in jsonlines format on rank 0.
|
| 768 |
+
if tp_rank == 0 and bench_args.result_filename:
|
| 769 |
+
with open(bench_args.result_filename, "a") as fout:
|
| 770 |
+
for result in result_list:
|
| 771 |
+
fout.write(json.dumps(result) + "\n")
|
| 772 |
+
|
| 773 |
+
if server_args.tp_size > 1:
|
| 774 |
+
destroy_distributed_environment()
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
def main(server_args, bench_args):
|
| 778 |
+
server_args.cuda_graph_max_bs = max(bench_args.batch_size)
|
| 779 |
+
|
| 780 |
+
_set_envs_and_config(server_args)
|
| 781 |
+
|
| 782 |
+
if server_args.model_path:
|
| 783 |
+
if bench_args.correctness_test:
|
| 784 |
+
work_func = correctness_test
|
| 785 |
+
else:
|
| 786 |
+
work_func = latency_test
|
| 787 |
+
else:
|
| 788 |
+
raise ValueError(
|
| 789 |
+
"Provide --model-path for running the tests or "
|
| 790 |
+
"provide --result-filename for plotting the results"
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
port_args = PortArgs.init_new(server_args)
|
| 794 |
+
|
| 795 |
+
if server_args.tp_size == 1:
|
| 796 |
+
work_func(server_args, port_args, bench_args, 0, 0)
|
| 797 |
+
else:
|
| 798 |
+
workers = []
|
| 799 |
+
for tp_rank in range(server_args.tp_size):
|
| 800 |
+
with maybe_reindex_device_id(tp_rank) as gpu_id:
|
| 801 |
+
proc = multiprocessing.Process(
|
| 802 |
+
target=work_func,
|
| 803 |
+
args=(
|
| 804 |
+
server_args,
|
| 805 |
+
port_args,
|
| 806 |
+
bench_args,
|
| 807 |
+
gpu_id,
|
| 808 |
+
tp_rank,
|
| 809 |
+
),
|
| 810 |
+
)
|
| 811 |
+
proc.start()
|
| 812 |
+
workers.append(proc)
|
| 813 |
+
|
| 814 |
+
for proc in workers:
|
| 815 |
+
proc.join()
|
| 816 |
+
|
| 817 |
+
proc.terminate()
|
| 818 |
+
|
| 819 |
+
|
| 820 |
+
if __name__ == "__main__":
|
| 821 |
+
parser = argparse.ArgumentParser()
|
| 822 |
+
ServerArgs.add_cli_args(parser)
|
| 823 |
+
BenchArgs.add_cli_args(parser)
|
| 824 |
+
args = parser.parse_args()
|
| 825 |
+
server_args = ServerArgs.from_cli_args(args)
|
| 826 |
+
bench_args = BenchArgs.from_cli_args(args)
|
| 827 |
+
|
| 828 |
+
logging.basicConfig(
|
| 829 |
+
level=getattr(logging, server_args.log_level.upper()),
|
| 830 |
+
format="%(message)s",
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
try:
|
| 834 |
+
main(server_args, bench_args)
|
| 835 |
+
finally:
|
| 836 |
+
if server_args.tp_size != 1:
|
| 837 |
+
kill_process_tree(os.getpid(), include_parent=False)
|
sglang/python/sglang/bench_one_batch_server.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Benchmark the latency of running a single batch with a server.
|
| 3 |
+
|
| 4 |
+
This script launches a server and uses the HTTP interface.
|
| 5 |
+
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
|
| 9 |
+
|
| 10 |
+
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
|
| 11 |
+
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
|
| 12 |
+
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --output-path results.json --profile
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
|
| 17 |
+
from sglang.srt.server_args import ServerArgs
|
| 18 |
+
from sglang.test.bench_one_batch_server_internal import (
|
| 19 |
+
BenchArgs,
|
| 20 |
+
run_benchmark_internal,
|
| 21 |
+
)
|
| 22 |
+
from sglang.test.nightly_bench_utils import save_results_as_pydantic_models
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
| 26 |
+
results, server_info = run_benchmark_internal(server_args, bench_args)
|
| 27 |
+
|
| 28 |
+
# Save results as pydantic models in the JSON format
|
| 29 |
+
if bench_args.pydantic_result_filename:
|
| 30 |
+
save_results_as_pydantic_models(
|
| 31 |
+
results,
|
| 32 |
+
pydantic_result_filename=bench_args.pydantic_result_filename,
|
| 33 |
+
model_path=server_args.model_path,
|
| 34 |
+
server_args=bench_args.server_args_for_metrics,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
return results, server_info
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
parser = argparse.ArgumentParser()
|
| 42 |
+
ServerArgs.add_cli_args(parser)
|
| 43 |
+
BenchArgs.add_cli_args(parser)
|
| 44 |
+
args = parser.parse_args()
|
| 45 |
+
|
| 46 |
+
server_args = ServerArgs.from_cli_args(args)
|
| 47 |
+
bench_args = BenchArgs.from_cli_args(args)
|
| 48 |
+
|
| 49 |
+
run_benchmark(server_args, bench_args)
|
sglang/python/sglang/bench_serving.py
ADDED
|
@@ -0,0 +1,2238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py
|
| 2 |
+
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Benchmark online serving with dynamic requests.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
|
| 9 |
+
|
| 10 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import asyncio
|
| 15 |
+
import copy
|
| 16 |
+
import importlib.util
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import random
|
| 20 |
+
import shutil
|
| 21 |
+
import sys
|
| 22 |
+
import time
|
| 23 |
+
import traceback
|
| 24 |
+
import uuid
|
| 25 |
+
import warnings
|
| 26 |
+
from argparse import ArgumentParser
|
| 27 |
+
from copy import deepcopy
|
| 28 |
+
from dataclasses import dataclass, field, replace
|
| 29 |
+
from datetime import datetime
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple, Union
|
| 32 |
+
|
| 33 |
+
import aiohttp
|
| 34 |
+
import numpy as np
|
| 35 |
+
import requests
|
| 36 |
+
from tqdm.asyncio import tqdm
|
| 37 |
+
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
| 38 |
+
|
| 39 |
+
from sglang.benchmark.datasets import DatasetRow, get_dataset
|
| 40 |
+
from sglang.benchmark.datasets.mooncake import get_mooncake_request_over_time
|
| 41 |
+
from sglang.benchmark.utils import (
|
| 42 |
+
get_tokenizer,
|
| 43 |
+
parse_custom_headers,
|
| 44 |
+
remove_prefix,
|
| 45 |
+
set_ulimit,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
_ROUTING_KEY_HEADER = "X-SMG-Routing-Key"
|
| 49 |
+
|
| 50 |
+
TERM_PLOTLIB_AVAILABLE = (importlib.util.find_spec("termplotlib") is not None) and (
|
| 51 |
+
shutil.which("gnuplot") is not None
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
global args
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# don't want to import sglang package here
|
| 58 |
+
def _get_bool_env_var(name: str, default: str = "false") -> bool:
|
| 59 |
+
value = os.getenv(name, default)
|
| 60 |
+
return value.lower() in ("true", "1")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _create_bench_client_session():
|
| 64 |
+
# When the pressure is big, the read buffer could be full before aio thread read
|
| 65 |
+
# the content. We increase the read_bufsize from 64K to 10M.
|
| 66 |
+
# Define constants for timeout and buffer size for clarity and maintainability
|
| 67 |
+
BENCH_AIOHTTP_TIMEOUT_SECONDS = 6 * 60 * 60 # 6 hours
|
| 68 |
+
BENCH_AIOHTTP_READ_BUFSIZE_BYTES = 10 * 1024**2 # 10 MB
|
| 69 |
+
|
| 70 |
+
aiohttp_timeout = aiohttp.ClientTimeout(total=BENCH_AIOHTTP_TIMEOUT_SECONDS)
|
| 71 |
+
return aiohttp.ClientSession(
|
| 72 |
+
timeout=aiohttp_timeout, read_bufsize=BENCH_AIOHTTP_READ_BUFSIZE_BYTES
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclass
|
| 77 |
+
class RequestFuncInput:
|
| 78 |
+
prompt: Union[str, List[str], List[Dict[str, str]]]
|
| 79 |
+
api_url: str
|
| 80 |
+
prompt_len: int
|
| 81 |
+
output_len: int
|
| 82 |
+
model: str
|
| 83 |
+
lora_name: str
|
| 84 |
+
image_data: Optional[List[str]]
|
| 85 |
+
extra_request_body: Dict[str, Any]
|
| 86 |
+
timestamp: Optional[float] = None
|
| 87 |
+
routing_key: Optional[str] = None
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@dataclass
|
| 91 |
+
class RequestFuncOutput:
|
| 92 |
+
generated_text: str = ""
|
| 93 |
+
success: bool = False
|
| 94 |
+
latency: float = 0.0
|
| 95 |
+
ttft: float = 0.0 # Time to first token
|
| 96 |
+
itl: List[float] = field(default_factory=list) # List of inter-token latencies
|
| 97 |
+
text_chunks: List[str] = field(default_factory=list)
|
| 98 |
+
prompt_len: int = 0
|
| 99 |
+
error: str = ""
|
| 100 |
+
output_len: int = 0
|
| 101 |
+
start_time: float = 0.0
|
| 102 |
+
|
| 103 |
+
@staticmethod
|
| 104 |
+
def init_new(request_func_input: RequestFuncInput):
|
| 105 |
+
output = RequestFuncOutput()
|
| 106 |
+
output.prompt_len = request_func_input.prompt_len
|
| 107 |
+
return output
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_auth_headers() -> Dict[str, str]:
|
| 111 |
+
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
| 112 |
+
if openai_api_key:
|
| 113 |
+
return {"Authorization": f"Bearer {openai_api_key}"}
|
| 114 |
+
else:
|
| 115 |
+
api_key = os.environ.get("API_KEY")
|
| 116 |
+
if api_key:
|
| 117 |
+
return {"Authorization": f"{api_key}"}
|
| 118 |
+
return {}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_request_headers() -> Dict[str, str]:
|
| 122 |
+
headers = get_auth_headers()
|
| 123 |
+
if h := getattr(args, "header", None):
|
| 124 |
+
headers.update(parse_custom_headers(h))
|
| 125 |
+
return headers
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def wait_for_endpoint(url: str, timeout_sec: int = 60) -> bool:
|
| 129 |
+
"""Wait for the server to become ready by polling the given URL."""
|
| 130 |
+
print(f"Waiting up to {timeout_sec}s for {url} to become ready...")
|
| 131 |
+
start_time = time.perf_counter()
|
| 132 |
+
headers = get_auth_headers()
|
| 133 |
+
while True:
|
| 134 |
+
try:
|
| 135 |
+
response = requests.get(url, headers=headers, timeout=5)
|
| 136 |
+
if response.status_code == 200:
|
| 137 |
+
elapsed = time.perf_counter() - start_time
|
| 138 |
+
print(f"Server ready in {elapsed:.1f}s.")
|
| 139 |
+
return True
|
| 140 |
+
except requests.exceptions.RequestException:
|
| 141 |
+
pass
|
| 142 |
+
elapsed = time.perf_counter() - start_time
|
| 143 |
+
if elapsed >= timeout_sec:
|
| 144 |
+
print(f"Server did not become ready within {timeout_sec}s timeout.")
|
| 145 |
+
return False
|
| 146 |
+
time.sleep(1)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# trt llm does not support ignore_eos
|
| 150 |
+
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
|
| 151 |
+
async def async_request_trt_llm(
|
| 152 |
+
request_func_input: RequestFuncInput,
|
| 153 |
+
pbar: Optional[tqdm] = None,
|
| 154 |
+
) -> RequestFuncOutput:
|
| 155 |
+
api_url = request_func_input.api_url
|
| 156 |
+
assert api_url.endswith("generate_stream")
|
| 157 |
+
|
| 158 |
+
async with _create_bench_client_session() as session:
|
| 159 |
+
payload = {
|
| 160 |
+
"accumulate_tokens": True,
|
| 161 |
+
"text_input": request_func_input.prompt,
|
| 162 |
+
"temperature": 0.000001,
|
| 163 |
+
"top_p": 1.0,
|
| 164 |
+
"max_tokens": request_func_input.output_len,
|
| 165 |
+
"stream": True,
|
| 166 |
+
"min_length": request_func_input.output_len,
|
| 167 |
+
"end_id": 1048576,
|
| 168 |
+
**request_func_input.extra_request_body,
|
| 169 |
+
}
|
| 170 |
+
if args.disable_ignore_eos:
|
| 171 |
+
del payload["min_length"]
|
| 172 |
+
del payload["end_id"]
|
| 173 |
+
output = RequestFuncOutput.init_new(request_func_input)
|
| 174 |
+
|
| 175 |
+
ttft = 0.0
|
| 176 |
+
st = time.perf_counter()
|
| 177 |
+
most_recent_timestamp = st
|
| 178 |
+
try:
|
| 179 |
+
async with session.post(url=api_url, json=payload) as response:
|
| 180 |
+
if response.status == 200:
|
| 181 |
+
async for chunk_bytes in response.content:
|
| 182 |
+
chunk_bytes = chunk_bytes.strip()
|
| 183 |
+
if not chunk_bytes:
|
| 184 |
+
continue
|
| 185 |
+
|
| 186 |
+
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:")
|
| 187 |
+
|
| 188 |
+
data = json.loads(chunk)
|
| 189 |
+
output.generated_text += data["text_output"]
|
| 190 |
+
timestamp = time.perf_counter()
|
| 191 |
+
# First token
|
| 192 |
+
if ttft == 0.0:
|
| 193 |
+
ttft = timestamp - st
|
| 194 |
+
output.ttft = ttft
|
| 195 |
+
|
| 196 |
+
# Decoding phase
|
| 197 |
+
else:
|
| 198 |
+
output.itl.append(timestamp - most_recent_timestamp)
|
| 199 |
+
|
| 200 |
+
most_recent_timestamp = timestamp
|
| 201 |
+
|
| 202 |
+
output.latency = most_recent_timestamp - st
|
| 203 |
+
output.success = True
|
| 204 |
+
output.output_len = request_func_input.output_len
|
| 205 |
+
|
| 206 |
+
else:
|
| 207 |
+
output.error = (
|
| 208 |
+
(response.reason or "") + ": " + (await response.text())
|
| 209 |
+
)
|
| 210 |
+
output.success = False
|
| 211 |
+
except Exception:
|
| 212 |
+
output.success = False
|
| 213 |
+
exc_info = sys.exc_info()
|
| 214 |
+
output.error = "".join(traceback.format_exception(*exc_info))
|
| 215 |
+
|
| 216 |
+
if pbar:
|
| 217 |
+
pbar.update(1)
|
| 218 |
+
return output
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# set ignore_eos True by default
|
| 222 |
+
async def async_request_openai_completions(
|
| 223 |
+
request_func_input: RequestFuncInput,
|
| 224 |
+
pbar: Optional[tqdm] = None,
|
| 225 |
+
) -> RequestFuncOutput:
|
| 226 |
+
api_url = request_func_input.api_url
|
| 227 |
+
assert api_url.endswith(
|
| 228 |
+
"completions"
|
| 229 |
+
), "OpenAI Completions API URL must end with 'completions'."
|
| 230 |
+
|
| 231 |
+
prompt = request_func_input.prompt
|
| 232 |
+
|
| 233 |
+
async with _create_bench_client_session() as session:
|
| 234 |
+
# Build payload with defaults that can be overridden by extra_request_body
|
| 235 |
+
payload = {
|
| 236 |
+
"model": request_func_input.model,
|
| 237 |
+
"prompt": prompt,
|
| 238 |
+
"best_of": 1,
|
| 239 |
+
"max_tokens": request_func_input.output_len,
|
| 240 |
+
"stream": not args.disable_stream,
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
# Add temperature default only if not specified in extra_request_body
|
| 244 |
+
if "temperature" not in request_func_input.extra_request_body:
|
| 245 |
+
payload["temperature"] = 0.0
|
| 246 |
+
|
| 247 |
+
# Add ignore_eos default only if not specified in extra_request_body
|
| 248 |
+
if "ignore_eos" not in request_func_input.extra_request_body:
|
| 249 |
+
payload["ignore_eos"] = not args.disable_ignore_eos
|
| 250 |
+
|
| 251 |
+
# Merge in extra parameters - these will override defaults if present
|
| 252 |
+
payload.update(request_func_input.extra_request_body)
|
| 253 |
+
|
| 254 |
+
# hack to accommodate different LoRA conventions between SGLang and vLLM.
|
| 255 |
+
if request_func_input.lora_name:
|
| 256 |
+
payload["model"] = request_func_input.lora_name
|
| 257 |
+
payload["lora_path"] = request_func_input.lora_name
|
| 258 |
+
|
| 259 |
+
if request_func_input.image_data:
|
| 260 |
+
payload.update({"image_data": request_func_input.image_data})
|
| 261 |
+
|
| 262 |
+
headers = get_request_headers()
|
| 263 |
+
if request_func_input.routing_key:
|
| 264 |
+
headers[_ROUTING_KEY_HEADER] = request_func_input.routing_key
|
| 265 |
+
|
| 266 |
+
output = RequestFuncOutput.init_new(request_func_input)
|
| 267 |
+
|
| 268 |
+
generated_text = ""
|
| 269 |
+
output_len = request_func_input.output_len
|
| 270 |
+
ttft = 0.0
|
| 271 |
+
st = time.perf_counter()
|
| 272 |
+
output.start_time = st
|
| 273 |
+
most_recent_timestamp = st
|
| 274 |
+
try:
|
| 275 |
+
async with session.post(
|
| 276 |
+
url=api_url, json=payload, headers=headers
|
| 277 |
+
) as response:
|
| 278 |
+
if response.status == 200:
|
| 279 |
+
async for chunk_bytes in response.content:
|
| 280 |
+
chunk_bytes = chunk_bytes.strip()
|
| 281 |
+
if not chunk_bytes:
|
| 282 |
+
continue
|
| 283 |
+
|
| 284 |
+
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
| 285 |
+
latency = time.perf_counter() - st
|
| 286 |
+
if chunk == "[DONE]":
|
| 287 |
+
pass
|
| 288 |
+
else:
|
| 289 |
+
data = json.loads(chunk)
|
| 290 |
+
|
| 291 |
+
# NOTE: Some completion API might have a last
|
| 292 |
+
# usage summary response without a token so we
|
| 293 |
+
# want to check a token was generated
|
| 294 |
+
if data["choices"][0]["text"]:
|
| 295 |
+
timestamp = time.perf_counter()
|
| 296 |
+
# First token
|
| 297 |
+
if ttft == 0.0:
|
| 298 |
+
ttft = time.perf_counter() - st
|
| 299 |
+
output.ttft = ttft
|
| 300 |
+
|
| 301 |
+
# Decoding phase
|
| 302 |
+
else:
|
| 303 |
+
output.text_chunks.append(
|
| 304 |
+
data["choices"][0]["text"]
|
| 305 |
+
)
|
| 306 |
+
output.itl.append(timestamp - most_recent_timestamp)
|
| 307 |
+
|
| 308 |
+
most_recent_timestamp = timestamp
|
| 309 |
+
generated_text += data["choices"][0]["text"]
|
| 310 |
+
output_len = (data.get("usage") or {}).get(
|
| 311 |
+
"completion_tokens", output_len
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
output.generated_text = generated_text
|
| 315 |
+
output.success = True
|
| 316 |
+
output.latency = latency
|
| 317 |
+
output.output_len = output_len
|
| 318 |
+
else:
|
| 319 |
+
output.error = (
|
| 320 |
+
(response.reason or "") + ": " + (await response.text())
|
| 321 |
+
)
|
| 322 |
+
output.success = False
|
| 323 |
+
except Exception:
|
| 324 |
+
output.success = False
|
| 325 |
+
exc_info = sys.exc_info()
|
| 326 |
+
output.error = "".join(traceback.format_exception(*exc_info))
|
| 327 |
+
|
| 328 |
+
if pbar:
|
| 329 |
+
pbar.update(1)
|
| 330 |
+
return output
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
async def async_request_openai_chat_completions(
|
| 334 |
+
request_func_input: RequestFuncInput,
|
| 335 |
+
pbar: Optional[tqdm] = None,
|
| 336 |
+
) -> RequestFuncOutput:
|
| 337 |
+
"""Makes a request to the OpenAI Chat Completions API.
|
| 338 |
+
|
| 339 |
+
Handles both streaming and non-streaming responses, including support
|
| 340 |
+
for image data in messages. Calculates and returns various performance
|
| 341 |
+
metrics.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
request_func_input: Input parameters for the request.
|
| 345 |
+
pbar: Optional tqdm progress bar to update.
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
RequestFuncOutput: Output of the request, including generated text,
|
| 349 |
+
latency, TTFT, ITL, and success status.
|
| 350 |
+
"""
|
| 351 |
+
api_url = request_func_input.api_url
|
| 352 |
+
assert api_url.endswith(
|
| 353 |
+
"chat/completions"
|
| 354 |
+
), "OpenAI Chat Completions API URL must end with 'chat/completions'."
|
| 355 |
+
|
| 356 |
+
# TODO put it to other functions when `pbar` logic is refactored
|
| 357 |
+
if getattr(args, "print_requests", False):
|
| 358 |
+
rid = str(uuid.uuid4())
|
| 359 |
+
input_partial = deepcopy(request_func_input)
|
| 360 |
+
input_partial.prompt = "..."
|
| 361 |
+
request_start_time = time.time()
|
| 362 |
+
print(
|
| 363 |
+
f'rid={rid} time={request_start_time} message="request start" request_func_input="{str(input_partial)}"'
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
if isinstance(request_func_input.prompt, list):
|
| 367 |
+
messages = request_func_input.prompt
|
| 368 |
+
elif request_func_input.image_data:
|
| 369 |
+
# Build multi-image content: a list of image_url entries followed by the text
|
| 370 |
+
content_items = [
|
| 371 |
+
{
|
| 372 |
+
"type": "image_url",
|
| 373 |
+
"image_url": {"url": img_url},
|
| 374 |
+
}
|
| 375 |
+
for img_url in request_func_input.image_data
|
| 376 |
+
]
|
| 377 |
+
content_items.append({"type": "text", "text": request_func_input.prompt})
|
| 378 |
+
messages = [
|
| 379 |
+
{
|
| 380 |
+
"role": "user",
|
| 381 |
+
"content": content_items,
|
| 382 |
+
},
|
| 383 |
+
]
|
| 384 |
+
else:
|
| 385 |
+
messages = [{"role": "user", "content": request_func_input.prompt}]
|
| 386 |
+
|
| 387 |
+
async with _create_bench_client_session() as session:
|
| 388 |
+
# Build payload with defaults that can be overridden by extra_request_body
|
| 389 |
+
payload = {
|
| 390 |
+
"model": request_func_input.model,
|
| 391 |
+
"messages": messages,
|
| 392 |
+
"max_completion_tokens": request_func_input.output_len,
|
| 393 |
+
"stream": not args.disable_stream,
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
# Add temperature default only if not specified in extra_request_body
|
| 397 |
+
if "temperature" not in request_func_input.extra_request_body:
|
| 398 |
+
payload["temperature"] = 0.0
|
| 399 |
+
|
| 400 |
+
# Add ignore_eos default only if not specified in extra_request_body
|
| 401 |
+
# Default to False for more realistic behavior (respect EOS tokens)
|
| 402 |
+
if "ignore_eos" not in request_func_input.extra_request_body:
|
| 403 |
+
payload["ignore_eos"] = not args.disable_ignore_eos
|
| 404 |
+
|
| 405 |
+
# Merge in extra parameters (tools, temperature, top_p, etc.)
|
| 406 |
+
# These will override defaults if present
|
| 407 |
+
payload.update(request_func_input.extra_request_body)
|
| 408 |
+
|
| 409 |
+
# hack to accommodate different LoRA conventions between SGLang and vLLM.
|
| 410 |
+
if request_func_input.lora_name:
|
| 411 |
+
payload["model"] = request_func_input.lora_name
|
| 412 |
+
payload["lora_path"] = request_func_input.lora_name
|
| 413 |
+
|
| 414 |
+
headers = get_request_headers()
|
| 415 |
+
if request_func_input.routing_key:
|
| 416 |
+
headers[_ROUTING_KEY_HEADER] = request_func_input.routing_key
|
| 417 |
+
|
| 418 |
+
output = RequestFuncOutput.init_new(request_func_input)
|
| 419 |
+
|
| 420 |
+
generated_text = ""
|
| 421 |
+
output_len = request_func_input.output_len
|
| 422 |
+
ttft = 0.0
|
| 423 |
+
st = time.perf_counter()
|
| 424 |
+
output.start_time = st
|
| 425 |
+
most_recent_timestamp = st
|
| 426 |
+
try:
|
| 427 |
+
async with session.post(
|
| 428 |
+
url=api_url, json=payload, headers=headers
|
| 429 |
+
) as response:
|
| 430 |
+
if response.status == 200:
|
| 431 |
+
if args.disable_stream:
|
| 432 |
+
# Non-streaming response
|
| 433 |
+
response_json = await response.json()
|
| 434 |
+
output.generated_text = response_json["choices"][0]["message"][
|
| 435 |
+
"content"
|
| 436 |
+
]
|
| 437 |
+
output.success = True
|
| 438 |
+
output.latency = time.perf_counter() - st
|
| 439 |
+
output.ttft = (
|
| 440 |
+
output.latency
|
| 441 |
+
) # For non-streaming, TTFT = total latency
|
| 442 |
+
output.output_len = response_json.get("usage", {}).get(
|
| 443 |
+
"completion_tokens", output_len
|
| 444 |
+
)
|
| 445 |
+
else:
|
| 446 |
+
# Streaming response
|
| 447 |
+
async for chunk_bytes in response.content:
|
| 448 |
+
chunk_bytes = chunk_bytes.strip()
|
| 449 |
+
if not chunk_bytes:
|
| 450 |
+
continue
|
| 451 |
+
|
| 452 |
+
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
| 453 |
+
latency = time.perf_counter() - st
|
| 454 |
+
if chunk == "[DONE]":
|
| 455 |
+
pass
|
| 456 |
+
else:
|
| 457 |
+
data = json.loads(chunk)
|
| 458 |
+
|
| 459 |
+
# Check if this chunk contains content
|
| 460 |
+
delta = data.get("choices", [{}])[0].get("delta", {})
|
| 461 |
+
content = delta.get("content", "")
|
| 462 |
+
|
| 463 |
+
if content:
|
| 464 |
+
timestamp = time.perf_counter()
|
| 465 |
+
# First token
|
| 466 |
+
if ttft == 0.0:
|
| 467 |
+
ttft = timestamp - st
|
| 468 |
+
output.ttft = ttft
|
| 469 |
+
|
| 470 |
+
# Decoding phase
|
| 471 |
+
else:
|
| 472 |
+
output.text_chunks.append(content)
|
| 473 |
+
output.itl.append(
|
| 474 |
+
timestamp - most_recent_timestamp
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
most_recent_timestamp = timestamp
|
| 478 |
+
generated_text += content
|
| 479 |
+
|
| 480 |
+
# Check for usage info in final chunk
|
| 481 |
+
output_len = (data.get("usage") or {}).get(
|
| 482 |
+
"completion_tokens", output_len
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
output.generated_text = generated_text
|
| 486 |
+
output.success = True
|
| 487 |
+
output.latency = latency
|
| 488 |
+
output.output_len = output_len
|
| 489 |
+
else:
|
| 490 |
+
output.error = (
|
| 491 |
+
(response.reason or "") + ": " + (await response.text())
|
| 492 |
+
)
|
| 493 |
+
output.success = False
|
| 494 |
+
except Exception:
|
| 495 |
+
output.success = False
|
| 496 |
+
exc_info = sys.exc_info()
|
| 497 |
+
output.error = "".join(traceback.format_exception(*exc_info))
|
| 498 |
+
|
| 499 |
+
# TODO put it to other functions when `pbar` logic is refactored
|
| 500 |
+
if getattr(args, "print_requests", False):
|
| 501 |
+
curr_t = time.time()
|
| 502 |
+
output_partial = deepcopy(output)
|
| 503 |
+
output_partial.generated_text = "..."
|
| 504 |
+
print(
|
| 505 |
+
f'rid={rid} time={curr_t} time_delta={curr_t - request_start_time} message="request end" output="{str(output_partial)}"'
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
if pbar:
|
| 509 |
+
pbar.update(1)
|
| 510 |
+
return output
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
async def async_request_truss(
|
| 514 |
+
request_func_input: RequestFuncInput,
|
| 515 |
+
pbar: Optional[tqdm] = None,
|
| 516 |
+
) -> RequestFuncOutput:
|
| 517 |
+
api_url = request_func_input.api_url
|
| 518 |
+
|
| 519 |
+
prompt = request_func_input.prompt
|
| 520 |
+
|
| 521 |
+
async with _create_bench_client_session() as session:
|
| 522 |
+
payload = {
|
| 523 |
+
"model": request_func_input.model,
|
| 524 |
+
"prompt": prompt,
|
| 525 |
+
"temperature": 0.0,
|
| 526 |
+
"best_of": 1,
|
| 527 |
+
"max_tokens": request_func_input.output_len,
|
| 528 |
+
"stream": not args.disable_stream,
|
| 529 |
+
"ignore_eos": not args.disable_ignore_eos,
|
| 530 |
+
**request_func_input.extra_request_body,
|
| 531 |
+
}
|
| 532 |
+
headers = get_request_headers()
|
| 533 |
+
|
| 534 |
+
output = RequestFuncOutput.init_new(request_func_input)
|
| 535 |
+
|
| 536 |
+
generated_text = ""
|
| 537 |
+
ttft = 0.0
|
| 538 |
+
st = time.perf_counter()
|
| 539 |
+
most_recent_timestamp = st
|
| 540 |
+
try:
|
| 541 |
+
async with session.post(
|
| 542 |
+
url=api_url, json=payload, headers=headers
|
| 543 |
+
) as response:
|
| 544 |
+
if response.status == 200:
|
| 545 |
+
async for chunk_bytes in response.content:
|
| 546 |
+
chunk_bytes = chunk_bytes.strip()
|
| 547 |
+
if not chunk_bytes:
|
| 548 |
+
continue
|
| 549 |
+
|
| 550 |
+
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
| 551 |
+
latency = time.perf_counter() - st
|
| 552 |
+
if chunk == "[DONE]":
|
| 553 |
+
pass
|
| 554 |
+
else:
|
| 555 |
+
data = json.loads(chunk)
|
| 556 |
+
|
| 557 |
+
# NOTE: Some completion API might have a last
|
| 558 |
+
# usage summary response without a token so we
|
| 559 |
+
# want to check a token was generated
|
| 560 |
+
if data["choices"][0]["text"]:
|
| 561 |
+
timestamp = time.perf_counter()
|
| 562 |
+
# First token
|
| 563 |
+
if ttft == 0.0:
|
| 564 |
+
ttft = time.perf_counter() - st
|
| 565 |
+
output.ttft = ttft
|
| 566 |
+
|
| 567 |
+
# Decoding phase
|
| 568 |
+
else:
|
| 569 |
+
output.itl.append(timestamp - most_recent_timestamp)
|
| 570 |
+
|
| 571 |
+
most_recent_timestamp = timestamp
|
| 572 |
+
generated_text += data["choices"][0]["text"]
|
| 573 |
+
|
| 574 |
+
output.generated_text = generated_text
|
| 575 |
+
output.success = True
|
| 576 |
+
output.latency = latency
|
| 577 |
+
output.output_len = request_func_input.output_len
|
| 578 |
+
else:
|
| 579 |
+
output.error = (
|
| 580 |
+
(response.reason or "") + ": " + (await response.text())
|
| 581 |
+
)
|
| 582 |
+
output.success = False
|
| 583 |
+
except Exception:
|
| 584 |
+
output.success = False
|
| 585 |
+
exc_info = sys.exc_info()
|
| 586 |
+
output.error = "".join(traceback.format_exception(*exc_info))
|
| 587 |
+
|
| 588 |
+
if pbar:
|
| 589 |
+
pbar.update(1)
|
| 590 |
+
return output
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
async def async_request_sglang_generate(
|
| 594 |
+
request_func_input: RequestFuncInput,
|
| 595 |
+
pbar: Optional[tqdm] = None,
|
| 596 |
+
) -> RequestFuncOutput:
|
| 597 |
+
api_url = request_func_input.api_url
|
| 598 |
+
prompt = request_func_input.prompt
|
| 599 |
+
|
| 600 |
+
async with _create_bench_client_session() as session:
|
| 601 |
+
payload = {
|
| 602 |
+
("text" if isinstance(prompt, str) else "input_ids"): prompt,
|
| 603 |
+
"sampling_params": {
|
| 604 |
+
"temperature": 0.0,
|
| 605 |
+
"max_new_tokens": request_func_input.output_len,
|
| 606 |
+
"ignore_eos": not args.disable_ignore_eos,
|
| 607 |
+
},
|
| 608 |
+
"stream": not args.disable_stream,
|
| 609 |
+
"lora_path": request_func_input.lora_name,
|
| 610 |
+
"return_logprob": args.return_logprob,
|
| 611 |
+
"return_routed_experts": args.return_routed_experts,
|
| 612 |
+
"logprob_start_len": -1,
|
| 613 |
+
**request_func_input.extra_request_body,
|
| 614 |
+
}
|
| 615 |
+
|
| 616 |
+
# Add image data if available (list of image urls/base64)
|
| 617 |
+
if request_func_input.image_data:
|
| 618 |
+
payload["image_data"] = request_func_input.image_data
|
| 619 |
+
|
| 620 |
+
headers = get_request_headers()
|
| 621 |
+
if request_func_input.routing_key:
|
| 622 |
+
headers[_ROUTING_KEY_HEADER] = request_func_input.routing_key
|
| 623 |
+
|
| 624 |
+
output = RequestFuncOutput.init_new(request_func_input)
|
| 625 |
+
|
| 626 |
+
generated_text = ""
|
| 627 |
+
output_len = request_func_input.output_len
|
| 628 |
+
ttft = 0.0
|
| 629 |
+
st = time.perf_counter()
|
| 630 |
+
output.start_time = st
|
| 631 |
+
most_recent_timestamp = st
|
| 632 |
+
last_output_len = 0
|
| 633 |
+
try:
|
| 634 |
+
async with session.post(
|
| 635 |
+
url=api_url, json=payload, headers=headers
|
| 636 |
+
) as response:
|
| 637 |
+
if response.status == 200:
|
| 638 |
+
async for chunk_bytes in response.content:
|
| 639 |
+
chunk_bytes = chunk_bytes.strip()
|
| 640 |
+
if not chunk_bytes:
|
| 641 |
+
continue
|
| 642 |
+
|
| 643 |
+
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
| 644 |
+
latency = time.perf_counter() - st
|
| 645 |
+
if chunk == "[DONE]":
|
| 646 |
+
pass
|
| 647 |
+
else:
|
| 648 |
+
data = json.loads(chunk)
|
| 649 |
+
|
| 650 |
+
# NOTE: Some completion API might have a last
|
| 651 |
+
# usage summary response without a token so we
|
| 652 |
+
# want to check a token was generated
|
| 653 |
+
if "text" in data and data["text"]:
|
| 654 |
+
timestamp = time.perf_counter()
|
| 655 |
+
generated_text = data["text"]
|
| 656 |
+
output_len = data["meta_info"]["completion_tokens"]
|
| 657 |
+
|
| 658 |
+
# First token
|
| 659 |
+
if ttft == 0.0:
|
| 660 |
+
ttft = time.perf_counter() - st
|
| 661 |
+
output.ttft = ttft
|
| 662 |
+
|
| 663 |
+
# Decoding phase
|
| 664 |
+
else:
|
| 665 |
+
num_new_tokens = output_len - last_output_len
|
| 666 |
+
if num_new_tokens == 0:
|
| 667 |
+
continue
|
| 668 |
+
chunk_gap = timestamp - most_recent_timestamp
|
| 669 |
+
adjust_itl = chunk_gap / num_new_tokens
|
| 670 |
+
output.itl.extend([adjust_itl] * num_new_tokens)
|
| 671 |
+
|
| 672 |
+
most_recent_timestamp = timestamp
|
| 673 |
+
last_output_len = output_len
|
| 674 |
+
|
| 675 |
+
output.generated_text = generated_text
|
| 676 |
+
output.success = True
|
| 677 |
+
output.latency = latency
|
| 678 |
+
output.output_len = output_len
|
| 679 |
+
else:
|
| 680 |
+
output.error = (
|
| 681 |
+
(response.reason or "") + ": " + (await response.text())
|
| 682 |
+
)
|
| 683 |
+
output.success = False
|
| 684 |
+
except Exception:
|
| 685 |
+
output.success = False
|
| 686 |
+
exc_info = sys.exc_info()
|
| 687 |
+
output.error = "".join(traceback.format_exception(*exc_info))
|
| 688 |
+
print(f"{output.error=}")
|
| 689 |
+
|
| 690 |
+
if pbar:
|
| 691 |
+
pbar.update(1)
|
| 692 |
+
return output
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
async def async_request_gserver(
|
| 696 |
+
request_func_input: RequestFuncInput,
|
| 697 |
+
pbar: Optional[tqdm] = None,
|
| 698 |
+
) -> RequestFuncOutput:
|
| 699 |
+
raise NotImplementedError()
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
async def async_request_profile(api_url: str) -> RequestFuncOutput:
|
| 703 |
+
async with _create_bench_client_session() as session:
|
| 704 |
+
output = RequestFuncOutput()
|
| 705 |
+
try:
|
| 706 |
+
if api_url.endswith("/start_profile"):
|
| 707 |
+
num_steps = getattr(args, "profile_num_steps", None)
|
| 708 |
+
profile_by_stage = getattr(args, "profile_by_stage", None)
|
| 709 |
+
if profile_by_stage and num_steps is None:
|
| 710 |
+
num_steps = 5
|
| 711 |
+
|
| 712 |
+
output_dir = getattr(args, "profile_output_dir", None)
|
| 713 |
+
if output_dir is None:
|
| 714 |
+
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
|
| 715 |
+
output_dir = Path(os.path.abspath(os.path.normpath(output_dir))) / str(
|
| 716 |
+
time.time()
|
| 717 |
+
)
|
| 718 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
| 719 |
+
output_dir = str(output_dir)
|
| 720 |
+
|
| 721 |
+
body = {
|
| 722 |
+
"activities": getattr(args, "profile_activities", []),
|
| 723 |
+
"num_steps": num_steps,
|
| 724 |
+
"profile_by_stage": profile_by_stage,
|
| 725 |
+
"profile_stages": getattr(args, "profile_stages", None),
|
| 726 |
+
"output_dir": output_dir,
|
| 727 |
+
"profile_prefix": getattr(args, "profile_prefix", None),
|
| 728 |
+
}
|
| 729 |
+
else:
|
| 730 |
+
# stop_profile doesn't need any parameters
|
| 731 |
+
body = {}
|
| 732 |
+
print(f"async_request_profile {api_url=} {body=}")
|
| 733 |
+
# Add optional profiling parameters if provided
|
| 734 |
+
if (
|
| 735 |
+
hasattr(args, "profile_start_step")
|
| 736 |
+
and args.profile_start_step is not None
|
| 737 |
+
):
|
| 738 |
+
body["start_step"] = str(args.profile_start_step)
|
| 739 |
+
if hasattr(args, "profile_steps") and args.profile_steps is not None:
|
| 740 |
+
body["num_steps"] = str(args.profile_steps)
|
| 741 |
+
async with session.post(url=api_url, json=body) as response:
|
| 742 |
+
if response.status == 200:
|
| 743 |
+
output.success = True
|
| 744 |
+
else:
|
| 745 |
+
output.error = (
|
| 746 |
+
(response.reason or "") + ": " + (await response.text())
|
| 747 |
+
)
|
| 748 |
+
output.success = False
|
| 749 |
+
except Exception:
|
| 750 |
+
output.success = False
|
| 751 |
+
exc_info = sys.exc_info()
|
| 752 |
+
output.error = "".join(traceback.format_exception(*exc_info))
|
| 753 |
+
|
| 754 |
+
return output
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
def _build_profile_urls(
|
| 758 |
+
profile_prefill_url: Optional[List[str]],
|
| 759 |
+
profile_decode_url: Optional[List[str]],
|
| 760 |
+
) -> List[Tuple[str, str]]:
|
| 761 |
+
"""Build profile URLs list from prefill/decode URL arguments.
|
| 762 |
+
|
| 763 |
+
Returns:
|
| 764 |
+
List of (worker_type, url) tuples. e.g., [("Prefill-0", "http://..."), ("Decode-0", "http://...")]
|
| 765 |
+
"""
|
| 766 |
+
profile_urls = []
|
| 767 |
+
if profile_prefill_url:
|
| 768 |
+
for idx, url in enumerate(profile_prefill_url):
|
| 769 |
+
profile_urls.append((f"Prefill-{idx}", url))
|
| 770 |
+
if profile_decode_url:
|
| 771 |
+
for idx, url in enumerate(profile_decode_url):
|
| 772 |
+
profile_urls.append((f"Decode-{idx}", url))
|
| 773 |
+
return profile_urls
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
async def _call_profile_pd(profile_urls: List[Tuple[str, str]], mode: str) -> None:
|
| 777 |
+
"""Call profile endpoint (start/stop) on PD separated workers.
|
| 778 |
+
|
| 779 |
+
Args:
|
| 780 |
+
profile_urls: List of (worker_type, url) tuples
|
| 781 |
+
mode: "start" or "stop"
|
| 782 |
+
"""
|
| 783 |
+
endpoint = "/start_profile" if mode == "start" else "/stop_profile"
|
| 784 |
+
action = "Starting" if mode == "start" else "Stopping"
|
| 785 |
+
action_past = "started" if mode == "start" else "stopped"
|
| 786 |
+
|
| 787 |
+
print(f"{action} profiler...")
|
| 788 |
+
|
| 789 |
+
for worker_type, url in profile_urls:
|
| 790 |
+
profile_output = await async_request_profile(api_url=url + endpoint)
|
| 791 |
+
if profile_output.success:
|
| 792 |
+
print(f"Profiler {action_past} for {worker_type} worker at {url}")
|
| 793 |
+
else:
|
| 794 |
+
print(
|
| 795 |
+
f"Failed to {mode} profiler for {worker_type} worker at {url}: {profile_output.error}"
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
ASYNC_REQUEST_FUNCS = {
|
| 800 |
+
"sglang": async_request_sglang_generate,
|
| 801 |
+
"sglang-native": async_request_sglang_generate,
|
| 802 |
+
"sglang-oai": async_request_openai_completions,
|
| 803 |
+
"sglang-oai-chat": async_request_openai_chat_completions,
|
| 804 |
+
"vllm": async_request_openai_completions,
|
| 805 |
+
"vllm-chat": async_request_openai_chat_completions,
|
| 806 |
+
"lmdeploy": async_request_openai_completions,
|
| 807 |
+
"lmdeploy-chat": async_request_openai_chat_completions,
|
| 808 |
+
"trt": async_request_trt_llm,
|
| 809 |
+
"gserver": async_request_gserver,
|
| 810 |
+
"truss": async_request_truss,
|
| 811 |
+
}
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
@dataclass
|
| 815 |
+
class BenchmarkMetrics:
|
| 816 |
+
completed: int
|
| 817 |
+
total_input: int
|
| 818 |
+
total_input_text: int
|
| 819 |
+
total_input_vision: int
|
| 820 |
+
total_output: int
|
| 821 |
+
total_output_retokenized: int
|
| 822 |
+
request_throughput: float
|
| 823 |
+
input_throughput: float
|
| 824 |
+
output_throughput: float
|
| 825 |
+
output_throughput_retokenized: float
|
| 826 |
+
total_throughput: float
|
| 827 |
+
total_throughput_retokenized: float
|
| 828 |
+
mean_ttft_ms: float
|
| 829 |
+
median_ttft_ms: float
|
| 830 |
+
std_ttft_ms: float
|
| 831 |
+
p99_ttft_ms: float
|
| 832 |
+
mean_tpot_ms: float
|
| 833 |
+
median_tpot_ms: float
|
| 834 |
+
std_tpot_ms: float
|
| 835 |
+
p99_tpot_ms: float
|
| 836 |
+
mean_itl_ms: float
|
| 837 |
+
median_itl_ms: float
|
| 838 |
+
std_itl_ms: float
|
| 839 |
+
p95_itl_ms: float
|
| 840 |
+
p99_itl_ms: float
|
| 841 |
+
max_itl_ms: float
|
| 842 |
+
mean_e2e_latency_ms: float
|
| 843 |
+
median_e2e_latency_ms: float
|
| 844 |
+
std_e2e_latency_ms: float
|
| 845 |
+
p90_e2e_latency_ms: float
|
| 846 |
+
p99_e2e_latency_ms: float
|
| 847 |
+
concurrency: float
|
| 848 |
+
max_output_tokens_per_s: float = 0.0
|
| 849 |
+
max_concurrent_requests: int = 0
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
async def get_request(
|
| 853 |
+
input_requests: List[DatasetRow],
|
| 854 |
+
request_rate: float,
|
| 855 |
+
use_trace_timestamps: bool = False,
|
| 856 |
+
slowdown_factor: float = 1.0,
|
| 857 |
+
) -> AsyncGenerator[DatasetRow, None]:
|
| 858 |
+
if use_trace_timestamps:
|
| 859 |
+
print(
|
| 860 |
+
f"Using trace timestamps for request generation with slowdown factor {slowdown_factor}."
|
| 861 |
+
)
|
| 862 |
+
# Sort requests by timestamp for correct replay
|
| 863 |
+
input_requests.sort(key=lambda r: r.timestamp)
|
| 864 |
+
|
| 865 |
+
start_time = time.perf_counter()
|
| 866 |
+
trace_start_time_ms = input_requests[0].timestamp if input_requests else 0
|
| 867 |
+
|
| 868 |
+
for request in input_requests:
|
| 869 |
+
trace_time_s = (request.timestamp - trace_start_time_ms) / 1000.0
|
| 870 |
+
target_arrival_time = start_time + (trace_time_s * slowdown_factor)
|
| 871 |
+
|
| 872 |
+
sleep_duration = target_arrival_time - time.perf_counter()
|
| 873 |
+
if sleep_duration > 0:
|
| 874 |
+
await asyncio.sleep(sleep_duration)
|
| 875 |
+
|
| 876 |
+
yield request
|
| 877 |
+
else:
|
| 878 |
+
input_requests_iter = iter(input_requests)
|
| 879 |
+
for request in input_requests_iter:
|
| 880 |
+
yield request
|
| 881 |
+
|
| 882 |
+
if request_rate == float("inf"):
|
| 883 |
+
# If the request rate is infinity, then we don't need to wait.
|
| 884 |
+
continue
|
| 885 |
+
|
| 886 |
+
# Sample the request interval from the exponential distribution.
|
| 887 |
+
interval = np.random.exponential(1.0 / request_rate)
|
| 888 |
+
# The next request will be sent after the interval.
|
| 889 |
+
await asyncio.sleep(interval)
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
def calculate_metrics(
|
| 893 |
+
input_requests: Optional[List[DatasetRow]],
|
| 894 |
+
outputs: List[RequestFuncOutput],
|
| 895 |
+
dur_s: float,
|
| 896 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 897 |
+
backend: str,
|
| 898 |
+
accept_length: Optional[float] = None,
|
| 899 |
+
plot_throughput: bool = False,
|
| 900 |
+
) -> Tuple[BenchmarkMetrics, List[int]]:
|
| 901 |
+
output_lens: List[int] = []
|
| 902 |
+
retokenized_output_lens: List[int] = []
|
| 903 |
+
total_input = 0
|
| 904 |
+
total_input_text = 0
|
| 905 |
+
total_input_vision = 0
|
| 906 |
+
completed = 0
|
| 907 |
+
itls: List[float] = []
|
| 908 |
+
tpots: List[float] = []
|
| 909 |
+
ttfts: List[float] = []
|
| 910 |
+
e2e_latencies: List[float] = []
|
| 911 |
+
retokenized_itls: List[float] = []
|
| 912 |
+
|
| 913 |
+
use_retokenized_itl = (
|
| 914 |
+
accept_length is not None
|
| 915 |
+
and accept_length > 0
|
| 916 |
+
and backend in ("sglang-oai", "sglang-oai-chat")
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
for i in range(len(outputs)):
|
| 920 |
+
if outputs[i].success:
|
| 921 |
+
output_len = outputs[i].output_len
|
| 922 |
+
output_lens.append(output_len)
|
| 923 |
+
retokenized_output_len = len(
|
| 924 |
+
tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
|
| 925 |
+
)
|
| 926 |
+
retokenized_output_lens.append(retokenized_output_len)
|
| 927 |
+
if input_requests is not None:
|
| 928 |
+
total_input += input_requests[i].prompt_len
|
| 929 |
+
total_input_text += input_requests[i].text_prompt_len
|
| 930 |
+
total_input_vision += input_requests[i].vision_prompt_len
|
| 931 |
+
if output_len > 1:
|
| 932 |
+
tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
| 933 |
+
if use_retokenized_itl:
|
| 934 |
+
for k, itl in enumerate(outputs[i].itl):
|
| 935 |
+
num_tokens = len(
|
| 936 |
+
tokenizer.encode(
|
| 937 |
+
outputs[i].text_chunks[k], add_special_tokens=False
|
| 938 |
+
)
|
| 939 |
+
)
|
| 940 |
+
adjusted_itl = itl / num_tokens
|
| 941 |
+
retokenized_itls.extend([adjusted_itl] * num_tokens)
|
| 942 |
+
else:
|
| 943 |
+
itls += outputs[i].itl
|
| 944 |
+
ttfts.append(outputs[i].ttft)
|
| 945 |
+
|
| 946 |
+
e2e_latencies.append(outputs[i].latency)
|
| 947 |
+
|
| 948 |
+
completed += 1
|
| 949 |
+
else:
|
| 950 |
+
output_lens.append(0)
|
| 951 |
+
retokenized_output_lens.append(0)
|
| 952 |
+
|
| 953 |
+
if completed == 0:
|
| 954 |
+
warnings.warn(
|
| 955 |
+
"All requests failed. This is likely due to a misconfiguration "
|
| 956 |
+
"on the benchmark arguments.",
|
| 957 |
+
stacklevel=2,
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
max_output_tokens_per_s = 0.0
|
| 961 |
+
max_concurrent_requests = 0
|
| 962 |
+
|
| 963 |
+
successful_outputs = [output for output in outputs if output.success]
|
| 964 |
+
if successful_outputs:
|
| 965 |
+
min_start_time = min(output.start_time for output in successful_outputs)
|
| 966 |
+
max_end_time = max(
|
| 967 |
+
output.start_time + output.latency for output in successful_outputs
|
| 968 |
+
)
|
| 969 |
+
|
| 970 |
+
duration_seconds = int(np.ceil(max_end_time - min_start_time)) + 1
|
| 971 |
+
tokens_per_second = np.zeros(duration_seconds)
|
| 972 |
+
concurrent_requests_per_second = np.zeros(duration_seconds)
|
| 973 |
+
|
| 974 |
+
for output in outputs:
|
| 975 |
+
if not output.success:
|
| 976 |
+
continue
|
| 977 |
+
|
| 978 |
+
token_times = [output.start_time + output.ttft]
|
| 979 |
+
current_time = token_times[0]
|
| 980 |
+
for itl_value in output.itl:
|
| 981 |
+
current_time += itl_value
|
| 982 |
+
token_times.append(current_time)
|
| 983 |
+
|
| 984 |
+
for token_time in token_times:
|
| 985 |
+
second_bucket = int(token_time - min_start_time)
|
| 986 |
+
if 0 <= second_bucket < duration_seconds:
|
| 987 |
+
tokens_per_second[second_bucket] += 1
|
| 988 |
+
|
| 989 |
+
request_start_second = int(output.start_time - min_start_time)
|
| 990 |
+
request_end_second = int(
|
| 991 |
+
(output.start_time + output.latency) - min_start_time
|
| 992 |
+
)
|
| 993 |
+
|
| 994 |
+
for second in range(
|
| 995 |
+
request_start_second, min(request_end_second + 1, duration_seconds)
|
| 996 |
+
):
|
| 997 |
+
concurrent_requests_per_second[second] += 1
|
| 998 |
+
|
| 999 |
+
if len(tokens_per_second) > 0:
|
| 1000 |
+
max_output_tokens_per_s = float(np.max(tokens_per_second))
|
| 1001 |
+
max_concurrent_requests = int(np.max(concurrent_requests_per_second))
|
| 1002 |
+
|
| 1003 |
+
if plot_throughput:
|
| 1004 |
+
if TERM_PLOTLIB_AVAILABLE:
|
| 1005 |
+
import termplotlib as tpl
|
| 1006 |
+
|
| 1007 |
+
fig = tpl.figure()
|
| 1008 |
+
fig.plot(
|
| 1009 |
+
np.arange(len(tokens_per_second)),
|
| 1010 |
+
tokens_per_second,
|
| 1011 |
+
title="Output tokens per second",
|
| 1012 |
+
xlabel="Time (s)",
|
| 1013 |
+
)
|
| 1014 |
+
fig.plot(
|
| 1015 |
+
np.arange(len(concurrent_requests_per_second)),
|
| 1016 |
+
concurrent_requests_per_second,
|
| 1017 |
+
title="Concurrent requests per second",
|
| 1018 |
+
xlabel="Time (s)",
|
| 1019 |
+
)
|
| 1020 |
+
fig.show()
|
| 1021 |
+
else:
|
| 1022 |
+
print("tip: install termplotlib and gnuplot to plot the metrics")
|
| 1023 |
+
|
| 1024 |
+
itls = retokenized_itls if use_retokenized_itl else itls
|
| 1025 |
+
metrics = BenchmarkMetrics(
|
| 1026 |
+
completed=completed,
|
| 1027 |
+
total_input=total_input,
|
| 1028 |
+
total_input_text=total_input_text,
|
| 1029 |
+
total_input_vision=total_input_vision,
|
| 1030 |
+
total_output=sum(output_lens),
|
| 1031 |
+
total_output_retokenized=sum(retokenized_output_lens),
|
| 1032 |
+
request_throughput=completed / dur_s,
|
| 1033 |
+
input_throughput=total_input / dur_s,
|
| 1034 |
+
output_throughput=sum(output_lens) / dur_s,
|
| 1035 |
+
output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
|
| 1036 |
+
total_throughput=(total_input + sum(output_lens)) / dur_s,
|
| 1037 |
+
total_throughput_retokenized=(total_input + sum(retokenized_output_lens))
|
| 1038 |
+
/ dur_s,
|
| 1039 |
+
mean_ttft_ms=np.mean(ttfts or 0)
|
| 1040 |
+
* 1000, # ttfts is empty if streaming is not supported by backend
|
| 1041 |
+
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
| 1042 |
+
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
| 1043 |
+
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
| 1044 |
+
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
| 1045 |
+
median_tpot_ms=np.median(tpots or 0) * 1000,
|
| 1046 |
+
std_tpot_ms=np.std(tpots or 0) * 1000,
|
| 1047 |
+
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
| 1048 |
+
mean_itl_ms=np.mean(itls or 0) * 1000,
|
| 1049 |
+
median_itl_ms=np.median(itls or 0) * 1000,
|
| 1050 |
+
std_itl_ms=np.std(itls or 0) * 1000,
|
| 1051 |
+
p95_itl_ms=np.percentile(itls or 0, 95) * 1000,
|
| 1052 |
+
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
|
| 1053 |
+
max_itl_ms=np.max(itls or 0) * 1000,
|
| 1054 |
+
mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
|
| 1055 |
+
median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
|
| 1056 |
+
std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
|
| 1057 |
+
p90_e2e_latency_ms=np.percentile(e2e_latencies, 90) * 1000,
|
| 1058 |
+
p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000,
|
| 1059 |
+
concurrency=np.sum(e2e_latencies) / dur_s,
|
| 1060 |
+
max_output_tokens_per_s=max_output_tokens_per_s,
|
| 1061 |
+
max_concurrent_requests=max_concurrent_requests,
|
| 1062 |
+
)
|
| 1063 |
+
|
| 1064 |
+
return metrics, output_lens
|
| 1065 |
+
|
| 1066 |
+
|
| 1067 |
+
MULTI_TURN_BACKENDS = {"sglang-oai-chat", "vllm-chat", "lmdeploy-chat"}
|
| 1068 |
+
|
| 1069 |
+
|
| 1070 |
+
def wrap_multi_turn_request_func(request_func: Callable, backend: str) -> Callable:
|
| 1071 |
+
assert (
|
| 1072 |
+
backend in MULTI_TURN_BACKENDS
|
| 1073 |
+
), f"Multi-turn only supports chat backends: {MULTI_TURN_BACKENDS}, got {backend}"
|
| 1074 |
+
|
| 1075 |
+
async def f(
|
| 1076 |
+
request_func_input: RequestFuncInput,
|
| 1077 |
+
pbar: Optional[tqdm] = None,
|
| 1078 |
+
) -> List[RequestFuncOutput]:
|
| 1079 |
+
prompts: List[str] = request_func_input.prompt
|
| 1080 |
+
prev_messages: List[Dict[str, str]] = []
|
| 1081 |
+
outputs = []
|
| 1082 |
+
|
| 1083 |
+
for round_index in range(len(prompts)):
|
| 1084 |
+
prev_messages.append({"role": "user", "content": prompts[round_index]})
|
| 1085 |
+
|
| 1086 |
+
inner_input = replace(
|
| 1087 |
+
copy.deepcopy(request_func_input), prompt=copy.deepcopy(prev_messages)
|
| 1088 |
+
)
|
| 1089 |
+
output = await request_func(
|
| 1090 |
+
inner_input, pbar=pbar if round_index == len(prompts) - 1 else None
|
| 1091 |
+
)
|
| 1092 |
+
outputs.append(output)
|
| 1093 |
+
|
| 1094 |
+
prev_messages.append(
|
| 1095 |
+
{"role": "assistant", "content": output.generated_text}
|
| 1096 |
+
)
|
| 1097 |
+
|
| 1098 |
+
return outputs
|
| 1099 |
+
|
| 1100 |
+
return f
|
| 1101 |
+
|
| 1102 |
+
|
| 1103 |
+
async def benchmark(
|
| 1104 |
+
backend: str,
|
| 1105 |
+
api_url: str,
|
| 1106 |
+
base_url: str,
|
| 1107 |
+
model_id: str,
|
| 1108 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 1109 |
+
input_requests: List[DatasetRow],
|
| 1110 |
+
request_rate: float,
|
| 1111 |
+
max_concurrency: Optional[int],
|
| 1112 |
+
disable_tqdm: bool,
|
| 1113 |
+
lora_names: List[str],
|
| 1114 |
+
lora_request_distribution: Optional[str],
|
| 1115 |
+
lora_zipf_alpha: Optional[float],
|
| 1116 |
+
extra_request_body: Dict[str, Any],
|
| 1117 |
+
profile: bool,
|
| 1118 |
+
pd_separated: bool = False,
|
| 1119 |
+
flush_cache: bool = False,
|
| 1120 |
+
warmup_requests: int = 1,
|
| 1121 |
+
use_trace_timestamps: bool = False,
|
| 1122 |
+
mooncake_slowdown_factor=1.0,
|
| 1123 |
+
mooncake_num_rounds=1,
|
| 1124 |
+
profile_prefill_url: Optional[List[str]] = None,
|
| 1125 |
+
profile_decode_url: Optional[List[str]] = None,
|
| 1126 |
+
):
|
| 1127 |
+
if backend in ASYNC_REQUEST_FUNCS:
|
| 1128 |
+
request_func = ASYNC_REQUEST_FUNCS[backend]
|
| 1129 |
+
else:
|
| 1130 |
+
raise ValueError(f"Unknown backend: {backend}")
|
| 1131 |
+
|
| 1132 |
+
# Check for multi-turn: prompt is a list of strings (not OpenAI messages dicts)
|
| 1133 |
+
# Multi-turn format: ["turn1", "turn2", ...] - list of strings
|
| 1134 |
+
# OpenAI format: [{"role": "user", "content": "..."}, ...] - list of dicts
|
| 1135 |
+
first_prompt = input_requests[0].prompt
|
| 1136 |
+
is_multi_turn = (
|
| 1137 |
+
isinstance(first_prompt, list)
|
| 1138 |
+
and len(first_prompt) > 0
|
| 1139 |
+
and isinstance(first_prompt[0], str)
|
| 1140 |
+
)
|
| 1141 |
+
if is_multi_turn:
|
| 1142 |
+
request_func = wrap_multi_turn_request_func(request_func, backend=backend)
|
| 1143 |
+
|
| 1144 |
+
# Limit concurrency
|
| 1145 |
+
# From https://github.com/vllm-project/vllm/pull/9390
|
| 1146 |
+
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
|
| 1147 |
+
|
| 1148 |
+
async def limited_request_func(request_func_input, pbar):
|
| 1149 |
+
if semaphore is None:
|
| 1150 |
+
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
| 1151 |
+
async with semaphore:
|
| 1152 |
+
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
| 1153 |
+
|
| 1154 |
+
# Warmup
|
| 1155 |
+
print(f"Starting warmup with {warmup_requests} sequences...")
|
| 1156 |
+
|
| 1157 |
+
# Handle the data structure difference for the warmup request
|
| 1158 |
+
if args.dataset_name == "mooncake":
|
| 1159 |
+
# For mooncake, input_requests is a list of dicts.
|
| 1160 |
+
# We need to build a temporary DatasetRow for the warmup phase.
|
| 1161 |
+
warmup_record = input_requests[0]
|
| 1162 |
+
|
| 1163 |
+
# Build prompt from hash_ids, just like in the async generator
|
| 1164 |
+
hash_ids = warmup_record.get("hash_ids", [])
|
| 1165 |
+
prompt_text = ""
|
| 1166 |
+
for hash_id in hash_ids:
|
| 1167 |
+
prompt_text += f"{hash_id}" + " ".join(["hi"] * 512)
|
| 1168 |
+
prompt_text += "Can you tell me a detailed story in 1000 words?"
|
| 1169 |
+
|
| 1170 |
+
output_len = warmup_record.get("output_length", 32)
|
| 1171 |
+
prompt_len = len(tokenizer.encode(prompt_text))
|
| 1172 |
+
|
| 1173 |
+
# Create a temporary DatasetRow object for warmup
|
| 1174 |
+
test_request = DatasetRow(
|
| 1175 |
+
prompt=prompt_text,
|
| 1176 |
+
prompt_len=prompt_len,
|
| 1177 |
+
output_len=output_len,
|
| 1178 |
+
image_data=None, # Mooncake doesn't have image data
|
| 1179 |
+
)
|
| 1180 |
+
else:
|
| 1181 |
+
# For all other datasets, input_requests is a list of DatasetRow objects
|
| 1182 |
+
test_request = input_requests[0]
|
| 1183 |
+
|
| 1184 |
+
if lora_names is not None and len(lora_names) != 0:
|
| 1185 |
+
lora_name = lora_names[0]
|
| 1186 |
+
else:
|
| 1187 |
+
lora_name = None
|
| 1188 |
+
|
| 1189 |
+
# Create the test input once
|
| 1190 |
+
test_input = RequestFuncInput(
|
| 1191 |
+
model=model_id,
|
| 1192 |
+
prompt=test_request.prompt,
|
| 1193 |
+
api_url=api_url,
|
| 1194 |
+
prompt_len=test_request.prompt_len,
|
| 1195 |
+
output_len=min(test_request.output_len, 32),
|
| 1196 |
+
lora_name=lora_name,
|
| 1197 |
+
image_data=test_request.image_data,
|
| 1198 |
+
extra_request_body=extra_request_body,
|
| 1199 |
+
)
|
| 1200 |
+
|
| 1201 |
+
# Run warmup requests
|
| 1202 |
+
warmup_tasks = []
|
| 1203 |
+
for _ in range(warmup_requests):
|
| 1204 |
+
warmup_tasks.append(
|
| 1205 |
+
asyncio.create_task(request_func(request_func_input=test_input))
|
| 1206 |
+
)
|
| 1207 |
+
|
| 1208 |
+
warmup_outputs = await asyncio.gather(*warmup_tasks)
|
| 1209 |
+
if is_multi_turn:
|
| 1210 |
+
warmup_outputs = [x for output in warmup_outputs for x in output]
|
| 1211 |
+
|
| 1212 |
+
# Check if at least one warmup request succeeded
|
| 1213 |
+
if warmup_requests > 0 and not any(output.success for output in warmup_outputs):
|
| 1214 |
+
raise ValueError(
|
| 1215 |
+
"Warmup failed - Please make sure benchmark arguments "
|
| 1216 |
+
f"are correctly specified. Error: {warmup_outputs[0].error}"
|
| 1217 |
+
)
|
| 1218 |
+
else:
|
| 1219 |
+
print(
|
| 1220 |
+
f"Warmup completed with {args.warmup_requests} sequences. Starting main benchmark run..."
|
| 1221 |
+
)
|
| 1222 |
+
|
| 1223 |
+
# Flush cache
|
| 1224 |
+
if ("sglang" in backend and _get_bool_env_var("SGLANG_IS_IN_CI")) or flush_cache:
|
| 1225 |
+
requests.post(base_url + "/flush_cache", headers=get_auth_headers())
|
| 1226 |
+
|
| 1227 |
+
time.sleep(1.0)
|
| 1228 |
+
|
| 1229 |
+
# Build profile URLs for PD separated mode (do this once at the beginning)
|
| 1230 |
+
pd_profile_urls = []
|
| 1231 |
+
if profile and pd_separated:
|
| 1232 |
+
pd_profile_urls = _build_profile_urls(profile_prefill_url, profile_decode_url)
|
| 1233 |
+
if not pd_profile_urls:
|
| 1234 |
+
print(
|
| 1235 |
+
"Warning: PD separated mode requires --profile-prefill-url or --profile-decode-url"
|
| 1236 |
+
)
|
| 1237 |
+
print("Skipping profiler start. Please specify worker URLs for profiling.")
|
| 1238 |
+
|
| 1239 |
+
# Start profiler
|
| 1240 |
+
if profile:
|
| 1241 |
+
if pd_separated:
|
| 1242 |
+
if pd_profile_urls:
|
| 1243 |
+
await _call_profile_pd(pd_profile_urls, "start")
|
| 1244 |
+
else:
|
| 1245 |
+
print("Starting profiler...")
|
| 1246 |
+
profile_output = await async_request_profile(
|
| 1247 |
+
api_url=base_url + "/start_profile"
|
| 1248 |
+
)
|
| 1249 |
+
if profile_output.success:
|
| 1250 |
+
print("Profiler started")
|
| 1251 |
+
|
| 1252 |
+
# Run all requests
|
| 1253 |
+
benchmark_start_time = time.perf_counter()
|
| 1254 |
+
tasks: List[asyncio.Task] = []
|
| 1255 |
+
pbar_total = len(input_requests)
|
| 1256 |
+
if (
|
| 1257 |
+
backend == "sglang" and args.dataset_name == "mooncake"
|
| 1258 |
+
): # Assuming mooncake is mainly for sglang or similar backends
|
| 1259 |
+
print("Using time-based Mooncake request scheduler, ignoring --request-rate.")
|
| 1260 |
+
request_generator = get_mooncake_request_over_time(
|
| 1261 |
+
input_requests, tokenizer, mooncake_slowdown_factor, mooncake_num_rounds
|
| 1262 |
+
)
|
| 1263 |
+
print(
|
| 1264 |
+
f"Starting Mooncake trace replay. Sessions: {len(input_requests)}, Rounds per session: {mooncake_num_rounds}. Slowdown factor: {mooncake_slowdown_factor}"
|
| 1265 |
+
)
|
| 1266 |
+
pbar_total *= args.mooncake_num_rounds
|
| 1267 |
+
else:
|
| 1268 |
+
request_generator = get_request(input_requests, request_rate)
|
| 1269 |
+
|
| 1270 |
+
# Prepare LoRA request distribution parameters
|
| 1271 |
+
if lora_request_distribution == "distinct":
|
| 1272 |
+
lora_idx = 0
|
| 1273 |
+
elif lora_request_distribution == "skewed":
|
| 1274 |
+
weights = np.array([lora_zipf_alpha**-i for i in range(len(lora_names))])
|
| 1275 |
+
lora_probs = weights / np.sum(weights)
|
| 1276 |
+
else:
|
| 1277 |
+
lora_idx = None
|
| 1278 |
+
lora_probs = None
|
| 1279 |
+
|
| 1280 |
+
pbar = None if disable_tqdm else tqdm(total=pbar_total)
|
| 1281 |
+
async for request in request_generator:
|
| 1282 |
+
if lora_names is not None and len(lora_names) != 0:
|
| 1283 |
+
if lora_request_distribution == "uniform":
|
| 1284 |
+
lora_name = random.choice(lora_names)
|
| 1285 |
+
elif lora_request_distribution == "distinct":
|
| 1286 |
+
lora_name = lora_names[lora_idx]
|
| 1287 |
+
lora_idx = (lora_idx + 1) % len(lora_names)
|
| 1288 |
+
else:
|
| 1289 |
+
assert (
|
| 1290 |
+
lora_request_distribution == "skewed"
|
| 1291 |
+
), f"Unexpected lora_request_distribution: {lora_request_distribution}. Expected 'skewed'."
|
| 1292 |
+
|
| 1293 |
+
lora_name = np.random.choice(lora_names, p=lora_probs)
|
| 1294 |
+
else:
|
| 1295 |
+
lora_name = None
|
| 1296 |
+
|
| 1297 |
+
# Merge global extra_request_body with per-request extras
|
| 1298 |
+
# Per-request parameters take precedence over global ones
|
| 1299 |
+
merged_extra_body = {**extra_request_body, **request.extra_request_body}
|
| 1300 |
+
|
| 1301 |
+
request_func_input = RequestFuncInput(
|
| 1302 |
+
model=model_id,
|
| 1303 |
+
prompt=request.prompt,
|
| 1304 |
+
api_url=api_url,
|
| 1305 |
+
prompt_len=request.prompt_len,
|
| 1306 |
+
output_len=request.output_len,
|
| 1307 |
+
lora_name=lora_name,
|
| 1308 |
+
image_data=request.image_data,
|
| 1309 |
+
extra_request_body=merged_extra_body,
|
| 1310 |
+
timestamp=request.timestamp,
|
| 1311 |
+
routing_key=request.routing_key,
|
| 1312 |
+
)
|
| 1313 |
+
|
| 1314 |
+
tasks.append(
|
| 1315 |
+
asyncio.create_task(
|
| 1316 |
+
limited_request_func(request_func_input=request_func_input, pbar=pbar)
|
| 1317 |
+
)
|
| 1318 |
+
)
|
| 1319 |
+
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
| 1320 |
+
if is_multi_turn:
|
| 1321 |
+
outputs = [x for output in outputs for x in output]
|
| 1322 |
+
|
| 1323 |
+
# Stop profiler (only if profile_steps was not provided, as it auto-stops)
|
| 1324 |
+
if profile and not (
|
| 1325 |
+
hasattr(args, "profile_steps") and args.profile_steps is not None
|
| 1326 |
+
):
|
| 1327 |
+
if pd_separated:
|
| 1328 |
+
if pd_profile_urls:
|
| 1329 |
+
await _call_profile_pd(pd_profile_urls, "stop")
|
| 1330 |
+
else:
|
| 1331 |
+
if getattr(args, "profile_num_steps", None) is None:
|
| 1332 |
+
print("Stopping profiler...")
|
| 1333 |
+
profile_output = await async_request_profile(
|
| 1334 |
+
api_url=base_url + "/stop_profile"
|
| 1335 |
+
)
|
| 1336 |
+
if profile_output.success:
|
| 1337 |
+
print("Profiler stopped")
|
| 1338 |
+
|
| 1339 |
+
if pbar is not None:
|
| 1340 |
+
pbar.close()
|
| 1341 |
+
|
| 1342 |
+
if "sglang" in backend:
|
| 1343 |
+
server_info = requests.get(
|
| 1344 |
+
base_url + "/get_server_info", headers=get_auth_headers()
|
| 1345 |
+
)
|
| 1346 |
+
if server_info.status_code == 200:
|
| 1347 |
+
server_info_json = server_info.json()
|
| 1348 |
+
if "decode" in server_info_json:
|
| 1349 |
+
server_info_json = server_info_json["decode"][0]
|
| 1350 |
+
if (
|
| 1351 |
+
"internal_states" in server_info_json
|
| 1352 |
+
and server_info_json["internal_states"]
|
| 1353 |
+
):
|
| 1354 |
+
accept_length = server_info_json["internal_states"][0].get(
|
| 1355 |
+
"avg_spec_accept_length", None
|
| 1356 |
+
)
|
| 1357 |
+
else:
|
| 1358 |
+
accept_length = None
|
| 1359 |
+
else:
|
| 1360 |
+
accept_length = None
|
| 1361 |
+
else:
|
| 1362 |
+
accept_length = None
|
| 1363 |
+
|
| 1364 |
+
# Compute metrics and print results
|
| 1365 |
+
benchmark_duration = time.perf_counter() - benchmark_start_time
|
| 1366 |
+
metrics, output_lens = calculate_metrics(
|
| 1367 |
+
input_requests=None if is_multi_turn else input_requests,
|
| 1368 |
+
outputs=outputs,
|
| 1369 |
+
dur_s=benchmark_duration,
|
| 1370 |
+
tokenizer=tokenizer,
|
| 1371 |
+
backend=backend,
|
| 1372 |
+
accept_length=accept_length,
|
| 1373 |
+
plot_throughput=args.plot_throughput,
|
| 1374 |
+
)
|
| 1375 |
+
|
| 1376 |
+
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
|
| 1377 |
+
print("{:<40} {:<10}".format("Backend:", backend))
|
| 1378 |
+
print(
|
| 1379 |
+
"{:<40} {:<10}".format(
|
| 1380 |
+
"Traffic request rate:", "trace" if use_trace_timestamps else request_rate
|
| 1381 |
+
)
|
| 1382 |
+
)
|
| 1383 |
+
print(
|
| 1384 |
+
"{:<40} {:<10}".format(
|
| 1385 |
+
"Max request concurrency:",
|
| 1386 |
+
max_concurrency if max_concurrency else "not set",
|
| 1387 |
+
)
|
| 1388 |
+
)
|
| 1389 |
+
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
| 1390 |
+
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
| 1391 |
+
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
| 1392 |
+
print("{:<40} {:<10}".format("Total input text tokens:", metrics.total_input_text))
|
| 1393 |
+
if args.dataset_name in ["image", "mmmu"]:
|
| 1394 |
+
print(
|
| 1395 |
+
"{:<40} {:<10}".format(
|
| 1396 |
+
"Total input vision tokens:", metrics.total_input_vision
|
| 1397 |
+
)
|
| 1398 |
+
)
|
| 1399 |
+
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
|
| 1400 |
+
print(
|
| 1401 |
+
"{:<40} {:<10}".format(
|
| 1402 |
+
"Total generated tokens (retokenized):", metrics.total_output_retokenized
|
| 1403 |
+
)
|
| 1404 |
+
)
|
| 1405 |
+
print(
|
| 1406 |
+
"{:<40} {:<10.2f}".format(
|
| 1407 |
+
"Request throughput (req/s):", metrics.request_throughput
|
| 1408 |
+
)
|
| 1409 |
+
)
|
| 1410 |
+
print(
|
| 1411 |
+
"{:<40} {:<10.2f}".format(
|
| 1412 |
+
"Input token throughput (tok/s):", metrics.input_throughput
|
| 1413 |
+
)
|
| 1414 |
+
)
|
| 1415 |
+
print(
|
| 1416 |
+
"{:<40} {:<10.2f}".format(
|
| 1417 |
+
"Output token throughput (tok/s):", metrics.output_throughput
|
| 1418 |
+
)
|
| 1419 |
+
)
|
| 1420 |
+
print(
|
| 1421 |
+
"{:<40} {:<10.2f}".format(
|
| 1422 |
+
"Peak output token throughput (tok/s):", metrics.max_output_tokens_per_s
|
| 1423 |
+
)
|
| 1424 |
+
)
|
| 1425 |
+
print(
|
| 1426 |
+
"{:<40} {:<10}".format(
|
| 1427 |
+
"Peak concurrent requests:", metrics.max_concurrent_requests
|
| 1428 |
+
)
|
| 1429 |
+
)
|
| 1430 |
+
print(
|
| 1431 |
+
"{:<40} {:<10.2f}".format(
|
| 1432 |
+
"Total token throughput (tok/s):", metrics.total_throughput
|
| 1433 |
+
)
|
| 1434 |
+
)
|
| 1435 |
+
print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
|
| 1436 |
+
if accept_length:
|
| 1437 |
+
print("{:<40} {:<10.2f}".format("Accept length:", accept_length))
|
| 1438 |
+
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
|
| 1439 |
+
print(
|
| 1440 |
+
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
|
| 1441 |
+
)
|
| 1442 |
+
print(
|
| 1443 |
+
"{:<40} {:<10.2f}".format(
|
| 1444 |
+
"Median E2E Latency (ms):", metrics.median_e2e_latency_ms
|
| 1445 |
+
)
|
| 1446 |
+
)
|
| 1447 |
+
print(
|
| 1448 |
+
"{:<40} {:<10.2f}".format("P90 E2E Latency (ms):", metrics.p90_e2e_latency_ms)
|
| 1449 |
+
)
|
| 1450 |
+
print(
|
| 1451 |
+
"{:<40} {:<10.2f}".format("P99 E2E Latency (ms):", metrics.p99_e2e_latency_ms)
|
| 1452 |
+
)
|
| 1453 |
+
print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
|
| 1454 |
+
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
| 1455 |
+
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
| 1456 |
+
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
| 1457 |
+
print(
|
| 1458 |
+
"{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
|
| 1459 |
+
)
|
| 1460 |
+
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
| 1461 |
+
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
|
| 1462 |
+
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
| 1463 |
+
print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
|
| 1464 |
+
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
| 1465 |
+
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
| 1466 |
+
print("{:<40} {:<10.2f}".format("P95 ITL (ms):", metrics.p95_itl_ms))
|
| 1467 |
+
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
| 1468 |
+
print("{:<40} {:<10.2f}".format("Max ITL (ms):", metrics.max_itl_ms))
|
| 1469 |
+
print("=" * 50)
|
| 1470 |
+
|
| 1471 |
+
resp = requests.get(base_url + "/get_server_info", headers=get_auth_headers())
|
| 1472 |
+
server_info = resp.json() if resp.status_code == 200 else None
|
| 1473 |
+
|
| 1474 |
+
if (
|
| 1475 |
+
metrics.median_ttft_ms is not None
|
| 1476 |
+
and metrics.mean_itl_ms is not None
|
| 1477 |
+
and metrics.output_throughput is not None
|
| 1478 |
+
):
|
| 1479 |
+
result = {
|
| 1480 |
+
# Arguments
|
| 1481 |
+
"tag": getattr(args, "tag", None),
|
| 1482 |
+
"backend": args.backend,
|
| 1483 |
+
"dataset_name": args.dataset_name,
|
| 1484 |
+
"request_rate": "trace" if use_trace_timestamps else request_rate,
|
| 1485 |
+
"max_concurrency": max_concurrency,
|
| 1486 |
+
"sharegpt_output_len": args.sharegpt_output_len,
|
| 1487 |
+
"random_input_len": args.random_input_len,
|
| 1488 |
+
"random_output_len": args.random_output_len,
|
| 1489 |
+
"random_range_ratio": args.random_range_ratio,
|
| 1490 |
+
# Information
|
| 1491 |
+
"server_info": server_info,
|
| 1492 |
+
# Results
|
| 1493 |
+
"duration": benchmark_duration,
|
| 1494 |
+
"completed": metrics.completed,
|
| 1495 |
+
"total_input_tokens": metrics.total_input,
|
| 1496 |
+
"total_input_text_tokens": metrics.total_input_text,
|
| 1497 |
+
"total_input_vision_tokens": metrics.total_input_vision,
|
| 1498 |
+
"total_output_tokens": metrics.total_output,
|
| 1499 |
+
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
| 1500 |
+
"request_throughput": metrics.request_throughput,
|
| 1501 |
+
"input_throughput": metrics.input_throughput,
|
| 1502 |
+
"output_throughput": metrics.output_throughput,
|
| 1503 |
+
"total_throughput": metrics.total_throughput,
|
| 1504 |
+
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
| 1505 |
+
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
| 1506 |
+
"std_e2e_latency_ms": metrics.std_e2e_latency_ms,
|
| 1507 |
+
"p90_e2e_latency_ms": metrics.p90_e2e_latency_ms,
|
| 1508 |
+
"p99_e2e_latency_ms": metrics.p99_e2e_latency_ms,
|
| 1509 |
+
"mean_ttft_ms": metrics.mean_ttft_ms,
|
| 1510 |
+
"median_ttft_ms": metrics.median_ttft_ms,
|
| 1511 |
+
"std_ttft_ms": metrics.std_ttft_ms,
|
| 1512 |
+
"p99_ttft_ms": metrics.p99_ttft_ms,
|
| 1513 |
+
"mean_tpot_ms": metrics.mean_tpot_ms,
|
| 1514 |
+
"median_tpot_ms": metrics.median_tpot_ms,
|
| 1515 |
+
"std_tpot_ms": metrics.std_tpot_ms,
|
| 1516 |
+
"p99_tpot_ms": metrics.p99_tpot_ms,
|
| 1517 |
+
"mean_itl_ms": metrics.mean_itl_ms,
|
| 1518 |
+
"median_itl_ms": metrics.median_itl_ms,
|
| 1519 |
+
"std_itl_ms": metrics.std_itl_ms,
|
| 1520 |
+
"p95_itl_ms": metrics.p95_itl_ms,
|
| 1521 |
+
"p99_itl_ms": metrics.p99_itl_ms,
|
| 1522 |
+
"concurrency": metrics.concurrency,
|
| 1523 |
+
"accept_length": accept_length,
|
| 1524 |
+
"max_output_tokens_per_s": metrics.max_output_tokens_per_s,
|
| 1525 |
+
"max_concurrent_requests": metrics.max_concurrent_requests,
|
| 1526 |
+
}
|
| 1527 |
+
else:
|
| 1528 |
+
print(f"Error running benchmark for request rate: {request_rate}")
|
| 1529 |
+
print("-" * 30)
|
| 1530 |
+
|
| 1531 |
+
# Determine output file name
|
| 1532 |
+
if args.output_file:
|
| 1533 |
+
output_file_name = args.output_file
|
| 1534 |
+
else:
|
| 1535 |
+
now = datetime.now().strftime("%m%d")
|
| 1536 |
+
if args.dataset_name == "image":
|
| 1537 |
+
output_file_name = (
|
| 1538 |
+
f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_"
|
| 1539 |
+
f"{args.random_output_len}_{args.image_count}imgs_"
|
| 1540 |
+
f"{args.image_resolution}.jsonl"
|
| 1541 |
+
)
|
| 1542 |
+
elif args.dataset_name.startswith("random"):
|
| 1543 |
+
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
|
| 1544 |
+
else:
|
| 1545 |
+
output_file_name = (
|
| 1546 |
+
f"{args.backend}_{now}_{args.num_prompts}_{args.dataset_name}.jsonl"
|
| 1547 |
+
)
|
| 1548 |
+
|
| 1549 |
+
result_details = {
|
| 1550 |
+
"input_lens": [output.prompt_len for output in outputs],
|
| 1551 |
+
"output_lens": output_lens,
|
| 1552 |
+
"ttfts": [output.ttft for output in outputs],
|
| 1553 |
+
"itls": [output.itl for output in outputs],
|
| 1554 |
+
"generated_texts": [output.generated_text for output in outputs],
|
| 1555 |
+
"errors": [output.error for output in outputs],
|
| 1556 |
+
}
|
| 1557 |
+
|
| 1558 |
+
# Append results to a JSONL file
|
| 1559 |
+
with open(output_file_name, "a") as file:
|
| 1560 |
+
if args.output_details:
|
| 1561 |
+
result_for_dump = result | result_details
|
| 1562 |
+
else:
|
| 1563 |
+
result_for_dump = result
|
| 1564 |
+
file.write(json.dumps(result_for_dump) + "\n")
|
| 1565 |
+
|
| 1566 |
+
return result | result_details
|
| 1567 |
+
|
| 1568 |
+
|
| 1569 |
+
def check_chat_template(model_path):
|
| 1570 |
+
try:
|
| 1571 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 1572 |
+
return "chat_template" in tokenizer.init_kwargs
|
| 1573 |
+
except Exception as e:
|
| 1574 |
+
print(f"Fail to load tokenizer config with error={e}")
|
| 1575 |
+
return False
|
| 1576 |
+
|
| 1577 |
+
|
| 1578 |
+
def set_global_args(args_: argparse.Namespace):
|
| 1579 |
+
"""Set the global args."""
|
| 1580 |
+
global args
|
| 1581 |
+
args = args_
|
| 1582 |
+
|
| 1583 |
+
|
| 1584 |
+
def run_benchmark(args_: argparse.Namespace):
|
| 1585 |
+
global args
|
| 1586 |
+
args = args_
|
| 1587 |
+
|
| 1588 |
+
# Set default value for max_concurrency if not present
|
| 1589 |
+
if not hasattr(args, "max_concurrency"):
|
| 1590 |
+
args.max_concurrency = None
|
| 1591 |
+
|
| 1592 |
+
# Set default value for warmup_requests if not present
|
| 1593 |
+
if not hasattr(args, "warmup_requests"):
|
| 1594 |
+
args.warmup_requests = 1
|
| 1595 |
+
|
| 1596 |
+
if not hasattr(args, "output_details"):
|
| 1597 |
+
args.output_details = False
|
| 1598 |
+
|
| 1599 |
+
if not hasattr(args, "tokenize_prompt"):
|
| 1600 |
+
args.tokenize_prompt = False
|
| 1601 |
+
|
| 1602 |
+
if not hasattr(args, "plot_throughput"):
|
| 1603 |
+
args.plot_throughput = False
|
| 1604 |
+
|
| 1605 |
+
if not hasattr(args, "use_trace_timestamps"):
|
| 1606 |
+
args.use_trace_timestamps = False
|
| 1607 |
+
if not hasattr(args, "mooncake_slowdown_factor"):
|
| 1608 |
+
args.mooncake_slowdown_factor = 1.0
|
| 1609 |
+
|
| 1610 |
+
if not hasattr(args, "mooncake_slowdown_factor"):
|
| 1611 |
+
args.mooncake_slowdown_factor = 1.0
|
| 1612 |
+
|
| 1613 |
+
if not hasattr(args, "mooncake_num_rounds"):
|
| 1614 |
+
args.mooncake_num_rounds = 1
|
| 1615 |
+
|
| 1616 |
+
if not hasattr(args, "served_model_name"):
|
| 1617 |
+
args.served_model_name = None
|
| 1618 |
+
|
| 1619 |
+
if getattr(args, "print_requests", False):
|
| 1620 |
+
assert args.backend == "sglang-oai-chat" # only support this now
|
| 1621 |
+
|
| 1622 |
+
print(f"benchmark_args={args}")
|
| 1623 |
+
|
| 1624 |
+
# Set global environments
|
| 1625 |
+
set_ulimit()
|
| 1626 |
+
random.seed(args.seed)
|
| 1627 |
+
np.random.seed(args.seed)
|
| 1628 |
+
|
| 1629 |
+
extra_request_body = {}
|
| 1630 |
+
if args.extra_request_body:
|
| 1631 |
+
extra_request_body = json.loads(args.extra_request_body)
|
| 1632 |
+
|
| 1633 |
+
if args.tokenize_prompt:
|
| 1634 |
+
assert (
|
| 1635 |
+
args.backend == "sglang"
|
| 1636 |
+
), "`--tokenize-prompt` only compatible with `--backend sglang` currently"
|
| 1637 |
+
|
| 1638 |
+
# Set url
|
| 1639 |
+
if args.port is None:
|
| 1640 |
+
args.port = {
|
| 1641 |
+
"sglang": 30000,
|
| 1642 |
+
"sglang-native": 30000,
|
| 1643 |
+
"sglang-oai": 30000,
|
| 1644 |
+
"lmdeploy": 23333,
|
| 1645 |
+
"vllm": 8000,
|
| 1646 |
+
"trt": 8000,
|
| 1647 |
+
"gserver": 9988,
|
| 1648 |
+
"truss": 8080,
|
| 1649 |
+
}.get(args.backend, 30000)
|
| 1650 |
+
|
| 1651 |
+
model_url = (
|
| 1652 |
+
f"{args.base_url}/v1/models"
|
| 1653 |
+
if args.base_url
|
| 1654 |
+
else f"http://{args.host}:{args.port}/v1/models"
|
| 1655 |
+
)
|
| 1656 |
+
|
| 1657 |
+
if args.backend in ["sglang", "sglang-native"]:
|
| 1658 |
+
api_url = (
|
| 1659 |
+
f"{args.base_url}/generate"
|
| 1660 |
+
if args.base_url
|
| 1661 |
+
else f"http://{args.host}:{args.port}/generate"
|
| 1662 |
+
)
|
| 1663 |
+
elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
|
| 1664 |
+
api_url = (
|
| 1665 |
+
f"{args.base_url}/v1/completions"
|
| 1666 |
+
if args.base_url
|
| 1667 |
+
else f"http://{args.host}:{args.port}/v1/completions"
|
| 1668 |
+
)
|
| 1669 |
+
elif args.backend in ["sglang-oai-chat", "vllm-chat", "lmdeploy-chat"]:
|
| 1670 |
+
api_url = (
|
| 1671 |
+
f"{args.base_url}/v1/chat/completions"
|
| 1672 |
+
if args.base_url
|
| 1673 |
+
else f"http://{args.host}:{args.port}/v1/chat/completions"
|
| 1674 |
+
)
|
| 1675 |
+
elif args.backend == "trt":
|
| 1676 |
+
api_url = (
|
| 1677 |
+
f"{args.base_url}/v2/models/ensemble/generate_stream"
|
| 1678 |
+
if args.base_url
|
| 1679 |
+
else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream"
|
| 1680 |
+
)
|
| 1681 |
+
if args.model is None:
|
| 1682 |
+
print("Please provide a model using `--model` when using `trt` backend.")
|
| 1683 |
+
sys.exit(1)
|
| 1684 |
+
elif args.backend == "gserver":
|
| 1685 |
+
api_url = args.base_url if args.base_url else f"{args.host}:{args.port}"
|
| 1686 |
+
args.model = args.model or "default"
|
| 1687 |
+
elif args.backend == "truss":
|
| 1688 |
+
api_url = (
|
| 1689 |
+
f"{args.base_url}/v1/models/model:predict"
|
| 1690 |
+
if args.base_url
|
| 1691 |
+
else f"http://{args.host}:{args.port}/v1/models/model:predict"
|
| 1692 |
+
)
|
| 1693 |
+
base_url = (
|
| 1694 |
+
f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url
|
| 1695 |
+
)
|
| 1696 |
+
|
| 1697 |
+
# Wait for server to be ready
|
| 1698 |
+
if args.ready_check_timeout_sec > 0:
|
| 1699 |
+
health_url = model_url if args.backend not in ("trt", "gserver") else base_url
|
| 1700 |
+
if not wait_for_endpoint(health_url, args.ready_check_timeout_sec):
|
| 1701 |
+
print(f"Server at {health_url} is not ready. Exiting.")
|
| 1702 |
+
sys.exit(1)
|
| 1703 |
+
|
| 1704 |
+
# Get model name
|
| 1705 |
+
if args.model is None:
|
| 1706 |
+
if args.backend == "truss":
|
| 1707 |
+
print(
|
| 1708 |
+
"Please provide a model with `--model` when using truss backend. e.g. --model meta-llama/Llama-3.1-8B-Instruct"
|
| 1709 |
+
)
|
| 1710 |
+
sys.exit(1)
|
| 1711 |
+
try:
|
| 1712 |
+
response = requests.get(model_url, headers=get_auth_headers())
|
| 1713 |
+
model_list = response.json().get("data", [])
|
| 1714 |
+
args.model = model_list[0]["id"] if model_list else None
|
| 1715 |
+
except Exception as e:
|
| 1716 |
+
print(f"Failed to fetch model from {model_url}. Error: {e}")
|
| 1717 |
+
print(
|
| 1718 |
+
"Please specify the correct host and port using `--host` and `--port`."
|
| 1719 |
+
)
|
| 1720 |
+
sys.exit(1)
|
| 1721 |
+
|
| 1722 |
+
if args.model is None:
|
| 1723 |
+
print("No model specified or found. Please provide a model using `--model`.")
|
| 1724 |
+
sys.exit(1)
|
| 1725 |
+
|
| 1726 |
+
if not check_chat_template(args.model):
|
| 1727 |
+
print(
|
| 1728 |
+
"\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n"
|
| 1729 |
+
"Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n"
|
| 1730 |
+
)
|
| 1731 |
+
|
| 1732 |
+
if args.dataset_name in ["image", "mmmu"]:
|
| 1733 |
+
args.apply_chat_template = True
|
| 1734 |
+
assert (
|
| 1735 |
+
not args.tokenize_prompt
|
| 1736 |
+
), "`--tokenize-prompt` not compatible with image dataset"
|
| 1737 |
+
|
| 1738 |
+
if args.lora_request_distribution in ["distinct", "skewed"]:
|
| 1739 |
+
assert (
|
| 1740 |
+
args.lora_name is not None and len(args.lora_name) > 1
|
| 1741 |
+
), "More than 1 LoRA adapter must be specified via --lora-name to use 'distinct' or 'skewed' request distribution."
|
| 1742 |
+
|
| 1743 |
+
assert (
|
| 1744 |
+
args.lora_zipf_alpha > 1
|
| 1745 |
+
), f"Got invalid value for --lora-zipf-alpha of {args.lora_zipf_alpha}. It must be greater than 1."
|
| 1746 |
+
|
| 1747 |
+
print(f"{args}\n")
|
| 1748 |
+
|
| 1749 |
+
# Read dataset
|
| 1750 |
+
backend = args.backend
|
| 1751 |
+
model_id = args.served_model_name or args.model
|
| 1752 |
+
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
| 1753 |
+
tokenizer = get_tokenizer(tokenizer_id)
|
| 1754 |
+
input_requests = get_dataset(args, tokenizer, model_id)
|
| 1755 |
+
|
| 1756 |
+
# compatible with SimpleNamespace
|
| 1757 |
+
if not hasattr(args, "flush_cache"):
|
| 1758 |
+
args.flush_cache = False
|
| 1759 |
+
|
| 1760 |
+
# Prepare LoRA arguments
|
| 1761 |
+
lora_request_distribution = (
|
| 1762 |
+
args.lora_request_distribution if args.lora_name is not None else None
|
| 1763 |
+
)
|
| 1764 |
+
|
| 1765 |
+
lora_zipf_alpha = (
|
| 1766 |
+
args.lora_zipf_alpha
|
| 1767 |
+
if args.lora_name is not None and args.lora_request_distribution == "skewed"
|
| 1768 |
+
else None
|
| 1769 |
+
)
|
| 1770 |
+
|
| 1771 |
+
return asyncio.run(
|
| 1772 |
+
benchmark(
|
| 1773 |
+
backend=backend,
|
| 1774 |
+
api_url=api_url,
|
| 1775 |
+
base_url=base_url,
|
| 1776 |
+
model_id=model_id,
|
| 1777 |
+
tokenizer=tokenizer,
|
| 1778 |
+
input_requests=input_requests,
|
| 1779 |
+
request_rate=args.request_rate,
|
| 1780 |
+
max_concurrency=args.max_concurrency,
|
| 1781 |
+
disable_tqdm=args.disable_tqdm,
|
| 1782 |
+
lora_names=args.lora_name,
|
| 1783 |
+
lora_request_distribution=lora_request_distribution,
|
| 1784 |
+
lora_zipf_alpha=lora_zipf_alpha,
|
| 1785 |
+
extra_request_body=extra_request_body,
|
| 1786 |
+
profile=args.profile,
|
| 1787 |
+
pd_separated=args.pd_separated,
|
| 1788 |
+
flush_cache=args.flush_cache,
|
| 1789 |
+
warmup_requests=args.warmup_requests,
|
| 1790 |
+
use_trace_timestamps=args.use_trace_timestamps,
|
| 1791 |
+
mooncake_slowdown_factor=args.mooncake_slowdown_factor,
|
| 1792 |
+
mooncake_num_rounds=args.mooncake_num_rounds,
|
| 1793 |
+
profile_prefill_url=getattr(args, "profile_prefill_url", None),
|
| 1794 |
+
profile_decode_url=getattr(args, "profile_decode_url", None),
|
| 1795 |
+
)
|
| 1796 |
+
)
|
| 1797 |
+
|
| 1798 |
+
|
| 1799 |
+
class LoRAPathAction(argparse.Action):
|
| 1800 |
+
def __call__(self, parser, namespace, values, option_string=None):
|
| 1801 |
+
setattr(namespace, self.dest, [])
|
| 1802 |
+
for lora_name in values:
|
| 1803 |
+
getattr(namespace, self.dest).append(lora_name)
|
| 1804 |
+
|
| 1805 |
+
|
| 1806 |
+
if __name__ == "__main__":
|
| 1807 |
+
parser = ArgumentParser(description="Benchmark the online serving throughput.")
|
| 1808 |
+
parser.add_argument(
|
| 1809 |
+
"--backend",
|
| 1810 |
+
type=str,
|
| 1811 |
+
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
| 1812 |
+
default="sglang",
|
| 1813 |
+
help="Must specify a backend, depending on the LLM Inference Engine.",
|
| 1814 |
+
)
|
| 1815 |
+
parser.add_argument(
|
| 1816 |
+
"--base-url",
|
| 1817 |
+
type=str,
|
| 1818 |
+
default=None,
|
| 1819 |
+
help="Server or API base url if not using http host and port.",
|
| 1820 |
+
)
|
| 1821 |
+
parser.add_argument(
|
| 1822 |
+
"--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
|
| 1823 |
+
)
|
| 1824 |
+
parser.add_argument(
|
| 1825 |
+
"--port",
|
| 1826 |
+
type=int,
|
| 1827 |
+
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
|
| 1828 |
+
)
|
| 1829 |
+
parser.add_argument(
|
| 1830 |
+
"--ready-check-timeout-sec",
|
| 1831 |
+
type=int,
|
| 1832 |
+
default=60,
|
| 1833 |
+
help="Maximum time in seconds to wait for the server to be ready before benchmarking. Set to 0 to skip. Default: 60.",
|
| 1834 |
+
)
|
| 1835 |
+
parser.add_argument(
|
| 1836 |
+
"--dataset-name",
|
| 1837 |
+
type=str,
|
| 1838 |
+
default="sharegpt",
|
| 1839 |
+
choices=[
|
| 1840 |
+
"sharegpt",
|
| 1841 |
+
"custom",
|
| 1842 |
+
"openai",
|
| 1843 |
+
"random",
|
| 1844 |
+
"random-ids",
|
| 1845 |
+
"generated-shared-prefix",
|
| 1846 |
+
"mmmu",
|
| 1847 |
+
"image",
|
| 1848 |
+
"mooncake",
|
| 1849 |
+
],
|
| 1850 |
+
help="Name of the dataset to benchmark on.",
|
| 1851 |
+
)
|
| 1852 |
+
parser.add_argument(
|
| 1853 |
+
"--dataset-path", type=str, default="", help="Path to the dataset."
|
| 1854 |
+
)
|
| 1855 |
+
parser.add_argument(
|
| 1856 |
+
"--model",
|
| 1857 |
+
type=str,
|
| 1858 |
+
help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
|
| 1859 |
+
)
|
| 1860 |
+
parser.add_argument(
|
| 1861 |
+
"--served-model-name",
|
| 1862 |
+
type=str,
|
| 1863 |
+
help="The name of the model as served by the serving service. If not set, this defaults to the value of --model.",
|
| 1864 |
+
)
|
| 1865 |
+
parser.add_argument(
|
| 1866 |
+
"--tokenizer",
|
| 1867 |
+
type=str,
|
| 1868 |
+
help="Name or path of the tokenizer. If not set, using the model conf.",
|
| 1869 |
+
)
|
| 1870 |
+
parser.add_argument(
|
| 1871 |
+
"--num-prompts",
|
| 1872 |
+
type=int,
|
| 1873 |
+
default=1000,
|
| 1874 |
+
help="Number of prompts to process. Default is 1000.",
|
| 1875 |
+
)
|
| 1876 |
+
parser.add_argument(
|
| 1877 |
+
"--sharegpt-output-len",
|
| 1878 |
+
type=int,
|
| 1879 |
+
default=None,
|
| 1880 |
+
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
|
| 1881 |
+
)
|
| 1882 |
+
parser.add_argument(
|
| 1883 |
+
"--sharegpt-context-len",
|
| 1884 |
+
type=int,
|
| 1885 |
+
default=None,
|
| 1886 |
+
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
|
| 1887 |
+
)
|
| 1888 |
+
parser.add_argument(
|
| 1889 |
+
"--random-input-len",
|
| 1890 |
+
type=int,
|
| 1891 |
+
default=1024,
|
| 1892 |
+
help="Number of input tokens per request, used only for random and image dataset.",
|
| 1893 |
+
)
|
| 1894 |
+
parser.add_argument(
|
| 1895 |
+
"--random-output-len",
|
| 1896 |
+
default=1024,
|
| 1897 |
+
type=int,
|
| 1898 |
+
help="Number of output tokens per request, used only for random and image dataset.",
|
| 1899 |
+
)
|
| 1900 |
+
parser.add_argument(
|
| 1901 |
+
"--random-range-ratio",
|
| 1902 |
+
type=float,
|
| 1903 |
+
default=0.0,
|
| 1904 |
+
help="Range of sampled ratio of input/output length, "
|
| 1905 |
+
"used only for random and image dataset.",
|
| 1906 |
+
)
|
| 1907 |
+
# image dataset args
|
| 1908 |
+
parser.add_argument(
|
| 1909 |
+
"--image-count",
|
| 1910 |
+
type=int,
|
| 1911 |
+
default=1,
|
| 1912 |
+
help="Number of images per request (only available with the image dataset)",
|
| 1913 |
+
)
|
| 1914 |
+
parser.add_argument(
|
| 1915 |
+
"--image-resolution",
|
| 1916 |
+
type=str,
|
| 1917 |
+
default="1080p",
|
| 1918 |
+
help=(
|
| 1919 |
+
"Resolution of images for image dataset. "
|
| 1920 |
+
"Supports presets 4k/1080p/720p/360p or custom 'heightxwidth' (e.g., 1080x1920)."
|
| 1921 |
+
),
|
| 1922 |
+
)
|
| 1923 |
+
parser.add_argument(
|
| 1924 |
+
"--random-image-count",
|
| 1925 |
+
action="store_true",
|
| 1926 |
+
help="Enable Random Image Count",
|
| 1927 |
+
)
|
| 1928 |
+
parser.add_argument(
|
| 1929 |
+
"--image-format",
|
| 1930 |
+
type=str,
|
| 1931 |
+
default="jpeg",
|
| 1932 |
+
help=("Format of images for image dataset. " "Supports jpeg and png."),
|
| 1933 |
+
)
|
| 1934 |
+
parser.add_argument(
|
| 1935 |
+
"--image-content",
|
| 1936 |
+
type=str,
|
| 1937 |
+
default="random",
|
| 1938 |
+
help=("Content for images for image dataset. " "Supports random and blank."),
|
| 1939 |
+
)
|
| 1940 |
+
parser.add_argument(
|
| 1941 |
+
"--request-rate",
|
| 1942 |
+
type=float,
|
| 1943 |
+
default=float("inf"),
|
| 1944 |
+
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
|
| 1945 |
+
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
|
| 1946 |
+
)
|
| 1947 |
+
parser.add_argument(
|
| 1948 |
+
"--use-trace-timestamps",
|
| 1949 |
+
action="store_true",
|
| 1950 |
+
help="Use timestamps from the trace file for request scheduling. Only valid for 'mooncake' dataset.",
|
| 1951 |
+
)
|
| 1952 |
+
parser.add_argument(
|
| 1953 |
+
"--max-concurrency",
|
| 1954 |
+
type=int,
|
| 1955 |
+
default=None,
|
| 1956 |
+
help="Maximum number of concurrent requests. This can be used "
|
| 1957 |
+
"to help simulate an environment where a higher level component "
|
| 1958 |
+
"is enforcing a maximum number of concurrent requests. While the "
|
| 1959 |
+
"--request-rate argument controls the rate at which requests are "
|
| 1960 |
+
"initiated, this argument will control how many are actually allowed "
|
| 1961 |
+
"to execute at a time. This means that when used in combination, the "
|
| 1962 |
+
"actual request rate may be lower than specified with --request-rate, "
|
| 1963 |
+
"if the server is not processing requests fast enough to keep up.",
|
| 1964 |
+
)
|
| 1965 |
+
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
|
| 1966 |
+
parser.add_argument(
|
| 1967 |
+
"--output-details", action="store_true", help="Output details of benchmarking."
|
| 1968 |
+
)
|
| 1969 |
+
parser.add_argument(
|
| 1970 |
+
"--print-requests",
|
| 1971 |
+
action="store_true",
|
| 1972 |
+
help="Print requests immediately during benchmarking. Useful to quickly realize issues.",
|
| 1973 |
+
)
|
| 1974 |
+
parser.add_argument(
|
| 1975 |
+
"--disable-tqdm",
|
| 1976 |
+
action="store_true",
|
| 1977 |
+
help="Specify to disable tqdm progress bar.",
|
| 1978 |
+
)
|
| 1979 |
+
parser.add_argument(
|
| 1980 |
+
"--disable-stream",
|
| 1981 |
+
action="store_true",
|
| 1982 |
+
help="Disable streaming mode.",
|
| 1983 |
+
)
|
| 1984 |
+
parser.add_argument(
|
| 1985 |
+
"--return-logprob",
|
| 1986 |
+
action="store_true",
|
| 1987 |
+
help="Return logprob.",
|
| 1988 |
+
)
|
| 1989 |
+
parser.add_argument(
|
| 1990 |
+
"--return-routed-experts",
|
| 1991 |
+
action="store_true",
|
| 1992 |
+
help="Return routed experts.",
|
| 1993 |
+
)
|
| 1994 |
+
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
| 1995 |
+
parser.add_argument(
|
| 1996 |
+
"--disable-ignore-eos",
|
| 1997 |
+
action="store_true",
|
| 1998 |
+
help="Disable ignoring EOS.",
|
| 1999 |
+
)
|
| 2000 |
+
parser.add_argument(
|
| 2001 |
+
"--extra-request-body",
|
| 2002 |
+
metavar='{"key1": "value1", "key2": "value2"}',
|
| 2003 |
+
type=str,
|
| 2004 |
+
help="Append given JSON object to the request payload. You can use this to specify"
|
| 2005 |
+
"additional generate params like sampling params.",
|
| 2006 |
+
)
|
| 2007 |
+
parser.add_argument(
|
| 2008 |
+
"--apply-chat-template",
|
| 2009 |
+
action="store_true",
|
| 2010 |
+
help="Apply chat template",
|
| 2011 |
+
)
|
| 2012 |
+
parser.add_argument(
|
| 2013 |
+
"--profile",
|
| 2014 |
+
action="store_true",
|
| 2015 |
+
help="Use Torch Profiler. The endpoint must be launched with "
|
| 2016 |
+
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
| 2017 |
+
)
|
| 2018 |
+
parser.add_argument(
|
| 2019 |
+
"--plot-throughput",
|
| 2020 |
+
action="store_true",
|
| 2021 |
+
help="Plot throughput and concurrent requests over time. Requires termplotlib and gnuplot.",
|
| 2022 |
+
)
|
| 2023 |
+
# TODO unify all these
|
| 2024 |
+
parser.add_argument(
|
| 2025 |
+
"--profile-activities",
|
| 2026 |
+
type=str,
|
| 2027 |
+
nargs="+",
|
| 2028 |
+
default=["CPU", "GPU"],
|
| 2029 |
+
choices=["CPU", "GPU", "CUDA_PROFILER", "XPU"],
|
| 2030 |
+
help="Profiler activities to capture: CPU, GPU, XPU, CUDA_PROFILER.",
|
| 2031 |
+
)
|
| 2032 |
+
parser.add_argument(
|
| 2033 |
+
"--profile-start-step",
|
| 2034 |
+
type=int,
|
| 2035 |
+
default=None,
|
| 2036 |
+
help="Start profiling after this many forward steps. Useful for warmup.",
|
| 2037 |
+
)
|
| 2038 |
+
parser.add_argument(
|
| 2039 |
+
"--profile-steps",
|
| 2040 |
+
type=int,
|
| 2041 |
+
default=None,
|
| 2042 |
+
help="Number of steps to profile. If specified, profiling stops automatically after this many steps.",
|
| 2043 |
+
)
|
| 2044 |
+
parser.add_argument("--profile-num-steps", type=int, default=None)
|
| 2045 |
+
parser.add_argument("--profile-by-stage", action="store_true", default=False)
|
| 2046 |
+
parser.add_argument("--profile-stages", nargs="+", default=None)
|
| 2047 |
+
parser.add_argument(
|
| 2048 |
+
"--profile-output-dir",
|
| 2049 |
+
type=str,
|
| 2050 |
+
default=None,
|
| 2051 |
+
help="Output directory for profile traces.",
|
| 2052 |
+
)
|
| 2053 |
+
parser.add_argument(
|
| 2054 |
+
"--profile-prefix",
|
| 2055 |
+
type=str,
|
| 2056 |
+
default=None,
|
| 2057 |
+
help="Prefix for profile trace filenames.",
|
| 2058 |
+
)
|
| 2059 |
+
parser.add_argument(
|
| 2060 |
+
"--lora-name",
|
| 2061 |
+
type=str,
|
| 2062 |
+
nargs="*",
|
| 2063 |
+
default=None,
|
| 2064 |
+
action=LoRAPathAction,
|
| 2065 |
+
help="The names of LoRA adapters. You can provide a list of names in the format {name} {name} {name}...",
|
| 2066 |
+
)
|
| 2067 |
+
parser.add_argument(
|
| 2068 |
+
"--lora-request-distribution",
|
| 2069 |
+
type=str,
|
| 2070 |
+
default="uniform",
|
| 2071 |
+
choices=[
|
| 2072 |
+
"uniform",
|
| 2073 |
+
"distinct",
|
| 2074 |
+
"skewed",
|
| 2075 |
+
],
|
| 2076 |
+
help="What distribution to sample the LoRA adapters specified in --lora-name. Borrowed from the Punica paper. "
|
| 2077 |
+
"'distinct' distribution means selecting a new LoRA adapter for every request. "
|
| 2078 |
+
"'skewed' distribution follows the Zipf distribution, where the number of requests "
|
| 2079 |
+
"to model i specified in --lora-name is α times the number of requests for model i+1, "
|
| 2080 |
+
"where α > 1.",
|
| 2081 |
+
)
|
| 2082 |
+
parser.add_argument(
|
| 2083 |
+
"--lora-zipf-alpha",
|
| 2084 |
+
type=float,
|
| 2085 |
+
default=1.5,
|
| 2086 |
+
help="The parameter to use for the Zipf distribution when --lora-request-distribution='skewed'.",
|
| 2087 |
+
)
|
| 2088 |
+
parser.add_argument(
|
| 2089 |
+
"--prompt-suffix",
|
| 2090 |
+
type=str,
|
| 2091 |
+
default="",
|
| 2092 |
+
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
|
| 2093 |
+
)
|
| 2094 |
+
parser.add_argument(
|
| 2095 |
+
"--pd-separated",
|
| 2096 |
+
action="store_true",
|
| 2097 |
+
help="Benchmark PD disaggregation server",
|
| 2098 |
+
)
|
| 2099 |
+
|
| 2100 |
+
# Create a mutually exclusive group for profiling URLs
|
| 2101 |
+
# In PD separated mode, prefill and decode workers must be profiled separately
|
| 2102 |
+
profile_url_group = parser.add_mutually_exclusive_group()
|
| 2103 |
+
profile_url_group.add_argument(
|
| 2104 |
+
"--profile-prefill-url",
|
| 2105 |
+
type=str,
|
| 2106 |
+
nargs="*",
|
| 2107 |
+
default=None,
|
| 2108 |
+
help="URL(s) of the prefill worker(s) for profiling in PD separated mode. "
|
| 2109 |
+
"Can specify multiple URLs: --profile-prefill-url http://localhost:30000 http://localhost:30001. "
|
| 2110 |
+
"NOTE: Cannot be used together with --profile-decode-url. "
|
| 2111 |
+
"In PD separated mode, prefill and decode workers must be profiled separately.",
|
| 2112 |
+
)
|
| 2113 |
+
profile_url_group.add_argument(
|
| 2114 |
+
"--profile-decode-url",
|
| 2115 |
+
type=str,
|
| 2116 |
+
nargs="*",
|
| 2117 |
+
default=None,
|
| 2118 |
+
help="URL(s) of the decode worker(s) for profiling in PD separated mode. "
|
| 2119 |
+
"Can specify multiple URLs: --profile-decode-url http://localhost:30010 http://localhost:30011. "
|
| 2120 |
+
"NOTE: Cannot be used together with --profile-prefill-url. "
|
| 2121 |
+
"In PD separated mode, prefill and decode workers must be profiled separately.",
|
| 2122 |
+
)
|
| 2123 |
+
parser.add_argument(
|
| 2124 |
+
"--flush-cache",
|
| 2125 |
+
action="store_true",
|
| 2126 |
+
help="Flush the cache before running the benchmark",
|
| 2127 |
+
)
|
| 2128 |
+
parser.add_argument(
|
| 2129 |
+
"--warmup-requests",
|
| 2130 |
+
type=int,
|
| 2131 |
+
default=1,
|
| 2132 |
+
help="Number of warmup requests to run before the benchmark",
|
| 2133 |
+
)
|
| 2134 |
+
parser.add_argument(
|
| 2135 |
+
"--tokenize-prompt",
|
| 2136 |
+
action="store_true",
|
| 2137 |
+
help="Use integer ids instead of string for inputs. Useful to control prompt lengths accurately",
|
| 2138 |
+
)
|
| 2139 |
+
|
| 2140 |
+
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
| 2141 |
+
group.add_argument(
|
| 2142 |
+
"--gsp-num-groups",
|
| 2143 |
+
type=int,
|
| 2144 |
+
default=64,
|
| 2145 |
+
help="Number of system prompt groups for generated-shared-prefix dataset",
|
| 2146 |
+
)
|
| 2147 |
+
group.add_argument(
|
| 2148 |
+
"--gsp-prompts-per-group",
|
| 2149 |
+
type=int,
|
| 2150 |
+
default=16,
|
| 2151 |
+
help="Number of prompts per system prompt group for generated-shared-prefix dataset",
|
| 2152 |
+
)
|
| 2153 |
+
group.add_argument(
|
| 2154 |
+
"--gsp-system-prompt-len",
|
| 2155 |
+
type=int,
|
| 2156 |
+
default=2048,
|
| 2157 |
+
help="Target length in tokens for system prompts in generated-shared-prefix dataset",
|
| 2158 |
+
)
|
| 2159 |
+
group.add_argument(
|
| 2160 |
+
"--gsp-question-len",
|
| 2161 |
+
type=int,
|
| 2162 |
+
default=128,
|
| 2163 |
+
help="Target length in tokens for questions in generated-shared-prefix dataset",
|
| 2164 |
+
)
|
| 2165 |
+
group.add_argument(
|
| 2166 |
+
"--gsp-output-len",
|
| 2167 |
+
type=int,
|
| 2168 |
+
default=256,
|
| 2169 |
+
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
| 2170 |
+
)
|
| 2171 |
+
parser.add_argument(
|
| 2172 |
+
"--gsp-range-ratio",
|
| 2173 |
+
type=float,
|
| 2174 |
+
# WARN: The default 1.0 is for backward compatibility, and is different from the default 0.0 for random dataset
|
| 2175 |
+
default=1.0,
|
| 2176 |
+
help="Range of sampled ratio of input/output length, used only for gsp dataset.",
|
| 2177 |
+
)
|
| 2178 |
+
group.add_argument(
|
| 2179 |
+
"--gsp-fast-prepare",
|
| 2180 |
+
action="store_true",
|
| 2181 |
+
help="Speedup preparing by removing statistics computation, which will make some output statistics inaccurate but suitable for pressure tests.",
|
| 2182 |
+
)
|
| 2183 |
+
group.add_argument(
|
| 2184 |
+
"--gsp-send-routing-key",
|
| 2185 |
+
action="store_true",
|
| 2186 |
+
help="Send routing key in requests via X-SMG-Routing-Key header. Requests with the same prefix share the same routing key.",
|
| 2187 |
+
)
|
| 2188 |
+
group.add_argument(
|
| 2189 |
+
"--gsp-num-turns",
|
| 2190 |
+
type=int,
|
| 2191 |
+
default=1,
|
| 2192 |
+
help="Number of turns for multi-turn conversations. If > 1, each prompt becomes a list of questions sharing the same system prefix.",
|
| 2193 |
+
)
|
| 2194 |
+
group.add_argument(
|
| 2195 |
+
"--gsp-ordered",
|
| 2196 |
+
action="store_true",
|
| 2197 |
+
help="Keep requests in order without shuffling. By default, requests are shuffled randomly.",
|
| 2198 |
+
)
|
| 2199 |
+
mooncake_group = parser.add_argument_group("mooncake dataset arguments")
|
| 2200 |
+
mooncake_group.add_argument(
|
| 2201 |
+
"--mooncake-slowdown-factor",
|
| 2202 |
+
type=float,
|
| 2203 |
+
default=1.0,
|
| 2204 |
+
help="Slowdown factor for replaying the mooncake trace. "
|
| 2205 |
+
"A value of 2.0 means the replay is twice as slow. "
|
| 2206 |
+
"NOTE: --request-rate is IGNORED in mooncake mode.",
|
| 2207 |
+
)
|
| 2208 |
+
mooncake_group.add_argument(
|
| 2209 |
+
"--mooncake-num-rounds",
|
| 2210 |
+
type=int,
|
| 2211 |
+
default=1,
|
| 2212 |
+
help="Number of conversation rounds for each session in the mooncake dataset. "
|
| 2213 |
+
"A value > 1 will enable true multi-turn session benchmarking.",
|
| 2214 |
+
)
|
| 2215 |
+
mooncake_group.add_argument(
|
| 2216 |
+
"--mooncake-workload",
|
| 2217 |
+
type=str,
|
| 2218 |
+
default="conversation",
|
| 2219 |
+
choices=[
|
| 2220 |
+
"mooncake",
|
| 2221 |
+
"conversation",
|
| 2222 |
+
"synthetic",
|
| 2223 |
+
"toolagent",
|
| 2224 |
+
],
|
| 2225 |
+
help="Underlying workload for the mooncake dataset.",
|
| 2226 |
+
)
|
| 2227 |
+
parser.add_argument(
|
| 2228 |
+
"--tag", type=str, default=None, help="The tag to be dumped to output."
|
| 2229 |
+
)
|
| 2230 |
+
parser.add_argument(
|
| 2231 |
+
"--header",
|
| 2232 |
+
type=str,
|
| 2233 |
+
nargs="+",
|
| 2234 |
+
default=None,
|
| 2235 |
+
help="Custom HTTP headers in Key=Value format. Example: --header MyHeader=MY_VALUE MyAnotherHeader=myanothervalue",
|
| 2236 |
+
)
|
| 2237 |
+
args = parser.parse_args()
|
| 2238 |
+
run_benchmark(args)
|
sglang/python/sglang/benchmark/__init__.py
ADDED
|
File without changes
|
sglang/python/sglang/benchmark/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (169 Bytes). View file
|
|
|
sglang/python/sglang/benchmark/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (7.65 kB). View file
|
|
|
sglang/python/sglang/benchmark/datasets/__init__.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Type
|
| 2 |
+
|
| 3 |
+
from sglang.benchmark.datasets.common import BaseDataset, DatasetRow
|
| 4 |
+
from sglang.benchmark.datasets.custom import CustomDataset
|
| 5 |
+
from sglang.benchmark.datasets.generated_shared_prefix import (
|
| 6 |
+
GeneratedSharedPrefixDataset,
|
| 7 |
+
)
|
| 8 |
+
from sglang.benchmark.datasets.image import ImageDataset
|
| 9 |
+
from sglang.benchmark.datasets.mmmu import MMMUDataset
|
| 10 |
+
from sglang.benchmark.datasets.mooncake import MooncakeDataset
|
| 11 |
+
from sglang.benchmark.datasets.openai_dataset import OpenAIDataset
|
| 12 |
+
from sglang.benchmark.datasets.random import RandomDataset
|
| 13 |
+
from sglang.benchmark.datasets.sharegpt import ShareGPTDataset
|
| 14 |
+
|
| 15 |
+
DATASET_MAPPING: Dict[str, Type[BaseDataset]] = {
|
| 16 |
+
"sharegpt": ShareGPTDataset,
|
| 17 |
+
"custom": CustomDataset,
|
| 18 |
+
"openai": OpenAIDataset,
|
| 19 |
+
# TODO: "random" vs "random-ids" should be a flag (e.g. --random-source=sharegpt|integers),
|
| 20 |
+
# not two separate dataset names sharing the same class.
|
| 21 |
+
"random": RandomDataset,
|
| 22 |
+
"random-ids": RandomDataset,
|
| 23 |
+
"generated-shared-prefix": GeneratedSharedPrefixDataset,
|
| 24 |
+
"mmmu": MMMUDataset,
|
| 25 |
+
"image": ImageDataset,
|
| 26 |
+
"mooncake": MooncakeDataset,
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_dataset(args, tokenizer, model_id=None):
|
| 31 |
+
dataset_name = args.dataset_name
|
| 32 |
+
if dataset_name.startswith("random") and dataset_name not in DATASET_MAPPING:
|
| 33 |
+
dataset_name = "random-ids"
|
| 34 |
+
|
| 35 |
+
if dataset_name not in DATASET_MAPPING:
|
| 36 |
+
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
| 37 |
+
|
| 38 |
+
dataset_cls = DATASET_MAPPING[dataset_name]
|
| 39 |
+
dataset = dataset_cls.from_args(args)
|
| 40 |
+
return dataset.load(tokenizer=tokenizer, model_id=model_id)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
__all__ = [
|
| 44 |
+
"DATASET_MAPPING",
|
| 45 |
+
"DatasetRow",
|
| 46 |
+
"get_dataset",
|
| 47 |
+
]
|
sglang/python/sglang/benchmark/datasets/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.1 kB). View file
|
|
|
sglang/python/sglang/benchmark/datasets/__pycache__/common.cpython-311.pyc
ADDED
|
Binary file (5.24 kB). View file
|
|
|
sglang/python/sglang/benchmark/datasets/__pycache__/custom.cpython-311.pyc
ADDED
|
Binary file (6.7 kB). View file
|
|
|
sglang/python/sglang/benchmark/datasets/__pycache__/generated_shared_prefix.cpython-311.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
sglang/python/sglang/benchmark/datasets/__pycache__/image.cpython-311.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
sglang/python/sglang/benchmark/datasets/__pycache__/mmmu.cpython-311.pyc
ADDED
|
Binary file (5.98 kB). View file
|
|
|