Lekr0 commited on
Commit
d02d576
·
verified ·
1 Parent(s): a402b9b

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. sglang/.claude/skills/add-jit-kernel/SKILL.md +553 -0
  2. sglang/.claude/skills/add-sgl-kernel/SKILL.md +358 -0
  3. sglang/.claude/skills/sglang-bisect-ci-regression/SKILL.md +219 -0
  4. sglang/.claude/skills/write-sglang-test/SKILL.md +248 -0
  5. sglang/benchmark/json_jump_forward/README.md +88 -0
  6. sglang/benchmark/json_jump_forward/bench_other.py +288 -0
  7. sglang/benchmark/json_jump_forward/bench_sglang.py +143 -0
  8. sglang/benchmark/json_jump_forward/build_dataset.py +58 -0
  9. sglang/benchmark/json_jump_forward/dataset.txt +50 -0
  10. sglang/benchmark/multi_turn_chat/bench_other.py +93 -0
  11. sglang/benchmark/multi_turn_chat/data_gen.py +29 -0
  12. sglang/benchmark/tree_of_thought_deep/README.md +51 -0
  13. sglang/benchmark/tree_of_thought_deep/bench_other.py +222 -0
  14. sglang/benchmark/tree_of_thought_deep/bench_sglang.py +171 -0
  15. sglang/docker/configs/.zshrc +27 -0
  16. sglang/docker/configs/opt/.gitconfig +30 -0
  17. sglang/docker/configs/opt/.tmux.conf +27 -0
  18. sglang/docker/configs/opt/.vimrc +45 -0
  19. sglang/docker/configs/yank +12 -0
  20. sglang/python/sglang.egg-info/PKG-INFO +120 -0
  21. sglang/python/sglang.egg-info/SOURCES.txt +0 -0
  22. sglang/python/sglang.egg-info/dependency_links.txt +1 -0
  23. sglang/python/sglang.egg-info/entry_points.txt +2 -0
  24. sglang/python/sglang.egg-info/requires.txt +121 -0
  25. sglang/python/sglang.egg-info/top_level.txt +1 -0
  26. sglang/python/sglang/README.md +18 -0
  27. sglang/python/sglang/__init__.py +83 -0
  28. sglang/python/sglang/__pycache__/__init__.cpython-311.pyc +0 -0
  29. sglang/python/sglang/__pycache__/_version.cpython-311.pyc +0 -0
  30. sglang/python/sglang/__pycache__/bench_serving.cpython-311.pyc +0 -0
  31. sglang/python/sglang/__pycache__/check_env.cpython-311.pyc +0 -0
  32. sglang/python/sglang/__pycache__/global_config.cpython-311.pyc +0 -0
  33. sglang/python/sglang/__pycache__/launch_server.cpython-311.pyc +0 -0
  34. sglang/python/sglang/__pycache__/utils.cpython-311.pyc +0 -0
  35. sglang/python/sglang/__pycache__/version.cpython-311.pyc +0 -0
  36. sglang/python/sglang/_version.py +34 -0
  37. sglang/python/sglang/bench_offline_throughput.py +543 -0
  38. sglang/python/sglang/bench_one_batch.py +837 -0
  39. sglang/python/sglang/bench_one_batch_server.py +49 -0
  40. sglang/python/sglang/bench_serving.py +2238 -0
  41. sglang/python/sglang/benchmark/__init__.py +0 -0
  42. sglang/python/sglang/benchmark/__pycache__/__init__.cpython-311.pyc +0 -0
  43. sglang/python/sglang/benchmark/__pycache__/utils.cpython-311.pyc +0 -0
  44. sglang/python/sglang/benchmark/datasets/__init__.py +47 -0
  45. sglang/python/sglang/benchmark/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
  46. sglang/python/sglang/benchmark/datasets/__pycache__/common.cpython-311.pyc +0 -0
  47. sglang/python/sglang/benchmark/datasets/__pycache__/custom.cpython-311.pyc +0 -0
  48. sglang/python/sglang/benchmark/datasets/__pycache__/generated_shared_prefix.cpython-311.pyc +0 -0
  49. sglang/python/sglang/benchmark/datasets/__pycache__/image.cpython-311.pyc +0 -0
  50. sglang/python/sglang/benchmark/datasets/__pycache__/mmmu.cpython-311.pyc +0 -0
sglang/.claude/skills/add-jit-kernel/SKILL.md ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: add-jit-kernel
3
+ description: Step-by-step tutorial for adding a new lightweight JIT CUDA kernel to sglang's jit_kernel module
4
+ ---
5
+
6
+ # Tutorial: Adding a New JIT Kernel to SGLang
7
+
8
+ This tutorial walks through adding a simple element-wise scale operation as a JIT kernel. We'll implement `scale(x, factor) = x * factor` to demonstrate the complete workflow.
9
+
10
+ ## Goal
11
+
12
+ Add a new operation that scales each element of a tensor by a scalar factor:
13
+
14
+ - Input: tensor `x` (CUDA) and scalar `factor` (float, passed as C++ template argument)
15
+ - Output: `x * factor` (element-wise), allocated internally
16
+ - Supported dtypes: **FP16 (`torch.float16`), BF16 (`torch.bfloat16`), FP32 (`torch.float32`)**
17
+
18
+ ## When to use JIT vs AOT (`sgl-kernel`)
19
+
20
+ - **JIT (`jit_kernel`)**: lightweight, few dependencies, rapid iteration, compiled on first use
21
+ - **AOT (`sgl-kernel`)**: depends on CUTLASS / FlashInfer / DeepGEMM, needs pre-built wheel
22
+
23
+ ---
24
+
25
+ ## Common Abstractions in `python/sglang/jit_kernel/include/sgl_kernel/`
26
+
27
+ **Always prefer these abstractions over raw CUDA primitives.** They provide safety, readability, and consistency with the rest of the codebase.
28
+
29
+ ### `utils.h` — Host-side utilities
30
+
31
+ ```cpp
32
+ #include <sgl_kernel/utils.h>
33
+ ```
34
+
35
+ - **`host::RuntimeCheck(cond, args...)`** — Assert a condition at runtime; throws `PanicError` with file/line info on failure. Prefer this over bare `assert`.
36
+ - **`host::Panic(args...)`** — Unconditionally throw a `PanicError` with a descriptive message.
37
+ - **`host::div_ceil(a, b)`** — Integer ceiling division `(a + b - 1) / b`.
38
+ - **`host::irange(n)`** / **`host::irange(start, end)`** — Range views for cleaner loops.
39
+ - **`host::pointer::offset(ptr, offsets...)`** — Byte-safe pointer arithmetic on `void*`. Use this instead of raw casts.
40
+
41
+ ### `utils.cuh` — Device-side utilities + `LaunchKernel`
42
+
43
+ ```cpp
44
+ #include <sgl_kernel/utils.cuh>
45
+ ```
46
+
47
+ - **Type aliases**: `fp16_t`, `bf16_t`, `fp32_t`, `fp8_e4m3_t`, `fp8_e5m2_t` and their packed variants `fp16x2_t`, `bf16x2_t`, `fp32x2_t`, etc.
48
+ - **`SGL_DEVICE`** — Expands to `__forceinline__ __device__`. Use on all device functions.
49
+ - **`device::kWarpThreads`** — Constant `32`.
50
+ - **`device::load_as<T>(ptr, offset)`** / **`device::store_as<T>(ptr, val, offset)`** — Type-safe loads/stores from `void*`.
51
+ - **`device::pointer::offset(ptr, offsets...)`** — Pointer arithmetic on device.
52
+ - **`host::LaunchKernel(grid, block, device_or_stream [, smem])`** — RAII kernel launcher that:
53
+ - Resolves the CUDA stream from a `DLDevice` via TVM-FFI automatically.
54
+ - Checks the CUDA error with file/line info after launch via `operator()(kernel, args...)`.
55
+ - Supports `.enable_pdl(bool)` for PDL (Programmatic Dependent Launch, SM90+).
56
+ - **`host::RuntimeDeviceCheck(cudaError_t)`** — Check a CUDA error; throw on failure.
57
+
58
+ ### `tensor.h` — Tensor validation (`TensorMatcher`, Symbolic types)
59
+
60
+ ```cpp
61
+ #include <sgl_kernel/tensor.h>
62
+ ```
63
+
64
+ This is the **primary validation API** for all kernel launchers. Use it to validate every `tvm::ffi::TensorView` argument.
65
+
66
+ - **`host::SymbolicSize{"name"}`** — A named symbolic dimension. Call `.set_value(n)` to pin it, `.unwrap()` to extract after verification.
67
+ - **`host::SymbolicDType`** — Symbolic dtype. Use `.set_options<Ts...>()` to restrict allowed types.
68
+ - **`host::SymbolicDevice`** — Symbolic device. Use `.set_options<kDLCUDA>()` to restrict to CUDA.
69
+ - **`host::TensorMatcher({dims...})`** — Fluent builder for tensor validation:
70
+ - `.with_dtype<T>()` — require a specific C++ type (e.g. `fp16_t`)
71
+ - `.with_dtype<T1, T2, ...>()` — allow a set of types
72
+ - `.with_device<kDLCUDA>(device_sym)` — require CUDA, bind device to symbol
73
+ - `.with_strides({strides...})` — validate strides (omit to require contiguous)
74
+ - `.verify(tensor_view)` — execute the check; throws `PanicError` with full context on failure; **chainable** (`verify(a).verify(b)` to check multiple tensors with the same shape)
75
+
76
+ **Typical pattern:**
77
+ ```cpp
78
+ auto N = SymbolicSize{"num_elements"};
79
+ auto device = SymbolicDevice{};
80
+ device.set_options<kDLCUDA>();
81
+ TensorMatcher({N}) //
82
+ .with_dtype<fp16_t>()
83
+ .with_device(device)
84
+ .verify(dst)
85
+ .verify(src); // same shape, dtype, device as dst
86
+ const size_t n = N.unwrap();
87
+ const DLDevice dev = device.unwrap();
88
+ ```
89
+
90
+ ### `type.cuh` — `dtype_trait<T>` and `packed_t<T>`
91
+
92
+ ```cpp
93
+ #include <sgl_kernel/type.cuh>
94
+ ```
95
+
96
+ - **`dtype_trait<T>`** — Static trait struct for each scalar type. Provides:
97
+ - `dtype_trait<T>::from(value)` — convert from another type (e.g. `fp32_t` → `fp16_t`)
98
+ - `dtype_trait<T>::abs/sqrt/rsqrt/max/min(x)` — type-dispatched math (for `fp32_t`)
99
+ - **`packed_t<T>`** — Two-element packed alias: `packed_t<fp16_t>` = `fp16x2_t`, `packed_t<bf16_t>` = `bf16x2_t`, `packed_t<fp32_t>` = `fp32x2_t`. Use for vectorized loads/stores.
100
+ - **`device::cast<To, From>(value)`** — Type-safe cast using `dtype_trait`, e.g. `cast<fp32x2_t, fp16x2_t>(v)`.
101
+
102
+ ### `vec.cuh` — Vectorized memory access (`AlignedVector`)
103
+
104
+ ```cpp
105
+ #include <sgl_kernel/vec.cuh>
106
+ ```
107
+
108
+ - **`device::AlignedVector<T, N>`** — Aligned storage for N elements of type T. N must be a power of two, `sizeof(T)*N <= 32`. Enables 128-bit vector loads/stores for bandwidth efficiency.
109
+ - `.load(ptr, offset)` — vectorized load from `ptr[offset]`
110
+ - `.store(ptr, offset)` — vectorized store to `ptr[offset]`
111
+ - `.fill(value)` — fill all lanes
112
+ - `operator[](i)` — element access
113
+
114
+ ### `tile.cuh` — `tile::Memory` (strided memory access pattern)
115
+
116
+ ```cpp
117
+ #include <sgl_kernel/tile.cuh>
118
+ ```
119
+
120
+ - **`device::tile::Memory<T>::cta(blockDim.x)`** — Creates a tile accessor where each thread handles `tid = threadIdx.x` with stride `blockDim.x`. Common for loops over a 1D array.
121
+ - **`.load(ptr, offset)`** — loads `ptr[tid + offset * blockDim.x]`
122
+ - **`.store(ptr, val, offset)`** — stores to `ptr[tid + offset * blockDim.x]`
123
+ - **`.in_bound(n, offset)`** — boundary check
124
+
125
+ ### `math.cuh` — Device math (`device::math::`)
126
+
127
+ ```cpp
128
+ #include <sgl_kernel/math.cuh>
129
+ ```
130
+
131
+ - `device::math::max/min/abs/sqrt/rsqrt<T>(a, b)` — type-dispatched math via `dtype_trait`
132
+ - `device::math::exp/sin/cos(float)` — fast float math wrappers
133
+
134
+ ### `warp.cuh` — Warp-level primitives
135
+
136
+ ```cpp
137
+ #include <sgl_kernel/warp.cuh>
138
+ ```
139
+
140
+ - `device::warp::reduce_sum<T>(value)` — warp-level sum reduction via `__shfl_xor_sync`
141
+ - `device::warp::reduce_max<T>(value)` — warp-level max reduction
142
+
143
+ ### `cta.cuh` — CTA-level primitives
144
+
145
+ ```cpp
146
+ #include <sgl_kernel/cta.cuh>
147
+ ```
148
+
149
+ - `device::cta::reduce_max<T>(value, smem, min_value)` — CTA-wide max using shared memory + warp reduction. Caller is responsible for a `__syncthreads()` after if the result in `smem[0]` is needed.
150
+
151
+ ### `atomic.cuh` — Atomic operations
152
+
153
+ ```cpp
154
+ #include <sgl_kernel/atomic.cuh>
155
+ ```
156
+
157
+ - `device::atomic::max(float* addr, float value)` — float atomic max (handles negative values correctly via bit tricks).
158
+
159
+ ### `runtime.cuh` — Occupancy and device info
160
+
161
+ ```cpp
162
+ #include <sgl_kernel/runtime.cuh>
163
+ ```
164
+
165
+ - `host::runtime::get_blocks_per_sm(kernel, block_dim)` — max active blocks per SM (occupancy)
166
+ - `host::runtime::get_sm_count(device_id)` — number of SMs on the device
167
+ - `host::runtime::get_cc_major(device_id)` — compute capability major version
168
+
169
+ **Persistent kernel pattern** (cap blocks to SM count × occupancy):
170
+ ```cpp
171
+ static const uint32_t max_occ = runtime::get_blocks_per_sm(kernel, kBlockSize);
172
+ static const uint32_t num_sm = runtime::get_sm_count(device.unwrap().device_id);
173
+ const auto num_blocks = std::min(num_sm * max_occ, div_ceil(n, kBlockSize));
174
+ LaunchKernel(num_blocks, kBlockSize, device.unwrap())(kernel, params);
175
+ ```
176
+
177
+ ---
178
+
179
+ ## Step 0 (optional): Generate a `.clangd` config for better IDE support
180
+
181
+ ```bash
182
+ python -m sglang.jit_kernel
183
+ ```
184
+
185
+ ---
186
+
187
+ ## Step 1: Implement the CUDA kernel in `jit_kernel/csrc/`
188
+
189
+ Create `python/sglang/jit_kernel/csrc/elementwise/scale.cuh`.
190
+
191
+ The implementation fully uses the project abstractions described above:
192
+
193
+ ```cpp
194
+ #include <sgl_kernel/tensor.h> // TensorMatcher, SymbolicSize, SymbolicDevice
195
+ #include <sgl_kernel/type.cuh> // dtype_trait, fp16_t, bf16_t, fp32_t
196
+ #include <sgl_kernel/utils.h> // RuntimeCheck, div_ceil
197
+ #include <sgl_kernel/utils.cuh> // LaunchKernel, SGL_DEVICE
198
+ #include <sgl_kernel/vec.cuh> // AlignedVector
199
+
200
+ #include <dlpack/dlpack.h>
201
+ #include <tvm/ffi/container/tensor.h>
202
+
203
+ namespace {
204
+
205
+ // ----------------------------------------------------------------
206
+ // Kernel: element-wise scale using vectorized 128-bit loads/stores
207
+ // T = fp16_t | bf16_t | fp32_t
208
+ // kVecN = number of elements per vector load (e.g. 8 for fp16)
209
+ // kFactor = scale factor encoded as kFactorNumer / kFactorDenom
210
+ // ----------------------------------------------------------------
211
+ template <typename T, int kVecN, int32_t kFactorNumer, int32_t kFactorDenom>
212
+ __global__ void scale_kernel(T* __restrict__ dst,
213
+ const T* __restrict__ src,
214
+ uint32_t n_vecs,
215
+ uint32_t n_remainder,
216
+ uint32_t n_total) {
217
+ constexpr float kFactor = static_cast<float>(kFactorNumer)
218
+ / static_cast<float>(kFactorDenom);
219
+
220
+ using vec_t = device::AlignedVector<T, kVecN>;
221
+
222
+ // --- vectorised body ---
223
+ const uint32_t vec_stride = blockDim.x * gridDim.x;
224
+ for (uint32_t vi = blockIdx.x * blockDim.x + threadIdx.x;
225
+ vi < n_vecs;
226
+ vi += vec_stride) {
227
+ vec_t v;
228
+ v.load(src, vi);
229
+ #pragma unroll
230
+ for (int i = 0; i < kVecN; ++i) {
231
+ v[i] = static_cast<T>(static_cast<float>(v[i]) * kFactor);
232
+ }
233
+ v.store(dst, vi);
234
+ }
235
+
236
+ // --- scalar tail ---
237
+ const uint32_t base = n_vecs * kVecN;
238
+ const uint32_t scalar_stride = blockDim.x * gridDim.x;
239
+ for (uint32_t i = blockIdx.x * blockDim.x + threadIdx.x;
240
+ i < n_remainder;
241
+ i += scalar_stride) {
242
+ dst[base + i] = static_cast<T>(static_cast<float>(src[base + i]) * kFactor);
243
+ }
244
+ }
245
+
246
+ // ----------------------------------------------------------------
247
+ // Launcher: validates tensors, selects vector width, launches kernel
248
+ // ----------------------------------------------------------------
249
+ template <typename T, int32_t kFactorNumer, int32_t kFactorDenom>
250
+ void scale(tvm::ffi::TensorView dst, tvm::ffi::TensorView src) {
251
+ using namespace host;
252
+
253
+ // 1. Validate input tensors with TensorMatcher
254
+ SymbolicSize N = {"num_elements"};
255
+ SymbolicDevice device_;
256
+ device_.set_options<kDLCUDA>();
257
+
258
+ TensorMatcher({N}) //
259
+ .with_dtype<T>()
260
+ .with_device(device_)
261
+ .verify(dst)
262
+ .verify(src); // same shape / dtype / device as dst
263
+
264
+ const uint32_t n = static_cast<uint32_t>(N.unwrap());
265
+ const DLDevice device = device_.unwrap();
266
+
267
+ RuntimeCheck(n > 0, "scale: num_elements must be > 0, got ", n);
268
+
269
+ // 2. Choose vector width for 128-bit loads (16 bytes)
270
+ // fp16/bf16: 8 elements × 2 bytes = 16 bytes
271
+ // fp32: 4 elements × 4 bytes = 16 bytes
272
+ constexpr int kVecN = 16 / sizeof(T);
273
+ const uint32_t n_vecs = n / kVecN;
274
+ const uint32_t n_remainder = n % kVecN;
275
+
276
+ // 3. Launch
277
+ constexpr uint32_t kBlockSize = 256;
278
+ const uint32_t grid = div_ceil(std::max(n_vecs, n_remainder), kBlockSize);
279
+
280
+ LaunchKernel(grid, kBlockSize, device)(
281
+ scale_kernel<T, kVecN, kFactorNumer, kFactorDenom>,
282
+ static_cast<T*>(dst.data_ptr()),
283
+ static_cast<const T*>(src.data_ptr()),
284
+ n_vecs,
285
+ n_remainder,
286
+ n);
287
+ }
288
+
289
+ } // namespace
290
+ ```
291
+
292
+ **Key points:**
293
+
294
+ - Include headers from `sgl_kernel/` — **not** raw CUDA headers for anything already covered
295
+ - Use `TensorMatcher` for all tensor validation; never manually check shape/dtype/device
296
+ - Use `AlignedVector` for vectorised 128-bit loads/stores — significant bandwidth win
297
+ - Use `LaunchKernel` — it resolves the stream and checks errors automatically
298
+ - Use `RuntimeCheck` for runtime assertions with useful error messages
299
+ - `fp16_t` / `bf16_t` / `fp32_t` are the project's type aliases (from `utils.cuh`)
300
+ - `device::cast<To, From>` or `dtype_trait<T>::from(val)` for cross-type conversions
301
+ - `device::math::` functions for device math instead of bare `__` intrinsics
302
+
303
+ ---
304
+
305
+ ## Step 2: Add the Python wrapper in `jit_kernel/`
306
+
307
+ Create `python/sglang/jit_kernel/scale.py`:
308
+
309
+ ```python
310
+ from __future__ import annotations
311
+
312
+ from typing import TYPE_CHECKING
313
+
314
+ import torch
315
+
316
+ from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args
317
+
318
+ if TYPE_CHECKING:
319
+ from tvm_ffi.module import Module
320
+
321
+
322
+ @cache_once
323
+ def _jit_scale_module(dtype: torch.dtype, factor_numer: int, factor_denom: int) -> Module:
324
+ """Compile and cache the JIT scale module for a given dtype and factor."""
325
+ args = make_cpp_args(dtype, factor_numer, factor_denom)
326
+ return load_jit(
327
+ "scale",
328
+ *args,
329
+ cuda_files=["elementwise/scale.cuh"],
330
+ cuda_wrappers=[("scale", f"scale<{args}>")],
331
+ )
332
+
333
+
334
+ def scale(src: torch.Tensor, factor: float, out: torch.Tensor | None = None) -> torch.Tensor:
335
+ """
336
+ Element-wise scale: dst = src * factor.
337
+
338
+ Supported dtypes: torch.float16, torch.bfloat16, torch.float32.
339
+
340
+ Parameters
341
+ ----------
342
+ src : CUDA tensor (FP16 / BF16 / FP32)
343
+ factor : scale factor
344
+ out : optional pre-allocated output tensor (same shape/dtype as src)
345
+
346
+ Returns
347
+ -------
348
+ Scaled tensor (dst = src * factor).
349
+ """
350
+ assert src.is_cuda, "src must be a CUDA tensor"
351
+ assert src.dtype in (torch.float16, torch.bfloat16, torch.float32), (
352
+ f"Unsupported dtype {src.dtype}. Supported: float16, bfloat16, float32"
353
+ )
354
+ if out is None:
355
+ out = torch.empty_like(src)
356
+ else:
357
+ assert out.shape == src.shape, "out shape must match src"
358
+ assert out.dtype == src.dtype, "out dtype must match src"
359
+
360
+ # Encode factor as integer ratio; denom=1000 gives 3 decimal places of precision
361
+ factor_denom = 1000
362
+ factor_numer = round(factor * factor_denom)
363
+
364
+ module = _jit_scale_module(src.dtype, factor_numer, factor_denom)
365
+ module.scale(out, src)
366
+ return out
367
+ ```
368
+
369
+ **Key points:**
370
+
371
+ - Use `cache_once` — **not** `functools.lru_cache` (incompatible with `torch.compile`)
372
+ - `load_jit` first arg(s) form the unique build marker; same marker = same cached binary
373
+ - `cuda_wrappers`: `(export_name, kernel_symbol)` — `export_name` is called from Python
374
+ - `make_cpp_args(dtype, ...)` converts `torch.dtype` to C++ type alias:
375
+
376
+ | `torch.dtype` | C++ type |
377
+ |--------------------|------------|
378
+ | `torch.float16` | `fp16_t` |
379
+ | `torch.bfloat16` | `bf16_t` |
380
+ | `torch.float32` | `fp32_t` |
381
+
382
+ ---
383
+
384
+ ## Step 3 (optional): Tune JIT build flags
385
+
386
+ ```python
387
+ return load_jit(
388
+ "scale",
389
+ *args,
390
+ cuda_files=["elementwise/scale.cuh"],
391
+ cuda_wrappers=[("scale", f"scale<{args}>")],
392
+ extra_cuda_cflags=["-O3", "--use_fast_math"],
393
+ )
394
+ ```
395
+
396
+ If your kernel requires SM90+, raise a clear Python error before calling `load_jit`:
397
+
398
+ ```python
399
+ if torch.cuda.get_device_capability()[0] < 9:
400
+ raise RuntimeError("This kernel requires SM90 (Hopper) or later")
401
+ ```
402
+
403
+ ---
404
+
405
+ ## Step 4: Write tests (required)
406
+
407
+ Create `python/sglang/jit_kernel/tests/test_scale.py`:
408
+
409
+ ```python
410
+ import pytest
411
+ import torch
412
+ from sglang.jit_kernel.scale import scale
413
+
414
+
415
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
416
+ @pytest.mark.parametrize("size", [1, 127, 128, 1024, 4097]) # cover tail remainder
417
+ @pytest.mark.parametrize("factor", [0.5, 1.0, 2.0, 3.0])
418
+ def test_scale_correctness(dtype, size, factor):
419
+ src = torch.randn(size, dtype=dtype, device="cuda")
420
+ out = scale(src, factor)
421
+ expected = src * factor
422
+
423
+ rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-2, 1e-2)
424
+ torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
425
+
426
+
427
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
428
+ def test_scale_out_param(dtype):
429
+ src = torch.randn(1024, dtype=dtype, device="cuda")
430
+ out = torch.empty_like(src)
431
+ result = scale(src, 2.0, out=out)
432
+ assert result is out
433
+ torch.testing.assert_close(out, src * 2.0, rtol=1e-2, atol=1e-2)
434
+
435
+
436
+ def test_scale_cpu_error():
437
+ src = torch.randn(128, dtype=torch.float16) # CPU tensor
438
+ with pytest.raises(AssertionError, match="CUDA"):
439
+ scale(src, 2.0)
440
+
441
+
442
+ def test_scale_unsupported_dtype():
443
+ src = torch.randint(0, 10, (128,), dtype=torch.int32, device="cuda")
444
+ with pytest.raises(AssertionError, match="Unsupported dtype"):
445
+ scale(src, 2.0)
446
+
447
+
448
+ if __name__ == "__main__":
449
+ pytest.main([__file__, "-v", "-s"])
450
+ ```
451
+
452
+ ---
453
+
454
+ ## Step 5: Add a benchmark (required)
455
+
456
+ Create `python/sglang/jit_kernel/benchmark/bench_scale.py`:
457
+
458
+ ```python
459
+ import itertools
460
+
461
+ import torch
462
+ import triton
463
+ import triton.testing
464
+
465
+ from sglang.jit_kernel.benchmark.utils import (
466
+ DEFAULT_DEVICE,
467
+ DEFAULT_DTYPE,
468
+ get_benchmark_range,
469
+ run_benchmark,
470
+ )
471
+ from sglang.jit_kernel.scale import scale as jit_scale
472
+
473
+
474
+ SIZE_LIST = get_benchmark_range(
475
+ full_range=[2**n for n in range(10, 20)], # 1K … 512K elements
476
+ ci_range=[4096, 65536],
477
+ )
478
+
479
+ configs = list(itertools.product(SIZE_LIST))
480
+
481
+
482
+ @triton.testing.perf_report(
483
+ triton.testing.Benchmark(
484
+ x_names=["size"],
485
+ x_vals=configs,
486
+ line_arg="provider",
487
+ line_vals=["jit", "torch"],
488
+ line_names=["SGL JIT Kernel", "PyTorch"],
489
+ styles=[("blue", "-"), ("red", "--")],
490
+ ylabel="us",
491
+ plot_name="scale-performance",
492
+ args={},
493
+ )
494
+ )
495
+ def benchmark(size: int, provider: str):
496
+ src = torch.randn(size, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE)
497
+ factor = 2.0
498
+
499
+ if provider == "jit":
500
+ fn = lambda: jit_scale(src, factor)
501
+ else:
502
+ fn = lambda: src * factor
503
+
504
+ return run_benchmark(fn)
505
+
506
+
507
+ if __name__ == "__main__":
508
+ benchmark.run(print_data=True)
509
+ ```
510
+
511
+ Run:
512
+
513
+ ```bash
514
+ python python/sglang/jit_kernel/benchmark/bench_scale.py
515
+ ```
516
+
517
+ ---
518
+
519
+ ## Troubleshooting
520
+
521
+ - **JIT compilation fails**: ensure the `.cuh` file is under `python/sglang/jit_kernel/csrc/`; reduce template argument combinations
522
+ - **CUDA crash / illegal memory access**: `CUDA_LAUNCH_BLOCKING=1`; `compute-sanitizer --tool memcheck python ...`
523
+ - **Unstable benchmark results**: `run_benchmark` uses CUDA-graph-based timing by default
524
+
525
+ ---
526
+
527
+ ## References
528
+
529
+ - `docs/developer_guide/development_jit_kernel_guide.md`
530
+ - `python/sglang/jit_kernel/utils.py` — `cache_once`, `load_jit`, `make_cpp_args`
531
+ - `python/sglang/jit_kernel/include/sgl_kernel/tensor.h` — `TensorMatcher`, `SymbolicSize/DType/Device`
532
+ - `python/sglang/jit_kernel/include/sgl_kernel/utils.cuh` — type aliases, `LaunchKernel`, `SGL_DEVICE`
533
+ - `python/sglang/jit_kernel/include/sgl_kernel/vec.cuh` — `AlignedVector`
534
+ - `python/sglang/jit_kernel/include/sgl_kernel/tile.cuh` — `tile::Memory`
535
+ - `python/sglang/jit_kernel/include/sgl_kernel/type.cuh` — `dtype_trait`, `packed_t`, `device::cast`
536
+ - `python/sglang/jit_kernel/include/sgl_kernel/math.cuh` — `device::math::`
537
+ - `python/sglang/jit_kernel/include/sgl_kernel/warp.cuh` — `warp::reduce_sum/max`
538
+ - `python/sglang/jit_kernel/include/sgl_kernel/cta.cuh` — `cta::reduce_max`
539
+ - `python/sglang/jit_kernel/include/sgl_kernel/atomic.cuh` — `atomic::max`
540
+ - `python/sglang/jit_kernel/include/sgl_kernel/runtime.cuh` — occupancy / SM count helpers
541
+ - `python/sglang/jit_kernel/csrc/add_constant.cuh` — minimal runnable reference
542
+ - `python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh` — real example using `TensorMatcher` + `LaunchKernel` + `tile::Memory`
543
+ - `python/sglang/jit_kernel/csrc/elementwise/qknorm.cuh` — real example using `runtime::get_blocks_per_sm` + persistent kernel pattern
544
+ - `python/sglang/jit_kernel/benchmark/utils.py` — benchmark helpers
545
+
546
+ ## Summary of Files Created
547
+
548
+ ```
549
+ python/sglang/jit_kernel/csrc/elementwise/scale.cuh # NEW: CUDA kernel
550
+ python/sglang/jit_kernel/scale.py # NEW: Python wrapper
551
+ python/sglang/jit_kernel/tests/test_scale.py # NEW: Tests
552
+ python/sglang/jit_kernel/benchmark/bench_scale.py # NEW: Benchmark
553
+ ```
sglang/.claude/skills/add-sgl-kernel/SKILL.md ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: add-sgl-kernel
3
+ description: Step-by-step tutorial for adding a heavyweight AOT CUDA/C++ kernel to sgl-kernel (including tests & benchmarks)
4
+ ---
5
+
6
+ # Tutorial: Adding a New Kernel to `sgl-kernel` (AOT / Heavyweight)
7
+
8
+ This tutorial walks through adding a simple element-wise scale operation as an AOT kernel. We'll implement `scale(x, factor) = x * factor` to demonstrate the complete workflow.
9
+
10
+ ## Goal
11
+
12
+ Add a new operation that scales each element of a tensor by a scalar factor:
13
+
14
+ - Input: tensor `x` (CUDA) and scalar `factor` (float)
15
+ - Output: `x * factor` (element-wise, in-place or into pre-allocated `out`)
16
+ - Supported dtypes: **FP16 (`torch.float16`), BF16 (`torch.bfloat16`), FP32 (`torch.float32`)**
17
+ - Dispatched via `DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16` macro (defined in `sgl-kernel/include/utils.h`)
18
+
19
+ ## Two rules of thumb (must follow)
20
+
21
+ 1. **Heavyweight kernels go to `sgl-kernel`.** If it depends on CUTLASS / FlashInfer / DeepGEMM (or similarly heavy stacks), implement it in `sgl-kernel/`.
22
+ 2. **Lightweight kernels go to `python/sglang/jit_kernel`.** If it is small, has few dependencies, and benefits from rapid iteration, implement it as a JIT kernel instead.
23
+
24
+ In addition, every new kernel must ship with:
25
+
26
+ - **Tests** (pytest)
27
+ - **A benchmark script** (triton.testing)
28
+
29
+ ---
30
+
31
+ ## Repository integration map
32
+
33
+ You will typically touch these files/areas:
34
+
35
+ - Implementation: `sgl-kernel/csrc/elementwise/scale.cu` (pick the right subdirectory)
36
+ - Public declarations: `sgl-kernel/include/sgl_kernel_ops.h`
37
+ - Torch extension registration: `sgl-kernel/csrc/common_extension.cc`
38
+ - Build: `sgl-kernel/CMakeLists.txt` (`set(SOURCES ...)`)
39
+ - Python API: `sgl-kernel/python/sgl_kernel/` and `sgl-kernel/python/sgl_kernel/__init__.py`
40
+ - Tests: `sgl-kernel/tests/test_scale.py`
41
+ - Benchmarks: `sgl-kernel/benchmark/bench_scale.py`
42
+
43
+ ---
44
+
45
+ ## Step 1: Implement the kernel in `csrc/`
46
+
47
+ Pick the right subdirectory:
48
+
49
+ - `csrc/elementwise/` — for element-wise ops (our example)
50
+ - `csrc/gemm/`, `csrc/attention/`, `csrc/moe/` — for other categories
51
+
52
+ Create `sgl-kernel/csrc/elementwise/scale.cu`:
53
+
54
+ ```cpp
55
+ #include <ATen/cuda/CUDAContext.h>
56
+ #include <c10/cuda/CUDAGuard.h>
57
+ #include <torch/all.h>
58
+
59
+ #include "utils.h" // DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
60
+
61
+ // scale_kernel: out[i] = input[i] * factor
62
+ // Supports float, half (__half), __nv_bfloat16 via template T
63
+ template <typename T>
64
+ __global__ void scale_kernel(T* __restrict__ out,
65
+ const T* __restrict__ input,
66
+ float factor,
67
+ int64_t n) {
68
+ int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
69
+ if (idx < n) {
70
+ out[idx] = static_cast<T>(static_cast<float>(input[idx]) * factor);
71
+ }
72
+ }
73
+
74
+ void scale(at::Tensor& out, const at::Tensor& input, double factor) {
75
+ TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
76
+ TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
77
+ TORCH_CHECK(out.is_cuda(), "out must be a CUDA tensor");
78
+ TORCH_CHECK(out.is_contiguous(), "out must be contiguous");
79
+ TORCH_CHECK(out.sizes() == input.sizes(), "out and input must have the same shape");
80
+ TORCH_CHECK(out.scalar_type() == input.scalar_type(),
81
+ "out and input must have the same dtype");
82
+
83
+ const int64_t n = input.numel();
84
+ const int threads = 256;
85
+ const int blocks = (n + threads - 1) / threads;
86
+
87
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
88
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
89
+
90
+ // Dispatches over float, float16, bfloat16
91
+ DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
92
+ scale_kernel<c_type><<<blocks, threads, 0, stream>>>(
93
+ static_cast<c_type*>(out.data_ptr()),
94
+ static_cast<const c_type*>(input.data_ptr()),
95
+ static_cast<float>(factor),
96
+ n);
97
+ cudaError_t status = cudaGetLastError();
98
+ TORCH_CHECK(status == cudaSuccess,
99
+ "scale_kernel launch failed: ", cudaGetErrorString(status));
100
+ return true;
101
+ });
102
+ }
103
+ ```
104
+
105
+ **Key points:**
106
+
107
+ - Use `at::Tensor` (PyTorch tensors), `TORCH_CHECK` for validation, `at::cuda::getCurrentCUDAStream()` for stream
108
+ - `DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16` covers `float`, `half` (FP16), `__nv_bfloat16` (BF16)
109
+ - Add device error checking after every kernel launch
110
+ - If a kernel only works on certain architectures, enforce that with `TORCH_CHECK` and skip logic in tests
111
+
112
+ ---
113
+
114
+ ## Step 2: Add a C++ declaration in `include/sgl_kernel_ops.h`
115
+
116
+ Edit `sgl-kernel/include/sgl_kernel_ops.h`, add to the elementwise section:
117
+
118
+ ```cpp
119
+ void scale(at::Tensor& out, const at::Tensor& input, double factor);
120
+ ```
121
+
122
+ ---
123
+
124
+ ## Step 3: Register the op in `csrc/common_extension.cc`
125
+
126
+ Edit `sgl-kernel/csrc/common_extension.cc`, inside `TORCH_LIBRARY_FRAGMENT(sgl_kernel, m)`:
127
+
128
+ ```cpp
129
+ // From csrc/elementwise
130
+ m.def("scale(Tensor! out, Tensor input, float factor) -> ()");
131
+ m.impl("scale", torch::kCUDA, &scale);
132
+ ```
133
+
134
+ **Key points:**
135
+
136
+ - `Tensor!` means in-place / mutable output argument
137
+ - The schema is important for `torch.compile` and for consistent call signatures
138
+ - If your underlying C++ API uses `float` but PyTorch bindings expect `double`, the implicit cast is fine for scalars; use shims if needed for other types
139
+
140
+ ---
141
+
142
+ ## Step 4: Add the new source file to `CMakeLists.txt`
143
+
144
+ Edit `sgl-kernel/CMakeLists.txt`, add to `set(SOURCES ...)`:
145
+
146
+ ```cmake
147
+ csrc/elementwise/scale.cu
148
+ ```
149
+
150
+ **Key points:**
151
+
152
+ - Keep the list **alphabetically sorted** (the file explicitly requires this)
153
+ - If the kernel has arch constraints, reflect that in tests/benchmarks via skip logic
154
+
155
+ ---
156
+
157
+ ## Step 5: Expose a Python API under `sgl-kernel/python/sgl_kernel/`
158
+
159
+ In `sgl-kernel/python/sgl_kernel/__init__.py`, add:
160
+
161
+ ```python
162
+ from torch.ops import sgl_kernel as _ops
163
+
164
+ def scale(out: torch.Tensor, input: torch.Tensor, factor: float) -> None:
165
+ """
166
+ Element-wise scale: out = input * factor (in-place into out).
167
+
168
+ Supported dtypes: torch.float16, torch.bfloat16, torch.float32.
169
+
170
+ Parameters
171
+ ----------
172
+ out : pre-allocated CUDA output tensor (same shape/dtype as input)
173
+ input : CUDA input tensor
174
+ factor : scale factor (float)
175
+ """
176
+ _ops.scale(out, input, factor)
177
+ ```
178
+
179
+ Or export it from the existing module organisation — follow the pattern already used by similar ops in `__init__.py`.
180
+
181
+ ---
182
+
183
+ ## Step 6: Write tests (required)
184
+
185
+ Create `sgl-kernel/tests/test_scale.py`:
186
+
187
+ ```python
188
+ import pytest
189
+ import torch
190
+ import sgl_kernel
191
+
192
+
193
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
194
+ @pytest.mark.parametrize("size", [128, 1024, 4096, 65536])
195
+ @pytest.mark.parametrize("factor", [0.5, 1.0, 2.0])
196
+ def test_scale_correctness(dtype, size, factor):
197
+ input = torch.randn(size, dtype=dtype, device="cuda")
198
+ out = torch.empty_like(input)
199
+
200
+ sgl_kernel.scale(out, input, factor)
201
+
202
+ expected = input * factor
203
+ rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-2, 1e-2)
204
+ torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
205
+
206
+
207
+ def test_scale_shape_mismatch():
208
+ input = torch.randn(128, dtype=torch.float16, device="cuda")
209
+ out = torch.empty(256, dtype=torch.float16, device="cuda")
210
+ with pytest.raises(RuntimeError, match="same shape"):
211
+ sgl_kernel.scale(out, input, 2.0)
212
+
213
+
214
+ def test_scale_cpu_input():
215
+ input = torch.randn(128, dtype=torch.float16) # CPU
216
+ out = torch.empty_like(input)
217
+ with pytest.raises(RuntimeError, match="CUDA"):
218
+ sgl_kernel.scale(out, input, 2.0)
219
+
220
+
221
+ if __name__ == "__main__":
222
+ pytest.main([__file__, "-q"])
223
+ ```
224
+
225
+ Run:
226
+
227
+ ```bash
228
+ pytest sgl-kernel/tests/test_scale.py -q
229
+ ```
230
+
231
+ ---
232
+
233
+ ## Step 7: Add a benchmark (required)
234
+
235
+ Create `sgl-kernel/benchmark/bench_scale.py`:
236
+
237
+ ```python
238
+ import itertools
239
+ import os
240
+
241
+ import torch
242
+ import triton
243
+ import triton.testing
244
+
245
+ import sgl_kernel
246
+
247
+ IS_CI = (
248
+ os.getenv("CI", "false").lower() == "true"
249
+ or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
250
+ )
251
+
252
+ dtypes = [torch.float16] if IS_CI else [torch.float16, torch.bfloat16, torch.float32]
253
+ sizes = [4096] if IS_CI else [2**n for n in range(10, 20)] # 1K … 512K
254
+ factors = [2.0]
255
+
256
+ configs = list(itertools.product(dtypes, sizes))
257
+
258
+
259
+ def torch_scale(input: torch.Tensor, factor: float) -> torch.Tensor:
260
+ return input * factor
261
+
262
+
263
+ @triton.testing.perf_report(
264
+ triton.testing.Benchmark(
265
+ x_names=["dtype", "size"],
266
+ x_vals=configs,
267
+ line_arg="provider",
268
+ line_vals=["sglang", "torch"],
269
+ line_names=["SGL Kernel", "PyTorch"],
270
+ styles=[("green", "-"), ("red", "--")],
271
+ ylabel="µs (median)",
272
+ plot_name="scale-performance",
273
+ args={},
274
+ )
275
+ )
276
+ def benchmark(dtype, size, provider):
277
+ input = torch.randn(size, dtype=dtype, device="cuda")
278
+ out = torch.empty_like(input)
279
+ factor = 2.0
280
+
281
+ if provider == "sglang":
282
+ fn = lambda: sgl_kernel.scale(out, input, factor)
283
+ else:
284
+ fn = lambda: torch_scale(input, factor)
285
+
286
+ ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
287
+ fn, quantiles=[0.5, 0.2, 0.8]
288
+ )
289
+ return 1000 * ms, 1000 * max_ms, 1000 * min_ms
290
+
291
+
292
+ if __name__ == "__main__":
293
+ benchmark.run(print_data=True)
294
+ ```
295
+
296
+ Run:
297
+
298
+ ```bash
299
+ python sgl-kernel/benchmark/bench_scale.py
300
+ ```
301
+
302
+ ---
303
+
304
+ ## Step 8: Build and validate
305
+
306
+ Build:
307
+
308
+ ```bash
309
+ cd sgl-kernel
310
+ make build -j16
311
+ ```
312
+
313
+ If you need to limit host resource usage:
314
+
315
+ ```bash
316
+ cd sgl-kernel
317
+ make build -j1 MAX_JOBS=2 CMAKE_ARGS="-DSGL_KERNEL_COMPILE_THREADS=1"
318
+ ```
319
+
320
+ Validate:
321
+
322
+ ```bash
323
+ pytest sgl-kernel/tests/test_scale.py -q
324
+ python sgl-kernel/benchmark/bench_scale.py
325
+ ```
326
+
327
+ ---
328
+
329
+ ## Troubleshooting
330
+
331
+ - **Async CUDA errors**: `CUDA_LAUNCH_BLOCKING=1`
332
+ - **Memory errors**: `compute-sanitizer --tool memcheck python ...`
333
+ - **Build is too slow / OOM**: reduce `MAX_JOBS` and `SGL_KERNEL_COMPILE_THREADS`
334
+ - **Binary bloat**: use `sgl-kernel/analyze_whl_kernel_sizes.py`
335
+ - **CMake sources list**: if your `.cu` file is missing from `SOURCES`, the symbol will be undefined at link time
336
+
337
+ ---
338
+
339
+ ## References
340
+
341
+ - `sgl-kernel/README.md`
342
+ - `sgl-kernel/include/sgl_kernel_ops.h`
343
+ - `sgl-kernel/csrc/common_extension.cc`
344
+ - `sgl-kernel/CMakeLists.txt`
345
+ - `sgl-kernel/include/utils.h` — `DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16` macro and friends
346
+ - `sgl-kernel/csrc/elementwise/activation.cu` — reference for the FP16/BF16/FP32 dispatch pattern
347
+
348
+ ## Summary of Files Created/Modified
349
+
350
+ ```
351
+ sgl-kernel/csrc/elementwise/scale.cu # NEW: CUDA kernel + launcher
352
+ sgl-kernel/include/sgl_kernel_ops.h # MODIFIED: C++ declaration
353
+ sgl-kernel/csrc/common_extension.cc # MODIFIED: schema + dispatch registration
354
+ sgl-kernel/CMakeLists.txt # MODIFIED: add source file (alphabetical)
355
+ sgl-kernel/python/sgl_kernel/__init__.py # MODIFIED: export Python API
356
+ sgl-kernel/tests/test_scale.py # NEW: tests
357
+ sgl-kernel/benchmark/bench_scale.py # NEW: benchmark
358
+ ```
sglang/.claude/skills/sglang-bisect-ci-regression/SKILL.md ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SGLang Bisect CI Regression
2
+
3
+ Investigate a consistently failing CI test to find the root cause - whether it's a code regression from a specific PR, a hardware/runner-specific issue, or an environment change. Optionally reproduce the failure on a remote GPU server.
4
+
5
+ ## Slash Command
6
+
7
+ `/sglang-bisect-ci-regression <test_name_or_ci_url> [ssh_target] [docker_container]`
8
+
9
+ ## When to Use This Skill
10
+
11
+ - A CI test is failing consistently on main (scheduled runs)
12
+ - You need to find which PR introduced a regression
13
+ - You suspect a runner-specific or GPU-specific issue
14
+ - You want to reproduce a CI failure on a remote server
15
+
16
+ ## Arguments
17
+
18
+ - **First argument (required)**: Test file name (e.g. `test_lora_tp.py`) or a GitHub Actions job URL
19
+ - **Second argument (optional)**: SSH target for remote reproduction (e.g. `user@host`)
20
+ - **Third argument (optional)**: Docker container name on the SSH target (e.g. `sglang_dev`)
21
+
22
+ If SSH target and docker container are not provided, the skill will only perform the CI log analysis and bisection, without remote reproduction. **Ask the user** for these if reproduction is needed and they weren't provided.
23
+
24
+ ## Background: Scheduled CI Runs
25
+
26
+ SGLang uses the `pr-test.yml` workflow with **scheduled runs** (cron-triggered) to periodically test the `main` branch. These runs are the primary data source for detecting regressions:
27
+
28
+ - **Workflow**: `pr-test.yml` with `event: schedule`
29
+ - **Branch**: `main`
30
+ - **Dashboard**: https://github.com/sgl-project/sglang/actions/workflows/pr-test.yml?query=event%3Aschedule
31
+ - **Frequency**: Runs multiple times daily, each pinned to the HEAD of `main` at trigger time
32
+ - **Purpose**: Catches regressions that slip through PR-level CI (e.g., interaction bugs between merged PRs, hardware-specific issues)
33
+
34
+ Always use these scheduled runs (not PR-triggered runs) when bisecting regressions on `main`. The `--event schedule` filter in `gh run list` ensures you only see these periodic main-branch runs.
35
+
36
+ ## Workflow
37
+
38
+ ### Phase 1: Extract the Failure Signature
39
+
40
+ 1. **Get the failing test details from CI logs.** If given a URL, fetch logs directly. If given a test name, find recent scheduled runs of `pr-test.yml` on `main` that failed:
41
+
42
+ ```bash
43
+ # List recent scheduled runs targeting main (the primary source of truth for regressions)
44
+ # These are cron-triggered runs visible at:
45
+ # https://github.com/sgl-project/sglang/actions/workflows/pr-test.yml?query=event%3Aschedule
46
+ gh run list --repo sgl-project/sglang --workflow="pr-test.yml" --event schedule --branch main --limit 20 --json databaseId,conclusion,createdAt,headSha
47
+
48
+ # Find the job containing the test
49
+ gh run view {RUN_ID} --repo sgl-project/sglang --json jobs --jq '.jobs[] | select(.conclusion == "failure") | {name, conclusion, databaseId}'
50
+
51
+ # Get the failure details
52
+ gh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -E -B 5 -A 30 "AssertionError|FAIL|Error|{TEST_NAME}"
53
+ ```
54
+
55
+ 2. **Record the failure signature:**
56
+ - Exact error message and assertion
57
+ - Affected test method name
58
+ - Model/config involved
59
+ - Numeric values (e.g., tolerance diffs, scores)
60
+ - Whether the failure is deterministic (same values across runs)
61
+
62
+ ### Phase 2: Temporal Bisection
63
+
64
+ 3. **Find the boundary between passing and failing runs.** Walk through the scheduled run history (from the `pr-test.yml` schedule runs on `main`) to identify:
65
+ - Last known PASSING run (sha + date)
66
+ - First known FAILING run (sha + date)
67
+
68
+ ```bash
69
+ # For each scheduled run, check the specific partition/job status
70
+ gh run view {RUN_ID} --repo sgl-project/sglang --json jobs --jq '.jobs[] | select(.name == "{JOB_NAME}") | {conclusion, databaseId}'
71
+
72
+ # Verify a specific test passed or failed in a run
73
+ gh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -E "{TEST_NAME}|PASSED|FAILED|logprobs mismatch" | head -10
74
+ ```
75
+
76
+ 4. **List commits between the boundary:**
77
+
78
+ ```bash
79
+ git log --oneline {LAST_PASS_SHA}..{FIRST_FAIL_SHA}
80
+ ```
81
+
82
+ 5. **Filter for relevant commits** that touch files related to the failing test (model layers, kernels, test utilities, etc.):
83
+
84
+ ```bash
85
+ git log --oneline {LAST_PASS_SHA}..{FIRST_FAIL_SHA} -- {relevant_paths}
86
+ ```
87
+
88
+ ### Phase 3: Runner/Hardware Analysis
89
+
90
+ 6. **Check if the failure is runner-specific.** Extract the runner identity from each failing and passing run:
91
+
92
+ ```bash
93
+ # Get runner name and machine
94
+ gh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -E "Runner name|Machine name" | head -5
95
+
96
+ # Get GPU/driver info
97
+ gh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -i -E "NVIDIA-SMI|Driver Version|CUDA Version" | head -5
98
+
99
+ # Get package versions
100
+ gh run view {RUN_ID} --repo sgl-project/sglang --job {JOB_ID} --log 2>&1 | grep -E "sgl.kernel.*==|flashinfer.*==" | head -5
101
+ ```
102
+
103
+ 7. **Correlate runners with pass/fail outcomes.** Build a table:
104
+
105
+ | Run ID | Date | Runner | GPU Type | Driver | Result |
106
+ |--------|------|--------|----------|--------|--------|
107
+
108
+ If all failures map to a specific runner type/GPU and all passes map to another, the issue is **hardware-specific**, not a code regression.
109
+
110
+ ### Phase 4: Code Analysis
111
+
112
+ 8. **If a code regression is suspected** (failures not runner-specific), examine the candidate commits:
113
+ - Read the changed files
114
+ - Understand how the changes could affect the failing test
115
+ - Look for prefill-vs-decode differences, TP-specific paths, kernel changes
116
+
117
+ 9. **If a hardware issue is suspected**, analyze:
118
+ - Kernel compatibility (CUDA compute capability)
119
+ - Driver version differences
120
+ - All-reduce / NCCL behavior differences
121
+ - CUDA graph capture differences across GPU architectures
122
+
123
+ ### Phase 5: Remote Reproduction (Optional)
124
+
125
+ Only if SSH target and docker container were provided.
126
+
127
+ 10. **Verify the remote environment:**
128
+
129
+ ```bash
130
+ ssh {SSH_TARGET} "docker exec {CONTAINER} nvidia-smi --query-gpu=name,driver_version --format=csv"
131
+ ssh {SSH_TARGET} "docker exec {CONTAINER} pip show sgl-kernel sglang flashinfer-python 2>&1 | grep -E 'Name:|Version:'"
132
+ ```
133
+
134
+ 11. **Ensure latest code is installed.** If the container is stale, update:
135
+
136
+ ```bash
137
+ # Try fetching latest main
138
+ ssh {SSH_TARGET} "docker exec {CONTAINER} bash -c 'cd /path/to/sglang && git fetch origin main && git checkout origin/main'"
139
+ # Or download and install from tarball if git auth fails
140
+ ssh {SSH_TARGET} "docker exec {CONTAINER} bash -c 'cd /tmp && curl -L https://github.com/sgl-project/sglang/archive/refs/heads/main.tar.gz | tar xz && cd sglang-main && pip install -e \"python[all]\"'"
141
+ # Reinstall (after git fetch)
142
+ ssh {SSH_TARGET} "docker exec {CONTAINER} bash -c 'cd /path/to/sglang && pip install -e \"python[all]\"'"
143
+ # Install test dependencies if needed
144
+ ssh {SSH_TARGET} "docker exec {CONTAINER} pip install peft rouge-score"
145
+ ```
146
+
147
+ 12. **Create a minimal reproduction script** that:
148
+ - Uses `if __name__ == '__main__'` with `mp.set_start_method("spawn")`
149
+ - Runs the specific failing test configuration
150
+ - Prints key metrics (diffs, scores, outputs)
151
+ - Exits with code 1 on failure
152
+
153
+ 13. **Copy and run the reproduction script:**
154
+
155
+ ```bash
156
+ scp /tmp/repro_script.py {SSH_TARGET}:/tmp/
157
+ ssh {SSH_TARGET} "docker cp /tmp/repro_script.py {CONTAINER}:/tmp/"
158
+ ssh {SSH_TARGET} "docker exec -e CUDA_VISIBLE_DEVICES=0,1 {CONTAINER} python3 /tmp/repro_script.py"
159
+ ```
160
+
161
+ 14. **Run control experiments** to isolate the variable:
162
+ - If suspecting TP issue: run with TP=1 as control
163
+ - If suspecting GPU issue: compare same code on different GPU
164
+ - If suspecting a specific commit: test before/after that commit
165
+
166
+ ### Phase 6: Report
167
+
168
+ 15. **Produce a structured report:**
169
+
170
+ ```markdown
171
+ ## CI Regression Bisection Report
172
+
173
+ ### Failure Signature
174
+ - **Test**: {test_file}::{test_method}
175
+ - **Error**: {exact error message}
176
+ - **Key metrics**: {numeric values}
177
+ - **Deterministic**: Yes/No
178
+
179
+ ### Root Cause Classification
180
+ One of:
181
+ - **Code Regression**: PR #{number} introduced the bug
182
+ - **Hardware-Specific**: Fails on {GPU_TYPE}, passes on others
183
+ - **Environment Change**: New runner/driver/package version
184
+ - **Pre-existing Flakiness**: Intermittent, not a new regression
185
+
186
+ ### Evidence
187
+ | Condition | Result |
188
+ |-----------|--------|
189
+ | {condition1} | PASS/FAIL |
190
+ | {condition2} | PASS/FAIL |
191
+
192
+ ### Timeline
193
+ - {date}: Last known pass ({sha}, {runner})
194
+ - {date}: First known fail ({sha}, {runner})
195
+ - {date}: Confirmed reproduction on {server}
196
+
197
+ ### Recommended Fix
198
+ - **Short-term**: {workaround}
199
+ - **Long-term**: {proper fix}
200
+ ```
201
+
202
+ ## Key Patterns to Recognize
203
+
204
+ | Pattern | Diagnosis |
205
+ |---------|-----------|
206
+ | Same SHA passes on runner A, fails on runner B | Hardware/runner-specific |
207
+ | All runners fail after commit X | Code regression from commit X |
208
+ | Intermittent - same runner sometimes passes/fails | Flaky test or race condition |
209
+ | Prefill OK but decode fails | TP/all-reduce issue in decode path |
210
+ | Works with TP=1, fails with TP>1 | Tensor parallelism bug |
211
+ | Exact same numeric diff every time | Deterministic bug, not flakiness |
212
+
213
+ ## Important Notes
214
+
215
+ - **Always check runner identity** before concluding it's a code regression. Many "consistent" failures are actually runner-specific.
216
+ - **Test partition assignments change over time** as tests are added/removed. A test may move between partitions, landing on different runner types.
217
+ - **H200 runners** use `/root/actions-runner/` path and machine names like `gpu-h200-worker-*`. Non-H200 runners use `/public_sglang_ci/runner-*` paths.
218
+ - When running remote reproduction, use `run_in_background` for long-running tests and check output with `TaskOutput`.
219
+ - Container environments may be stale - always verify package versions match CI before drawing conclusions.
sglang/.claude/skills/write-sglang-test/SKILL.md ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: write-sglang-test
3
+ description: Guide for writing SGLang CI/UT tests following project conventions. Covers CustomTestCase, CI registration, server fixtures, model selection, and test placement. Use when creating new tests, adding CI test cases, writing unit tests, or when the user asks to add tests for SGLang features.
4
+ ---
5
+
6
+ # Writing SGLang CI / UT Tests
7
+
8
+ ## Core Rules
9
+
10
+ 1. **Always use `CustomTestCase`** — never raw `unittest.TestCase`
11
+ 2. **Place tests in `test/registered/<category>/`** — only use `test/manual/` for debugging / non-CI tests
12
+ 3. **Reuse server fixtures** — inherit from `DefaultServerBase` or write `setUpClass`/`tearDownClass` with `popen_launch_server`
13
+ 4. **Smallest model for model-agnostic functionality** — use `DEFAULT_SMALL_MODEL_NAME_FOR_TEST` (Llama-3.2-1B-Instruct) for basic features that don't depend on model size
14
+ 5. **8B for general performance** — use `DEFAULT_MODEL_NAME_FOR_TEST` (Llama-3.1-8B-Instruct, single-node) for performance tests that don't involve spec / DP / parallelism
15
+ 6. **Bigger features → discuss case by case** — spec, DP attention, tensor/pipeline parallelism etc. may need multi-GPU suites and specific models
16
+
17
+ ---
18
+
19
+ ## Test File Template
20
+
21
+ ### Functional correctness test (small model)
22
+
23
+ ```python
24
+ import unittest
25
+
26
+ import requests
27
+
28
+ from sglang.srt.utils import kill_process_tree
29
+ from sglang.test.ci.ci_register import register_cuda_ci
30
+ from sglang.test.test_utils import (
31
+ DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
32
+ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
33
+ DEFAULT_URL_FOR_TEST,
34
+ CustomTestCase,
35
+ popen_launch_server,
36
+ )
37
+
38
+ register_cuda_ci(est_time=60, suite="stage-b-test-small-1-gpu")
39
+
40
+
41
+ class TestMyFeature(CustomTestCase):
42
+ @classmethod
43
+ def setUpClass(cls):
44
+ cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
45
+ cls.base_url = DEFAULT_URL_FOR_TEST
46
+ cls.process = popen_launch_server(
47
+ cls.model,
48
+ cls.base_url,
49
+ timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
50
+ other_args=["--arg1", "value1"], # feature-specific args
51
+ )
52
+
53
+ @classmethod
54
+ def tearDownClass(cls):
55
+ kill_process_tree(cls.process.pid)
56
+
57
+ def test_basic_functionality(self):
58
+ response = requests.post(
59
+ self.base_url + "/generate",
60
+ json={"text": "Hello", "sampling_params": {"max_new_tokens": 32}},
61
+ )
62
+ self.assertEqual(response.status_code, 200)
63
+
64
+
65
+ if __name__ == "__main__":
66
+ unittest.main(verbosity=3)
67
+ ```
68
+
69
+ ### General performance test (8B model, single node, no spec/DP/parallelism)
70
+
71
+ ```python
72
+ import time
73
+ import unittest
74
+
75
+ import requests
76
+
77
+ from sglang.srt.utils import kill_process_tree
78
+ from sglang.test.ci.ci_register import register_cuda_ci
79
+ from sglang.test.test_utils import (
80
+ DEFAULT_MODEL_NAME_FOR_TEST,
81
+ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
82
+ DEFAULT_URL_FOR_TEST,
83
+ CustomTestCase,
84
+ popen_launch_server,
85
+ )
86
+
87
+ register_cuda_ci(est_time=300, suite="stage-b-test-large-1-gpu")
88
+
89
+
90
+ class TestMyFeaturePerf(CustomTestCase):
91
+ @classmethod
92
+ def setUpClass(cls):
93
+ cls.model = DEFAULT_MODEL_NAME_FOR_TEST
94
+ cls.base_url = DEFAULT_URL_FOR_TEST
95
+ cls.process = popen_launch_server(
96
+ cls.model,
97
+ cls.base_url,
98
+ timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
99
+ )
100
+
101
+ @classmethod
102
+ def tearDownClass(cls):
103
+ kill_process_tree(cls.process.pid)
104
+
105
+ def test_latency(self):
106
+ start = time.perf_counter()
107
+ response = requests.post(
108
+ self.base_url + "/generate",
109
+ json={"text": "Hello", "sampling_params": {"max_new_tokens": 128}},
110
+ )
111
+ elapsed = time.perf_counter() - start
112
+ self.assertEqual(response.status_code, 200)
113
+ self.assertLess(elapsed, 5.0, "Latency exceeded threshold")
114
+
115
+
116
+ if __name__ == "__main__":
117
+ unittest.main(verbosity=3)
118
+ ```
119
+
120
+ ---
121
+
122
+ ## Server Fixture Reuse
123
+
124
+ For tests that only need a standard server, inherit from `DefaultServerBase` and override class attributes:
125
+
126
+ ```python
127
+ from sglang.test.server_fixtures.default_fixture import DefaultServerBase
128
+
129
+ class TestMyFeature(DefaultServerBase):
130
+ model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
131
+ other_args = ["--enable-my-feature"]
132
+
133
+ def test_something(self):
134
+ ...
135
+ ```
136
+
137
+ Available fixtures in `python/sglang/test/server_fixtures/`:
138
+
139
+ | Fixture | Use case |
140
+ |---------|----------|
141
+ | `DefaultServerBase` | Standard single-server tests |
142
+ | `EagleServerBase` | EAGLE speculative decoding |
143
+ | `PDDisaggregationServerBase` | Disaggregated prefill/decode |
144
+ | `MMMUServerBase` | Multimodal VLM tests |
145
+
146
+ ---
147
+
148
+ ## CI Registration
149
+
150
+ Every test file in `test/registered/` **must** call a registration function at module level:
151
+
152
+ ```python
153
+ from sglang.test.ci.ci_register import register_cuda_ci, register_amd_ci
154
+
155
+ register_cuda_ci(est_time=60, suite="stage-b-test-small-1-gpu")
156
+ register_amd_ci(est_time=60, suite="stage-b-test-small-1-gpu-amd") # optional
157
+ ```
158
+
159
+ Parameters:
160
+ - `est_time`: estimated runtime in seconds (used for CI partitioning)
161
+ - `suite`: which CI suite to run in (see below)
162
+ - `nightly=True`: for nightly-only tests (default `False` = per-commit)
163
+ - `disabled="reason"`: temporarily disable with explanation
164
+
165
+ ### Suite selection guide
166
+
167
+ **Default cases (1 GPU):**
168
+
169
+ | Scenario | Model | Suite |
170
+ |----------|-------|-------|
171
+ | Model-agnostic basic functionality | 1B (smallest) | `stage-b-test-small-1-gpu` |
172
+ | General performance (no spec/DP/parallelism) | 8B | `stage-b-test-large-1-gpu` |
173
+
174
+ **Bigger features (case by case):**
175
+
176
+ | Scenario | Suite |
177
+ |----------|-------|
178
+ | 2 GPU (e.g. TP=2) | `stage-b-test-large-2-gpu` |
179
+ | 4 GPU (H100) | `stage-c-test-4-gpu-h100` |
180
+ | 8 GPU (H200) | `stage-c-test-8-gpu-h200` |
181
+ | Nightly, 1 GPU | `nightly-1-gpu` |
182
+ | Nightly, 8 GPU | `nightly-8-gpu` |
183
+
184
+ For spec, DP attention, parallelism, disaggregation, etc., discuss with the team to determine the appropriate suite and GPU configuration.
185
+
186
+ ---
187
+
188
+ ## Model Constants
189
+
190
+ All defined in `python/sglang/test/test_utils.py`:
191
+
192
+ | Constant | Model | When to use |
193
+ |----------|-------|-------------|
194
+ | `DEFAULT_SMALL_MODEL_NAME_FOR_TEST` | Llama-3.2-1B-Instruct | Model-agnostic basic functionality |
195
+ | `DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE` | Llama-3.2-1B | Base (non-instruct) model tests |
196
+ | `DEFAULT_MODEL_NAME_FOR_TEST` | Llama-3.1-8B-Instruct | General performance (single node) |
197
+ | `DEFAULT_MOE_MODEL_NAME_FOR_TEST` | Mixtral-8x7B-Instruct | MoE-specific tests |
198
+ | `DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST` | — | Embedding tests |
199
+ | `DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST` | — | Vision-language tests |
200
+
201
+ ---
202
+
203
+ ## Test Placement
204
+
205
+ ```
206
+ test/
207
+ ├── registered/ # CI tests (auto-discovered by run_suite.py)
208
+ │ ├── sampling/ # test_penalty.py, test_sampling_params.py ...
209
+ │ ├── sessions/ # test_session_control.py ...
210
+ │ ├── openai_server/ # basic/, features/, validation/ ...
211
+ │ ├── spec/ # eagle/, utils/ ...
212
+ │ ├── models/ # model-specific accuracy tests
213
+ │ ├── perf/ # performance benchmarks
214
+ │ └── <category>/ # create new category if needed
215
+ ├── manual/ # Non-CI: debugging, one-off, manual verification
216
+ └── run_suite.py # CI runner (scans registered/ only)
217
+ ```
218
+
219
+ **Decision rule**: if the test should run in CI → `registered/`. If it's for local debugging or requires special hardware not in CI → `manual/`.
220
+
221
+ ---
222
+
223
+ ## Key Utilities
224
+
225
+ ```python
226
+ from sglang.test.test_utils import (
227
+ CustomTestCase, # base class with retry logic
228
+ popen_launch_server, # launch server subprocess
229
+ DEFAULT_URL_FOR_TEST, # auto-configured base URL
230
+ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, # 600s default
231
+ run_bench_serving, # benchmark helper (launch + bench)
232
+ )
233
+ from sglang.srt.utils import kill_process_tree # cleanup server
234
+ ```
235
+
236
+ ---
237
+
238
+ ## Checklist
239
+
240
+ Before submitting a test:
241
+
242
+ - [ ] Inherits from `CustomTestCase` (not `unittest.TestCase`)
243
+ - [ ] Has `register_*_ci(...)` call at module level
244
+ - [ ] Placed in `test/registered/<category>/`
245
+ - [ ] Model selection: smallest for model-agnostic features, 8B for general perf, case-by-case for other complex features
246
+ - [ ] `setUpClass` launches server, `tearDownClass` kills it
247
+ - [ ] Has `if __name__ == "__main__": unittest.main(verbosity=3)`
248
+ - [ ] `est_time` is reasonable (measure locally)
sglang/benchmark/json_jump_forward/README.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Run benchmark
2
+
3
+ ### Dependencies
4
+
5
+ ```
6
+ llama_cpp_python 0.2.38
7
+ guidance 0.1.10
8
+ vllm 0.2.7
9
+ outlines 0.0.25
10
+ ```
11
+
12
+ ### Build dataset
13
+
14
+ When benchmarking long document information retrieval, run the following command to build the dataset:
15
+
16
+ ```bash
17
+ pip install wikipedia
18
+ python3 build_dataset.py
19
+ ```
20
+
21
+ ### Benchmark sglang
22
+
23
+ Run Llama-7B
24
+
25
+ ```bash
26
+ python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
27
+ ```
28
+
29
+ Benchmark Character Generation
30
+
31
+ ```bash
32
+ python3 bench_sglang.py --mode character
33
+ ```
34
+
35
+ Benchmark City Information Retrieval
36
+
37
+ ```bash
38
+ python3 bench_sglang.py --mode city
39
+ ```
40
+
41
+
42
+ ### Benchmark Outlines + vLLM
43
+
44
+ Run Llama-7B
45
+
46
+ ```bash
47
+ python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
48
+ ```
49
+
50
+ Benchmark Character Generation
51
+
52
+ ```bash
53
+ python3 bench_other.py --mode character --backend outlines
54
+ ```
55
+
56
+ Benchmark City Information Retrieval
57
+
58
+ ```bash
59
+ python3 bench_other.py --mode city --backend outlines
60
+ ```
61
+
62
+ ### Benchmark guidance
63
+
64
+ Run Llama-7B and benchmark character generation
65
+
66
+ ```bash
67
+ python3 bench_other.py --mode character --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
68
+ ```
69
+
70
+ Run Llama-7B and benchmark city information retrieval
71
+
72
+ ```bash
73
+ python3 bench_other.py --mode city --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
74
+ ```
75
+
76
+ ### Benchmark lmql
77
+
78
+ Run Llama-7B and benchmark character generation
79
+
80
+ ```
81
+ python3 bench_other.py --mode character --backend lmql --parallel 1
82
+ ```
83
+
84
+ Run Llama-7B and benchmark city information retrieval
85
+
86
+ ```
87
+ python3 bench_other.py --mode city --backend lmql --parallel 1
88
+ ```
sglang/benchmark/json_jump_forward/bench_other.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from functools import partial
6
+
7
+ import guidance
8
+ from tqdm import tqdm
9
+
10
+ from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
11
+ from sglang.utils import dump_state_text, read_jsonl
12
+
13
+ # there are some FSM bugs with json regex converted from pydantic model
14
+ # here use a string regex instead
15
+ # regex_string = build_regex_from_object(HarryPoterRole)
16
+ character_regex = (
17
+ r"""\{\n"""
18
+ + r""" "name": "[\w\d\s]{1,16}",\n"""
19
+ + r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
20
+ + r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
21
+ + r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
22
+ + r""" "wand": \{\n"""
23
+ + r""" "wood": "[\w\d\s]{1,16}",\n"""
24
+ + r""" "core": "[\w\d\s]{1,16}",\n"""
25
+ + r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
26
+ + r""" \},\n"""
27
+ + r""" "alive": "(Alive|Deceased)",\n"""
28
+ + r""" "patronus": "[\w\d\s]{1,16}",\n"""
29
+ + r""" "bogart": "[\w\d\s]{1,16}"\n"""
30
+ + r"""\}"""
31
+ )
32
+
33
+ city_regex = (
34
+ r"""\{\n"""
35
+ + r""" "name": "[\w\d\s]{1,16}",\n"""
36
+ + r""" "country": "[\w\d\s]{1,16}",\n"""
37
+ + r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
38
+ + r""" "population": [-+]?[0-9]{1,9},\n"""
39
+ + r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
40
+ + r"""\}"""
41
+ )
42
+
43
+ # fmt: off
44
+ def character_gen(name, generate):
45
+ s = name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
46
+ s += generate(s, max_tokens=256, regex=character_regex)
47
+ return s
48
+ # fmt: on
49
+
50
+ # fmt: off
51
+ def city_gen(document, generate):
52
+ s = "Please extract the information of a city from the following wikipedia page.\n"
53
+ s += "Page begin.\n" + document + "Page end.\n"
54
+ s += "Here is the name, country, and symbol of the city in JSON format.\n"
55
+ s += generate(s, max_tokens=256, regex=city_regex)
56
+ return s
57
+ # fmt: on
58
+
59
+
60
+ @guidance
61
+ def character_maker(lm, name):
62
+ regex_str_no_quote = r"[\w\d\s]+"
63
+ regex_float = r"[0-9]+\.[0-9]+"
64
+ lm += f"""\
65
+ {name} is a character in Harry Potter. Please fill in the following information about this character.
66
+ {{
67
+ "name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}",
68
+ "house": "{guidance.select(options=['Gryffindor', 'Slytherin', 'Ravenclaw', 'Hufflepuff'], name='house')}",
69
+ "blood status": "{guidance.select(options=['Pure-blood', 'Half-blood', 'Muggle-born'], name='blood status')}",
70
+ "occupation": "{guidance.select(options=['student', 'teacher', 'auror', 'ministry of magic', 'death eater', 'order of the phoenix'], name='occupation')}",
71
+ "wand": {{
72
+ "wood": "{guidance.gen("wood", max_tokens=16, regex=regex_str_no_quote)}",
73
+ "core": "{guidance.gen('core', max_tokens=16, regex=regex_str_no_quote)}",
74
+ "length": {guidance.gen('length', max_tokens=10, regex=regex_float)}
75
+ }},
76
+ "alive": "{guidance.select(options=['Alive', 'Deceased'], name='alive')}",
77
+ "patronus": "{guidance.gen('patronus', max_tokens=16, regex=regex_str_no_quote)}",
78
+ "bogart": "{guidance.gen('bogart', max_tokens=16, regex=regex_str_no_quote)}"
79
+ }}
80
+ """
81
+
82
+ return lm
83
+
84
+
85
+ async def call_generate_lmql(
86
+ prompt, temperature, max_tokens, regex, max_len=4096, model=None, **kwargs
87
+ ):
88
+ assert model is not None
89
+ import lmql
90
+
91
+ @lmql.query(model=model)
92
+ async def program(question, max_tokens, regex):
93
+ '''lmql
94
+ """{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and REGEX(ANSWER, regex)
95
+ return ANSWER
96
+ '''
97
+
98
+ return await program(
99
+ question=prompt,
100
+ temperature=temperature,
101
+ max_tokens=max_tokens,
102
+ max_len=max_len,
103
+ regex=regex,
104
+ **kwargs,
105
+ )
106
+
107
+
108
+ @guidance
109
+ def city_maker(lm, document):
110
+ regex_str_no_quote = r"[\w\d\s]+"
111
+ regex_float = r"[0-9]+\.[0-9]+"
112
+ lm += f"""\
113
+ Please extract the information of a city from the following wikipedia page.
114
+ Page begin.
115
+ {document}
116
+ Page end.
117
+ Here is the name, country, and symbol of the city in JSON format.
118
+ {{
119
+ "name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}",
120
+ "country": "{guidance.gen("country", max_tokens=16, regex=regex_str_no_quote)}",
121
+ "latitude": {guidance.gen("latitude", max_tokens=10, regex=regex_float)},
122
+ "population": {guidance.gen("population", max_tokens=10, regex=r"[0-9]+")},
123
+ "top 3 landmarks": [
124
+ "{guidance.gen("landmark1", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark2", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark3", max_tokens=16, regex=regex_str_no_quote)}"
125
+ ]
126
+ }}
127
+ """
128
+
129
+ return lm
130
+
131
+
132
+ def bench_character(args):
133
+ arguments = []
134
+ with open(args.data_path, "r") as f:
135
+ for line in f:
136
+ arguments.append({"name": line.strip()})
137
+ arguments = arguments[: args.num_jsons]
138
+
139
+ states = [None] * len(arguments)
140
+
141
+ # Select backend
142
+ if args.backend == "outlines":
143
+ call_generate = partial(get_call_generate(args), temperature=0)
144
+
145
+ def get_one_answer(i):
146
+ states[i] = character_gen(**arguments[i], generate=call_generate)
147
+
148
+ elif args.backend == "guidance":
149
+ model = guidance.models.LlamaCpp(
150
+ args.model_path,
151
+ n_gpu_layers=-1,
152
+ n_ctx=args.n_ctx,
153
+ )
154
+
155
+ def get_one_answer(i):
156
+ lm = model + character_maker(**arguments[i])
157
+ states[i] = lm
158
+
159
+ elif args.backend == "lmql":
160
+ import asyncio
161
+
162
+ import lmql
163
+
164
+ model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
165
+ call_generate = partial(
166
+ call_generate_lmql,
167
+ model=model,
168
+ max_tokens=256,
169
+ regex=character_regex,
170
+ )
171
+
172
+ async def get_one_answer_async(i):
173
+ states[i] = await call_generate(prompt=arguments[i]["name"], temperature=0)
174
+
175
+ else:
176
+ raise ValueError(f"Invalid backend: {args.backend}")
177
+
178
+ tic = time.perf_counter()
179
+
180
+ if args.backend != "lmql":
181
+ if args.parallel == 1:
182
+ for i in tqdm(range(len(arguments))):
183
+ get_one_answer(i)
184
+ else:
185
+ with ThreadPoolExecutor(args.parallel) as executor:
186
+ rets = list(
187
+ tqdm(
188
+ executor.map(get_one_answer, list(range(len(arguments)))),
189
+ total=len(arguments),
190
+ )
191
+ )
192
+ for _ in rets:
193
+ pass
194
+ else:
195
+ batches = []
196
+ for i in range(0, len(arguments), args.parallel):
197
+ batches.append(list(range(i, min(i + args.parallel, len(arguments)))))
198
+ loop = asyncio.get_event_loop()
199
+
200
+ for bt in tqdm(batches):
201
+ loop.run_until_complete(
202
+ asyncio.gather(*[get_one_answer_async(i) for i in bt])
203
+ )
204
+
205
+ latency = time.perf_counter() - tic
206
+
207
+ return states, latency
208
+
209
+
210
+ def bench_city_doc(args):
211
+ arguments = []
212
+ for line in read_jsonl(args.data_path):
213
+ arguments.append({"document": line["document"]})
214
+ arguments = arguments[: args.num_jsons]
215
+
216
+ states = [None] * len(arguments)
217
+
218
+ # Select backend
219
+ if args.backend == "outlines":
220
+ call_generate = partial(get_call_generate(args), temperature=0)
221
+
222
+ def get_one_answer(i):
223
+ states[i] = city_gen(**arguments[i], generate=call_generate)
224
+
225
+ elif args.backend == "guidance":
226
+ model = guidance.models.LlamaCpp(
227
+ args.model_path,
228
+ n_gpu_layers=-1,
229
+ n_ctx=args.n_ctx,
230
+ )
231
+
232
+ def get_one_answer(i):
233
+ lm = model + city_maker(**arguments[i])
234
+ states[i] = lm
235
+
236
+ else:
237
+ raise ValueError(f"Invalid backend: {args.backend}")
238
+
239
+ tic = time.perf_counter()
240
+ if args.parallel == 1:
241
+ for i in tqdm(range(len(arguments))):
242
+ get_one_answer(i)
243
+ else:
244
+ with ThreadPoolExecutor(args.parallel) as executor:
245
+ rets = executor.map(get_one_answer, list(range(len(arguments))))
246
+ for _ in rets:
247
+ pass
248
+
249
+ latency = time.perf_counter() - tic
250
+
251
+ return states, latency
252
+
253
+
254
+ def main(args):
255
+ if args.mode == "character":
256
+ args.data_path = "dataset.txt"
257
+ states, latency = bench_character(args)
258
+ elif args.mode == "city":
259
+ args.data_path = "questions.jsonl"
260
+ states, latency = bench_city_doc(args)
261
+
262
+ # Compute accuracy
263
+ print(f"Latency: {latency:.3f}")
264
+
265
+ # Write results
266
+ dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)
267
+
268
+ with open(args.result_file, "a") as fout:
269
+ value = {
270
+ "task": "json_jump_forward",
271
+ "backend": args.backend,
272
+ "latency": round(latency, 3),
273
+ "num_jsons": args.num_jsons,
274
+ "mode": args.mode,
275
+ "parallel": args.parallel,
276
+ }
277
+ fout.write(json.dumps(value) + "\n")
278
+
279
+
280
+ if __name__ == "__main__":
281
+ parser = argparse.ArgumentParser()
282
+ parser.add_argument("--data-path", type=str)
283
+ parser.add_argument("--num-jsons", type=int, default=50)
284
+ parser.add_argument(
285
+ "--mode", type=str, default="character", choices=["character", "city"]
286
+ )
287
+ args = add_common_other_args_and_parse(parser)
288
+ main(args)
sglang/benchmark/json_jump_forward/bench_sglang.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+
5
+ import sglang as sgl
6
+ from sglang.test.test_utils import (
7
+ add_common_sglang_args_and_parse,
8
+ select_sglang_backend,
9
+ )
10
+ from sglang.utils import dump_state_text, read_jsonl
11
+
12
+ # there are some FSM bugs with json regex converted from pydantic model
13
+ # here use a string regex instead
14
+ # regex_string = build_regex_from_object(HarryPoterRole)
15
+ character_regex = (
16
+ r"""\{\n"""
17
+ + r""" "name": "[\w\d\s]{1,16}",\n"""
18
+ + r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
19
+ + r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
20
+ + r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
21
+ + r""" "wand": \{\n"""
22
+ + r""" "wood": "[\w\d\s]{1,16}",\n"""
23
+ + r""" "core": "[\w\d\s]{1,16}",\n"""
24
+ + r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
25
+ + r""" \},\n"""
26
+ + r""" "alive": "(Alive|Deceased)",\n"""
27
+ + r""" "patronus": "[\w\d\s]{1,16}",\n"""
28
+ + r""" "bogart": "[\w\d\s]{1,16}"\n"""
29
+ + r"""\}"""
30
+ )
31
+
32
+ city_regex = (
33
+ r"""\{\n"""
34
+ + r""" "name": "[\w\d\s]{1,16}",\n"""
35
+ + r""" "country": "[\w\d\s]{1,16}",\n"""
36
+ + r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
37
+ + r""" "population": [-+]?[0-9]{1,9},\n"""
38
+ + r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
39
+ + r"""\}"""
40
+ )
41
+
42
+ # fmt: off
43
+ @sgl.function
44
+ def character_gen(s, name):
45
+ s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
46
+ s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
47
+ # fmt: on
48
+
49
+ # fmt: off
50
+ @sgl.function
51
+ def city_gen(s, document):
52
+ s += "Please extract the information of a city from the following wikipedia page.\n"
53
+ s += "Page begin.\n" + document + "Page end.\n"
54
+ s += "Here is the name, country, and symbol of the city in JSON format.\n"
55
+ s += sgl.gen("json_output",max_tokens=256, regex=city_regex)
56
+ # fmt: on
57
+
58
+
59
+ def bench_city_doc(args):
60
+ arguments = []
61
+ for line in read_jsonl(args.data_path):
62
+ arguments.append({"document": line["document"]})
63
+ arguments = arguments[: args.num_jsons]
64
+
65
+ # Select backend
66
+ backend = select_sglang_backend(args)
67
+ sgl.set_default_backend(backend)
68
+
69
+ # Run requests
70
+ tic = time.perf_counter()
71
+ states = city_gen.run_batch(
72
+ arguments,
73
+ temperature=0,
74
+ num_threads=args.parallel,
75
+ progress_bar=True,
76
+ )
77
+ latency = time.perf_counter() - tic
78
+
79
+ return states, latency
80
+
81
+
82
+ def bench_character(args):
83
+ arguments = []
84
+ with open(args.data_path, "r") as f:
85
+ for line in f:
86
+ arguments.append({"name": line.strip()})
87
+ arguments = arguments[: args.num_jsons]
88
+
89
+ # Select backend
90
+ backend = select_sglang_backend(args)
91
+ sgl.set_default_backend(backend)
92
+
93
+ # Run requests
94
+ tic = time.perf_counter()
95
+ states = character_gen.run_batch(
96
+ arguments,
97
+ temperature=0,
98
+ num_threads=args.parallel,
99
+ progress_bar=True,
100
+ )
101
+ latency = time.perf_counter() - tic
102
+
103
+ return states, latency
104
+
105
+
106
+ def main(args):
107
+ if args.mode == "character":
108
+ args.data_path = "dataset.txt"
109
+ states, latency = bench_character(args)
110
+ elif args.mode == "city":
111
+ args.data_path = "questions.jsonl"
112
+ states, latency = bench_city_doc(args)
113
+
114
+ # Compute accuracy
115
+ print(f"Latency: {latency:.3f}")
116
+
117
+ # Write results
118
+ dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)
119
+ with open(f"{args.backend}_{args.mode}.json", "w") as fout:
120
+ for state in states:
121
+ fout.write(state["json_output"] + "\n")
122
+
123
+ with open(args.result_file, "a") as fout:
124
+ value = {
125
+ "task": "json_jump_forward",
126
+ "backend": args.backend,
127
+ "latency": round(latency, 3),
128
+ "num_jsons": args.num_jsons,
129
+ "mode": args.mode,
130
+ "parallel": args.parallel,
131
+ }
132
+ fout.write(json.dumps(value) + "\n")
133
+
134
+
135
+ if __name__ == "__main__":
136
+ parser = argparse.ArgumentParser()
137
+ parser.add_argument("--data-path", type=str)
138
+ parser.add_argument("--num-jsons", type=int, default=50)
139
+ parser.add_argument(
140
+ "--mode", type=str, default="character", choices=["character", "city"]
141
+ )
142
+ args = add_common_sglang_args_and_parse(parser)
143
+ main(args)
sglang/benchmark/json_jump_forward/build_dataset.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import transformers
4
+ import wikipedia
5
+
6
+ model_path = "meta-llama/Llama-2-7b-chat-hf"
7
+ t = transformers.AutoTokenizer.from_pretrained(model_path)
8
+ city_names = [
9
+ "los angles",
10
+ "london",
11
+ "tokyo",
12
+ "beijing",
13
+ "singapore",
14
+ "paris",
15
+ "dubai",
16
+ "sydney",
17
+ "moscow",
18
+ "rome",
19
+ "toronto",
20
+ "rio de janeiro",
21
+ "istanbul",
22
+ "berlin",
23
+ "auckland",
24
+ "buenos aires",
25
+ "mexico city",
26
+ "mumbai",
27
+ "seoul",
28
+ "bangkok",
29
+ "cairo",
30
+ "athens",
31
+ "jerusalem",
32
+ ]
33
+
34
+
35
+ def get_content(city_name):
36
+ content = str(wikipedia.page(city_name).content)
37
+ content = content.replace("\n\n", "\n")
38
+
39
+ tokens = t.encode(content)
40
+
41
+ expected_tokens = 3000
42
+ truncate_len = int((expected_tokens / len(tokens)) * len(content))
43
+ truncate_content = content[:truncate_len]
44
+ truncate_tokens = t.encode(truncate_content)
45
+
46
+ # Count token
47
+ print(
48
+ f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
49
+ )
50
+
51
+ return truncate_content
52
+
53
+
54
+ if __name__ == "__main__":
55
+ with open("questions.jsonl", "w") as fout:
56
+ for city_name in city_names:
57
+ truncate_content = get_content(city_name)
58
+ fout.write(json.dumps({"document": truncate_content}) + "\n")
sglang/benchmark/json_jump_forward/dataset.txt ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Harry Potter
2
+ Hermione Granger
3
+ Ron Weasley
4
+ Albus Dumbledore
5
+ Severus Snape
6
+ Rubeus Hagrid
7
+ Draco Malfoy
8
+ Ginny Weasley
9
+ Fred Weasley
10
+ George Weasley
11
+ Percy Weasley
12
+ Sirius Black
13
+ Remus Lupin
14
+ Neville Longbottom
15
+ Luna Lovegood
16
+ Cedric Diggory
17
+ Cho Chang
18
+ Lord Voldemort
19
+ Minerva McGonagall
20
+ Filius Flitwick
21
+ Dolores Umbridge
22
+ Bellatrix Lestrange
23
+ Lucius Malfoy
24
+ Molly Weasley
25
+ Arthur Weasley
26
+ Nymphadora Tonks
27
+ Dobby
28
+ Moaning Myrtle
29
+ Peter Pettigrew
30
+ Alastor 'Mad-Eye' Moody
31
+ Horace Slughorn
32
+ Vernon Dursley
33
+ Petunia Dursley
34
+ Dudley Dursley
35
+ Argus Filch
36
+ Sybill Trelawney
37
+ Gilderoy Lockhart
38
+ Fleur Delacour
39
+ Viktor Krum
40
+ Bill Weasley
41
+ Oliver Wood
42
+ Cornelius Fudge
43
+ Barty Crouch Sr.
44
+ Barty Crouch Jr.
45
+ Kingsley Shacklebolt
46
+ Quirinus Quirrell
47
+ Nearly Headless Nick
48
+ Aunt Marge
49
+ Griphook
50
+ Ludo Bagman
sglang/benchmark/multi_turn_chat/bench_other.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ from argparse import ArgumentParser
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from functools import partial
6
+
7
+ from data_gen import gen_arguments
8
+ from tqdm import tqdm
9
+ from vllm.transformers_utils.tokenizer import get_tokenizer
10
+
11
+ from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
12
+ from sglang.utils import dump_state_text
13
+
14
+
15
+ def multi_turns(generate, qas):
16
+ s = ""
17
+ for qa in qas:
18
+ s += qa["prompt"]
19
+ s += generate(s, max_tokens=qa["new_tokens"])
20
+
21
+ return s
22
+
23
+
24
+ def main(args):
25
+ print(args)
26
+
27
+ tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
28
+
29
+ multi_qas = gen_arguments(args, tokenizer)
30
+
31
+ states = [None] * args.num_qa
32
+
33
+ call_generate = partial(get_call_generate(args), temperature=0)
34
+
35
+ def get_one_answer(i):
36
+ states[i] = multi_turns(generate=call_generate, **multi_qas[i])
37
+
38
+ tic = time.perf_counter()
39
+ if args.parallel == 1:
40
+ for i in tqdm(range(len(multi_qas))):
41
+ get_one_answer(i)
42
+ else:
43
+ with ThreadPoolExecutor(args.parallel) as executor:
44
+ rets = list(
45
+ tqdm(
46
+ executor.map(get_one_answer, list(range(len(multi_qas)))),
47
+ total=len(multi_qas),
48
+ )
49
+ )
50
+ for _ in rets:
51
+ pass
52
+
53
+ latency = time.perf_counter() - tic
54
+
55
+ # Compute accuracy
56
+ print(f"Latency: {latency:.3f}")
57
+
58
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
59
+
60
+ with open(args.result_file, "a") as fout:
61
+ value = {
62
+ "task": "multi_turn_chat",
63
+ "backend": args.backend,
64
+ "num_gpus": 1,
65
+ "latency": round(latency, 3),
66
+ "num_requests": args.num_qa,
67
+ "num_turns": args.turns,
68
+ "other": {
69
+ "parallel": args.parallel,
70
+ "output_mode": "long" if args.long else "short",
71
+ },
72
+ }
73
+ fout.write(json.dumps(value) + "\n")
74
+
75
+
76
+ if __name__ == "__main__":
77
+ parser = ArgumentParser()
78
+ parser.add_argument("--turns", type=int, default=4)
79
+ parser.add_argument("--num-qa", type=int, default=20)
80
+ parser.add_argument("--min-len-q", type=int, default=256)
81
+ parser.add_argument("--max-len-q", type=int, default=512)
82
+ parser.add_argument("--min-len-a", type=int, default=4)
83
+ parser.add_argument("--max-len-a", type=int, default=8)
84
+ parser.add_argument("--tokenizer", type=str, required=True)
85
+ parser.add_argument("--trust-remote-code", action="store_true")
86
+ parser.add_argument("--long", action="store_true")
87
+ args = add_common_other_args_and_parse(parser)
88
+
89
+ if args.long:
90
+ args.min_len_a = 256
91
+ args.max_len_a = 512
92
+ args.num_qa = 20
93
+ main(args)
sglang/benchmark/multi_turn_chat/data_gen.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import string
3
+
4
+ random.seed(42)
5
+
6
+
7
+ def gen_prompt(tokenizer, token_num):
8
+ cha_set = string.ascii_letters + string.digits
9
+ ret = "".join(random.choices(cha_set, k=token_num))
10
+ while len(tokenizer(ret).input_ids) < token_num:
11
+ ret += random.choice(cha_set)
12
+ return ret
13
+
14
+
15
+ def gen_arguments(args, tokenizer):
16
+ multi_qas = [{"qas": []} for _ in range(args.num_qa)]
17
+ for i in range(args.num_qa):
18
+ qas = multi_qas[i]["qas"]
19
+ for _ in range(args.turns):
20
+ prompt_len = random.randint(args.min_len_q, args.max_len_q)
21
+ new_tokens = random.randint(args.min_len_a, args.max_len_a)
22
+ qas.append(
23
+ {
24
+ "prompt": gen_prompt(tokenizer, prompt_len),
25
+ "new_tokens": new_tokens,
26
+ }
27
+ )
28
+
29
+ return multi_qas
sglang/benchmark/tree_of_thought_deep/README.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Download data
2
+ ```
3
+ wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
4
+ ```
5
+
6
+ ## Run benchmark
7
+
8
+ NOTE: This is an implementation for throughput/latency benchmark purposes. The prompts are not tuned to achieve good accuracy on the GSM-8K tasks.
9
+
10
+ ### Benchmark sglang
11
+ ```
12
+ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
13
+ ```
14
+
15
+ ```
16
+ python3 bench_sglang.py --num-questions 32
17
+ python3 bench_sglang.py --num-questions 16 --parallel 1
18
+ ```
19
+
20
+
21
+ ### Benchmark vllm
22
+ ```
23
+ python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
24
+ ```
25
+
26
+ ```
27
+ python3 bench_other.py --num-questions 32 --backend vllm
28
+ ```
29
+
30
+
31
+ ### Benchmark lightllm
32
+ ```
33
+ # A10G
34
+ python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
35
+ ```
36
+
37
+ ```
38
+ python3 bench_other.py --num-questions 32 --backend lightllm
39
+ ```
40
+
41
+
42
+ ### Benchmark guidance
43
+ ```
44
+ python3 bench_other.py --num-questions 8 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
45
+ ```
46
+
47
+ ### Benchmark lmql
48
+
49
+ ```
50
+ python3 bench_other.py --num-questions 8 --backend lmql --parallel 1
51
+ ```
sglang/benchmark/tree_of_thought_deep/bench_other.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import ast
3
+ import json
4
+ import re
5
+ import time
6
+ from collections import Counter
7
+ from concurrent.futures import ThreadPoolExecutor
8
+
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+
12
+ from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
13
+ from sglang.utils import dump_state_text, read_jsonl
14
+
15
+ INVALID = -9999999
16
+
17
+
18
+ def get_answer_value(answer_str):
19
+ answer_str = answer_str.replace(",", "")
20
+ numbers = re.findall(r"\d+", answer_str)
21
+ if len(numbers) < 1:
22
+ return INVALID
23
+ try:
24
+ return ast.literal_eval(numbers[-1])
25
+ except SyntaxError:
26
+ return INVALID
27
+
28
+
29
+ def most_frequent_number(numbers):
30
+ if not numbers:
31
+ return None
32
+
33
+ frequency = Counter(numbers)
34
+ most_frequent = max(frequency, key=frequency.get)
35
+ return most_frequent
36
+
37
+
38
+ USER_PREFIX = "[INST] "
39
+ USER_SUFFIX = " [/INST]"
40
+ ASSISTANT_PREFIX = ""
41
+ ASSISTANT_SUFFIX = " </s><s>"
42
+
43
+ # Use a low temp to make the results more deterministic and the comparison more fair.
44
+ temp = 0.001
45
+
46
+
47
+ def propose_plan(s, question, num_branches, call_generate):
48
+ s += (
49
+ USER_PREFIX
50
+ + """Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """
51
+ + question
52
+ + USER_SUFFIX
53
+ )
54
+
55
+ s += ASSISTANT_PREFIX
56
+ comps = call_generate(
57
+ s, max_tokens=256, temperature=temp, stop=None, n=num_branches
58
+ )
59
+ return [s + comp + ASSISTANT_SUFFIX for comp in comps]
60
+
61
+
62
+ def execute_plan(s, num_branches, call_generate):
63
+ s += (
64
+ USER_PREFIX
65
+ + """The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short."""
66
+ + USER_SUFFIX
67
+ )
68
+ s += ASSISTANT_PREFIX
69
+ comps = call_generate(
70
+ s, max_tokens=256, temperature=temp, stop=None, n=num_branches
71
+ )
72
+ return [s + comp + ASSISTANT_SUFFIX for comp in comps]
73
+
74
+
75
+ def reflect_solution(s, num_branches, call_generate):
76
+ s += (
77
+ USER_PREFIX
78
+ + """Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness."""
79
+ + USER_SUFFIX
80
+ )
81
+ s += ASSISTANT_PREFIX
82
+ comps = call_generate(
83
+ s, max_tokens=256, temperature=temp, stop=None, n=num_branches
84
+ )
85
+ return [s + comp + ASSISTANT_SUFFIX for comp in comps]
86
+
87
+
88
+ def get_final_answer(s, num_branches, call_generate):
89
+ s += (
90
+ USER_PREFIX
91
+ + """Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration."""
92
+ + USER_SUFFIX
93
+ )
94
+ s += ASSISTANT_PREFIX
95
+ comps = call_generate(
96
+ s, max_tokens=256, temperature=temp, stop=None, n=num_branches
97
+ )
98
+ return [s + comp + ASSISTANT_SUFFIX for comp in comps]
99
+
100
+
101
+ def tree_search(question, num_branches, call_generate):
102
+ plan_forks = propose_plan("", question, num_branches, call_generate)
103
+
104
+ sol_states = []
105
+ for plan in plan_forks:
106
+ forks = execute_plan(plan, num_branches, call_generate)
107
+ sol_states.extend(forks)
108
+
109
+ ref_states = []
110
+ for sol in sol_states:
111
+ forks = reflect_solution(sol, num_branches, call_generate)
112
+ ref_states.extend(forks)
113
+
114
+ solutions = []
115
+ for sol in ref_states:
116
+ ans = get_final_answer(sol, num_branches, call_generate)
117
+ solutions.append(ans)
118
+
119
+ return solutions
120
+
121
+
122
+ def main(args):
123
+ lines = read_jsonl(args.data_path)
124
+
125
+ # Construct prompts
126
+ num_branches = 2
127
+ questions = []
128
+ labels = []
129
+ for i in range(len(lines[: args.num_questions])):
130
+ questions.append(lines[i]["question"])
131
+ labels.append(get_answer_value(lines[i]["answer"]))
132
+ assert all(l != INVALID for l in labels)
133
+ arguments = [{"question": q, "num_branches": num_branches} for q in questions]
134
+
135
+ # Select backend
136
+ call_generate = get_call_generate(args)
137
+
138
+ # Run requests
139
+ states = [None] * len(questions)
140
+
141
+ tic = time.perf_counter()
142
+ if args.backend != "lmql":
143
+
144
+ def get_one_answer(i):
145
+ states[i] = tree_search(**arguments[i], call_generate=call_generate)
146
+
147
+ if args.parallel == 1:
148
+ for i in tqdm(range(len(questions))):
149
+ get_one_answer(i)
150
+ else:
151
+ with ThreadPoolExecutor(args.parallel) as executor:
152
+ list(
153
+ tqdm(
154
+ executor.map(get_one_answer, list(range(len(questions)))),
155
+ total=len(questions),
156
+ )
157
+ )
158
+
159
+ else:
160
+ import asyncio
161
+
162
+ from lmql_funcs import tree_search_async
163
+
164
+ async def get_one_answer_async(i):
165
+ states[i] = await tree_search_async(
166
+ **arguments[i], call_generate=call_generate
167
+ )
168
+
169
+ batches = [
170
+ [] for _ in range((len(questions) + args.parallel - 1) // args.parallel)
171
+ ]
172
+ for i in range(len(questions)):
173
+ batches[i // args.parallel].append(i)
174
+
175
+ loop = asyncio.get_event_loop()
176
+ for bt in tqdm(batches):
177
+ tasks = [get_one_answer_async(k) for k in bt]
178
+ loop.run_until_complete(asyncio.gather(*tasks))
179
+
180
+ latency = time.perf_counter() - tic
181
+
182
+ answers_text = []
183
+ for s in states:
184
+ answers_text.append([x for xs in s for x in xs])
185
+
186
+ preds = []
187
+ for i in range(len(states)):
188
+ answers = [get_answer_value(v) for v in answers_text[i]]
189
+ preds.append(most_frequent_number(answers))
190
+
191
+ # Compute accuracy
192
+ acc = np.mean(np.array(preds) == np.array(labels))
193
+ invalid = np.mean(np.array(preds) == INVALID)
194
+ print(f"Latency: {latency:.3f}")
195
+ print(f"Invalid: {invalid:.3f}")
196
+ print(f"Accuracy: {acc:.3f}")
197
+
198
+ # Write results
199
+ dump_state_text(f"tmp_output_{args.backend}.txt", answers_text)
200
+
201
+ with open(args.result_file, "a") as fout:
202
+ value = {
203
+ "task": "tree_of_thought_gsm8k",
204
+ "backend": args.backend,
205
+ "num_gpus": 1,
206
+ "latency": round(latency, 3),
207
+ "accuracy": round(acc, 3),
208
+ "num_requests": args.num_questions,
209
+ "other": {
210
+ "num_questions": args.num_questions,
211
+ "parallel": args.parallel,
212
+ },
213
+ }
214
+ fout.write(json.dumps(value) + "\n")
215
+
216
+
217
+ if __name__ == "__main__":
218
+ parser = argparse.ArgumentParser()
219
+ parser.add_argument("--data-path", type=str, default="test.jsonl")
220
+ parser.add_argument("--num-questions", type=int, default=200)
221
+ args = add_common_other_args_and_parse(parser)
222
+ main(args)
sglang/benchmark/tree_of_thought_deep/bench_sglang.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import ast
3
+ import json
4
+ import re
5
+ import time
6
+ from collections import Counter
7
+
8
+ import numpy as np
9
+
10
+ import sglang as sgl
11
+ from sglang.test.test_utils import (
12
+ add_common_sglang_args_and_parse,
13
+ select_sglang_backend,
14
+ )
15
+ from sglang.utils import dump_state_text, read_jsonl
16
+
17
+ INVALID = -9999999
18
+
19
+
20
+ def get_answer_value(answer_str):
21
+ answer_str = answer_str.replace(",", "")
22
+ numbers = re.findall(r"\d+", answer_str)
23
+ if len(numbers) < 1:
24
+ return INVALID
25
+ try:
26
+ return ast.literal_eval(numbers[-1])
27
+ except SyntaxError:
28
+ return INVALID
29
+
30
+
31
+ def most_frequent_number(numbers):
32
+ if not numbers:
33
+ return None
34
+
35
+ frequency = Counter(numbers)
36
+ most_frequent = max(frequency, key=frequency.get)
37
+ return most_frequent
38
+
39
+
40
+ # Use a low temp to make the results more deterministic and the comparison more fair.
41
+ temp = 0.001
42
+
43
+
44
+ def propose_plan(s, question, num_branches):
45
+ s += sgl.user(
46
+ """Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """
47
+ + question
48
+ )
49
+ forks = s.fork(num_branches)
50
+ forks += sgl.assistant(sgl.gen("plan", max_tokens=256, temperature=temp))
51
+ return forks
52
+
53
+
54
+ def execute_plan(s, num_branches):
55
+ s += sgl.user(
56
+ """The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short."""
57
+ )
58
+ forks = s.fork(num_branches)
59
+ forks += sgl.assistant(sgl.gen("answer", max_tokens=256, temperature=temp))
60
+ return forks
61
+
62
+
63
+ def reflect_solution(s, num_branches):
64
+ s += sgl.user(
65
+ """Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness."""
66
+ )
67
+ forks = s.fork(num_branches)
68
+ forks += sgl.assistant(sgl.gen("score", max_tokens=256, temperature=temp))
69
+ return forks
70
+
71
+
72
+ def get_final_answer(s, num_branches):
73
+ s += sgl.user(
74
+ """Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration."""
75
+ )
76
+ forks = s.fork(num_branches)
77
+ forks += sgl.assistant(sgl.gen("final_answer", max_tokens=256, temperature=temp))
78
+ return forks
79
+
80
+
81
+ @sgl.function
82
+ def tree_search(s, question, num_branches):
83
+ plan_forks = propose_plan(s, question, num_branches)
84
+
85
+ sol_states = []
86
+ for plan in plan_forks:
87
+ forks = execute_plan(plan, num_branches)
88
+ sol_states.extend(forks)
89
+
90
+ ref_states = []
91
+ for sol in sol_states:
92
+ forks = reflect_solution(sol, num_branches)
93
+ ref_states.extend(forks)
94
+
95
+ solutions = []
96
+ for sol in ref_states:
97
+ forks = get_final_answer(sol, num_branches)
98
+ solutions.append(forks)
99
+ solutions = [[s.text() for s in forks] for forks in solutions]
100
+
101
+ return solutions
102
+
103
+
104
+ def main(args):
105
+ lines = read_jsonl(args.data_path)
106
+ lines = list(lines)
107
+
108
+ # Construct prompts
109
+ num_branches = 2
110
+ questions = []
111
+ labels = []
112
+ for i in range(len(lines[: args.num_questions])):
113
+ questions.append(lines[i]["question"])
114
+ labels.append(get_answer_value(lines[i]["answer"]))
115
+ assert all(l != INVALID for l in labels)
116
+ arguments = [{"question": q, "num_branches": num_branches} for q in questions]
117
+
118
+ # Select backend
119
+ backend = select_sglang_backend(args)
120
+
121
+ # Run requests
122
+ tic = time.perf_counter()
123
+ states = tree_search.run_batch(
124
+ arguments,
125
+ temperature=0,
126
+ backend=backend,
127
+ num_threads=args.parallel,
128
+ progress_bar=True,
129
+ )
130
+ latency = time.perf_counter() - tic
131
+ answers_text = []
132
+ for s in states:
133
+ answers_text.append([x for xs in s.ret_value for x in xs])
134
+
135
+ preds = []
136
+ for i in range(len(states)):
137
+ answers = [get_answer_value(v) for v in answers_text[i]]
138
+ preds.append(most_frequent_number(answers))
139
+
140
+ # Compute accuracy
141
+ acc = np.mean(np.array(preds) == np.array(labels))
142
+ invalid = np.mean(np.array(preds) == INVALID)
143
+ print(f"Latency: {latency:.3f}")
144
+ print(f"Invalid: {invalid:.3f}")
145
+ print(f"Accuracy: {acc:.3f}")
146
+
147
+ # Write results
148
+ dump_state_text(f"tmp_output_{args.backend}.txt", answers_text)
149
+
150
+ with open(args.result_file, "a") as fout:
151
+ value = {
152
+ "task": "tree_of_thought_gsm8k",
153
+ "backend": args.backend,
154
+ "num_gpus": 1,
155
+ "latency": round(latency, 3),
156
+ "accuracy": round(acc, 3),
157
+ "num_requests": args.num_questions,
158
+ "other": {
159
+ "num_questions": args.num_questions,
160
+ "parallel": args.parallel,
161
+ },
162
+ }
163
+ fout.write(json.dumps(value) + "\n")
164
+
165
+
166
+ if __name__ == "__main__":
167
+ parser = argparse.ArgumentParser()
168
+ parser.add_argument("--data-path", type=str, default="test.jsonl")
169
+ parser.add_argument("--num-questions", type=int, default=200)
170
+ args = add_common_sglang_args_and_parse(parser)
171
+ main(args)
sglang/docker/configs/.zshrc ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export ZSH="/root/.oh-my-zsh"
2
+
3
+ # Theme
4
+ ZSH_THEME="robbyrussell"
5
+
6
+ # Plugins
7
+ plugins=(
8
+ git
9
+ z
10
+ zsh-autosuggestions
11
+ zsh-syntax-highlighting
12
+ )
13
+
14
+ source $ZSH/oh-my-zsh.sh
15
+
16
+ # Aliases
17
+ alias ll='ls -alF'
18
+ alias la='ls -A'
19
+ alias l='ls -CF'
20
+ alias vi='vim'
21
+
22
+ # Enhanced history
23
+ HISTSIZE=10000
24
+ SAVEHIST=10000
25
+ setopt HIST_IGNORE_ALL_DUPS
26
+ setopt HIST_FIND_NO_DUPS
27
+ setopt INC_APPEND_HISTORY
sglang/docker/configs/opt/.gitconfig ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [core]
2
+ editor = vim
3
+ whitespace = fix,-indent-with-non-tab,trailing-space,cr-at-eol
4
+ pager = diff-so-fancy | less --tabs=4 -RFX
5
+
6
+ [color]
7
+ ui = true
8
+
9
+ [color "diff-highlight"]
10
+ oldNormal = red bold
11
+ oldHighlight = red bold 52
12
+ newNormal = green bold
13
+ newHighlight = green bold 22
14
+
15
+ [color "diff"]
16
+ meta = 11
17
+ frag = magenta bold
18
+ commit = yellow bold
19
+ old = red bold
20
+ new = green bold
21
+ whitespace = red reverse
22
+
23
+ [alias]
24
+ lg = log --color --graph --pretty=format:'%Cred%h%Creset - %s %Cgreen(%cr) %C(bold blue)<%an>%Creset%C(auto)%d%Creset' --abbrev-commit --
25
+
26
+ [http]
27
+ sslVerify = false
28
+
29
+ [pull]
30
+ rebase = true
sglang/docker/configs/opt/.tmux.conf ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pane border styling
2
+ set -g pane-border-style fg='#742727',bg=black
3
+ set -g pane-active-border-style fg=red,bg=black
4
+
5
+ # Status bar styling
6
+ set -g status-style bg='#0C8A92',fg=black
7
+
8
+ # Change prefix key to backtick
9
+ set-option -g prefix `
10
+ unbind C-b
11
+ bind-key ` send-prefix
12
+
13
+ # Split panes using - and = with current path
14
+ unbind '"'
15
+ bind - splitw -v -c '#{pane_current_path}'
16
+ unbind '%'
17
+ bind = splitw -h -c '#{pane_current_path}'
18
+
19
+ # Vi mode settings
20
+ bind-key -T copy-mode-vi Y send-keys -X copy-pipe 'yank > #{pane_tty}'
21
+ set-window-option -g mode-keys vi
22
+
23
+ # Other settings
24
+ set-option -g escape-time 0
25
+ set-option -g base-index 1
26
+ set-window-option -g mouse on
27
+ set -g history-limit 100000
sglang/docker/configs/opt/.vimrc ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function! Yank(text) abort
2
+ let escape = system('yank', a:text)
3
+ if v:shell_error
4
+ echoerr escape
5
+ else
6
+ call writefile([escape], '/dev/tty', 'b')
7
+ endif
8
+ endfunction
9
+
10
+ noremap <silent> <Leader>y y:<C-U>call Yank(@0)<CR>
11
+
12
+ " automatically run yank(1) whenever yanking in Vim
13
+ function! CopyYank() abort
14
+ call Yank(join(v:event.regcontents, "\n"))
15
+ endfunction
16
+
17
+ autocmd TextYankPost * call CopyYank()
18
+
19
+ " Basic settings
20
+ set number
21
+ syntax on
22
+ set mouse=a
23
+ filetype indent on
24
+
25
+ " Indentation
26
+ set autoindent nosmartindent
27
+ set smarttab
28
+ set expandtab
29
+ set shiftwidth=4
30
+ set softtabstop=4
31
+
32
+ " Visual guides
33
+ set colorcolumn=120
34
+ highlight ColorColumn ctermbg=5
35
+
36
+ " Status line
37
+ set laststatus=2
38
+ set statusline=%<%f\ %h%m%r%=%{\"[\".(&fenc==\"\"?&enc:&fenc).((exists(\"+bomb\")\ &&\ &bomb)?\",B\":\"\").\"]\ \"}%k\ %-14.(%l,%c%V%)\ %P
39
+
40
+ " Backspace behavior
41
+ set backspace=2
42
+
43
+ " Encoding
44
+ set encoding=utf-8
45
+ set fileencoding=utf-8
sglang/docker/configs/yank ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ put() {
3
+ esc=$1
4
+ test -n "$TMUX" -o -z "${TERM##screen*}" && esc="\033Ptmux;\033$esc\033\\"
5
+ printf "$esc"
6
+ }
7
+ put "\033]52;c;!\a"
8
+ buf=$( cat "$@" )
9
+ len=$( printf %s "$buf" | wc -c ) max=74994
10
+ test $len -gt $max && echo "$0: input is $(( len - max )) bytes too long" >&2
11
+ put "\033]52;c;$( printf %s "$buf" | head -c $max | base64 | tr -d '\r\n' )\a"
12
+ test -n "$TMUX" && tmux set-buffer "$buf" ||:
sglang/python/sglang.egg-info/PKG-INFO ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: sglang
3
+ Version: 0.5.9
4
+ Summary: SGLang is a fast serving framework for large language models and vision language models.
5
+ Project-URL: Homepage, https://github.com/sgl-project/sglang
6
+ Project-URL: Bug Tracker, https://github.com/sgl-project/sglang/issues
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: Apache Software License
9
+ Requires-Python: >=3.10
10
+ Description-Content-Type: text/markdown
11
+ Requires-Dist: IPython
12
+ Requires-Dist: aiohttp
13
+ Requires-Dist: apache-tvm-ffi<0.2,>=0.1.5
14
+ Requires-Dist: anthropic>=0.20.0
15
+ Requires-Dist: blobfile==3.0.0
16
+ Requires-Dist: build
17
+ Requires-Dist: compressed-tensors
18
+ Requires-Dist: cuda-python==12.9
19
+ Requires-Dist: decord2
20
+ Requires-Dist: datasets
21
+ Requires-Dist: einops
22
+ Requires-Dist: fastapi
23
+ Requires-Dist: flashinfer_python==0.6.4
24
+ Requires-Dist: flashinfer_cubin==0.6.4
25
+ Requires-Dist: gguf
26
+ Requires-Dist: hf_transfer
27
+ Requires-Dist: huggingface_hub
28
+ Requires-Dist: interegular
29
+ Requires-Dist: llguidance<0.8.0,>=0.7.11
30
+ Requires-Dist: modelscope
31
+ Requires-Dist: msgspec
32
+ Requires-Dist: ninja
33
+ Requires-Dist: numpy
34
+ Requires-Dist: nvidia-cutlass-dsl>=4.3.4
35
+ Requires-Dist: nvidia-ml-py
36
+ Requires-Dist: openai-harmony==0.0.4
37
+ Requires-Dist: openai==2.6.1
38
+ Requires-Dist: orjson
39
+ Requires-Dist: outlines==0.1.11
40
+ Requires-Dist: packaging
41
+ Requires-Dist: partial_json_parser
42
+ Requires-Dist: pillow
43
+ Requires-Dist: prometheus-client>=0.20.0
44
+ Requires-Dist: psutil
45
+ Requires-Dist: py-spy
46
+ Requires-Dist: pybase64
47
+ Requires-Dist: pydantic
48
+ Requires-Dist: python-multipart
49
+ Requires-Dist: pyzmq>=25.1.2
50
+ Requires-Dist: quack-kernels==0.2.4
51
+ Requires-Dist: requests
52
+ Requires-Dist: scipy
53
+ Requires-Dist: sentencepiece
54
+ Requires-Dist: setproctitle
55
+ Requires-Dist: sgl-fa4==4.0.3
56
+ Requires-Dist: sgl-kernel==0.3.21
57
+ Requires-Dist: soundfile==0.13.1
58
+ Requires-Dist: tiktoken
59
+ Requires-Dist: timm==1.0.16
60
+ Requires-Dist: torch_memory_saver==0.0.9
61
+ Requires-Dist: torch==2.9.1
62
+ Requires-Dist: torchao==0.9.0
63
+ Requires-Dist: torchaudio==2.9.1
64
+ Requires-Dist: torchcodec==0.8.0; sys_platform != "linux" or (sys_platform == "linux" and platform_machine != "aarch64" and platform_machine != "arm64" and platform_machine != "armv7l")
65
+ Requires-Dist: torchvision
66
+ Requires-Dist: tqdm
67
+ Requires-Dist: transformers==4.57.1
68
+ Requires-Dist: uvicorn
69
+ Requires-Dist: uvloop
70
+ Requires-Dist: watchfiles
71
+ Requires-Dist: xgrammar==0.1.27
72
+ Requires-Dist: smg-grpc-proto>=0.4.1
73
+ Requires-Dist: grpcio>=1.78.0
74
+ Requires-Dist: grpcio-reflection>=1.78.0
75
+ Requires-Dist: grpcio-health-checking>=1.78.0
76
+ Provides-Extra: checkpoint-engine
77
+ Requires-Dist: checkpoint-engine==0.1.2; extra == "checkpoint-engine"
78
+ Provides-Extra: diffusion
79
+ Requires-Dist: PyYAML==6.0.1; extra == "diffusion"
80
+ Requires-Dist: cloudpickle==3.1.2; extra == "diffusion"
81
+ Requires-Dist: diffusers==0.36.0; extra == "diffusion"
82
+ Requires-Dist: imageio==2.36.0; extra == "diffusion"
83
+ Requires-Dist: imageio-ffmpeg==0.5.1; extra == "diffusion"
84
+ Requires-Dist: moviepy>=2.0.0; extra == "diffusion"
85
+ Requires-Dist: opencv-python-headless==4.10.0.84; extra == "diffusion"
86
+ Requires-Dist: remote-pdb==2.1.0; extra == "diffusion"
87
+ Requires-Dist: st_attn==0.0.7; (platform_machine != "aarch64" and platform_machine != "arm64") and extra == "diffusion"
88
+ Requires-Dist: vsa==0.0.4; (platform_machine != "aarch64" and platform_machine != "arm64") and extra == "diffusion"
89
+ Requires-Dist: runai_model_streamer>=0.15.5; extra == "diffusion"
90
+ Requires-Dist: cache-dit==1.2.3; extra == "diffusion"
91
+ Requires-Dist: addict==2.4.0; extra == "diffusion"
92
+ Requires-Dist: av==16.1.0; extra == "diffusion"
93
+ Requires-Dist: scikit-image==0.25.2; extra == "diffusion"
94
+ Requires-Dist: trimesh>=4.0.0; extra == "diffusion"
95
+ Requires-Dist: xatlas; extra == "diffusion"
96
+ Provides-Extra: ray
97
+ Requires-Dist: ray[default]>=2.54.0; extra == "ray"
98
+ Provides-Extra: tracing
99
+ Requires-Dist: opentelemetry-api; extra == "tracing"
100
+ Requires-Dist: opentelemetry-exporter-otlp; extra == "tracing"
101
+ Requires-Dist: opentelemetry-exporter-otlp-proto-grpc; extra == "tracing"
102
+ Requires-Dist: opentelemetry-sdk; extra == "tracing"
103
+ Provides-Extra: test
104
+ Requires-Dist: accelerate; extra == "test"
105
+ Requires-Dist: bitsandbytes; extra == "test"
106
+ Requires-Dist: expecttest; extra == "test"
107
+ Requires-Dist: jsonlines; extra == "test"
108
+ Requires-Dist: lm-eval[api]>=0.4.9.2; extra == "test"
109
+ Requires-Dist: matplotlib; extra == "test"
110
+ Requires-Dist: pandas; extra == "test"
111
+ Requires-Dist: parameterized; extra == "test"
112
+ Requires-Dist: peft; extra == "test"
113
+ Requires-Dist: pytest; extra == "test"
114
+ Requires-Dist: sentence_transformers; extra == "test"
115
+ Requires-Dist: tabulate; extra == "test"
116
+ Provides-Extra: dev
117
+ Requires-Dist: sglang[test]; extra == "dev"
118
+ Provides-Extra: all
119
+ Requires-Dist: sglang[diffusion]; extra == "all"
120
+ Requires-Dist: sglang[tracing]; extra == "all"
sglang/python/sglang.egg-info/SOURCES.txt ADDED
The diff for this file is too large to render. See raw diff
 
sglang/python/sglang.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
sglang/python/sglang.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ sglang = sglang.cli.main:main
sglang/python/sglang.egg-info/requires.txt ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ IPython
2
+ aiohttp
3
+ apache-tvm-ffi<0.2,>=0.1.5
4
+ anthropic>=0.20.0
5
+ blobfile==3.0.0
6
+ build
7
+ compressed-tensors
8
+ cuda-python==12.9
9
+ decord2
10
+ datasets
11
+ einops
12
+ fastapi
13
+ flashinfer_python==0.6.4
14
+ flashinfer_cubin==0.6.4
15
+ gguf
16
+ hf_transfer
17
+ huggingface_hub
18
+ interegular
19
+ llguidance<0.8.0,>=0.7.11
20
+ modelscope
21
+ msgspec
22
+ ninja
23
+ numpy
24
+ nvidia-cutlass-dsl>=4.3.4
25
+ nvidia-ml-py
26
+ openai-harmony==0.0.4
27
+ openai==2.6.1
28
+ orjson
29
+ outlines==0.1.11
30
+ packaging
31
+ partial_json_parser
32
+ pillow
33
+ prometheus-client>=0.20.0
34
+ psutil
35
+ py-spy
36
+ pybase64
37
+ pydantic
38
+ python-multipart
39
+ pyzmq>=25.1.2
40
+ quack-kernels==0.2.4
41
+ requests
42
+ scipy
43
+ sentencepiece
44
+ setproctitle
45
+ sgl-fa4==4.0.3
46
+ sgl-kernel==0.3.21
47
+ soundfile==0.13.1
48
+ tiktoken
49
+ timm==1.0.16
50
+ torch_memory_saver==0.0.9
51
+ torch==2.9.1
52
+ torchao==0.9.0
53
+ torchaudio==2.9.1
54
+ torchvision
55
+ tqdm
56
+ transformers==4.57.1
57
+ uvicorn
58
+ uvloop
59
+ watchfiles
60
+ xgrammar==0.1.27
61
+ smg-grpc-proto>=0.4.1
62
+ grpcio>=1.78.0
63
+ grpcio-reflection>=1.78.0
64
+ grpcio-health-checking>=1.78.0
65
+
66
+ [:sys_platform != "linux" or (sys_platform == "linux" and platform_machine != "aarch64" and platform_machine != "arm64" and platform_machine != "armv7l")]
67
+ torchcodec==0.8.0
68
+
69
+ [all]
70
+ sglang[diffusion]
71
+ sglang[tracing]
72
+
73
+ [checkpoint-engine]
74
+ checkpoint-engine==0.1.2
75
+
76
+ [dev]
77
+ sglang[test]
78
+
79
+ [diffusion]
80
+ PyYAML==6.0.1
81
+ cloudpickle==3.1.2
82
+ diffusers==0.36.0
83
+ imageio==2.36.0
84
+ imageio-ffmpeg==0.5.1
85
+ moviepy>=2.0.0
86
+ opencv-python-headless==4.10.0.84
87
+ remote-pdb==2.1.0
88
+ runai_model_streamer>=0.15.5
89
+ cache-dit==1.2.3
90
+ addict==2.4.0
91
+ av==16.1.0
92
+ scikit-image==0.25.2
93
+ trimesh>=4.0.0
94
+ xatlas
95
+
96
+ [diffusion:platform_machine != "aarch64" and platform_machine != "arm64"]
97
+ st_attn==0.0.7
98
+ vsa==0.0.4
99
+
100
+ [ray]
101
+ ray[default]>=2.54.0
102
+
103
+ [test]
104
+ accelerate
105
+ bitsandbytes
106
+ expecttest
107
+ jsonlines
108
+ lm-eval[api]>=0.4.9.2
109
+ matplotlib
110
+ pandas
111
+ parameterized
112
+ peft
113
+ pytest
114
+ sentence_transformers
115
+ tabulate
116
+
117
+ [tracing]
118
+ opentelemetry-api
119
+ opentelemetry-exporter-otlp
120
+ opentelemetry-exporter-otlp-proto-grpc
121
+ opentelemetry-sdk
sglang/python/sglang.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ sglang
sglang/python/sglang/README.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code Structure
2
+
3
+ - `eval`: The evaluation utilities.
4
+ - `lang`: The frontend language.
5
+ - `multimodal_gen`: Inference framework for accelerated image/video generation.
6
+ - `srt`: The backend engine for running local models. (SRT = SGLang Runtime).
7
+ - `test`: The test utilities.
8
+ - `api.py`: The public APIs.
9
+ - `bench_offline_throughput.py`: Benchmark the performance in the offline mode.
10
+ - `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server.
11
+ - `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server.
12
+ - `bench_serving.py`: Benchmark online serving with dynamic requests.
13
+ - `check_env.py`: Check the environment variables and dependencies.
14
+ - `global_config.py`: The global configs and constants.
15
+ - `launch_server.py`: The entry point for launching a local server.
16
+ - `profiler.py`: The profiling entry point to send profile requests.
17
+ - `utils.py`: Common utilities.
18
+ - `version.py`: Version info.
sglang/python/sglang/__init__.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SGLang public APIs
2
+
3
+ # Frontend Language APIs
4
+ from sglang.global_config import global_config
5
+ from sglang.lang.api import (
6
+ Engine,
7
+ Runtime,
8
+ assistant,
9
+ assistant_begin,
10
+ assistant_end,
11
+ flush_cache,
12
+ function,
13
+ gen,
14
+ gen_int,
15
+ gen_string,
16
+ get_server_info,
17
+ image,
18
+ select,
19
+ separate_reasoning,
20
+ set_default_backend,
21
+ system,
22
+ system_begin,
23
+ system_end,
24
+ user,
25
+ user_begin,
26
+ user_end,
27
+ video,
28
+ )
29
+ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
30
+ from sglang.lang.choices import (
31
+ greedy_token_selection,
32
+ token_length_normalized,
33
+ unconditional_likelihood_normalized,
34
+ )
35
+
36
+ # Lazy import some libraries
37
+ from sglang.utils import LazyImport
38
+ from sglang.version import __version__
39
+
40
+ Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
41
+ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
42
+ OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
43
+ VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
44
+
45
+ # Runtime Engine APIs
46
+ ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
47
+ Engine = LazyImport("sglang.srt.entrypoints.engine", "Engine")
48
+
49
+ __all__ = [
50
+ "Engine",
51
+ "Runtime",
52
+ "assistant",
53
+ "assistant_begin",
54
+ "assistant_end",
55
+ "flush_cache",
56
+ "function",
57
+ "gen",
58
+ "gen_int",
59
+ "gen_string",
60
+ "get_server_info",
61
+ "image",
62
+ "select",
63
+ "separate_reasoning",
64
+ "set_default_backend",
65
+ "system",
66
+ "system_begin",
67
+ "system_end",
68
+ "user",
69
+ "user_begin",
70
+ "user_end",
71
+ "video",
72
+ "RuntimeEndpoint",
73
+ "greedy_token_selection",
74
+ "token_length_normalized",
75
+ "unconditional_likelihood_normalized",
76
+ "ServerArgs",
77
+ "Anthropic",
78
+ "LiteLLM",
79
+ "OpenAI",
80
+ "VertexAI",
81
+ "global_config",
82
+ "__version__",
83
+ ]
sglang/python/sglang/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.08 kB). View file
 
sglang/python/sglang/__pycache__/_version.cpython-311.pyc ADDED
Binary file (872 Bytes). View file
 
sglang/python/sglang/__pycache__/bench_serving.cpython-311.pyc ADDED
Binary file (95.3 kB). View file
 
sglang/python/sglang/__pycache__/check_env.cpython-311.pyc ADDED
Binary file (24.9 kB). View file
 
sglang/python/sglang/__pycache__/global_config.cpython-311.pyc ADDED
Binary file (969 Bytes). View file
 
sglang/python/sglang/__pycache__/launch_server.cpython-311.pyc ADDED
Binary file (2.62 kB). View file
 
sglang/python/sglang/__pycache__/utils.cpython-311.pyc ADDED
Binary file (34.6 kB). View file
 
sglang/python/sglang/__pycache__/version.cpython-311.pyc ADDED
Binary file (1.31 kB). View file
 
sglang/python/sglang/_version.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
13
+ TYPE_CHECKING = False
14
+ if TYPE_CHECKING:
15
+ from typing import Tuple
16
+ from typing import Union
17
+
18
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
20
+ else:
21
+ VERSION_TUPLE = object
22
+ COMMIT_ID = object
23
+
24
+ version: str
25
+ __version__: str
26
+ __version_tuple__: VERSION_TUPLE
27
+ version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
30
+
31
+ __version__ = version = '0.5.9'
32
+ __version_tuple__ = version_tuple = (0, 5, 9)
33
+
34
+ __commit_id__ = commit_id = 'gbbe9c7eeb'
sglang/python/sglang/bench_offline_throughput.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Benchmark the throughput in the offline mode.
3
+ It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py).
4
+
5
+ # Usage
6
+ ## Sharegpt dataset with default args
7
+ python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10
8
+
9
+ ## Random dataset with default args
10
+ python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024
11
+ """
12
+
13
+ import argparse
14
+ import asyncio
15
+ import dataclasses
16
+ import inspect
17
+ import json
18
+ import logging
19
+ import os
20
+ import random
21
+ import time
22
+ from typing import Dict, List, Optional
23
+
24
+ import numpy as np
25
+
26
+ from sglang.benchmark.datasets import DatasetRow, get_dataset
27
+ from sglang.benchmark.datasets.random import sample_random_requests
28
+ from sglang.benchmark.utils import get_tokenizer, set_ulimit
29
+ from sglang.lang.backend.runtime_endpoint import Runtime
30
+ from sglang.srt.entrypoints.engine import Engine
31
+ from sglang.srt.server_args import ServerArgs
32
+
33
+
34
+ @dataclasses.dataclass
35
+ class BenchArgs:
36
+ backend: str = "engine"
37
+ result_filename: str = ""
38
+ dataset_name: str = "sharegpt"
39
+ dataset_path: str = ""
40
+ num_prompts: int = 1000
41
+ sharegpt_output_len: Optional[int] = None
42
+ sharegpt_context_len: Optional[int] = None
43
+ random_input_len: int = 1024
44
+ random_output_len: int = 1024
45
+ random_range_ratio: float = 0.0
46
+ gsp_num_groups: int = 64
47
+ gsp_prompts_per_group: int = 16
48
+ gsp_system_prompt_len: int = 2048
49
+ gsp_question_len: int = 128
50
+ gsp_output_len: int = 256
51
+ seed: int = 1
52
+ disable_ignore_eos: bool = False
53
+ extra_request_body: Optional[str] = None
54
+ apply_chat_template: bool = False
55
+ profile: bool = False
56
+ skip_warmup: bool = False
57
+ do_not_exit: bool = False
58
+ prompt_suffix: str = ""
59
+ return_logprob: bool = False
60
+ logprob_start_len: int = -1
61
+
62
+ @staticmethod
63
+ def add_cli_args(parser: argparse.ArgumentParser):
64
+ parser.add_argument("--backend", type=str, default=BenchArgs.backend)
65
+ parser.add_argument(
66
+ "--result-filename", type=str, default=BenchArgs.result_filename
67
+ )
68
+ parser.add_argument(
69
+ "--dataset-name",
70
+ type=str,
71
+ default="sharegpt",
72
+ choices=["sharegpt", "random", "generated-shared-prefix"],
73
+ help="Name of the dataset to benchmark on.",
74
+ )
75
+ parser.add_argument(
76
+ "--dataset-path", type=str, default="", help="Path to the dataset."
77
+ )
78
+ parser.add_argument(
79
+ "--num-prompts",
80
+ type=int,
81
+ default=BenchArgs.num_prompts,
82
+ help="Number of prompts to process. Default is 1000.",
83
+ )
84
+ parser.add_argument(
85
+ "--sharegpt-output-len",
86
+ type=int,
87
+ default=BenchArgs.sharegpt_output_len,
88
+ help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
89
+ )
90
+ parser.add_argument(
91
+ "--sharegpt-context-len",
92
+ type=int,
93
+ default=BenchArgs.sharegpt_context_len,
94
+ help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
95
+ )
96
+ parser.add_argument(
97
+ "--random-input-len",
98
+ type=int,
99
+ default=BenchArgs.random_input_len,
100
+ help="Number of input tokens per request, used only for random dataset.",
101
+ )
102
+ parser.add_argument(
103
+ "--random-output-len",
104
+ type=int,
105
+ default=BenchArgs.random_output_len,
106
+ help="Number of output tokens per request, used only for random dataset.",
107
+ )
108
+ parser.add_argument(
109
+ "--random-range-ratio",
110
+ type=float,
111
+ default=BenchArgs.random_range_ratio,
112
+ help="Range of sampled ratio of input/output length, "
113
+ "used only for random dataset.",
114
+ )
115
+ parser.add_argument(
116
+ "--gsp-num-groups",
117
+ type=int,
118
+ default=BenchArgs.gsp_num_groups,
119
+ help="Number of groups with shared prefix, used"
120
+ "only for generate-shared-prefix",
121
+ )
122
+ parser.add_argument(
123
+ "--gsp-prompts-per-group",
124
+ type=int,
125
+ default=BenchArgs.gsp_prompts_per_group,
126
+ help="Number of prompts per group of shared prefix, used"
127
+ "only for generate-shared-prefix",
128
+ )
129
+ parser.add_argument(
130
+ "--gsp-system-prompt-len",
131
+ type=int,
132
+ default=BenchArgs.gsp_system_prompt_len,
133
+ help="System prompt length, used" "only for generate-shared-prefix",
134
+ )
135
+ parser.add_argument(
136
+ "--gsp-question-len",
137
+ type=int,
138
+ default=BenchArgs.gsp_question_len,
139
+ help="Question length, used" "only for generate-shared-prefix",
140
+ )
141
+ parser.add_argument(
142
+ "--gsp-output-len",
143
+ type=int,
144
+ default=BenchArgs.gsp_output_len,
145
+ help="Target length in tokens for outputs in generated-shared-prefix dataset",
146
+ )
147
+ parser.add_argument("--seed", type=int, default=1, help="The random seed.")
148
+ parser.add_argument(
149
+ "--disable-ignore-eos",
150
+ action="store_true",
151
+ help="Disable ignore EOS token",
152
+ )
153
+ parser.add_argument(
154
+ "--extra-request-body",
155
+ metavar='{"key1": "value1", "key2": "value2"}',
156
+ type=str,
157
+ default=BenchArgs.extra_request_body,
158
+ help="Append given JSON object to the request payload. You can use this to specify"
159
+ "additional generate params like sampling params.",
160
+ )
161
+ parser.add_argument(
162
+ "--apply-chat-template",
163
+ action="store_true",
164
+ help="Apply chat template",
165
+ )
166
+ parser.add_argument(
167
+ "--profile",
168
+ action="store_true",
169
+ help="Use Torch Profiler. The endpoint must be launched with "
170
+ "SGLANG_TORCH_PROFILER_DIR to enable profiler.",
171
+ )
172
+ parser.add_argument(
173
+ "--skip-warmup",
174
+ action="store_true",
175
+ help="Skip the warmup batches.",
176
+ )
177
+ parser.add_argument(
178
+ "--do-not-exit",
179
+ action="store_true",
180
+ help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
181
+ )
182
+ parser.add_argument(
183
+ "--prompt-suffix",
184
+ type=str,
185
+ default="",
186
+ help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
187
+ )
188
+ parser.add_argument(
189
+ "--return-logprob",
190
+ action="store_true",
191
+ help="Enable returning log probabilities.",
192
+ )
193
+ parser.add_argument(
194
+ "--logprob-start-len",
195
+ type=int,
196
+ default=-1,
197
+ help="Start length for logprob. -1 means only return logprobs for output tokens (default). 0 means return logprobs for all tokens including input.",
198
+ )
199
+
200
+ @classmethod
201
+ def from_cli_args(cls, args: argparse.Namespace):
202
+ attrs = [attr.name for attr in dataclasses.fields(cls)]
203
+ return cls(**{attr: getattr(args, attr) for attr in attrs})
204
+
205
+
206
+ def throughput_test_once(
207
+ backend_name: str,
208
+ backend,
209
+ reqs: List[DatasetRow],
210
+ ignore_eos: bool,
211
+ extra_request_body: Dict,
212
+ profile: bool,
213
+ return_logprob: bool = False,
214
+ logprob_start_len: int = -1,
215
+ ):
216
+ measurement_results = {
217
+ "backend": backend_name,
218
+ "successful_requests": len(reqs),
219
+ "total_latency": -1,
220
+ "total_input_tokens": sum(r.prompt_len for r in reqs),
221
+ "total_output_tokens": -1,
222
+ "request_throughput": -1,
223
+ "input_throughput": -1,
224
+ "output_throughput": -1,
225
+ "total_throughput": -1,
226
+ }
227
+
228
+ prompt = [r.prompt for r in reqs]
229
+ sampling_params = [
230
+ {
231
+ "temperature": 0,
232
+ "max_new_tokens": r.output_len,
233
+ "ignore_eos": ignore_eos,
234
+ **extra_request_body,
235
+ }
236
+ for r in reqs
237
+ ]
238
+
239
+ if profile:
240
+ assert (
241
+ "SGLANG_TORCH_PROFILER_DIR" in os.environ
242
+ ), "Please set SGLANG_TORCH_PROFILER_DIR."
243
+ os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True)
244
+ backend.start_profile()
245
+
246
+ st = time.perf_counter()
247
+ gen_out = backend.generate(
248
+ prompt=prompt,
249
+ sampling_params=sampling_params,
250
+ return_logprob=return_logprob,
251
+ logprob_start_len=logprob_start_len,
252
+ )
253
+ latency = time.perf_counter() - st
254
+
255
+ if profile:
256
+ dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
257
+ known_files = set(os.listdir(dir))
258
+ backend.stop_profile()
259
+ monitor_trace_file(known_files, dir)
260
+
261
+ if backend_name == "runtime":
262
+ gen_out = json.loads(gen_out)
263
+
264
+ server_info = backend.get_server_info()
265
+
266
+ measurement_results["total_latency"] = latency
267
+ measurement_results["total_output_tokens"] = sum(
268
+ o["meta_info"]["completion_tokens"] for o in gen_out
269
+ )
270
+ measurement_results["request_throughput"] = (
271
+ measurement_results["successful_requests"] / latency
272
+ )
273
+ measurement_results["input_throughput"] = (
274
+ measurement_results["total_input_tokens"] / latency
275
+ )
276
+ measurement_results["output_throughput"] = (
277
+ measurement_results["total_output_tokens"] / latency
278
+ )
279
+ measurement_results["total_throughput"] = (
280
+ measurement_results["total_input_tokens"]
281
+ + measurement_results["total_output_tokens"]
282
+ ) / latency
283
+
284
+ if inspect.isawaitable(server_info):
285
+ server_info = asyncio.run(server_info)
286
+
287
+ measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
288
+ "last_gen_throughput"
289
+ ]
290
+
291
+ return measurement_results
292
+
293
+
294
+ def monitor_trace_file(known_files, directory, interval=1):
295
+ print(f"Monitoring {directory} for new trace files...")
296
+
297
+ while True:
298
+ flag = False
299
+ time.sleep(interval)
300
+ current_files = set(os.listdir(directory))
301
+
302
+ new_files = current_files - known_files
303
+ for new_file in new_files:
304
+ new_file_path = os.path.join(directory, new_file)
305
+ print(f"New file detected: {new_file}")
306
+
307
+ previous_size = 0
308
+ while True:
309
+ try:
310
+ current_size = os.path.getsize(new_file_path)
311
+ except FileNotFoundError:
312
+ print(f"File {new_file} is no longer accessible.")
313
+ break
314
+
315
+ if current_size > previous_size:
316
+ previous_size = current_size
317
+ else:
318
+ flag = True
319
+ break
320
+
321
+ time.sleep(interval)
322
+ if flag:
323
+ break
324
+
325
+
326
+ def _create_ray_engine_backend(server_args: ServerArgs):
327
+ """Create a RayEngine inside a Ray actor on a placement group.
328
+
329
+ RayEngine requires a placement group, so we launch it inside a Ray actor
330
+ and return a lightweight proxy that forwards calls via ray.get().
331
+ """
332
+ import ray
333
+ from ray.runtime_env import RuntimeEnv
334
+ from ray.util.placement_group import placement_group
335
+ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
336
+
337
+ env_vars = {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}
338
+ if os.environ.get("HF_TOKEN"):
339
+ env_vars["HF_TOKEN"] = os.environ["HF_TOKEN"]
340
+ if not ray.is_initialized():
341
+ ray.init(runtime_env=RuntimeEnv(env_vars=env_vars))
342
+
343
+ total_gpus = server_args.tp_size * server_args.pp_size
344
+ pg = placement_group([{"CPU": 1, "GPU": total_gpus}], strategy="STRICT_PACK")
345
+ ray.get(pg.ready())
346
+
347
+ @ray.remote
348
+ class _EngineActor:
349
+ def __init__(self, **kwargs):
350
+ from sglang.srt.ray.engine import RayEngine
351
+
352
+ self.engine = RayEngine(**kwargs)
353
+
354
+ def call(self, method, **kwargs):
355
+ return getattr(self.engine, method)(**kwargs)
356
+
357
+ actor = _EngineActor.options(
358
+ num_cpus=1,
359
+ num_gpus=0,
360
+ scheduling_strategy=PlacementGroupSchedulingStrategy(
361
+ placement_group=pg,
362
+ placement_group_bundle_index=0,
363
+ ),
364
+ ).remote(**dataclasses.asdict(server_args))
365
+
366
+ class _Proxy:
367
+ """Forwards method calls to the remote RayEngine actor."""
368
+
369
+ def generate(self, **kwargs):
370
+ return ray.get(actor.call.remote("generate", **kwargs))
371
+
372
+ def get_server_info(self, **kwargs):
373
+ return ray.get(actor.call.remote("get_server_info", **kwargs))
374
+
375
+ def start_profile(self, **kwargs):
376
+ return ray.get(actor.call.remote("start_profile", **kwargs))
377
+
378
+ def stop_profile(self, **kwargs):
379
+ return ray.get(actor.call.remote("stop_profile", **kwargs))
380
+
381
+ def shutdown(self):
382
+ try:
383
+ ray.get(actor.call.remote("shutdown"), timeout=60)
384
+ except Exception:
385
+ pass
386
+ try:
387
+ ray.util.remove_placement_group(pg)
388
+ except Exception:
389
+ pass
390
+
391
+ return _Proxy()
392
+
393
+
394
+ def throughput_test(
395
+ server_args: ServerArgs,
396
+ bench_args: BenchArgs,
397
+ ):
398
+ if bench_args.backend == "engine":
399
+ if server_args.use_ray:
400
+ backend = _create_ray_engine_backend(server_args)
401
+ else:
402
+ backend = Engine(**dataclasses.asdict(server_args))
403
+ if not backend:
404
+ raise ValueError("Please provide valid engine arguments")
405
+ elif bench_args.backend == "runtime":
406
+ backend = Runtime(**dataclasses.asdict(server_args))
407
+ else:
408
+ raise ValueError('Please set backend to either "engine" or "runtime"')
409
+
410
+ tokenizer_id = server_args.tokenizer_path or server_args.model_path
411
+ tokenizer = get_tokenizer(tokenizer_id)
412
+
413
+ # Set global environments
414
+ set_ulimit()
415
+ random.seed(bench_args.seed)
416
+ np.random.seed(bench_args.seed)
417
+
418
+ # Parse args
419
+ extra_request_body = {}
420
+ if bench_args.extra_request_body:
421
+ extra_request_body = json.loads(args.extra_request_body)
422
+
423
+ # Read dataset
424
+ input_requests = get_dataset(bench_args, tokenizer)
425
+
426
+ warmup_requests = sample_random_requests(
427
+ input_len=256,
428
+ output_len=16,
429
+ num_prompts=min(bench_args.num_prompts, 16),
430
+ range_ratio=1.0,
431
+ tokenizer=tokenizer,
432
+ dataset_path=bench_args.dataset_path,
433
+ )
434
+
435
+ # Warm up
436
+ if not bench_args.skip_warmup:
437
+ logging.info("\nWarmup...")
438
+ throughput_test_once(
439
+ backend_name=bench_args.backend,
440
+ backend=backend,
441
+ reqs=warmup_requests,
442
+ ignore_eos=not bench_args.disable_ignore_eos,
443
+ extra_request_body=extra_request_body,
444
+ profile=False,
445
+ return_logprob=bench_args.return_logprob,
446
+ logprob_start_len=bench_args.logprob_start_len,
447
+ )
448
+ time.sleep(0.5)
449
+
450
+ logging.info("\nBenchmark...")
451
+ result = throughput_test_once(
452
+ backend_name=bench_args.backend,
453
+ backend=backend,
454
+ reqs=input_requests,
455
+ ignore_eos=not bench_args.disable_ignore_eos,
456
+ extra_request_body=extra_request_body,
457
+ profile=bench_args.profile,
458
+ return_logprob=bench_args.return_logprob,
459
+ logprob_start_len=bench_args.logprob_start_len,
460
+ )
461
+ backend.shutdown()
462
+
463
+ if bench_args.result_filename:
464
+ with open(bench_args.result_filename, "a") as fout:
465
+ fout.write(json.dumps(result) + "\n")
466
+
467
+ print(
468
+ "\n{s:{c}^{n}}".format(s=" Offline Throughput Benchmark Result ", n=50, c="=")
469
+ )
470
+ print("{:<40} {:<10}".format("Backend:", result["backend"]))
471
+ print("{:<40} {:<10}".format("Successful requests:", result["successful_requests"]))
472
+ print("{:<40} {:<10.2f}".format("Benchmark duration (s):", result["total_latency"]))
473
+ print("{:<40} {:<10}".format("Total input tokens:", result["total_input_tokens"]))
474
+ print(
475
+ "{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
476
+ )
477
+ print(
478
+ "{:<40} {:<10.2f}".format(
479
+ "Last generation throughput (tok/s):", result["last_gen_throughput"]
480
+ )
481
+ )
482
+ print(
483
+ "{:<40} {:<10.2f}".format(
484
+ "Request throughput (req/s):", result["request_throughput"]
485
+ )
486
+ )
487
+ print(
488
+ "{:<40} {:<10.2f}".format(
489
+ "Input token throughput (tok/s):", result["input_throughput"]
490
+ )
491
+ )
492
+ print(
493
+ "{:<40} {:<10.2f}".format(
494
+ "Output token throughput (tok/s):", result["output_throughput"]
495
+ )
496
+ )
497
+ print(
498
+ "{:<40} {:<10.2f}".format(
499
+ "Total token throughput (tok/s):", result["total_throughput"]
500
+ )
501
+ )
502
+ print("=" * 50)
503
+
504
+ return result
505
+
506
+
507
+ if __name__ == "__main__":
508
+ parser = argparse.ArgumentParser()
509
+ ServerArgs.add_cli_args(parser)
510
+ BenchArgs.add_cli_args(parser)
511
+ args = parser.parse_args()
512
+
513
+ # handling ModelScope model downloads
514
+ if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() in ("true", "1"):
515
+ if os.path.exists(args.model_path):
516
+ print(f"Using local model path: {args.model_path}")
517
+ else:
518
+ try:
519
+ from modelscope import snapshot_download
520
+
521
+ print(f"Using ModelScope to download model: {args.model_path}")
522
+
523
+ # download the model and replace args.model_path
524
+ args.model_path = snapshot_download(
525
+ args.model_path,
526
+ )
527
+ print(f"Model downloaded to: {args.model_path}")
528
+ except Exception as e:
529
+ print(f"ModelScope download failed: {str(e)}")
530
+ raise e
531
+
532
+ server_args = ServerArgs.from_cli_args(args)
533
+ bench_args = BenchArgs.from_cli_args(args)
534
+
535
+ logging.basicConfig(
536
+ level=getattr(logging, server_args.log_level.upper()),
537
+ format="%(message)s",
538
+ )
539
+
540
+ throughput_test(server_args, bench_args)
541
+
542
+ while bench_args.do_not_exit:
543
+ pass
sglang/python/sglang/bench_one_batch.py ADDED
@@ -0,0 +1,837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Benchmark the latency of running a single static batch without a server.
3
+
4
+ This script does not launch a server and uses the low-level APIs.
5
+ It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
6
+
7
+ # Usage (latency test)
8
+ ## with dummy weights:
9
+ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
10
+ ## sweep through multiple data points and store (append) the results in a jsonl file:
11
+ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
12
+ ## run with profiling:
13
+ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile
14
+ ## run with profiling to custom directory:
15
+ export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log
16
+ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile
17
+ ## run with CUDA profiler (nsys):
18
+ nsys profile --force-overwrite=true -o bench_one_batch python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile --profile-activities CUDA_PROFILER
19
+ # Usage (correctness test):
20
+ python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
21
+
22
+ ## Reference output (of the correctness test above, can be gpu dependent):
23
+ input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]
24
+
25
+ prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
26
+ [-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
27
+ [ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]],
28
+ device='cuda:0')
29
+
30
+ prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141],
31
+ [-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781],
32
+ [-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]],
33
+ device='cuda:0')
34
+
35
+ ========== Prompt 0 ==========
36
+ <s> The capital of France is Paris.
37
+ The capital of the United States is Washington, D.C.
38
+
39
+
40
+ ========== Prompt 1 ==========
41
+ <s> The capital of the United Kindom is London.
42
+ The capital of the United Kingdom is London.
43
+ The capital of the
44
+
45
+ ========== Prompt 2 ==========
46
+ <s> Today is a sunny day and I like to go for a walk in the park.
47
+ I'm going to the park
48
+ """
49
+
50
+ import argparse
51
+ import copy
52
+ import dataclasses
53
+ import itertools
54
+ import json
55
+ import logging
56
+ import multiprocessing
57
+ import os
58
+ import time
59
+ from types import SimpleNamespace
60
+ from typing import Optional, Tuple
61
+
62
+ import numpy as np
63
+ import torch
64
+ import torch.distributed as dist
65
+
66
+ from sglang.srt.configs.model_config import ModelConfig
67
+ from sglang.srt.distributed.parallel_state import destroy_distributed_environment
68
+ from sglang.srt.entrypoints.engine import _set_envs_and_config
69
+ from sglang.srt.layers.moe import initialize_moe_config
70
+ from sglang.srt.layers.quantization.fp4_utils import initialize_fp4_gemm_config
71
+ from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config
72
+ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
73
+ from sglang.srt.managers.scheduler_dp_attn_mixin import prepare_mlp_sync_batch_raw
74
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
75
+ from sglang.srt.model_executor.model_runner import ModelRunner
76
+ from sglang.srt.sampling.sampling_params import SamplingParams
77
+ from sglang.srt.server_args import PortArgs, ServerArgs
78
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
79
+ from sglang.srt.utils import (
80
+ configure_logger,
81
+ get_bool_env_var,
82
+ kill_process_tree,
83
+ maybe_reindex_device_id,
84
+ require_mlp_sync,
85
+ require_mlp_tp_gather,
86
+ set_gpu_proc_affinity,
87
+ suppress_other_loggers,
88
+ )
89
+ from sglang.srt.utils.hf_transformers_utils import get_tokenizer
90
+
91
+
92
+ def start_profile(profile_activities, profile_record_shapes=False, rank_print=print):
93
+ """
94
+ Abstracted function to start profiling based on profile_activities.
95
+ Returns profiler object (or None).
96
+ """
97
+ if "CUDA_PROFILER" in profile_activities:
98
+ try:
99
+ torch.cuda.cudart().cudaProfilerStart()
100
+ rank_print("CUDA Profiler started (nsys will begin capturing)")
101
+ except Exception as e:
102
+ rank_print(f"Failed to start CUDA profiler: {e}")
103
+ return None
104
+ else:
105
+ activities = []
106
+ if "CPU" in profile_activities:
107
+ activities.append(torch.profiler.ProfilerActivity.CPU)
108
+ if "GPU" in profile_activities:
109
+ activities.append(torch.profiler.ProfilerActivity.CUDA)
110
+ if "XPU" in profile_activities:
111
+ activities.append(torch.profiler.ProfilerActivity.XPU)
112
+ if activities:
113
+ profiler = torch.profiler.profile(
114
+ activities=activities,
115
+ with_stack=True,
116
+ record_shapes=profile_record_shapes,
117
+ )
118
+ profiler.start()
119
+ return profiler
120
+ return None
121
+
122
+
123
+ def stop_profile(
124
+ profiler,
125
+ profile_activities,
126
+ rank_print=print,
127
+ save_trace=False,
128
+ trace_filename=None,
129
+ stage=None,
130
+ ):
131
+ """
132
+ Abstracted function to stop profiling based on profile_activities.
133
+ Optionally saves trace results and prints completion messages.
134
+ """
135
+ if "CUDA_PROFILER" in profile_activities:
136
+ try:
137
+ torch.cuda.cudart().cudaProfilerStop()
138
+ rank_print("CUDA Profiler stopped (nsys should dump traces)")
139
+ except Exception as e:
140
+ rank_print(f"Failed to stop CUDA profiler: {e}")
141
+ elif profiler is not None:
142
+ profiler.stop()
143
+
144
+ if save_trace:
145
+ if profiler is not None:
146
+ if trace_filename:
147
+ _save_profile_trace_results(profiler, trace_filename)
148
+ stage_desc = f"for {stage}" if stage else ""
149
+ rank_print(
150
+ f"torch profiler chrome trace {stage_desc} saved to {trace_filename}"
151
+ )
152
+ if "CUDA_PROFILER" in profile_activities:
153
+ rank_print(f"CUDA profiler trace for {stage} completed")
154
+
155
+
156
+ @dataclasses.dataclass
157
+ class BenchArgs:
158
+ run_name: str = "default"
159
+ batch_size: Tuple[int] = (1,)
160
+ input_len: Tuple[int] = (1024,)
161
+ output_len: Tuple[int] = (16,)
162
+ prompt_filename: str = ""
163
+ result_filename: str = "result.jsonl"
164
+ correctness_test: bool = False
165
+ # This is only used for correctness test
166
+ cut_len: int = 4
167
+ log_decode_step: int = 0
168
+ profile: bool = False
169
+ profile_record_shapes: bool = False
170
+ profile_activities: Tuple[str] = ("CPU", "GPU")
171
+ profile_stage: str = "all"
172
+ profile_filename_prefix: str = "profile"
173
+ profile_start_step: Optional[int] = None
174
+ profile_steps: Optional[int] = None
175
+
176
+ @staticmethod
177
+ def add_cli_args(parser: argparse.ArgumentParser):
178
+ parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
179
+ parser.add_argument(
180
+ "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
181
+ )
182
+ parser.add_argument(
183
+ "--input-len", type=int, nargs="+", default=BenchArgs.input_len
184
+ )
185
+ parser.add_argument(
186
+ "--output-len", type=int, nargs="+", default=BenchArgs.output_len
187
+ )
188
+ parser.add_argument(
189
+ "--prompt-filename", type=str, default=BenchArgs.prompt_filename
190
+ )
191
+ parser.add_argument(
192
+ "--result-filename", type=str, default=BenchArgs.result_filename
193
+ )
194
+ parser.add_argument("--correctness-test", action="store_true")
195
+ parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
196
+ parser.add_argument(
197
+ "--log-decode-step",
198
+ type=int,
199
+ default=BenchArgs.log_decode_step,
200
+ help="Log decode latency by step, default is set to zero to disable.",
201
+ )
202
+ parser.add_argument("--profile", action="store_true", help="Enable profiling.")
203
+ parser.add_argument(
204
+ "--profile-record-shapes",
205
+ action="store_true",
206
+ help="Record tensor shapes in profiling results.",
207
+ )
208
+ parser.add_argument(
209
+ "--profile-activities",
210
+ type=str,
211
+ nargs="+",
212
+ default=["CPU", "GPU"],
213
+ choices=["CPU", "GPU", "CUDA_PROFILER", "XPU"],
214
+ help="Profiler activities: CPU, GPU, XPU, CUDA_PROFILER. If CPU/GPU/XPU, use torch profiler. If CUDA_PROFILER, use CUDA profiler.",
215
+ )
216
+ parser.add_argument(
217
+ "--profile-stage",
218
+ type=str,
219
+ default=BenchArgs.profile_stage,
220
+ choices=["all", "prefill", "decode"],
221
+ help="Which stage to profile: all, prefill, or decode only.",
222
+ )
223
+ parser.add_argument(
224
+ "--profile-filename-prefix",
225
+ type=str,
226
+ default=BenchArgs.profile_filename_prefix,
227
+ help="Prefix of the profiling file names. The full profiling result file(s) be "
228
+ '"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"',
229
+ )
230
+ parser.add_argument(
231
+ "--profile-start-step",
232
+ type=int,
233
+ default=None,
234
+ help="Decode step at which to start profiling (0-indexed). If not specified, defaults to output_len // 2.",
235
+ )
236
+ parser.add_argument(
237
+ "--profile-steps",
238
+ type=int,
239
+ default=None,
240
+ help="Number of decode steps to profile starting from profile-start-step. If not specified, profiles only one step.",
241
+ )
242
+
243
+ @classmethod
244
+ def from_cli_args(cls, args: argparse.Namespace):
245
+ # use the default value's type to cast the args into correct types.
246
+ attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
247
+ result = {}
248
+ for attr, attr_type in attrs:
249
+ value = getattr(args, attr)
250
+ # Handle None values - don't try to cast them
251
+ if value is None or attr_type == type(None):
252
+ result[attr] = value
253
+ else:
254
+ result[attr] = attr_type(value)
255
+ return cls(**result)
256
+
257
+
258
+ def load_model(server_args, port_args, gpu_id, tp_rank):
259
+ suppress_other_loggers()
260
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
261
+ moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
262
+
263
+ model_config = ModelConfig.from_server_args(server_args)
264
+ model_runner = ModelRunner(
265
+ model_config=model_config,
266
+ mem_fraction_static=server_args.mem_fraction_static,
267
+ gpu_id=gpu_id,
268
+ tp_rank=tp_rank,
269
+ tp_size=server_args.tp_size,
270
+ moe_ep_rank=moe_ep_rank,
271
+ moe_ep_size=server_args.ep_size,
272
+ pp_rank=0,
273
+ pp_size=1,
274
+ nccl_port=port_args.nccl_port,
275
+ server_args=server_args,
276
+ )
277
+ rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
278
+ tokenizer = get_tokenizer(
279
+ server_args.tokenizer_path,
280
+ tokenizer_mode=server_args.tokenizer_mode,
281
+ trust_remote_code=server_args.trust_remote_code,
282
+ )
283
+ if server_args.tp_size > 1:
284
+ dist.barrier()
285
+ return model_runner, tokenizer
286
+
287
+
288
+ def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
289
+ prompts = (
290
+ custom_prompts
291
+ if custom_prompts
292
+ else [
293
+ "The capital of France is",
294
+ "The capital of the United Kindom is",
295
+ "Today is a sunny day and I like",
296
+ ]
297
+ )
298
+ input_ids = [tokenizer.encode(p) for p in prompts]
299
+ sampling_params = SamplingParams(
300
+ temperature=0,
301
+ max_new_tokens=BenchArgs.output_len,
302
+ )
303
+
304
+ reqs = []
305
+ for i in range(len(prompts)):
306
+ assert len(input_ids[i]) > bench_args.cut_len
307
+
308
+ tmp_input_ids = input_ids[i][: bench_args.cut_len]
309
+ req = Req(
310
+ rid=i,
311
+ origin_input_text=prompts[i],
312
+ origin_input_ids=tmp_input_ids,
313
+ sampling_params=sampling_params,
314
+ )
315
+ req.fill_ids = req.origin_input_ids
316
+ req.logprob_start_len = -1
317
+ req.set_extend_input_len(len(req.fill_ids) - len(req.prefix_indices))
318
+ reqs.append(req)
319
+
320
+ return input_ids, reqs
321
+
322
+
323
+ def prepare_extend_inputs_for_correctness_test(
324
+ bench_args, input_ids, reqs, model_runner
325
+ ):
326
+ for i in range(len(reqs)):
327
+ req: Req = reqs[i]
328
+ req.fill_ids += input_ids[i][bench_args.cut_len :]
329
+ req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
330
+ i, : bench_args.cut_len
331
+ ].to(req.prefix_indices.dtype)
332
+ req.logprob_start_len = -1
333
+ req.set_extend_input_len(len(req.fill_ids) - len(req.prefix_indices))
334
+ return reqs
335
+
336
+
337
+ def prepare_synthetic_inputs_for_latency_test(
338
+ batch_size, input_len, custom_inputs=None
339
+ ):
340
+ input_ids = (
341
+ custom_inputs
342
+ if custom_inputs
343
+ else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
344
+ )
345
+ sampling_params = SamplingParams(
346
+ temperature=0,
347
+ max_new_tokens=BenchArgs.output_len,
348
+ )
349
+
350
+ reqs = []
351
+ for i in range(len(input_ids)):
352
+ req = Req(
353
+ rid=i,
354
+ origin_input_text="",
355
+ origin_input_ids=list(input_ids[i]),
356
+ sampling_params=sampling_params,
357
+ )
358
+ req.fill_ids = req.origin_input_ids
359
+ req.logprob_start_len = -1
360
+ req.set_extend_input_len(len(req.fill_ids) - len(req.prefix_indices))
361
+ reqs.append(req)
362
+
363
+ return reqs
364
+
365
+
366
+ class TreeCacheNamespace(SimpleNamespace):
367
+ def supports_swa(self) -> bool:
368
+ return False
369
+
370
+ def supports_mamba(self) -> bool:
371
+ return False
372
+
373
+ def is_chunk_cache(self) -> bool:
374
+ return False
375
+
376
+ def is_tree_cache(self) -> bool:
377
+ return not self.is_chunk_cache()
378
+
379
+
380
+ @torch.no_grad
381
+ def extend(reqs, model_runner):
382
+ # Create dummy tree_cache for benchmarks (no prefix caching, just allocation)
383
+ dummy_tree_cache = TreeCacheNamespace(
384
+ page_size=model_runner.server_args.page_size,
385
+ device=model_runner.device,
386
+ token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
387
+ )
388
+
389
+ batch = ScheduleBatch.init_new(
390
+ reqs=reqs,
391
+ req_to_token_pool=model_runner.req_to_token_pool,
392
+ token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
393
+ tree_cache=dummy_tree_cache,
394
+ model_config=model_runner.model_config,
395
+ enable_overlap=False,
396
+ spec_algorithm=SpeculativeAlgorithm.NONE,
397
+ )
398
+ batch.prepare_for_extend()
399
+ _maybe_prepare_mlp_sync_batch(batch, model_runner)
400
+ model_worker_batch = batch.get_model_worker_batch()
401
+ forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
402
+ logits_output = model_runner.forward(forward_batch).logits_output
403
+ next_token_ids = model_runner.sample(logits_output, forward_batch)
404
+ return next_token_ids, logits_output.next_token_logits, batch
405
+
406
+
407
+ @torch.no_grad
408
+ def decode(input_token_ids, batch, model_runner):
409
+ batch.output_ids = input_token_ids
410
+ batch.prepare_for_decode()
411
+ _maybe_prepare_mlp_sync_batch(batch, model_runner)
412
+ model_worker_batch = batch.get_model_worker_batch()
413
+ forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
414
+ logits_output = model_runner.forward(forward_batch).logits_output
415
+ next_token_ids = model_runner.sample(logits_output, forward_batch)
416
+ return next_token_ids, logits_output.next_token_logits
417
+
418
+
419
+ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
420
+ if require_mlp_sync(model_runner.server_args):
421
+ prepare_mlp_sync_batch_raw(
422
+ batch,
423
+ dp_size=model_runner.server_args.dp_size,
424
+ attn_tp_size=1,
425
+ tp_group=model_runner.tp_group,
426
+ get_idle_batch=None,
427
+ disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
428
+ require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
429
+ disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
430
+ offload_tags=set(),
431
+ )
432
+
433
+
434
+ def _read_prompts_from_file(prompt_file, rank_print):
435
+ """Read custom prompts from the file specified by `--prompt-filename`."""
436
+ if not prompt_file:
437
+ return []
438
+ if not os.path.exists(prompt_file):
439
+ rank_print(
440
+ f"Custom prompt file {prompt_file} not found. Using default inputs..."
441
+ )
442
+ return []
443
+ with open(prompt_file, "r") as pf:
444
+ return pf.readlines()
445
+
446
+
447
+ def _get_torch_profiler_output_dir():
448
+ return os.environ.get("SGLANG_TORCH_PROFILER_DIR", "/tmp")
449
+
450
+
451
+ def _create_torch_profiler_filename(
452
+ profile_filename_prefix, batch_size, input_len, output_len, stage
453
+ ):
454
+ output_dir = _get_torch_profiler_output_dir()
455
+ filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_{stage}.trace.json.gz"
456
+ return os.path.join(output_dir, filename)
457
+
458
+
459
+ def _save_profile_trace_results(profiler, filename):
460
+ parent_dir = os.path.dirname(os.path.abspath(filename))
461
+ os.makedirs(parent_dir, exist_ok=True)
462
+ profiler.export_chrome_trace(filename)
463
+ print(
464
+ profiler.key_averages(group_by_input_shape=True).table(
465
+ sort_by="self_cpu_time_total"
466
+ )
467
+ )
468
+
469
+
470
+ def correctness_test(
471
+ server_args,
472
+ port_args,
473
+ bench_args,
474
+ gpu_id,
475
+ tp_rank,
476
+ ):
477
+ # Configure the logger
478
+ configure_logger(server_args, prefix=f" TP{tp_rank}")
479
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
480
+
481
+ # Load the model
482
+ model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank)
483
+
484
+ # Prepare inputs
485
+ custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
486
+ input_ids, reqs = prepare_inputs_for_correctness_test(
487
+ bench_args, tokenizer, custom_prompts
488
+ )
489
+ rank_print(f"\n{input_ids=}\n")
490
+
491
+ if bench_args.cut_len > 0:
492
+ # Prefill
493
+ next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
494
+ rank_print(f"prefill logits (first half): {next_token_logits} \n")
495
+
496
+ # Prepare extend inputs
497
+ reqs = prepare_extend_inputs_for_correctness_test(
498
+ bench_args, input_ids, reqs, model_runner
499
+ )
500
+
501
+ # Extend (prefill w/ KV cache)
502
+ next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
503
+ rank_print(f"prefill logits (final): {next_token_logits} \n")
504
+
505
+ # Decode
506
+ output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
507
+ for _ in range(bench_args.output_len[0] - 1):
508
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
509
+ next_token_ids_list = next_token_ids.tolist()
510
+ for i in range(len(reqs)):
511
+ output_ids[i].append(next_token_ids_list[i])
512
+
513
+ # Print output texts
514
+ for i in range(len(reqs)):
515
+ rank_print(f"========== Prompt {i} ==========")
516
+ rank_print(tokenizer.decode(output_ids[i]), "\n")
517
+
518
+
519
+ def synchronize(device):
520
+ torch.get_device_module(device).synchronize()
521
+
522
+
523
+ def latency_test_run_once(
524
+ run_name,
525
+ model_runner,
526
+ rank_print,
527
+ reqs,
528
+ batch_size,
529
+ input_len,
530
+ output_len,
531
+ device,
532
+ log_decode_step,
533
+ profile,
534
+ profile_record_shapes,
535
+ profile_activities,
536
+ profile_filename_prefix,
537
+ profile_stage,
538
+ tp_rank,
539
+ profile_start_step=None,
540
+ profile_steps=None,
541
+ ):
542
+ max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
543
+ if batch_size > max_batch_size:
544
+ rank_print(
545
+ f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
546
+ )
547
+ return
548
+
549
+ model_runner.req_to_token_pool.clear()
550
+ model_runner.token_to_kv_pool_allocator.clear()
551
+
552
+ measurement_results = {
553
+ "run_name": run_name,
554
+ "batch_size": batch_size,
555
+ "input_len": input_len,
556
+ "output_len": output_len,
557
+ }
558
+
559
+ tot_latency = 0
560
+
561
+ profiler = None
562
+ enable_profile_prefill = profile and profile_stage in ["all", "prefill"]
563
+ if enable_profile_prefill:
564
+ profiler = start_profile(
565
+ profile_activities,
566
+ profile_record_shapes=profile_record_shapes,
567
+ rank_print=rank_print,
568
+ )
569
+
570
+ synchronize(device)
571
+ tic = time.perf_counter()
572
+ next_token_ids, _, batch = extend(reqs, model_runner)
573
+ synchronize(device)
574
+ prefill_latency = time.perf_counter() - tic
575
+
576
+ if enable_profile_prefill:
577
+ trace_filename = _create_torch_profiler_filename(
578
+ profile_filename_prefix, batch_size, input_len, output_len, "prefill"
579
+ )
580
+ stop_profile(
581
+ profiler,
582
+ profile_activities,
583
+ rank_print=rank_print,
584
+ save_trace=True,
585
+ trace_filename=trace_filename,
586
+ stage="prefill",
587
+ )
588
+
589
+ tot_latency += prefill_latency
590
+ throughput = input_len * batch_size / prefill_latency
591
+ rank_print(
592
+ f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
593
+ )
594
+ measurement_results["prefill_latency"] = prefill_latency
595
+ measurement_results["prefill_throughput"] = throughput
596
+
597
+ decode_latencies = []
598
+ # Determine profiling start step and end step
599
+ profile_start = (
600
+ profile_start_step if profile_start_step is not None else (output_len // 2)
601
+ )
602
+ profile_end = profile_start + (profile_steps if profile_steps is not None else 1)
603
+ enable_profile_decode = profile and profile_stage in ["all", "decode"]
604
+ profiler = None
605
+ for i in range(output_len - 1):
606
+ synchronize(device)
607
+ # Start profiler at the specified step
608
+ if enable_profile_decode and i == profile_start:
609
+ profiler = start_profile(
610
+ profile_activities,
611
+ profile_record_shapes=profile_record_shapes,
612
+ rank_print=rank_print,
613
+ )
614
+
615
+ tic = time.perf_counter()
616
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
617
+ synchronize(device)
618
+ latency = time.perf_counter() - tic
619
+
620
+ # Stop profiler after the specified number of steps
621
+ if enable_profile_decode and profiler is not None and i >= profile_end - 1:
622
+ trace_filename = _create_torch_profiler_filename(
623
+ profile_filename_prefix, batch_size, input_len, output_len, "decode"
624
+ )
625
+ stop_profile(
626
+ profiler,
627
+ profile_activities,
628
+ rank_print=rank_print,
629
+ save_trace=True,
630
+ trace_filename=trace_filename,
631
+ stage="decode",
632
+ )
633
+ profiler = None
634
+
635
+ tot_latency += latency
636
+ throughput = batch_size / latency
637
+ decode_latencies.append(latency)
638
+ if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0):
639
+ rank_print(
640
+ f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
641
+ )
642
+
643
+ # Record decode timing from 2nd output
644
+ if output_len > 1:
645
+ med_decode_latency = np.median(decode_latencies)
646
+ med_decode_throughput = batch_size / med_decode_latency
647
+ rank_print(
648
+ f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
649
+ )
650
+ measurement_results["median_decode_latency"] = med_decode_latency
651
+ measurement_results["median_decode_throughput"] = med_decode_throughput
652
+
653
+ throughput = (input_len + output_len) * batch_size / tot_latency
654
+ rank_print(
655
+ f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
656
+ )
657
+ measurement_results["total_latency"] = tot_latency
658
+ measurement_results["overall_throughput"] = throughput
659
+ return measurement_results
660
+
661
+
662
+ def latency_test(
663
+ server_args,
664
+ port_args,
665
+ bench_args,
666
+ gpu_id,
667
+ tp_rank,
668
+ ):
669
+ initialize_moe_config(server_args)
670
+ initialize_fp8_gemm_config(server_args)
671
+ initialize_fp4_gemm_config(server_args)
672
+
673
+ # Set CPU affinity
674
+ if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
675
+ set_gpu_proc_affinity(
676
+ server_args.pp_size, server_args.tp_size, server_args.nnodes, tp_rank
677
+ )
678
+
679
+ # Configure the logger
680
+ configure_logger(server_args, prefix=f" TP{tp_rank}")
681
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
682
+
683
+ # Load the model
684
+ model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank)
685
+
686
+ # Prepare inputs for warm up
687
+ reqs = prepare_synthetic_inputs_for_latency_test(
688
+ bench_args.batch_size[0], bench_args.input_len[0]
689
+ )
690
+
691
+ # Warm up
692
+ rank_print("Warmup ...")
693
+ latency_test_run_once(
694
+ bench_args.run_name,
695
+ model_runner,
696
+ rank_print,
697
+ reqs,
698
+ bench_args.batch_size[0],
699
+ bench_args.input_len[0],
700
+ min(32, bench_args.output_len[0]), # shorter decoding to speed up the warmup
701
+ server_args.device,
702
+ log_decode_step=0,
703
+ profile=False,
704
+ profile_record_shapes=False,
705
+ profile_activities=("CPU", "GPU"),
706
+ profile_filename_prefix="",
707
+ profile_stage="all",
708
+ tp_rank=tp_rank,
709
+ profile_start_step=None,
710
+ profile_steps=None,
711
+ )
712
+
713
+ rank_print("Benchmark ...")
714
+
715
+ custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
716
+ custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs]
717
+ custom_input_len = len(custom_inputs)
718
+
719
+ # Run the sweep
720
+ result_list = []
721
+ for bs, il, ol in itertools.product(
722
+ bench_args.batch_size, bench_args.input_len, bench_args.output_len
723
+ ):
724
+ bs_aligned_inputs = []
725
+ if custom_inputs:
726
+ if custom_input_len == bs:
727
+ bs_aligned_inputs = custom_inputs
728
+ elif custom_input_len > bs:
729
+ rank_print(
730
+ f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). "
731
+ f"Using the first {bs} prompts."
732
+ )
733
+ bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs])
734
+ else:
735
+ rank_print(
736
+ f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). "
737
+ f"Pad to the desired batch_size with the last prompt."
738
+ )
739
+ bs_aligned_inputs = copy.deepcopy(custom_inputs)
740
+ bs_aligned_inputs.extend(
741
+ [bs_aligned_inputs[-1]] * (bs - custom_input_len)
742
+ )
743
+
744
+ reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs)
745
+ ret = latency_test_run_once(
746
+ bench_args.run_name,
747
+ model_runner,
748
+ rank_print,
749
+ reqs,
750
+ bs,
751
+ il,
752
+ ol,
753
+ server_args.device,
754
+ bench_args.log_decode_step,
755
+ bench_args.profile if tp_rank == 0 else None,
756
+ bench_args.profile_record_shapes if tp_rank == 0 else None,
757
+ bench_args.profile_activities,
758
+ bench_args.profile_filename_prefix,
759
+ bench_args.profile_stage,
760
+ tp_rank,
761
+ bench_args.profile_start_step,
762
+ bench_args.profile_steps,
763
+ )
764
+ if ret is not None:
765
+ result_list.append(ret)
766
+
767
+ # Write results in jsonlines format on rank 0.
768
+ if tp_rank == 0 and bench_args.result_filename:
769
+ with open(bench_args.result_filename, "a") as fout:
770
+ for result in result_list:
771
+ fout.write(json.dumps(result) + "\n")
772
+
773
+ if server_args.tp_size > 1:
774
+ destroy_distributed_environment()
775
+
776
+
777
+ def main(server_args, bench_args):
778
+ server_args.cuda_graph_max_bs = max(bench_args.batch_size)
779
+
780
+ _set_envs_and_config(server_args)
781
+
782
+ if server_args.model_path:
783
+ if bench_args.correctness_test:
784
+ work_func = correctness_test
785
+ else:
786
+ work_func = latency_test
787
+ else:
788
+ raise ValueError(
789
+ "Provide --model-path for running the tests or "
790
+ "provide --result-filename for plotting the results"
791
+ )
792
+
793
+ port_args = PortArgs.init_new(server_args)
794
+
795
+ if server_args.tp_size == 1:
796
+ work_func(server_args, port_args, bench_args, 0, 0)
797
+ else:
798
+ workers = []
799
+ for tp_rank in range(server_args.tp_size):
800
+ with maybe_reindex_device_id(tp_rank) as gpu_id:
801
+ proc = multiprocessing.Process(
802
+ target=work_func,
803
+ args=(
804
+ server_args,
805
+ port_args,
806
+ bench_args,
807
+ gpu_id,
808
+ tp_rank,
809
+ ),
810
+ )
811
+ proc.start()
812
+ workers.append(proc)
813
+
814
+ for proc in workers:
815
+ proc.join()
816
+
817
+ proc.terminate()
818
+
819
+
820
+ if __name__ == "__main__":
821
+ parser = argparse.ArgumentParser()
822
+ ServerArgs.add_cli_args(parser)
823
+ BenchArgs.add_cli_args(parser)
824
+ args = parser.parse_args()
825
+ server_args = ServerArgs.from_cli_args(args)
826
+ bench_args = BenchArgs.from_cli_args(args)
827
+
828
+ logging.basicConfig(
829
+ level=getattr(logging, server_args.log_level.upper()),
830
+ format="%(message)s",
831
+ )
832
+
833
+ try:
834
+ main(server_args, bench_args)
835
+ finally:
836
+ if server_args.tp_size != 1:
837
+ kill_process_tree(os.getpid(), include_parent=False)
sglang/python/sglang/bench_one_batch_server.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Benchmark the latency of running a single batch with a server.
3
+
4
+ This script launches a server and uses the HTTP interface.
5
+ It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
6
+
7
+ Usage:
8
+ python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
9
+
10
+ python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
11
+ python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
12
+ python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --output-path results.json --profile
13
+ """
14
+
15
+ import argparse
16
+
17
+ from sglang.srt.server_args import ServerArgs
18
+ from sglang.test.bench_one_batch_server_internal import (
19
+ BenchArgs,
20
+ run_benchmark_internal,
21
+ )
22
+ from sglang.test.nightly_bench_utils import save_results_as_pydantic_models
23
+
24
+
25
+ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
26
+ results, server_info = run_benchmark_internal(server_args, bench_args)
27
+
28
+ # Save results as pydantic models in the JSON format
29
+ if bench_args.pydantic_result_filename:
30
+ save_results_as_pydantic_models(
31
+ results,
32
+ pydantic_result_filename=bench_args.pydantic_result_filename,
33
+ model_path=server_args.model_path,
34
+ server_args=bench_args.server_args_for_metrics,
35
+ )
36
+
37
+ return results, server_info
38
+
39
+
40
+ if __name__ == "__main__":
41
+ parser = argparse.ArgumentParser()
42
+ ServerArgs.add_cli_args(parser)
43
+ BenchArgs.add_cli_args(parser)
44
+ args = parser.parse_args()
45
+
46
+ server_args = ServerArgs.from_cli_args(args)
47
+ bench_args = BenchArgs.from_cli_args(args)
48
+
49
+ run_benchmark(server_args, bench_args)
sglang/python/sglang/bench_serving.py ADDED
@@ -0,0 +1,2238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py
2
+ # Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py
3
+
4
+ """
5
+ Benchmark online serving with dynamic requests.
6
+
7
+ Usage:
8
+ python3 -m sglang.bench_serving --backend sglang --num-prompt 10
9
+
10
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
11
+ """
12
+
13
+ import argparse
14
+ import asyncio
15
+ import copy
16
+ import importlib.util
17
+ import json
18
+ import os
19
+ import random
20
+ import shutil
21
+ import sys
22
+ import time
23
+ import traceback
24
+ import uuid
25
+ import warnings
26
+ from argparse import ArgumentParser
27
+ from copy import deepcopy
28
+ from dataclasses import dataclass, field, replace
29
+ from datetime import datetime
30
+ from pathlib import Path
31
+ from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple, Union
32
+
33
+ import aiohttp
34
+ import numpy as np
35
+ import requests
36
+ from tqdm.asyncio import tqdm
37
+ from transformers import AutoTokenizer, PreTrainedTokenizerBase
38
+
39
+ from sglang.benchmark.datasets import DatasetRow, get_dataset
40
+ from sglang.benchmark.datasets.mooncake import get_mooncake_request_over_time
41
+ from sglang.benchmark.utils import (
42
+ get_tokenizer,
43
+ parse_custom_headers,
44
+ remove_prefix,
45
+ set_ulimit,
46
+ )
47
+
48
+ _ROUTING_KEY_HEADER = "X-SMG-Routing-Key"
49
+
50
+ TERM_PLOTLIB_AVAILABLE = (importlib.util.find_spec("termplotlib") is not None) and (
51
+ shutil.which("gnuplot") is not None
52
+ )
53
+
54
+ global args
55
+
56
+
57
+ # don't want to import sglang package here
58
+ def _get_bool_env_var(name: str, default: str = "false") -> bool:
59
+ value = os.getenv(name, default)
60
+ return value.lower() in ("true", "1")
61
+
62
+
63
+ def _create_bench_client_session():
64
+ # When the pressure is big, the read buffer could be full before aio thread read
65
+ # the content. We increase the read_bufsize from 64K to 10M.
66
+ # Define constants for timeout and buffer size for clarity and maintainability
67
+ BENCH_AIOHTTP_TIMEOUT_SECONDS = 6 * 60 * 60 # 6 hours
68
+ BENCH_AIOHTTP_READ_BUFSIZE_BYTES = 10 * 1024**2 # 10 MB
69
+
70
+ aiohttp_timeout = aiohttp.ClientTimeout(total=BENCH_AIOHTTP_TIMEOUT_SECONDS)
71
+ return aiohttp.ClientSession(
72
+ timeout=aiohttp_timeout, read_bufsize=BENCH_AIOHTTP_READ_BUFSIZE_BYTES
73
+ )
74
+
75
+
76
+ @dataclass
77
+ class RequestFuncInput:
78
+ prompt: Union[str, List[str], List[Dict[str, str]]]
79
+ api_url: str
80
+ prompt_len: int
81
+ output_len: int
82
+ model: str
83
+ lora_name: str
84
+ image_data: Optional[List[str]]
85
+ extra_request_body: Dict[str, Any]
86
+ timestamp: Optional[float] = None
87
+ routing_key: Optional[str] = None
88
+
89
+
90
+ @dataclass
91
+ class RequestFuncOutput:
92
+ generated_text: str = ""
93
+ success: bool = False
94
+ latency: float = 0.0
95
+ ttft: float = 0.0 # Time to first token
96
+ itl: List[float] = field(default_factory=list) # List of inter-token latencies
97
+ text_chunks: List[str] = field(default_factory=list)
98
+ prompt_len: int = 0
99
+ error: str = ""
100
+ output_len: int = 0
101
+ start_time: float = 0.0
102
+
103
+ @staticmethod
104
+ def init_new(request_func_input: RequestFuncInput):
105
+ output = RequestFuncOutput()
106
+ output.prompt_len = request_func_input.prompt_len
107
+ return output
108
+
109
+
110
+ def get_auth_headers() -> Dict[str, str]:
111
+ openai_api_key = os.environ.get("OPENAI_API_KEY")
112
+ if openai_api_key:
113
+ return {"Authorization": f"Bearer {openai_api_key}"}
114
+ else:
115
+ api_key = os.environ.get("API_KEY")
116
+ if api_key:
117
+ return {"Authorization": f"{api_key}"}
118
+ return {}
119
+
120
+
121
+ def get_request_headers() -> Dict[str, str]:
122
+ headers = get_auth_headers()
123
+ if h := getattr(args, "header", None):
124
+ headers.update(parse_custom_headers(h))
125
+ return headers
126
+
127
+
128
+ def wait_for_endpoint(url: str, timeout_sec: int = 60) -> bool:
129
+ """Wait for the server to become ready by polling the given URL."""
130
+ print(f"Waiting up to {timeout_sec}s for {url} to become ready...")
131
+ start_time = time.perf_counter()
132
+ headers = get_auth_headers()
133
+ while True:
134
+ try:
135
+ response = requests.get(url, headers=headers, timeout=5)
136
+ if response.status_code == 200:
137
+ elapsed = time.perf_counter() - start_time
138
+ print(f"Server ready in {elapsed:.1f}s.")
139
+ return True
140
+ except requests.exceptions.RequestException:
141
+ pass
142
+ elapsed = time.perf_counter() - start_time
143
+ if elapsed >= timeout_sec:
144
+ print(f"Server did not become ready within {timeout_sec}s timeout.")
145
+ return False
146
+ time.sleep(1)
147
+
148
+
149
+ # trt llm does not support ignore_eos
150
+ # https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
151
+ async def async_request_trt_llm(
152
+ request_func_input: RequestFuncInput,
153
+ pbar: Optional[tqdm] = None,
154
+ ) -> RequestFuncOutput:
155
+ api_url = request_func_input.api_url
156
+ assert api_url.endswith("generate_stream")
157
+
158
+ async with _create_bench_client_session() as session:
159
+ payload = {
160
+ "accumulate_tokens": True,
161
+ "text_input": request_func_input.prompt,
162
+ "temperature": 0.000001,
163
+ "top_p": 1.0,
164
+ "max_tokens": request_func_input.output_len,
165
+ "stream": True,
166
+ "min_length": request_func_input.output_len,
167
+ "end_id": 1048576,
168
+ **request_func_input.extra_request_body,
169
+ }
170
+ if args.disable_ignore_eos:
171
+ del payload["min_length"]
172
+ del payload["end_id"]
173
+ output = RequestFuncOutput.init_new(request_func_input)
174
+
175
+ ttft = 0.0
176
+ st = time.perf_counter()
177
+ most_recent_timestamp = st
178
+ try:
179
+ async with session.post(url=api_url, json=payload) as response:
180
+ if response.status == 200:
181
+ async for chunk_bytes in response.content:
182
+ chunk_bytes = chunk_bytes.strip()
183
+ if not chunk_bytes:
184
+ continue
185
+
186
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:")
187
+
188
+ data = json.loads(chunk)
189
+ output.generated_text += data["text_output"]
190
+ timestamp = time.perf_counter()
191
+ # First token
192
+ if ttft == 0.0:
193
+ ttft = timestamp - st
194
+ output.ttft = ttft
195
+
196
+ # Decoding phase
197
+ else:
198
+ output.itl.append(timestamp - most_recent_timestamp)
199
+
200
+ most_recent_timestamp = timestamp
201
+
202
+ output.latency = most_recent_timestamp - st
203
+ output.success = True
204
+ output.output_len = request_func_input.output_len
205
+
206
+ else:
207
+ output.error = (
208
+ (response.reason or "") + ": " + (await response.text())
209
+ )
210
+ output.success = False
211
+ except Exception:
212
+ output.success = False
213
+ exc_info = sys.exc_info()
214
+ output.error = "".join(traceback.format_exception(*exc_info))
215
+
216
+ if pbar:
217
+ pbar.update(1)
218
+ return output
219
+
220
+
221
+ # set ignore_eos True by default
222
+ async def async_request_openai_completions(
223
+ request_func_input: RequestFuncInput,
224
+ pbar: Optional[tqdm] = None,
225
+ ) -> RequestFuncOutput:
226
+ api_url = request_func_input.api_url
227
+ assert api_url.endswith(
228
+ "completions"
229
+ ), "OpenAI Completions API URL must end with 'completions'."
230
+
231
+ prompt = request_func_input.prompt
232
+
233
+ async with _create_bench_client_session() as session:
234
+ # Build payload with defaults that can be overridden by extra_request_body
235
+ payload = {
236
+ "model": request_func_input.model,
237
+ "prompt": prompt,
238
+ "best_of": 1,
239
+ "max_tokens": request_func_input.output_len,
240
+ "stream": not args.disable_stream,
241
+ }
242
+
243
+ # Add temperature default only if not specified in extra_request_body
244
+ if "temperature" not in request_func_input.extra_request_body:
245
+ payload["temperature"] = 0.0
246
+
247
+ # Add ignore_eos default only if not specified in extra_request_body
248
+ if "ignore_eos" not in request_func_input.extra_request_body:
249
+ payload["ignore_eos"] = not args.disable_ignore_eos
250
+
251
+ # Merge in extra parameters - these will override defaults if present
252
+ payload.update(request_func_input.extra_request_body)
253
+
254
+ # hack to accommodate different LoRA conventions between SGLang and vLLM.
255
+ if request_func_input.lora_name:
256
+ payload["model"] = request_func_input.lora_name
257
+ payload["lora_path"] = request_func_input.lora_name
258
+
259
+ if request_func_input.image_data:
260
+ payload.update({"image_data": request_func_input.image_data})
261
+
262
+ headers = get_request_headers()
263
+ if request_func_input.routing_key:
264
+ headers[_ROUTING_KEY_HEADER] = request_func_input.routing_key
265
+
266
+ output = RequestFuncOutput.init_new(request_func_input)
267
+
268
+ generated_text = ""
269
+ output_len = request_func_input.output_len
270
+ ttft = 0.0
271
+ st = time.perf_counter()
272
+ output.start_time = st
273
+ most_recent_timestamp = st
274
+ try:
275
+ async with session.post(
276
+ url=api_url, json=payload, headers=headers
277
+ ) as response:
278
+ if response.status == 200:
279
+ async for chunk_bytes in response.content:
280
+ chunk_bytes = chunk_bytes.strip()
281
+ if not chunk_bytes:
282
+ continue
283
+
284
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
285
+ latency = time.perf_counter() - st
286
+ if chunk == "[DONE]":
287
+ pass
288
+ else:
289
+ data = json.loads(chunk)
290
+
291
+ # NOTE: Some completion API might have a last
292
+ # usage summary response without a token so we
293
+ # want to check a token was generated
294
+ if data["choices"][0]["text"]:
295
+ timestamp = time.perf_counter()
296
+ # First token
297
+ if ttft == 0.0:
298
+ ttft = time.perf_counter() - st
299
+ output.ttft = ttft
300
+
301
+ # Decoding phase
302
+ else:
303
+ output.text_chunks.append(
304
+ data["choices"][0]["text"]
305
+ )
306
+ output.itl.append(timestamp - most_recent_timestamp)
307
+
308
+ most_recent_timestamp = timestamp
309
+ generated_text += data["choices"][0]["text"]
310
+ output_len = (data.get("usage") or {}).get(
311
+ "completion_tokens", output_len
312
+ )
313
+
314
+ output.generated_text = generated_text
315
+ output.success = True
316
+ output.latency = latency
317
+ output.output_len = output_len
318
+ else:
319
+ output.error = (
320
+ (response.reason or "") + ": " + (await response.text())
321
+ )
322
+ output.success = False
323
+ except Exception:
324
+ output.success = False
325
+ exc_info = sys.exc_info()
326
+ output.error = "".join(traceback.format_exception(*exc_info))
327
+
328
+ if pbar:
329
+ pbar.update(1)
330
+ return output
331
+
332
+
333
+ async def async_request_openai_chat_completions(
334
+ request_func_input: RequestFuncInput,
335
+ pbar: Optional[tqdm] = None,
336
+ ) -> RequestFuncOutput:
337
+ """Makes a request to the OpenAI Chat Completions API.
338
+
339
+ Handles both streaming and non-streaming responses, including support
340
+ for image data in messages. Calculates and returns various performance
341
+ metrics.
342
+
343
+ Args:
344
+ request_func_input: Input parameters for the request.
345
+ pbar: Optional tqdm progress bar to update.
346
+
347
+ Returns:
348
+ RequestFuncOutput: Output of the request, including generated text,
349
+ latency, TTFT, ITL, and success status.
350
+ """
351
+ api_url = request_func_input.api_url
352
+ assert api_url.endswith(
353
+ "chat/completions"
354
+ ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
355
+
356
+ # TODO put it to other functions when `pbar` logic is refactored
357
+ if getattr(args, "print_requests", False):
358
+ rid = str(uuid.uuid4())
359
+ input_partial = deepcopy(request_func_input)
360
+ input_partial.prompt = "..."
361
+ request_start_time = time.time()
362
+ print(
363
+ f'rid={rid} time={request_start_time} message="request start" request_func_input="{str(input_partial)}"'
364
+ )
365
+
366
+ if isinstance(request_func_input.prompt, list):
367
+ messages = request_func_input.prompt
368
+ elif request_func_input.image_data:
369
+ # Build multi-image content: a list of image_url entries followed by the text
370
+ content_items = [
371
+ {
372
+ "type": "image_url",
373
+ "image_url": {"url": img_url},
374
+ }
375
+ for img_url in request_func_input.image_data
376
+ ]
377
+ content_items.append({"type": "text", "text": request_func_input.prompt})
378
+ messages = [
379
+ {
380
+ "role": "user",
381
+ "content": content_items,
382
+ },
383
+ ]
384
+ else:
385
+ messages = [{"role": "user", "content": request_func_input.prompt}]
386
+
387
+ async with _create_bench_client_session() as session:
388
+ # Build payload with defaults that can be overridden by extra_request_body
389
+ payload = {
390
+ "model": request_func_input.model,
391
+ "messages": messages,
392
+ "max_completion_tokens": request_func_input.output_len,
393
+ "stream": not args.disable_stream,
394
+ }
395
+
396
+ # Add temperature default only if not specified in extra_request_body
397
+ if "temperature" not in request_func_input.extra_request_body:
398
+ payload["temperature"] = 0.0
399
+
400
+ # Add ignore_eos default only if not specified in extra_request_body
401
+ # Default to False for more realistic behavior (respect EOS tokens)
402
+ if "ignore_eos" not in request_func_input.extra_request_body:
403
+ payload["ignore_eos"] = not args.disable_ignore_eos
404
+
405
+ # Merge in extra parameters (tools, temperature, top_p, etc.)
406
+ # These will override defaults if present
407
+ payload.update(request_func_input.extra_request_body)
408
+
409
+ # hack to accommodate different LoRA conventions between SGLang and vLLM.
410
+ if request_func_input.lora_name:
411
+ payload["model"] = request_func_input.lora_name
412
+ payload["lora_path"] = request_func_input.lora_name
413
+
414
+ headers = get_request_headers()
415
+ if request_func_input.routing_key:
416
+ headers[_ROUTING_KEY_HEADER] = request_func_input.routing_key
417
+
418
+ output = RequestFuncOutput.init_new(request_func_input)
419
+
420
+ generated_text = ""
421
+ output_len = request_func_input.output_len
422
+ ttft = 0.0
423
+ st = time.perf_counter()
424
+ output.start_time = st
425
+ most_recent_timestamp = st
426
+ try:
427
+ async with session.post(
428
+ url=api_url, json=payload, headers=headers
429
+ ) as response:
430
+ if response.status == 200:
431
+ if args.disable_stream:
432
+ # Non-streaming response
433
+ response_json = await response.json()
434
+ output.generated_text = response_json["choices"][0]["message"][
435
+ "content"
436
+ ]
437
+ output.success = True
438
+ output.latency = time.perf_counter() - st
439
+ output.ttft = (
440
+ output.latency
441
+ ) # For non-streaming, TTFT = total latency
442
+ output.output_len = response_json.get("usage", {}).get(
443
+ "completion_tokens", output_len
444
+ )
445
+ else:
446
+ # Streaming response
447
+ async for chunk_bytes in response.content:
448
+ chunk_bytes = chunk_bytes.strip()
449
+ if not chunk_bytes:
450
+ continue
451
+
452
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
453
+ latency = time.perf_counter() - st
454
+ if chunk == "[DONE]":
455
+ pass
456
+ else:
457
+ data = json.loads(chunk)
458
+
459
+ # Check if this chunk contains content
460
+ delta = data.get("choices", [{}])[0].get("delta", {})
461
+ content = delta.get("content", "")
462
+
463
+ if content:
464
+ timestamp = time.perf_counter()
465
+ # First token
466
+ if ttft == 0.0:
467
+ ttft = timestamp - st
468
+ output.ttft = ttft
469
+
470
+ # Decoding phase
471
+ else:
472
+ output.text_chunks.append(content)
473
+ output.itl.append(
474
+ timestamp - most_recent_timestamp
475
+ )
476
+
477
+ most_recent_timestamp = timestamp
478
+ generated_text += content
479
+
480
+ # Check for usage info in final chunk
481
+ output_len = (data.get("usage") or {}).get(
482
+ "completion_tokens", output_len
483
+ )
484
+
485
+ output.generated_text = generated_text
486
+ output.success = True
487
+ output.latency = latency
488
+ output.output_len = output_len
489
+ else:
490
+ output.error = (
491
+ (response.reason or "") + ": " + (await response.text())
492
+ )
493
+ output.success = False
494
+ except Exception:
495
+ output.success = False
496
+ exc_info = sys.exc_info()
497
+ output.error = "".join(traceback.format_exception(*exc_info))
498
+
499
+ # TODO put it to other functions when `pbar` logic is refactored
500
+ if getattr(args, "print_requests", False):
501
+ curr_t = time.time()
502
+ output_partial = deepcopy(output)
503
+ output_partial.generated_text = "..."
504
+ print(
505
+ f'rid={rid} time={curr_t} time_delta={curr_t - request_start_time} message="request end" output="{str(output_partial)}"'
506
+ )
507
+
508
+ if pbar:
509
+ pbar.update(1)
510
+ return output
511
+
512
+
513
+ async def async_request_truss(
514
+ request_func_input: RequestFuncInput,
515
+ pbar: Optional[tqdm] = None,
516
+ ) -> RequestFuncOutput:
517
+ api_url = request_func_input.api_url
518
+
519
+ prompt = request_func_input.prompt
520
+
521
+ async with _create_bench_client_session() as session:
522
+ payload = {
523
+ "model": request_func_input.model,
524
+ "prompt": prompt,
525
+ "temperature": 0.0,
526
+ "best_of": 1,
527
+ "max_tokens": request_func_input.output_len,
528
+ "stream": not args.disable_stream,
529
+ "ignore_eos": not args.disable_ignore_eos,
530
+ **request_func_input.extra_request_body,
531
+ }
532
+ headers = get_request_headers()
533
+
534
+ output = RequestFuncOutput.init_new(request_func_input)
535
+
536
+ generated_text = ""
537
+ ttft = 0.0
538
+ st = time.perf_counter()
539
+ most_recent_timestamp = st
540
+ try:
541
+ async with session.post(
542
+ url=api_url, json=payload, headers=headers
543
+ ) as response:
544
+ if response.status == 200:
545
+ async for chunk_bytes in response.content:
546
+ chunk_bytes = chunk_bytes.strip()
547
+ if not chunk_bytes:
548
+ continue
549
+
550
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
551
+ latency = time.perf_counter() - st
552
+ if chunk == "[DONE]":
553
+ pass
554
+ else:
555
+ data = json.loads(chunk)
556
+
557
+ # NOTE: Some completion API might have a last
558
+ # usage summary response without a token so we
559
+ # want to check a token was generated
560
+ if data["choices"][0]["text"]:
561
+ timestamp = time.perf_counter()
562
+ # First token
563
+ if ttft == 0.0:
564
+ ttft = time.perf_counter() - st
565
+ output.ttft = ttft
566
+
567
+ # Decoding phase
568
+ else:
569
+ output.itl.append(timestamp - most_recent_timestamp)
570
+
571
+ most_recent_timestamp = timestamp
572
+ generated_text += data["choices"][0]["text"]
573
+
574
+ output.generated_text = generated_text
575
+ output.success = True
576
+ output.latency = latency
577
+ output.output_len = request_func_input.output_len
578
+ else:
579
+ output.error = (
580
+ (response.reason or "") + ": " + (await response.text())
581
+ )
582
+ output.success = False
583
+ except Exception:
584
+ output.success = False
585
+ exc_info = sys.exc_info()
586
+ output.error = "".join(traceback.format_exception(*exc_info))
587
+
588
+ if pbar:
589
+ pbar.update(1)
590
+ return output
591
+
592
+
593
+ async def async_request_sglang_generate(
594
+ request_func_input: RequestFuncInput,
595
+ pbar: Optional[tqdm] = None,
596
+ ) -> RequestFuncOutput:
597
+ api_url = request_func_input.api_url
598
+ prompt = request_func_input.prompt
599
+
600
+ async with _create_bench_client_session() as session:
601
+ payload = {
602
+ ("text" if isinstance(prompt, str) else "input_ids"): prompt,
603
+ "sampling_params": {
604
+ "temperature": 0.0,
605
+ "max_new_tokens": request_func_input.output_len,
606
+ "ignore_eos": not args.disable_ignore_eos,
607
+ },
608
+ "stream": not args.disable_stream,
609
+ "lora_path": request_func_input.lora_name,
610
+ "return_logprob": args.return_logprob,
611
+ "return_routed_experts": args.return_routed_experts,
612
+ "logprob_start_len": -1,
613
+ **request_func_input.extra_request_body,
614
+ }
615
+
616
+ # Add image data if available (list of image urls/base64)
617
+ if request_func_input.image_data:
618
+ payload["image_data"] = request_func_input.image_data
619
+
620
+ headers = get_request_headers()
621
+ if request_func_input.routing_key:
622
+ headers[_ROUTING_KEY_HEADER] = request_func_input.routing_key
623
+
624
+ output = RequestFuncOutput.init_new(request_func_input)
625
+
626
+ generated_text = ""
627
+ output_len = request_func_input.output_len
628
+ ttft = 0.0
629
+ st = time.perf_counter()
630
+ output.start_time = st
631
+ most_recent_timestamp = st
632
+ last_output_len = 0
633
+ try:
634
+ async with session.post(
635
+ url=api_url, json=payload, headers=headers
636
+ ) as response:
637
+ if response.status == 200:
638
+ async for chunk_bytes in response.content:
639
+ chunk_bytes = chunk_bytes.strip()
640
+ if not chunk_bytes:
641
+ continue
642
+
643
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
644
+ latency = time.perf_counter() - st
645
+ if chunk == "[DONE]":
646
+ pass
647
+ else:
648
+ data = json.loads(chunk)
649
+
650
+ # NOTE: Some completion API might have a last
651
+ # usage summary response without a token so we
652
+ # want to check a token was generated
653
+ if "text" in data and data["text"]:
654
+ timestamp = time.perf_counter()
655
+ generated_text = data["text"]
656
+ output_len = data["meta_info"]["completion_tokens"]
657
+
658
+ # First token
659
+ if ttft == 0.0:
660
+ ttft = time.perf_counter() - st
661
+ output.ttft = ttft
662
+
663
+ # Decoding phase
664
+ else:
665
+ num_new_tokens = output_len - last_output_len
666
+ if num_new_tokens == 0:
667
+ continue
668
+ chunk_gap = timestamp - most_recent_timestamp
669
+ adjust_itl = chunk_gap / num_new_tokens
670
+ output.itl.extend([adjust_itl] * num_new_tokens)
671
+
672
+ most_recent_timestamp = timestamp
673
+ last_output_len = output_len
674
+
675
+ output.generated_text = generated_text
676
+ output.success = True
677
+ output.latency = latency
678
+ output.output_len = output_len
679
+ else:
680
+ output.error = (
681
+ (response.reason or "") + ": " + (await response.text())
682
+ )
683
+ output.success = False
684
+ except Exception:
685
+ output.success = False
686
+ exc_info = sys.exc_info()
687
+ output.error = "".join(traceback.format_exception(*exc_info))
688
+ print(f"{output.error=}")
689
+
690
+ if pbar:
691
+ pbar.update(1)
692
+ return output
693
+
694
+
695
+ async def async_request_gserver(
696
+ request_func_input: RequestFuncInput,
697
+ pbar: Optional[tqdm] = None,
698
+ ) -> RequestFuncOutput:
699
+ raise NotImplementedError()
700
+
701
+
702
+ async def async_request_profile(api_url: str) -> RequestFuncOutput:
703
+ async with _create_bench_client_session() as session:
704
+ output = RequestFuncOutput()
705
+ try:
706
+ if api_url.endswith("/start_profile"):
707
+ num_steps = getattr(args, "profile_num_steps", None)
708
+ profile_by_stage = getattr(args, "profile_by_stage", None)
709
+ if profile_by_stage and num_steps is None:
710
+ num_steps = 5
711
+
712
+ output_dir = getattr(args, "profile_output_dir", None)
713
+ if output_dir is None:
714
+ output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
715
+ output_dir = Path(os.path.abspath(os.path.normpath(output_dir))) / str(
716
+ time.time()
717
+ )
718
+ output_dir.mkdir(exist_ok=True, parents=True)
719
+ output_dir = str(output_dir)
720
+
721
+ body = {
722
+ "activities": getattr(args, "profile_activities", []),
723
+ "num_steps": num_steps,
724
+ "profile_by_stage": profile_by_stage,
725
+ "profile_stages": getattr(args, "profile_stages", None),
726
+ "output_dir": output_dir,
727
+ "profile_prefix": getattr(args, "profile_prefix", None),
728
+ }
729
+ else:
730
+ # stop_profile doesn't need any parameters
731
+ body = {}
732
+ print(f"async_request_profile {api_url=} {body=}")
733
+ # Add optional profiling parameters if provided
734
+ if (
735
+ hasattr(args, "profile_start_step")
736
+ and args.profile_start_step is not None
737
+ ):
738
+ body["start_step"] = str(args.profile_start_step)
739
+ if hasattr(args, "profile_steps") and args.profile_steps is not None:
740
+ body["num_steps"] = str(args.profile_steps)
741
+ async with session.post(url=api_url, json=body) as response:
742
+ if response.status == 200:
743
+ output.success = True
744
+ else:
745
+ output.error = (
746
+ (response.reason or "") + ": " + (await response.text())
747
+ )
748
+ output.success = False
749
+ except Exception:
750
+ output.success = False
751
+ exc_info = sys.exc_info()
752
+ output.error = "".join(traceback.format_exception(*exc_info))
753
+
754
+ return output
755
+
756
+
757
+ def _build_profile_urls(
758
+ profile_prefill_url: Optional[List[str]],
759
+ profile_decode_url: Optional[List[str]],
760
+ ) -> List[Tuple[str, str]]:
761
+ """Build profile URLs list from prefill/decode URL arguments.
762
+
763
+ Returns:
764
+ List of (worker_type, url) tuples. e.g., [("Prefill-0", "http://..."), ("Decode-0", "http://...")]
765
+ """
766
+ profile_urls = []
767
+ if profile_prefill_url:
768
+ for idx, url in enumerate(profile_prefill_url):
769
+ profile_urls.append((f"Prefill-{idx}", url))
770
+ if profile_decode_url:
771
+ for idx, url in enumerate(profile_decode_url):
772
+ profile_urls.append((f"Decode-{idx}", url))
773
+ return profile_urls
774
+
775
+
776
+ async def _call_profile_pd(profile_urls: List[Tuple[str, str]], mode: str) -> None:
777
+ """Call profile endpoint (start/stop) on PD separated workers.
778
+
779
+ Args:
780
+ profile_urls: List of (worker_type, url) tuples
781
+ mode: "start" or "stop"
782
+ """
783
+ endpoint = "/start_profile" if mode == "start" else "/stop_profile"
784
+ action = "Starting" if mode == "start" else "Stopping"
785
+ action_past = "started" if mode == "start" else "stopped"
786
+
787
+ print(f"{action} profiler...")
788
+
789
+ for worker_type, url in profile_urls:
790
+ profile_output = await async_request_profile(api_url=url + endpoint)
791
+ if profile_output.success:
792
+ print(f"Profiler {action_past} for {worker_type} worker at {url}")
793
+ else:
794
+ print(
795
+ f"Failed to {mode} profiler for {worker_type} worker at {url}: {profile_output.error}"
796
+ )
797
+
798
+
799
+ ASYNC_REQUEST_FUNCS = {
800
+ "sglang": async_request_sglang_generate,
801
+ "sglang-native": async_request_sglang_generate,
802
+ "sglang-oai": async_request_openai_completions,
803
+ "sglang-oai-chat": async_request_openai_chat_completions,
804
+ "vllm": async_request_openai_completions,
805
+ "vllm-chat": async_request_openai_chat_completions,
806
+ "lmdeploy": async_request_openai_completions,
807
+ "lmdeploy-chat": async_request_openai_chat_completions,
808
+ "trt": async_request_trt_llm,
809
+ "gserver": async_request_gserver,
810
+ "truss": async_request_truss,
811
+ }
812
+
813
+
814
+ @dataclass
815
+ class BenchmarkMetrics:
816
+ completed: int
817
+ total_input: int
818
+ total_input_text: int
819
+ total_input_vision: int
820
+ total_output: int
821
+ total_output_retokenized: int
822
+ request_throughput: float
823
+ input_throughput: float
824
+ output_throughput: float
825
+ output_throughput_retokenized: float
826
+ total_throughput: float
827
+ total_throughput_retokenized: float
828
+ mean_ttft_ms: float
829
+ median_ttft_ms: float
830
+ std_ttft_ms: float
831
+ p99_ttft_ms: float
832
+ mean_tpot_ms: float
833
+ median_tpot_ms: float
834
+ std_tpot_ms: float
835
+ p99_tpot_ms: float
836
+ mean_itl_ms: float
837
+ median_itl_ms: float
838
+ std_itl_ms: float
839
+ p95_itl_ms: float
840
+ p99_itl_ms: float
841
+ max_itl_ms: float
842
+ mean_e2e_latency_ms: float
843
+ median_e2e_latency_ms: float
844
+ std_e2e_latency_ms: float
845
+ p90_e2e_latency_ms: float
846
+ p99_e2e_latency_ms: float
847
+ concurrency: float
848
+ max_output_tokens_per_s: float = 0.0
849
+ max_concurrent_requests: int = 0
850
+
851
+
852
+ async def get_request(
853
+ input_requests: List[DatasetRow],
854
+ request_rate: float,
855
+ use_trace_timestamps: bool = False,
856
+ slowdown_factor: float = 1.0,
857
+ ) -> AsyncGenerator[DatasetRow, None]:
858
+ if use_trace_timestamps:
859
+ print(
860
+ f"Using trace timestamps for request generation with slowdown factor {slowdown_factor}."
861
+ )
862
+ # Sort requests by timestamp for correct replay
863
+ input_requests.sort(key=lambda r: r.timestamp)
864
+
865
+ start_time = time.perf_counter()
866
+ trace_start_time_ms = input_requests[0].timestamp if input_requests else 0
867
+
868
+ for request in input_requests:
869
+ trace_time_s = (request.timestamp - trace_start_time_ms) / 1000.0
870
+ target_arrival_time = start_time + (trace_time_s * slowdown_factor)
871
+
872
+ sleep_duration = target_arrival_time - time.perf_counter()
873
+ if sleep_duration > 0:
874
+ await asyncio.sleep(sleep_duration)
875
+
876
+ yield request
877
+ else:
878
+ input_requests_iter = iter(input_requests)
879
+ for request in input_requests_iter:
880
+ yield request
881
+
882
+ if request_rate == float("inf"):
883
+ # If the request rate is infinity, then we don't need to wait.
884
+ continue
885
+
886
+ # Sample the request interval from the exponential distribution.
887
+ interval = np.random.exponential(1.0 / request_rate)
888
+ # The next request will be sent after the interval.
889
+ await asyncio.sleep(interval)
890
+
891
+
892
+ def calculate_metrics(
893
+ input_requests: Optional[List[DatasetRow]],
894
+ outputs: List[RequestFuncOutput],
895
+ dur_s: float,
896
+ tokenizer: PreTrainedTokenizerBase,
897
+ backend: str,
898
+ accept_length: Optional[float] = None,
899
+ plot_throughput: bool = False,
900
+ ) -> Tuple[BenchmarkMetrics, List[int]]:
901
+ output_lens: List[int] = []
902
+ retokenized_output_lens: List[int] = []
903
+ total_input = 0
904
+ total_input_text = 0
905
+ total_input_vision = 0
906
+ completed = 0
907
+ itls: List[float] = []
908
+ tpots: List[float] = []
909
+ ttfts: List[float] = []
910
+ e2e_latencies: List[float] = []
911
+ retokenized_itls: List[float] = []
912
+
913
+ use_retokenized_itl = (
914
+ accept_length is not None
915
+ and accept_length > 0
916
+ and backend in ("sglang-oai", "sglang-oai-chat")
917
+ )
918
+
919
+ for i in range(len(outputs)):
920
+ if outputs[i].success:
921
+ output_len = outputs[i].output_len
922
+ output_lens.append(output_len)
923
+ retokenized_output_len = len(
924
+ tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
925
+ )
926
+ retokenized_output_lens.append(retokenized_output_len)
927
+ if input_requests is not None:
928
+ total_input += input_requests[i].prompt_len
929
+ total_input_text += input_requests[i].text_prompt_len
930
+ total_input_vision += input_requests[i].vision_prompt_len
931
+ if output_len > 1:
932
+ tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
933
+ if use_retokenized_itl:
934
+ for k, itl in enumerate(outputs[i].itl):
935
+ num_tokens = len(
936
+ tokenizer.encode(
937
+ outputs[i].text_chunks[k], add_special_tokens=False
938
+ )
939
+ )
940
+ adjusted_itl = itl / num_tokens
941
+ retokenized_itls.extend([adjusted_itl] * num_tokens)
942
+ else:
943
+ itls += outputs[i].itl
944
+ ttfts.append(outputs[i].ttft)
945
+
946
+ e2e_latencies.append(outputs[i].latency)
947
+
948
+ completed += 1
949
+ else:
950
+ output_lens.append(0)
951
+ retokenized_output_lens.append(0)
952
+
953
+ if completed == 0:
954
+ warnings.warn(
955
+ "All requests failed. This is likely due to a misconfiguration "
956
+ "on the benchmark arguments.",
957
+ stacklevel=2,
958
+ )
959
+
960
+ max_output_tokens_per_s = 0.0
961
+ max_concurrent_requests = 0
962
+
963
+ successful_outputs = [output for output in outputs if output.success]
964
+ if successful_outputs:
965
+ min_start_time = min(output.start_time for output in successful_outputs)
966
+ max_end_time = max(
967
+ output.start_time + output.latency for output in successful_outputs
968
+ )
969
+
970
+ duration_seconds = int(np.ceil(max_end_time - min_start_time)) + 1
971
+ tokens_per_second = np.zeros(duration_seconds)
972
+ concurrent_requests_per_second = np.zeros(duration_seconds)
973
+
974
+ for output in outputs:
975
+ if not output.success:
976
+ continue
977
+
978
+ token_times = [output.start_time + output.ttft]
979
+ current_time = token_times[0]
980
+ for itl_value in output.itl:
981
+ current_time += itl_value
982
+ token_times.append(current_time)
983
+
984
+ for token_time in token_times:
985
+ second_bucket = int(token_time - min_start_time)
986
+ if 0 <= second_bucket < duration_seconds:
987
+ tokens_per_second[second_bucket] += 1
988
+
989
+ request_start_second = int(output.start_time - min_start_time)
990
+ request_end_second = int(
991
+ (output.start_time + output.latency) - min_start_time
992
+ )
993
+
994
+ for second in range(
995
+ request_start_second, min(request_end_second + 1, duration_seconds)
996
+ ):
997
+ concurrent_requests_per_second[second] += 1
998
+
999
+ if len(tokens_per_second) > 0:
1000
+ max_output_tokens_per_s = float(np.max(tokens_per_second))
1001
+ max_concurrent_requests = int(np.max(concurrent_requests_per_second))
1002
+
1003
+ if plot_throughput:
1004
+ if TERM_PLOTLIB_AVAILABLE:
1005
+ import termplotlib as tpl
1006
+
1007
+ fig = tpl.figure()
1008
+ fig.plot(
1009
+ np.arange(len(tokens_per_second)),
1010
+ tokens_per_second,
1011
+ title="Output tokens per second",
1012
+ xlabel="Time (s)",
1013
+ )
1014
+ fig.plot(
1015
+ np.arange(len(concurrent_requests_per_second)),
1016
+ concurrent_requests_per_second,
1017
+ title="Concurrent requests per second",
1018
+ xlabel="Time (s)",
1019
+ )
1020
+ fig.show()
1021
+ else:
1022
+ print("tip: install termplotlib and gnuplot to plot the metrics")
1023
+
1024
+ itls = retokenized_itls if use_retokenized_itl else itls
1025
+ metrics = BenchmarkMetrics(
1026
+ completed=completed,
1027
+ total_input=total_input,
1028
+ total_input_text=total_input_text,
1029
+ total_input_vision=total_input_vision,
1030
+ total_output=sum(output_lens),
1031
+ total_output_retokenized=sum(retokenized_output_lens),
1032
+ request_throughput=completed / dur_s,
1033
+ input_throughput=total_input / dur_s,
1034
+ output_throughput=sum(output_lens) / dur_s,
1035
+ output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
1036
+ total_throughput=(total_input + sum(output_lens)) / dur_s,
1037
+ total_throughput_retokenized=(total_input + sum(retokenized_output_lens))
1038
+ / dur_s,
1039
+ mean_ttft_ms=np.mean(ttfts or 0)
1040
+ * 1000, # ttfts is empty if streaming is not supported by backend
1041
+ median_ttft_ms=np.median(ttfts or 0) * 1000,
1042
+ std_ttft_ms=np.std(ttfts or 0) * 1000,
1043
+ p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
1044
+ mean_tpot_ms=np.mean(tpots or 0) * 1000,
1045
+ median_tpot_ms=np.median(tpots or 0) * 1000,
1046
+ std_tpot_ms=np.std(tpots or 0) * 1000,
1047
+ p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
1048
+ mean_itl_ms=np.mean(itls or 0) * 1000,
1049
+ median_itl_ms=np.median(itls or 0) * 1000,
1050
+ std_itl_ms=np.std(itls or 0) * 1000,
1051
+ p95_itl_ms=np.percentile(itls or 0, 95) * 1000,
1052
+ p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
1053
+ max_itl_ms=np.max(itls or 0) * 1000,
1054
+ mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
1055
+ median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
1056
+ std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
1057
+ p90_e2e_latency_ms=np.percentile(e2e_latencies, 90) * 1000,
1058
+ p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000,
1059
+ concurrency=np.sum(e2e_latencies) / dur_s,
1060
+ max_output_tokens_per_s=max_output_tokens_per_s,
1061
+ max_concurrent_requests=max_concurrent_requests,
1062
+ )
1063
+
1064
+ return metrics, output_lens
1065
+
1066
+
1067
+ MULTI_TURN_BACKENDS = {"sglang-oai-chat", "vllm-chat", "lmdeploy-chat"}
1068
+
1069
+
1070
+ def wrap_multi_turn_request_func(request_func: Callable, backend: str) -> Callable:
1071
+ assert (
1072
+ backend in MULTI_TURN_BACKENDS
1073
+ ), f"Multi-turn only supports chat backends: {MULTI_TURN_BACKENDS}, got {backend}"
1074
+
1075
+ async def f(
1076
+ request_func_input: RequestFuncInput,
1077
+ pbar: Optional[tqdm] = None,
1078
+ ) -> List[RequestFuncOutput]:
1079
+ prompts: List[str] = request_func_input.prompt
1080
+ prev_messages: List[Dict[str, str]] = []
1081
+ outputs = []
1082
+
1083
+ for round_index in range(len(prompts)):
1084
+ prev_messages.append({"role": "user", "content": prompts[round_index]})
1085
+
1086
+ inner_input = replace(
1087
+ copy.deepcopy(request_func_input), prompt=copy.deepcopy(prev_messages)
1088
+ )
1089
+ output = await request_func(
1090
+ inner_input, pbar=pbar if round_index == len(prompts) - 1 else None
1091
+ )
1092
+ outputs.append(output)
1093
+
1094
+ prev_messages.append(
1095
+ {"role": "assistant", "content": output.generated_text}
1096
+ )
1097
+
1098
+ return outputs
1099
+
1100
+ return f
1101
+
1102
+
1103
+ async def benchmark(
1104
+ backend: str,
1105
+ api_url: str,
1106
+ base_url: str,
1107
+ model_id: str,
1108
+ tokenizer: PreTrainedTokenizerBase,
1109
+ input_requests: List[DatasetRow],
1110
+ request_rate: float,
1111
+ max_concurrency: Optional[int],
1112
+ disable_tqdm: bool,
1113
+ lora_names: List[str],
1114
+ lora_request_distribution: Optional[str],
1115
+ lora_zipf_alpha: Optional[float],
1116
+ extra_request_body: Dict[str, Any],
1117
+ profile: bool,
1118
+ pd_separated: bool = False,
1119
+ flush_cache: bool = False,
1120
+ warmup_requests: int = 1,
1121
+ use_trace_timestamps: bool = False,
1122
+ mooncake_slowdown_factor=1.0,
1123
+ mooncake_num_rounds=1,
1124
+ profile_prefill_url: Optional[List[str]] = None,
1125
+ profile_decode_url: Optional[List[str]] = None,
1126
+ ):
1127
+ if backend in ASYNC_REQUEST_FUNCS:
1128
+ request_func = ASYNC_REQUEST_FUNCS[backend]
1129
+ else:
1130
+ raise ValueError(f"Unknown backend: {backend}")
1131
+
1132
+ # Check for multi-turn: prompt is a list of strings (not OpenAI messages dicts)
1133
+ # Multi-turn format: ["turn1", "turn2", ...] - list of strings
1134
+ # OpenAI format: [{"role": "user", "content": "..."}, ...] - list of dicts
1135
+ first_prompt = input_requests[0].prompt
1136
+ is_multi_turn = (
1137
+ isinstance(first_prompt, list)
1138
+ and len(first_prompt) > 0
1139
+ and isinstance(first_prompt[0], str)
1140
+ )
1141
+ if is_multi_turn:
1142
+ request_func = wrap_multi_turn_request_func(request_func, backend=backend)
1143
+
1144
+ # Limit concurrency
1145
+ # From https://github.com/vllm-project/vllm/pull/9390
1146
+ semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
1147
+
1148
+ async def limited_request_func(request_func_input, pbar):
1149
+ if semaphore is None:
1150
+ return await request_func(request_func_input=request_func_input, pbar=pbar)
1151
+ async with semaphore:
1152
+ return await request_func(request_func_input=request_func_input, pbar=pbar)
1153
+
1154
+ # Warmup
1155
+ print(f"Starting warmup with {warmup_requests} sequences...")
1156
+
1157
+ # Handle the data structure difference for the warmup request
1158
+ if args.dataset_name == "mooncake":
1159
+ # For mooncake, input_requests is a list of dicts.
1160
+ # We need to build a temporary DatasetRow for the warmup phase.
1161
+ warmup_record = input_requests[0]
1162
+
1163
+ # Build prompt from hash_ids, just like in the async generator
1164
+ hash_ids = warmup_record.get("hash_ids", [])
1165
+ prompt_text = ""
1166
+ for hash_id in hash_ids:
1167
+ prompt_text += f"{hash_id}" + " ".join(["hi"] * 512)
1168
+ prompt_text += "Can you tell me a detailed story in 1000 words?"
1169
+
1170
+ output_len = warmup_record.get("output_length", 32)
1171
+ prompt_len = len(tokenizer.encode(prompt_text))
1172
+
1173
+ # Create a temporary DatasetRow object for warmup
1174
+ test_request = DatasetRow(
1175
+ prompt=prompt_text,
1176
+ prompt_len=prompt_len,
1177
+ output_len=output_len,
1178
+ image_data=None, # Mooncake doesn't have image data
1179
+ )
1180
+ else:
1181
+ # For all other datasets, input_requests is a list of DatasetRow objects
1182
+ test_request = input_requests[0]
1183
+
1184
+ if lora_names is not None and len(lora_names) != 0:
1185
+ lora_name = lora_names[0]
1186
+ else:
1187
+ lora_name = None
1188
+
1189
+ # Create the test input once
1190
+ test_input = RequestFuncInput(
1191
+ model=model_id,
1192
+ prompt=test_request.prompt,
1193
+ api_url=api_url,
1194
+ prompt_len=test_request.prompt_len,
1195
+ output_len=min(test_request.output_len, 32),
1196
+ lora_name=lora_name,
1197
+ image_data=test_request.image_data,
1198
+ extra_request_body=extra_request_body,
1199
+ )
1200
+
1201
+ # Run warmup requests
1202
+ warmup_tasks = []
1203
+ for _ in range(warmup_requests):
1204
+ warmup_tasks.append(
1205
+ asyncio.create_task(request_func(request_func_input=test_input))
1206
+ )
1207
+
1208
+ warmup_outputs = await asyncio.gather(*warmup_tasks)
1209
+ if is_multi_turn:
1210
+ warmup_outputs = [x for output in warmup_outputs for x in output]
1211
+
1212
+ # Check if at least one warmup request succeeded
1213
+ if warmup_requests > 0 and not any(output.success for output in warmup_outputs):
1214
+ raise ValueError(
1215
+ "Warmup failed - Please make sure benchmark arguments "
1216
+ f"are correctly specified. Error: {warmup_outputs[0].error}"
1217
+ )
1218
+ else:
1219
+ print(
1220
+ f"Warmup completed with {args.warmup_requests} sequences. Starting main benchmark run..."
1221
+ )
1222
+
1223
+ # Flush cache
1224
+ if ("sglang" in backend and _get_bool_env_var("SGLANG_IS_IN_CI")) or flush_cache:
1225
+ requests.post(base_url + "/flush_cache", headers=get_auth_headers())
1226
+
1227
+ time.sleep(1.0)
1228
+
1229
+ # Build profile URLs for PD separated mode (do this once at the beginning)
1230
+ pd_profile_urls = []
1231
+ if profile and pd_separated:
1232
+ pd_profile_urls = _build_profile_urls(profile_prefill_url, profile_decode_url)
1233
+ if not pd_profile_urls:
1234
+ print(
1235
+ "Warning: PD separated mode requires --profile-prefill-url or --profile-decode-url"
1236
+ )
1237
+ print("Skipping profiler start. Please specify worker URLs for profiling.")
1238
+
1239
+ # Start profiler
1240
+ if profile:
1241
+ if pd_separated:
1242
+ if pd_profile_urls:
1243
+ await _call_profile_pd(pd_profile_urls, "start")
1244
+ else:
1245
+ print("Starting profiler...")
1246
+ profile_output = await async_request_profile(
1247
+ api_url=base_url + "/start_profile"
1248
+ )
1249
+ if profile_output.success:
1250
+ print("Profiler started")
1251
+
1252
+ # Run all requests
1253
+ benchmark_start_time = time.perf_counter()
1254
+ tasks: List[asyncio.Task] = []
1255
+ pbar_total = len(input_requests)
1256
+ if (
1257
+ backend == "sglang" and args.dataset_name == "mooncake"
1258
+ ): # Assuming mooncake is mainly for sglang or similar backends
1259
+ print("Using time-based Mooncake request scheduler, ignoring --request-rate.")
1260
+ request_generator = get_mooncake_request_over_time(
1261
+ input_requests, tokenizer, mooncake_slowdown_factor, mooncake_num_rounds
1262
+ )
1263
+ print(
1264
+ f"Starting Mooncake trace replay. Sessions: {len(input_requests)}, Rounds per session: {mooncake_num_rounds}. Slowdown factor: {mooncake_slowdown_factor}"
1265
+ )
1266
+ pbar_total *= args.mooncake_num_rounds
1267
+ else:
1268
+ request_generator = get_request(input_requests, request_rate)
1269
+
1270
+ # Prepare LoRA request distribution parameters
1271
+ if lora_request_distribution == "distinct":
1272
+ lora_idx = 0
1273
+ elif lora_request_distribution == "skewed":
1274
+ weights = np.array([lora_zipf_alpha**-i for i in range(len(lora_names))])
1275
+ lora_probs = weights / np.sum(weights)
1276
+ else:
1277
+ lora_idx = None
1278
+ lora_probs = None
1279
+
1280
+ pbar = None if disable_tqdm else tqdm(total=pbar_total)
1281
+ async for request in request_generator:
1282
+ if lora_names is not None and len(lora_names) != 0:
1283
+ if lora_request_distribution == "uniform":
1284
+ lora_name = random.choice(lora_names)
1285
+ elif lora_request_distribution == "distinct":
1286
+ lora_name = lora_names[lora_idx]
1287
+ lora_idx = (lora_idx + 1) % len(lora_names)
1288
+ else:
1289
+ assert (
1290
+ lora_request_distribution == "skewed"
1291
+ ), f"Unexpected lora_request_distribution: {lora_request_distribution}. Expected 'skewed'."
1292
+
1293
+ lora_name = np.random.choice(lora_names, p=lora_probs)
1294
+ else:
1295
+ lora_name = None
1296
+
1297
+ # Merge global extra_request_body with per-request extras
1298
+ # Per-request parameters take precedence over global ones
1299
+ merged_extra_body = {**extra_request_body, **request.extra_request_body}
1300
+
1301
+ request_func_input = RequestFuncInput(
1302
+ model=model_id,
1303
+ prompt=request.prompt,
1304
+ api_url=api_url,
1305
+ prompt_len=request.prompt_len,
1306
+ output_len=request.output_len,
1307
+ lora_name=lora_name,
1308
+ image_data=request.image_data,
1309
+ extra_request_body=merged_extra_body,
1310
+ timestamp=request.timestamp,
1311
+ routing_key=request.routing_key,
1312
+ )
1313
+
1314
+ tasks.append(
1315
+ asyncio.create_task(
1316
+ limited_request_func(request_func_input=request_func_input, pbar=pbar)
1317
+ )
1318
+ )
1319
+ outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
1320
+ if is_multi_turn:
1321
+ outputs = [x for output in outputs for x in output]
1322
+
1323
+ # Stop profiler (only if profile_steps was not provided, as it auto-stops)
1324
+ if profile and not (
1325
+ hasattr(args, "profile_steps") and args.profile_steps is not None
1326
+ ):
1327
+ if pd_separated:
1328
+ if pd_profile_urls:
1329
+ await _call_profile_pd(pd_profile_urls, "stop")
1330
+ else:
1331
+ if getattr(args, "profile_num_steps", None) is None:
1332
+ print("Stopping profiler...")
1333
+ profile_output = await async_request_profile(
1334
+ api_url=base_url + "/stop_profile"
1335
+ )
1336
+ if profile_output.success:
1337
+ print("Profiler stopped")
1338
+
1339
+ if pbar is not None:
1340
+ pbar.close()
1341
+
1342
+ if "sglang" in backend:
1343
+ server_info = requests.get(
1344
+ base_url + "/get_server_info", headers=get_auth_headers()
1345
+ )
1346
+ if server_info.status_code == 200:
1347
+ server_info_json = server_info.json()
1348
+ if "decode" in server_info_json:
1349
+ server_info_json = server_info_json["decode"][0]
1350
+ if (
1351
+ "internal_states" in server_info_json
1352
+ and server_info_json["internal_states"]
1353
+ ):
1354
+ accept_length = server_info_json["internal_states"][0].get(
1355
+ "avg_spec_accept_length", None
1356
+ )
1357
+ else:
1358
+ accept_length = None
1359
+ else:
1360
+ accept_length = None
1361
+ else:
1362
+ accept_length = None
1363
+
1364
+ # Compute metrics and print results
1365
+ benchmark_duration = time.perf_counter() - benchmark_start_time
1366
+ metrics, output_lens = calculate_metrics(
1367
+ input_requests=None if is_multi_turn else input_requests,
1368
+ outputs=outputs,
1369
+ dur_s=benchmark_duration,
1370
+ tokenizer=tokenizer,
1371
+ backend=backend,
1372
+ accept_length=accept_length,
1373
+ plot_throughput=args.plot_throughput,
1374
+ )
1375
+
1376
+ print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
1377
+ print("{:<40} {:<10}".format("Backend:", backend))
1378
+ print(
1379
+ "{:<40} {:<10}".format(
1380
+ "Traffic request rate:", "trace" if use_trace_timestamps else request_rate
1381
+ )
1382
+ )
1383
+ print(
1384
+ "{:<40} {:<10}".format(
1385
+ "Max request concurrency:",
1386
+ max_concurrency if max_concurrency else "not set",
1387
+ )
1388
+ )
1389
+ print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
1390
+ print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
1391
+ print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
1392
+ print("{:<40} {:<10}".format("Total input text tokens:", metrics.total_input_text))
1393
+ if args.dataset_name in ["image", "mmmu"]:
1394
+ print(
1395
+ "{:<40} {:<10}".format(
1396
+ "Total input vision tokens:", metrics.total_input_vision
1397
+ )
1398
+ )
1399
+ print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
1400
+ print(
1401
+ "{:<40} {:<10}".format(
1402
+ "Total generated tokens (retokenized):", metrics.total_output_retokenized
1403
+ )
1404
+ )
1405
+ print(
1406
+ "{:<40} {:<10.2f}".format(
1407
+ "Request throughput (req/s):", metrics.request_throughput
1408
+ )
1409
+ )
1410
+ print(
1411
+ "{:<40} {:<10.2f}".format(
1412
+ "Input token throughput (tok/s):", metrics.input_throughput
1413
+ )
1414
+ )
1415
+ print(
1416
+ "{:<40} {:<10.2f}".format(
1417
+ "Output token throughput (tok/s):", metrics.output_throughput
1418
+ )
1419
+ )
1420
+ print(
1421
+ "{:<40} {:<10.2f}".format(
1422
+ "Peak output token throughput (tok/s):", metrics.max_output_tokens_per_s
1423
+ )
1424
+ )
1425
+ print(
1426
+ "{:<40} {:<10}".format(
1427
+ "Peak concurrent requests:", metrics.max_concurrent_requests
1428
+ )
1429
+ )
1430
+ print(
1431
+ "{:<40} {:<10.2f}".format(
1432
+ "Total token throughput (tok/s):", metrics.total_throughput
1433
+ )
1434
+ )
1435
+ print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
1436
+ if accept_length:
1437
+ print("{:<40} {:<10.2f}".format("Accept length:", accept_length))
1438
+ print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
1439
+ print(
1440
+ "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
1441
+ )
1442
+ print(
1443
+ "{:<40} {:<10.2f}".format(
1444
+ "Median E2E Latency (ms):", metrics.median_e2e_latency_ms
1445
+ )
1446
+ )
1447
+ print(
1448
+ "{:<40} {:<10.2f}".format("P90 E2E Latency (ms):", metrics.p90_e2e_latency_ms)
1449
+ )
1450
+ print(
1451
+ "{:<40} {:<10.2f}".format("P99 E2E Latency (ms):", metrics.p99_e2e_latency_ms)
1452
+ )
1453
+ print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
1454
+ print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
1455
+ print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
1456
+ print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
1457
+ print(
1458
+ "{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
1459
+ )
1460
+ print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
1461
+ print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
1462
+ print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
1463
+ print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
1464
+ print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
1465
+ print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
1466
+ print("{:<40} {:<10.2f}".format("P95 ITL (ms):", metrics.p95_itl_ms))
1467
+ print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
1468
+ print("{:<40} {:<10.2f}".format("Max ITL (ms):", metrics.max_itl_ms))
1469
+ print("=" * 50)
1470
+
1471
+ resp = requests.get(base_url + "/get_server_info", headers=get_auth_headers())
1472
+ server_info = resp.json() if resp.status_code == 200 else None
1473
+
1474
+ if (
1475
+ metrics.median_ttft_ms is not None
1476
+ and metrics.mean_itl_ms is not None
1477
+ and metrics.output_throughput is not None
1478
+ ):
1479
+ result = {
1480
+ # Arguments
1481
+ "tag": getattr(args, "tag", None),
1482
+ "backend": args.backend,
1483
+ "dataset_name": args.dataset_name,
1484
+ "request_rate": "trace" if use_trace_timestamps else request_rate,
1485
+ "max_concurrency": max_concurrency,
1486
+ "sharegpt_output_len": args.sharegpt_output_len,
1487
+ "random_input_len": args.random_input_len,
1488
+ "random_output_len": args.random_output_len,
1489
+ "random_range_ratio": args.random_range_ratio,
1490
+ # Information
1491
+ "server_info": server_info,
1492
+ # Results
1493
+ "duration": benchmark_duration,
1494
+ "completed": metrics.completed,
1495
+ "total_input_tokens": metrics.total_input,
1496
+ "total_input_text_tokens": metrics.total_input_text,
1497
+ "total_input_vision_tokens": metrics.total_input_vision,
1498
+ "total_output_tokens": metrics.total_output,
1499
+ "total_output_tokens_retokenized": metrics.total_output_retokenized,
1500
+ "request_throughput": metrics.request_throughput,
1501
+ "input_throughput": metrics.input_throughput,
1502
+ "output_throughput": metrics.output_throughput,
1503
+ "total_throughput": metrics.total_throughput,
1504
+ "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
1505
+ "median_e2e_latency_ms": metrics.median_e2e_latency_ms,
1506
+ "std_e2e_latency_ms": metrics.std_e2e_latency_ms,
1507
+ "p90_e2e_latency_ms": metrics.p90_e2e_latency_ms,
1508
+ "p99_e2e_latency_ms": metrics.p99_e2e_latency_ms,
1509
+ "mean_ttft_ms": metrics.mean_ttft_ms,
1510
+ "median_ttft_ms": metrics.median_ttft_ms,
1511
+ "std_ttft_ms": metrics.std_ttft_ms,
1512
+ "p99_ttft_ms": metrics.p99_ttft_ms,
1513
+ "mean_tpot_ms": metrics.mean_tpot_ms,
1514
+ "median_tpot_ms": metrics.median_tpot_ms,
1515
+ "std_tpot_ms": metrics.std_tpot_ms,
1516
+ "p99_tpot_ms": metrics.p99_tpot_ms,
1517
+ "mean_itl_ms": metrics.mean_itl_ms,
1518
+ "median_itl_ms": metrics.median_itl_ms,
1519
+ "std_itl_ms": metrics.std_itl_ms,
1520
+ "p95_itl_ms": metrics.p95_itl_ms,
1521
+ "p99_itl_ms": metrics.p99_itl_ms,
1522
+ "concurrency": metrics.concurrency,
1523
+ "accept_length": accept_length,
1524
+ "max_output_tokens_per_s": metrics.max_output_tokens_per_s,
1525
+ "max_concurrent_requests": metrics.max_concurrent_requests,
1526
+ }
1527
+ else:
1528
+ print(f"Error running benchmark for request rate: {request_rate}")
1529
+ print("-" * 30)
1530
+
1531
+ # Determine output file name
1532
+ if args.output_file:
1533
+ output_file_name = args.output_file
1534
+ else:
1535
+ now = datetime.now().strftime("%m%d")
1536
+ if args.dataset_name == "image":
1537
+ output_file_name = (
1538
+ f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_"
1539
+ f"{args.random_output_len}_{args.image_count}imgs_"
1540
+ f"{args.image_resolution}.jsonl"
1541
+ )
1542
+ elif args.dataset_name.startswith("random"):
1543
+ output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
1544
+ else:
1545
+ output_file_name = (
1546
+ f"{args.backend}_{now}_{args.num_prompts}_{args.dataset_name}.jsonl"
1547
+ )
1548
+
1549
+ result_details = {
1550
+ "input_lens": [output.prompt_len for output in outputs],
1551
+ "output_lens": output_lens,
1552
+ "ttfts": [output.ttft for output in outputs],
1553
+ "itls": [output.itl for output in outputs],
1554
+ "generated_texts": [output.generated_text for output in outputs],
1555
+ "errors": [output.error for output in outputs],
1556
+ }
1557
+
1558
+ # Append results to a JSONL file
1559
+ with open(output_file_name, "a") as file:
1560
+ if args.output_details:
1561
+ result_for_dump = result | result_details
1562
+ else:
1563
+ result_for_dump = result
1564
+ file.write(json.dumps(result_for_dump) + "\n")
1565
+
1566
+ return result | result_details
1567
+
1568
+
1569
+ def check_chat_template(model_path):
1570
+ try:
1571
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
1572
+ return "chat_template" in tokenizer.init_kwargs
1573
+ except Exception as e:
1574
+ print(f"Fail to load tokenizer config with error={e}")
1575
+ return False
1576
+
1577
+
1578
+ def set_global_args(args_: argparse.Namespace):
1579
+ """Set the global args."""
1580
+ global args
1581
+ args = args_
1582
+
1583
+
1584
+ def run_benchmark(args_: argparse.Namespace):
1585
+ global args
1586
+ args = args_
1587
+
1588
+ # Set default value for max_concurrency if not present
1589
+ if not hasattr(args, "max_concurrency"):
1590
+ args.max_concurrency = None
1591
+
1592
+ # Set default value for warmup_requests if not present
1593
+ if not hasattr(args, "warmup_requests"):
1594
+ args.warmup_requests = 1
1595
+
1596
+ if not hasattr(args, "output_details"):
1597
+ args.output_details = False
1598
+
1599
+ if not hasattr(args, "tokenize_prompt"):
1600
+ args.tokenize_prompt = False
1601
+
1602
+ if not hasattr(args, "plot_throughput"):
1603
+ args.plot_throughput = False
1604
+
1605
+ if not hasattr(args, "use_trace_timestamps"):
1606
+ args.use_trace_timestamps = False
1607
+ if not hasattr(args, "mooncake_slowdown_factor"):
1608
+ args.mooncake_slowdown_factor = 1.0
1609
+
1610
+ if not hasattr(args, "mooncake_slowdown_factor"):
1611
+ args.mooncake_slowdown_factor = 1.0
1612
+
1613
+ if not hasattr(args, "mooncake_num_rounds"):
1614
+ args.mooncake_num_rounds = 1
1615
+
1616
+ if not hasattr(args, "served_model_name"):
1617
+ args.served_model_name = None
1618
+
1619
+ if getattr(args, "print_requests", False):
1620
+ assert args.backend == "sglang-oai-chat" # only support this now
1621
+
1622
+ print(f"benchmark_args={args}")
1623
+
1624
+ # Set global environments
1625
+ set_ulimit()
1626
+ random.seed(args.seed)
1627
+ np.random.seed(args.seed)
1628
+
1629
+ extra_request_body = {}
1630
+ if args.extra_request_body:
1631
+ extra_request_body = json.loads(args.extra_request_body)
1632
+
1633
+ if args.tokenize_prompt:
1634
+ assert (
1635
+ args.backend == "sglang"
1636
+ ), "`--tokenize-prompt` only compatible with `--backend sglang` currently"
1637
+
1638
+ # Set url
1639
+ if args.port is None:
1640
+ args.port = {
1641
+ "sglang": 30000,
1642
+ "sglang-native": 30000,
1643
+ "sglang-oai": 30000,
1644
+ "lmdeploy": 23333,
1645
+ "vllm": 8000,
1646
+ "trt": 8000,
1647
+ "gserver": 9988,
1648
+ "truss": 8080,
1649
+ }.get(args.backend, 30000)
1650
+
1651
+ model_url = (
1652
+ f"{args.base_url}/v1/models"
1653
+ if args.base_url
1654
+ else f"http://{args.host}:{args.port}/v1/models"
1655
+ )
1656
+
1657
+ if args.backend in ["sglang", "sglang-native"]:
1658
+ api_url = (
1659
+ f"{args.base_url}/generate"
1660
+ if args.base_url
1661
+ else f"http://{args.host}:{args.port}/generate"
1662
+ )
1663
+ elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
1664
+ api_url = (
1665
+ f"{args.base_url}/v1/completions"
1666
+ if args.base_url
1667
+ else f"http://{args.host}:{args.port}/v1/completions"
1668
+ )
1669
+ elif args.backend in ["sglang-oai-chat", "vllm-chat", "lmdeploy-chat"]:
1670
+ api_url = (
1671
+ f"{args.base_url}/v1/chat/completions"
1672
+ if args.base_url
1673
+ else f"http://{args.host}:{args.port}/v1/chat/completions"
1674
+ )
1675
+ elif args.backend == "trt":
1676
+ api_url = (
1677
+ f"{args.base_url}/v2/models/ensemble/generate_stream"
1678
+ if args.base_url
1679
+ else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream"
1680
+ )
1681
+ if args.model is None:
1682
+ print("Please provide a model using `--model` when using `trt` backend.")
1683
+ sys.exit(1)
1684
+ elif args.backend == "gserver":
1685
+ api_url = args.base_url if args.base_url else f"{args.host}:{args.port}"
1686
+ args.model = args.model or "default"
1687
+ elif args.backend == "truss":
1688
+ api_url = (
1689
+ f"{args.base_url}/v1/models/model:predict"
1690
+ if args.base_url
1691
+ else f"http://{args.host}:{args.port}/v1/models/model:predict"
1692
+ )
1693
+ base_url = (
1694
+ f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url
1695
+ )
1696
+
1697
+ # Wait for server to be ready
1698
+ if args.ready_check_timeout_sec > 0:
1699
+ health_url = model_url if args.backend not in ("trt", "gserver") else base_url
1700
+ if not wait_for_endpoint(health_url, args.ready_check_timeout_sec):
1701
+ print(f"Server at {health_url} is not ready. Exiting.")
1702
+ sys.exit(1)
1703
+
1704
+ # Get model name
1705
+ if args.model is None:
1706
+ if args.backend == "truss":
1707
+ print(
1708
+ "Please provide a model with `--model` when using truss backend. e.g. --model meta-llama/Llama-3.1-8B-Instruct"
1709
+ )
1710
+ sys.exit(1)
1711
+ try:
1712
+ response = requests.get(model_url, headers=get_auth_headers())
1713
+ model_list = response.json().get("data", [])
1714
+ args.model = model_list[0]["id"] if model_list else None
1715
+ except Exception as e:
1716
+ print(f"Failed to fetch model from {model_url}. Error: {e}")
1717
+ print(
1718
+ "Please specify the correct host and port using `--host` and `--port`."
1719
+ )
1720
+ sys.exit(1)
1721
+
1722
+ if args.model is None:
1723
+ print("No model specified or found. Please provide a model using `--model`.")
1724
+ sys.exit(1)
1725
+
1726
+ if not check_chat_template(args.model):
1727
+ print(
1728
+ "\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n"
1729
+ "Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n"
1730
+ )
1731
+
1732
+ if args.dataset_name in ["image", "mmmu"]:
1733
+ args.apply_chat_template = True
1734
+ assert (
1735
+ not args.tokenize_prompt
1736
+ ), "`--tokenize-prompt` not compatible with image dataset"
1737
+
1738
+ if args.lora_request_distribution in ["distinct", "skewed"]:
1739
+ assert (
1740
+ args.lora_name is not None and len(args.lora_name) > 1
1741
+ ), "More than 1 LoRA adapter must be specified via --lora-name to use 'distinct' or 'skewed' request distribution."
1742
+
1743
+ assert (
1744
+ args.lora_zipf_alpha > 1
1745
+ ), f"Got invalid value for --lora-zipf-alpha of {args.lora_zipf_alpha}. It must be greater than 1."
1746
+
1747
+ print(f"{args}\n")
1748
+
1749
+ # Read dataset
1750
+ backend = args.backend
1751
+ model_id = args.served_model_name or args.model
1752
+ tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
1753
+ tokenizer = get_tokenizer(tokenizer_id)
1754
+ input_requests = get_dataset(args, tokenizer, model_id)
1755
+
1756
+ # compatible with SimpleNamespace
1757
+ if not hasattr(args, "flush_cache"):
1758
+ args.flush_cache = False
1759
+
1760
+ # Prepare LoRA arguments
1761
+ lora_request_distribution = (
1762
+ args.lora_request_distribution if args.lora_name is not None else None
1763
+ )
1764
+
1765
+ lora_zipf_alpha = (
1766
+ args.lora_zipf_alpha
1767
+ if args.lora_name is not None and args.lora_request_distribution == "skewed"
1768
+ else None
1769
+ )
1770
+
1771
+ return asyncio.run(
1772
+ benchmark(
1773
+ backend=backend,
1774
+ api_url=api_url,
1775
+ base_url=base_url,
1776
+ model_id=model_id,
1777
+ tokenizer=tokenizer,
1778
+ input_requests=input_requests,
1779
+ request_rate=args.request_rate,
1780
+ max_concurrency=args.max_concurrency,
1781
+ disable_tqdm=args.disable_tqdm,
1782
+ lora_names=args.lora_name,
1783
+ lora_request_distribution=lora_request_distribution,
1784
+ lora_zipf_alpha=lora_zipf_alpha,
1785
+ extra_request_body=extra_request_body,
1786
+ profile=args.profile,
1787
+ pd_separated=args.pd_separated,
1788
+ flush_cache=args.flush_cache,
1789
+ warmup_requests=args.warmup_requests,
1790
+ use_trace_timestamps=args.use_trace_timestamps,
1791
+ mooncake_slowdown_factor=args.mooncake_slowdown_factor,
1792
+ mooncake_num_rounds=args.mooncake_num_rounds,
1793
+ profile_prefill_url=getattr(args, "profile_prefill_url", None),
1794
+ profile_decode_url=getattr(args, "profile_decode_url", None),
1795
+ )
1796
+ )
1797
+
1798
+
1799
+ class LoRAPathAction(argparse.Action):
1800
+ def __call__(self, parser, namespace, values, option_string=None):
1801
+ setattr(namespace, self.dest, [])
1802
+ for lora_name in values:
1803
+ getattr(namespace, self.dest).append(lora_name)
1804
+
1805
+
1806
+ if __name__ == "__main__":
1807
+ parser = ArgumentParser(description="Benchmark the online serving throughput.")
1808
+ parser.add_argument(
1809
+ "--backend",
1810
+ type=str,
1811
+ choices=list(ASYNC_REQUEST_FUNCS.keys()),
1812
+ default="sglang",
1813
+ help="Must specify a backend, depending on the LLM Inference Engine.",
1814
+ )
1815
+ parser.add_argument(
1816
+ "--base-url",
1817
+ type=str,
1818
+ default=None,
1819
+ help="Server or API base url if not using http host and port.",
1820
+ )
1821
+ parser.add_argument(
1822
+ "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
1823
+ )
1824
+ parser.add_argument(
1825
+ "--port",
1826
+ type=int,
1827
+ help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
1828
+ )
1829
+ parser.add_argument(
1830
+ "--ready-check-timeout-sec",
1831
+ type=int,
1832
+ default=60,
1833
+ help="Maximum time in seconds to wait for the server to be ready before benchmarking. Set to 0 to skip. Default: 60.",
1834
+ )
1835
+ parser.add_argument(
1836
+ "--dataset-name",
1837
+ type=str,
1838
+ default="sharegpt",
1839
+ choices=[
1840
+ "sharegpt",
1841
+ "custom",
1842
+ "openai",
1843
+ "random",
1844
+ "random-ids",
1845
+ "generated-shared-prefix",
1846
+ "mmmu",
1847
+ "image",
1848
+ "mooncake",
1849
+ ],
1850
+ help="Name of the dataset to benchmark on.",
1851
+ )
1852
+ parser.add_argument(
1853
+ "--dataset-path", type=str, default="", help="Path to the dataset."
1854
+ )
1855
+ parser.add_argument(
1856
+ "--model",
1857
+ type=str,
1858
+ help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
1859
+ )
1860
+ parser.add_argument(
1861
+ "--served-model-name",
1862
+ type=str,
1863
+ help="The name of the model as served by the serving service. If not set, this defaults to the value of --model.",
1864
+ )
1865
+ parser.add_argument(
1866
+ "--tokenizer",
1867
+ type=str,
1868
+ help="Name or path of the tokenizer. If not set, using the model conf.",
1869
+ )
1870
+ parser.add_argument(
1871
+ "--num-prompts",
1872
+ type=int,
1873
+ default=1000,
1874
+ help="Number of prompts to process. Default is 1000.",
1875
+ )
1876
+ parser.add_argument(
1877
+ "--sharegpt-output-len",
1878
+ type=int,
1879
+ default=None,
1880
+ help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
1881
+ )
1882
+ parser.add_argument(
1883
+ "--sharegpt-context-len",
1884
+ type=int,
1885
+ default=None,
1886
+ help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
1887
+ )
1888
+ parser.add_argument(
1889
+ "--random-input-len",
1890
+ type=int,
1891
+ default=1024,
1892
+ help="Number of input tokens per request, used only for random and image dataset.",
1893
+ )
1894
+ parser.add_argument(
1895
+ "--random-output-len",
1896
+ default=1024,
1897
+ type=int,
1898
+ help="Number of output tokens per request, used only for random and image dataset.",
1899
+ )
1900
+ parser.add_argument(
1901
+ "--random-range-ratio",
1902
+ type=float,
1903
+ default=0.0,
1904
+ help="Range of sampled ratio of input/output length, "
1905
+ "used only for random and image dataset.",
1906
+ )
1907
+ # image dataset args
1908
+ parser.add_argument(
1909
+ "--image-count",
1910
+ type=int,
1911
+ default=1,
1912
+ help="Number of images per request (only available with the image dataset)",
1913
+ )
1914
+ parser.add_argument(
1915
+ "--image-resolution",
1916
+ type=str,
1917
+ default="1080p",
1918
+ help=(
1919
+ "Resolution of images for image dataset. "
1920
+ "Supports presets 4k/1080p/720p/360p or custom 'heightxwidth' (e.g., 1080x1920)."
1921
+ ),
1922
+ )
1923
+ parser.add_argument(
1924
+ "--random-image-count",
1925
+ action="store_true",
1926
+ help="Enable Random Image Count",
1927
+ )
1928
+ parser.add_argument(
1929
+ "--image-format",
1930
+ type=str,
1931
+ default="jpeg",
1932
+ help=("Format of images for image dataset. " "Supports jpeg and png."),
1933
+ )
1934
+ parser.add_argument(
1935
+ "--image-content",
1936
+ type=str,
1937
+ default="random",
1938
+ help=("Content for images for image dataset. " "Supports random and blank."),
1939
+ )
1940
+ parser.add_argument(
1941
+ "--request-rate",
1942
+ type=float,
1943
+ default=float("inf"),
1944
+ help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
1945
+ "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
1946
+ )
1947
+ parser.add_argument(
1948
+ "--use-trace-timestamps",
1949
+ action="store_true",
1950
+ help="Use timestamps from the trace file for request scheduling. Only valid for 'mooncake' dataset.",
1951
+ )
1952
+ parser.add_argument(
1953
+ "--max-concurrency",
1954
+ type=int,
1955
+ default=None,
1956
+ help="Maximum number of concurrent requests. This can be used "
1957
+ "to help simulate an environment where a higher level component "
1958
+ "is enforcing a maximum number of concurrent requests. While the "
1959
+ "--request-rate argument controls the rate at which requests are "
1960
+ "initiated, this argument will control how many are actually allowed "
1961
+ "to execute at a time. This means that when used in combination, the "
1962
+ "actual request rate may be lower than specified with --request-rate, "
1963
+ "if the server is not processing requests fast enough to keep up.",
1964
+ )
1965
+ parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
1966
+ parser.add_argument(
1967
+ "--output-details", action="store_true", help="Output details of benchmarking."
1968
+ )
1969
+ parser.add_argument(
1970
+ "--print-requests",
1971
+ action="store_true",
1972
+ help="Print requests immediately during benchmarking. Useful to quickly realize issues.",
1973
+ )
1974
+ parser.add_argument(
1975
+ "--disable-tqdm",
1976
+ action="store_true",
1977
+ help="Specify to disable tqdm progress bar.",
1978
+ )
1979
+ parser.add_argument(
1980
+ "--disable-stream",
1981
+ action="store_true",
1982
+ help="Disable streaming mode.",
1983
+ )
1984
+ parser.add_argument(
1985
+ "--return-logprob",
1986
+ action="store_true",
1987
+ help="Return logprob.",
1988
+ )
1989
+ parser.add_argument(
1990
+ "--return-routed-experts",
1991
+ action="store_true",
1992
+ help="Return routed experts.",
1993
+ )
1994
+ parser.add_argument("--seed", type=int, default=1, help="The random seed.")
1995
+ parser.add_argument(
1996
+ "--disable-ignore-eos",
1997
+ action="store_true",
1998
+ help="Disable ignoring EOS.",
1999
+ )
2000
+ parser.add_argument(
2001
+ "--extra-request-body",
2002
+ metavar='{"key1": "value1", "key2": "value2"}',
2003
+ type=str,
2004
+ help="Append given JSON object to the request payload. You can use this to specify"
2005
+ "additional generate params like sampling params.",
2006
+ )
2007
+ parser.add_argument(
2008
+ "--apply-chat-template",
2009
+ action="store_true",
2010
+ help="Apply chat template",
2011
+ )
2012
+ parser.add_argument(
2013
+ "--profile",
2014
+ action="store_true",
2015
+ help="Use Torch Profiler. The endpoint must be launched with "
2016
+ "SGLANG_TORCH_PROFILER_DIR to enable profiler.",
2017
+ )
2018
+ parser.add_argument(
2019
+ "--plot-throughput",
2020
+ action="store_true",
2021
+ help="Plot throughput and concurrent requests over time. Requires termplotlib and gnuplot.",
2022
+ )
2023
+ # TODO unify all these
2024
+ parser.add_argument(
2025
+ "--profile-activities",
2026
+ type=str,
2027
+ nargs="+",
2028
+ default=["CPU", "GPU"],
2029
+ choices=["CPU", "GPU", "CUDA_PROFILER", "XPU"],
2030
+ help="Profiler activities to capture: CPU, GPU, XPU, CUDA_PROFILER.",
2031
+ )
2032
+ parser.add_argument(
2033
+ "--profile-start-step",
2034
+ type=int,
2035
+ default=None,
2036
+ help="Start profiling after this many forward steps. Useful for warmup.",
2037
+ )
2038
+ parser.add_argument(
2039
+ "--profile-steps",
2040
+ type=int,
2041
+ default=None,
2042
+ help="Number of steps to profile. If specified, profiling stops automatically after this many steps.",
2043
+ )
2044
+ parser.add_argument("--profile-num-steps", type=int, default=None)
2045
+ parser.add_argument("--profile-by-stage", action="store_true", default=False)
2046
+ parser.add_argument("--profile-stages", nargs="+", default=None)
2047
+ parser.add_argument(
2048
+ "--profile-output-dir",
2049
+ type=str,
2050
+ default=None,
2051
+ help="Output directory for profile traces.",
2052
+ )
2053
+ parser.add_argument(
2054
+ "--profile-prefix",
2055
+ type=str,
2056
+ default=None,
2057
+ help="Prefix for profile trace filenames.",
2058
+ )
2059
+ parser.add_argument(
2060
+ "--lora-name",
2061
+ type=str,
2062
+ nargs="*",
2063
+ default=None,
2064
+ action=LoRAPathAction,
2065
+ help="The names of LoRA adapters. You can provide a list of names in the format {name} {name} {name}...",
2066
+ )
2067
+ parser.add_argument(
2068
+ "--lora-request-distribution",
2069
+ type=str,
2070
+ default="uniform",
2071
+ choices=[
2072
+ "uniform",
2073
+ "distinct",
2074
+ "skewed",
2075
+ ],
2076
+ help="What distribution to sample the LoRA adapters specified in --lora-name. Borrowed from the Punica paper. "
2077
+ "'distinct' distribution means selecting a new LoRA adapter for every request. "
2078
+ "'skewed' distribution follows the Zipf distribution, where the number of requests "
2079
+ "to model i specified in --lora-name is α times the number of requests for model i+1, "
2080
+ "where α > 1.",
2081
+ )
2082
+ parser.add_argument(
2083
+ "--lora-zipf-alpha",
2084
+ type=float,
2085
+ default=1.5,
2086
+ help="The parameter to use for the Zipf distribution when --lora-request-distribution='skewed'.",
2087
+ )
2088
+ parser.add_argument(
2089
+ "--prompt-suffix",
2090
+ type=str,
2091
+ default="",
2092
+ help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
2093
+ )
2094
+ parser.add_argument(
2095
+ "--pd-separated",
2096
+ action="store_true",
2097
+ help="Benchmark PD disaggregation server",
2098
+ )
2099
+
2100
+ # Create a mutually exclusive group for profiling URLs
2101
+ # In PD separated mode, prefill and decode workers must be profiled separately
2102
+ profile_url_group = parser.add_mutually_exclusive_group()
2103
+ profile_url_group.add_argument(
2104
+ "--profile-prefill-url",
2105
+ type=str,
2106
+ nargs="*",
2107
+ default=None,
2108
+ help="URL(s) of the prefill worker(s) for profiling in PD separated mode. "
2109
+ "Can specify multiple URLs: --profile-prefill-url http://localhost:30000 http://localhost:30001. "
2110
+ "NOTE: Cannot be used together with --profile-decode-url. "
2111
+ "In PD separated mode, prefill and decode workers must be profiled separately.",
2112
+ )
2113
+ profile_url_group.add_argument(
2114
+ "--profile-decode-url",
2115
+ type=str,
2116
+ nargs="*",
2117
+ default=None,
2118
+ help="URL(s) of the decode worker(s) for profiling in PD separated mode. "
2119
+ "Can specify multiple URLs: --profile-decode-url http://localhost:30010 http://localhost:30011. "
2120
+ "NOTE: Cannot be used together with --profile-prefill-url. "
2121
+ "In PD separated mode, prefill and decode workers must be profiled separately.",
2122
+ )
2123
+ parser.add_argument(
2124
+ "--flush-cache",
2125
+ action="store_true",
2126
+ help="Flush the cache before running the benchmark",
2127
+ )
2128
+ parser.add_argument(
2129
+ "--warmup-requests",
2130
+ type=int,
2131
+ default=1,
2132
+ help="Number of warmup requests to run before the benchmark",
2133
+ )
2134
+ parser.add_argument(
2135
+ "--tokenize-prompt",
2136
+ action="store_true",
2137
+ help="Use integer ids instead of string for inputs. Useful to control prompt lengths accurately",
2138
+ )
2139
+
2140
+ group = parser.add_argument_group("generated-shared-prefix dataset arguments")
2141
+ group.add_argument(
2142
+ "--gsp-num-groups",
2143
+ type=int,
2144
+ default=64,
2145
+ help="Number of system prompt groups for generated-shared-prefix dataset",
2146
+ )
2147
+ group.add_argument(
2148
+ "--gsp-prompts-per-group",
2149
+ type=int,
2150
+ default=16,
2151
+ help="Number of prompts per system prompt group for generated-shared-prefix dataset",
2152
+ )
2153
+ group.add_argument(
2154
+ "--gsp-system-prompt-len",
2155
+ type=int,
2156
+ default=2048,
2157
+ help="Target length in tokens for system prompts in generated-shared-prefix dataset",
2158
+ )
2159
+ group.add_argument(
2160
+ "--gsp-question-len",
2161
+ type=int,
2162
+ default=128,
2163
+ help="Target length in tokens for questions in generated-shared-prefix dataset",
2164
+ )
2165
+ group.add_argument(
2166
+ "--gsp-output-len",
2167
+ type=int,
2168
+ default=256,
2169
+ help="Target length in tokens for outputs in generated-shared-prefix dataset",
2170
+ )
2171
+ parser.add_argument(
2172
+ "--gsp-range-ratio",
2173
+ type=float,
2174
+ # WARN: The default 1.0 is for backward compatibility, and is different from the default 0.0 for random dataset
2175
+ default=1.0,
2176
+ help="Range of sampled ratio of input/output length, used only for gsp dataset.",
2177
+ )
2178
+ group.add_argument(
2179
+ "--gsp-fast-prepare",
2180
+ action="store_true",
2181
+ help="Speedup preparing by removing statistics computation, which will make some output statistics inaccurate but suitable for pressure tests.",
2182
+ )
2183
+ group.add_argument(
2184
+ "--gsp-send-routing-key",
2185
+ action="store_true",
2186
+ help="Send routing key in requests via X-SMG-Routing-Key header. Requests with the same prefix share the same routing key.",
2187
+ )
2188
+ group.add_argument(
2189
+ "--gsp-num-turns",
2190
+ type=int,
2191
+ default=1,
2192
+ help="Number of turns for multi-turn conversations. If > 1, each prompt becomes a list of questions sharing the same system prefix.",
2193
+ )
2194
+ group.add_argument(
2195
+ "--gsp-ordered",
2196
+ action="store_true",
2197
+ help="Keep requests in order without shuffling. By default, requests are shuffled randomly.",
2198
+ )
2199
+ mooncake_group = parser.add_argument_group("mooncake dataset arguments")
2200
+ mooncake_group.add_argument(
2201
+ "--mooncake-slowdown-factor",
2202
+ type=float,
2203
+ default=1.0,
2204
+ help="Slowdown factor for replaying the mooncake trace. "
2205
+ "A value of 2.0 means the replay is twice as slow. "
2206
+ "NOTE: --request-rate is IGNORED in mooncake mode.",
2207
+ )
2208
+ mooncake_group.add_argument(
2209
+ "--mooncake-num-rounds",
2210
+ type=int,
2211
+ default=1,
2212
+ help="Number of conversation rounds for each session in the mooncake dataset. "
2213
+ "A value > 1 will enable true multi-turn session benchmarking.",
2214
+ )
2215
+ mooncake_group.add_argument(
2216
+ "--mooncake-workload",
2217
+ type=str,
2218
+ default="conversation",
2219
+ choices=[
2220
+ "mooncake",
2221
+ "conversation",
2222
+ "synthetic",
2223
+ "toolagent",
2224
+ ],
2225
+ help="Underlying workload for the mooncake dataset.",
2226
+ )
2227
+ parser.add_argument(
2228
+ "--tag", type=str, default=None, help="The tag to be dumped to output."
2229
+ )
2230
+ parser.add_argument(
2231
+ "--header",
2232
+ type=str,
2233
+ nargs="+",
2234
+ default=None,
2235
+ help="Custom HTTP headers in Key=Value format. Example: --header MyHeader=MY_VALUE MyAnotherHeader=myanothervalue",
2236
+ )
2237
+ args = parser.parse_args()
2238
+ run_benchmark(args)
sglang/python/sglang/benchmark/__init__.py ADDED
File without changes
sglang/python/sglang/benchmark/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (169 Bytes). View file
 
sglang/python/sglang/benchmark/__pycache__/utils.cpython-311.pyc ADDED
Binary file (7.65 kB). View file
 
sglang/python/sglang/benchmark/datasets/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Type
2
+
3
+ from sglang.benchmark.datasets.common import BaseDataset, DatasetRow
4
+ from sglang.benchmark.datasets.custom import CustomDataset
5
+ from sglang.benchmark.datasets.generated_shared_prefix import (
6
+ GeneratedSharedPrefixDataset,
7
+ )
8
+ from sglang.benchmark.datasets.image import ImageDataset
9
+ from sglang.benchmark.datasets.mmmu import MMMUDataset
10
+ from sglang.benchmark.datasets.mooncake import MooncakeDataset
11
+ from sglang.benchmark.datasets.openai_dataset import OpenAIDataset
12
+ from sglang.benchmark.datasets.random import RandomDataset
13
+ from sglang.benchmark.datasets.sharegpt import ShareGPTDataset
14
+
15
+ DATASET_MAPPING: Dict[str, Type[BaseDataset]] = {
16
+ "sharegpt": ShareGPTDataset,
17
+ "custom": CustomDataset,
18
+ "openai": OpenAIDataset,
19
+ # TODO: "random" vs "random-ids" should be a flag (e.g. --random-source=sharegpt|integers),
20
+ # not two separate dataset names sharing the same class.
21
+ "random": RandomDataset,
22
+ "random-ids": RandomDataset,
23
+ "generated-shared-prefix": GeneratedSharedPrefixDataset,
24
+ "mmmu": MMMUDataset,
25
+ "image": ImageDataset,
26
+ "mooncake": MooncakeDataset,
27
+ }
28
+
29
+
30
+ def get_dataset(args, tokenizer, model_id=None):
31
+ dataset_name = args.dataset_name
32
+ if dataset_name.startswith("random") and dataset_name not in DATASET_MAPPING:
33
+ dataset_name = "random-ids"
34
+
35
+ if dataset_name not in DATASET_MAPPING:
36
+ raise ValueError(f"Unknown dataset: {args.dataset_name}")
37
+
38
+ dataset_cls = DATASET_MAPPING[dataset_name]
39
+ dataset = dataset_cls.from_args(args)
40
+ return dataset.load(tokenizer=tokenizer, model_id=model_id)
41
+
42
+
43
+ __all__ = [
44
+ "DATASET_MAPPING",
45
+ "DatasetRow",
46
+ "get_dataset",
47
+ ]
sglang/python/sglang/benchmark/datasets/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.1 kB). View file
 
sglang/python/sglang/benchmark/datasets/__pycache__/common.cpython-311.pyc ADDED
Binary file (5.24 kB). View file
 
sglang/python/sglang/benchmark/datasets/__pycache__/custom.cpython-311.pyc ADDED
Binary file (6.7 kB). View file
 
sglang/python/sglang/benchmark/datasets/__pycache__/generated_shared_prefix.cpython-311.pyc ADDED
Binary file (11.9 kB). View file
 
sglang/python/sglang/benchmark/datasets/__pycache__/image.cpython-311.pyc ADDED
Binary file (12.6 kB). View file
 
sglang/python/sglang/benchmark/datasets/__pycache__/mmmu.cpython-311.pyc ADDED
Binary file (5.98 kB). View file