Zip Ye commited on
Commit ·
fa1aa1c
1
Parent(s): 399c281
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- README.md +8 -11
- draft_probe_suite/draft_model/config.json +38 -0
- draft_probe_suite/draft_model/model.safetensors +3 -0
- draft_probe_suite/pretrained_draft_model/config.json +38 -0
- draft_probe_suite/pretrained_draft_model/model.safetensors +3 -0
- draft_probe_suite/probe/config.json +1 -0
- draft_probe_suite/probe/state_dict.pth +3 -0
- sglang/README.md +17 -0
- sglang/__init__.py +83 -0
- sglang/bench_offline_throughput.py +476 -0
- sglang/bench_one_batch.py +795 -0
- sglang/bench_one_batch_server.py +605 -0
- sglang/bench_serving.py +0 -0
- sglang/check_env.py +433 -0
- sglang/cli/__init__.py +0 -0
- sglang/cli/generate.py +33 -0
- sglang/cli/main.py +26 -0
- sglang/cli/serve.py +75 -0
- sglang/cli/utils.py +152 -0
- sglang/compile_deep_gemm.py +191 -0
- sglang/eval/llama3_eval.py +315 -0
- sglang/eval/loogle_eval.py +164 -0
- sglang/global_config.py +29 -0
- sglang/jit_kernel/.clang-format +19 -0
- sglang/jit_kernel/__pycache__/hicache.cpython-311.pyc +0 -0
- sglang/jit_kernel/__pycache__/utils.cpython-311.pyc +0 -0
- sglang/jit_kernel/csrc/cuda_wait_value.cuh +38 -0
- sglang/jit_kernel/csrc/hicache.cuh +264 -0
- sglang/jit_kernel/cuda_wait_value.py +79 -0
- sglang/jit_kernel/hicache.py +138 -0
- sglang/jit_kernel/include/sgl_kernel/tensor.h +487 -0
- sglang/jit_kernel/include/sgl_kernel/utils.cuh +101 -0
- sglang/jit_kernel/include/sgl_kernel/utils.h +88 -0
- sglang/jit_kernel/include/sgl_kernel/warp.cuh +145 -0
- sglang/jit_kernel/utils.py +103 -0
- sglang/lang/__pycache__/api.cpython-311.pyc +0 -0
- sglang/lang/__pycache__/chat_template.cpython-311.pyc +0 -0
- sglang/lang/__pycache__/choices.cpython-311.pyc +0 -0
- sglang/lang/__pycache__/interpreter.cpython-311.pyc +0 -0
- sglang/lang/__pycache__/ir.cpython-311.pyc +0 -0
- sglang/lang/api.py +292 -0
- sglang/lang/backend/__pycache__/base_backend.cpython-311.pyc +0 -0
- sglang/lang/backend/__pycache__/runtime_endpoint.cpython-311.pyc +0 -0
- sglang/lang/backend/anthropic.py +73 -0
- sglang/lang/backend/base_backend.py +82 -0
- sglang/lang/backend/litellm.py +90 -0
- sglang/lang/backend/openai.py +475 -0
- sglang/lang/backend/runtime_endpoint.py +527 -0
- sglang/lang/backend/vertexai.py +148 -0
- sglang/lang/chat_template.py +668 -0
README.md
CHANGED
|
@@ -1,11 +1,10 @@
|
|
| 1 |
<div align="center">
|
| 2 |
-
<img src="assets/logo.png" alt="SEAGLE Logo" width="
|
| 3 |
</div>
|
| 4 |
|
| 5 |
-
# SEAGLE:
|
| 6 |
|
| 7 |
**SEAGLE** is a safety-aware speculative decoding policy based on [SGLang](https://github.com/sgl-project/sglang). It embeds a lightweight probe model into the draft loop of [EAGLE-3](https://github.com/SafeAILab/EAGLE) speculative decoding, performs real-time safety monitoring on each decoding step, dynamically adjusts draft tokens, and triggers a fallback mechanism when unsafe content is continuously detected.
|
| 8 |
-
> [HERE](https://www.modelscope.cn/models/Alibaba-AAIG/SEAGLE) is the link to the source code and model weights.
|
| 9 |
|
| 10 |

|
| 11 |
|
|
@@ -76,7 +75,7 @@ from sglang.srt.server_args import ServerArgs
|
|
| 76 |
from sglang.srt.entrypoints.http_server import launch_server as _launch_server
|
| 77 |
|
| 78 |
# =========================================================
|
| 79 |
-
# Launch SGlang Server with
|
| 80 |
# =========================================================
|
| 81 |
MODEL_PATH = "your_qwen3_235b_a22b_instruct_2507_path"
|
| 82 |
DRAFT_MODEL_PATH = "draft_probe_suite/draft_model"
|
|
@@ -246,7 +245,7 @@ We begin by evaluating the acceleration performance of our draft models, encompa
|
|
| 246 |
|
| 247 |
> **Note:** Our pre-trained draft model can be found [here](https://www.modelscope.cn/models/Alibaba-AAIG/SEAGLE/tree/master/draft_probe_suite/pretrained_draft_model). Compared to the [Meituan](https://modelscope.cn/models/lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge-Meituan) version, our Eagle Head has undergone accelerated training specifically for Chinese. The pre-trained version can be used standalone as an Eagle Head for Qwen3-235B-A22B-Instruct-2507, delivering outstanding acceleration performance in both Chinese and English.
|
| 248 |
|
| 249 |
-
**Launch with Standard SGLang
|
| 250 |
|
| 251 |
```bash
|
| 252 |
export SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1
|
|
@@ -282,11 +281,9 @@ Evaluate the probe's impact on normal chatting data (query safety & response saf
|
|
| 282 |
| :--- | :---: | :---: | :---: |
|
| 283 |
| FuseChat-Mixture | 50,000 | 0.99506 | 0.00494 |
|
| 284 |
|
| 285 |
-
|
| 286 |
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
The trained probe is integrated into the Eagle3 decoding pipeline. We evaluate the end-to-end utility and safety of the SafeAware decoding strategy using an SGLang single-request configuration.
|
| 290 |
|
| 291 |
#### (1) Utility Performance
|
| 292 |
|
|
@@ -315,8 +312,8 @@ Safety scores are evaluated based on the discriminative reward model (DRM), gene
|
|
| 315 |
| 📎 [Chinese: 100 High-Risk](assets/valuesTest_zh_hard_100.jsonl) | DRM Score | 0.43 | 0.49 | **0.83** |
|
| 316 |
| | QwQ Score | 0.70 | 0.70 | **0.92** |
|
| 317 |
| | GRM Score | 0.23 | 0.31 | **0.81** |
|
| 318 |
-
| 📊 [English Log](assets/GRM_judge_log_en.xlsx) |
|
| 319 |
-
| 📊 [Chinese Log](assets/GRM_judge_log_zh.xlsx) |
|
| 320 |
|
| 321 |
---
|
| 322 |
|
|
|
|
| 1 |
<div align="center">
|
| 2 |
+
<img src="assets/logo.png" alt="SEAGLE Logo" width="250"/>
|
| 3 |
</div>
|
| 4 |
|
| 5 |
+
# SEAGLE: Safe-Aware EAGLE
|
| 6 |
|
| 7 |
**SEAGLE** is a safety-aware speculative decoding policy based on [SGLang](https://github.com/sgl-project/sglang). It embeds a lightweight probe model into the draft loop of [EAGLE-3](https://github.com/SafeAILab/EAGLE) speculative decoding, performs real-time safety monitoring on each decoding step, dynamically adjusts draft tokens, and triggers a fallback mechanism when unsafe content is continuously detected.
|
|
|
|
| 8 |
|
| 9 |

|
| 10 |
|
|
|
|
| 75 |
from sglang.srt.entrypoints.http_server import launch_server as _launch_server
|
| 76 |
|
| 77 |
# =========================================================
|
| 78 |
+
# Launch SGlang Server with Safe-Aware Eagle3 Decoding
|
| 79 |
# =========================================================
|
| 80 |
MODEL_PATH = "your_qwen3_235b_a22b_instruct_2507_path"
|
| 81 |
DRAFT_MODEL_PATH = "draft_probe_suite/draft_model"
|
|
|
|
| 245 |
|
| 246 |
> **Note:** Our pre-trained draft model can be found [here](https://www.modelscope.cn/models/Alibaba-AAIG/SEAGLE/tree/master/draft_probe_suite/pretrained_draft_model). Compared to the [Meituan](https://modelscope.cn/models/lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge-Meituan) version, our Eagle Head has undergone accelerated training specifically for Chinese. The pre-trained version can be used standalone as an Eagle Head for Qwen3-235B-A22B-Instruct-2507, delivering outstanding acceleration performance in both Chinese and English.
|
| 247 |
|
| 248 |
+
**Launch with Standard SGLang Command:**
|
| 249 |
|
| 250 |
```bash
|
| 251 |
export SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1
|
|
|
|
| 281 |
| :--- | :---: | :---: | :---: |
|
| 282 |
| FuseChat-Mixture | 50,000 | 0.99506 | 0.00494 |
|
| 283 |
|
| 284 |
+
### ⚖️ 3.3 Utility and Safety
|
| 285 |
|
| 286 |
+
The trained probe is integrated into the Eagle3 decoding pipeline. Using an SGLang & Single Request configuration, the general utility and security of the SafeAware decoding strategy are evaluated.
|
|
|
|
|
|
|
| 287 |
|
| 288 |
#### (1) Utility Performance
|
| 289 |
|
|
|
|
| 312 |
| 📎 [Chinese: 100 High-Risk](assets/valuesTest_zh_hard_100.jsonl) | DRM Score | 0.43 | 0.49 | **0.83** |
|
| 313 |
| | QwQ Score | 0.70 | 0.70 | **0.92** |
|
| 314 |
| | GRM Score | 0.23 | 0.31 | **0.81** |
|
| 315 |
+
| 📊 [English Log](assets/GRM_judge_log_en.xlsx) | Evaluation | ✅ | - | ✅ |
|
| 316 |
+
| 📊 [Chinese Log](assets/GRM_judge_log_zh.xlsx) | Evaluation | ✅ | - | ✅ |
|
| 317 |
|
| 318 |
---
|
| 319 |
|
draft_probe_suite/draft_model/config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"LlamaForCausalLMEagle3"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 151643,
|
| 8 |
+
"draft_vocab_size": 32000,
|
| 9 |
+
"dtype": "bfloat16",
|
| 10 |
+
"eagle_config": {
|
| 11 |
+
"eagle_aux_hidden_state_layer_ids": [
|
| 12 |
+
1,
|
| 13 |
+
46,
|
| 14 |
+
90
|
| 15 |
+
],
|
| 16 |
+
"use_aux_hidden_state": true
|
| 17 |
+
},
|
| 18 |
+
"eos_token_id": 151645,
|
| 19 |
+
"head_dim": 128,
|
| 20 |
+
"hidden_act": "silu",
|
| 21 |
+
"hidden_size": 4096,
|
| 22 |
+
"initializer_range": 0.02,
|
| 23 |
+
"intermediate_size": 24576,
|
| 24 |
+
"max_position_embeddings": 40960,
|
| 25 |
+
"mlp_bias": false,
|
| 26 |
+
"model_type": "llama",
|
| 27 |
+
"num_attention_heads": 64,
|
| 28 |
+
"num_hidden_layers": 1,
|
| 29 |
+
"num_key_value_heads": 4,
|
| 30 |
+
"pretraining_tp": 1,
|
| 31 |
+
"rms_norm_eps": 1e-06,
|
| 32 |
+
"rope_scaling": null,
|
| 33 |
+
"rope_theta": 1000000.0,
|
| 34 |
+
"tie_word_embeddings": false,
|
| 35 |
+
"transformers_version": "4.57.1",
|
| 36 |
+
"use_cache": true,
|
| 37 |
+
"vocab_size": 151936
|
| 38 |
+
}
|
draft_probe_suite/draft_model/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:855968a019ea7ce5d33cdc503737c4fe19fd3aa8a176440d818506bf018f130d
|
| 3 |
+
size 1185333104
|
draft_probe_suite/pretrained_draft_model/config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"LlamaForCausalLMEagle3"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 151643,
|
| 8 |
+
"draft_vocab_size": 32000,
|
| 9 |
+
"dtype": "bfloat16",
|
| 10 |
+
"eagle_config": {
|
| 11 |
+
"eagle_aux_hidden_state_layer_ids": [
|
| 12 |
+
1,
|
| 13 |
+
46,
|
| 14 |
+
90
|
| 15 |
+
],
|
| 16 |
+
"use_aux_hidden_state": true
|
| 17 |
+
},
|
| 18 |
+
"eos_token_id": 151645,
|
| 19 |
+
"head_dim": 128,
|
| 20 |
+
"hidden_act": "silu",
|
| 21 |
+
"hidden_size": 4096,
|
| 22 |
+
"initializer_range": 0.02,
|
| 23 |
+
"intermediate_size": 24576,
|
| 24 |
+
"max_position_embeddings": 40960,
|
| 25 |
+
"mlp_bias": false,
|
| 26 |
+
"model_type": "llama",
|
| 27 |
+
"num_attention_heads": 64,
|
| 28 |
+
"num_hidden_layers": 1,
|
| 29 |
+
"num_key_value_heads": 4,
|
| 30 |
+
"pretraining_tp": 1,
|
| 31 |
+
"rms_norm_eps": 1e-06,
|
| 32 |
+
"rope_scaling": null,
|
| 33 |
+
"rope_theta": 1000000.0,
|
| 34 |
+
"tie_word_embeddings": false,
|
| 35 |
+
"transformers_version": "4.57.1",
|
| 36 |
+
"use_cache": true,
|
| 37 |
+
"vocab_size": 151936
|
| 38 |
+
}
|
draft_probe_suite/pretrained_draft_model/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b4b0c81dfb283d27ab78fbdffffe1b0962d6fc3ef4bce16dbc4d6561ef19a9b1
|
| 3 |
+
size 1185333104
|
draft_probe_suite/probe/config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"input_size": 4096, "output_size": 1, "intermediate_size": 1024}
|
draft_probe_suite/probe/state_dict.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bee888de61324d261765237d29a23f7d4858fb6c571bbe58726321a8be0b8d26
|
| 3 |
+
size 16803511
|
sglang/README.md
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code Structure
|
| 2 |
+
|
| 3 |
+
- `eval`: The evaluation utilities.
|
| 4 |
+
- `lang`: The frontend language.
|
| 5 |
+
- `srt`: The backend engine for running local models. (SRT = SGLang Runtime).
|
| 6 |
+
- `test`: The test utilities.
|
| 7 |
+
- `api.py`: The public APIs.
|
| 8 |
+
- `bench_offline_throughput.py`: Benchmark the performance in the offline mode.
|
| 9 |
+
- `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server.
|
| 10 |
+
- `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server.
|
| 11 |
+
- `bench_serving.py`: Benchmark online serving with dynamic requests.
|
| 12 |
+
- `check_env.py`: Check the environment variables and dependencies.
|
| 13 |
+
- `global_config.py`: The global configs and constants.
|
| 14 |
+
- `launch_server.py`: The entry point for launching a local server.
|
| 15 |
+
- `profiler.py`: The profiling entry point to send profile requests.
|
| 16 |
+
- `utils.py`: Common utilities.
|
| 17 |
+
- `version.py`: Version info.
|
sglang/__init__.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SGLang public APIs
|
| 2 |
+
|
| 3 |
+
# Frontend Language APIs
|
| 4 |
+
from sglang.global_config import global_config
|
| 5 |
+
from sglang.lang.api import (
|
| 6 |
+
Engine,
|
| 7 |
+
Runtime,
|
| 8 |
+
assistant,
|
| 9 |
+
assistant_begin,
|
| 10 |
+
assistant_end,
|
| 11 |
+
flush_cache,
|
| 12 |
+
function,
|
| 13 |
+
gen,
|
| 14 |
+
gen_int,
|
| 15 |
+
gen_string,
|
| 16 |
+
get_server_info,
|
| 17 |
+
image,
|
| 18 |
+
select,
|
| 19 |
+
separate_reasoning,
|
| 20 |
+
set_default_backend,
|
| 21 |
+
system,
|
| 22 |
+
system_begin,
|
| 23 |
+
system_end,
|
| 24 |
+
user,
|
| 25 |
+
user_begin,
|
| 26 |
+
user_end,
|
| 27 |
+
video,
|
| 28 |
+
)
|
| 29 |
+
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
| 30 |
+
from sglang.lang.choices import (
|
| 31 |
+
greedy_token_selection,
|
| 32 |
+
token_length_normalized,
|
| 33 |
+
unconditional_likelihood_normalized,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Lazy import some libraries
|
| 37 |
+
from sglang.utils import LazyImport
|
| 38 |
+
from sglang.version import __version__
|
| 39 |
+
|
| 40 |
+
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
|
| 41 |
+
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
|
| 42 |
+
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
|
| 43 |
+
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
|
| 44 |
+
|
| 45 |
+
# Runtime Engine APIs
|
| 46 |
+
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
|
| 47 |
+
Engine = LazyImport("sglang.srt.entrypoints.engine", "Engine")
|
| 48 |
+
|
| 49 |
+
__all__ = [
|
| 50 |
+
"Engine",
|
| 51 |
+
"Runtime",
|
| 52 |
+
"assistant",
|
| 53 |
+
"assistant_begin",
|
| 54 |
+
"assistant_end",
|
| 55 |
+
"flush_cache",
|
| 56 |
+
"function",
|
| 57 |
+
"gen",
|
| 58 |
+
"gen_int",
|
| 59 |
+
"gen_string",
|
| 60 |
+
"get_server_info",
|
| 61 |
+
"image",
|
| 62 |
+
"select",
|
| 63 |
+
"separate_reasoning",
|
| 64 |
+
"set_default_backend",
|
| 65 |
+
"system",
|
| 66 |
+
"system_begin",
|
| 67 |
+
"system_end",
|
| 68 |
+
"user",
|
| 69 |
+
"user_begin",
|
| 70 |
+
"user_end",
|
| 71 |
+
"video",
|
| 72 |
+
"RuntimeEndpoint",
|
| 73 |
+
"greedy_token_selection",
|
| 74 |
+
"token_length_normalized",
|
| 75 |
+
"unconditional_likelihood_normalized",
|
| 76 |
+
"ServerArgs",
|
| 77 |
+
"Anthropic",
|
| 78 |
+
"LiteLLM",
|
| 79 |
+
"OpenAI",
|
| 80 |
+
"VertexAI",
|
| 81 |
+
"global_config",
|
| 82 |
+
"__version__",
|
| 83 |
+
]
|
sglang/bench_offline_throughput.py
ADDED
|
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Benchmark the throughput in the offline mode.
|
| 3 |
+
It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py).
|
| 4 |
+
|
| 5 |
+
# Usage
|
| 6 |
+
## Sharegpt dataset with default args
|
| 7 |
+
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10
|
| 8 |
+
|
| 9 |
+
## Random dataset with default args
|
| 10 |
+
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import asyncio
|
| 15 |
+
import dataclasses
|
| 16 |
+
import inspect
|
| 17 |
+
import json
|
| 18 |
+
import logging
|
| 19 |
+
import os
|
| 20 |
+
import random
|
| 21 |
+
import time
|
| 22 |
+
from typing import Dict, List, Optional
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
|
| 26 |
+
from sglang.bench_serving import (
|
| 27 |
+
DatasetRow,
|
| 28 |
+
get_dataset,
|
| 29 |
+
get_tokenizer,
|
| 30 |
+
sample_random_requests,
|
| 31 |
+
set_ulimit,
|
| 32 |
+
)
|
| 33 |
+
from sglang.lang.backend.runtime_endpoint import Runtime
|
| 34 |
+
from sglang.srt.entrypoints.engine import Engine
|
| 35 |
+
from sglang.srt.server_args import ServerArgs
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclasses.dataclass
|
| 39 |
+
class BenchArgs:
|
| 40 |
+
backend: str = "engine"
|
| 41 |
+
result_filename: str = ""
|
| 42 |
+
dataset_name: str = "sharegpt"
|
| 43 |
+
dataset_path: str = ""
|
| 44 |
+
num_prompts: int = 1000
|
| 45 |
+
sharegpt_output_len: Optional[int] = None
|
| 46 |
+
sharegpt_context_len: Optional[int] = None
|
| 47 |
+
random_input_len: int = 1024
|
| 48 |
+
random_output_len: int = 1024
|
| 49 |
+
random_range_ratio: float = 0.0
|
| 50 |
+
gsp_num_groups: int = 64
|
| 51 |
+
gsp_prompts_per_group: int = 16
|
| 52 |
+
gsp_system_prompt_len: int = 2048
|
| 53 |
+
gsp_question_len: int = 128
|
| 54 |
+
gsp_output_len: int = 256
|
| 55 |
+
seed: int = 1
|
| 56 |
+
disable_ignore_eos: bool = False
|
| 57 |
+
extra_request_body: Optional[str] = None
|
| 58 |
+
apply_chat_template: bool = False
|
| 59 |
+
profile: bool = False
|
| 60 |
+
skip_warmup: bool = False
|
| 61 |
+
do_not_exit: bool = False
|
| 62 |
+
prompt_suffix: str = ""
|
| 63 |
+
return_logprob: bool = False
|
| 64 |
+
logprob_start_len: int = -1
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def add_cli_args(parser: argparse.ArgumentParser):
|
| 68 |
+
parser.add_argument("--backend", type=str, default=BenchArgs.backend)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--result-filename", type=str, default=BenchArgs.result_filename
|
| 71 |
+
)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--dataset-name",
|
| 74 |
+
type=str,
|
| 75 |
+
default="sharegpt",
|
| 76 |
+
choices=["sharegpt", "random", "generated-shared-prefix"],
|
| 77 |
+
help="Name of the dataset to benchmark on.",
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--dataset-path", type=str, default="", help="Path to the dataset."
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--num-prompts",
|
| 84 |
+
type=int,
|
| 85 |
+
default=BenchArgs.num_prompts,
|
| 86 |
+
help="Number of prompts to process. Default is 1000.",
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--sharegpt-output-len",
|
| 90 |
+
type=int,
|
| 91 |
+
default=BenchArgs.sharegpt_output_len,
|
| 92 |
+
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--sharegpt-context-len",
|
| 96 |
+
type=int,
|
| 97 |
+
default=BenchArgs.sharegpt_context_len,
|
| 98 |
+
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--random-input-len",
|
| 102 |
+
type=int,
|
| 103 |
+
default=BenchArgs.random_input_len,
|
| 104 |
+
help="Number of input tokens per request, used only for random dataset.",
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--random-output-len",
|
| 108 |
+
type=int,
|
| 109 |
+
default=BenchArgs.random_output_len,
|
| 110 |
+
help="Number of output tokens per request, used only for random dataset.",
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--random-range-ratio",
|
| 114 |
+
type=float,
|
| 115 |
+
default=BenchArgs.random_range_ratio,
|
| 116 |
+
help="Range of sampled ratio of input/output length, "
|
| 117 |
+
"used only for random dataset.",
|
| 118 |
+
)
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--gsp-num-groups",
|
| 121 |
+
type=int,
|
| 122 |
+
default=BenchArgs.gsp_num_groups,
|
| 123 |
+
help="Number of groups with shared prefix, used"
|
| 124 |
+
"only for generate-shared-prefix",
|
| 125 |
+
)
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--gsp-prompts-per-group",
|
| 128 |
+
type=int,
|
| 129 |
+
default=BenchArgs.gsp_prompts_per_group,
|
| 130 |
+
help="Number of prompts per group of shared prefix, used"
|
| 131 |
+
"only for generate-shared-prefix",
|
| 132 |
+
)
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--gsp-system-prompt-len",
|
| 135 |
+
type=int,
|
| 136 |
+
default=BenchArgs.gsp_system_prompt_len,
|
| 137 |
+
help="System prompt length, used" "only for generate-shared-prefix",
|
| 138 |
+
)
|
| 139 |
+
parser.add_argument(
|
| 140 |
+
"--gsp-question-len",
|
| 141 |
+
type=int,
|
| 142 |
+
default=BenchArgs.gsp_question_len,
|
| 143 |
+
help="Question length, used" "only for generate-shared-prefix",
|
| 144 |
+
)
|
| 145 |
+
parser.add_argument(
|
| 146 |
+
"--gsp-output-len",
|
| 147 |
+
type=int,
|
| 148 |
+
default=BenchArgs.gsp_output_len,
|
| 149 |
+
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
| 150 |
+
)
|
| 151 |
+
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
"--disable-ignore-eos",
|
| 154 |
+
action="store_true",
|
| 155 |
+
help="Disable ignore EOS token",
|
| 156 |
+
)
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--extra-request-body",
|
| 159 |
+
metavar='{"key1": "value1", "key2": "value2"}',
|
| 160 |
+
type=str,
|
| 161 |
+
default=BenchArgs.extra_request_body,
|
| 162 |
+
help="Append given JSON object to the request payload. You can use this to specify"
|
| 163 |
+
"additional generate params like sampling params.",
|
| 164 |
+
)
|
| 165 |
+
parser.add_argument(
|
| 166 |
+
"--apply-chat-template",
|
| 167 |
+
action="store_true",
|
| 168 |
+
help="Apply chat template",
|
| 169 |
+
)
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"--profile",
|
| 172 |
+
action="store_true",
|
| 173 |
+
help="Use Torch Profiler. The endpoint must be launched with "
|
| 174 |
+
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
| 175 |
+
)
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"--skip-warmup",
|
| 178 |
+
action="store_true",
|
| 179 |
+
help="Skip the warmup batches.",
|
| 180 |
+
)
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--do-not-exit",
|
| 183 |
+
action="store_true",
|
| 184 |
+
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--prompt-suffix",
|
| 188 |
+
type=str,
|
| 189 |
+
default="",
|
| 190 |
+
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
|
| 191 |
+
)
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--return-logprob",
|
| 194 |
+
action="store_true",
|
| 195 |
+
help="Enable returning log probabilities.",
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--logprob-start-len",
|
| 199 |
+
type=int,
|
| 200 |
+
default=-1,
|
| 201 |
+
help="Start length for logprob. -1 means only return logprobs for output tokens (default). 0 means return logprobs for all tokens including input.",
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
@classmethod
|
| 205 |
+
def from_cli_args(cls, args: argparse.Namespace):
|
| 206 |
+
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
| 207 |
+
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def throughput_test_once(
|
| 211 |
+
backend_name: str,
|
| 212 |
+
backend,
|
| 213 |
+
reqs: List[DatasetRow],
|
| 214 |
+
ignore_eos: bool,
|
| 215 |
+
extra_request_body: Dict,
|
| 216 |
+
profile: bool,
|
| 217 |
+
return_logprob: bool = False,
|
| 218 |
+
logprob_start_len: int = -1,
|
| 219 |
+
):
|
| 220 |
+
measurement_results = {
|
| 221 |
+
"backend": backend_name,
|
| 222 |
+
"successful_requests": len(reqs),
|
| 223 |
+
"total_latency": -1,
|
| 224 |
+
"total_input_tokens": sum(r.prompt_len for r in reqs),
|
| 225 |
+
"total_output_tokens": -1,
|
| 226 |
+
"request_throughput": -1,
|
| 227 |
+
"input_throughput": -1,
|
| 228 |
+
"output_throughput": -1,
|
| 229 |
+
"total_throughput": -1,
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
prompt = [r.prompt for r in reqs]
|
| 233 |
+
sampling_params = [
|
| 234 |
+
{
|
| 235 |
+
"temperature": 0,
|
| 236 |
+
"max_new_tokens": r.output_len,
|
| 237 |
+
"ignore_eos": ignore_eos,
|
| 238 |
+
**extra_request_body,
|
| 239 |
+
}
|
| 240 |
+
for r in reqs
|
| 241 |
+
]
|
| 242 |
+
|
| 243 |
+
if profile:
|
| 244 |
+
assert (
|
| 245 |
+
"SGLANG_TORCH_PROFILER_DIR" in os.environ
|
| 246 |
+
), "Please set SGLANG_TORCH_PROFILER_DIR."
|
| 247 |
+
os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True)
|
| 248 |
+
backend.start_profile()
|
| 249 |
+
|
| 250 |
+
st = time.perf_counter()
|
| 251 |
+
gen_out = backend.generate(
|
| 252 |
+
prompt=prompt,
|
| 253 |
+
sampling_params=sampling_params,
|
| 254 |
+
return_logprob=return_logprob,
|
| 255 |
+
logprob_start_len=logprob_start_len,
|
| 256 |
+
)
|
| 257 |
+
latency = time.perf_counter() - st
|
| 258 |
+
|
| 259 |
+
if profile:
|
| 260 |
+
dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
|
| 261 |
+
known_files = set(os.listdir(dir))
|
| 262 |
+
backend.stop_profile()
|
| 263 |
+
monitor_trace_file(known_files, dir)
|
| 264 |
+
|
| 265 |
+
if backend_name == "runtime":
|
| 266 |
+
gen_out = json.loads(gen_out)
|
| 267 |
+
|
| 268 |
+
server_info = backend.get_server_info()
|
| 269 |
+
|
| 270 |
+
measurement_results["total_latency"] = latency
|
| 271 |
+
measurement_results["total_output_tokens"] = sum(
|
| 272 |
+
o["meta_info"]["completion_tokens"] for o in gen_out
|
| 273 |
+
)
|
| 274 |
+
measurement_results["request_throughput"] = (
|
| 275 |
+
measurement_results["successful_requests"] / latency
|
| 276 |
+
)
|
| 277 |
+
measurement_results["input_throughput"] = (
|
| 278 |
+
measurement_results["total_input_tokens"] / latency
|
| 279 |
+
)
|
| 280 |
+
measurement_results["output_throughput"] = (
|
| 281 |
+
measurement_results["total_output_tokens"] / latency
|
| 282 |
+
)
|
| 283 |
+
measurement_results["total_throughput"] = (
|
| 284 |
+
measurement_results["total_input_tokens"]
|
| 285 |
+
+ measurement_results["total_output_tokens"]
|
| 286 |
+
) / latency
|
| 287 |
+
|
| 288 |
+
if inspect.isawaitable(server_info):
|
| 289 |
+
server_info = asyncio.run(server_info)
|
| 290 |
+
|
| 291 |
+
measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
|
| 292 |
+
"last_gen_throughput"
|
| 293 |
+
]
|
| 294 |
+
|
| 295 |
+
return measurement_results
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def monitor_trace_file(known_files, directory, interval=1):
|
| 299 |
+
print(f"Monitoring {directory} for new trace files...")
|
| 300 |
+
|
| 301 |
+
while True:
|
| 302 |
+
flag = False
|
| 303 |
+
time.sleep(interval)
|
| 304 |
+
current_files = set(os.listdir(directory))
|
| 305 |
+
|
| 306 |
+
new_files = current_files - known_files
|
| 307 |
+
for new_file in new_files:
|
| 308 |
+
new_file_path = os.path.join(directory, new_file)
|
| 309 |
+
print(f"New file detected: {new_file}")
|
| 310 |
+
|
| 311 |
+
previous_size = 0
|
| 312 |
+
while True:
|
| 313 |
+
try:
|
| 314 |
+
current_size = os.path.getsize(new_file_path)
|
| 315 |
+
except FileNotFoundError:
|
| 316 |
+
print(f"File {new_file} is no longer accessible.")
|
| 317 |
+
break
|
| 318 |
+
|
| 319 |
+
if current_size > previous_size:
|
| 320 |
+
previous_size = current_size
|
| 321 |
+
else:
|
| 322 |
+
flag = True
|
| 323 |
+
break
|
| 324 |
+
|
| 325 |
+
time.sleep(interval)
|
| 326 |
+
if flag:
|
| 327 |
+
break
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def throughput_test(
|
| 331 |
+
server_args: ServerArgs,
|
| 332 |
+
bench_args: BenchArgs,
|
| 333 |
+
):
|
| 334 |
+
if bench_args.backend == "engine":
|
| 335 |
+
backend = Engine(**dataclasses.asdict(server_args))
|
| 336 |
+
if not backend:
|
| 337 |
+
raise ValueError("Please provide valid engine arguments")
|
| 338 |
+
elif bench_args.backend == "runtime":
|
| 339 |
+
backend = Runtime(**dataclasses.asdict(server_args))
|
| 340 |
+
else:
|
| 341 |
+
raise ValueError('Please set backend to either "engine" or "runtime"')
|
| 342 |
+
|
| 343 |
+
tokenizer_id = server_args.tokenizer_path or server_args.model_path
|
| 344 |
+
tokenizer = get_tokenizer(tokenizer_id)
|
| 345 |
+
|
| 346 |
+
# Set global environments
|
| 347 |
+
set_ulimit()
|
| 348 |
+
random.seed(bench_args.seed)
|
| 349 |
+
np.random.seed(bench_args.seed)
|
| 350 |
+
|
| 351 |
+
# Parse args
|
| 352 |
+
extra_request_body = {}
|
| 353 |
+
if bench_args.extra_request_body:
|
| 354 |
+
extra_request_body = json.loads(args.extra_request_body)
|
| 355 |
+
|
| 356 |
+
# Read dataset
|
| 357 |
+
input_requests = get_dataset(bench_args, tokenizer)
|
| 358 |
+
|
| 359 |
+
warmup_requests = sample_random_requests(
|
| 360 |
+
input_len=256,
|
| 361 |
+
output_len=16,
|
| 362 |
+
num_prompts=min(bench_args.num_prompts, 16),
|
| 363 |
+
range_ratio=1.0,
|
| 364 |
+
tokenizer=tokenizer,
|
| 365 |
+
dataset_path=bench_args.dataset_path,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# Warm up
|
| 369 |
+
if not bench_args.skip_warmup:
|
| 370 |
+
logging.info("\nWarmup...")
|
| 371 |
+
throughput_test_once(
|
| 372 |
+
backend_name=bench_args.backend,
|
| 373 |
+
backend=backend,
|
| 374 |
+
reqs=warmup_requests,
|
| 375 |
+
ignore_eos=not bench_args.disable_ignore_eos,
|
| 376 |
+
extra_request_body=extra_request_body,
|
| 377 |
+
profile=False,
|
| 378 |
+
return_logprob=bench_args.return_logprob,
|
| 379 |
+
logprob_start_len=bench_args.logprob_start_len,
|
| 380 |
+
)
|
| 381 |
+
time.sleep(0.5)
|
| 382 |
+
|
| 383 |
+
logging.info("\nBenchmark...")
|
| 384 |
+
result = throughput_test_once(
|
| 385 |
+
backend_name=bench_args.backend,
|
| 386 |
+
backend=backend,
|
| 387 |
+
reqs=input_requests,
|
| 388 |
+
ignore_eos=not bench_args.disable_ignore_eos,
|
| 389 |
+
extra_request_body=extra_request_body,
|
| 390 |
+
profile=bench_args.profile,
|
| 391 |
+
return_logprob=bench_args.return_logprob,
|
| 392 |
+
logprob_start_len=bench_args.logprob_start_len,
|
| 393 |
+
)
|
| 394 |
+
backend.shutdown()
|
| 395 |
+
|
| 396 |
+
if bench_args.result_filename:
|
| 397 |
+
with open(bench_args.result_filename, "a") as fout:
|
| 398 |
+
fout.write(json.dumps(result) + "\n")
|
| 399 |
+
|
| 400 |
+
print(
|
| 401 |
+
"\n{s:{c}^{n}}".format(s=" Offline Throughput Benchmark Result ", n=50, c="=")
|
| 402 |
+
)
|
| 403 |
+
print("{:<40} {:<10}".format("Backend:", result["backend"]))
|
| 404 |
+
print("{:<40} {:<10}".format("Successful requests:", result["successful_requests"]))
|
| 405 |
+
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", result["total_latency"]))
|
| 406 |
+
print("{:<40} {:<10}".format("Total input tokens:", result["total_input_tokens"]))
|
| 407 |
+
print(
|
| 408 |
+
"{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
|
| 409 |
+
)
|
| 410 |
+
print(
|
| 411 |
+
"{:<40} {:<10.2f}".format(
|
| 412 |
+
"Last generation throughput (tok/s):", result["last_gen_throughput"]
|
| 413 |
+
)
|
| 414 |
+
)
|
| 415 |
+
print(
|
| 416 |
+
"{:<40} {:<10.2f}".format(
|
| 417 |
+
"Request throughput (req/s):", result["request_throughput"]
|
| 418 |
+
)
|
| 419 |
+
)
|
| 420 |
+
print(
|
| 421 |
+
"{:<40} {:<10.2f}".format(
|
| 422 |
+
"Input token throughput (tok/s):", result["input_throughput"]
|
| 423 |
+
)
|
| 424 |
+
)
|
| 425 |
+
print(
|
| 426 |
+
"{:<40} {:<10.2f}".format(
|
| 427 |
+
"Output token throughput (tok/s):", result["output_throughput"]
|
| 428 |
+
)
|
| 429 |
+
)
|
| 430 |
+
print(
|
| 431 |
+
"{:<40} {:<10.2f}".format(
|
| 432 |
+
"Total token throughput (tok/s):", result["total_throughput"]
|
| 433 |
+
)
|
| 434 |
+
)
|
| 435 |
+
print("=" * 50)
|
| 436 |
+
|
| 437 |
+
return result
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
if __name__ == "__main__":
|
| 441 |
+
parser = argparse.ArgumentParser()
|
| 442 |
+
ServerArgs.add_cli_args(parser)
|
| 443 |
+
BenchArgs.add_cli_args(parser)
|
| 444 |
+
args = parser.parse_args()
|
| 445 |
+
|
| 446 |
+
# handling ModelScope model downloads
|
| 447 |
+
if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() in ("true", "1"):
|
| 448 |
+
if os.path.exists(args.model_path):
|
| 449 |
+
print(f"Using local model path: {args.model_path}")
|
| 450 |
+
else:
|
| 451 |
+
try:
|
| 452 |
+
from modelscope import snapshot_download
|
| 453 |
+
|
| 454 |
+
print(f"Using ModelScope to download model: {args.model_path}")
|
| 455 |
+
|
| 456 |
+
# download the model and replace args.model_path
|
| 457 |
+
args.model_path = snapshot_download(
|
| 458 |
+
args.model_path,
|
| 459 |
+
)
|
| 460 |
+
print(f"Model downloaded to: {args.model_path}")
|
| 461 |
+
except Exception as e:
|
| 462 |
+
print(f"ModelScope download failed: {str(e)}")
|
| 463 |
+
raise e
|
| 464 |
+
|
| 465 |
+
server_args = ServerArgs.from_cli_args(args)
|
| 466 |
+
bench_args = BenchArgs.from_cli_args(args)
|
| 467 |
+
|
| 468 |
+
logging.basicConfig(
|
| 469 |
+
level=getattr(logging, server_args.log_level.upper()),
|
| 470 |
+
format="%(message)s",
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
throughput_test(server_args, bench_args)
|
| 474 |
+
|
| 475 |
+
while bench_args.do_not_exit:
|
| 476 |
+
pass
|
sglang/bench_one_batch.py
ADDED
|
@@ -0,0 +1,795 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Benchmark the latency of running a single static batch without a server.
|
| 3 |
+
|
| 4 |
+
This script does not launch a server and uses the low-level APIs.
|
| 5 |
+
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
|
| 6 |
+
|
| 7 |
+
# Usage (latency test)
|
| 8 |
+
## with dummy weights:
|
| 9 |
+
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
| 10 |
+
## sweep through multiple data points and store (append) the results in a jsonl file:
|
| 11 |
+
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
|
| 12 |
+
## run with profiling:
|
| 13 |
+
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile
|
| 14 |
+
## run with profiling to custom directory:
|
| 15 |
+
export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log
|
| 16 |
+
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile
|
| 17 |
+
## run with CUDA profiler (nsys):
|
| 18 |
+
nsys profile --force-overwrite=true -o bench_one_batch python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile --profile-activities CUDA_PROFILER
|
| 19 |
+
# Usage (correctness test):
|
| 20 |
+
python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
| 21 |
+
|
| 22 |
+
## Reference output (of the correctness test above, can be gpu dependent):
|
| 23 |
+
input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]
|
| 24 |
+
|
| 25 |
+
prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
|
| 26 |
+
[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
|
| 27 |
+
[ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]],
|
| 28 |
+
device='cuda:0')
|
| 29 |
+
|
| 30 |
+
prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141],
|
| 31 |
+
[-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781],
|
| 32 |
+
[-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]],
|
| 33 |
+
device='cuda:0')
|
| 34 |
+
|
| 35 |
+
========== Prompt 0 ==========
|
| 36 |
+
<s> The capital of France is Paris.
|
| 37 |
+
The capital of the United States is Washington, D.C.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
========== Prompt 1 ==========
|
| 41 |
+
<s> The capital of the United Kindom is London.
|
| 42 |
+
The capital of the United Kingdom is London.
|
| 43 |
+
The capital of the
|
| 44 |
+
|
| 45 |
+
========== Prompt 2 ==========
|
| 46 |
+
<s> Today is a sunny day and I like to go for a walk in the park.
|
| 47 |
+
I'm going to the park
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
import argparse
|
| 51 |
+
import copy
|
| 52 |
+
import dataclasses
|
| 53 |
+
import itertools
|
| 54 |
+
import json
|
| 55 |
+
import logging
|
| 56 |
+
import multiprocessing
|
| 57 |
+
import os
|
| 58 |
+
import time
|
| 59 |
+
from types import SimpleNamespace
|
| 60 |
+
from typing import Tuple
|
| 61 |
+
|
| 62 |
+
import numpy as np
|
| 63 |
+
import torch
|
| 64 |
+
import torch.distributed as dist
|
| 65 |
+
|
| 66 |
+
from sglang.srt.configs.model_config import ModelConfig
|
| 67 |
+
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
|
| 68 |
+
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
| 69 |
+
from sglang.srt.layers.moe import initialize_moe_config
|
| 70 |
+
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
| 71 |
+
from sglang.srt.managers.scheduler_dp_attn_mixin import prepare_mlp_sync_batch_raw
|
| 72 |
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
| 73 |
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
| 74 |
+
from sglang.srt.sampling.sampling_params import SamplingParams
|
| 75 |
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
| 76 |
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
| 77 |
+
from sglang.srt.utils import (
|
| 78 |
+
configure_logger,
|
| 79 |
+
get_bool_env_var,
|
| 80 |
+
is_cuda_alike,
|
| 81 |
+
is_xpu,
|
| 82 |
+
kill_process_tree,
|
| 83 |
+
maybe_reindex_device_id,
|
| 84 |
+
require_mlp_sync,
|
| 85 |
+
require_mlp_tp_gather,
|
| 86 |
+
set_gpu_proc_affinity,
|
| 87 |
+
suppress_other_loggers,
|
| 88 |
+
)
|
| 89 |
+
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
|
| 90 |
+
|
| 91 |
+
profile_activities = [torch.profiler.ProfilerActivity.CPU] + [
|
| 92 |
+
profiler_activity
|
| 93 |
+
for available, profiler_activity in [
|
| 94 |
+
(is_cuda_alike(), torch.profiler.ProfilerActivity.CUDA),
|
| 95 |
+
(is_xpu(), torch.profiler.ProfilerActivity.XPU),
|
| 96 |
+
]
|
| 97 |
+
if available
|
| 98 |
+
]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def start_profile(profile_activities, profile_record_shapes=False, rank_print=print):
|
| 102 |
+
"""
|
| 103 |
+
Abstracted function to start profiling based on profile_activities.
|
| 104 |
+
Returns profiler object (or None).
|
| 105 |
+
"""
|
| 106 |
+
if "CUDA_PROFILER" in profile_activities:
|
| 107 |
+
try:
|
| 108 |
+
torch.cuda.cudart().cudaProfilerStart()
|
| 109 |
+
rank_print("CUDA Profiler started (nsys will begin capturing)")
|
| 110 |
+
except Exception as e:
|
| 111 |
+
rank_print(f"Failed to start CUDA profiler: {e}")
|
| 112 |
+
return None
|
| 113 |
+
else:
|
| 114 |
+
activities = []
|
| 115 |
+
if "CPU" in profile_activities:
|
| 116 |
+
activities.append(torch.profiler.ProfilerActivity.CPU)
|
| 117 |
+
if "GPU" in profile_activities:
|
| 118 |
+
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
| 119 |
+
if activities:
|
| 120 |
+
profiler = torch.profiler.profile(
|
| 121 |
+
activities=activities,
|
| 122 |
+
with_stack=True,
|
| 123 |
+
record_shapes=profile_record_shapes,
|
| 124 |
+
)
|
| 125 |
+
profiler.start()
|
| 126 |
+
return profiler
|
| 127 |
+
return None
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def stop_profile(
|
| 131 |
+
profiler,
|
| 132 |
+
profile_activities,
|
| 133 |
+
rank_print=print,
|
| 134 |
+
save_trace=False,
|
| 135 |
+
trace_filename=None,
|
| 136 |
+
stage=None,
|
| 137 |
+
):
|
| 138 |
+
"""
|
| 139 |
+
Abstracted function to stop profiling based on profile_activities.
|
| 140 |
+
Optionally saves trace results and prints completion messages.
|
| 141 |
+
"""
|
| 142 |
+
if "CUDA_PROFILER" in profile_activities:
|
| 143 |
+
try:
|
| 144 |
+
torch.cuda.cudart().cudaProfilerStop()
|
| 145 |
+
rank_print("CUDA Profiler stopped (nsys should dump traces)")
|
| 146 |
+
except Exception as e:
|
| 147 |
+
rank_print(f"Failed to stop CUDA profiler: {e}")
|
| 148 |
+
elif profiler is not None:
|
| 149 |
+
profiler.stop()
|
| 150 |
+
|
| 151 |
+
if save_trace:
|
| 152 |
+
if profiler is not None:
|
| 153 |
+
if trace_filename:
|
| 154 |
+
_save_profile_trace_results(profiler, trace_filename)
|
| 155 |
+
stage_desc = f"for {stage}" if stage else ""
|
| 156 |
+
rank_print(
|
| 157 |
+
f"torch profiler chrome trace {stage_desc} saved to {trace_filename}"
|
| 158 |
+
)
|
| 159 |
+
if "CUDA_PROFILER" in profile_activities:
|
| 160 |
+
rank_print(f"CUDA profiler trace for {stage} completed")
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@dataclasses.dataclass
|
| 164 |
+
class BenchArgs:
|
| 165 |
+
run_name: str = "default"
|
| 166 |
+
batch_size: Tuple[int] = (1,)
|
| 167 |
+
input_len: Tuple[int] = (1024,)
|
| 168 |
+
output_len: Tuple[int] = (16,)
|
| 169 |
+
prompt_filename: str = ""
|
| 170 |
+
result_filename: str = "result.jsonl"
|
| 171 |
+
correctness_test: bool = False
|
| 172 |
+
# This is only used for correctness test
|
| 173 |
+
cut_len: int = 4
|
| 174 |
+
log_decode_step: int = 0
|
| 175 |
+
profile: bool = False
|
| 176 |
+
profile_record_shapes: bool = False
|
| 177 |
+
profile_activities: Tuple[str] = ("CPU", "GPU")
|
| 178 |
+
profile_stage: str = "all"
|
| 179 |
+
profile_filename_prefix: str = "profile"
|
| 180 |
+
|
| 181 |
+
@staticmethod
|
| 182 |
+
def add_cli_args(parser: argparse.ArgumentParser):
|
| 183 |
+
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
|
| 184 |
+
parser.add_argument(
|
| 185 |
+
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
|
| 186 |
+
)
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
|
| 189 |
+
)
|
| 190 |
+
parser.add_argument(
|
| 191 |
+
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
| 192 |
+
)
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--prompt-filename", type=str, default=BenchArgs.prompt_filename
|
| 195 |
+
)
|
| 196 |
+
parser.add_argument(
|
| 197 |
+
"--result-filename", type=str, default=BenchArgs.result_filename
|
| 198 |
+
)
|
| 199 |
+
parser.add_argument("--correctness-test", action="store_true")
|
| 200 |
+
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--log-decode-step",
|
| 203 |
+
type=int,
|
| 204 |
+
default=BenchArgs.log_decode_step,
|
| 205 |
+
help="Log decode latency by step, default is set to zero to disable.",
|
| 206 |
+
)
|
| 207 |
+
parser.add_argument("--profile", action="store_true", help="Enable profiling.")
|
| 208 |
+
parser.add_argument(
|
| 209 |
+
"--profile-record-shapes",
|
| 210 |
+
action="store_true",
|
| 211 |
+
help="Record tensor shapes in profiling results.",
|
| 212 |
+
)
|
| 213 |
+
parser.add_argument(
|
| 214 |
+
"--profile-activities",
|
| 215 |
+
type=str,
|
| 216 |
+
nargs="+",
|
| 217 |
+
default=["CPU", "GPU"],
|
| 218 |
+
choices=["CPU", "GPU", "CUDA_PROFILER"],
|
| 219 |
+
help="Profiler activities: CPU, GPU, CUDA_PROFILER. If CPU/GPU, use torch profiler. If CUDA_PROFILER, use CUDA profiler.",
|
| 220 |
+
)
|
| 221 |
+
parser.add_argument(
|
| 222 |
+
"--profile-stage",
|
| 223 |
+
type=str,
|
| 224 |
+
default=BenchArgs.profile_stage,
|
| 225 |
+
choices=["all", "prefill", "decode"],
|
| 226 |
+
help="Which stage to profile: all, prefill, or decode only.",
|
| 227 |
+
)
|
| 228 |
+
parser.add_argument(
|
| 229 |
+
"--profile-filename-prefix",
|
| 230 |
+
type=str,
|
| 231 |
+
default=BenchArgs.profile_filename_prefix,
|
| 232 |
+
help="Prefix of the profiling file names. The full profiling result file(s) be "
|
| 233 |
+
'"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"',
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
@classmethod
|
| 237 |
+
def from_cli_args(cls, args: argparse.Namespace):
|
| 238 |
+
# use the default value's type to cast the args into correct types.
|
| 239 |
+
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
| 240 |
+
return cls(
|
| 241 |
+
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def load_model(server_args, port_args, gpu_id, tp_rank):
|
| 246 |
+
suppress_other_loggers()
|
| 247 |
+
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
| 248 |
+
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
|
| 249 |
+
|
| 250 |
+
model_config = ModelConfig.from_server_args(server_args)
|
| 251 |
+
model_runner = ModelRunner(
|
| 252 |
+
model_config=model_config,
|
| 253 |
+
mem_fraction_static=server_args.mem_fraction_static,
|
| 254 |
+
gpu_id=gpu_id,
|
| 255 |
+
tp_rank=tp_rank,
|
| 256 |
+
tp_size=server_args.tp_size,
|
| 257 |
+
moe_ep_rank=moe_ep_rank,
|
| 258 |
+
moe_ep_size=server_args.ep_size,
|
| 259 |
+
pp_rank=0,
|
| 260 |
+
pp_size=1,
|
| 261 |
+
nccl_port=port_args.nccl_port,
|
| 262 |
+
server_args=server_args,
|
| 263 |
+
)
|
| 264 |
+
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
| 265 |
+
tokenizer = get_tokenizer(
|
| 266 |
+
server_args.tokenizer_path,
|
| 267 |
+
tokenizer_mode=server_args.tokenizer_mode,
|
| 268 |
+
trust_remote_code=server_args.trust_remote_code,
|
| 269 |
+
)
|
| 270 |
+
if server_args.tp_size > 1:
|
| 271 |
+
dist.barrier()
|
| 272 |
+
return model_runner, tokenizer
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
|
| 276 |
+
prompts = (
|
| 277 |
+
custom_prompts
|
| 278 |
+
if custom_prompts
|
| 279 |
+
else [
|
| 280 |
+
"The capital of France is",
|
| 281 |
+
"The capital of the United Kindom is",
|
| 282 |
+
"Today is a sunny day and I like",
|
| 283 |
+
]
|
| 284 |
+
)
|
| 285 |
+
input_ids = [tokenizer.encode(p) for p in prompts]
|
| 286 |
+
sampling_params = SamplingParams(
|
| 287 |
+
temperature=0,
|
| 288 |
+
max_new_tokens=BenchArgs.output_len,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
reqs = []
|
| 292 |
+
for i in range(len(prompts)):
|
| 293 |
+
assert len(input_ids[i]) > bench_args.cut_len
|
| 294 |
+
|
| 295 |
+
tmp_input_ids = input_ids[i][: bench_args.cut_len]
|
| 296 |
+
req = Req(
|
| 297 |
+
rid=i,
|
| 298 |
+
origin_input_text=prompts[i],
|
| 299 |
+
origin_input_ids=tmp_input_ids,
|
| 300 |
+
sampling_params=sampling_params,
|
| 301 |
+
)
|
| 302 |
+
req.fill_ids = req.origin_input_ids
|
| 303 |
+
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
| 304 |
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
| 305 |
+
reqs.append(req)
|
| 306 |
+
|
| 307 |
+
return input_ids, reqs
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def prepare_extend_inputs_for_correctness_test(
|
| 311 |
+
bench_args, input_ids, reqs, model_runner
|
| 312 |
+
):
|
| 313 |
+
for i in range(len(reqs)):
|
| 314 |
+
req = reqs[i]
|
| 315 |
+
req.fill_ids += input_ids[i][bench_args.cut_len :]
|
| 316 |
+
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
| 317 |
+
i, : bench_args.cut_len
|
| 318 |
+
]
|
| 319 |
+
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
| 320 |
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
| 321 |
+
return reqs
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def prepare_synthetic_inputs_for_latency_test(
|
| 325 |
+
batch_size, input_len, custom_inputs=None
|
| 326 |
+
):
|
| 327 |
+
input_ids = (
|
| 328 |
+
custom_inputs
|
| 329 |
+
if custom_inputs
|
| 330 |
+
else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
|
| 331 |
+
)
|
| 332 |
+
sampling_params = SamplingParams(
|
| 333 |
+
temperature=0,
|
| 334 |
+
max_new_tokens=BenchArgs.output_len,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
reqs = []
|
| 338 |
+
for i in range(len(input_ids)):
|
| 339 |
+
req = Req(
|
| 340 |
+
rid=i,
|
| 341 |
+
origin_input_text="",
|
| 342 |
+
origin_input_ids=list(input_ids[i]),
|
| 343 |
+
sampling_params=sampling_params,
|
| 344 |
+
)
|
| 345 |
+
req.fill_ids = req.origin_input_ids
|
| 346 |
+
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
| 347 |
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
| 348 |
+
reqs.append(req)
|
| 349 |
+
|
| 350 |
+
return reqs
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
@torch.no_grad
|
| 354 |
+
def extend(reqs, model_runner):
|
| 355 |
+
# Create dummy tree_cache for benchmarks (no prefix caching, just allocation)
|
| 356 |
+
dummy_tree_cache = SimpleNamespace(
|
| 357 |
+
page_size=model_runner.server_args.page_size,
|
| 358 |
+
device=model_runner.device,
|
| 359 |
+
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
batch = ScheduleBatch.init_new(
|
| 363 |
+
reqs=reqs,
|
| 364 |
+
req_to_token_pool=model_runner.req_to_token_pool,
|
| 365 |
+
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
|
| 366 |
+
tree_cache=dummy_tree_cache,
|
| 367 |
+
model_config=model_runner.model_config,
|
| 368 |
+
enable_overlap=False,
|
| 369 |
+
spec_algorithm=SpeculativeAlgorithm.NONE,
|
| 370 |
+
)
|
| 371 |
+
batch.prepare_for_extend()
|
| 372 |
+
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
| 373 |
+
model_worker_batch = batch.get_model_worker_batch()
|
| 374 |
+
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
| 375 |
+
logits_output, _ = model_runner.forward(forward_batch)
|
| 376 |
+
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
| 377 |
+
return next_token_ids, logits_output.next_token_logits, batch
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
@torch.no_grad
|
| 381 |
+
def decode(input_token_ids, batch, model_runner):
|
| 382 |
+
batch.output_ids = input_token_ids
|
| 383 |
+
batch.prepare_for_decode()
|
| 384 |
+
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
| 385 |
+
model_worker_batch = batch.get_model_worker_batch()
|
| 386 |
+
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
| 387 |
+
logits_output, _ = model_runner.forward(forward_batch)
|
| 388 |
+
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
| 389 |
+
return next_token_ids, logits_output.next_token_logits
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
|
| 393 |
+
if require_mlp_sync(model_runner.server_args):
|
| 394 |
+
prepare_mlp_sync_batch_raw(
|
| 395 |
+
batch,
|
| 396 |
+
dp_size=model_runner.server_args.dp_size,
|
| 397 |
+
attn_tp_size=1,
|
| 398 |
+
tp_group=model_runner.tp_group,
|
| 399 |
+
get_idle_batch=None,
|
| 400 |
+
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
| 401 |
+
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
|
| 402 |
+
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
|
| 403 |
+
offload_tags=set(),
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def _read_prompts_from_file(prompt_file, rank_print):
|
| 408 |
+
"""Read custom prompts from the file specified by `--prompt-filename`."""
|
| 409 |
+
if not prompt_file:
|
| 410 |
+
return []
|
| 411 |
+
if not os.path.exists(prompt_file):
|
| 412 |
+
rank_print(
|
| 413 |
+
f"Custom prompt file {prompt_file} not found. Using default inputs..."
|
| 414 |
+
)
|
| 415 |
+
return []
|
| 416 |
+
with open(prompt_file, "r") as pf:
|
| 417 |
+
return pf.readlines()
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def _get_torch_profiler_output_dir():
|
| 421 |
+
return os.environ.get("SGLANG_TORCH_PROFILER_DIR", "/tmp")
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def _create_torch_profiler_filename(
|
| 425 |
+
profile_filename_prefix, batch_size, input_len, output_len, stage
|
| 426 |
+
):
|
| 427 |
+
output_dir = _get_torch_profiler_output_dir()
|
| 428 |
+
filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_{stage}.trace.json.gz"
|
| 429 |
+
return os.path.join(output_dir, filename)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def _save_profile_trace_results(profiler, filename):
|
| 433 |
+
parent_dir = os.path.dirname(os.path.abspath(filename))
|
| 434 |
+
os.makedirs(parent_dir, exist_ok=True)
|
| 435 |
+
profiler.export_chrome_trace(filename)
|
| 436 |
+
print(
|
| 437 |
+
profiler.key_averages(group_by_input_shape=True).table(
|
| 438 |
+
sort_by="self_cpu_time_total"
|
| 439 |
+
)
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def correctness_test(
|
| 444 |
+
server_args,
|
| 445 |
+
port_args,
|
| 446 |
+
bench_args,
|
| 447 |
+
gpu_id,
|
| 448 |
+
tp_rank,
|
| 449 |
+
):
|
| 450 |
+
# Configure the logger
|
| 451 |
+
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
| 452 |
+
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
| 453 |
+
|
| 454 |
+
# Load the model
|
| 455 |
+
model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank)
|
| 456 |
+
|
| 457 |
+
# Prepare inputs
|
| 458 |
+
custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
|
| 459 |
+
input_ids, reqs = prepare_inputs_for_correctness_test(
|
| 460 |
+
bench_args, tokenizer, custom_prompts
|
| 461 |
+
)
|
| 462 |
+
rank_print(f"\n{input_ids=}\n")
|
| 463 |
+
|
| 464 |
+
if bench_args.cut_len > 0:
|
| 465 |
+
# Prefill
|
| 466 |
+
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
| 467 |
+
rank_print(f"prefill logits (first half): {next_token_logits} \n")
|
| 468 |
+
|
| 469 |
+
# Prepare extend inputs
|
| 470 |
+
reqs = prepare_extend_inputs_for_correctness_test(
|
| 471 |
+
bench_args, input_ids, reqs, model_runner
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# Extend (prefill w/ KV cache)
|
| 475 |
+
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
| 476 |
+
rank_print(f"prefill logits (final): {next_token_logits} \n")
|
| 477 |
+
|
| 478 |
+
# Decode
|
| 479 |
+
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
| 480 |
+
for _ in range(bench_args.output_len[0] - 1):
|
| 481 |
+
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
| 482 |
+
next_token_ids_list = next_token_ids.tolist()
|
| 483 |
+
for i in range(len(reqs)):
|
| 484 |
+
output_ids[i].append(next_token_ids_list[i])
|
| 485 |
+
|
| 486 |
+
# Print output texts
|
| 487 |
+
for i in range(len(reqs)):
|
| 488 |
+
rank_print(f"========== Prompt {i} ==========")
|
| 489 |
+
rank_print(tokenizer.decode(output_ids[i]), "\n")
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def synchronize(device):
|
| 493 |
+
torch.get_device_module(device).synchronize()
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def latency_test_run_once(
|
| 497 |
+
run_name,
|
| 498 |
+
model_runner,
|
| 499 |
+
rank_print,
|
| 500 |
+
reqs,
|
| 501 |
+
batch_size,
|
| 502 |
+
input_len,
|
| 503 |
+
output_len,
|
| 504 |
+
device,
|
| 505 |
+
log_decode_step,
|
| 506 |
+
profile,
|
| 507 |
+
profile_record_shapes,
|
| 508 |
+
profile_activities,
|
| 509 |
+
profile_filename_prefix,
|
| 510 |
+
profile_stage,
|
| 511 |
+
tp_rank,
|
| 512 |
+
):
|
| 513 |
+
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
| 514 |
+
if batch_size > max_batch_size:
|
| 515 |
+
rank_print(
|
| 516 |
+
f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
|
| 517 |
+
)
|
| 518 |
+
return
|
| 519 |
+
|
| 520 |
+
model_runner.req_to_token_pool.clear()
|
| 521 |
+
model_runner.token_to_kv_pool_allocator.clear()
|
| 522 |
+
|
| 523 |
+
measurement_results = {
|
| 524 |
+
"run_name": run_name,
|
| 525 |
+
"batch_size": batch_size,
|
| 526 |
+
"input_len": input_len,
|
| 527 |
+
"output_len": output_len,
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
tot_latency = 0
|
| 531 |
+
|
| 532 |
+
profiler = None
|
| 533 |
+
enable_profile_prefill = profile and profile_stage in ["all", "prefill"]
|
| 534 |
+
if enable_profile_prefill:
|
| 535 |
+
profiler = start_profile(
|
| 536 |
+
profile_activities,
|
| 537 |
+
profile_record_shapes=profile_record_shapes,
|
| 538 |
+
rank_print=rank_print,
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
synchronize(device)
|
| 542 |
+
tic = time.perf_counter()
|
| 543 |
+
next_token_ids, _, batch = extend(reqs, model_runner)
|
| 544 |
+
synchronize(device)
|
| 545 |
+
prefill_latency = time.perf_counter() - tic
|
| 546 |
+
|
| 547 |
+
if enable_profile_prefill:
|
| 548 |
+
trace_filename = _create_torch_profiler_filename(
|
| 549 |
+
profile_filename_prefix, batch_size, input_len, output_len, "prefill"
|
| 550 |
+
)
|
| 551 |
+
stop_profile(
|
| 552 |
+
profiler,
|
| 553 |
+
profile_activities,
|
| 554 |
+
rank_print=rank_print,
|
| 555 |
+
save_trace=True,
|
| 556 |
+
trace_filename=trace_filename,
|
| 557 |
+
stage="prefill",
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
tot_latency += prefill_latency
|
| 561 |
+
throughput = input_len * batch_size / prefill_latency
|
| 562 |
+
rank_print(
|
| 563 |
+
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
| 564 |
+
)
|
| 565 |
+
measurement_results["prefill_latency"] = prefill_latency
|
| 566 |
+
measurement_results["prefill_throughput"] = throughput
|
| 567 |
+
|
| 568 |
+
decode_latencies = []
|
| 569 |
+
profile_step_of_interest = output_len // 2
|
| 570 |
+
enable_profile_decode = profile and profile_stage in ["all", "decode"]
|
| 571 |
+
for i in range(output_len - 1):
|
| 572 |
+
synchronize(device)
|
| 573 |
+
profiler = None
|
| 574 |
+
if enable_profile_decode and i == profile_step_of_interest:
|
| 575 |
+
profiler = start_profile(
|
| 576 |
+
profile_activities,
|
| 577 |
+
profile_record_shapes=profile_record_shapes,
|
| 578 |
+
rank_print=rank_print,
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
tic = time.perf_counter()
|
| 582 |
+
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
| 583 |
+
synchronize(device)
|
| 584 |
+
latency = time.perf_counter() - tic
|
| 585 |
+
|
| 586 |
+
if enable_profile_decode and i == profile_step_of_interest:
|
| 587 |
+
trace_filename = _create_torch_profiler_filename(
|
| 588 |
+
profile_filename_prefix, batch_size, input_len, output_len, "decode"
|
| 589 |
+
)
|
| 590 |
+
stop_profile(
|
| 591 |
+
profiler,
|
| 592 |
+
profile_activities,
|
| 593 |
+
rank_print=rank_print,
|
| 594 |
+
save_trace=True,
|
| 595 |
+
trace_filename=trace_filename,
|
| 596 |
+
stage="decode",
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
tot_latency += latency
|
| 600 |
+
throughput = batch_size / latency
|
| 601 |
+
decode_latencies.append(latency)
|
| 602 |
+
if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0):
|
| 603 |
+
rank_print(
|
| 604 |
+
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
# Record decode timing from 2nd output
|
| 608 |
+
if output_len > 1:
|
| 609 |
+
med_decode_latency = np.median(decode_latencies)
|
| 610 |
+
med_decode_throughput = batch_size / med_decode_latency
|
| 611 |
+
rank_print(
|
| 612 |
+
f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
|
| 613 |
+
)
|
| 614 |
+
measurement_results["median_decode_latency"] = med_decode_latency
|
| 615 |
+
measurement_results["median_decode_throughput"] = med_decode_throughput
|
| 616 |
+
|
| 617 |
+
throughput = (input_len + output_len) * batch_size / tot_latency
|
| 618 |
+
rank_print(
|
| 619 |
+
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
|
| 620 |
+
)
|
| 621 |
+
measurement_results["total_latency"] = tot_latency
|
| 622 |
+
measurement_results["overall_throughput"] = throughput
|
| 623 |
+
return measurement_results
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
def latency_test(
|
| 627 |
+
server_args,
|
| 628 |
+
port_args,
|
| 629 |
+
bench_args,
|
| 630 |
+
gpu_id,
|
| 631 |
+
tp_rank,
|
| 632 |
+
):
|
| 633 |
+
initialize_moe_config(server_args)
|
| 634 |
+
|
| 635 |
+
# Set CPU affinity
|
| 636 |
+
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
| 637 |
+
set_gpu_proc_affinity(
|
| 638 |
+
server_args.pp_size, server_args.tp_size, server_args.nnodes, tp_rank
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
# Configure the logger
|
| 642 |
+
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
| 643 |
+
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
| 644 |
+
|
| 645 |
+
# Load the model
|
| 646 |
+
model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank)
|
| 647 |
+
|
| 648 |
+
# Prepare inputs for warm up
|
| 649 |
+
reqs = prepare_synthetic_inputs_for_latency_test(
|
| 650 |
+
bench_args.batch_size[0], bench_args.input_len[0]
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
# Warm up
|
| 654 |
+
rank_print("Warmup ...")
|
| 655 |
+
latency_test_run_once(
|
| 656 |
+
bench_args.run_name,
|
| 657 |
+
model_runner,
|
| 658 |
+
rank_print,
|
| 659 |
+
reqs,
|
| 660 |
+
bench_args.batch_size[0],
|
| 661 |
+
bench_args.input_len[0],
|
| 662 |
+
min(32, bench_args.output_len[0]), # shorter decoding to speed up the warmup
|
| 663 |
+
server_args.device,
|
| 664 |
+
log_decode_step=0,
|
| 665 |
+
profile=False,
|
| 666 |
+
profile_record_shapes=False,
|
| 667 |
+
profile_activities=("CPU", "GPU"),
|
| 668 |
+
profile_filename_prefix="",
|
| 669 |
+
profile_stage="all",
|
| 670 |
+
tp_rank=tp_rank,
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
rank_print("Benchmark ...")
|
| 674 |
+
|
| 675 |
+
custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
|
| 676 |
+
custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs]
|
| 677 |
+
custom_input_len = len(custom_inputs)
|
| 678 |
+
|
| 679 |
+
# Run the sweep
|
| 680 |
+
result_list = []
|
| 681 |
+
for bs, il, ol in itertools.product(
|
| 682 |
+
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
| 683 |
+
):
|
| 684 |
+
bs_aligned_inputs = []
|
| 685 |
+
if custom_inputs:
|
| 686 |
+
if custom_input_len == bs:
|
| 687 |
+
bs_aligned_inputs = custom_inputs
|
| 688 |
+
elif custom_input_len > bs:
|
| 689 |
+
rank_print(
|
| 690 |
+
f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). "
|
| 691 |
+
f"Using the first {bs} prompts."
|
| 692 |
+
)
|
| 693 |
+
bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs])
|
| 694 |
+
else:
|
| 695 |
+
rank_print(
|
| 696 |
+
f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). "
|
| 697 |
+
f"Pad to the desired batch_size with the last prompt."
|
| 698 |
+
)
|
| 699 |
+
bs_aligned_inputs = copy.deepcopy(custom_inputs)
|
| 700 |
+
bs_aligned_inputs.extend(
|
| 701 |
+
[bs_aligned_inputs[-1]] * (bs - custom_input_len)
|
| 702 |
+
)
|
| 703 |
+
|
| 704 |
+
reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs)
|
| 705 |
+
ret = latency_test_run_once(
|
| 706 |
+
bench_args.run_name,
|
| 707 |
+
model_runner,
|
| 708 |
+
rank_print,
|
| 709 |
+
reqs,
|
| 710 |
+
bs,
|
| 711 |
+
il,
|
| 712 |
+
ol,
|
| 713 |
+
server_args.device,
|
| 714 |
+
bench_args.log_decode_step,
|
| 715 |
+
bench_args.profile if tp_rank == 0 else None,
|
| 716 |
+
bench_args.profile_record_shapes if tp_rank == 0 else None,
|
| 717 |
+
bench_args.profile_activities,
|
| 718 |
+
bench_args.profile_filename_prefix,
|
| 719 |
+
bench_args.profile_stage,
|
| 720 |
+
tp_rank,
|
| 721 |
+
)
|
| 722 |
+
if ret is not None:
|
| 723 |
+
result_list.append(ret)
|
| 724 |
+
|
| 725 |
+
# Write results in jsonlines format on rank 0.
|
| 726 |
+
if tp_rank == 0 and bench_args.result_filename:
|
| 727 |
+
with open(bench_args.result_filename, "a") as fout:
|
| 728 |
+
for result in result_list:
|
| 729 |
+
fout.write(json.dumps(result) + "\n")
|
| 730 |
+
|
| 731 |
+
if server_args.tp_size > 1:
|
| 732 |
+
destroy_distributed_environment()
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
def main(server_args, bench_args):
|
| 736 |
+
server_args.cuda_graph_max_bs = max(bench_args.batch_size)
|
| 737 |
+
|
| 738 |
+
_set_envs_and_config(server_args)
|
| 739 |
+
|
| 740 |
+
if server_args.model_path:
|
| 741 |
+
if bench_args.correctness_test:
|
| 742 |
+
work_func = correctness_test
|
| 743 |
+
else:
|
| 744 |
+
work_func = latency_test
|
| 745 |
+
else:
|
| 746 |
+
raise ValueError(
|
| 747 |
+
"Provide --model-path for running the tests or "
|
| 748 |
+
"provide --result-filename for plotting the results"
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
port_args = PortArgs.init_new(server_args)
|
| 752 |
+
|
| 753 |
+
if server_args.tp_size == 1:
|
| 754 |
+
work_func(server_args, port_args, bench_args, 0, 0)
|
| 755 |
+
else:
|
| 756 |
+
workers = []
|
| 757 |
+
for tp_rank in range(server_args.tp_size):
|
| 758 |
+
with maybe_reindex_device_id(tp_rank) as gpu_id:
|
| 759 |
+
proc = multiprocessing.Process(
|
| 760 |
+
target=work_func,
|
| 761 |
+
args=(
|
| 762 |
+
server_args,
|
| 763 |
+
port_args,
|
| 764 |
+
bench_args,
|
| 765 |
+
gpu_id,
|
| 766 |
+
tp_rank,
|
| 767 |
+
),
|
| 768 |
+
)
|
| 769 |
+
proc.start()
|
| 770 |
+
workers.append(proc)
|
| 771 |
+
|
| 772 |
+
for proc in workers:
|
| 773 |
+
proc.join()
|
| 774 |
+
|
| 775 |
+
proc.terminate()
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
if __name__ == "__main__":
|
| 779 |
+
parser = argparse.ArgumentParser()
|
| 780 |
+
ServerArgs.add_cli_args(parser)
|
| 781 |
+
BenchArgs.add_cli_args(parser)
|
| 782 |
+
args = parser.parse_args()
|
| 783 |
+
server_args = ServerArgs.from_cli_args(args)
|
| 784 |
+
bench_args = BenchArgs.from_cli_args(args)
|
| 785 |
+
|
| 786 |
+
logging.basicConfig(
|
| 787 |
+
level=getattr(logging, server_args.log_level.upper()),
|
| 788 |
+
format="%(message)s",
|
| 789 |
+
)
|
| 790 |
+
|
| 791 |
+
try:
|
| 792 |
+
main(server_args, bench_args)
|
| 793 |
+
finally:
|
| 794 |
+
if server_args.tp_size != 1:
|
| 795 |
+
kill_process_tree(os.getpid(), include_parent=False)
|
sglang/bench_one_batch_server.py
ADDED
|
@@ -0,0 +1,605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Benchmark the latency of running a single batch with a server.
|
| 3 |
+
|
| 4 |
+
This script launches a server and uses the HTTP interface.
|
| 5 |
+
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
|
| 9 |
+
|
| 10 |
+
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
|
| 11 |
+
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
|
| 12 |
+
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --output-path results.json --profile
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import dataclasses
|
| 17 |
+
import itertools
|
| 18 |
+
import json
|
| 19 |
+
import multiprocessing
|
| 20 |
+
import os
|
| 21 |
+
import random
|
| 22 |
+
import time
|
| 23 |
+
from typing import List, Optional, Tuple
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import requests
|
| 27 |
+
from pydantic import BaseModel
|
| 28 |
+
from transformers import AutoProcessor, PreTrainedTokenizer
|
| 29 |
+
|
| 30 |
+
from sglang.bench_serving import (
|
| 31 |
+
get_processor,
|
| 32 |
+
get_tokenizer,
|
| 33 |
+
sample_mmmu_requests,
|
| 34 |
+
sample_random_requests,
|
| 35 |
+
)
|
| 36 |
+
from sglang.profiler import run_profile
|
| 37 |
+
from sglang.srt.entrypoints.http_server import launch_server
|
| 38 |
+
from sglang.srt.server_args import ServerArgs
|
| 39 |
+
from sglang.srt.utils import is_blackwell, kill_process_tree
|
| 40 |
+
from sglang.test.nightly_bench_utils import save_results_as_pydantic_models
|
| 41 |
+
from sglang.test.test_utils import is_in_ci, write_github_step_summary
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclasses.dataclass
|
| 45 |
+
class BenchArgs:
|
| 46 |
+
run_name: str = "default"
|
| 47 |
+
batch_size: Tuple[int] = (1,)
|
| 48 |
+
input_len: Tuple[int] = (1024,)
|
| 49 |
+
output_len: Tuple[int] = (16,)
|
| 50 |
+
temperature: float = 0.0
|
| 51 |
+
return_logprob: bool = False
|
| 52 |
+
client_stream_interval: int = 1
|
| 53 |
+
input_len_step_percentage: float = 0.0
|
| 54 |
+
base_url: str = ""
|
| 55 |
+
skip_warmup: bool = False
|
| 56 |
+
show_report: bool = False
|
| 57 |
+
profile: bool = False
|
| 58 |
+
profile_steps: int = 5
|
| 59 |
+
profile_by_stage: bool = False
|
| 60 |
+
profile_prefix: Optional[str] = None
|
| 61 |
+
profile_output_dir: Optional[str] = None
|
| 62 |
+
dataset_path: str = ""
|
| 63 |
+
dataset_name: str = "random"
|
| 64 |
+
parallel_batch: bool = False
|
| 65 |
+
result_filename: str = "result.jsonl"
|
| 66 |
+
pydantic_result_filename: Optional[str] = None
|
| 67 |
+
append_to_github_summary: bool = True
|
| 68 |
+
seed: int = 42
|
| 69 |
+
|
| 70 |
+
@staticmethod
|
| 71 |
+
def add_cli_args(parser: argparse.ArgumentParser):
|
| 72 |
+
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
| 83 |
+
parser.add_argument("--return-logprob", action="store_true")
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--client-stream-interval",
|
| 86 |
+
type=int,
|
| 87 |
+
default=BenchArgs.client_stream_interval,
|
| 88 |
+
)
|
| 89 |
+
parser.add_argument(
|
| 90 |
+
"--input-len-step-percentage",
|
| 91 |
+
type=float,
|
| 92 |
+
default=BenchArgs.input_len_step_percentage,
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
|
| 95 |
+
parser.add_argument("--skip-warmup", action="store_true")
|
| 96 |
+
parser.add_argument("--show-report", action="store_true")
|
| 97 |
+
parser.add_argument("--profile", action="store_true")
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--profile-steps", type=int, default=BenchArgs.profile_steps
|
| 100 |
+
)
|
| 101 |
+
parser.add_argument("--profile-by-stage", action="store_true")
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--profile-prefix",
|
| 104 |
+
type=str,
|
| 105 |
+
default=BenchArgs.profile_prefix,
|
| 106 |
+
)
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--profile-output-dir",
|
| 109 |
+
type=str,
|
| 110 |
+
default=BenchArgs.profile_output_dir,
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--dataset-path",
|
| 114 |
+
type=str,
|
| 115 |
+
default=BenchArgs.dataset_path,
|
| 116 |
+
help="Path to the dataset.",
|
| 117 |
+
)
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--dataset-name",
|
| 120 |
+
type=str,
|
| 121 |
+
default=BenchArgs.dataset_name,
|
| 122 |
+
choices=["mmmu", "random"],
|
| 123 |
+
help="Name of the dataset to benchmark on.",
|
| 124 |
+
)
|
| 125 |
+
parser.add_argument("--parallel-batch", action="store_true")
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--result-filename",
|
| 128 |
+
type=str,
|
| 129 |
+
default=BenchArgs.result_filename,
|
| 130 |
+
help="Store the results line by line in the JSON Line format to this file.",
|
| 131 |
+
)
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--pydantic-result-filename",
|
| 134 |
+
type=str,
|
| 135 |
+
default=BenchArgs.pydantic_result_filename,
|
| 136 |
+
help="Store the results as pydantic models in the JSON format to this file.",
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--no-append-to-github-summary",
|
| 140 |
+
action="store_false",
|
| 141 |
+
dest="append_to_github_summary",
|
| 142 |
+
help="Disable appending the output of this run to github ci summary",
|
| 143 |
+
)
|
| 144 |
+
parser.add_argument("--seed", type=int, default=BenchArgs.seed)
|
| 145 |
+
|
| 146 |
+
@classmethod
|
| 147 |
+
def from_cli_args(cls, args: argparse.Namespace):
|
| 148 |
+
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
| 149 |
+
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class BenchOneCaseResult(BaseModel):
|
| 153 |
+
run_name: str
|
| 154 |
+
batch_size: int
|
| 155 |
+
input_len: int
|
| 156 |
+
output_len: int
|
| 157 |
+
latency: float
|
| 158 |
+
input_throughput: float
|
| 159 |
+
output_throughput: float
|
| 160 |
+
overall_throughput: float
|
| 161 |
+
last_ttft: float
|
| 162 |
+
last_gen_throughput: float
|
| 163 |
+
acc_length: float
|
| 164 |
+
profile_link: Optional[str] = None
|
| 165 |
+
|
| 166 |
+
def dump_to_jsonl(self, result_filename: str):
|
| 167 |
+
with open(result_filename, "a") as fout:
|
| 168 |
+
res = {
|
| 169 |
+
"run_name": self.run_name,
|
| 170 |
+
"batch_size": self.batch_size,
|
| 171 |
+
"input_len": self.input_len,
|
| 172 |
+
"output_len": self.output_len,
|
| 173 |
+
"latency": round(self.latency, 4),
|
| 174 |
+
"input_throughput": round(self.input_throughput, 2),
|
| 175 |
+
"output_throughput": round(self.output_throughput, 2),
|
| 176 |
+
"overall_throughput": round(self.overall_throughput, 2),
|
| 177 |
+
"last_ttft": round(self.last_ttft, 4),
|
| 178 |
+
"last_gen_throughput": round(self.last_gen_throughput, 2),
|
| 179 |
+
"acc_length": round(self.acc_length, 2),
|
| 180 |
+
}
|
| 181 |
+
fout.write(json.dumps(res) + "\n")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def launch_server_internal(server_args):
|
| 185 |
+
try:
|
| 186 |
+
launch_server(server_args)
|
| 187 |
+
except Exception as e:
|
| 188 |
+
raise e
|
| 189 |
+
finally:
|
| 190 |
+
kill_process_tree(os.getpid(), include_parent=False)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def launch_server_process(server_args: ServerArgs):
|
| 194 |
+
proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
|
| 195 |
+
proc.start()
|
| 196 |
+
base_url = f"http://{server_args.host}:{server_args.port}"
|
| 197 |
+
timeout = 600
|
| 198 |
+
|
| 199 |
+
start_time = time.time()
|
| 200 |
+
while time.time() - start_time < timeout:
|
| 201 |
+
try:
|
| 202 |
+
headers = {
|
| 203 |
+
"Content-Type": "application/json; charset=utf-8",
|
| 204 |
+
}
|
| 205 |
+
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
| 206 |
+
if response.status_code == 200:
|
| 207 |
+
return proc, base_url
|
| 208 |
+
except requests.RequestException:
|
| 209 |
+
pass
|
| 210 |
+
time.sleep(10)
|
| 211 |
+
raise TimeoutError("Server failed to start within the timeout period.")
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def run_one_case(
|
| 215 |
+
url: str,
|
| 216 |
+
batch_size: int,
|
| 217 |
+
input_len: int,
|
| 218 |
+
output_len: int,
|
| 219 |
+
temperature: float,
|
| 220 |
+
return_logprob: bool,
|
| 221 |
+
stream_interval: int,
|
| 222 |
+
input_len_step_percentage: float,
|
| 223 |
+
run_name: str,
|
| 224 |
+
result_filename: str,
|
| 225 |
+
tokenizer: PreTrainedTokenizer | AutoProcessor,
|
| 226 |
+
profile: bool = False,
|
| 227 |
+
profile_steps: int = BenchArgs.profile_steps,
|
| 228 |
+
profile_by_stage: bool = False,
|
| 229 |
+
profile_prefix: Optional[str] = BenchArgs.profile_prefix,
|
| 230 |
+
profile_output_dir: Optional[str] = BenchArgs.profile_output_dir,
|
| 231 |
+
dataset_name: str = BenchArgs.dataset_name,
|
| 232 |
+
dataset_path: str = BenchArgs.dataset_path,
|
| 233 |
+
parallel_batch: bool = False,
|
| 234 |
+
):
|
| 235 |
+
requests.post(url + "/flush_cache")
|
| 236 |
+
|
| 237 |
+
# Load input token ids
|
| 238 |
+
# TODO: reuse bench_serving.get_dataset ?
|
| 239 |
+
if dataset_name == "mmmu":
|
| 240 |
+
input_requests = sample_mmmu_requests(
|
| 241 |
+
num_requests=batch_size,
|
| 242 |
+
processor=tokenizer,
|
| 243 |
+
fixed_output_len=output_len,
|
| 244 |
+
random_sample=False,
|
| 245 |
+
)
|
| 246 |
+
elif dataset_name == "random":
|
| 247 |
+
input_requests = sample_random_requests(
|
| 248 |
+
input_len=input_len,
|
| 249 |
+
output_len=output_len,
|
| 250 |
+
num_prompts=batch_size,
|
| 251 |
+
range_ratio=1.0,
|
| 252 |
+
tokenizer=tokenizer,
|
| 253 |
+
dataset_path=dataset_path,
|
| 254 |
+
random_sample=True,
|
| 255 |
+
return_text=False,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# Load sampling parameters
|
| 259 |
+
use_structured_outputs = False
|
| 260 |
+
if use_structured_outputs:
|
| 261 |
+
texts = []
|
| 262 |
+
for _ in range(batch_size):
|
| 263 |
+
texts.append(
|
| 264 |
+
"Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n"
|
| 265 |
+
* 50
|
| 266 |
+
+ "Assistant:"
|
| 267 |
+
)
|
| 268 |
+
json_schema = "$$ANY$$"
|
| 269 |
+
else:
|
| 270 |
+
json_schema = None
|
| 271 |
+
|
| 272 |
+
payload = {
|
| 273 |
+
"sampling_params": {
|
| 274 |
+
"temperature": temperature,
|
| 275 |
+
"max_new_tokens": output_len,
|
| 276 |
+
"ignore_eos": True,
|
| 277 |
+
"json_schema": json_schema,
|
| 278 |
+
"stream_interval": stream_interval,
|
| 279 |
+
},
|
| 280 |
+
"return_logprob": return_logprob,
|
| 281 |
+
"stream": True,
|
| 282 |
+
**({"parallel_batch": parallel_batch} if parallel_batch else {}),
|
| 283 |
+
}
|
| 284 |
+
if dataset_name == "mmmu":
|
| 285 |
+
# vlm
|
| 286 |
+
input_ids = []
|
| 287 |
+
# for vlms, tokenizer is an instance of AutoProcessor
|
| 288 |
+
tokenizer = tokenizer.tokenizer
|
| 289 |
+
for input_req in input_requests:
|
| 290 |
+
input_ids += [tokenizer.encode(input_req.prompt)]
|
| 291 |
+
payload["image_data"] = [req.image_data for req in input_requests]
|
| 292 |
+
|
| 293 |
+
else:
|
| 294 |
+
input_ids = [req.prompt for req in input_requests]
|
| 295 |
+
|
| 296 |
+
payload["input_ids"] = input_ids
|
| 297 |
+
|
| 298 |
+
# Turn on profiler
|
| 299 |
+
profile_link = None
|
| 300 |
+
if profile:
|
| 301 |
+
profile_link: str = run_profile(
|
| 302 |
+
url=url,
|
| 303 |
+
num_steps=profile_steps,
|
| 304 |
+
activities=["CPU", "GPU"],
|
| 305 |
+
output_dir=profile_output_dir,
|
| 306 |
+
profile_by_stage=profile_by_stage,
|
| 307 |
+
profile_prefix=profile_prefix,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Run the request
|
| 311 |
+
tic = time.perf_counter()
|
| 312 |
+
response = requests.post(
|
| 313 |
+
url + "/generate",
|
| 314 |
+
json=payload,
|
| 315 |
+
stream=True,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# Get the TTFT of the last request in the batch
|
| 319 |
+
last_ttft = 0.0
|
| 320 |
+
for chunk in response.iter_lines(decode_unicode=False):
|
| 321 |
+
chunk = chunk.decode("utf-8")
|
| 322 |
+
if chunk and chunk.startswith("data:"):
|
| 323 |
+
if chunk == "data: [DONE]":
|
| 324 |
+
break
|
| 325 |
+
data = json.loads(chunk[5:].strip("\n"))
|
| 326 |
+
if "error" in data:
|
| 327 |
+
raise RuntimeError(f"Request has failed. {data}.")
|
| 328 |
+
|
| 329 |
+
assert (
|
| 330 |
+
data["meta_info"]["finish_reason"] is None
|
| 331 |
+
or data["meta_info"]["finish_reason"]["type"] == "length"
|
| 332 |
+
)
|
| 333 |
+
if data["meta_info"]["completion_tokens"] == 1:
|
| 334 |
+
last_ttft = time.perf_counter() - tic
|
| 335 |
+
|
| 336 |
+
# Compute metrics
|
| 337 |
+
latency = time.perf_counter() - tic
|
| 338 |
+
input_throughput = batch_size * input_len / last_ttft
|
| 339 |
+
output_throughput = batch_size * output_len / (latency - last_ttft)
|
| 340 |
+
overall_throughput = batch_size * (input_len + output_len) / latency
|
| 341 |
+
|
| 342 |
+
server_info = requests.get(url + "/get_server_info").json()
|
| 343 |
+
internal_state = server_info.get("internal_states", [{}])
|
| 344 |
+
last_gen_throughput = internal_state[0].get("last_gen_throughput", None) or -1
|
| 345 |
+
acc_length = internal_state[0].get("avg_spec_accept_length", None) or -1
|
| 346 |
+
|
| 347 |
+
# Print results
|
| 348 |
+
print(f"batch size: {batch_size}")
|
| 349 |
+
print(f"input_len: {input_len}")
|
| 350 |
+
print(f"output_len: {output_len}")
|
| 351 |
+
print(f"latency: {latency:.2f} s")
|
| 352 |
+
print(f"input throughput: {input_throughput:.2f} tok/s")
|
| 353 |
+
if output_len != 1:
|
| 354 |
+
print(f"output throughput: {output_throughput:.2f} tok/s")
|
| 355 |
+
print(f"last_ttft: {last_ttft:.2f} s")
|
| 356 |
+
print(f"last generation throughput: {last_gen_throughput:.2f} tok/s")
|
| 357 |
+
if acc_length > 0:
|
| 358 |
+
print(f"acc_length: {acc_length:.2f} ")
|
| 359 |
+
|
| 360 |
+
# Dump results
|
| 361 |
+
result = BenchOneCaseResult(
|
| 362 |
+
run_name=run_name,
|
| 363 |
+
batch_size=batch_size,
|
| 364 |
+
input_len=input_len,
|
| 365 |
+
output_len=output_len,
|
| 366 |
+
latency=latency,
|
| 367 |
+
input_throughput=input_throughput,
|
| 368 |
+
output_throughput=output_throughput,
|
| 369 |
+
overall_throughput=overall_throughput,
|
| 370 |
+
last_ttft=last_ttft,
|
| 371 |
+
last_gen_throughput=last_gen_throughput,
|
| 372 |
+
acc_length=acc_length,
|
| 373 |
+
profile_link=profile_link,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# Save and return the results
|
| 377 |
+
if result_filename:
|
| 378 |
+
result.dump_to_jsonl(result_filename)
|
| 379 |
+
|
| 380 |
+
return result
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def should_skip_due_to_token_capacity(
|
| 384 |
+
batch_size, input_len, output_len, skip_token_capacity_threshold
|
| 385 |
+
):
|
| 386 |
+
if batch_size * (input_len + output_len) > skip_token_capacity_threshold:
|
| 387 |
+
print(
|
| 388 |
+
"=" * 8
|
| 389 |
+
+ f"Skip benchmark {batch_size=} * ({input_len=} + {output_len=}) = {batch_size * (input_len + output_len)} > {skip_token_capacity_threshold=} due to kv cache limit."
|
| 390 |
+
+ "=" * 8
|
| 391 |
+
)
|
| 392 |
+
return True
|
| 393 |
+
return False
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def get_report_summary(
|
| 397 |
+
results: List[BenchOneCaseResult], bench_args: BenchArgs, server_args: ServerArgs
|
| 398 |
+
):
|
| 399 |
+
summary = (
|
| 400 |
+
f"\nInput lens: {bench_args.input_len}. Output lens: {bench_args.output_len}.\n"
|
| 401 |
+
)
|
| 402 |
+
summary += "| batch size | input len | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) |"
|
| 403 |
+
|
| 404 |
+
if bench_args.profile:
|
| 405 |
+
summary += " profile |"
|
| 406 |
+
|
| 407 |
+
summary += "\n"
|
| 408 |
+
summary += "| ---------- | --------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ----------------- | ------------------ |"
|
| 409 |
+
|
| 410 |
+
if bench_args.profile:
|
| 411 |
+
summary += "-------------|"
|
| 412 |
+
summary += "\n"
|
| 413 |
+
|
| 414 |
+
if is_blackwell():
|
| 415 |
+
hourly_cost_per_gpu = 4 # $4/hour for one B200
|
| 416 |
+
else:
|
| 417 |
+
hourly_cost_per_gpu = 2 # $2/hour for one H100
|
| 418 |
+
input_util = 0.7
|
| 419 |
+
|
| 420 |
+
# sort result by input_len
|
| 421 |
+
results.sort(key=lambda x: x.input_len)
|
| 422 |
+
for res in results:
|
| 423 |
+
hourly_cost = hourly_cost_per_gpu * server_args.tp_size
|
| 424 |
+
accept_length = round(res.acc_length, 2) if res.acc_length > 0 else "n/a"
|
| 425 |
+
line = (
|
| 426 |
+
f"| {res.batch_size} | "
|
| 427 |
+
f"{res.input_len} | "
|
| 428 |
+
f"{res.latency:.2f} | "
|
| 429 |
+
f"{res.input_throughput:.2f} | "
|
| 430 |
+
f"{res.output_throughput:.2f} | "
|
| 431 |
+
f"{accept_length} | "
|
| 432 |
+
f"{1 / (res.output_throughput/res.batch_size) * 1000:.2f} | "
|
| 433 |
+
f"{1e6 / (res.input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
|
| 434 |
+
f"{1e6 / res.output_throughput / 3600 * hourly_cost:.2f} |"
|
| 435 |
+
)
|
| 436 |
+
if bench_args.profile:
|
| 437 |
+
if res.profile_link:
|
| 438 |
+
line += f" [Profile]({res.profile_link}) |"
|
| 439 |
+
else:
|
| 440 |
+
line += f" n/a |"
|
| 441 |
+
line += "\n"
|
| 442 |
+
summary += line
|
| 443 |
+
|
| 444 |
+
return summary
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
| 448 |
+
if bench_args.base_url:
|
| 449 |
+
proc, base_url = None, bench_args.base_url
|
| 450 |
+
else:
|
| 451 |
+
proc, base_url = launch_server_process(server_args)
|
| 452 |
+
|
| 453 |
+
# Get tokenizer
|
| 454 |
+
server_info = requests.get(base_url + "/get_server_info").json()
|
| 455 |
+
if "tokenizer_path" in server_info:
|
| 456 |
+
tokenizer_path = server_info["tokenizer_path"]
|
| 457 |
+
elif "prefill" in server_info:
|
| 458 |
+
tokenizer_path = server_info["prefill"][0]["tokenizer_path"]
|
| 459 |
+
if bench_args.dataset_name == "mmmu":
|
| 460 |
+
# mmmu implies this is a MLLM
|
| 461 |
+
tokenizer = get_processor(tokenizer_path)
|
| 462 |
+
else:
|
| 463 |
+
tokenizer = get_tokenizer(tokenizer_path)
|
| 464 |
+
|
| 465 |
+
# Get token capacity
|
| 466 |
+
internal_state = server_info.get("internal_states", [{}])
|
| 467 |
+
skip_token_capacity_threshold = (
|
| 468 |
+
internal_state[0].get("memory_usage", {}).get("token_capacity", 1000000000)
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
# Warmup
|
| 472 |
+
if not bench_args.skip_warmup:
|
| 473 |
+
print("=" * 8 + " Warmup Begin " + "=" * 8)
|
| 474 |
+
print(f"Warmup with batch_size={bench_args.batch_size}")
|
| 475 |
+
for bs in bench_args.batch_size:
|
| 476 |
+
run_one_case(
|
| 477 |
+
base_url,
|
| 478 |
+
batch_size=bs,
|
| 479 |
+
input_len=1024,
|
| 480 |
+
output_len=16,
|
| 481 |
+
temperature=bench_args.temperature,
|
| 482 |
+
return_logprob=bench_args.return_logprob,
|
| 483 |
+
stream_interval=bench_args.client_stream_interval,
|
| 484 |
+
input_len_step_percentage=bench_args.input_len_step_percentage,
|
| 485 |
+
run_name="",
|
| 486 |
+
result_filename="",
|
| 487 |
+
tokenizer=tokenizer,
|
| 488 |
+
dataset_name=bench_args.dataset_name,
|
| 489 |
+
dataset_path=bench_args.dataset_path,
|
| 490 |
+
parallel_batch=bench_args.parallel_batch,
|
| 491 |
+
)
|
| 492 |
+
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
|
| 493 |
+
|
| 494 |
+
results = []
|
| 495 |
+
profile_results = []
|
| 496 |
+
try:
|
| 497 |
+
# Benchmark all cases
|
| 498 |
+
for bs, il, ol in itertools.product(
|
| 499 |
+
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
| 500 |
+
):
|
| 501 |
+
if should_skip_due_to_token_capacity(
|
| 502 |
+
bs, il, ol, skip_token_capacity_threshold
|
| 503 |
+
):
|
| 504 |
+
continue
|
| 505 |
+
results.append(
|
| 506 |
+
run_one_case(
|
| 507 |
+
base_url,
|
| 508 |
+
bs,
|
| 509 |
+
il,
|
| 510 |
+
ol,
|
| 511 |
+
temperature=bench_args.temperature,
|
| 512 |
+
return_logprob=bench_args.return_logprob,
|
| 513 |
+
stream_interval=bench_args.client_stream_interval,
|
| 514 |
+
input_len_step_percentage=bench_args.input_len_step_percentage,
|
| 515 |
+
run_name=bench_args.run_name,
|
| 516 |
+
result_filename=bench_args.result_filename,
|
| 517 |
+
tokenizer=tokenizer,
|
| 518 |
+
dataset_name=bench_args.dataset_name,
|
| 519 |
+
dataset_path=bench_args.dataset_path,
|
| 520 |
+
parallel_batch=bench_args.parallel_batch,
|
| 521 |
+
)
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
# Profile all cases
|
| 525 |
+
if bench_args.profile:
|
| 526 |
+
try:
|
| 527 |
+
for bs, il, ol in itertools.product(
|
| 528 |
+
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
| 529 |
+
):
|
| 530 |
+
if should_skip_due_to_token_capacity(
|
| 531 |
+
bs, il, ol, skip_token_capacity_threshold
|
| 532 |
+
):
|
| 533 |
+
continue
|
| 534 |
+
profile_prefix = (
|
| 535 |
+
bench_args.profile_prefix or ""
|
| 536 |
+
) + f"bs-{bs}-il-{il}"
|
| 537 |
+
profile_results.append(
|
| 538 |
+
run_one_case(
|
| 539 |
+
base_url,
|
| 540 |
+
bs,
|
| 541 |
+
il,
|
| 542 |
+
ol,
|
| 543 |
+
temperature=bench_args.temperature,
|
| 544 |
+
return_logprob=bench_args.return_logprob,
|
| 545 |
+
stream_interval=bench_args.client_stream_interval,
|
| 546 |
+
input_len_step_percentage=bench_args.input_len_step_percentage,
|
| 547 |
+
run_name=bench_args.run_name,
|
| 548 |
+
result_filename=bench_args.result_filename,
|
| 549 |
+
tokenizer=tokenizer,
|
| 550 |
+
dataset_name=bench_args.dataset_name,
|
| 551 |
+
dataset_path=bench_args.dataset_path,
|
| 552 |
+
parallel_batch=bench_args.parallel_batch,
|
| 553 |
+
profile=bench_args.profile,
|
| 554 |
+
profile_steps=bench_args.profile_steps,
|
| 555 |
+
profile_by_stage=bench_args.profile_by_stage,
|
| 556 |
+
profile_prefix=profile_prefix,
|
| 557 |
+
profile_output_dir=bench_args.profile_output_dir,
|
| 558 |
+
)
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
# Replace the profile link
|
| 562 |
+
for res, profile_res in zip(results, profile_results):
|
| 563 |
+
res.profile_link = profile_res.profile_link
|
| 564 |
+
except Exception as e:
|
| 565 |
+
print(f"Error profiling, there will be no profile trace dump: {e}")
|
| 566 |
+
finally:
|
| 567 |
+
if proc:
|
| 568 |
+
kill_process_tree(proc.pid)
|
| 569 |
+
|
| 570 |
+
print(f"\nResults are saved to {bench_args.result_filename}")
|
| 571 |
+
|
| 572 |
+
if not bench_args.show_report:
|
| 573 |
+
return
|
| 574 |
+
|
| 575 |
+
# Print summary
|
| 576 |
+
summary = get_report_summary(results, bench_args, server_args)
|
| 577 |
+
print(summary)
|
| 578 |
+
|
| 579 |
+
if is_in_ci() and bench_args.append_to_github_summary:
|
| 580 |
+
write_github_step_summary(summary)
|
| 581 |
+
else:
|
| 582 |
+
print(summary)
|
| 583 |
+
|
| 584 |
+
# Save results as pydantic models in the JSON format
|
| 585 |
+
if bench_args.pydantic_result_filename:
|
| 586 |
+
save_results_as_pydantic_models(
|
| 587 |
+
results,
|
| 588 |
+
pydantic_result_filename=bench_args.pydantic_result_filename,
|
| 589 |
+
model_path=server_args.model_path,
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
if __name__ == "__main__":
|
| 594 |
+
parser = argparse.ArgumentParser()
|
| 595 |
+
ServerArgs.add_cli_args(parser)
|
| 596 |
+
BenchArgs.add_cli_args(parser)
|
| 597 |
+
args = parser.parse_args()
|
| 598 |
+
|
| 599 |
+
random.seed(args.seed)
|
| 600 |
+
np.random.seed(args.seed)
|
| 601 |
+
|
| 602 |
+
server_args = ServerArgs.from_cli_args(args)
|
| 603 |
+
bench_args = BenchArgs.from_cli_args(args)
|
| 604 |
+
|
| 605 |
+
run_benchmark(server_args, bench_args)
|
sglang/bench_serving.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sglang/check_env.py
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Check environment configurations and dependency versions."""
|
| 2 |
+
|
| 3 |
+
import importlib.metadata
|
| 4 |
+
import os
|
| 5 |
+
import resource
|
| 6 |
+
import subprocess
|
| 7 |
+
import sys
|
| 8 |
+
from abc import abstractmethod
|
| 9 |
+
from collections import OrderedDict, defaultdict
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from sglang.srt.utils import is_hip, is_npu
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def is_cuda_v2():
|
| 17 |
+
return torch.version.cuda is not None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# List of packages to check versions
|
| 21 |
+
PACKAGE_LIST = [
|
| 22 |
+
"sglang",
|
| 23 |
+
"sgl_kernel",
|
| 24 |
+
"flashinfer_python",
|
| 25 |
+
"flashinfer_cubin",
|
| 26 |
+
"flashinfer_jit_cache",
|
| 27 |
+
"triton",
|
| 28 |
+
"transformers",
|
| 29 |
+
"torchao",
|
| 30 |
+
"numpy",
|
| 31 |
+
"aiohttp",
|
| 32 |
+
"fastapi",
|
| 33 |
+
"hf_transfer",
|
| 34 |
+
"huggingface_hub",
|
| 35 |
+
"interegular",
|
| 36 |
+
"modelscope",
|
| 37 |
+
"orjson",
|
| 38 |
+
"outlines",
|
| 39 |
+
"packaging",
|
| 40 |
+
"psutil",
|
| 41 |
+
"pydantic",
|
| 42 |
+
"python-multipart",
|
| 43 |
+
"pyzmq",
|
| 44 |
+
"torchao",
|
| 45 |
+
"uvicorn",
|
| 46 |
+
"uvloop",
|
| 47 |
+
"vllm",
|
| 48 |
+
"xgrammar",
|
| 49 |
+
"openai",
|
| 50 |
+
"tiktoken",
|
| 51 |
+
"anthropic",
|
| 52 |
+
"litellm",
|
| 53 |
+
"decord2",
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class BaseEnv:
|
| 58 |
+
"""Base class for environment check"""
|
| 59 |
+
|
| 60 |
+
def __init__(self):
|
| 61 |
+
self.package_list = PACKAGE_LIST
|
| 62 |
+
|
| 63 |
+
@abstractmethod
|
| 64 |
+
def get_info(self) -> dict:
|
| 65 |
+
"""
|
| 66 |
+
Get CUDA-related information if available.
|
| 67 |
+
"""
|
| 68 |
+
raise NotImplementedError
|
| 69 |
+
|
| 70 |
+
@abstractmethod
|
| 71 |
+
def get_topology(self) -> dict:
|
| 72 |
+
raise NotImplementedError
|
| 73 |
+
|
| 74 |
+
def get_package_versions(self) -> dict:
|
| 75 |
+
"""
|
| 76 |
+
Get versions of specified packages.
|
| 77 |
+
"""
|
| 78 |
+
versions = {}
|
| 79 |
+
for package in self.package_list:
|
| 80 |
+
package_name = package.split("==")[0].split(">=")[0].split("<=")[0]
|
| 81 |
+
try:
|
| 82 |
+
version = importlib.metadata.version(package_name)
|
| 83 |
+
versions[package_name] = version
|
| 84 |
+
except ModuleNotFoundError:
|
| 85 |
+
versions[package_name] = "Module Not Found"
|
| 86 |
+
return versions
|
| 87 |
+
|
| 88 |
+
def get_device_info(self):
|
| 89 |
+
"""
|
| 90 |
+
Get information about available GPU devices.
|
| 91 |
+
"""
|
| 92 |
+
devices = defaultdict(list)
|
| 93 |
+
capabilities = defaultdict(list)
|
| 94 |
+
for k in range(torch.cuda.device_count()):
|
| 95 |
+
devices[torch.cuda.get_device_name(k)].append(str(k))
|
| 96 |
+
capability = torch.cuda.get_device_capability(k)
|
| 97 |
+
capabilities[f"{capability[0]}.{capability[1]}"].append(str(k))
|
| 98 |
+
|
| 99 |
+
gpu_info = {}
|
| 100 |
+
for name, device_ids in devices.items():
|
| 101 |
+
gpu_info[f"GPU {','.join(device_ids)}"] = name
|
| 102 |
+
|
| 103 |
+
if len(capabilities) == 1:
|
| 104 |
+
# All GPUs have the same compute capability
|
| 105 |
+
cap, gpu_ids = list(capabilities.items())[0]
|
| 106 |
+
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
|
| 107 |
+
else:
|
| 108 |
+
# GPUs have different compute capabilities
|
| 109 |
+
for cap, gpu_ids in capabilities.items():
|
| 110 |
+
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
|
| 111 |
+
|
| 112 |
+
return gpu_info
|
| 113 |
+
|
| 114 |
+
def get_hypervisor_vendor(self) -> dict:
|
| 115 |
+
try:
|
| 116 |
+
output = subprocess.check_output(["lscpu"], text=True)
|
| 117 |
+
for line in output.split("\n"):
|
| 118 |
+
if "Hypervisor vendor:" in line:
|
| 119 |
+
return {"Hypervisor vendor:": line.split(":")[1].strip()}
|
| 120 |
+
return {}
|
| 121 |
+
except:
|
| 122 |
+
return {}
|
| 123 |
+
|
| 124 |
+
def get_ulimit_soft(self) -> dict:
|
| 125 |
+
ulimit_soft, _ = resource.getrlimit(resource.RLIMIT_NOFILE)
|
| 126 |
+
return {"ulimit soft": ulimit_soft}
|
| 127 |
+
|
| 128 |
+
def check_env(self):
|
| 129 |
+
"""
|
| 130 |
+
Check and print environment information.
|
| 131 |
+
"""
|
| 132 |
+
env_info = OrderedDict()
|
| 133 |
+
env_info["Python"] = sys.version.replace("\n", "")
|
| 134 |
+
env_info.update(self.get_info())
|
| 135 |
+
env_info["PyTorch"] = torch.__version__
|
| 136 |
+
env_info.update(self.get_package_versions())
|
| 137 |
+
env_info.update(self.get_topology())
|
| 138 |
+
env_info.update(self.get_hypervisor_vendor())
|
| 139 |
+
env_info.update(self.get_ulimit_soft())
|
| 140 |
+
|
| 141 |
+
for k, v in env_info.items():
|
| 142 |
+
print(f"{k}: {v}")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class GPUEnv(BaseEnv):
|
| 146 |
+
"""Environment checker for Nvidia GPU"""
|
| 147 |
+
|
| 148 |
+
def get_info(self):
|
| 149 |
+
cuda_info = {"CUDA available": torch.cuda.is_available()}
|
| 150 |
+
|
| 151 |
+
if cuda_info["CUDA available"]:
|
| 152 |
+
cuda_info.update(self.get_device_info())
|
| 153 |
+
cuda_info.update(self._get_cuda_version_info())
|
| 154 |
+
|
| 155 |
+
return cuda_info
|
| 156 |
+
|
| 157 |
+
def _get_cuda_version_info(self):
|
| 158 |
+
"""
|
| 159 |
+
Get CUDA version information.
|
| 160 |
+
"""
|
| 161 |
+
from torch.utils.cpp_extension import CUDA_HOME
|
| 162 |
+
|
| 163 |
+
cuda_info = {"CUDA_HOME": CUDA_HOME}
|
| 164 |
+
|
| 165 |
+
if CUDA_HOME and os.path.isdir(CUDA_HOME):
|
| 166 |
+
cuda_info.update(self._get_nvcc_info())
|
| 167 |
+
cuda_info.update(self._get_cuda_driver_version())
|
| 168 |
+
|
| 169 |
+
return cuda_info
|
| 170 |
+
|
| 171 |
+
def _get_nvcc_info(self):
|
| 172 |
+
"""
|
| 173 |
+
Get NVCC version information.
|
| 174 |
+
"""
|
| 175 |
+
from torch.utils.cpp_extension import CUDA_HOME
|
| 176 |
+
|
| 177 |
+
try:
|
| 178 |
+
nvcc = os.path.join(CUDA_HOME, "bin/nvcc")
|
| 179 |
+
nvcc_output = (
|
| 180 |
+
subprocess.check_output(f'"{nvcc}" -V', shell=True)
|
| 181 |
+
.decode("utf-8")
|
| 182 |
+
.strip()
|
| 183 |
+
)
|
| 184 |
+
return {
|
| 185 |
+
"NVCC": nvcc_output[
|
| 186 |
+
nvcc_output.rfind("Cuda compilation tools") : nvcc_output.rfind(
|
| 187 |
+
"Build"
|
| 188 |
+
)
|
| 189 |
+
].strip()
|
| 190 |
+
}
|
| 191 |
+
except subprocess.SubprocessError:
|
| 192 |
+
return {"NVCC": "Not Available"}
|
| 193 |
+
|
| 194 |
+
def _get_cuda_driver_version(self):
|
| 195 |
+
"""
|
| 196 |
+
Get CUDA driver version.
|
| 197 |
+
"""
|
| 198 |
+
versions = set()
|
| 199 |
+
try:
|
| 200 |
+
output = subprocess.check_output(
|
| 201 |
+
[
|
| 202 |
+
"nvidia-smi",
|
| 203 |
+
"--query-gpu=driver_version",
|
| 204 |
+
"--format=csv,noheader,nounits",
|
| 205 |
+
]
|
| 206 |
+
)
|
| 207 |
+
versions = set(output.decode().strip().split("\n"))
|
| 208 |
+
if len(versions) == 1:
|
| 209 |
+
return {"CUDA Driver Version": versions.pop()}
|
| 210 |
+
else:
|
| 211 |
+
return {"CUDA Driver Versions": ", ".join(sorted(versions))}
|
| 212 |
+
except subprocess.SubprocessError:
|
| 213 |
+
return {"CUDA Driver Version": "Not Available"}
|
| 214 |
+
|
| 215 |
+
def get_topology(self):
|
| 216 |
+
"""
|
| 217 |
+
Get GPU topology information.
|
| 218 |
+
"""
|
| 219 |
+
try:
|
| 220 |
+
result = subprocess.run(
|
| 221 |
+
["nvidia-smi", "topo", "-m"],
|
| 222 |
+
stdout=subprocess.PIPE,
|
| 223 |
+
stderr=subprocess.PIPE,
|
| 224 |
+
text=True,
|
| 225 |
+
check=True,
|
| 226 |
+
)
|
| 227 |
+
return {
|
| 228 |
+
"NVIDIA Topology": (
|
| 229 |
+
"\n" + result.stdout if result.returncode == 0 else None
|
| 230 |
+
)
|
| 231 |
+
}
|
| 232 |
+
except subprocess.SubprocessError:
|
| 233 |
+
return {}
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class HIPEnv(BaseEnv):
|
| 237 |
+
"""Environment checker for ROCm/HIP"""
|
| 238 |
+
|
| 239 |
+
def get_info(self):
|
| 240 |
+
cuda_info = {"ROCM available": torch.cuda.is_available()}
|
| 241 |
+
|
| 242 |
+
if cuda_info["ROCM available"]:
|
| 243 |
+
cuda_info.update(self.get_device_info())
|
| 244 |
+
cuda_info.update(self._get_cuda_version_info())
|
| 245 |
+
|
| 246 |
+
return cuda_info
|
| 247 |
+
|
| 248 |
+
def _get_cuda_version_info(self):
|
| 249 |
+
from torch.utils.cpp_extension import ROCM_HOME as ROCM_HOME
|
| 250 |
+
|
| 251 |
+
cuda_info = {"ROCM_HOME": ROCM_HOME}
|
| 252 |
+
|
| 253 |
+
if ROCM_HOME and os.path.isdir(ROCM_HOME):
|
| 254 |
+
cuda_info.update(self._get_hipcc_info())
|
| 255 |
+
cuda_info.update(self._get_rocm_driver_version())
|
| 256 |
+
|
| 257 |
+
return cuda_info
|
| 258 |
+
|
| 259 |
+
def _get_hipcc_info(self):
|
| 260 |
+
from torch.utils.cpp_extension import ROCM_HOME
|
| 261 |
+
|
| 262 |
+
try:
|
| 263 |
+
hipcc = os.path.join(ROCM_HOME, "bin/hipcc")
|
| 264 |
+
hipcc_output = (
|
| 265 |
+
subprocess.check_output(f'"{hipcc}" --version', shell=True)
|
| 266 |
+
.decode("utf-8")
|
| 267 |
+
.strip()
|
| 268 |
+
)
|
| 269 |
+
return {
|
| 270 |
+
"HIPCC": hipcc_output[
|
| 271 |
+
hipcc_output.rfind("HIP version") : hipcc_output.rfind("AMD clang")
|
| 272 |
+
].strip()
|
| 273 |
+
}
|
| 274 |
+
except subprocess.SubprocessError:
|
| 275 |
+
return {"HIPCC": "Not Available"}
|
| 276 |
+
|
| 277 |
+
def _get_rocm_driver_version(self):
|
| 278 |
+
try:
|
| 279 |
+
output = subprocess.check_output(
|
| 280 |
+
[
|
| 281 |
+
"rocm-smi",
|
| 282 |
+
"--showdriverversion",
|
| 283 |
+
"--csv",
|
| 284 |
+
]
|
| 285 |
+
)
|
| 286 |
+
versions = set(output.decode().strip().split("\n"))
|
| 287 |
+
versions.discard("name, value")
|
| 288 |
+
ver = versions.pop()
|
| 289 |
+
ver = ver.replace('"Driver version", ', "").replace('"', "")
|
| 290 |
+
|
| 291 |
+
return {"ROCM Driver Version": ver}
|
| 292 |
+
except subprocess.SubprocessError:
|
| 293 |
+
return {"ROCM Driver Version": "Not Available"}
|
| 294 |
+
|
| 295 |
+
def get_topology(self):
|
| 296 |
+
try:
|
| 297 |
+
result = subprocess.run(
|
| 298 |
+
["rocm-smi", "--showtopotype"],
|
| 299 |
+
stdout=subprocess.PIPE,
|
| 300 |
+
stderr=subprocess.PIPE,
|
| 301 |
+
text=True,
|
| 302 |
+
check=True,
|
| 303 |
+
)
|
| 304 |
+
return {
|
| 305 |
+
"AMD Topology": "\n" + result.stdout if result.returncode == 0 else None
|
| 306 |
+
}
|
| 307 |
+
except subprocess.SubprocessError:
|
| 308 |
+
return {}
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class NPUEnv(BaseEnv):
|
| 312 |
+
"""Environment checker for Ascend NPU"""
|
| 313 |
+
|
| 314 |
+
EXTRA_PACKAGE_LIST = [
|
| 315 |
+
"torch_npu",
|
| 316 |
+
"sgl-kernel-npu",
|
| 317 |
+
"deep_ep",
|
| 318 |
+
]
|
| 319 |
+
|
| 320 |
+
def __init__(self):
|
| 321 |
+
super().__init__()
|
| 322 |
+
self.package_list.extend(NPUEnv.EXTRA_PACKAGE_LIST)
|
| 323 |
+
|
| 324 |
+
def get_info(self):
|
| 325 |
+
cuda_info = {"NPU available": torch.npu.is_available()}
|
| 326 |
+
if cuda_info["NPU available"]:
|
| 327 |
+
cuda_info.update(self.get_device_info())
|
| 328 |
+
cuda_info.update(self._get_cann_version_info())
|
| 329 |
+
|
| 330 |
+
return cuda_info
|
| 331 |
+
|
| 332 |
+
def get_device_info(self):
|
| 333 |
+
"""
|
| 334 |
+
Get information about available NPUs.
|
| 335 |
+
Need to override due to torch_npu interface differences.
|
| 336 |
+
"""
|
| 337 |
+
devices = defaultdict(list)
|
| 338 |
+
for k in range(torch.npu.device_count()):
|
| 339 |
+
devices[torch.npu.get_device_name(k)].append(str(k))
|
| 340 |
+
|
| 341 |
+
npu_info = {}
|
| 342 |
+
for name, device_ids in devices.items():
|
| 343 |
+
npu_info[f"NPU {','.join(device_ids)}"] = name
|
| 344 |
+
|
| 345 |
+
return npu_info
|
| 346 |
+
|
| 347 |
+
def _get_cann_version_info(self):
|
| 348 |
+
cann_envs = ["ASCEND_TOOLKIT_HOME", "ASCEND_INSTALL_PATH"]
|
| 349 |
+
for var in cann_envs:
|
| 350 |
+
path = os.environ.get(var)
|
| 351 |
+
if path and os.path.exists(path):
|
| 352 |
+
CANN_HOME = path
|
| 353 |
+
break
|
| 354 |
+
else:
|
| 355 |
+
default_path = "/usr/local/Ascend/ascend-toolkit/latest"
|
| 356 |
+
CANN_HOME = default_path if os.path.exists(default_path) else None
|
| 357 |
+
|
| 358 |
+
if CANN_HOME:
|
| 359 |
+
npu_info = {"CANN_HOME": CANN_HOME}
|
| 360 |
+
npu_info.update(self._get_cann_info(CANN_HOME))
|
| 361 |
+
npu_info.update(self._get_ascend_driver_version())
|
| 362 |
+
return npu_info
|
| 363 |
+
else:
|
| 364 |
+
return {"CANN_HOME": "Not found"}
|
| 365 |
+
|
| 366 |
+
def _get_cann_info(self, CANN_HOME: str):
|
| 367 |
+
cann_info = {}
|
| 368 |
+
cann_version_file = os.path.join(CANN_HOME, "version.cfg")
|
| 369 |
+
if os.path.exists(cann_version_file):
|
| 370 |
+
with open(cann_version_file, "r", encoding="utf-8") as f:
|
| 371 |
+
f.readline() # discard first line comment in version.cfg
|
| 372 |
+
cann_info["CANN"] = f.readline().split("[")[1].split("]")[0]
|
| 373 |
+
else:
|
| 374 |
+
cann_info["CANN"] = "Not Available"
|
| 375 |
+
try:
|
| 376 |
+
bisheng = os.path.join(CANN_HOME, "compiler/ccec_compiler/bin/bisheng")
|
| 377 |
+
bisheng_output = (
|
| 378 |
+
subprocess.check_output([bisheng, "--version"]).decode("utf-8").strip()
|
| 379 |
+
)
|
| 380 |
+
cann_info["BiSheng"] = bisheng_output.split("\n")[0].strip()
|
| 381 |
+
except subprocess.SubprocessError:
|
| 382 |
+
cann_info["BiSheng"] = "Not Available"
|
| 383 |
+
return cann_info
|
| 384 |
+
|
| 385 |
+
def _get_ascend_driver_version(self):
|
| 386 |
+
try:
|
| 387 |
+
output = subprocess.check_output(
|
| 388 |
+
[
|
| 389 |
+
"npu-smi",
|
| 390 |
+
"info",
|
| 391 |
+
"-t",
|
| 392 |
+
"board",
|
| 393 |
+
"-i",
|
| 394 |
+
"0",
|
| 395 |
+
]
|
| 396 |
+
)
|
| 397 |
+
for line in output.decode().strip().split("\n"):
|
| 398 |
+
if "Software Version" in line:
|
| 399 |
+
version = line.split(":")[-1].strip()
|
| 400 |
+
break
|
| 401 |
+
else:
|
| 402 |
+
version = "Not Available"
|
| 403 |
+
|
| 404 |
+
return {"Ascend Driver Version": version}
|
| 405 |
+
except subprocess.SubprocessError:
|
| 406 |
+
return {"Ascend Driver Version": "Not Available"}
|
| 407 |
+
|
| 408 |
+
def get_topology(self):
|
| 409 |
+
try:
|
| 410 |
+
result = subprocess.run(
|
| 411 |
+
["npu-smi", "info", "-t", "topo"],
|
| 412 |
+
stdout=subprocess.PIPE,
|
| 413 |
+
stderr=subprocess.PIPE,
|
| 414 |
+
text=True,
|
| 415 |
+
check=True,
|
| 416 |
+
)
|
| 417 |
+
return {
|
| 418 |
+
"Ascend Topology": (
|
| 419 |
+
"\n" + result.stdout if result.returncode == 0 else None
|
| 420 |
+
)
|
| 421 |
+
}
|
| 422 |
+
except subprocess.SubprocessError:
|
| 423 |
+
return {}
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
if __name__ == "__main__":
|
| 427 |
+
if is_cuda_v2():
|
| 428 |
+
env = GPUEnv()
|
| 429 |
+
elif is_hip():
|
| 430 |
+
env = HIPEnv()
|
| 431 |
+
elif is_npu():
|
| 432 |
+
env = NPUEnv()
|
| 433 |
+
env.check_env()
|
sglang/cli/__init__.py
ADDED
|
File without changes
|
sglang/cli/generate.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
from sglang.cli.utils import get_is_diffusion_model, get_model_path
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def generate(args, extra_argv):
|
| 7 |
+
# If help is requested, show generate subcommand help without requiring --model-path
|
| 8 |
+
if any(h in extra_argv for h in ("-h", "--help")):
|
| 9 |
+
from sglang.multimodal_gen.runtime.entrypoints.cli.generate import (
|
| 10 |
+
add_multimodal_gen_generate_args,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
parser = argparse.ArgumentParser(description="SGLang Multimodal Generation")
|
| 14 |
+
add_multimodal_gen_generate_args(parser)
|
| 15 |
+
parser.parse_args(extra_argv)
|
| 16 |
+
return
|
| 17 |
+
|
| 18 |
+
model_path = get_model_path(extra_argv)
|
| 19 |
+
is_diffusion_model = get_is_diffusion_model(model_path)
|
| 20 |
+
if is_diffusion_model:
|
| 21 |
+
from sglang.multimodal_gen.runtime.entrypoints.cli.generate import (
|
| 22 |
+
add_multimodal_gen_generate_args,
|
| 23 |
+
generate_cmd,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
parser = argparse.ArgumentParser(description="SGLang Multimodal Generation")
|
| 27 |
+
add_multimodal_gen_generate_args(parser)
|
| 28 |
+
parsed_args = parser.parse_args(extra_argv)
|
| 29 |
+
generate_cmd(parsed_args)
|
| 30 |
+
else:
|
| 31 |
+
raise Exception(
|
| 32 |
+
f"Generate subcommand is not yet supported for model: {model_path}"
|
| 33 |
+
)
|
sglang/cli/main.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
from sglang.cli.generate import generate
|
| 4 |
+
from sglang.cli.serve import serve
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def main():
|
| 8 |
+
parser = argparse.ArgumentParser()
|
| 9 |
+
subparsers = parser.add_subparsers(dest="subcommand", required=True)
|
| 10 |
+
|
| 11 |
+
serve_parser = subparsers.add_parser(
|
| 12 |
+
"serve",
|
| 13 |
+
help="Launch the SGLang server.",
|
| 14 |
+
add_help=False, # Defer help to the specific parser
|
| 15 |
+
)
|
| 16 |
+
serve_parser.set_defaults(func=serve)
|
| 17 |
+
|
| 18 |
+
generate_parser = subparsers.add_parser(
|
| 19 |
+
"generate",
|
| 20 |
+
help="Run inference on a multimodal model.",
|
| 21 |
+
add_help=False, # Defer help to the specific parser
|
| 22 |
+
)
|
| 23 |
+
generate_parser.set_defaults(func=generate)
|
| 24 |
+
|
| 25 |
+
args, extra_argv = parser.parse_known_args()
|
| 26 |
+
args.func(args, extra_argv)
|
sglang/cli/serve.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
from sglang.cli.utils import get_is_diffusion_model, get_model_path
|
| 8 |
+
from sglang.srt.utils import kill_process_tree
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def serve(args, extra_argv):
|
| 14 |
+
if any(h in extra_argv for h in ("-h", "--help")):
|
| 15 |
+
# Since the server type is determined by the model, and we don't have a model path,
|
| 16 |
+
# we can't show the exact help. Instead, we show a general help message and then
|
| 17 |
+
# the help for both possible server types.
|
| 18 |
+
print(
|
| 19 |
+
"Usage: sglang serve --model-path <model-name-or-path> [additional-arguments]\n"
|
| 20 |
+
)
|
| 21 |
+
print(
|
| 22 |
+
"This command can launch either a standard language model server or a diffusion model server."
|
| 23 |
+
)
|
| 24 |
+
print("The server type is determined by the model path.\n")
|
| 25 |
+
print("For specific arguments, please provide a model_path.")
|
| 26 |
+
print("\n--- Help for Standard Language Model Server ---")
|
| 27 |
+
from sglang.srt.server_args import prepare_server_args
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
prepare_server_args(["--help"])
|
| 31 |
+
except SystemExit:
|
| 32 |
+
pass # argparse --help calls sys.exit
|
| 33 |
+
|
| 34 |
+
print("\n--- Help for Diffusion Model Server ---")
|
| 35 |
+
from sglang.multimodal_gen.runtime.entrypoints.cli.serve import (
|
| 36 |
+
add_multimodal_gen_serve_args,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
parser = argparse.ArgumentParser(description="SGLang Diffusion Model Serving")
|
| 40 |
+
add_multimodal_gen_serve_args(parser)
|
| 41 |
+
parser.print_help()
|
| 42 |
+
return
|
| 43 |
+
|
| 44 |
+
model_path = get_model_path(extra_argv)
|
| 45 |
+
try:
|
| 46 |
+
is_diffusion_model = get_is_diffusion_model(model_path)
|
| 47 |
+
if is_diffusion_model:
|
| 48 |
+
logger.info("Diffusion model detected")
|
| 49 |
+
|
| 50 |
+
if is_diffusion_model:
|
| 51 |
+
# Logic for Diffusion Models
|
| 52 |
+
from sglang.multimodal_gen.runtime.entrypoints.cli.serve import (
|
| 53 |
+
add_multimodal_gen_serve_args,
|
| 54 |
+
execute_serve_cmd,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
parser = argparse.ArgumentParser(
|
| 58 |
+
description="SGLang Diffusion Model Serving"
|
| 59 |
+
)
|
| 60 |
+
add_multimodal_gen_serve_args(parser)
|
| 61 |
+
parsed_args, remaining_argv = parser.parse_known_args(extra_argv)
|
| 62 |
+
|
| 63 |
+
execute_serve_cmd(parsed_args, remaining_argv)
|
| 64 |
+
else:
|
| 65 |
+
# Logic for Standard Language Models
|
| 66 |
+
from sglang.launch_server import run_server
|
| 67 |
+
from sglang.srt.server_args import prepare_server_args
|
| 68 |
+
|
| 69 |
+
# Add a dummy argument for the program name, expected by prepare_server_args
|
| 70 |
+
# as it typically processes sys.argv
|
| 71 |
+
server_args = prepare_server_args(extra_argv)
|
| 72 |
+
|
| 73 |
+
run_server(server_args)
|
| 74 |
+
finally:
|
| 75 |
+
kill_process_tree(os.getpid(), include_parent=False)
|
sglang/cli/utils.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import tempfile
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import filelock
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
temp_dir = tempfile.gettempdir()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
|
| 17 |
+
lock_dir = cache_dir or temp_dir
|
| 18 |
+
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
|
| 19 |
+
model_name = model_name_or_path.replace("/", "-")
|
| 20 |
+
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
|
| 21 |
+
lock_file_name = hash_name + model_name + ".lock"
|
| 22 |
+
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
|
| 23 |
+
return lock
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Copied and adapted from hf_diffusers_utils.py
|
| 27 |
+
def _maybe_download_model(
|
| 28 |
+
model_name_or_path: str, local_dir: str | None = None, download: bool = True
|
| 29 |
+
) -> str:
|
| 30 |
+
"""
|
| 31 |
+
Resolve a model path. If it's a local directory, return it.
|
| 32 |
+
If it's a Hugging Face Hub ID, download only the config file
|
| 33 |
+
(`model_index.json` or `config.json`) and return its directory.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
model_name_or_path: Local path or Hugging Face Hub model ID
|
| 37 |
+
local_dir: Local directory to save the downloaded file (if any)
|
| 38 |
+
download: Whether to download from Hugging Face Hub when needed
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Local directory path that contains the downloaded config file, or the original local directory.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
if os.path.exists(model_name_or_path):
|
| 45 |
+
logger.info("Model already exists locally")
|
| 46 |
+
return model_name_or_path
|
| 47 |
+
|
| 48 |
+
if not download:
|
| 49 |
+
return model_name_or_path
|
| 50 |
+
|
| 51 |
+
with _get_lock(model_name_or_path):
|
| 52 |
+
# Try `model_index.json` first (diffusers models)
|
| 53 |
+
try:
|
| 54 |
+
logger.info(
|
| 55 |
+
"Downloading model_index.json from HF Hub for %s...",
|
| 56 |
+
model_name_or_path,
|
| 57 |
+
)
|
| 58 |
+
file_path = hf_hub_download(
|
| 59 |
+
repo_id=model_name_or_path,
|
| 60 |
+
filename="model_index.json",
|
| 61 |
+
local_dir=local_dir,
|
| 62 |
+
)
|
| 63 |
+
logger.info("Downloaded to %s", file_path)
|
| 64 |
+
return os.path.dirname(file_path)
|
| 65 |
+
except Exception as e_index:
|
| 66 |
+
logger.debug("model_index.json not found or failed: %s", e_index)
|
| 67 |
+
|
| 68 |
+
# Fallback to `config.json`
|
| 69 |
+
try:
|
| 70 |
+
logger.info(
|
| 71 |
+
"Downloading config.json from HF Hub for %s...", model_name_or_path
|
| 72 |
+
)
|
| 73 |
+
file_path = hf_hub_download(
|
| 74 |
+
repo_id=model_name_or_path,
|
| 75 |
+
filename="config.json",
|
| 76 |
+
local_dir=local_dir,
|
| 77 |
+
)
|
| 78 |
+
logger.info("Downloaded to %s", file_path)
|
| 79 |
+
return os.path.dirname(file_path)
|
| 80 |
+
except Exception as e_config:
|
| 81 |
+
raise ValueError(
|
| 82 |
+
(
|
| 83 |
+
"Could not find model locally at %s and failed to download "
|
| 84 |
+
"model_index.json/config.json from HF Hub: %s"
|
| 85 |
+
)
|
| 86 |
+
% (model_name_or_path, e_config)
|
| 87 |
+
) from e_config
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# Copied and adapted from hf_diffusers_utils.py
|
| 91 |
+
def is_diffusers_model_path(model_path: str) -> True:
|
| 92 |
+
"""
|
| 93 |
+
Verify if the model directory contains a valid diffusers configuration.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
model_path: Path to the model directory
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
The loaded model configuration as a dictionary if the model is a diffusers model
|
| 100 |
+
None if the model is not a diffusers model
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
# Prefer model_index.json which indicates a diffusers pipeline
|
| 104 |
+
config_path = os.path.join(model_path, "model_index.json")
|
| 105 |
+
if not os.path.exists(config_path):
|
| 106 |
+
return False
|
| 107 |
+
|
| 108 |
+
# Load the config
|
| 109 |
+
with open(config_path) as f:
|
| 110 |
+
config = json.load(f)
|
| 111 |
+
|
| 112 |
+
# Verify diffusers version exists
|
| 113 |
+
if "_diffusers_version" not in config:
|
| 114 |
+
return False
|
| 115 |
+
return True
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def get_is_diffusion_model(model_path: str):
|
| 119 |
+
model_path = _maybe_download_model(model_path)
|
| 120 |
+
is_diffusion_model = is_diffusers_model_path(model_path)
|
| 121 |
+
if is_diffusion_model:
|
| 122 |
+
logger.info("Diffusion model detected")
|
| 123 |
+
return is_diffusion_model
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def get_model_path(extra_argv):
|
| 127 |
+
# Find the model_path argument
|
| 128 |
+
model_path = None
|
| 129 |
+
for i, arg in enumerate(extra_argv):
|
| 130 |
+
if arg == "--model-path":
|
| 131 |
+
if i + 1 < len(extra_argv):
|
| 132 |
+
model_path = extra_argv[i + 1]
|
| 133 |
+
break
|
| 134 |
+
elif arg.startswith("--model-path="):
|
| 135 |
+
model_path = arg.split("=", 1)[1]
|
| 136 |
+
break
|
| 137 |
+
|
| 138 |
+
if model_path is None:
|
| 139 |
+
# Fallback for --help or other cases where model-path is not provided
|
| 140 |
+
if any(h in extra_argv for h in ["-h", "--help"]):
|
| 141 |
+
raise Exception(
|
| 142 |
+
"Usage: sglang serve --model-path <model-name-or-path> [additional-arguments]\n\n"
|
| 143 |
+
"This command can launch either a standard language model server or a diffusion model server.\n"
|
| 144 |
+
"The server type is determined by the model path.\n"
|
| 145 |
+
"For specific arguments, please provide a model_path."
|
| 146 |
+
)
|
| 147 |
+
else:
|
| 148 |
+
raise Exception(
|
| 149 |
+
"Error: --model-path is required. "
|
| 150 |
+
"Please provide the path to the model."
|
| 151 |
+
)
|
| 152 |
+
return model_path
|
sglang/compile_deep_gemm.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Compile DeepGEMM Kernels for a model with specify server arguments
|
| 3 |
+
|
| 4 |
+
This script launches a server for capturing DeepGEMM calls and then compiles the kernels.
|
| 5 |
+
It accepts server arguments (the same as launch_server.py).
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import dataclasses
|
| 14 |
+
import multiprocessing
|
| 15 |
+
import os
|
| 16 |
+
import time
|
| 17 |
+
|
| 18 |
+
import requests
|
| 19 |
+
|
| 20 |
+
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
|
| 21 |
+
from sglang.srt.entrypoints.http_server import launch_server
|
| 22 |
+
from sglang.srt.entrypoints.warmup import warmup
|
| 23 |
+
from sglang.srt.environ import envs
|
| 24 |
+
from sglang.srt.managers.io_struct import GenerateReqInput
|
| 25 |
+
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
| 26 |
+
from sglang.srt.server_args import ServerArgs
|
| 27 |
+
from sglang.srt.utils import kill_process_tree
|
| 28 |
+
|
| 29 |
+
multiprocessing.set_start_method("spawn", force=True)
|
| 30 |
+
|
| 31 |
+
# Reduce warning
|
| 32 |
+
envs.SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE.set(True)
|
| 33 |
+
# Force enable deep gemm
|
| 34 |
+
envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(True)
|
| 35 |
+
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
|
| 36 |
+
os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclasses.dataclass
|
| 40 |
+
class CompileArgs:
|
| 41 |
+
timeout: int = 3600
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def add_cli_args(parser: argparse.ArgumentParser):
|
| 45 |
+
parser.add_argument("--timeout", type=int, default=CompileArgs.timeout)
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def from_cli_args(cls, args: argparse.Namespace):
|
| 49 |
+
# use the default value's type to cast the args into correct types.
|
| 50 |
+
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
| 51 |
+
return cls(
|
| 52 |
+
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@warmup("compile-deep-gemm")
|
| 57 |
+
async def warm_up_compile(
|
| 58 |
+
disaggregation_mode: str, tokenizer_manager: TokenizerManager
|
| 59 |
+
):
|
| 60 |
+
print("\nGenerate warm up request for compiling DeepGEMM...\n")
|
| 61 |
+
generate_req_input = GenerateReqInput(
|
| 62 |
+
input_ids=[0, 1, 2, 3],
|
| 63 |
+
sampling_params={
|
| 64 |
+
"temperature": 0.0,
|
| 65 |
+
"max_new_tokens": 8,
|
| 66 |
+
"ignore_eos": True,
|
| 67 |
+
},
|
| 68 |
+
)
|
| 69 |
+
if disaggregation_mode != "null":
|
| 70 |
+
generate_req_input.bootstrap_room = 0
|
| 71 |
+
generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST
|
| 72 |
+
|
| 73 |
+
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def launch_server_internal(server_args):
|
| 77 |
+
try:
|
| 78 |
+
launch_server(server_args)
|
| 79 |
+
except Exception as e:
|
| 80 |
+
raise e
|
| 81 |
+
finally:
|
| 82 |
+
kill_process_tree(os.getpid(), include_parent=False)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def launch_server_process_and_send_one_request(
|
| 86 |
+
server_args: ServerArgs, compile_args: CompileArgs
|
| 87 |
+
):
|
| 88 |
+
proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
|
| 89 |
+
proc.start()
|
| 90 |
+
base_url = f"http://{server_args.host}:{server_args.port}"
|
| 91 |
+
timeout = compile_args.timeout
|
| 92 |
+
|
| 93 |
+
start_time = time.perf_counter()
|
| 94 |
+
while time.perf_counter() - start_time < timeout:
|
| 95 |
+
try:
|
| 96 |
+
headers = {
|
| 97 |
+
"Content-Type": "application/json; charset=utf-8",
|
| 98 |
+
}
|
| 99 |
+
if server_args.node_rank == 0:
|
| 100 |
+
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
| 101 |
+
else:
|
| 102 |
+
# This http api is created by launch_dummy_health_check_server for none-rank0 node.
|
| 103 |
+
response = requests.get(f"{base_url}/health", headers=headers)
|
| 104 |
+
if response.status_code == 200:
|
| 105 |
+
# Rank-0 node send a request to sync with other node and then return.
|
| 106 |
+
if server_args.node_rank == 0:
|
| 107 |
+
payload = {
|
| 108 |
+
"input_ids": [0, 1, 2, 3],
|
| 109 |
+
"sampling_params": {
|
| 110 |
+
"max_new_tokens": 8,
|
| 111 |
+
"temperature": 0,
|
| 112 |
+
},
|
| 113 |
+
}
|
| 114 |
+
# In PD mode, include fake bootstrap fields so workers don't assert
|
| 115 |
+
if server_args.disaggregation_mode != "null":
|
| 116 |
+
payload["bootstrap_host"] = FAKE_BOOTSTRAP_HOST
|
| 117 |
+
payload["bootstrap_room"] = 0
|
| 118 |
+
|
| 119 |
+
response = requests.post(
|
| 120 |
+
f"{base_url}/generate",
|
| 121 |
+
json=payload,
|
| 122 |
+
timeout=600,
|
| 123 |
+
)
|
| 124 |
+
if response.status_code != 200:
|
| 125 |
+
error = response.json()
|
| 126 |
+
raise RuntimeError(f"Sync request failed: {error}")
|
| 127 |
+
# Other nodes should wait for the exit signal from Rank-0 node.
|
| 128 |
+
else:
|
| 129 |
+
start_time_waiting = time.perf_counter()
|
| 130 |
+
while proc.is_alive():
|
| 131 |
+
if time.perf_counter() - start_time_waiting < timeout:
|
| 132 |
+
time.sleep(10)
|
| 133 |
+
else:
|
| 134 |
+
raise TimeoutError("Waiting for main node timeout!")
|
| 135 |
+
return proc
|
| 136 |
+
except requests.RequestException:
|
| 137 |
+
pass
|
| 138 |
+
time.sleep(10)
|
| 139 |
+
raise TimeoutError(
|
| 140 |
+
"DeepGEMM Kernels compilation timeout."
|
| 141 |
+
"\n\nFeel free and please restart the command."
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
|
| 146 |
+
# Disable cuda graph and torch compile to save time
|
| 147 |
+
server_args.disable_cuda_graph = True
|
| 148 |
+
server_args.enable_torch_compile = False
|
| 149 |
+
print(f"Disable CUDA Graph and Torch Compile to save time...")
|
| 150 |
+
|
| 151 |
+
# Set watchdog timeout to compile_args.timeout because compilation will take a long time
|
| 152 |
+
server_args.watchdog_timeout = compile_args.timeout
|
| 153 |
+
server_args.warmups = "compile-deep-gemm"
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
|
| 157 |
+
print(
|
| 158 |
+
"Begin DeepGEMM Kernels compilation...\n"
|
| 159 |
+
"It may take a long time and timeout maybe raised "
|
| 160 |
+
"while the compilation is still in progress.\n"
|
| 161 |
+
"Just feel free to restart the command "
|
| 162 |
+
"until the compilation is fully finished.\n"
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
proc = launch_server_process_and_send_one_request(server_args, compile_args)
|
| 166 |
+
|
| 167 |
+
print("\nDeepGEMM Kernels compilation finished successfully.")
|
| 168 |
+
|
| 169 |
+
# Sleep for safety
|
| 170 |
+
time.sleep(10)
|
| 171 |
+
if proc.is_alive():
|
| 172 |
+
# This is the rank0 node.
|
| 173 |
+
kill_process_tree(proc.pid)
|
| 174 |
+
else:
|
| 175 |
+
try:
|
| 176 |
+
kill_process_tree(proc.pid)
|
| 177 |
+
except Exception:
|
| 178 |
+
pass
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
if __name__ == "__main__":
|
| 182 |
+
parser = argparse.ArgumentParser()
|
| 183 |
+
ServerArgs.add_cli_args(parser)
|
| 184 |
+
CompileArgs.add_cli_args(parser)
|
| 185 |
+
args = parser.parse_args()
|
| 186 |
+
server_args = ServerArgs.from_cli_args(args)
|
| 187 |
+
compile_args = CompileArgs.from_cli_args(args)
|
| 188 |
+
|
| 189 |
+
refine_server_args(server_args, compile_args)
|
| 190 |
+
|
| 191 |
+
run_compile(server_args, compile_args)
|
sglang/eval/llama3_eval.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapt from https://github.com/fw-ai/llm_eval_meta
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import asyncio
|
| 5 |
+
import os
|
| 6 |
+
import pickle
|
| 7 |
+
import re
|
| 8 |
+
import shutil
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
|
| 12 |
+
import httpx
|
| 13 |
+
import numpy as np
|
| 14 |
+
import openai
|
| 15 |
+
from datasets import load_dataset
|
| 16 |
+
from openai import AsyncOpenAI
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
# Mapping providers to their clients and models
|
| 20 |
+
provider_to_models = {
|
| 21 |
+
"b10": {
|
| 22 |
+
"8b": "meta-llama/Llama-3.1-8B-Instruct",
|
| 23 |
+
"70b": "meta-llama/Llama-3.1-70B-Instruct",
|
| 24 |
+
"405b": "meta-llama/Llama-3.1-405B-Instruct",
|
| 25 |
+
},
|
| 26 |
+
"oai": {
|
| 27 |
+
"8b": "meta-llama/Llama-3.1-8B-Instruct",
|
| 28 |
+
"70b": "meta-llama/Llama-3.1-70B-Instruct",
|
| 29 |
+
"405b": "meta-llama/Llama-3.1-405B-Instruct",
|
| 30 |
+
},
|
| 31 |
+
"sgl": {
|
| 32 |
+
"8b": "meta-llama/Llama-3.1-8B-Instruct",
|
| 33 |
+
"70b": "meta-llama/Llama-3.1-70B-Instruct",
|
| 34 |
+
"405b": "meta-llama/Llama-3.1-405B-Instruct",
|
| 35 |
+
},
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
async def fetch_responses(
|
| 40 |
+
client, prompt, semaphore, index, provider, model_size, output_dir, max_tokens
|
| 41 |
+
):
|
| 42 |
+
output_file = os.path.join(output_dir, f"response_{index}.pkl")
|
| 43 |
+
if os.path.exists(output_file):
|
| 44 |
+
print(f"File {output_file} already exists, skipping.")
|
| 45 |
+
return
|
| 46 |
+
|
| 47 |
+
async with semaphore:
|
| 48 |
+
response = await client.completions.create(
|
| 49 |
+
model=provider_to_models[provider][model_size],
|
| 50 |
+
prompt=prompt,
|
| 51 |
+
temperature=0.0,
|
| 52 |
+
max_tokens=max_tokens,
|
| 53 |
+
)
|
| 54 |
+
if isinstance(response, openai.BadRequestError):
|
| 55 |
+
with open(output_file, "wb") as f:
|
| 56 |
+
pickle.dump("bad_response", f)
|
| 57 |
+
assert isinstance(response, openai.types.completion.Completion)
|
| 58 |
+
# Save response to a file
|
| 59 |
+
with open(output_file, "wb") as f:
|
| 60 |
+
pickle.dump(response, f)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
TASK_TO_MAX_TOKENS = {
|
| 64 |
+
"evals__mmlu__details": 1,
|
| 65 |
+
"evals__mmlu__0_shot__cot__details": 1024,
|
| 66 |
+
# Official meta uses 1024, but a small % (.05) of questions are answered correctly after relaxing
|
| 67 |
+
"evals__mmlu_pro__details": 2048,
|
| 68 |
+
"evals__gsm8k__details": 1024,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
TASK_TO_EVAL_SET = {
|
| 72 |
+
"mmlu": "evals__mmlu__details",
|
| 73 |
+
"mmlu_cot": "evals__mmlu__0_shot__cot__details",
|
| 74 |
+
"mmlu_pro": "evals__mmlu_pro__details",
|
| 75 |
+
"gsm8k": "evals__gsm8k__details",
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class CustomAsyncHTTPXClient(httpx.AsyncClient):
|
| 80 |
+
async def send(self, request: httpx.Request, *args, **kwargs) -> httpx.Response:
|
| 81 |
+
request.url = httpx.URL(
|
| 82 |
+
f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict"
|
| 83 |
+
)
|
| 84 |
+
return await super().send(request, *args, **kwargs)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_client(provider):
|
| 88 |
+
if provider not in "b10":
|
| 89 |
+
if os.getenv("OPENAI_API_KEY") == None:
|
| 90 |
+
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
| 91 |
+
return {
|
| 92 |
+
"oai": AsyncOpenAI(base_url="http://127.0.0.1:8000/v1/"),
|
| 93 |
+
"b10": AsyncOpenAI(
|
| 94 |
+
api_key=f"Api-Key {os.getenv('OPENAI_API_KEY')}",
|
| 95 |
+
base_url=f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict",
|
| 96 |
+
http_client=CustomAsyncHTTPXClient(),
|
| 97 |
+
),
|
| 98 |
+
"sgl": AsyncOpenAI(base_url="http://127.0.0.1:30000/v1/"),
|
| 99 |
+
}[provider]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# Define the benchmark function
|
| 103 |
+
async def benchmark(args):
|
| 104 |
+
ds = load_dataset(
|
| 105 |
+
"meta-llama/Llama-3.1-405B-Instruct-evals",
|
| 106 |
+
f"Llama-3.1-405B-Instruct-{TASK_TO_EVAL_SET[args.task]}",
|
| 107 |
+
)
|
| 108 |
+
semaphore = asyncio.Semaphore(args.concurrency) # Limit to 16 concurrent tasks
|
| 109 |
+
|
| 110 |
+
if args.num_examples is None:
|
| 111 |
+
args.num_examples = len(ds["latest"]["input_final_prompts"])
|
| 112 |
+
prompts = ds["latest"]["input_final_prompts"][: args.num_examples]
|
| 113 |
+
|
| 114 |
+
# Create the output directory if it does not exist
|
| 115 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 116 |
+
|
| 117 |
+
tasks = []
|
| 118 |
+
# Create the tasks with tqdm progress bar
|
| 119 |
+
max_tokens = TASK_TO_MAX_TOKENS[TASK_TO_EVAL_SET[args.task]]
|
| 120 |
+
client = get_client(args.provider)
|
| 121 |
+
for idx, prompt in enumerate(tqdm(prompts, desc="Creating tasks")):
|
| 122 |
+
tasks.append(
|
| 123 |
+
asyncio.create_task(
|
| 124 |
+
fetch_responses(
|
| 125 |
+
client,
|
| 126 |
+
f"<|begin_of_text|>{prompt[0]}",
|
| 127 |
+
semaphore,
|
| 128 |
+
idx,
|
| 129 |
+
args.provider,
|
| 130 |
+
args.model_size,
|
| 131 |
+
args.output_dir,
|
| 132 |
+
max_tokens=max_tokens,
|
| 133 |
+
)
|
| 134 |
+
)
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Run the tasks with tqdm progress bar
|
| 138 |
+
for future in tqdm(
|
| 139 |
+
asyncio.as_completed(tasks), total=len(tasks), desc="Processing tasks"
|
| 140 |
+
):
|
| 141 |
+
await future
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def get_mmlu_answer(response):
|
| 145 |
+
if response is not None:
|
| 146 |
+
return response.choices[0].text.lstrip().rstrip().upper().replace(".", "")
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def get_mmlu_cot_answer(response):
|
| 151 |
+
pattern = r"The best answer is (.+)\.?"
|
| 152 |
+
match = re.search(pattern, response.choices[0].text)
|
| 153 |
+
if match:
|
| 154 |
+
return match.group(1).replace(".", "").replace("*", "")
|
| 155 |
+
|
| 156 |
+
pattern = r"the best answer is (.+)\.?"
|
| 157 |
+
match = re.search(pattern, response.choices[0].text)
|
| 158 |
+
if match:
|
| 159 |
+
return match.group(1).replace(".", "")
|
| 160 |
+
|
| 161 |
+
pattern = r"The correct answer is (.+)\.?"
|
| 162 |
+
match = re.search(pattern, response.choices[0].text)
|
| 163 |
+
if match:
|
| 164 |
+
return match.group(1).replace(".", "")
|
| 165 |
+
|
| 166 |
+
pattern = r"the correct answer is (.+)\.?"
|
| 167 |
+
match = re.search(pattern, response.choices[0].text)
|
| 168 |
+
if match:
|
| 169 |
+
return match.group(1).replace(".", "")
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def get_answer_gsm8k(response):
|
| 173 |
+
pattern = r"The final answer is (.+)\.?"
|
| 174 |
+
match = re.search(pattern, response.choices[0].text)
|
| 175 |
+
if match:
|
| 176 |
+
s = match.group(1)
|
| 177 |
+
for ok_symbol in ["%", "$"]:
|
| 178 |
+
s = s.replace(ok_symbol, "")
|
| 179 |
+
return s
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
TASK_TO_ANSWER_EXTRACTOR = {
|
| 183 |
+
"evals__mmlu__details": get_mmlu_answer,
|
| 184 |
+
"evals__mmlu__0_shot__cot__details": get_mmlu_cot_answer,
|
| 185 |
+
"evals__gsm8k__details": get_answer_gsm8k,
|
| 186 |
+
"evals__mmlu_pro__details": get_mmlu_cot_answer,
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def get_dataset_from_task(task, response_path, model_size):
|
| 191 |
+
ds_405b = load_dataset(
|
| 192 |
+
f"meta-llama/Llama-3.1-405B-Instruct-evals",
|
| 193 |
+
f"Llama-3.1-405B-Instruct-{task}",
|
| 194 |
+
)
|
| 195 |
+
ds_405b_hash_order = [x[0] for x in ds_405b["latest"]["input_final_prompts_hash"]]
|
| 196 |
+
|
| 197 |
+
if "70b" in model_size or "8b" in model_size:
|
| 198 |
+
if "70" in model_size:
|
| 199 |
+
ref_model_ds = load_dataset(
|
| 200 |
+
f"meta-llama/Llama-3.1-70B-Instruct-evals",
|
| 201 |
+
f"Llama-3.1-70B-Instruct-{task}",
|
| 202 |
+
)
|
| 203 |
+
else:
|
| 204 |
+
ref_model_ds = load_dataset(
|
| 205 |
+
f"meta-llama/Llama-3.1-8B-Instruct-evals",
|
| 206 |
+
f"Llama-3.1-8B-Instruct-{task}",
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
hash_to_row = {}
|
| 210 |
+
for row in ref_model_ds["latest"]:
|
| 211 |
+
hash_to_row[row["input_final_prompts_hash"][0]] = row
|
| 212 |
+
reordered_rows = []
|
| 213 |
+
for prompt_hash in ds_405b_hash_order:
|
| 214 |
+
reordered_rows.append(hash_to_row[prompt_hash])
|
| 215 |
+
ref_model_ds["latest"] = reordered_rows
|
| 216 |
+
return ref_model_ds
|
| 217 |
+
|
| 218 |
+
return ds_405b
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def analyze(task, response_path, model_size):
|
| 222 |
+
ds = get_dataset_from_task(task, response_path, model_size)
|
| 223 |
+
|
| 224 |
+
responses = []
|
| 225 |
+
total = len(ds["latest"])
|
| 226 |
+
|
| 227 |
+
for i in range(0, total):
|
| 228 |
+
response = pickle.load(
|
| 229 |
+
open(os.path.join(response_path, f"response_{i}.pkl"), "rb")
|
| 230 |
+
)
|
| 231 |
+
responses.append(response)
|
| 232 |
+
|
| 233 |
+
@dataclass
|
| 234 |
+
class Stats:
|
| 235 |
+
correct: int = 0
|
| 236 |
+
total: int = 0
|
| 237 |
+
meta_correct: int = 0
|
| 238 |
+
|
| 239 |
+
average: float = None
|
| 240 |
+
|
| 241 |
+
subtask_name_to_stats = defaultdict(lambda: Stats())
|
| 242 |
+
|
| 243 |
+
for response, ds_row in zip(responses, ds["latest"]):
|
| 244 |
+
model_answer = TASK_TO_ANSWER_EXTRACTOR[task](response)
|
| 245 |
+
|
| 246 |
+
subtask = ds_row["subtask_name"]
|
| 247 |
+
|
| 248 |
+
is_eval_correct = model_answer in ds_row["input_correct_responses"]
|
| 249 |
+
if is_eval_correct:
|
| 250 |
+
subtask_name_to_stats[subtask].correct += 1
|
| 251 |
+
|
| 252 |
+
if ds_row["is_correct"]:
|
| 253 |
+
subtask_name_to_stats[subtask].meta_correct += 1
|
| 254 |
+
|
| 255 |
+
subtask_name_to_stats[subtask].total += 1
|
| 256 |
+
|
| 257 |
+
micro_stats = Stats()
|
| 258 |
+
for subtask, stats in subtask_name_to_stats.items():
|
| 259 |
+
stats.average = stats.correct / stats.total
|
| 260 |
+
stats.meta_average = stats.meta_correct / stats.total
|
| 261 |
+
|
| 262 |
+
micro_stats.correct += stats.correct
|
| 263 |
+
micro_stats.total += stats.total
|
| 264 |
+
micro_stats.meta_correct += stats.meta_correct
|
| 265 |
+
|
| 266 |
+
micro_stats.average = micro_stats.correct / micro_stats.total
|
| 267 |
+
micro_stats.meta_average = micro_stats.meta_correct / micro_stats.total
|
| 268 |
+
|
| 269 |
+
print("Macro average", np.mean([x.average for x in subtask_name_to_stats.values()]))
|
| 270 |
+
print(
|
| 271 |
+
"Meta Macro average",
|
| 272 |
+
np.mean([x.meta_average for x in subtask_name_to_stats.values()]),
|
| 273 |
+
)
|
| 274 |
+
print("Micro average", micro_stats.average)
|
| 275 |
+
print("Meta Micro average", micro_stats.meta_average)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# Entry point for the script
|
| 279 |
+
if __name__ == "__main__":
|
| 280 |
+
parser = argparse.ArgumentParser(
|
| 281 |
+
description="Script to run model with specified parameters."
|
| 282 |
+
)
|
| 283 |
+
parser.add_argument(
|
| 284 |
+
"--model-size",
|
| 285 |
+
type=str,
|
| 286 |
+
default="8b",
|
| 287 |
+
help="Size of the model (e.g., 8b or 70b)",
|
| 288 |
+
)
|
| 289 |
+
parser.add_argument(
|
| 290 |
+
"--provider",
|
| 291 |
+
type=str,
|
| 292 |
+
default="sgl",
|
| 293 |
+
help="Provider name (e.g., sgl, oai, b10)",
|
| 294 |
+
)
|
| 295 |
+
parser.add_argument(
|
| 296 |
+
"--task",
|
| 297 |
+
type=str,
|
| 298 |
+
required=True,
|
| 299 |
+
help="Task (e.g., mmlu, mmlu_cot, mmlu_pro, gsm8k)",
|
| 300 |
+
)
|
| 301 |
+
parser.add_argument(
|
| 302 |
+
"--num-examples", type=int, default=None, help="Number of examples to process"
|
| 303 |
+
)
|
| 304 |
+
parser.add_argument("--concurrency", type=int, default=16)
|
| 305 |
+
parser.add_argument(
|
| 306 |
+
"--output-dir",
|
| 307 |
+
type=str,
|
| 308 |
+
default="tmp-output-dir",
|
| 309 |
+
help="Directory to save responses",
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
args = parser.parse_args()
|
| 313 |
+
asyncio.run(benchmark(args))
|
| 314 |
+
analyze(TASK_TO_EVAL_SET[args.task], args.output_dir, args.model_size)
|
| 315 |
+
shutil.rmtree("tmp-output-dir", ignore_errors=True)
|
sglang/eval/loogle_eval.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import asyncio
|
| 3 |
+
import os
|
| 4 |
+
import pickle
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
import openai
|
| 9 |
+
import torch
|
| 10 |
+
from bert_score import BERTScorer
|
| 11 |
+
from datasets import load_dataset
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_client(api_url: str) -> openai.AsyncOpenAI:
|
| 16 |
+
if os.getenv("OPENAI_API_KEY") is None:
|
| 17 |
+
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
| 18 |
+
return openai.AsyncOpenAI(base_url=api_url)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_dataset():
|
| 22 |
+
return load_dataset("bigai-nlco/LooGLE", "longdep_qa", split="test")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
async def fetch_response(
|
| 26 |
+
client: openai.AsyncOpenAI,
|
| 27 |
+
context: str,
|
| 28 |
+
question: str,
|
| 29 |
+
semaphore: asyncio.Semaphore,
|
| 30 |
+
index: int,
|
| 31 |
+
model: str,
|
| 32 |
+
output_dir: Path,
|
| 33 |
+
):
|
| 34 |
+
output_file = output_dir / f"response_{index}.pkl"
|
| 35 |
+
if output_file.exists():
|
| 36 |
+
return
|
| 37 |
+
|
| 38 |
+
prompt = (
|
| 39 |
+
"Please answer the question based on the long texts below.\n"
|
| 40 |
+
f"{context}\n"
|
| 41 |
+
f"Question: {question}\n"
|
| 42 |
+
"Answer:"
|
| 43 |
+
)
|
| 44 |
+
messages = [
|
| 45 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 46 |
+
{"role": "user", "content": prompt},
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
async with semaphore:
|
| 50 |
+
try:
|
| 51 |
+
response = await client.chat.completions.create(
|
| 52 |
+
model=model,
|
| 53 |
+
messages=messages,
|
| 54 |
+
temperature=0.0,
|
| 55 |
+
max_tokens=512,
|
| 56 |
+
)
|
| 57 |
+
except openai.BadRequestError as e:
|
| 58 |
+
with open(output_file, "wb") as f:
|
| 59 |
+
pickle.dump({"error": str(e)}, f)
|
| 60 |
+
return
|
| 61 |
+
|
| 62 |
+
with open(output_file, "wb") as f:
|
| 63 |
+
pickle.dump(response, f)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
async def benchmark(args):
|
| 67 |
+
dataset = get_dataset()
|
| 68 |
+
output_dir = Path(args.output_dir)
|
| 69 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 70 |
+
|
| 71 |
+
client = get_client(args.api_url)
|
| 72 |
+
semaphore = asyncio.Semaphore(args.max_concurrency)
|
| 73 |
+
|
| 74 |
+
tasks: List[asyncio.Task] = []
|
| 75 |
+
for idx, ex in enumerate(dataset):
|
| 76 |
+
if idx >= args.num_prompts:
|
| 77 |
+
break
|
| 78 |
+
tasks.append(
|
| 79 |
+
asyncio.create_task(
|
| 80 |
+
fetch_response(
|
| 81 |
+
client,
|
| 82 |
+
ex["context"],
|
| 83 |
+
ex["question"],
|
| 84 |
+
semaphore,
|
| 85 |
+
idx,
|
| 86 |
+
args.model,
|
| 87 |
+
output_dir,
|
| 88 |
+
)
|
| 89 |
+
)
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
for _ in tqdm(
|
| 93 |
+
asyncio.as_completed(tasks), total=len(tasks), desc="Running benchmark"
|
| 94 |
+
):
|
| 95 |
+
await _
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def analyse(args):
|
| 99 |
+
dataset = get_dataset()
|
| 100 |
+
output_dir = Path(args.output_dir)
|
| 101 |
+
|
| 102 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 103 |
+
scorer = BERTScorer(lang="en", device=device)
|
| 104 |
+
|
| 105 |
+
hyps: List[str] = []
|
| 106 |
+
refs: List[str] = []
|
| 107 |
+
for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
|
| 108 |
+
if idx >= args.num_prompts:
|
| 109 |
+
break
|
| 110 |
+
pkl_file = output_dir / f"response_{idx}.pkl"
|
| 111 |
+
if not pkl_file.exists():
|
| 112 |
+
raise FileNotFoundError(pkl_file)
|
| 113 |
+
|
| 114 |
+
response = pickle.load(open(pkl_file, "rb"))
|
| 115 |
+
if isinstance(response, dict) and "error" in response:
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
hyps.append(response.choices[0].message.content.strip())
|
| 119 |
+
refs.append(ex["answer"])
|
| 120 |
+
|
| 121 |
+
if not hyps:
|
| 122 |
+
print("No valid responses to score!")
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
batch_size = 64
|
| 126 |
+
all_f1: List[float] = []
|
| 127 |
+
for i in tqdm(range(0, len(hyps), batch_size), desc="Scoring batches"):
|
| 128 |
+
h_batch = hyps[i : i + batch_size]
|
| 129 |
+
r_batch = refs[i : i + batch_size]
|
| 130 |
+
_, _, f1_scores = scorer.score(h_batch, r_batch, verbose=False)
|
| 131 |
+
all_f1.extend([float(x) for x in f1_scores])
|
| 132 |
+
|
| 133 |
+
avg = sum(all_f1) / len(all_f1)
|
| 134 |
+
print(f"Average BERTScore (F1): {avg:.2%}")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == "__main__":
|
| 138 |
+
parser = argparse.ArgumentParser(
|
| 139 |
+
description="Run benchmark and evaluation in one go."
|
| 140 |
+
)
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--api-url",
|
| 143 |
+
default="http://127.0.0.1:30000/v1",
|
| 144 |
+
help="OpenAI‑compatible API base URL",
|
| 145 |
+
)
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--model",
|
| 148 |
+
default="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
|
| 149 |
+
help="Model name or ID, only used for model name",
|
| 150 |
+
)
|
| 151 |
+
parser.add_argument(
|
| 152 |
+
"--max-concurrency", type=int, default=144, help="Maximum concurrent requests"
|
| 153 |
+
)
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
"--output-dir", default="tmp-output-dir", help="Directory for cached responses"
|
| 156 |
+
)
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--num-prompts", type=int, default=10000, help="Number of prompts to run"
|
| 159 |
+
)
|
| 160 |
+
args = parser.parse_args()
|
| 161 |
+
|
| 162 |
+
asyncio.run(benchmark(args))
|
| 163 |
+
|
| 164 |
+
analyse(args)
|
sglang/global_config.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Global configurations"""
|
| 2 |
+
|
| 3 |
+
# FIXME: deprecate this file and move all usage to sglang.srt.environ or sglang.__init__.py
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class GlobalConfig:
|
| 7 |
+
"""
|
| 8 |
+
Store some global constants.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self):
|
| 12 |
+
# Verbosity level
|
| 13 |
+
# 0: do not output anything
|
| 14 |
+
# 2: output final text after every run
|
| 15 |
+
self.verbosity = 0
|
| 16 |
+
|
| 17 |
+
# Default backend of the language
|
| 18 |
+
self.default_backend = None
|
| 19 |
+
|
| 20 |
+
# Output tokenization configs
|
| 21 |
+
self.skip_special_tokens_in_output = True
|
| 22 |
+
self.spaces_between_special_tokens_in_out = True
|
| 23 |
+
|
| 24 |
+
# Language frontend interpreter optimization configs
|
| 25 |
+
self.enable_precache_with_tracing = True
|
| 26 |
+
self.enable_parallel_encoding = True
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
global_config = GlobalConfig()
|
sglang/jit_kernel/.clang-format
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
BasedOnStyle: Google
|
| 2 |
+
IndentWidth: 2
|
| 3 |
+
ColumnLimit: 120
|
| 4 |
+
AllowShortFunctionsOnASingleLine: Empty
|
| 5 |
+
DerivePointerAlignment: false
|
| 6 |
+
PointerAlignment: Left
|
| 7 |
+
NamespaceIndentation: None
|
| 8 |
+
SortIncludes: true
|
| 9 |
+
AllowShortLoopsOnASingleLine: false
|
| 10 |
+
BinPackParameters: false # Prevents packing parameters in declarations
|
| 11 |
+
BinPackArguments: false # Prevents packing arguments in function calls
|
| 12 |
+
AlignAfterOpenBracket: AlwaysBreak # Forces a break after the opening parenthesis
|
| 13 |
+
AlignOperands: Align # Aligns arguments vertically
|
| 14 |
+
PenaltyBreakBeforeFirstCallParameter: 1 # Encourages breaking before the first argument
|
| 15 |
+
PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name
|
| 16 |
+
|
| 17 |
+
IncludeCategories:
|
| 18 |
+
- Regex: '^<sgl_kernel/.*>$'
|
| 19 |
+
Priority: 0
|
sglang/jit_kernel/__pycache__/hicache.cpython-311.pyc
ADDED
|
Binary file (4.35 kB). View file
|
|
|
sglang/jit_kernel/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (6.87 kB). View file
|
|
|
sglang/jit_kernel/csrc/cuda_wait_value.cuh
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <sgl_kernel/tensor.h>
|
| 2 |
+
#include <sgl_kernel/utils.cuh>
|
| 3 |
+
|
| 4 |
+
#include <cuda_runtime_api.h>
|
| 5 |
+
|
| 6 |
+
#include <cstdint>
|
| 7 |
+
|
| 8 |
+
namespace {
|
| 9 |
+
|
| 10 |
+
__global__ void wait_flag_kernel(const int32_t* flag, int32_t target) {
|
| 11 |
+
const volatile int32_t* vflag = (volatile const int32_t*)flag;
|
| 12 |
+
|
| 13 |
+
while (*vflag != target) {
|
| 14 |
+
#if __CUDA_ARCH__ >= 700
|
| 15 |
+
__nanosleep(100);
|
| 16 |
+
#else
|
| 17 |
+
// Note: This falls back to an inefficient busy-wait on pre-Volta architectures.
|
| 18 |
+
#endif
|
| 19 |
+
}
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
auto stream_wait_value(const tvm::ffi::TensorView flag, std::int32_t value) -> void {
|
| 23 |
+
using namespace host;
|
| 24 |
+
|
| 25 |
+
auto length = SymbolicSize{"length"};
|
| 26 |
+
TensorMatcher({length}).with_dtype<int32_t>().with_device<kDLCUDA>().verify(flag);
|
| 27 |
+
RuntimeCheck(length.unwrap() >= 1, "wait_flag expects a non-empty tensor.");
|
| 28 |
+
|
| 29 |
+
auto* ptr = static_cast<std::int32_t*>(flag.data_ptr());
|
| 30 |
+
const auto stream = LaunchKernel::resolve_device(flag.device());
|
| 31 |
+
|
| 32 |
+
constexpr int blocks = 1;
|
| 33 |
+
constexpr int threads = 1;
|
| 34 |
+
wait_flag_kernel<<<blocks, threads, 0, stream>>>(ptr, value);
|
| 35 |
+
RuntimeDeviceCheck(cudaGetLastError());
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
} // namespace
|
sglang/jit_kernel/csrc/hicache.cuh
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <sgl_kernel/tensor.h>
|
| 2 |
+
#include <sgl_kernel/utils.cuh>
|
| 3 |
+
#include <sgl_kernel/utils.h>
|
| 4 |
+
#include <sgl_kernel/warp.cuh>
|
| 5 |
+
|
| 6 |
+
#include <dlpack/dlpack.h>
|
| 7 |
+
|
| 8 |
+
#include <algorithm>
|
| 9 |
+
#include <concepts>
|
| 10 |
+
#include <cstddef>
|
| 11 |
+
#include <cstdint>
|
| 12 |
+
#include <type_traits>
|
| 13 |
+
|
| 14 |
+
namespace {
|
| 15 |
+
|
| 16 |
+
struct HicacheKernelParams {
|
| 17 |
+
void* __restrict__ k_cache_dst;
|
| 18 |
+
void* __restrict__ v_cache_dst;
|
| 19 |
+
const void* __restrict__ indices_dst;
|
| 20 |
+
void* __restrict__ k_cache_src;
|
| 21 |
+
void* __restrict__ v_cache_src;
|
| 22 |
+
const void* __restrict__ indices_src;
|
| 23 |
+
std::size_t length;
|
| 24 |
+
std::size_t kv_cache_src_stride;
|
| 25 |
+
std::size_t kv_cache_dst_stride;
|
| 26 |
+
std::size_t num_layers = 0; // only used in all_layer transfer
|
| 27 |
+
};
|
| 28 |
+
|
| 29 |
+
template <
|
| 30 |
+
std::integral T,
|
| 31 |
+
std::size_t kElementSize,
|
| 32 |
+
std::size_t kUnroll,
|
| 33 |
+
std::size_t kBlockQuota,
|
| 34 |
+
std::size_t kNumThreads,
|
| 35 |
+
std::size_t kMaxOccupancy>
|
| 36 |
+
__global__ __launch_bounds__(kNumThreads, kMaxOccupancy) void hicache_transfer_per_layer(
|
| 37 |
+
const __grid_constant__ HicacheKernelParams params) {
|
| 38 |
+
// each warp acts as a worker
|
| 39 |
+
using namespace device;
|
| 40 |
+
static_assert(kNumThreads % kWarpThreads == 0);
|
| 41 |
+
static_assert(kWarpThreads % kUnroll == 0);
|
| 42 |
+
|
| 43 |
+
constexpr auto kWarpThreads = device::kWarpThreads / kUnroll;
|
| 44 |
+
constexpr auto kWarpsPerBlock = kNumThreads / kWarpThreads;
|
| 45 |
+
constexpr auto kWorkers = kWarpsPerBlock * kBlockQuota;
|
| 46 |
+
|
| 47 |
+
const auto& [
|
| 48 |
+
k_cache_dst, v_cache_dst, indices_dst, // dst
|
| 49 |
+
k_cache_src, v_cache_src, indices_src, // src
|
| 50 |
+
length, kv_cache_src_stride, kv_cache_dst_stride, _ // metadata
|
| 51 |
+
] = params;
|
| 52 |
+
const auto warp_id = blockIdx.x * kWarpsPerBlock + threadIdx.x / kWarpThreads;
|
| 53 |
+
|
| 54 |
+
// force to transfer 128 bytes per iteration
|
| 55 |
+
// since the PCIe transaction size is 128 bytes aligned
|
| 56 |
+
constexpr auto kGranularity = 128 / kWarpThreads;
|
| 57 |
+
|
| 58 |
+
for (auto i = warp_id; i < length; i += kWorkers) {
|
| 59 |
+
const auto pos_src = static_cast<const T*>(indices_src)[i];
|
| 60 |
+
const auto pos_dst = static_cast<const T*>(indices_dst)[i];
|
| 61 |
+
const auto src_k = pointer::offset(k_cache_src, pos_src * kv_cache_src_stride);
|
| 62 |
+
const auto dst_k = pointer::offset(k_cache_dst, pos_dst * kv_cache_dst_stride);
|
| 63 |
+
const auto src_v = pointer::offset(v_cache_src, pos_src * kv_cache_src_stride);
|
| 64 |
+
const auto dst_v = pointer::offset(v_cache_dst, pos_dst * kv_cache_dst_stride);
|
| 65 |
+
const auto vec_k = warp::load_vec<kElementSize, kGranularity, kWarpThreads>(src_k);
|
| 66 |
+
const auto vec_v = warp::load_vec<kElementSize, kGranularity, kWarpThreads>(src_v);
|
| 67 |
+
warp::store_vec<kElementSize, kGranularity, kWarpThreads>(dst_k, vec_k);
|
| 68 |
+
warp::store_vec<kElementSize, kGranularity, kWarpThreads>(dst_v, vec_v);
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
template <
|
| 73 |
+
std::integral T,
|
| 74 |
+
std::size_t kElementSize,
|
| 75 |
+
std::size_t kUnroll,
|
| 76 |
+
std::size_t kBlockQuota,
|
| 77 |
+
std::size_t kNumThreads,
|
| 78 |
+
std::size_t kMaxOccupancy>
|
| 79 |
+
__global__ __launch_bounds__(kNumThreads, kMaxOccupancy) void hicache_transfer_all_layer(
|
| 80 |
+
const __grid_constant__ HicacheKernelParams params) {
|
| 81 |
+
// each warp acts as a worker
|
| 82 |
+
using namespace device;
|
| 83 |
+
using src_ptr_t = std::add_pointer_t<const void* const>;
|
| 84 |
+
using dst_ptr_t = std::add_pointer_t<void* const>;
|
| 85 |
+
|
| 86 |
+
static_assert(kNumThreads % kWarpThreads == 0);
|
| 87 |
+
constexpr auto kWarpThreads = device::kWarpThreads / kUnroll;
|
| 88 |
+
constexpr auto kWarpsPerBlock = static_cast<uint32_t>(kNumThreads) / kWarpThreads;
|
| 89 |
+
constexpr auto kWorkers = kWarpsPerBlock * kBlockQuota;
|
| 90 |
+
|
| 91 |
+
const auto& [
|
| 92 |
+
k_ptr_dst, v_ptr_dst, indices_dst, // dst
|
| 93 |
+
k_ptr_src, v_ptr_src, indices_src, // src
|
| 94 |
+
length, kv_cache_src_stride, kv_cache_dst_stride, num_layers // metadata
|
| 95 |
+
] = params;
|
| 96 |
+
const auto warp_id = blockIdx.x * kWarpsPerBlock + threadIdx.x / kWarpThreads;
|
| 97 |
+
|
| 98 |
+
// force to transfer 128 bytes per iteration
|
| 99 |
+
// since the PCIe transaction size is 128 bytes aligned
|
| 100 |
+
constexpr auto kGranularity = 128 / kWarpThreads;
|
| 101 |
+
|
| 102 |
+
for (auto i = warp_id; i < length; i += kWorkers) {
|
| 103 |
+
const auto pos_src = static_cast<const T*>(indices_src)[i];
|
| 104 |
+
const auto pos_dst = static_cast<const T*>(indices_dst)[i];
|
| 105 |
+
for (std::size_t layer = 0; layer < num_layers; ++layer) {
|
| 106 |
+
const auto k_cache_src = static_cast<src_ptr_t>(k_ptr_src)[layer];
|
| 107 |
+
const auto v_cache_src = static_cast<src_ptr_t>(v_ptr_src)[layer];
|
| 108 |
+
const auto k_cache_dst = static_cast<dst_ptr_t>(k_ptr_dst)[layer];
|
| 109 |
+
const auto v_cache_dst = static_cast<dst_ptr_t>(v_ptr_dst)[layer];
|
| 110 |
+
const auto src_k = pointer::offset(k_cache_src, pos_src * kv_cache_src_stride);
|
| 111 |
+
const auto dst_k = pointer::offset(k_cache_dst, pos_dst * kv_cache_dst_stride);
|
| 112 |
+
const auto src_v = pointer::offset(v_cache_src, pos_src * kv_cache_src_stride);
|
| 113 |
+
const auto dst_v = pointer::offset(v_cache_dst, pos_dst * kv_cache_dst_stride);
|
| 114 |
+
const auto vec_k = warp::load_vec<kElementSize, kGranularity, kWarpThreads>(src_k);
|
| 115 |
+
const auto vec_v = warp::load_vec<kElementSize, kGranularity, kWarpThreads>(src_v);
|
| 116 |
+
warp::store_vec<kElementSize, kGranularity, kWarpThreads>(dst_k, vec_k);
|
| 117 |
+
warp::store_vec<kElementSize, kGranularity, kWarpThreads>(dst_v, vec_v);
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
template <
|
| 123 |
+
std::size_t kElementSize,
|
| 124 |
+
std::size_t kUnroll,
|
| 125 |
+
std::size_t kBlockQuota,
|
| 126 |
+
std::size_t kNumThreads,
|
| 127 |
+
std::size_t kMaxOccupancy>
|
| 128 |
+
struct HiCacheKernel {
|
| 129 |
+
template <typename T>
|
| 130 |
+
static constexpr auto _kernel_one =
|
| 131 |
+
hicache_transfer_per_layer<T, kElementSize, kUnroll, kBlockQuota, kNumThreads, kMaxOccupancy>;
|
| 132 |
+
template <typename T>
|
| 133 |
+
static constexpr auto _kernel_all =
|
| 134 |
+
hicache_transfer_all_layer<T, kElementSize, kUnroll, kBlockQuota, kNumThreads, kMaxOccupancy>;
|
| 135 |
+
|
| 136 |
+
static void run_one(
|
| 137 |
+
const tvm::ffi::TensorView k_cache_dst,
|
| 138 |
+
const tvm::ffi::TensorView v_cache_dst,
|
| 139 |
+
const tvm::ffi::TensorView indices_dst,
|
| 140 |
+
const tvm::ffi::TensorView k_cache_src,
|
| 141 |
+
const tvm::ffi::TensorView v_cache_src,
|
| 142 |
+
const tvm::ffi::TensorView indices_src) {
|
| 143 |
+
using namespace host;
|
| 144 |
+
|
| 145 |
+
auto D = SymbolicSize{"D"}; // cache dimension
|
| 146 |
+
auto N = SymbolicSize{"N"}; // src kv stride
|
| 147 |
+
auto M = SymbolicSize{"M"}; // dst kv stride
|
| 148 |
+
auto L = SymbolicSize{"L"}; // indices length
|
| 149 |
+
auto cache_dtype = SymbolicDType{};
|
| 150 |
+
auto indices_dtype = SymbolicDType{};
|
| 151 |
+
auto indices_device = SymbolicDevice{};
|
| 152 |
+
|
| 153 |
+
TensorMatcher({-1, D}) //
|
| 154 |
+
.with_strides({N, 1})
|
| 155 |
+
.with_dtype(cache_dtype)
|
| 156 |
+
.with_device<kDLCUDA, kDLCUDAHost, kDLCPU>()
|
| 157 |
+
.verify(k_cache_src)
|
| 158 |
+
.verify(v_cache_src);
|
| 159 |
+
TensorMatcher({-1, D}) //
|
| 160 |
+
.with_strides({M, 1})
|
| 161 |
+
.with_dtype(cache_dtype)
|
| 162 |
+
.with_device<kDLCUDA, kDLCUDAHost, kDLCPU>()
|
| 163 |
+
.verify(k_cache_dst)
|
| 164 |
+
.verify(v_cache_dst);
|
| 165 |
+
TensorMatcher({L}) //
|
| 166 |
+
.with_dtype<int32_t, int64_t>(indices_dtype)
|
| 167 |
+
.with_device<kDLCUDA>(indices_device)
|
| 168 |
+
.verify(indices_src)
|
| 169 |
+
.verify(indices_dst);
|
| 170 |
+
|
| 171 |
+
// verify dimension match
|
| 172 |
+
const auto dtype_size = dtype_bytes(cache_dtype.unwrap());
|
| 173 |
+
const auto element_bytes = D.unwrap() * dtype_size;
|
| 174 |
+
RuntimeCheck(kElementSize == element_bytes, "HicacheKernel: cache dimension mismatch.");
|
| 175 |
+
|
| 176 |
+
const auto k_cache_dst_ptr = k_cache_dst.data_ptr();
|
| 177 |
+
const auto v_cache_dst_ptr = v_cache_dst.data_ptr();
|
| 178 |
+
const auto k_cache_src_ptr = k_cache_src.data_ptr();
|
| 179 |
+
const auto v_cache_src_ptr = v_cache_src.data_ptr();
|
| 180 |
+
const auto indices_dst_ptr = indices_dst.data_ptr();
|
| 181 |
+
const auto indices_src_ptr = indices_src.data_ptr();
|
| 182 |
+
const auto length = static_cast<std::size_t>(L.unwrap());
|
| 183 |
+
const auto kv_cache_src_stride = static_cast<std::size_t>(N.unwrap()) * dtype_size;
|
| 184 |
+
const auto kv_cache_dst_stride = static_cast<std::size_t>(M.unwrap()) * dtype_size;
|
| 185 |
+
const auto use_int32 = indices_dtype.unwrap().bits == 32;
|
| 186 |
+
const auto device = indices_device.unwrap();
|
| 187 |
+
|
| 188 |
+
constexpr auto kWorkersPerBlock = kNumThreads / (device::kWarpThreads / kUnroll);
|
| 189 |
+
const auto num_blocks = std::min(div_ceil(length, kWorkersPerBlock), kBlockQuota);
|
| 190 |
+
const auto params = HicacheKernelParams{
|
| 191 |
+
.k_cache_dst = k_cache_dst_ptr,
|
| 192 |
+
.v_cache_dst = v_cache_dst_ptr,
|
| 193 |
+
.indices_dst = indices_dst_ptr,
|
| 194 |
+
.k_cache_src = k_cache_src_ptr,
|
| 195 |
+
.v_cache_src = v_cache_src_ptr,
|
| 196 |
+
.indices_src = indices_src_ptr,
|
| 197 |
+
.length = length,
|
| 198 |
+
.kv_cache_src_stride = kv_cache_src_stride,
|
| 199 |
+
.kv_cache_dst_stride = kv_cache_dst_stride,
|
| 200 |
+
};
|
| 201 |
+
const auto kernel = use_int32 ? _kernel_one<int32_t> : _kernel_one<int64_t>;
|
| 202 |
+
LaunchKernel(num_blocks, kNumThreads, device)(kernel, params);
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
static void run_all(
|
| 206 |
+
const tvm::ffi::TensorView k_ptr_dst,
|
| 207 |
+
const tvm::ffi::TensorView v_ptr_dst,
|
| 208 |
+
const tvm::ffi::TensorView indices_dst,
|
| 209 |
+
const tvm::ffi::TensorView k_ptr_src,
|
| 210 |
+
const tvm::ffi::TensorView v_ptr_src,
|
| 211 |
+
const tvm::ffi::TensorView indices_src,
|
| 212 |
+
const std::size_t kv_src_stride,
|
| 213 |
+
const std::size_t kv_dst_stride) {
|
| 214 |
+
using namespace host;
|
| 215 |
+
|
| 216 |
+
auto N = SymbolicSize{"N"}; // num layers
|
| 217 |
+
auto L = SymbolicSize{"L"}; // indices length
|
| 218 |
+
auto dtype_ = SymbolicDType{};
|
| 219 |
+
auto device_ = SymbolicDevice{};
|
| 220 |
+
|
| 221 |
+
TensorMatcher({N}) //
|
| 222 |
+
.with_dtype<uint64_t>()
|
| 223 |
+
.with_device<kDLCUDA>(device_)
|
| 224 |
+
.verify(k_ptr_src)
|
| 225 |
+
.verify(v_ptr_src)
|
| 226 |
+
.verify(k_ptr_dst)
|
| 227 |
+
.verify(v_ptr_dst);
|
| 228 |
+
TensorMatcher({L}) //
|
| 229 |
+
.with_dtype<int32_t, int64_t>(dtype_)
|
| 230 |
+
.with_device<kDLCUDA>(device_)
|
| 231 |
+
.verify(indices_src)
|
| 232 |
+
.verify(indices_dst);
|
| 233 |
+
|
| 234 |
+
// verify dimension match
|
| 235 |
+
const auto k_cache_dst_ptr = k_ptr_dst.data_ptr();
|
| 236 |
+
const auto v_cache_dst_ptr = v_ptr_dst.data_ptr();
|
| 237 |
+
const auto k_cache_src_ptr = k_ptr_src.data_ptr();
|
| 238 |
+
const auto v_cache_src_ptr = v_ptr_src.data_ptr();
|
| 239 |
+
const auto indices_dst_ptr = indices_dst.data_ptr();
|
| 240 |
+
const auto indices_src_ptr = indices_src.data_ptr();
|
| 241 |
+
const auto length = static_cast<std::size_t>(L.unwrap());
|
| 242 |
+
const auto use_int32 = dtype_.unwrap().bits == 32;
|
| 243 |
+
const auto device = device_.unwrap();
|
| 244 |
+
|
| 245 |
+
constexpr auto kWorkersPerBlock = kNumThreads / (device::kWarpThreads / kUnroll);
|
| 246 |
+
const auto num_blocks = std::min(div_ceil(length, kWorkersPerBlock), kBlockQuota);
|
| 247 |
+
const auto params = HicacheKernelParams{
|
| 248 |
+
.k_cache_dst = k_cache_dst_ptr,
|
| 249 |
+
.v_cache_dst = v_cache_dst_ptr,
|
| 250 |
+
.indices_dst = indices_dst_ptr,
|
| 251 |
+
.k_cache_src = k_cache_src_ptr,
|
| 252 |
+
.v_cache_src = v_cache_src_ptr,
|
| 253 |
+
.indices_src = indices_src_ptr,
|
| 254 |
+
.length = length,
|
| 255 |
+
.kv_cache_src_stride = kv_src_stride,
|
| 256 |
+
.kv_cache_dst_stride = kv_dst_stride,
|
| 257 |
+
.num_layers = static_cast<std::size_t>(N.unwrap()),
|
| 258 |
+
};
|
| 259 |
+
const auto kernel = use_int32 ? _kernel_all<int32_t> : _kernel_all<int64_t>;
|
| 260 |
+
LaunchKernel(num_blocks, kNumThreads, device)(kernel, params);
|
| 261 |
+
}
|
| 262 |
+
};
|
| 263 |
+
|
| 264 |
+
} // namespace
|
sglang/jit_kernel/cuda_wait_value.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
from typing import TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from sglang.jit_kernel.utils import load_jit
|
| 9 |
+
|
| 10 |
+
if TYPE_CHECKING:
|
| 11 |
+
import torch
|
| 12 |
+
from tvm_ffi.module import Module
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@lru_cache(maxsize=1)
|
| 16 |
+
def _jit_stream_wait_value_module() -> Module:
|
| 17 |
+
return load_jit(
|
| 18 |
+
"cuda_wait_value",
|
| 19 |
+
cuda_files=["cuda_wait_value.cuh"],
|
| 20 |
+
cuda_wrappers=[("stream_wait_value", "stream_wait_value")],
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def stream_wait_value(flag: torch.Tensor, value: int) -> None:
|
| 25 |
+
module = _jit_stream_wait_value_module()
|
| 26 |
+
module.stream_wait_value(flag, value)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Event:
|
| 30 |
+
def __init__(self) -> None:
|
| 31 |
+
self.flag = torch.zeros(1, dtype=torch.int32, device="cuda")
|
| 32 |
+
|
| 33 |
+
def record(self, value: int = 1) -> None:
|
| 34 |
+
self.flag[0] = value
|
| 35 |
+
|
| 36 |
+
def wait(self, value: int = 1) -> None:
|
| 37 |
+
stream_wait_value(self.flag, value)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_wait_before_record(event: Event | torch.cuda.Event):
|
| 41 |
+
stream_a = torch.cuda.Stream()
|
| 42 |
+
stream_b = torch.cuda.Stream()
|
| 43 |
+
|
| 44 |
+
with torch.cuda.stream(stream_a):
|
| 45 |
+
event.wait()
|
| 46 |
+
|
| 47 |
+
stream_a.synchronize()
|
| 48 |
+
|
| 49 |
+
with torch.cuda.stream(stream_b):
|
| 50 |
+
event.record()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def main():
|
| 54 |
+
import threading
|
| 55 |
+
import time
|
| 56 |
+
|
| 57 |
+
block_thead = threading.Thread(
|
| 58 |
+
target=test_wait_before_record, args=(Event(),), daemon=True
|
| 59 |
+
)
|
| 60 |
+
block_thead.start()
|
| 61 |
+
|
| 62 |
+
non_block_thread = threading.Thread(
|
| 63 |
+
target=test_wait_before_record, args=(torch.cuda.Event(),)
|
| 64 |
+
)
|
| 65 |
+
non_block_thread.start()
|
| 66 |
+
|
| 67 |
+
print("Checking if custom Event blocks the stream...", flush=True)
|
| 68 |
+
for _ in range(5):
|
| 69 |
+
print(f"{block_thead.is_alive()=}, {non_block_thread.is_alive()=}", flush=True)
|
| 70 |
+
time.sleep(1)
|
| 71 |
+
|
| 72 |
+
assert block_thead.is_alive(), "Custom Event did not block as expected"
|
| 73 |
+
assert not non_block_thread.is_alive(), "torch.cuda.Event should not block"
|
| 74 |
+
print("=" * 40)
|
| 75 |
+
print("Test completed successfully.")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
main()
|
sglang/jit_kernel/hicache.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
from typing import TYPE_CHECKING
|
| 6 |
+
|
| 7 |
+
from sglang.jit_kernel.utils import load_jit, make_cpp_args
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
import torch
|
| 11 |
+
from tvm_ffi.module import Module
|
| 12 |
+
|
| 13 |
+
DEFAULT_BLOCK_QUOTA = 2
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@lru_cache(maxsize=None)
|
| 17 |
+
def _jit_hicache_module(*, element_size: int, unroll: int, block_quota: int) -> Module:
|
| 18 |
+
num_threads, occupancy = 1024, 1
|
| 19 |
+
args = make_cpp_args(
|
| 20 |
+
element_size,
|
| 21 |
+
unroll,
|
| 22 |
+
block_quota,
|
| 23 |
+
num_threads,
|
| 24 |
+
occupancy,
|
| 25 |
+
)
|
| 26 |
+
return load_jit(
|
| 27 |
+
"hicache",
|
| 28 |
+
*args,
|
| 29 |
+
cuda_files=["hicache.cuh"],
|
| 30 |
+
cuda_wrappers=[
|
| 31 |
+
("launch_one", f"HiCacheKernel<{args}>::run_one"),
|
| 32 |
+
("launch_all", f"HiCacheKernel<{args}>::run_all"),
|
| 33 |
+
],
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def can_use_hicache_jit_kernel(
|
| 38 |
+
*,
|
| 39 |
+
element_size: int,
|
| 40 |
+
unroll: int | None = None, # can be tuned for performance
|
| 41 |
+
block_quota: int | None = None, # can be tuned for less interference
|
| 42 |
+
) -> bool:
|
| 43 |
+
try:
|
| 44 |
+
unroll = unroll or _default_unroll(element_size)
|
| 45 |
+
block_quota = block_quota or DEFAULT_BLOCK_QUOTA
|
| 46 |
+
_jit_hicache_module(
|
| 47 |
+
element_size=element_size,
|
| 48 |
+
unroll=unroll,
|
| 49 |
+
block_quota=block_quota,
|
| 50 |
+
)
|
| 51 |
+
return True
|
| 52 |
+
except Exception as e:
|
| 53 |
+
logger = logging.getLogger(__name__)
|
| 54 |
+
logger.warning(f"Failed to load JIT HiCache kernel: {e}")
|
| 55 |
+
return False
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _default_unroll(element_size: int) -> int:
|
| 59 |
+
if element_size <= 512:
|
| 60 |
+
return 4
|
| 61 |
+
|
| 62 |
+
if element_size <= 1024:
|
| 63 |
+
return 2
|
| 64 |
+
|
| 65 |
+
# fallback: no unroll
|
| 66 |
+
return 1
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def transfer_hicache_one_layer(
|
| 70 |
+
k_cache_dst: torch.Tensor,
|
| 71 |
+
v_cache_dst: torch.Tensor,
|
| 72 |
+
indices_dst: torch.Tensor,
|
| 73 |
+
k_cache_src: torch.Tensor,
|
| 74 |
+
v_cache_src: torch.Tensor,
|
| 75 |
+
indices_src: torch.Tensor,
|
| 76 |
+
*,
|
| 77 |
+
element_dim: int | None = None,
|
| 78 |
+
unroll: int | None = None, # can be tuned for performance
|
| 79 |
+
block_quota: int | None = None, # can be tuned for less interference
|
| 80 |
+
) -> None:
|
| 81 |
+
element_dim = element_dim or k_cache_dst.size(-1)
|
| 82 |
+
k_cache_src = k_cache_src.view(-1, element_dim)
|
| 83 |
+
v_cache_src = v_cache_src.view(-1, element_dim)
|
| 84 |
+
k_cache_dst = k_cache_dst.view(-1, element_dim)
|
| 85 |
+
v_cache_dst = v_cache_dst.view(-1, element_dim)
|
| 86 |
+
element_size = element_dim * k_cache_dst.element_size()
|
| 87 |
+
block_quota = block_quota or DEFAULT_BLOCK_QUOTA
|
| 88 |
+
unroll = unroll or _default_unroll(element_size)
|
| 89 |
+
module = _jit_hicache_module(
|
| 90 |
+
element_size=element_size,
|
| 91 |
+
unroll=unroll,
|
| 92 |
+
block_quota=block_quota,
|
| 93 |
+
)
|
| 94 |
+
module.launch_one(
|
| 95 |
+
k_cache_dst,
|
| 96 |
+
v_cache_dst,
|
| 97 |
+
indices_dst,
|
| 98 |
+
k_cache_src,
|
| 99 |
+
v_cache_src,
|
| 100 |
+
indices_src,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def transfer_hicache_all_layer(
|
| 105 |
+
k_ptr_dst: torch.Tensor,
|
| 106 |
+
v_ptr_dst: torch.Tensor,
|
| 107 |
+
indices_dst: torch.Tensor,
|
| 108 |
+
k_ptr_src: torch.Tensor,
|
| 109 |
+
v_ptr_src: torch.Tensor,
|
| 110 |
+
indices_src: torch.Tensor,
|
| 111 |
+
*,
|
| 112 |
+
kv_cache_src_stride_bytes: int,
|
| 113 |
+
kv_cache_dst_stride_bytes: int,
|
| 114 |
+
element_size: int | None = None,
|
| 115 |
+
unroll: int | None = None, # can be tuned for performance
|
| 116 |
+
block_quota: int | None = None, # can be tuned for less interference
|
| 117 |
+
) -> None:
|
| 118 |
+
if element_size is None: # assume both contiguous
|
| 119 |
+
assert kv_cache_dst_stride_bytes == kv_cache_src_stride_bytes
|
| 120 |
+
element_size = kv_cache_dst_stride_bytes
|
| 121 |
+
|
| 122 |
+
block_quota = block_quota or DEFAULT_BLOCK_QUOTA
|
| 123 |
+
unroll = unroll or _default_unroll(element_size)
|
| 124 |
+
module = _jit_hicache_module(
|
| 125 |
+
element_size=element_size,
|
| 126 |
+
unroll=unroll,
|
| 127 |
+
block_quota=block_quota,
|
| 128 |
+
)
|
| 129 |
+
module.launch_all(
|
| 130 |
+
k_ptr_dst,
|
| 131 |
+
v_ptr_dst,
|
| 132 |
+
indices_dst,
|
| 133 |
+
k_ptr_src,
|
| 134 |
+
v_ptr_src,
|
| 135 |
+
indices_src,
|
| 136 |
+
kv_cache_src_stride_bytes,
|
| 137 |
+
kv_cache_dst_stride_bytes,
|
| 138 |
+
)
|
sglang/jit_kernel/include/sgl_kernel/tensor.h
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <sgl_kernel/utils.h>
|
| 3 |
+
|
| 4 |
+
#include <dlpack/dlpack.h>
|
| 5 |
+
#include <tvm/ffi/container/tensor.h>
|
| 6 |
+
#include <tvm/ffi/dtype.h>
|
| 7 |
+
|
| 8 |
+
#include <algorithm>
|
| 9 |
+
#include <array>
|
| 10 |
+
#include <concepts>
|
| 11 |
+
#include <cstddef>
|
| 12 |
+
#include <cstdint>
|
| 13 |
+
#include <initializer_list>
|
| 14 |
+
#include <optional>
|
| 15 |
+
#include <ranges>
|
| 16 |
+
#include <source_location>
|
| 17 |
+
#include <span>
|
| 18 |
+
#include <sstream>
|
| 19 |
+
#include <string>
|
| 20 |
+
#include <string_view>
|
| 21 |
+
#include <type_traits>
|
| 22 |
+
#include <utility>
|
| 23 |
+
|
| 24 |
+
namespace host {
|
| 25 |
+
|
| 26 |
+
namespace stdr = std::ranges;
|
| 27 |
+
namespace stdv = std::views;
|
| 28 |
+
|
| 29 |
+
namespace details {
|
| 30 |
+
|
| 31 |
+
struct SizeRef;
|
| 32 |
+
struct DTypeRef;
|
| 33 |
+
struct DeviceRef;
|
| 34 |
+
|
| 35 |
+
template <typename T>
|
| 36 |
+
struct dtype_trait {};
|
| 37 |
+
|
| 38 |
+
template <std::integral T>
|
| 39 |
+
struct dtype_trait<T> {
|
| 40 |
+
inline static constexpr auto value = DLDataType{
|
| 41 |
+
.code = std::is_signed_v<T> ? DLDataTypeCode::kDLInt : DLDataTypeCode::kDLUInt,
|
| 42 |
+
.bits = static_cast<std::uint8_t>(sizeof(T) * 8),
|
| 43 |
+
.lanes = 1};
|
| 44 |
+
};
|
| 45 |
+
|
| 46 |
+
template <std::floating_point T>
|
| 47 |
+
struct dtype_trait<T> {
|
| 48 |
+
inline static constexpr auto value =
|
| 49 |
+
DLDataType{.code = DLDataTypeCode::kDLFloat, .bits = static_cast<std::uint8_t>(sizeof(T) * 8), .lanes = 1};
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
inline constexpr auto kAnyDeviceID = -1;
|
| 53 |
+
inline constexpr auto kAnySize = static_cast<int64_t>(-1);
|
| 54 |
+
inline constexpr auto kNullSize = static_cast<int64_t>(-1);
|
| 55 |
+
inline constexpr auto kNullDType = static_cast<DLDataTypeCode>(18u);
|
| 56 |
+
inline constexpr auto kNullDevice = static_cast<DLDeviceType>(-1);
|
| 57 |
+
|
| 58 |
+
template <typename... Ts>
|
| 59 |
+
inline constexpr auto kDTypeList = std::array<DLDataType, sizeof...(Ts)>{dtype_trait<Ts>::value...};
|
| 60 |
+
|
| 61 |
+
template <DLDeviceType... Codes>
|
| 62 |
+
inline constexpr auto kDeviceList = std::array<DLDevice, sizeof...(Codes)>{
|
| 63 |
+
DLDevice{.device_type = static_cast<DLDeviceType>(Codes), .device_id = kAnyDeviceID}...};
|
| 64 |
+
|
| 65 |
+
template <typename T>
|
| 66 |
+
struct PrintAbleSpan {
|
| 67 |
+
explicit PrintAbleSpan(std::span<const T> data) : data(data) {}
|
| 68 |
+
std::span<const T> data;
|
| 69 |
+
};
|
| 70 |
+
|
| 71 |
+
// define DLDataType comparison and printing in root namespace
|
| 72 |
+
inline constexpr auto kDeviceStringMap = [] {
|
| 73 |
+
constexpr auto map = std::array<std::pair<DLDeviceType, const char*>, 16>{
|
| 74 |
+
std::pair{DLDeviceType::kDLCPU, "cpu"},
|
| 75 |
+
std::pair{DLDeviceType::kDLCUDA, "cuda"},
|
| 76 |
+
std::pair{DLDeviceType::kDLCUDAHost, "cuda_host"},
|
| 77 |
+
std::pair{DLDeviceType::kDLOpenCL, "opencl"},
|
| 78 |
+
std::pair{DLDeviceType::kDLVulkan, "vulkan"},
|
| 79 |
+
std::pair{DLDeviceType::kDLMetal, "metal"},
|
| 80 |
+
std::pair{DLDeviceType::kDLVPI, "vpi"},
|
| 81 |
+
std::pair{DLDeviceType::kDLROCM, "rocm"},
|
| 82 |
+
std::pair{DLDeviceType::kDLROCMHost, "rocm_host"},
|
| 83 |
+
std::pair{DLDeviceType::kDLExtDev, "ext_dev"},
|
| 84 |
+
std::pair{DLDeviceType::kDLCUDAManaged, "cuda_managed"},
|
| 85 |
+
std::pair{DLDeviceType::kDLOneAPI, "oneapi"},
|
| 86 |
+
std::pair{DLDeviceType::kDLWebGPU, "webgpu"},
|
| 87 |
+
std::pair{DLDeviceType::kDLHexagon, "hexagon"},
|
| 88 |
+
std::pair{DLDeviceType::kDLMAIA, "maia"},
|
| 89 |
+
std::pair{DLDeviceType::kDLTrn, "trn"},
|
| 90 |
+
};
|
| 91 |
+
constexpr auto max_type = stdr::max(map | stdv::keys);
|
| 92 |
+
auto result = std::array<std::string_view, max_type + 1>{};
|
| 93 |
+
for (const auto& [code, name] : map) {
|
| 94 |
+
result[static_cast<std::size_t>(code)] = name;
|
| 95 |
+
}
|
| 96 |
+
return result;
|
| 97 |
+
}();
|
| 98 |
+
|
| 99 |
+
struct PrintableDevice {
|
| 100 |
+
DLDevice device;
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
inline auto& operator<<(std::ostream& os, DLDevice device) {
|
| 104 |
+
const auto& mapping = kDeviceStringMap;
|
| 105 |
+
const auto entry = static_cast<std::size_t>(device.device_type);
|
| 106 |
+
host::RuntimeCheck(entry < mapping.size());
|
| 107 |
+
const auto name = mapping[entry];
|
| 108 |
+
host::RuntimeCheck(!name.empty(), "Unknown device: ", int(device.device_type));
|
| 109 |
+
os << name;
|
| 110 |
+
if (device.device_id != kAnyDeviceID) os << "[" << device.device_id << "]";
|
| 111 |
+
return os;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
inline auto& operator<<(std::ostream& os, PrintableDevice pd) {
|
| 115 |
+
return os << pd.device;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
template <typename T>
|
| 119 |
+
inline auto& operator<<(std::ostream& os, PrintAbleSpan<T> span) {
|
| 120 |
+
os << "[";
|
| 121 |
+
for (const auto i : stdv::iota(std::size_t{0}, span.data.size())) {
|
| 122 |
+
if (i > 0) {
|
| 123 |
+
os << ", ";
|
| 124 |
+
}
|
| 125 |
+
os << span.data[i];
|
| 126 |
+
}
|
| 127 |
+
os << "]";
|
| 128 |
+
return os;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
} // namespace details
|
| 132 |
+
|
| 133 |
+
struct SymbolicSize {
|
| 134 |
+
public:
|
| 135 |
+
SymbolicSize(std::string_view annotation = {}) : m_value(details::kNullSize), m_annotation(annotation) {}
|
| 136 |
+
|
| 137 |
+
auto get_name() const -> std::string_view {
|
| 138 |
+
return m_annotation;
|
| 139 |
+
}
|
| 140 |
+
auto set_value(int64_t value) -> void {
|
| 141 |
+
host::RuntimeCheck(!this->has_value(), "Size value already set");
|
| 142 |
+
m_value = value;
|
| 143 |
+
}
|
| 144 |
+
auto has_value() const -> bool {
|
| 145 |
+
return m_value != details::kNullSize;
|
| 146 |
+
}
|
| 147 |
+
auto get_value() const -> std::optional<int64_t> {
|
| 148 |
+
return this->has_value() ? std::optional{m_value} : std::nullopt;
|
| 149 |
+
}
|
| 150 |
+
auto unwrap() const -> int64_t {
|
| 151 |
+
host::RuntimeCheck(this->has_value(), "Size value is not set");
|
| 152 |
+
return m_value;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
SymbolicSize(const SymbolicSize&) = delete;
|
| 156 |
+
SymbolicSize& operator=(const SymbolicSize&) = delete;
|
| 157 |
+
|
| 158 |
+
auto verify(int64_t dim) -> void {
|
| 159 |
+
if (this->has_value()) {
|
| 160 |
+
host::RuntimeCheck(m_value == dim, "Size mismatch: expected ", m_value, " but got ", dim);
|
| 161 |
+
} else {
|
| 162 |
+
this->set_value(dim);
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
private:
|
| 167 |
+
std::int64_t m_value;
|
| 168 |
+
std::string_view m_annotation;
|
| 169 |
+
};
|
| 170 |
+
|
| 171 |
+
inline auto operator==(DLDevice lhs, DLDevice rhs) -> bool {
|
| 172 |
+
return lhs.device_type == rhs.device_type && lhs.device_id == rhs.device_id;
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
struct SymbolicDType {
|
| 176 |
+
public:
|
| 177 |
+
SymbolicDType() : m_value({details::kNullDType, 0, 0}) {}
|
| 178 |
+
|
| 179 |
+
auto set_value(DLDataType value) -> void {
|
| 180 |
+
host::RuntimeCheck(!this->has_value(), "Dtype value already set");
|
| 181 |
+
host::RuntimeCheck(
|
| 182 |
+
m_check(value), "Dtype value [", value, "] not in the allowed options: ", details::PrintAbleSpan{m_options});
|
| 183 |
+
m_value = value;
|
| 184 |
+
}
|
| 185 |
+
auto has_value() const -> bool {
|
| 186 |
+
return m_value.code != details::kNullDType;
|
| 187 |
+
}
|
| 188 |
+
auto get_value() const -> std::optional<DLDataType> {
|
| 189 |
+
return this->has_value() ? std::optional{m_value} : std::nullopt;
|
| 190 |
+
}
|
| 191 |
+
auto unwrap() const -> DLDataType {
|
| 192 |
+
host::RuntimeCheck(this->has_value(), "Dtype value is not set");
|
| 193 |
+
return m_value;
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
auto set_options(std::span<const DLDataType> options) -> void {
|
| 197 |
+
m_options = options;
|
| 198 |
+
}
|
| 199 |
+
template <typename... Ts>
|
| 200 |
+
auto set_options() -> void {
|
| 201 |
+
m_options = details::kDTypeList<Ts...>;
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
auto verify(DLDataType dtype) -> void {
|
| 205 |
+
if (this->has_value()) {
|
| 206 |
+
host::RuntimeCheck(m_value == dtype, "DType mismatch: expected ", m_value, " but got ", dtype);
|
| 207 |
+
} else {
|
| 208 |
+
this->set_value(dtype);
|
| 209 |
+
}
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
private:
|
| 213 |
+
auto m_check(DLDataType value) const -> bool {
|
| 214 |
+
return stdr::empty(m_options) || (stdr::find(m_options, value) != stdr::end(m_options));
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
std::span<const DLDataType> m_options;
|
| 218 |
+
DLDataType m_value;
|
| 219 |
+
};
|
| 220 |
+
|
| 221 |
+
struct SymbolicDevice {
|
| 222 |
+
public:
|
| 223 |
+
SymbolicDevice() : m_value({details::kNullDevice, details::kAnyDeviceID}) {}
|
| 224 |
+
|
| 225 |
+
auto set_value(DLDevice value) -> void {
|
| 226 |
+
host::RuntimeCheck(!this->has_value(), "Device value already set");
|
| 227 |
+
host::RuntimeCheck(
|
| 228 |
+
m_check(value),
|
| 229 |
+
"Device value [",
|
| 230 |
+
details::PrintableDevice{value},
|
| 231 |
+
"] not in the allowed options: ",
|
| 232 |
+
details::PrintAbleSpan{m_options});
|
| 233 |
+
m_value = value;
|
| 234 |
+
}
|
| 235 |
+
auto has_value() const -> bool {
|
| 236 |
+
return m_value.device_type != details::kNullDevice;
|
| 237 |
+
}
|
| 238 |
+
auto get_value() const -> std::optional<DLDevice> {
|
| 239 |
+
return this->has_value() ? std::optional{m_value} : std::nullopt;
|
| 240 |
+
}
|
| 241 |
+
auto unwrap() const -> DLDevice {
|
| 242 |
+
host::RuntimeCheck(this->has_value(), "Device value is not set");
|
| 243 |
+
return m_value;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
auto set_options(std::span<const DLDevice> options) -> void {
|
| 247 |
+
m_options = options;
|
| 248 |
+
}
|
| 249 |
+
template <DLDeviceType... Codes>
|
| 250 |
+
auto set_options() -> void {
|
| 251 |
+
m_options = details::kDeviceList<Codes...>;
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
auto verify(DLDevice device) -> void {
|
| 255 |
+
if (this->has_value()) {
|
| 256 |
+
host::RuntimeCheck(
|
| 257 |
+
m_value == device,
|
| 258 |
+
"Device mismatch: expected ",
|
| 259 |
+
details::PrintableDevice{m_value},
|
| 260 |
+
" but got ",
|
| 261 |
+
details::PrintableDevice{device});
|
| 262 |
+
} else {
|
| 263 |
+
this->set_value(device);
|
| 264 |
+
}
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
private:
|
| 268 |
+
auto m_check(DLDevice value) const -> bool {
|
| 269 |
+
return stdr::empty(m_options) || (stdr::any_of(m_options, [value](const DLDevice& opt) {
|
| 270 |
+
// device type must exactly match
|
| 271 |
+
if (opt.device_type != value.device_type) return false;
|
| 272 |
+
// device id can be wildcarded
|
| 273 |
+
return opt.device_id == details::kAnyDeviceID || opt.device_id == value.device_id;
|
| 274 |
+
}));
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
std::span<const DLDevice> m_options;
|
| 278 |
+
DLDevice m_value;
|
| 279 |
+
};
|
| 280 |
+
|
| 281 |
+
namespace details {
|
| 282 |
+
|
| 283 |
+
template <typename T>
|
| 284 |
+
struct BaseRef {
|
| 285 |
+
public:
|
| 286 |
+
BaseRef(const BaseRef&) = delete;
|
| 287 |
+
BaseRef& operator=(const BaseRef&) = delete;
|
| 288 |
+
|
| 289 |
+
auto operator->() const -> T* {
|
| 290 |
+
return m_ref;
|
| 291 |
+
}
|
| 292 |
+
auto operator*() const -> T& {
|
| 293 |
+
return *m_ref;
|
| 294 |
+
}
|
| 295 |
+
auto rebind(T& other) -> void {
|
| 296 |
+
m_ref = &other;
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
explicit BaseRef() : m_ref(&m_cache), m_cache() {}
|
| 300 |
+
BaseRef(T& size) : m_ref(&size), m_cache() {}
|
| 301 |
+
|
| 302 |
+
private:
|
| 303 |
+
T* m_ref;
|
| 304 |
+
T m_cache;
|
| 305 |
+
};
|
| 306 |
+
|
| 307 |
+
struct SizeRef : BaseRef<SymbolicSize> {
|
| 308 |
+
using BaseRef::BaseRef;
|
| 309 |
+
SizeRef(int64_t value) {
|
| 310 |
+
if (value != kAnySize) {
|
| 311 |
+
(**this).set_value(value);
|
| 312 |
+
} else {
|
| 313 |
+
// otherwise, we can match any size
|
| 314 |
+
}
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
auto value_or_name(std::size_t dim) const -> std::string {
|
| 318 |
+
if (const auto value = (**this).get_value()) {
|
| 319 |
+
return std::to_string(*value);
|
| 320 |
+
} else {
|
| 321 |
+
const auto annotation = (**this).get_name();
|
| 322 |
+
if (annotation.empty()) {
|
| 323 |
+
return "dim#" + std::to_string(dim);
|
| 324 |
+
} else {
|
| 325 |
+
return static_cast<std::string>(annotation);
|
| 326 |
+
}
|
| 327 |
+
}
|
| 328 |
+
}
|
| 329 |
+
};
|
| 330 |
+
|
| 331 |
+
struct DTypeRef : BaseRef<SymbolicDType> {
|
| 332 |
+
using BaseRef::BaseRef;
|
| 333 |
+
DTypeRef(DLDataType options) {
|
| 334 |
+
(**this).set_value(options);
|
| 335 |
+
}
|
| 336 |
+
DTypeRef(std::initializer_list<DLDataType> options) {
|
| 337 |
+
(**this).set_options(options);
|
| 338 |
+
}
|
| 339 |
+
DTypeRef(std::span<const DLDataType> options) {
|
| 340 |
+
(**this).set_options(options);
|
| 341 |
+
}
|
| 342 |
+
};
|
| 343 |
+
|
| 344 |
+
struct DeviceRef : BaseRef<SymbolicDevice> {
|
| 345 |
+
using BaseRef::BaseRef;
|
| 346 |
+
DeviceRef(DLDevice options) {
|
| 347 |
+
(**this).set_value(options);
|
| 348 |
+
}
|
| 349 |
+
DeviceRef(std::initializer_list<DLDevice> options) {
|
| 350 |
+
(**this).set_options(options);
|
| 351 |
+
}
|
| 352 |
+
DeviceRef(std::span<const DLDevice> options) {
|
| 353 |
+
(**this).set_options(options);
|
| 354 |
+
}
|
| 355 |
+
};
|
| 356 |
+
|
| 357 |
+
} // namespace details
|
| 358 |
+
|
| 359 |
+
struct TensorMatcher {
|
| 360 |
+
private:
|
| 361 |
+
using SizeRef = details::SizeRef;
|
| 362 |
+
using DTypeRef = details::DTypeRef;
|
| 363 |
+
using DeviceRef = details::DeviceRef;
|
| 364 |
+
using Loc_t = std::source_location;
|
| 365 |
+
|
| 366 |
+
public:
|
| 367 |
+
TensorMatcher(const TensorMatcher&) = delete;
|
| 368 |
+
TensorMatcher& operator=(const TensorMatcher&) = delete;
|
| 369 |
+
|
| 370 |
+
explicit TensorMatcher(std::initializer_list<SizeRef> shape) : m_shape(shape), m_strides(), m_dtype() {}
|
| 371 |
+
|
| 372 |
+
auto with_strides(std::initializer_list<SizeRef> strides) && -> TensorMatcher&& {
|
| 373 |
+
// no partial update allowed
|
| 374 |
+
host::RuntimeCheck(m_strides.size() == 0, "Strides already specified");
|
| 375 |
+
host::RuntimeCheck(m_shape.size() == strides.size(), "Strides size must match shape size");
|
| 376 |
+
m_strides = strides;
|
| 377 |
+
return std::move(*this);
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
template <typename... Ts>
|
| 381 |
+
auto with_dtype(DTypeRef&& dtype) && -> TensorMatcher&& {
|
| 382 |
+
m_init_dtype();
|
| 383 |
+
m_dtype.rebind(*dtype);
|
| 384 |
+
return std::move(*this);
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
template <typename... Ts>
|
| 388 |
+
auto with_dtype() && -> TensorMatcher&& {
|
| 389 |
+
static_assert(sizeof...(Ts) > 0, "At least one dtype option must be specified");
|
| 390 |
+
m_init_dtype();
|
| 391 |
+
m_dtype->set_options<Ts...>();
|
| 392 |
+
return std::move(*this);
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
template <DLDeviceType... Codes>
|
| 396 |
+
auto with_device(DeviceRef&& device) && -> TensorMatcher&& {
|
| 397 |
+
m_init_device();
|
| 398 |
+
m_device.rebind(*device);
|
| 399 |
+
return std::move(*this);
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
template <DLDeviceType... Codes>
|
| 403 |
+
auto with_device() && -> TensorMatcher&& {
|
| 404 |
+
static_assert(sizeof...(Codes) > 0, "At least one device option must be specified");
|
| 405 |
+
m_init_device();
|
| 406 |
+
m_device->set_options<Codes...>();
|
| 407 |
+
return std::move(*this);
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
// once we start verification, we cannot modify anymore
|
| 411 |
+
auto verify(tvm::ffi::TensorView view, Loc_t loc = Loc_t::current()) const&& -> const TensorMatcher&& {
|
| 412 |
+
try {
|
| 413 |
+
this->m_verify_impl(view);
|
| 414 |
+
} catch (PanicError& e) {
|
| 415 |
+
auto oss = std::ostringstream{};
|
| 416 |
+
oss << "Tensor match failed for " << this->debug_str() << " at " << loc.file_name() << ":" << loc.line()
|
| 417 |
+
<< "\n- Root cause: " << e.detail();
|
| 418 |
+
throw PanicError(std::move(oss).str());
|
| 419 |
+
}
|
| 420 |
+
return std::move(*this);
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
auto debug_str() const -> std::string {
|
| 424 |
+
auto oss = std::ostringstream{};
|
| 425 |
+
oss << "Tensor<";
|
| 426 |
+
std::size_t dim = 0;
|
| 427 |
+
for (const auto& size_ref : m_shape) {
|
| 428 |
+
if (dim > 0) {
|
| 429 |
+
oss << ", ";
|
| 430 |
+
}
|
| 431 |
+
oss << size_ref.value_or_name(dim++);
|
| 432 |
+
}
|
| 433 |
+
oss << ">";
|
| 434 |
+
if (m_strides.size() > 0) {
|
| 435 |
+
oss << " [strides=<";
|
| 436 |
+
dim = 0;
|
| 437 |
+
for (const auto& stride_ref : m_strides) {
|
| 438 |
+
if (dim > 0) {
|
| 439 |
+
oss << ", ";
|
| 440 |
+
}
|
| 441 |
+
oss << stride_ref.value_or_name(dim++);
|
| 442 |
+
}
|
| 443 |
+
oss << ">]";
|
| 444 |
+
}
|
| 445 |
+
return std::move(oss).str();
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
private:
|
| 449 |
+
auto m_verify_impl(tvm::ffi::TensorView view) const -> void {
|
| 450 |
+
const auto dim = static_cast<std::size_t>(view.dim());
|
| 451 |
+
host::RuntimeCheck(dim == m_shape.size(), "Tensor dimension mismatch: expected ", m_shape.size(), " but got ", dim);
|
| 452 |
+
for (const auto i : stdv::iota(std::size_t{0}, dim)) {
|
| 453 |
+
m_shape[i]->verify(view.size(i));
|
| 454 |
+
}
|
| 455 |
+
if (this->m_has_strides()) {
|
| 456 |
+
for (const auto i : stdv::iota(std::size_t{0}, dim)) {
|
| 457 |
+
m_strides[i]->verify(view.stride(i));
|
| 458 |
+
}
|
| 459 |
+
} else {
|
| 460 |
+
host::RuntimeCheck(view.is_contiguous(), "Tensor is not contiguous as expected");
|
| 461 |
+
}
|
| 462 |
+
// since we may use the same matcher to verify again, we will force to check
|
| 463 |
+
m_dtype->verify(view.dtype());
|
| 464 |
+
m_device->verify(view.device());
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
auto m_init_dtype() -> void {
|
| 468 |
+
host::RuntimeCheck(!m_has_dtype, "DType already specified");
|
| 469 |
+
m_has_dtype = true;
|
| 470 |
+
}
|
| 471 |
+
auto m_init_device() -> void {
|
| 472 |
+
host::RuntimeCheck(!m_has_device, "Device already specified");
|
| 473 |
+
m_has_device = true;
|
| 474 |
+
}
|
| 475 |
+
auto m_has_strides() const -> bool {
|
| 476 |
+
return !m_strides.empty();
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
std::span<const SizeRef> m_shape;
|
| 480 |
+
std::span<const SizeRef> m_strides;
|
| 481 |
+
DTypeRef m_dtype;
|
| 482 |
+
DeviceRef m_device;
|
| 483 |
+
bool m_has_dtype = false;
|
| 484 |
+
bool m_has_device = false;
|
| 485 |
+
};
|
| 486 |
+
|
| 487 |
+
} // namespace host
|
sglang/jit_kernel/include/sgl_kernel/utils.cuh
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <sgl_kernel/utils.h>
|
| 4 |
+
|
| 5 |
+
#include <dlpack/dlpack.h>
|
| 6 |
+
#include <tvm/ffi/extra/c_env_api.h>
|
| 7 |
+
|
| 8 |
+
#include <concepts>
|
| 9 |
+
#include <cstddef>
|
| 10 |
+
#include <source_location>
|
| 11 |
+
#include <type_traits>
|
| 12 |
+
|
| 13 |
+
namespace device {
|
| 14 |
+
|
| 15 |
+
inline constexpr auto kWarpThreads = 32u;
|
| 16 |
+
|
| 17 |
+
namespace pointer {
|
| 18 |
+
|
| 19 |
+
// we only allow void * pointer arithmetic for safety
|
| 20 |
+
|
| 21 |
+
template <typename T, std::integral... U>
|
| 22 |
+
__always_inline __device__ auto offset(T* ptr, U... offset) -> void* {
|
| 23 |
+
static_assert(std::is_same_v<T, void>, "Pointer arithmetic is only allowed for void* pointers");
|
| 24 |
+
return static_cast<char*>(ptr) + (... + offset);
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
template <typename T, std::integral... U>
|
| 28 |
+
__always_inline __device__ auto offset(const T* ptr, U... offset) -> const void* {
|
| 29 |
+
static_assert(std::is_same_v<T, void>, "Pointer arithmetic is only allowed for void* pointers");
|
| 30 |
+
return static_cast<const char*>(ptr) + (... + offset);
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
} // namespace pointer
|
| 34 |
+
|
| 35 |
+
} // namespace device
|
| 36 |
+
|
| 37 |
+
namespace host {
|
| 38 |
+
|
| 39 |
+
inline auto
|
| 40 |
+
RuntimeDeviceCheck(::cudaError_t error, std::source_location location = std::source_location::current()) -> void {
|
| 41 |
+
if (error != ::cudaSuccess) {
|
| 42 |
+
[[unlikely]];
|
| 43 |
+
::host::panic(location, "CUDA error: ", ::cudaGetErrorString(error));
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
inline auto RuntimeCudaCheck(std::source_location location = std::source_location::current()) -> void {
|
| 48 |
+
return RuntimeDeviceCheck(::cudaGetLastError(), location);
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
template <auto F>
|
| 52 |
+
inline void set_smem_once(std::size_t smem_size) {
|
| 53 |
+
static const auto last_smem_size = [&] {
|
| 54 |
+
RuntimeDeviceCheck(::cudaFuncSetAttribute(F, ::cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
| 55 |
+
return smem_size;
|
| 56 |
+
}();
|
| 57 |
+
RuntimeCheck(
|
| 58 |
+
smem_size <= last_smem_size,
|
| 59 |
+
"Dynamic shared memory size exceeds the previously set maximum size: ",
|
| 60 |
+
last_smem_size,
|
| 61 |
+
" bytes");
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
struct LaunchKernel {
|
| 65 |
+
public:
|
| 66 |
+
explicit LaunchKernel(
|
| 67 |
+
dim3 grid_dim, dim3 block_dim, DLDevice device, std::size_t dynamic_shared_mem_bytes = 0) noexcept
|
| 68 |
+
: m_config(s_make_config(grid_dim, block_dim, resolve_device(device), dynamic_shared_mem_bytes)) {}
|
| 69 |
+
|
| 70 |
+
explicit LaunchKernel(
|
| 71 |
+
dim3 grid_dim, dim3 block_dim, cudaStream_t stream, std::size_t dynamic_shared_mem_bytes = 0) noexcept
|
| 72 |
+
: m_config(s_make_config(grid_dim, block_dim, stream, dynamic_shared_mem_bytes)) {}
|
| 73 |
+
|
| 74 |
+
static auto resolve_device(DLDevice device) -> cudaStream_t {
|
| 75 |
+
return static_cast<cudaStream_t>(::TVMFFIEnvGetStream(device.device_type, device.device_id));
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
LaunchKernel(const LaunchKernel&) = delete;
|
| 79 |
+
LaunchKernel& operator=(const LaunchKernel&) = delete;
|
| 80 |
+
|
| 81 |
+
template <typename T, typename... Args>
|
| 82 |
+
auto operator()(T&& kernel, Args&&... args) const -> void {
|
| 83 |
+
host::RuntimeDeviceCheck(::cudaLaunchKernelEx(&m_config, kernel, std::forward<Args>(args)...));
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
private:
|
| 87 |
+
static auto
|
| 88 |
+
s_make_config(dim3 grid_dim, dim3 block_dim, cudaStream_t stream, std::size_t smem) -> cudaLaunchConfig_t {
|
| 89 |
+
auto config = ::cudaLaunchConfig_t{};
|
| 90 |
+
config.gridDim = grid_dim;
|
| 91 |
+
config.blockDim = block_dim;
|
| 92 |
+
config.dynamicSmemBytes = smem;
|
| 93 |
+
config.stream = stream;
|
| 94 |
+
config.numAttrs = 0;
|
| 95 |
+
return config;
|
| 96 |
+
}
|
| 97 |
+
cudaLaunchConfig_t m_config;
|
| 98 |
+
/// TODO: We can add a queue to store the attributes if needed in the future.
|
| 99 |
+
};
|
| 100 |
+
|
| 101 |
+
} // namespace host
|
sglang/jit_kernel/include/sgl_kernel/utils.h
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <dlpack/dlpack.h>
|
| 4 |
+
|
| 5 |
+
#include <concepts>
|
| 6 |
+
#include <ostream>
|
| 7 |
+
#include <source_location>
|
| 8 |
+
#include <sstream>
|
| 9 |
+
#include <utility>
|
| 10 |
+
|
| 11 |
+
namespace host {
|
| 12 |
+
|
| 13 |
+
struct PanicError : public std::runtime_error {
|
| 14 |
+
public:
|
| 15 |
+
// copy and move constructors
|
| 16 |
+
explicit PanicError(std::string msg) : runtime_error(msg), m_message(std::move(msg)) {}
|
| 17 |
+
auto detail() const -> std::string_view {
|
| 18 |
+
const auto sv = std::string_view{m_message};
|
| 19 |
+
const auto pos = sv.find(": ");
|
| 20 |
+
return pos == std::string_view::npos ? sv : sv.substr(pos + 2);
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
private:
|
| 24 |
+
std::string m_message;
|
| 25 |
+
};
|
| 26 |
+
|
| 27 |
+
template <typename... Args>
|
| 28 |
+
[[noreturn]]
|
| 29 |
+
inline auto panic(std::source_location location, Args&&... args) -> void {
|
| 30 |
+
std::ostringstream os;
|
| 31 |
+
os << "Runtime check failed at " << location.file_name() << ":" << location.line();
|
| 32 |
+
if constexpr (sizeof...(args) > 0) {
|
| 33 |
+
os << ": ";
|
| 34 |
+
(os << ... << std::forward<Args>(args));
|
| 35 |
+
} else {
|
| 36 |
+
os << " in " << location.function_name();
|
| 37 |
+
}
|
| 38 |
+
throw PanicError(std::move(os).str());
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
template <typename... Args>
|
| 42 |
+
struct RuntimeCheck {
|
| 43 |
+
using Loc_t = std::source_location;
|
| 44 |
+
template <typename Cond>
|
| 45 |
+
explicit RuntimeCheck(Cond&& condition, Args&&... args, Loc_t location = Loc_t::current()) {
|
| 46 |
+
if (!condition) {
|
| 47 |
+
[[unlikely]];
|
| 48 |
+
::host::panic(location, std::forward<Args>(args)...);
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
template <typename Cond, typename... Args>
|
| 54 |
+
explicit RuntimeCheck(Cond&&, Args&&...) -> RuntimeCheck<Args...>;
|
| 55 |
+
|
| 56 |
+
template <std::signed_integral T, std::signed_integral U>
|
| 57 |
+
inline constexpr auto div_ceil(T a, U b) {
|
| 58 |
+
return (a + b - 1) / b;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
template <std::unsigned_integral T, std::unsigned_integral U>
|
| 62 |
+
inline constexpr auto div_ceil(T a, U b) {
|
| 63 |
+
return (a + b - 1) / b;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
inline auto dtype_bytes(DLDataType dtype) -> std::size_t {
|
| 67 |
+
return static_cast<std::size_t>(dtype.bits / 8);
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
namespace pointer {
|
| 71 |
+
|
| 72 |
+
// we only allow void * pointer arithmetic for safety
|
| 73 |
+
|
| 74 |
+
template <typename T, std::integral... U>
|
| 75 |
+
inline auto offset(T* ptr, U... offset) -> void* {
|
| 76 |
+
static_assert(std::is_same_v<T, void>, "Pointer arithmetic is only allowed for void* pointers");
|
| 77 |
+
return static_cast<char*>(ptr) + (... + offset);
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
template <typename T, std::integral... U>
|
| 81 |
+
inline auto offset(const T* ptr, U... offset) -> const void* {
|
| 82 |
+
static_assert(std::is_same_v<T, void>, "Pointer arithmetic is only allowed for void* pointers");
|
| 83 |
+
return static_cast<const char*>(ptr) + (... + offset);
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
} // namespace pointer
|
| 87 |
+
|
| 88 |
+
} // namespace host
|
sglang/jit_kernel/include/sgl_kernel/warp.cuh
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <sgl_kernel/utils.cuh>
|
| 3 |
+
|
| 4 |
+
#include <cstddef>
|
| 5 |
+
#include <cstdint>
|
| 6 |
+
#include <type_traits>
|
| 7 |
+
|
| 8 |
+
namespace device::warp {
|
| 9 |
+
|
| 10 |
+
namespace details {
|
| 11 |
+
|
| 12 |
+
template <std::size_t kUnit>
|
| 13 |
+
inline constexpr auto get_mem_package() {
|
| 14 |
+
if constexpr (kUnit == 16) {
|
| 15 |
+
return uint4{};
|
| 16 |
+
} else if constexpr (kUnit == 8) {
|
| 17 |
+
return uint2{};
|
| 18 |
+
} else if constexpr (kUnit == 4) {
|
| 19 |
+
return uint1{};
|
| 20 |
+
} else {
|
| 21 |
+
static_assert(kUnit == 16 || kUnit == 8 || kUnit == 4, "Unsupported memory package size");
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
inline constexpr auto default_unit_size(std::size_t x) -> std::size_t {
|
| 26 |
+
if (x % (16 * kWarpThreads) == 0) return 16;
|
| 27 |
+
if (x % (8 * kWarpThreads) == 0) return 8;
|
| 28 |
+
if (x % (4 * kWarpThreads) == 0) return 4;
|
| 29 |
+
return 0; // trigger static assert in _get_mem_package
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
template <std::size_t kBytes, std::size_t kUnit>
|
| 33 |
+
using mem_package_t = decltype(get_mem_package<kUnit>());
|
| 34 |
+
|
| 35 |
+
template <typename T, std::size_t N>
|
| 36 |
+
struct storage_vec {
|
| 37 |
+
T data[N];
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
+
__always_inline __device__ auto load_nc(const uint1* __restrict__ src) -> uint1 {
|
| 41 |
+
uint32_t tmp;
|
| 42 |
+
asm volatile("ld.global.cs.b32 %0,[%1];" : "=r"(tmp) : "l"(src));
|
| 43 |
+
return uint1{tmp};
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
__always_inline __device__ auto load_nc(const uint2* __restrict__ src) -> uint2 {
|
| 47 |
+
uint32_t tmp0, tmp1;
|
| 48 |
+
asm volatile("ld.global.cs.v2.b32 {%0,%1},[%2];" : "=r"(tmp0), "=r"(tmp1) : "l"(src));
|
| 49 |
+
return uint2{tmp0, tmp1};
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
__always_inline __device__ auto load_nc(const uint4* __restrict__ src) -> uint4 {
|
| 53 |
+
uint32_t tmp0, tmp1, tmp2, tmp3;
|
| 54 |
+
asm volatile("ld.global.cs.v4.b32 {%0,%1,%2,%3},[%4];" : "=r"(tmp0), "=r"(tmp1), "=r"(tmp2), "=r"(tmp3) : "l"(src));
|
| 55 |
+
return uint4{tmp0, tmp1, tmp2, tmp3};
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
__always_inline __device__ void store_nc(uint1* __restrict__ dst, const uint1& value) {
|
| 59 |
+
uint32_t tmp = value.x;
|
| 60 |
+
asm volatile("st.global.cs.b32 [%0],%1;" ::"l"(dst), "r"(tmp));
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
__always_inline __device__ void store_nc(uint2* __restrict__ dst, const uint2& value) {
|
| 64 |
+
uint32_t tmp0 = value.x;
|
| 65 |
+
uint32_t tmp1 = value.y;
|
| 66 |
+
asm volatile("st.global.cs.v2.b32 [%0],{%1,%2};" ::"l"(dst), "r"(tmp0), "r"(tmp1));
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
__always_inline __device__ void store_nc(uint4* __restrict__ dst, const uint4& value) {
|
| 70 |
+
uint32_t tmp0 = value.x;
|
| 71 |
+
uint32_t tmp1 = value.y;
|
| 72 |
+
uint32_t tmp2 = value.z;
|
| 73 |
+
uint32_t tmp3 = value.w;
|
| 74 |
+
asm volatile("st.global.cs.v4.b32 [%0],{%1,%2,%3,%4};" ::"l"(dst), "r"(tmp0), "r"(tmp1), "r"(tmp2), "r"(tmp3));
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
} // namespace details
|
| 78 |
+
|
| 79 |
+
template <
|
| 80 |
+
std::size_t kBytes,
|
| 81 |
+
std::size_t kUnit = details::default_unit_size(kBytes),
|
| 82 |
+
std::size_t kThreads = ::device::kWarpThreads>
|
| 83 |
+
__always_inline __device__ void copy(void* __restrict__ dst, const void* __restrict__ src) {
|
| 84 |
+
using Package = details::mem_package_t<kBytes, kUnit>;
|
| 85 |
+
constexpr auto kBytesPerLoop = sizeof(Package) * kThreads;
|
| 86 |
+
constexpr auto kLoopCount = kBytes / kBytesPerLoop;
|
| 87 |
+
static_assert(kBytes % kBytesPerLoop == 0, "kBytes must be multiple of 128 bytes");
|
| 88 |
+
|
| 89 |
+
const auto dst_packed = static_cast<Package*>(dst);
|
| 90 |
+
const auto src_packed = static_cast<const Package*>(src);
|
| 91 |
+
const auto lane_id = threadIdx.x % kThreads;
|
| 92 |
+
|
| 93 |
+
#pragma unroll kLoopCount
|
| 94 |
+
for (std::size_t i = 0; i < kLoopCount; ++i) {
|
| 95 |
+
const auto j = i * kThreads + lane_id;
|
| 96 |
+
dst_packed[j] = src_packed[j];
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
template <
|
| 101 |
+
std::size_t kBytes,
|
| 102 |
+
std::size_t kUnit = details::default_unit_size(kBytes),
|
| 103 |
+
std::size_t kThreads = ::device::kWarpThreads>
|
| 104 |
+
__always_inline __device__ auto load_vec(const void* __restrict__ src) {
|
| 105 |
+
using Package = details::mem_package_t<kBytes, kUnit>;
|
| 106 |
+
constexpr auto kBytesPerLoop = sizeof(Package) * kThreads;
|
| 107 |
+
constexpr auto kLoopCount = kBytes / kBytesPerLoop;
|
| 108 |
+
static_assert(kBytes % kBytesPerLoop == 0, "kBytes must be multiple of 128 bytes");
|
| 109 |
+
|
| 110 |
+
const auto src_packed = static_cast<const Package*>(src);
|
| 111 |
+
const auto lane_id = threadIdx.x % kThreads;
|
| 112 |
+
details::storage_vec<Package, kLoopCount> vec;
|
| 113 |
+
|
| 114 |
+
#pragma unroll kLoopCount
|
| 115 |
+
for (std::size_t i = 0; i < kLoopCount; ++i) {
|
| 116 |
+
const auto j = i * kThreads + lane_id;
|
| 117 |
+
vec.data[i] = details::load_nc(src_packed + j);
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
return vec;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
template <
|
| 124 |
+
std::size_t kBytes,
|
| 125 |
+
std::size_t kUnit = details::default_unit_size(kBytes),
|
| 126 |
+
std::size_t kThreads = ::device::kWarpThreads,
|
| 127 |
+
typename Tp>
|
| 128 |
+
__always_inline __device__ void store_vec(void* __restrict__ dst, const Tp& vec) {
|
| 129 |
+
using Package = details::mem_package_t<kBytes, kUnit>;
|
| 130 |
+
constexpr auto kBytesPerLoop = sizeof(Package) * kThreads;
|
| 131 |
+
constexpr auto kLoopCount = kBytes / kBytesPerLoop;
|
| 132 |
+
static_assert(kBytes % kBytesPerLoop == 0, "kBytes must be multiple of 128 bytes");
|
| 133 |
+
static_assert(std::is_same_v<Tp, details::storage_vec<Package, kLoopCount>>);
|
| 134 |
+
|
| 135 |
+
const auto dst_packed = static_cast<Package*>(dst);
|
| 136 |
+
const auto lane_id = threadIdx.x % kThreads;
|
| 137 |
+
|
| 138 |
+
#pragma unroll kLoopCount
|
| 139 |
+
for (std::size_t i = 0; i < kLoopCount; ++i) {
|
| 140 |
+
const auto j = i * kThreads + lane_id;
|
| 141 |
+
details::store_nc(dst_packed + j, vec.data[i]);
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
} // namespace device::warp
|
sglang/jit_kernel/utils.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import pathlib
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
from typing import TYPE_CHECKING, List, Tuple, TypeAlias, Union
|
| 6 |
+
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
from tvm_ffi import Module
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _make_wrapper(tup: Tuple[str, str]) -> str:
|
| 12 |
+
export_name, kernel_name = tup
|
| 13 |
+
return f"TVM_FFI_DLL_EXPORT_TYPED_FUNC({export_name}, ({kernel_name}));"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@lru_cache()
|
| 17 |
+
def _resolve_kernel_path() -> pathlib.Path:
|
| 18 |
+
cur_dir = pathlib.Path(__file__).parent.resolve()
|
| 19 |
+
|
| 20 |
+
# first, try this directory structure
|
| 21 |
+
def _environment_install():
|
| 22 |
+
candidate = cur_dir.resolve()
|
| 23 |
+
if (candidate / "include").exists() and (candidate / "csrc").exists():
|
| 24 |
+
return candidate
|
| 25 |
+
return None
|
| 26 |
+
|
| 27 |
+
def _package_install():
|
| 28 |
+
# TODO: support find path by package
|
| 29 |
+
return None
|
| 30 |
+
|
| 31 |
+
path = _environment_install() or _package_install()
|
| 32 |
+
if path is None:
|
| 33 |
+
raise RuntimeError("Cannot find sgl-kernel/jit path")
|
| 34 |
+
return path
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
KERNEL_PATH = _resolve_kernel_path()
|
| 38 |
+
DEFAULT_INCLUDE = [str(KERNEL_PATH / "include")]
|
| 39 |
+
DEFAULT_CFLAGS = ["-std=c++20", "-O3"]
|
| 40 |
+
DEFAULT_CUDA_CFLAGS = ["-std=c++20", "-O3", "--expt-relaxed-constexpr"]
|
| 41 |
+
DEFAULT_LDFLAGS = []
|
| 42 |
+
CPP_TEMPLATE_TYPE: TypeAlias = Union[int, float, bool]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class CPPArgList(list[str]):
|
| 46 |
+
def __str__(self) -> str:
|
| 47 |
+
return ", ".join(self)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def make_cpp_args(*args: CPP_TEMPLATE_TYPE) -> CPPArgList:
|
| 51 |
+
def _convert(arg: CPP_TEMPLATE_TYPE) -> str:
|
| 52 |
+
if isinstance(arg, bool):
|
| 53 |
+
return "true" if arg else "false"
|
| 54 |
+
if isinstance(arg, (int, float)):
|
| 55 |
+
return str(arg)
|
| 56 |
+
raise TypeError(f"Unsupported argument type for cpp template: {type(arg)}")
|
| 57 |
+
|
| 58 |
+
return CPPArgList(_convert(arg) for arg in args)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def load_jit(
|
| 62 |
+
*args: str,
|
| 63 |
+
cpp_files: List[str] | None = None,
|
| 64 |
+
cuda_files: List[str] | None = None,
|
| 65 |
+
cpp_wrappers: List[Tuple[str, str]] | None = None,
|
| 66 |
+
cuda_wrappers: List[Tuple[str, str]] | None = None,
|
| 67 |
+
extra_cflags: List[str] | None = None,
|
| 68 |
+
extra_cuda_cflags: List[str] | None = None,
|
| 69 |
+
extra_ldflags: List[str] | None = None,
|
| 70 |
+
extra_include_paths: List[str] | None = None,
|
| 71 |
+
build_directory: str | None = None,
|
| 72 |
+
) -> Module:
|
| 73 |
+
from tvm_ffi.cpp import load_inline
|
| 74 |
+
|
| 75 |
+
cpp_files = cpp_files or []
|
| 76 |
+
cuda_files = cuda_files or []
|
| 77 |
+
cpp_wrappers = cpp_wrappers or []
|
| 78 |
+
cuda_wrappers = cuda_wrappers or []
|
| 79 |
+
extra_cflags = extra_cflags or []
|
| 80 |
+
extra_cuda_cflags = extra_cuda_cflags or []
|
| 81 |
+
extra_ldflags = extra_ldflags or []
|
| 82 |
+
extra_include_paths = extra_include_paths or []
|
| 83 |
+
|
| 84 |
+
# include cpp files
|
| 85 |
+
cpp_paths = [(KERNEL_PATH / "csrc" / f).resolve() for f in cpp_files]
|
| 86 |
+
cpp_sources = [f'#include "{path}"' for path in cpp_paths]
|
| 87 |
+
cpp_sources += [_make_wrapper(tup) for tup in cpp_wrappers]
|
| 88 |
+
|
| 89 |
+
# include cuda files
|
| 90 |
+
cuda_paths = [(KERNEL_PATH / "csrc" / f).resolve() for f in cuda_files]
|
| 91 |
+
cuda_sources = [f'#include "{path}"' for path in cuda_paths]
|
| 92 |
+
cuda_sources += [_make_wrapper(tup) for tup in cuda_wrappers]
|
| 93 |
+
|
| 94 |
+
return load_inline(
|
| 95 |
+
"sgl_kernel_jit_" + "_".join(str(arg) for arg in args),
|
| 96 |
+
cpp_sources=cpp_sources,
|
| 97 |
+
cuda_sources=cuda_sources,
|
| 98 |
+
extra_cflags=DEFAULT_CFLAGS + extra_cflags,
|
| 99 |
+
extra_cuda_cflags=DEFAULT_CUDA_CFLAGS + extra_cuda_cflags,
|
| 100 |
+
extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags,
|
| 101 |
+
extra_include_paths=DEFAULT_INCLUDE + extra_include_paths,
|
| 102 |
+
build_directory=build_directory,
|
| 103 |
+
)
|
sglang/lang/__pycache__/api.cpython-311.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
sglang/lang/__pycache__/chat_template.cpython-311.pyc
ADDED
|
Binary file (19.4 kB). View file
|
|
|
sglang/lang/__pycache__/choices.cpython-311.pyc
ADDED
|
Binary file (9.41 kB). View file
|
|
|
sglang/lang/__pycache__/interpreter.cpython-311.pyc
ADDED
|
Binary file (50 kB). View file
|
|
|
sglang/lang/__pycache__/ir.cpython-311.pyc
ADDED
|
Binary file (33.4 kB). View file
|
|
|
sglang/lang/api.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Public APIs of the language."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from typing import Callable, List, Optional, Union
|
| 5 |
+
|
| 6 |
+
from sglang.global_config import global_config
|
| 7 |
+
from sglang.lang.backend.base_backend import BaseBackend
|
| 8 |
+
from sglang.lang.choices import ChoicesSamplingMethod, token_length_normalized
|
| 9 |
+
from sglang.lang.ir import (
|
| 10 |
+
SglExpr,
|
| 11 |
+
SglExprList,
|
| 12 |
+
SglFunction,
|
| 13 |
+
SglGen,
|
| 14 |
+
SglImage,
|
| 15 |
+
SglRoleBegin,
|
| 16 |
+
SglRoleEnd,
|
| 17 |
+
SglSelect,
|
| 18 |
+
SglSeparateReasoning,
|
| 19 |
+
SglVideo,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def function(
|
| 24 |
+
func: Optional[Callable] = None, num_api_spec_tokens: Optional[int] = None
|
| 25 |
+
):
|
| 26 |
+
if func:
|
| 27 |
+
return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
|
| 28 |
+
|
| 29 |
+
def decorator(func):
|
| 30 |
+
return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
|
| 31 |
+
|
| 32 |
+
return decorator
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def Runtime(*args, **kwargs):
|
| 36 |
+
# Avoid importing unnecessary dependency
|
| 37 |
+
from sglang.lang.backend.runtime_endpoint import Runtime
|
| 38 |
+
|
| 39 |
+
return Runtime(*args, **kwargs)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def Engine(*args, **kwargs):
|
| 43 |
+
# Avoid importing unnecessary dependency
|
| 44 |
+
from sglang.srt.entrypoints.engine import Engine
|
| 45 |
+
|
| 46 |
+
return Engine(*args, **kwargs)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def set_default_backend(backend: BaseBackend):
|
| 50 |
+
global_config.default_backend = backend
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def flush_cache(backend: Optional[BaseBackend] = None):
|
| 54 |
+
backend = backend or global_config.default_backend
|
| 55 |
+
if backend is None:
|
| 56 |
+
return False
|
| 57 |
+
|
| 58 |
+
# If backend is Runtime
|
| 59 |
+
if hasattr(backend, "endpoint"):
|
| 60 |
+
backend = backend.endpoint
|
| 61 |
+
return backend.flush_cache()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_server_info(backend: Optional[BaseBackend] = None):
|
| 65 |
+
backend = backend or global_config.default_backend
|
| 66 |
+
if backend is None:
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
# If backend is Runtime
|
| 70 |
+
if hasattr(backend, "endpoint"):
|
| 71 |
+
backend = backend.endpoint
|
| 72 |
+
return backend.get_server_info()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def gen(
|
| 76 |
+
name: Optional[str] = None,
|
| 77 |
+
max_tokens: Optional[int] = None,
|
| 78 |
+
min_tokens: Optional[int] = None,
|
| 79 |
+
n: Optional[int] = None,
|
| 80 |
+
stop: Optional[Union[str, List[str]]] = None,
|
| 81 |
+
stop_token_ids: Optional[List[int]] = None,
|
| 82 |
+
stop_regex: Optional[Union[str, List[str]]] = None,
|
| 83 |
+
temperature: Optional[float] = None,
|
| 84 |
+
top_p: Optional[float] = None,
|
| 85 |
+
top_k: Optional[int] = None,
|
| 86 |
+
min_p: Optional[float] = None,
|
| 87 |
+
frequency_penalty: Optional[float] = None,
|
| 88 |
+
presence_penalty: Optional[float] = None,
|
| 89 |
+
ignore_eos: Optional[bool] = None,
|
| 90 |
+
return_logprob: Optional[bool] = None,
|
| 91 |
+
logprob_start_len: Optional[int] = None,
|
| 92 |
+
top_logprobs_num: Optional[int] = None,
|
| 93 |
+
return_text_in_logprobs: Optional[bool] = None,
|
| 94 |
+
dtype: Optional[Union[type, str]] = None,
|
| 95 |
+
choices: Optional[List[str]] = None,
|
| 96 |
+
choices_method: Optional[ChoicesSamplingMethod] = None,
|
| 97 |
+
regex: Optional[str] = None,
|
| 98 |
+
json_schema: Optional[str] = None,
|
| 99 |
+
):
|
| 100 |
+
"""Call the model to generate. See the meaning of the arguments in docs/backend/sampling_params.md"""
|
| 101 |
+
|
| 102 |
+
if choices:
|
| 103 |
+
return SglSelect(
|
| 104 |
+
name,
|
| 105 |
+
choices,
|
| 106 |
+
0.0 if temperature is None else temperature,
|
| 107 |
+
token_length_normalized if choices_method is None else choices_method,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# check regex is valid
|
| 111 |
+
if regex is not None:
|
| 112 |
+
try:
|
| 113 |
+
re.compile(regex)
|
| 114 |
+
except re.error as e:
|
| 115 |
+
raise e
|
| 116 |
+
|
| 117 |
+
return SglGen(
|
| 118 |
+
name,
|
| 119 |
+
max_tokens,
|
| 120 |
+
min_tokens,
|
| 121 |
+
n,
|
| 122 |
+
stop,
|
| 123 |
+
stop_token_ids,
|
| 124 |
+
stop_regex,
|
| 125 |
+
temperature,
|
| 126 |
+
top_p,
|
| 127 |
+
top_k,
|
| 128 |
+
min_p,
|
| 129 |
+
frequency_penalty,
|
| 130 |
+
presence_penalty,
|
| 131 |
+
ignore_eos,
|
| 132 |
+
return_logprob,
|
| 133 |
+
logprob_start_len,
|
| 134 |
+
top_logprobs_num,
|
| 135 |
+
return_text_in_logprobs,
|
| 136 |
+
dtype,
|
| 137 |
+
regex,
|
| 138 |
+
json_schema,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def gen_int(
|
| 143 |
+
name: Optional[str] = None,
|
| 144 |
+
max_tokens: Optional[int] = None,
|
| 145 |
+
n: Optional[int] = None,
|
| 146 |
+
stop: Optional[Union[str, List[str]]] = None,
|
| 147 |
+
stop_token_ids: Optional[List[int]] = None,
|
| 148 |
+
stop_regex: Optional[Union[str, List[str]]] = None,
|
| 149 |
+
temperature: Optional[float] = None,
|
| 150 |
+
top_p: Optional[float] = None,
|
| 151 |
+
top_k: Optional[int] = None,
|
| 152 |
+
min_p: Optional[float] = None,
|
| 153 |
+
frequency_penalty: Optional[float] = None,
|
| 154 |
+
presence_penalty: Optional[float] = None,
|
| 155 |
+
ignore_eos: Optional[bool] = None,
|
| 156 |
+
return_logprob: Optional[bool] = None,
|
| 157 |
+
logprob_start_len: Optional[int] = None,
|
| 158 |
+
top_logprobs_num: Optional[int] = None,
|
| 159 |
+
return_text_in_logprobs: Optional[bool] = None,
|
| 160 |
+
):
|
| 161 |
+
return SglGen(
|
| 162 |
+
name,
|
| 163 |
+
max_tokens,
|
| 164 |
+
None,
|
| 165 |
+
n,
|
| 166 |
+
stop,
|
| 167 |
+
stop_token_ids,
|
| 168 |
+
stop_regex,
|
| 169 |
+
temperature,
|
| 170 |
+
top_p,
|
| 171 |
+
top_k,
|
| 172 |
+
min_p,
|
| 173 |
+
frequency_penalty,
|
| 174 |
+
presence_penalty,
|
| 175 |
+
ignore_eos,
|
| 176 |
+
return_logprob,
|
| 177 |
+
logprob_start_len,
|
| 178 |
+
top_logprobs_num,
|
| 179 |
+
return_text_in_logprobs,
|
| 180 |
+
int,
|
| 181 |
+
None,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def gen_string(
|
| 186 |
+
name: Optional[str] = None,
|
| 187 |
+
max_tokens: Optional[int] = None,
|
| 188 |
+
n: Optional[int] = None,
|
| 189 |
+
stop: Optional[Union[str, List[str]]] = None,
|
| 190 |
+
stop_token_ids: Optional[List[int]] = None,
|
| 191 |
+
stop_regex: Optional[Union[str, List[str]]] = None,
|
| 192 |
+
temperature: Optional[float] = None,
|
| 193 |
+
top_p: Optional[float] = None,
|
| 194 |
+
top_k: Optional[int] = None,
|
| 195 |
+
min_p: Optional[float] = None,
|
| 196 |
+
frequency_penalty: Optional[float] = None,
|
| 197 |
+
presence_penalty: Optional[float] = None,
|
| 198 |
+
ignore_eos: Optional[bool] = None,
|
| 199 |
+
return_logprob: Optional[bool] = None,
|
| 200 |
+
logprob_start_len: Optional[int] = None,
|
| 201 |
+
top_logprobs_num: Optional[int] = None,
|
| 202 |
+
return_text_in_logprobs: Optional[bool] = None,
|
| 203 |
+
):
|
| 204 |
+
return SglGen(
|
| 205 |
+
name,
|
| 206 |
+
max_tokens,
|
| 207 |
+
None,
|
| 208 |
+
n,
|
| 209 |
+
stop,
|
| 210 |
+
stop_token_ids,
|
| 211 |
+
stop_regex,
|
| 212 |
+
temperature,
|
| 213 |
+
top_p,
|
| 214 |
+
top_k,
|
| 215 |
+
min_p,
|
| 216 |
+
frequency_penalty,
|
| 217 |
+
presence_penalty,
|
| 218 |
+
ignore_eos,
|
| 219 |
+
return_logprob,
|
| 220 |
+
logprob_start_len,
|
| 221 |
+
top_logprobs_num,
|
| 222 |
+
return_text_in_logprobs,
|
| 223 |
+
str,
|
| 224 |
+
None,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def image(expr: SglExpr):
|
| 229 |
+
return SglImage(expr)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def video(path: str, num_frames: int):
|
| 233 |
+
return SglVideo(path, num_frames)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def select(
|
| 237 |
+
name: Optional[str] = None,
|
| 238 |
+
choices: Optional[List[str]] = None,
|
| 239 |
+
temperature: float = 0.0,
|
| 240 |
+
choices_method: ChoicesSamplingMethod = token_length_normalized,
|
| 241 |
+
):
|
| 242 |
+
assert choices is not None
|
| 243 |
+
return SglSelect(name, choices, temperature, choices_method)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def _role_common(name: str, expr: Optional[SglExpr] = None):
|
| 247 |
+
if expr is None:
|
| 248 |
+
return SglExprList([SglRoleBegin(name), SglRoleEnd(name)])
|
| 249 |
+
else:
|
| 250 |
+
return SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def system(expr: Optional[SglExpr] = None):
|
| 254 |
+
return _role_common("system", expr)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def user(expr: Optional[SglExpr] = None):
|
| 258 |
+
return _role_common("user", expr)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def assistant(expr: Optional[SglExpr] = None):
|
| 262 |
+
return _role_common("assistant", expr)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def system_begin():
|
| 266 |
+
return SglRoleBegin("system")
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def system_end():
|
| 270 |
+
return SglRoleEnd("system")
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def user_begin():
|
| 274 |
+
return SglRoleBegin("user")
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def user_end():
|
| 278 |
+
return SglRoleEnd("user")
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def assistant_begin():
|
| 282 |
+
return SglRoleBegin("assistant")
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def assistant_end():
|
| 286 |
+
return SglRoleEnd("assistant")
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def separate_reasoning(
|
| 290 |
+
expr: Optional[SglExpr] = None, model_type: Optional[str] = None
|
| 291 |
+
):
|
| 292 |
+
return SglExprList([expr, SglSeparateReasoning(model_type, expr=expr)])
|
sglang/lang/backend/__pycache__/base_backend.cpython-311.pyc
ADDED
|
Binary file (4.6 kB). View file
|
|
|
sglang/lang/backend/__pycache__/runtime_endpoint.cpython-311.pyc
ADDED
|
Binary file (25.5 kB). View file
|
|
|
sglang/lang/backend/anthropic.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sglang.lang.backend.base_backend import BaseBackend
|
| 2 |
+
from sglang.lang.chat_template import get_chat_template
|
| 3 |
+
from sglang.lang.interpreter import StreamExecutor
|
| 4 |
+
from sglang.lang.ir import SglSamplingParams
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
import anthropic
|
| 8 |
+
except ImportError as e:
|
| 9 |
+
anthropic = e
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Anthropic(BaseBackend):
|
| 13 |
+
def __init__(self, model_name, *args, **kwargs):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
if isinstance(anthropic, Exception):
|
| 17 |
+
raise anthropic
|
| 18 |
+
|
| 19 |
+
self.model_name = model_name
|
| 20 |
+
self.chat_template = get_chat_template("claude")
|
| 21 |
+
self.client = anthropic.Anthropic(*args, **kwargs)
|
| 22 |
+
|
| 23 |
+
def get_chat_template(self):
|
| 24 |
+
return self.chat_template
|
| 25 |
+
|
| 26 |
+
def generate(
|
| 27 |
+
self,
|
| 28 |
+
s: StreamExecutor,
|
| 29 |
+
sampling_params: SglSamplingParams,
|
| 30 |
+
):
|
| 31 |
+
if s.messages_:
|
| 32 |
+
messages = s.messages_
|
| 33 |
+
else:
|
| 34 |
+
messages = [{"role": "user", "content": s.text_}]
|
| 35 |
+
|
| 36 |
+
if messages and messages[0]["role"] == "system":
|
| 37 |
+
system = messages.pop(0)["content"]
|
| 38 |
+
else:
|
| 39 |
+
system = ""
|
| 40 |
+
|
| 41 |
+
ret = self.client.messages.create(
|
| 42 |
+
model=self.model_name,
|
| 43 |
+
system=system,
|
| 44 |
+
messages=messages,
|
| 45 |
+
**sampling_params.to_anthropic_kwargs(),
|
| 46 |
+
)
|
| 47 |
+
comp = ret.content[0].text
|
| 48 |
+
|
| 49 |
+
return comp, {}
|
| 50 |
+
|
| 51 |
+
def generate_stream(
|
| 52 |
+
self,
|
| 53 |
+
s: StreamExecutor,
|
| 54 |
+
sampling_params: SglSamplingParams,
|
| 55 |
+
):
|
| 56 |
+
if s.messages_:
|
| 57 |
+
messages = s.messages_
|
| 58 |
+
else:
|
| 59 |
+
messages = [{"role": "user", "content": s.text_}]
|
| 60 |
+
|
| 61 |
+
if messages and messages[0]["role"] == "system":
|
| 62 |
+
system = messages.pop(0)["content"]
|
| 63 |
+
else:
|
| 64 |
+
system = ""
|
| 65 |
+
|
| 66 |
+
with self.client.messages.stream(
|
| 67 |
+
model=self.model_name,
|
| 68 |
+
system=system,
|
| 69 |
+
messages=messages,
|
| 70 |
+
**sampling_params.to_anthropic_kwargs(),
|
| 71 |
+
) as stream:
|
| 72 |
+
for text in stream.text_stream:
|
| 73 |
+
yield text, {}
|
sglang/lang/backend/base_backend.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Union
|
| 2 |
+
|
| 3 |
+
from sglang.lang.chat_template import get_chat_template
|
| 4 |
+
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
| 5 |
+
from sglang.lang.interpreter import StreamExecutor
|
| 6 |
+
from sglang.lang.ir import SglSamplingParams
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BaseBackend:
|
| 10 |
+
def __init__(self) -> None:
|
| 11 |
+
self.support_concate_and_append = False
|
| 12 |
+
self.chat_template = get_chat_template("default")
|
| 13 |
+
|
| 14 |
+
def get_model_name(self):
|
| 15 |
+
raise NotImplementedError()
|
| 16 |
+
|
| 17 |
+
def get_chat_template(self):
|
| 18 |
+
return self.chat_template
|
| 19 |
+
|
| 20 |
+
def cache_prefix(self, prefix_str: str):
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
def uncache_prefix(self, rid: str):
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
def end_request(self, rid: Union[str, List[str]]):
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
def begin_program(self, s: StreamExecutor):
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
def end_program(self, s: Union[StreamExecutor, List[StreamExecutor]]):
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
def commit_lazy_operations(self, s: StreamExecutor):
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
def fork_program(
|
| 39 |
+
self,
|
| 40 |
+
src: StreamExecutor,
|
| 41 |
+
dst: List[StreamExecutor],
|
| 42 |
+
position_ids_offset: Optional[List[int]] = None,
|
| 43 |
+
):
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
def fill_image(self, s: StreamExecutor):
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
def generate(
|
| 50 |
+
self,
|
| 51 |
+
s: StreamExecutor,
|
| 52 |
+
sampling_params: SglSamplingParams,
|
| 53 |
+
):
|
| 54 |
+
raise NotImplementedError()
|
| 55 |
+
|
| 56 |
+
def generate_stream(
|
| 57 |
+
self,
|
| 58 |
+
s: StreamExecutor,
|
| 59 |
+
sampling_params: SglSamplingParams,
|
| 60 |
+
):
|
| 61 |
+
raise NotImplementedError()
|
| 62 |
+
|
| 63 |
+
def select(
|
| 64 |
+
self,
|
| 65 |
+
s: StreamExecutor,
|
| 66 |
+
choices: List[str],
|
| 67 |
+
temperature: float,
|
| 68 |
+
choices_method: Optional[ChoicesSamplingMethod] = None,
|
| 69 |
+
) -> ChoicesDecision:
|
| 70 |
+
raise NotImplementedError()
|
| 71 |
+
|
| 72 |
+
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
| 73 |
+
raise NotImplementedError()
|
| 74 |
+
|
| 75 |
+
def shutdown(self):
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
def flush_cache(self):
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
+
def get_server_info(self):
|
| 82 |
+
pass
|
sglang/lang/backend/litellm.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Mapping, Optional
|
| 2 |
+
|
| 3 |
+
from sglang.lang.backend.base_backend import BaseBackend
|
| 4 |
+
from sglang.lang.chat_template import get_chat_template_by_model_path
|
| 5 |
+
from sglang.lang.interpreter import StreamExecutor
|
| 6 |
+
from sglang.lang.ir import SglSamplingParams
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
import litellm
|
| 10 |
+
except ImportError as e:
|
| 11 |
+
litellm = e
|
| 12 |
+
litellm.num_retries = 1
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LiteLLM(BaseBackend):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
model_name,
|
| 19 |
+
chat_template=None,
|
| 20 |
+
api_key=None,
|
| 21 |
+
organization: Optional[str] = None,
|
| 22 |
+
base_url: Optional[str] = None,
|
| 23 |
+
timeout: Optional[float] = 600,
|
| 24 |
+
max_retries: Optional[int] = litellm.num_retries,
|
| 25 |
+
default_headers: Optional[Mapping[str, str]] = None,
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
|
| 29 |
+
if isinstance(litellm, Exception):
|
| 30 |
+
raise litellm
|
| 31 |
+
|
| 32 |
+
self.model_name = model_name
|
| 33 |
+
|
| 34 |
+
self.chat_template = chat_template or get_chat_template_by_model_path(
|
| 35 |
+
model_name
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
self.client_params = {
|
| 39 |
+
"api_key": api_key,
|
| 40 |
+
"organization": organization,
|
| 41 |
+
"base_url": base_url,
|
| 42 |
+
"timeout": timeout,
|
| 43 |
+
"max_retries": max_retries,
|
| 44 |
+
"default_headers": default_headers,
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
def get_chat_template(self):
|
| 48 |
+
return self.chat_template
|
| 49 |
+
|
| 50 |
+
def generate(
|
| 51 |
+
self,
|
| 52 |
+
s: StreamExecutor,
|
| 53 |
+
sampling_params: SglSamplingParams,
|
| 54 |
+
):
|
| 55 |
+
if s.messages_:
|
| 56 |
+
messages = s.messages_
|
| 57 |
+
else:
|
| 58 |
+
messages = [{"role": "user", "content": s.text_}]
|
| 59 |
+
|
| 60 |
+
ret = litellm.completion(
|
| 61 |
+
model=self.model_name,
|
| 62 |
+
messages=messages,
|
| 63 |
+
**self.client_params,
|
| 64 |
+
**sampling_params.to_litellm_kwargs(),
|
| 65 |
+
)
|
| 66 |
+
comp = ret.choices[0].message.content
|
| 67 |
+
|
| 68 |
+
return comp, {}
|
| 69 |
+
|
| 70 |
+
def generate_stream(
|
| 71 |
+
self,
|
| 72 |
+
s: StreamExecutor,
|
| 73 |
+
sampling_params: SglSamplingParams,
|
| 74 |
+
):
|
| 75 |
+
if s.messages_:
|
| 76 |
+
messages = s.messages_
|
| 77 |
+
else:
|
| 78 |
+
messages = [{"role": "user", "content": s.text_}]
|
| 79 |
+
|
| 80 |
+
ret = litellm.completion(
|
| 81 |
+
model=self.model_name,
|
| 82 |
+
messages=messages,
|
| 83 |
+
stream=True,
|
| 84 |
+
**self.client_params,
|
| 85 |
+
**sampling_params.to_litellm_kwargs(),
|
| 86 |
+
)
|
| 87 |
+
for chunk in ret:
|
| 88 |
+
text = chunk.choices[0].delta.content
|
| 89 |
+
if text is not None:
|
| 90 |
+
yield text, {}
|
sglang/lang/backend/openai.py
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import logging
|
| 3 |
+
import time
|
| 4 |
+
import warnings
|
| 5 |
+
from typing import List, Optional, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from sglang.lang.backend.base_backend import BaseBackend
|
| 10 |
+
from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
|
| 11 |
+
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
| 12 |
+
from sglang.lang.interpreter import StreamExecutor
|
| 13 |
+
from sglang.lang.ir import SglSamplingParams
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import openai
|
| 17 |
+
import tiktoken
|
| 18 |
+
except ImportError as e:
|
| 19 |
+
openai = tiktoken = e
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def create_logit_bias_int(tokenizer):
|
| 26 |
+
"""Get logit bias for integer numbers."""
|
| 27 |
+
int_token_ids = []
|
| 28 |
+
|
| 29 |
+
tokens = tokenizer._mergeable_ranks
|
| 30 |
+
for token, token_id in tokens.items():
|
| 31 |
+
s = tokenizer.decode([token_id])
|
| 32 |
+
if all([c.isdigit() for c in s]) or s in [" "]:
|
| 33 |
+
int_token_ids.append(token_id)
|
| 34 |
+
if len(int_token_ids) >= 300: # OpenAI API limit
|
| 35 |
+
break
|
| 36 |
+
special_tokens = tokenizer._special_tokens
|
| 37 |
+
mask = {t: 100 for t in int_token_ids[:299]}
|
| 38 |
+
mask[special_tokens["<|endoftext|>"]] = 100
|
| 39 |
+
return mask
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
INSTRUCT_MODEL_NAMES = [
|
| 43 |
+
"gpt-3.5-turbo-instruct",
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclasses.dataclass
|
| 48 |
+
class TokenUsage:
|
| 49 |
+
prompt_tokens: int
|
| 50 |
+
completion_tokens: int
|
| 51 |
+
|
| 52 |
+
def reset(self):
|
| 53 |
+
self.prompt_tokens = self.completion_tokens = 0
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class OpenAI(BaseBackend):
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
model_name: str,
|
| 60 |
+
is_chat_model: Optional[bool] = None,
|
| 61 |
+
chat_template: Optional[ChatTemplate] = None,
|
| 62 |
+
is_azure: bool = False,
|
| 63 |
+
*args,
|
| 64 |
+
**kwargs,
|
| 65 |
+
):
|
| 66 |
+
super().__init__()
|
| 67 |
+
|
| 68 |
+
if isinstance(openai, Exception):
|
| 69 |
+
raise openai
|
| 70 |
+
|
| 71 |
+
if is_azure:
|
| 72 |
+
self.client = openai.AzureOpenAI(*args, **kwargs)
|
| 73 |
+
else:
|
| 74 |
+
self.client = openai.OpenAI(*args, **kwargs)
|
| 75 |
+
|
| 76 |
+
self.model_name = model_name
|
| 77 |
+
try:
|
| 78 |
+
self.tokenizer = tiktoken.encoding_for_model(model_name)
|
| 79 |
+
except KeyError:
|
| 80 |
+
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
| 81 |
+
self.logit_bias_int = create_logit_bias_int(self.tokenizer)
|
| 82 |
+
|
| 83 |
+
self.chat_template = chat_template or get_chat_template_by_model_path(
|
| 84 |
+
model_name
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
if is_chat_model is not None:
|
| 88 |
+
self.is_chat_model = is_chat_model
|
| 89 |
+
else:
|
| 90 |
+
if model_name in INSTRUCT_MODEL_NAMES:
|
| 91 |
+
self.is_chat_model = False
|
| 92 |
+
else:
|
| 93 |
+
self.is_chat_model = True
|
| 94 |
+
|
| 95 |
+
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
|
| 96 |
+
|
| 97 |
+
# Usage
|
| 98 |
+
self.token_usage = TokenUsage(0, 0)
|
| 99 |
+
|
| 100 |
+
# API speculative execution
|
| 101 |
+
# TODO(ying): This does not support multi-threading (run_batch)
|
| 102 |
+
self.spec_kwargs = {}
|
| 103 |
+
self.spec_format = []
|
| 104 |
+
self.spec_max_num_tries = 3
|
| 105 |
+
|
| 106 |
+
def get_chat_template(self):
|
| 107 |
+
return self.chat_template
|
| 108 |
+
|
| 109 |
+
def _prepare_spec_execution(
|
| 110 |
+
self,
|
| 111 |
+
sampling_params: SglSamplingParams,
|
| 112 |
+
num_api_spec_tokens: int,
|
| 113 |
+
spec_var_name: str,
|
| 114 |
+
):
|
| 115 |
+
if "max_tokens" not in self.spec_kwargs:
|
| 116 |
+
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
|
| 117 |
+
else:
|
| 118 |
+
assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
|
| 119 |
+
|
| 120 |
+
params = sampling_params.to_openai_kwargs()
|
| 121 |
+
for key, value in params.items():
|
| 122 |
+
if key in ["stop"]:
|
| 123 |
+
continue
|
| 124 |
+
if key in ["max_tokens"]:
|
| 125 |
+
warnings.warn(
|
| 126 |
+
"The parameter max_tokens will be overwritten by speculated number of tokens."
|
| 127 |
+
)
|
| 128 |
+
continue
|
| 129 |
+
if key not in self.spec_kwargs:
|
| 130 |
+
self.spec_kwargs[key] = value
|
| 131 |
+
else:
|
| 132 |
+
assert (
|
| 133 |
+
value == self.spec_kwargs[key]
|
| 134 |
+
), "sampling parameters should be consistent if turn on api speculative execution."
|
| 135 |
+
self.spec_format.append(
|
| 136 |
+
{"text": "", "stop": params["stop"], "name": spec_var_name}
|
| 137 |
+
)
|
| 138 |
+
return "", {}
|
| 139 |
+
|
| 140 |
+
def generate(
|
| 141 |
+
self,
|
| 142 |
+
s: StreamExecutor,
|
| 143 |
+
sampling_params: SglSamplingParams,
|
| 144 |
+
spec_var_name: str = None,
|
| 145 |
+
):
|
| 146 |
+
if sampling_params.dtype is None:
|
| 147 |
+
if self.is_chat_model:
|
| 148 |
+
if s.num_api_spec_tokens is None:
|
| 149 |
+
if not s.text_.endswith(self.chat_prefix):
|
| 150 |
+
raise RuntimeError(
|
| 151 |
+
"This use case is not supported if api speculative execution is off. "
|
| 152 |
+
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
|
| 153 |
+
"Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
|
| 154 |
+
)
|
| 155 |
+
prompt = s.messages_
|
| 156 |
+
else:
|
| 157 |
+
return self._prepare_spec_execution(
|
| 158 |
+
sampling_params, s.num_api_spec_tokens, spec_var_name
|
| 159 |
+
)
|
| 160 |
+
else:
|
| 161 |
+
prompt = s.text_
|
| 162 |
+
|
| 163 |
+
kwargs = sampling_params.to_openai_kwargs()
|
| 164 |
+
if (
|
| 165 |
+
self.model_name.startswith("o1")
|
| 166 |
+
or self.model_name.startswith("o3")
|
| 167 |
+
or "o1" in self.model_name
|
| 168 |
+
):
|
| 169 |
+
kwargs.pop("max_tokens", None)
|
| 170 |
+
else:
|
| 171 |
+
kwargs.pop("max_completion_tokens", None)
|
| 172 |
+
|
| 173 |
+
comp = openai_completion(
|
| 174 |
+
client=self.client,
|
| 175 |
+
token_usage=self.token_usage,
|
| 176 |
+
is_chat=self.is_chat_model,
|
| 177 |
+
model=self.model_name,
|
| 178 |
+
prompt=prompt,
|
| 179 |
+
**kwargs,
|
| 180 |
+
)
|
| 181 |
+
# Keep the returned list (or string) as is.
|
| 182 |
+
elif sampling_params.dtype in [str, "str", "string"]:
|
| 183 |
+
assert (
|
| 184 |
+
not self.is_chat_model
|
| 185 |
+
), "constrained type not supported on chat model"
|
| 186 |
+
kwargs = sampling_params.to_openai_kwargs()
|
| 187 |
+
kwargs.pop("stop")
|
| 188 |
+
comp = openai_completion(
|
| 189 |
+
client=self.client,
|
| 190 |
+
token_usage=self.token_usage,
|
| 191 |
+
is_chat=self.is_chat_model,
|
| 192 |
+
model=self.model_name,
|
| 193 |
+
prompt=s.text_ + '"',
|
| 194 |
+
stop='"',
|
| 195 |
+
**kwargs,
|
| 196 |
+
)
|
| 197 |
+
# Wrap each element in quotes if we have a list.
|
| 198 |
+
if isinstance(comp, list):
|
| 199 |
+
comp = ['"' + x + '"' for x in comp]
|
| 200 |
+
else:
|
| 201 |
+
comp = '"' + comp + '"'
|
| 202 |
+
elif sampling_params.dtype in [int, "int"]:
|
| 203 |
+
assert (
|
| 204 |
+
not self.is_chat_model
|
| 205 |
+
), "constrained type not supported on chat model"
|
| 206 |
+
kwargs = sampling_params.to_openai_kwargs()
|
| 207 |
+
kwargs.pop("stop")
|
| 208 |
+
comp = openai_completion(
|
| 209 |
+
client=self.client,
|
| 210 |
+
token_usage=self.token_usage,
|
| 211 |
+
is_chat=self.is_chat_model,
|
| 212 |
+
model=self.model_name,
|
| 213 |
+
prompt=s.text_,
|
| 214 |
+
logit_bias=self.logit_bias_int,
|
| 215 |
+
stop=[" "],
|
| 216 |
+
**kwargs,
|
| 217 |
+
)
|
| 218 |
+
# Leave as a list if that's what is returned.
|
| 219 |
+
else:
|
| 220 |
+
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
|
| 221 |
+
|
| 222 |
+
return comp, {}
|
| 223 |
+
|
| 224 |
+
def spec_fill(self, value: str):
|
| 225 |
+
assert self.is_chat_model
|
| 226 |
+
self.spec_format.append({"text": value, "stop": None, "name": None})
|
| 227 |
+
|
| 228 |
+
def spec_pattern_match(self, comp):
|
| 229 |
+
for i, term in enumerate(self.spec_format):
|
| 230 |
+
text = term["text"]
|
| 231 |
+
if text != "":
|
| 232 |
+
if comp.startswith(text):
|
| 233 |
+
comp = comp[len(text) :]
|
| 234 |
+
else:
|
| 235 |
+
return False
|
| 236 |
+
else:
|
| 237 |
+
pos = comp.find(term["stop"])
|
| 238 |
+
if pos != -1:
|
| 239 |
+
term["text"] = comp[:pos]
|
| 240 |
+
comp = comp[pos:]
|
| 241 |
+
else:
|
| 242 |
+
if i == len(self.spec_format) - 1:
|
| 243 |
+
term["text"] = comp
|
| 244 |
+
else:
|
| 245 |
+
return False
|
| 246 |
+
return True
|
| 247 |
+
|
| 248 |
+
def role_end_generate(
|
| 249 |
+
self,
|
| 250 |
+
s: StreamExecutor,
|
| 251 |
+
):
|
| 252 |
+
if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
|
| 253 |
+
return
|
| 254 |
+
|
| 255 |
+
comp = ""
|
| 256 |
+
if not all(x["name"] is None for x in self.spec_format):
|
| 257 |
+
# TODO(ying): throw errors or warnings
|
| 258 |
+
for i in range(self.spec_max_num_tries):
|
| 259 |
+
comp = openai_completion(
|
| 260 |
+
client=self.client,
|
| 261 |
+
token_usage=self.token_usage,
|
| 262 |
+
is_chat=self.is_chat_model,
|
| 263 |
+
model=self.model_name,
|
| 264 |
+
prompt=s.messages_,
|
| 265 |
+
**self.spec_kwargs,
|
| 266 |
+
)
|
| 267 |
+
# Use a string for pattern matching.
|
| 268 |
+
comp_for_match = comp[0] if isinstance(comp, list) else comp
|
| 269 |
+
if self.spec_pattern_match(comp_for_match):
|
| 270 |
+
break
|
| 271 |
+
|
| 272 |
+
for term in self.spec_format:
|
| 273 |
+
s.text_ += term["text"]
|
| 274 |
+
name = term["name"]
|
| 275 |
+
if name is not None:
|
| 276 |
+
s.variables[name] = term["text"]
|
| 277 |
+
s.meta_info[name] = {}
|
| 278 |
+
s.variable_event[name].set()
|
| 279 |
+
|
| 280 |
+
self.spec_kwargs = {}
|
| 281 |
+
self.spec_format = []
|
| 282 |
+
|
| 283 |
+
def generate_stream(
|
| 284 |
+
self,
|
| 285 |
+
s: StreamExecutor,
|
| 286 |
+
sampling_params: SglSamplingParams,
|
| 287 |
+
):
|
| 288 |
+
if sampling_params.dtype is None:
|
| 289 |
+
if self.is_chat_model:
|
| 290 |
+
if not s.text_.endswith(self.chat_prefix):
|
| 291 |
+
raise RuntimeError(
|
| 292 |
+
"This use case is not supported. "
|
| 293 |
+
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
|
| 294 |
+
)
|
| 295 |
+
prompt = s.messages_
|
| 296 |
+
else:
|
| 297 |
+
prompt = s.text_
|
| 298 |
+
|
| 299 |
+
kwargs = sampling_params.to_openai_kwargs()
|
| 300 |
+
generator = openai_completion_stream(
|
| 301 |
+
client=self.client,
|
| 302 |
+
token_usage=self.token_usage,
|
| 303 |
+
is_chat=self.is_chat_model,
|
| 304 |
+
model=self.model_name,
|
| 305 |
+
prompt=prompt,
|
| 306 |
+
**kwargs,
|
| 307 |
+
)
|
| 308 |
+
return generator
|
| 309 |
+
else:
|
| 310 |
+
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
|
| 311 |
+
|
| 312 |
+
def select(
|
| 313 |
+
self,
|
| 314 |
+
s: StreamExecutor,
|
| 315 |
+
choices: List[str],
|
| 316 |
+
temperature: float,
|
| 317 |
+
choices_method: ChoicesSamplingMethod,
|
| 318 |
+
) -> ChoicesDecision:
|
| 319 |
+
"""Note: `choices_method` is not used by the OpenAI backend."""
|
| 320 |
+
if self.is_chat_model:
|
| 321 |
+
raise NotImplementedError(
|
| 322 |
+
"select/choices is not supported for chat models. "
|
| 323 |
+
"Please try to use a non-chat model such as gpt-3.5-turbo-instruct"
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
n_choices = len(choices)
|
| 327 |
+
token_ids = [self.tokenizer.encode(x) for x in choices]
|
| 328 |
+
scores = [0] * n_choices
|
| 329 |
+
valid = [len(x) > 0 for x in token_ids]
|
| 330 |
+
prompt_tokens = self.tokenizer.encode(s.text_)
|
| 331 |
+
|
| 332 |
+
max_len = max([len(x) for x in token_ids])
|
| 333 |
+
for step in range(max_len):
|
| 334 |
+
# Build logit bias
|
| 335 |
+
logit_bias = {}
|
| 336 |
+
for i in range(n_choices):
|
| 337 |
+
if valid[i]:
|
| 338 |
+
logit_bias[token_ids[i][step]] = 100
|
| 339 |
+
|
| 340 |
+
# Call API
|
| 341 |
+
ret = self.client.completions.create(
|
| 342 |
+
model=self.model_name,
|
| 343 |
+
prompt=prompt_tokens,
|
| 344 |
+
logit_bias=logit_bias,
|
| 345 |
+
max_tokens=1,
|
| 346 |
+
temperature=temperature,
|
| 347 |
+
)
|
| 348 |
+
ret_str = ret.choices[0].text
|
| 349 |
+
ret_token = self.tokenizer.encode(ret_str)[0]
|
| 350 |
+
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
|
| 351 |
+
self.token_usage.completion_tokens = ret.usage.completion_tokens
|
| 352 |
+
|
| 353 |
+
# TODO:
|
| 354 |
+
# 1. return logits as the scores
|
| 355 |
+
# 2. compute logits of the full choice
|
| 356 |
+
# 3. consider chunk-based decoding
|
| 357 |
+
|
| 358 |
+
# Update valid
|
| 359 |
+
hit = False
|
| 360 |
+
for i in range(n_choices):
|
| 361 |
+
if valid[i]:
|
| 362 |
+
if step == len(token_ids[i]) - 1:
|
| 363 |
+
valid[i] = False
|
| 364 |
+
|
| 365 |
+
if ret_token == token_ids[i][step]:
|
| 366 |
+
scores[i] += 1
|
| 367 |
+
hit = True
|
| 368 |
+
else:
|
| 369 |
+
valid[i] = False
|
| 370 |
+
assert hit
|
| 371 |
+
|
| 372 |
+
if np.sum(valid) <= 1:
|
| 373 |
+
break
|
| 374 |
+
|
| 375 |
+
prompt_tokens.append(ret_token)
|
| 376 |
+
|
| 377 |
+
return ChoicesDecision(
|
| 378 |
+
decision=choices[np.argmax(scores)],
|
| 379 |
+
meta_info={"scores": scores},
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def openai_completion(
|
| 384 |
+
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
| 385 |
+
) -> Union[str, List[str]]:
|
| 386 |
+
# if "ebnf" is in kwargs, warn and remove
|
| 387 |
+
if "ebnf" in kwargs:
|
| 388 |
+
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
|
| 389 |
+
del kwargs["ebnf"]
|
| 390 |
+
|
| 391 |
+
for attempt in range(retries):
|
| 392 |
+
try:
|
| 393 |
+
if is_chat:
|
| 394 |
+
if "stop" in kwargs and kwargs["stop"] is None:
|
| 395 |
+
kwargs.pop("stop")
|
| 396 |
+
ret = client.chat.completions.create(messages=prompt, **kwargs)
|
| 397 |
+
if len(ret.choices) == 1:
|
| 398 |
+
comp = ret.choices[0].message.content
|
| 399 |
+
else:
|
| 400 |
+
comp = [c.message.content for c in ret.choices]
|
| 401 |
+
else:
|
| 402 |
+
ret = client.completions.create(prompt=prompt, **kwargs)
|
| 403 |
+
if isinstance(prompt, (list, tuple)):
|
| 404 |
+
comp = [c.text for c in ret.choices]
|
| 405 |
+
else:
|
| 406 |
+
comp = ret.choices[0].text
|
| 407 |
+
if len(ret.choices) > 1:
|
| 408 |
+
comp = [c.text for c in ret.choices]
|
| 409 |
+
|
| 410 |
+
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
| 411 |
+
token_usage.completion_tokens += ret.usage.completion_tokens
|
| 412 |
+
break
|
| 413 |
+
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
| 414 |
+
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|
| 415 |
+
time.sleep(5)
|
| 416 |
+
if attempt == retries - 1:
|
| 417 |
+
raise e
|
| 418 |
+
except Exception as e:
|
| 419 |
+
logger.error(f"RuntimeError {e}.")
|
| 420 |
+
raise e
|
| 421 |
+
|
| 422 |
+
return comp
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def openai_completion_stream(
|
| 426 |
+
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
| 427 |
+
):
|
| 428 |
+
# if "ebnf" is in kwargs, warn and remove
|
| 429 |
+
if "ebnf" in kwargs:
|
| 430 |
+
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
|
| 431 |
+
del kwargs["ebnf"]
|
| 432 |
+
|
| 433 |
+
for attempt in range(retries):
|
| 434 |
+
try:
|
| 435 |
+
if is_chat:
|
| 436 |
+
if "stop" in kwargs and kwargs["stop"] is None:
|
| 437 |
+
kwargs.pop("stop")
|
| 438 |
+
generator = client.chat.completions.create(
|
| 439 |
+
messages=prompt,
|
| 440 |
+
stream=True,
|
| 441 |
+
stream_options={"include_usage": True},
|
| 442 |
+
**kwargs,
|
| 443 |
+
)
|
| 444 |
+
for ret in generator:
|
| 445 |
+
if len(ret.choices) == 0:
|
| 446 |
+
continue
|
| 447 |
+
try:
|
| 448 |
+
content = ret.choices[0].delta.content
|
| 449 |
+
except IndexError:
|
| 450 |
+
content = None
|
| 451 |
+
yield content or "", {}
|
| 452 |
+
else:
|
| 453 |
+
generator = client.completions.create(
|
| 454 |
+
prompt=prompt,
|
| 455 |
+
stream=True,
|
| 456 |
+
stream_options={"include_usage": True},
|
| 457 |
+
**kwargs,
|
| 458 |
+
)
|
| 459 |
+
for ret in generator:
|
| 460 |
+
if len(ret.choices) == 0:
|
| 461 |
+
continue
|
| 462 |
+
content = ret.choices[0].text
|
| 463 |
+
yield content or "", {}
|
| 464 |
+
|
| 465 |
+
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
| 466 |
+
token_usage.completion_tokens += ret.usage.completion_tokens
|
| 467 |
+
break
|
| 468 |
+
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
| 469 |
+
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|
| 470 |
+
time.sleep(5)
|
| 471 |
+
if attempt == retries - 1:
|
| 472 |
+
raise e
|
| 473 |
+
except Exception as e:
|
| 474 |
+
logger.error(f"RuntimeError {e}.")
|
| 475 |
+
raise e
|
sglang/lang/backend/runtime_endpoint.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import atexit
|
| 2 |
+
import json
|
| 3 |
+
import multiprocessing
|
| 4 |
+
import warnings
|
| 5 |
+
from typing import Dict, List, Optional, Union
|
| 6 |
+
|
| 7 |
+
import aiohttp
|
| 8 |
+
import requests
|
| 9 |
+
|
| 10 |
+
from sglang.global_config import global_config
|
| 11 |
+
from sglang.lang.backend.base_backend import BaseBackend
|
| 12 |
+
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
|
| 13 |
+
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
| 14 |
+
from sglang.lang.interpreter import StreamExecutor
|
| 15 |
+
from sglang.lang.ir import (
|
| 16 |
+
REGEX_BOOL,
|
| 17 |
+
REGEX_FLOAT,
|
| 18 |
+
REGEX_INT,
|
| 19 |
+
REGEX_STR,
|
| 20 |
+
SglSamplingParams,
|
| 21 |
+
)
|
| 22 |
+
from sglang.utils import http_request
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class RuntimeEndpoint(BaseBackend):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
base_url: str,
|
| 29 |
+
api_key: Optional[str] = None,
|
| 30 |
+
verify: Optional[str] = None,
|
| 31 |
+
chat_template_name: Optional[str] = None,
|
| 32 |
+
):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.support_concate_and_append = True
|
| 35 |
+
|
| 36 |
+
self.base_url = base_url
|
| 37 |
+
self.api_key = api_key
|
| 38 |
+
self.verify = verify
|
| 39 |
+
|
| 40 |
+
res = http_request(
|
| 41 |
+
self.base_url + "/get_model_info",
|
| 42 |
+
api_key=self.api_key,
|
| 43 |
+
verify=self.verify,
|
| 44 |
+
)
|
| 45 |
+
self._assert_success(res)
|
| 46 |
+
self.model_info = res.json()
|
| 47 |
+
|
| 48 |
+
if chat_template_name:
|
| 49 |
+
self.chat_template = get_chat_template(chat_template_name)
|
| 50 |
+
else:
|
| 51 |
+
self.chat_template = get_chat_template_by_model_path(
|
| 52 |
+
self.model_info["model_path"]
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def get_model_name(self):
|
| 56 |
+
return self.model_info["model_path"]
|
| 57 |
+
|
| 58 |
+
def flush_cache(self):
|
| 59 |
+
res = http_request(
|
| 60 |
+
self.base_url + "/flush_cache",
|
| 61 |
+
api_key=self.api_key,
|
| 62 |
+
verify=self.verify,
|
| 63 |
+
method="POST",
|
| 64 |
+
)
|
| 65 |
+
self._assert_success(res)
|
| 66 |
+
|
| 67 |
+
def get_server_info(self):
|
| 68 |
+
res = http_request(
|
| 69 |
+
self.base_url + "/get_server_info",
|
| 70 |
+
api_key=self.api_key,
|
| 71 |
+
verify=self.verify,
|
| 72 |
+
)
|
| 73 |
+
self._assert_success(res)
|
| 74 |
+
return res.json()
|
| 75 |
+
|
| 76 |
+
def get_chat_template(self):
|
| 77 |
+
return self.chat_template
|
| 78 |
+
|
| 79 |
+
def cache_prefix(self, prefix_str: str):
|
| 80 |
+
res = http_request(
|
| 81 |
+
self.base_url + "/generate",
|
| 82 |
+
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
| 83 |
+
api_key=self.api_key,
|
| 84 |
+
verify=self.verify,
|
| 85 |
+
)
|
| 86 |
+
self._assert_success(res)
|
| 87 |
+
|
| 88 |
+
def start_profile(self):
|
| 89 |
+
res = http_request(
|
| 90 |
+
self.base_url + "/start_profile",
|
| 91 |
+
api_key=self.api_key,
|
| 92 |
+
verify=self.verify,
|
| 93 |
+
)
|
| 94 |
+
self._assert_success(res)
|
| 95 |
+
|
| 96 |
+
def stop_profile(self):
|
| 97 |
+
res = http_request(
|
| 98 |
+
self.base_url + "/stop_profile",
|
| 99 |
+
api_key=self.api_key,
|
| 100 |
+
verify=self.verify,
|
| 101 |
+
)
|
| 102 |
+
self._assert_success(res)
|
| 103 |
+
|
| 104 |
+
def commit_lazy_operations(self, s: StreamExecutor):
|
| 105 |
+
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
| 106 |
+
self._add_images(s, data)
|
| 107 |
+
res = http_request(
|
| 108 |
+
self.base_url + "/generate",
|
| 109 |
+
json=data,
|
| 110 |
+
api_key=self.api_key,
|
| 111 |
+
verify=self.verify,
|
| 112 |
+
)
|
| 113 |
+
self._assert_success(res)
|
| 114 |
+
|
| 115 |
+
def fill_image(self, s: StreamExecutor):
|
| 116 |
+
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
| 117 |
+
self._add_images(s, data)
|
| 118 |
+
res = http_request(
|
| 119 |
+
self.base_url + "/generate",
|
| 120 |
+
json=data,
|
| 121 |
+
api_key=self.api_key,
|
| 122 |
+
verify=self.verify,
|
| 123 |
+
)
|
| 124 |
+
self._assert_success(res)
|
| 125 |
+
|
| 126 |
+
def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams):
|
| 127 |
+
if sampling_params.dtype is None:
|
| 128 |
+
return
|
| 129 |
+
|
| 130 |
+
if sampling_params.stop == ():
|
| 131 |
+
sampling_params.stop = []
|
| 132 |
+
|
| 133 |
+
dtype_regex = None
|
| 134 |
+
if sampling_params.dtype in ["int", int]:
|
| 135 |
+
|
| 136 |
+
dtype_regex = REGEX_INT
|
| 137 |
+
sampling_params.stop.extend([" ", "\n"])
|
| 138 |
+
elif sampling_params.dtype in ["float", float]:
|
| 139 |
+
|
| 140 |
+
dtype_regex = REGEX_FLOAT
|
| 141 |
+
sampling_params.stop.extend([" ", "\n"])
|
| 142 |
+
elif sampling_params.dtype in ["str", str]:
|
| 143 |
+
|
| 144 |
+
dtype_regex = REGEX_STR
|
| 145 |
+
elif sampling_params.dtype in ["bool", bool]:
|
| 146 |
+
|
| 147 |
+
dtype_regex = REGEX_BOOL
|
| 148 |
+
else:
|
| 149 |
+
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
| 150 |
+
|
| 151 |
+
if dtype_regex is not None and sampling_params.regex is not None:
|
| 152 |
+
warnings.warn(
|
| 153 |
+
f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
sampling_params.regex = dtype_regex
|
| 157 |
+
|
| 158 |
+
def generate(
|
| 159 |
+
self,
|
| 160 |
+
s: StreamExecutor,
|
| 161 |
+
sampling_params: SglSamplingParams,
|
| 162 |
+
):
|
| 163 |
+
self._handle_dtype_to_regex(sampling_params)
|
| 164 |
+
data = {
|
| 165 |
+
"text": s.text_,
|
| 166 |
+
"sampling_params": {
|
| 167 |
+
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
| 168 |
+
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
| 169 |
+
**sampling_params.to_srt_kwargs(),
|
| 170 |
+
},
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
for item in [
|
| 174 |
+
"return_logprob",
|
| 175 |
+
"logprob_start_len",
|
| 176 |
+
"top_logprobs_num",
|
| 177 |
+
"return_text_in_logprobs",
|
| 178 |
+
]:
|
| 179 |
+
value = getattr(sampling_params, item, None)
|
| 180 |
+
if value is not None:
|
| 181 |
+
data[item] = value
|
| 182 |
+
|
| 183 |
+
self._add_images(s, data)
|
| 184 |
+
|
| 185 |
+
res = http_request(
|
| 186 |
+
self.base_url + "/generate",
|
| 187 |
+
json=data,
|
| 188 |
+
api_key=self.api_key,
|
| 189 |
+
verify=self.verify,
|
| 190 |
+
)
|
| 191 |
+
self._assert_success(res)
|
| 192 |
+
|
| 193 |
+
obj = res.json()
|
| 194 |
+
comp = obj["text"]
|
| 195 |
+
return comp, obj["meta_info"]
|
| 196 |
+
|
| 197 |
+
def generate_stream(
|
| 198 |
+
self,
|
| 199 |
+
s: StreamExecutor,
|
| 200 |
+
sampling_params: SglSamplingParams,
|
| 201 |
+
):
|
| 202 |
+
self._handle_dtype_to_regex(sampling_params)
|
| 203 |
+
|
| 204 |
+
data = {
|
| 205 |
+
"text": s.text_,
|
| 206 |
+
"sampling_params": {
|
| 207 |
+
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
| 208 |
+
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
| 209 |
+
**sampling_params.to_srt_kwargs(),
|
| 210 |
+
},
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
for item in [
|
| 214 |
+
"return_logprob",
|
| 215 |
+
"logprob_start_len",
|
| 216 |
+
"top_logprobs_num",
|
| 217 |
+
"return_text_in_logprobs",
|
| 218 |
+
]:
|
| 219 |
+
value = getattr(sampling_params, item, None)
|
| 220 |
+
if value is not None:
|
| 221 |
+
data[item] = value
|
| 222 |
+
|
| 223 |
+
data["stream"] = True
|
| 224 |
+
self._add_images(s, data)
|
| 225 |
+
|
| 226 |
+
res = http_request(
|
| 227 |
+
self.base_url + "/generate",
|
| 228 |
+
json=data,
|
| 229 |
+
stream=True,
|
| 230 |
+
api_key=self.api_key,
|
| 231 |
+
verify=self.verify,
|
| 232 |
+
)
|
| 233 |
+
self._assert_success(res)
|
| 234 |
+
pos = 0
|
| 235 |
+
|
| 236 |
+
for chunk in res.iter_lines(decode_unicode=False):
|
| 237 |
+
chunk = chunk.decode("utf-8")
|
| 238 |
+
if chunk and chunk.startswith("data:"):
|
| 239 |
+
if chunk == "data: [DONE]":
|
| 240 |
+
break
|
| 241 |
+
data = json.loads(chunk[5:].strip("\n"))
|
| 242 |
+
chunk_text = data["text"][pos:]
|
| 243 |
+
meta_info = data["meta_info"]
|
| 244 |
+
pos += len(chunk_text)
|
| 245 |
+
yield chunk_text, meta_info
|
| 246 |
+
|
| 247 |
+
def select(
|
| 248 |
+
self,
|
| 249 |
+
s: StreamExecutor,
|
| 250 |
+
choices: List[str],
|
| 251 |
+
temperature: float,
|
| 252 |
+
choices_method: ChoicesSamplingMethod,
|
| 253 |
+
) -> ChoicesDecision:
|
| 254 |
+
assert temperature <= 1e-5
|
| 255 |
+
|
| 256 |
+
# Cache common prefix
|
| 257 |
+
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
| 258 |
+
obj = self._generate_http_request(s, data)
|
| 259 |
+
prompt_len = obj["meta_info"]["prompt_tokens"]
|
| 260 |
+
logprob_start_len = max(prompt_len - 2, 0) # For token healing
|
| 261 |
+
|
| 262 |
+
# Compute logprob
|
| 263 |
+
data = {
|
| 264 |
+
"text": [s.text_ + c for c in choices],
|
| 265 |
+
"sampling_params": {
|
| 266 |
+
"max_new_tokens": 0,
|
| 267 |
+
"temperature": 0,
|
| 268 |
+
},
|
| 269 |
+
"return_logprob": True,
|
| 270 |
+
"return_text_in_logprobs": True,
|
| 271 |
+
"logprob_start_len": logprob_start_len,
|
| 272 |
+
}
|
| 273 |
+
obj = self._generate_http_request(s, data)
|
| 274 |
+
|
| 275 |
+
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
| 276 |
+
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
| 277 |
+
normalized_prompt_logprobs = [
|
| 278 |
+
compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"])
|
| 279 |
+
for r in obj
|
| 280 |
+
]
|
| 281 |
+
|
| 282 |
+
# Remove extra token if no token healing occurred
|
| 283 |
+
for i in range(len(input_token_logprobs)):
|
| 284 |
+
healed_token_str = input_token_logprobs[i][0][-1]
|
| 285 |
+
if s.text_.endswith(healed_token_str):
|
| 286 |
+
healed_token_logprob = input_token_logprobs[i][0][0]
|
| 287 |
+
normalized_prompt_logprobs[i] = (
|
| 288 |
+
normalized_prompt_logprobs[i] * len(input_token_logprobs[i])
|
| 289 |
+
- healed_token_logprob
|
| 290 |
+
) / (len(input_token_logprobs[i]) - 1)
|
| 291 |
+
input_token_logprobs[i] = input_token_logprobs[i][1:]
|
| 292 |
+
|
| 293 |
+
# Compute unconditional logprobs if required
|
| 294 |
+
if choices_method.requires_unconditional_logprobs:
|
| 295 |
+
input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
|
| 296 |
+
data = {
|
| 297 |
+
"input_ids": input_ids,
|
| 298 |
+
"sampling_params": {"max_new_tokens": 0},
|
| 299 |
+
"return_logprob": True,
|
| 300 |
+
}
|
| 301 |
+
obj = self._generate_http_request(s, data)
|
| 302 |
+
unconditional_token_logprobs = [
|
| 303 |
+
r["meta_info"]["input_token_logprobs"] for r in obj
|
| 304 |
+
]
|
| 305 |
+
else:
|
| 306 |
+
unconditional_token_logprobs = None
|
| 307 |
+
|
| 308 |
+
return choices_method(
|
| 309 |
+
choices=choices,
|
| 310 |
+
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
| 311 |
+
input_token_logprobs=input_token_logprobs,
|
| 312 |
+
output_token_logprobs=output_token_logprobs,
|
| 313 |
+
unconditional_token_logprobs=unconditional_token_logprobs,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
| 317 |
+
res = http_request(
|
| 318 |
+
self.base_url + "/concate_and_append_request",
|
| 319 |
+
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
| 320 |
+
api_key=self.api_key,
|
| 321 |
+
verify=self.verify,
|
| 322 |
+
)
|
| 323 |
+
self._assert_success(res)
|
| 324 |
+
|
| 325 |
+
def _generate_http_request(self, s: StreamExecutor, data):
|
| 326 |
+
self._add_images(s, data)
|
| 327 |
+
res = http_request(
|
| 328 |
+
self.base_url + "/generate",
|
| 329 |
+
json=data,
|
| 330 |
+
api_key=self.api_key,
|
| 331 |
+
verify=self.verify,
|
| 332 |
+
)
|
| 333 |
+
self._assert_success(res)
|
| 334 |
+
return res.json()
|
| 335 |
+
|
| 336 |
+
def _add_images(self, s: StreamExecutor, data):
|
| 337 |
+
if s.images_:
|
| 338 |
+
assert len(s.images_) == 1, "Only support one image."
|
| 339 |
+
data["image_data"] = s.images_[0][1]
|
| 340 |
+
|
| 341 |
+
def _assert_success(self, res):
|
| 342 |
+
if res.status_code != 200:
|
| 343 |
+
try:
|
| 344 |
+
content = res.json()
|
| 345 |
+
except json.JSONDecodeError:
|
| 346 |
+
content = res.text
|
| 347 |
+
raise RuntimeError(content)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def compute_normalized_prompt_logprobs(input_logprobs):
|
| 351 |
+
values = [x[0] for x in input_logprobs if x[0]]
|
| 352 |
+
return sum(values) / len(values)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class Runtime:
|
| 356 |
+
"""
|
| 357 |
+
A wrapper for the HTTP server.
|
| 358 |
+
This is used for launching the server in a python program without
|
| 359 |
+
using the command line interface.
|
| 360 |
+
|
| 361 |
+
It is mainly used for the frontend language.
|
| 362 |
+
You should use the Engine class if you want to do normal offline processing without the frontend language.
|
| 363 |
+
"""
|
| 364 |
+
|
| 365 |
+
def __init__(
|
| 366 |
+
self,
|
| 367 |
+
log_level: str = "error",
|
| 368 |
+
*args,
|
| 369 |
+
**kwargs,
|
| 370 |
+
):
|
| 371 |
+
"""See the arguments in server_args.py::ServerArgs"""
|
| 372 |
+
# We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run
|
| 373 |
+
# client code without installing SRT server and its dependency if they want.
|
| 374 |
+
from sglang.srt.entrypoints.http_server import launch_server
|
| 375 |
+
from sglang.srt.server_args import ServerArgs
|
| 376 |
+
from sglang.srt.utils import is_port_available
|
| 377 |
+
|
| 378 |
+
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
| 379 |
+
|
| 380 |
+
# Pre-allocate ports
|
| 381 |
+
for port in range(self.server_args.port, 40000):
|
| 382 |
+
if is_port_available(port):
|
| 383 |
+
break
|
| 384 |
+
self.server_args.port = port
|
| 385 |
+
|
| 386 |
+
self.url = self.server_args.url()
|
| 387 |
+
self.generate_url = self.url + "/generate"
|
| 388 |
+
|
| 389 |
+
# NOTE: We store pid instead of proc to fix some issues during __delete__
|
| 390 |
+
self.pid = None
|
| 391 |
+
pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False)
|
| 392 |
+
|
| 393 |
+
ctx = multiprocessing.get_context("spawn")
|
| 394 |
+
proc = ctx.Process(
|
| 395 |
+
target=launch_server,
|
| 396 |
+
args=(self.server_args, pipe_writer),
|
| 397 |
+
)
|
| 398 |
+
proc.start()
|
| 399 |
+
pipe_writer.close()
|
| 400 |
+
self.pid = proc.pid
|
| 401 |
+
|
| 402 |
+
# Before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
| 403 |
+
atexit.register(self.shutdown)
|
| 404 |
+
|
| 405 |
+
# TODO: remove this pipe_writer mechanism and use `/health_generate` instead.
|
| 406 |
+
try:
|
| 407 |
+
init_state = pipe_reader.recv()
|
| 408 |
+
except EOFError:
|
| 409 |
+
init_state = ""
|
| 410 |
+
|
| 411 |
+
if init_state != "ready":
|
| 412 |
+
self.shutdown()
|
| 413 |
+
raise RuntimeError(
|
| 414 |
+
"Initialization failed. Please see the error messages above."
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
self.endpoint = RuntimeEndpoint(self.url)
|
| 418 |
+
|
| 419 |
+
def shutdown(self):
|
| 420 |
+
from sglang.srt.utils import kill_process_tree
|
| 421 |
+
|
| 422 |
+
if self.pid is not None:
|
| 423 |
+
kill_process_tree(self.pid)
|
| 424 |
+
self.pid = None
|
| 425 |
+
|
| 426 |
+
def start_profile(self):
|
| 427 |
+
self.endpoint.start_profile()
|
| 428 |
+
|
| 429 |
+
def stop_profile(self):
|
| 430 |
+
self.endpoint.stop_profile()
|
| 431 |
+
|
| 432 |
+
def cache_prefix(self, prefix: str):
|
| 433 |
+
self.endpoint.cache_prefix(prefix)
|
| 434 |
+
|
| 435 |
+
def get_tokenizer(self):
|
| 436 |
+
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
|
| 437 |
+
|
| 438 |
+
return get_tokenizer(
|
| 439 |
+
self.server_args.tokenizer_path,
|
| 440 |
+
tokenizer_mode=self.server_args.tokenizer_mode,
|
| 441 |
+
trust_remote_code=self.server_args.trust_remote_code,
|
| 442 |
+
revision=self.server_args.revision,
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
async def async_generate(
|
| 446 |
+
self,
|
| 447 |
+
prompt: str,
|
| 448 |
+
sampling_params: Optional[Dict] = None,
|
| 449 |
+
):
|
| 450 |
+
if self.server_args.skip_tokenizer_init:
|
| 451 |
+
json_data = {
|
| 452 |
+
"input_ids": prompt,
|
| 453 |
+
"sampling_params": sampling_params,
|
| 454 |
+
"stream": True,
|
| 455 |
+
}
|
| 456 |
+
else:
|
| 457 |
+
json_data = {
|
| 458 |
+
"text": prompt,
|
| 459 |
+
"sampling_params": sampling_params,
|
| 460 |
+
"stream": True,
|
| 461 |
+
}
|
| 462 |
+
pos = 0
|
| 463 |
+
|
| 464 |
+
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
| 465 |
+
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
| 466 |
+
async with session.post(self.generate_url, json=json_data) as response:
|
| 467 |
+
async for chunk, _ in response.content.iter_chunks():
|
| 468 |
+
chunk = chunk.decode("utf-8")
|
| 469 |
+
if chunk and chunk.startswith("data:"):
|
| 470 |
+
if chunk == "data: [DONE]\n\n":
|
| 471 |
+
break
|
| 472 |
+
data = json.loads(chunk[5:].strip("\n"))
|
| 473 |
+
if "text" in data:
|
| 474 |
+
cur = data["text"][pos:]
|
| 475 |
+
if cur:
|
| 476 |
+
yield cur
|
| 477 |
+
pos += len(cur)
|
| 478 |
+
else:
|
| 479 |
+
yield data
|
| 480 |
+
|
| 481 |
+
add_request = async_generate
|
| 482 |
+
|
| 483 |
+
def generate(
|
| 484 |
+
self,
|
| 485 |
+
prompt: Union[str, List[str]],
|
| 486 |
+
sampling_params: Optional[Dict] = None,
|
| 487 |
+
return_logprob: Optional[Union[List[bool], bool]] = False,
|
| 488 |
+
logprob_start_len: Optional[Union[List[int], int]] = None,
|
| 489 |
+
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
| 490 |
+
lora_path: Optional[List[Optional[str]]] = None,
|
| 491 |
+
):
|
| 492 |
+
json_data = {
|
| 493 |
+
"text": prompt,
|
| 494 |
+
"sampling_params": sampling_params,
|
| 495 |
+
"return_logprob": return_logprob,
|
| 496 |
+
"logprob_start_len": logprob_start_len,
|
| 497 |
+
"top_logprobs_num": top_logprobs_num,
|
| 498 |
+
"lora_path": lora_path,
|
| 499 |
+
}
|
| 500 |
+
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
|
| 501 |
+
response = requests.post(
|
| 502 |
+
self.url + "/generate",
|
| 503 |
+
json=json_data,
|
| 504 |
+
)
|
| 505 |
+
return json.dumps(response.json())
|
| 506 |
+
|
| 507 |
+
def encode(
|
| 508 |
+
self,
|
| 509 |
+
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
| 510 |
+
):
|
| 511 |
+
json_data = {"text": prompt}
|
| 512 |
+
response = requests.post(self.url + "/encode", json=json_data)
|
| 513 |
+
return json.dumps(response.json())
|
| 514 |
+
|
| 515 |
+
async def get_server_info(self):
|
| 516 |
+
async with aiohttp.ClientSession() as session:
|
| 517 |
+
async with session.get(f"{self.url}/get_server_info") as response:
|
| 518 |
+
if response.status == 200:
|
| 519 |
+
return await response.json()
|
| 520 |
+
else:
|
| 521 |
+
error_data = await response.json()
|
| 522 |
+
raise RuntimeError(
|
| 523 |
+
f"Failed to get server info. {error_data['error']['message']}"
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
def __del__(self):
|
| 527 |
+
self.shutdown()
|
sglang/lang/backend/vertexai.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
from sglang.lang.backend.base_backend import BaseBackend
|
| 5 |
+
from sglang.lang.chat_template import get_chat_template
|
| 6 |
+
from sglang.lang.interpreter import StreamExecutor
|
| 7 |
+
from sglang.lang.ir import SglSamplingParams
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
import vertexai
|
| 11 |
+
from vertexai.preview.generative_models import (
|
| 12 |
+
GenerationConfig,
|
| 13 |
+
GenerativeModel,
|
| 14 |
+
Image,
|
| 15 |
+
)
|
| 16 |
+
except ImportError as e:
|
| 17 |
+
GenerativeModel = e
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class VertexAI(BaseBackend):
|
| 21 |
+
def __init__(self, model_name, safety_settings=None):
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
if isinstance(GenerativeModel, Exception):
|
| 25 |
+
raise GenerativeModel
|
| 26 |
+
|
| 27 |
+
project_id = os.environ["GCP_PROJECT_ID"]
|
| 28 |
+
location = os.environ.get("GCP_LOCATION")
|
| 29 |
+
vertexai.init(project=project_id, location=location)
|
| 30 |
+
|
| 31 |
+
self.model_name = model_name
|
| 32 |
+
self.chat_template = get_chat_template("default")
|
| 33 |
+
self.safety_settings = safety_settings
|
| 34 |
+
|
| 35 |
+
def get_chat_template(self):
|
| 36 |
+
return self.chat_template
|
| 37 |
+
|
| 38 |
+
def generate(
|
| 39 |
+
self,
|
| 40 |
+
s: StreamExecutor,
|
| 41 |
+
sampling_params: SglSamplingParams,
|
| 42 |
+
):
|
| 43 |
+
if s.messages_:
|
| 44 |
+
prompt = self.messages_to_vertexai_input(s.messages_)
|
| 45 |
+
else:
|
| 46 |
+
# single-turn
|
| 47 |
+
prompt = (
|
| 48 |
+
self.text_to_vertexai_input(s.text_, s.cur_images)
|
| 49 |
+
if s.cur_images
|
| 50 |
+
else s.text_
|
| 51 |
+
)
|
| 52 |
+
ret = GenerativeModel(self.model_name).generate_content(
|
| 53 |
+
prompt,
|
| 54 |
+
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
| 55 |
+
safety_settings=self.safety_settings,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
comp = ret.text
|
| 59 |
+
|
| 60 |
+
return comp, {}
|
| 61 |
+
|
| 62 |
+
def generate_stream(
|
| 63 |
+
self,
|
| 64 |
+
s: StreamExecutor,
|
| 65 |
+
sampling_params: SglSamplingParams,
|
| 66 |
+
):
|
| 67 |
+
if s.messages_:
|
| 68 |
+
prompt = self.messages_to_vertexai_input(s.messages_)
|
| 69 |
+
else:
|
| 70 |
+
# single-turn
|
| 71 |
+
prompt = (
|
| 72 |
+
self.text_to_vertexai_input(s.text_, s.cur_images)
|
| 73 |
+
if s.cur_images
|
| 74 |
+
else s.text_
|
| 75 |
+
)
|
| 76 |
+
generator = GenerativeModel(self.model_name).generate_content(
|
| 77 |
+
prompt,
|
| 78 |
+
stream=True,
|
| 79 |
+
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
| 80 |
+
safety_settings=self.safety_settings,
|
| 81 |
+
)
|
| 82 |
+
for ret in generator:
|
| 83 |
+
yield ret.text, {}
|
| 84 |
+
|
| 85 |
+
def text_to_vertexai_input(self, text, images):
|
| 86 |
+
input = []
|
| 87 |
+
# split with image token
|
| 88 |
+
text_segs = text.split(self.chat_template.image_token)
|
| 89 |
+
for image_path, image_base64_data in images:
|
| 90 |
+
text_seg = text_segs.pop(0)
|
| 91 |
+
if text_seg != "":
|
| 92 |
+
input.append(text_seg)
|
| 93 |
+
input.append(Image.from_bytes(image_base64_data))
|
| 94 |
+
text_seg = text_segs.pop(0)
|
| 95 |
+
if text_seg != "":
|
| 96 |
+
input.append(text_seg)
|
| 97 |
+
return input
|
| 98 |
+
|
| 99 |
+
def messages_to_vertexai_input(self, messages):
|
| 100 |
+
vertexai_message = []
|
| 101 |
+
# from openai message format to vertexai message format
|
| 102 |
+
for msg in messages:
|
| 103 |
+
if isinstance(msg["content"], str):
|
| 104 |
+
text = msg["content"]
|
| 105 |
+
else:
|
| 106 |
+
text = msg["content"][0]["text"]
|
| 107 |
+
|
| 108 |
+
if msg["role"] == "system":
|
| 109 |
+
warnings.warn("Warning: system prompt is not supported in VertexAI.")
|
| 110 |
+
vertexai_message.append(
|
| 111 |
+
{
|
| 112 |
+
"role": "user",
|
| 113 |
+
"parts": [{"text": "System prompt: " + text}],
|
| 114 |
+
}
|
| 115 |
+
)
|
| 116 |
+
vertexai_message.append(
|
| 117 |
+
{
|
| 118 |
+
"role": "model",
|
| 119 |
+
"parts": [{"text": "Understood."}],
|
| 120 |
+
}
|
| 121 |
+
)
|
| 122 |
+
continue
|
| 123 |
+
if msg["role"] == "user":
|
| 124 |
+
vertexai_msg = {
|
| 125 |
+
"role": "user",
|
| 126 |
+
"parts": [{"text": text}],
|
| 127 |
+
}
|
| 128 |
+
elif msg["role"] == "assistant":
|
| 129 |
+
vertexai_msg = {
|
| 130 |
+
"role": "model",
|
| 131 |
+
"parts": [{"text": text}],
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
# images
|
| 135 |
+
if isinstance(msg["content"], list) and len(msg["content"]) > 1:
|
| 136 |
+
for image in msg["content"][1:]:
|
| 137 |
+
assert image["type"] == "image_url"
|
| 138 |
+
vertexai_msg["parts"].append(
|
| 139 |
+
{
|
| 140 |
+
"inline_data": {
|
| 141 |
+
"data": image["image_url"]["url"].split(",")[1],
|
| 142 |
+
"mime_type": "image/jpeg",
|
| 143 |
+
}
|
| 144 |
+
}
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
vertexai_message.append(vertexai_msg)
|
| 148 |
+
return vertexai_message
|
sglang/lang/chat_template.py
ADDED
|
@@ -0,0 +1,668 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from enum import Enum, auto
|
| 4 |
+
from typing import Callable, Dict, List, Tuple
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ChatTemplateStyle(Enum):
|
| 8 |
+
PLAIN = auto()
|
| 9 |
+
LLAMA2 = auto()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class ChatTemplate:
|
| 14 |
+
name: str
|
| 15 |
+
default_system_prompt: str
|
| 16 |
+
role_prefix_and_suffix: Dict[str, Tuple[str, str]]
|
| 17 |
+
stop_str: List[str] = ()
|
| 18 |
+
image_token: str = "<image>"
|
| 19 |
+
audio_token: str = "<audio>"
|
| 20 |
+
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
|
| 21 |
+
|
| 22 |
+
def get_prefix_and_suffix(
|
| 23 |
+
self, role: str, hist_messages: List[Dict]
|
| 24 |
+
) -> Tuple[str, str]:
|
| 25 |
+
prefix, suffix = self.role_prefix_and_suffix.get(role, ("", ""))
|
| 26 |
+
|
| 27 |
+
if self.style == ChatTemplateStyle.LLAMA2:
|
| 28 |
+
if role == "system" and not hist_messages:
|
| 29 |
+
user_prefix, _ = self.role_prefix_and_suffix.get("user", ("", ""))
|
| 30 |
+
system_prefix, system_suffix = self.role_prefix_and_suffix.get(
|
| 31 |
+
"system", ("", "")
|
| 32 |
+
)
|
| 33 |
+
return (user_prefix + system_prefix, system_suffix)
|
| 34 |
+
elif (
|
| 35 |
+
role == "user"
|
| 36 |
+
and len(hist_messages) == 1
|
| 37 |
+
and hist_messages[0]["content"] is not None
|
| 38 |
+
):
|
| 39 |
+
return ("", suffix)
|
| 40 |
+
|
| 41 |
+
return prefix, suffix
|
| 42 |
+
|
| 43 |
+
def get_prompt(self, messages: List[Dict]) -> str:
|
| 44 |
+
prompt = ""
|
| 45 |
+
for i, message in enumerate(messages):
|
| 46 |
+
role, content = message["role"], message["content"]
|
| 47 |
+
if role == "system" and content is None:
|
| 48 |
+
content = self.default_system_prompt
|
| 49 |
+
if content is None:
|
| 50 |
+
continue
|
| 51 |
+
|
| 52 |
+
prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
|
| 53 |
+
prompt += f"{prefix}{content}{suffix}"
|
| 54 |
+
return prompt
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
chat_template_registry: Dict[str, ChatTemplate] = {}
|
| 58 |
+
matching_function_registry: List[Callable] = []
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def register_chat_template(template):
|
| 62 |
+
chat_template_registry[template.name] = template
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def register_chat_template_matching_function(func):
|
| 66 |
+
matching_function_registry.append(func)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_chat_template(name):
|
| 70 |
+
return chat_template_registry[name]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_chat_template_by_model_path(model_path):
|
| 74 |
+
for matching_func in matching_function_registry:
|
| 75 |
+
template_name = matching_func(model_path)
|
| 76 |
+
if template_name is not None:
|
| 77 |
+
return get_chat_template(template_name)
|
| 78 |
+
return get_chat_template("default")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
register_chat_template(
|
| 82 |
+
ChatTemplate(
|
| 83 |
+
name="default",
|
| 84 |
+
default_system_prompt=None,
|
| 85 |
+
role_prefix_and_suffix={
|
| 86 |
+
"system": ("SYSTEM:", "\n"),
|
| 87 |
+
"user": ("USER:", "\n"),
|
| 88 |
+
"assistant": ("ASSISTANT:", "\n"),
|
| 89 |
+
},
|
| 90 |
+
)
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
register_chat_template(
|
| 94 |
+
ChatTemplate(
|
| 95 |
+
name="claude",
|
| 96 |
+
default_system_prompt=None,
|
| 97 |
+
role_prefix_and_suffix={
|
| 98 |
+
"system": ("", ""),
|
| 99 |
+
"user": ("\n\nHuman: ", ""),
|
| 100 |
+
"assistant": ("\n\nAssistant:", ""),
|
| 101 |
+
},
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
register_chat_template(
|
| 106 |
+
ChatTemplate(
|
| 107 |
+
name="chatml",
|
| 108 |
+
default_system_prompt=None,
|
| 109 |
+
role_prefix_and_suffix={
|
| 110 |
+
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
| 111 |
+
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
| 112 |
+
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
| 113 |
+
},
|
| 114 |
+
style=ChatTemplateStyle.PLAIN,
|
| 115 |
+
stop_str=("<|im_end|>",),
|
| 116 |
+
)
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
register_chat_template(
|
| 120 |
+
ChatTemplate(
|
| 121 |
+
name="chatml-llava",
|
| 122 |
+
default_system_prompt="You are a helpful assistant.",
|
| 123 |
+
role_prefix_and_suffix={
|
| 124 |
+
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
| 125 |
+
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
| 126 |
+
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
| 127 |
+
},
|
| 128 |
+
style=ChatTemplateStyle.PLAIN,
|
| 129 |
+
stop_str=("<|im_end|>",),
|
| 130 |
+
image_token="<image>\n",
|
| 131 |
+
)
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# There is default system prompt for qwen
|
| 135 |
+
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
|
| 136 |
+
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
| 137 |
+
register_chat_template(
|
| 138 |
+
ChatTemplate(
|
| 139 |
+
name="qwen",
|
| 140 |
+
default_system_prompt="You are a helpful assistant.",
|
| 141 |
+
role_prefix_and_suffix={
|
| 142 |
+
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
| 143 |
+
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
| 144 |
+
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
| 145 |
+
},
|
| 146 |
+
style=ChatTemplateStyle.PLAIN,
|
| 147 |
+
stop_str=("<|im_end|>",),
|
| 148 |
+
)
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
|
| 152 |
+
register_chat_template(
|
| 153 |
+
ChatTemplate(
|
| 154 |
+
name="qwen2-vl",
|
| 155 |
+
default_system_prompt="You are a helpful assistant.",
|
| 156 |
+
role_prefix_and_suffix={
|
| 157 |
+
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
| 158 |
+
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
| 159 |
+
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
| 160 |
+
},
|
| 161 |
+
style=ChatTemplateStyle.PLAIN,
|
| 162 |
+
stop_str=("<|im_end|>",),
|
| 163 |
+
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
| 164 |
+
)
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Reference: https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
|
| 168 |
+
register_chat_template(
|
| 169 |
+
ChatTemplate(
|
| 170 |
+
name="vicuna_v1.1",
|
| 171 |
+
default_system_prompt=(
|
| 172 |
+
"A chat between a curious user and an artificial intelligence assistant. "
|
| 173 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
| 174 |
+
),
|
| 175 |
+
role_prefix_and_suffix={
|
| 176 |
+
"system": ("", " "),
|
| 177 |
+
"user": ("USER:", " "),
|
| 178 |
+
"assistant": ("ASSISTANT:", "</s>"),
|
| 179 |
+
},
|
| 180 |
+
image_token=" <image>\n",
|
| 181 |
+
)
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
register_chat_template(
|
| 185 |
+
ChatTemplate(
|
| 186 |
+
name="llama-2-chat",
|
| 187 |
+
default_system_prompt=None,
|
| 188 |
+
role_prefix_and_suffix={
|
| 189 |
+
"system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
|
| 190 |
+
"user": ("[INST] ", " [/INST]"),
|
| 191 |
+
"assistant": ("", " </s><s>"),
|
| 192 |
+
},
|
| 193 |
+
style=ChatTemplateStyle.LLAMA2,
|
| 194 |
+
)
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
|
| 198 |
+
register_chat_template(
|
| 199 |
+
ChatTemplate(
|
| 200 |
+
name="mistral",
|
| 201 |
+
default_system_prompt=None,
|
| 202 |
+
role_prefix_and_suffix={
|
| 203 |
+
"system": ("[SYSTEM_PROMPT] ", " [/SYSTEM_PROMPT]"),
|
| 204 |
+
"user": ("[INST] ", " [/INST]"),
|
| 205 |
+
"assistant": ("", " </s><s>"),
|
| 206 |
+
},
|
| 207 |
+
stop_str=("</s>",),
|
| 208 |
+
image_token="[IMG]",
|
| 209 |
+
)
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
register_chat_template(
|
| 213 |
+
ChatTemplate(
|
| 214 |
+
name="llama-3-instruct",
|
| 215 |
+
default_system_prompt=None,
|
| 216 |
+
role_prefix_and_suffix={
|
| 217 |
+
"system": (
|
| 218 |
+
"<|start_header_id|>system<|end_header_id|>\n\n",
|
| 219 |
+
"<|eot_id|>",
|
| 220 |
+
),
|
| 221 |
+
"user": (
|
| 222 |
+
"<|start_header_id|>user<|end_header_id|>\n\n",
|
| 223 |
+
"<|eot_id|>",
|
| 224 |
+
),
|
| 225 |
+
"assistant": (
|
| 226 |
+
"<|start_header_id|>assistant<|end_header_id|>\n\n",
|
| 227 |
+
"<|eot_id|>",
|
| 228 |
+
),
|
| 229 |
+
},
|
| 230 |
+
stop_str=("<|eot_id|>",),
|
| 231 |
+
image_token="<|image|>",
|
| 232 |
+
)
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# https://huggingface.co/openbmb/MiniCPM-V-2_6
|
| 236 |
+
register_chat_template(
|
| 237 |
+
ChatTemplate(
|
| 238 |
+
name="minicpmv",
|
| 239 |
+
default_system_prompt=None,
|
| 240 |
+
role_prefix_and_suffix={
|
| 241 |
+
"system": ("", " "),
|
| 242 |
+
"user": ("user:", " "),
|
| 243 |
+
"assistant": ("assistant:", "</s>"),
|
| 244 |
+
},
|
| 245 |
+
stop_str=("<|im_end|>", "<|endoftext|>"),
|
| 246 |
+
image_token="(<image>./</image>)",
|
| 247 |
+
)
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
register_chat_template(
|
| 251 |
+
ChatTemplate(
|
| 252 |
+
name="janus-pro",
|
| 253 |
+
default_system_prompt=None,
|
| 254 |
+
role_prefix_and_suffix={
|
| 255 |
+
"system": (
|
| 256 |
+
"",
|
| 257 |
+
"",
|
| 258 |
+
),
|
| 259 |
+
"User": (
|
| 260 |
+
"<|User|>",
|
| 261 |
+
"",
|
| 262 |
+
),
|
| 263 |
+
"assistant": (
|
| 264 |
+
"<|Assistant|>",
|
| 265 |
+
"<|end▁of▁sentence|>",
|
| 266 |
+
),
|
| 267 |
+
},
|
| 268 |
+
stop_str=("<|end▁of▁sentence|>",),
|
| 269 |
+
image_token="<image_placeholder>\n",
|
| 270 |
+
)
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
# https://huggingface.co/openbmb/MiniCPM-o-2_6
|
| 274 |
+
register_chat_template(
|
| 275 |
+
ChatTemplate(
|
| 276 |
+
name="minicpmo",
|
| 277 |
+
default_system_prompt=None,
|
| 278 |
+
role_prefix_and_suffix={
|
| 279 |
+
"system": ("", " "),
|
| 280 |
+
"user": ("user:", " "),
|
| 281 |
+
"assistant": ("assistant:", "</s>"),
|
| 282 |
+
},
|
| 283 |
+
stop_str=("<|im_end|>", "<|endoftext|>"),
|
| 284 |
+
image_token="(<image>./</image>)",
|
| 285 |
+
audio_token="(<audio>./</audio>)",
|
| 286 |
+
)
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
register_chat_template(
|
| 290 |
+
ChatTemplate(
|
| 291 |
+
name="janus",
|
| 292 |
+
default_system_prompt=None,
|
| 293 |
+
role_prefix_and_suffix={
|
| 294 |
+
"system": (
|
| 295 |
+
"",
|
| 296 |
+
"",
|
| 297 |
+
),
|
| 298 |
+
"user": (
|
| 299 |
+
"<|User|>",
|
| 300 |
+
"",
|
| 301 |
+
),
|
| 302 |
+
"assistant": (
|
| 303 |
+
"<|Assistant|>",
|
| 304 |
+
"<|end▁of▁sentence|>",
|
| 305 |
+
),
|
| 306 |
+
},
|
| 307 |
+
stop_str=("<|end▁of▁sentence|>",),
|
| 308 |
+
image_token="<image_placeholder>\n",
|
| 309 |
+
)
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
|
| 313 |
+
register_chat_template(
|
| 314 |
+
ChatTemplate(
|
| 315 |
+
name="llama-3-instruct-llava",
|
| 316 |
+
default_system_prompt=None,
|
| 317 |
+
role_prefix_and_suffix={
|
| 318 |
+
"system": (
|
| 319 |
+
"<|start_header_id|>system<|end_header_id|>\n\n",
|
| 320 |
+
"<|eot_id|>",
|
| 321 |
+
),
|
| 322 |
+
"user": (
|
| 323 |
+
"<|start_header_id|>user<|end_header_id|>\n\n",
|
| 324 |
+
"<|eot_id|>",
|
| 325 |
+
),
|
| 326 |
+
"assistant": (
|
| 327 |
+
"<|start_header_id|>assistant<|end_header_id|>\n\n",
|
| 328 |
+
"<|eot_id|>",
|
| 329 |
+
),
|
| 330 |
+
},
|
| 331 |
+
stop_str=("<|eot_id|>",),
|
| 332 |
+
image_token="<image>\n",
|
| 333 |
+
)
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
# Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
|
| 337 |
+
register_chat_template(
|
| 338 |
+
ChatTemplate(
|
| 339 |
+
name="llama-4",
|
| 340 |
+
default_system_prompt=None,
|
| 341 |
+
role_prefix_and_suffix={
|
| 342 |
+
"system": (
|
| 343 |
+
"<|header_start|>system<|header_end|>\n\n",
|
| 344 |
+
"<|eot|>",
|
| 345 |
+
),
|
| 346 |
+
"user": (
|
| 347 |
+
"<|header_start|>user<|header_end|>\n\n",
|
| 348 |
+
"<|eot|>",
|
| 349 |
+
),
|
| 350 |
+
"assistant": (
|
| 351 |
+
"<|header_start|>assistant<|header_end|>\n\n",
|
| 352 |
+
"<|eot|>",
|
| 353 |
+
),
|
| 354 |
+
},
|
| 355 |
+
stop_str=("<|eot|>",),
|
| 356 |
+
image_token="<|image|>",
|
| 357 |
+
)
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
|
| 361 |
+
register_chat_template(
|
| 362 |
+
ChatTemplate(
|
| 363 |
+
name="yi-1.5",
|
| 364 |
+
default_system_prompt=None,
|
| 365 |
+
role_prefix_and_suffix={
|
| 366 |
+
"system": ("", ""),
|
| 367 |
+
"user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
|
| 368 |
+
"assistant": ("", "<|im_end|>\n"),
|
| 369 |
+
},
|
| 370 |
+
style=ChatTemplateStyle.PLAIN,
|
| 371 |
+
stop_str=("<|im_end|>",),
|
| 372 |
+
)
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
# Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
|
| 376 |
+
register_chat_template(
|
| 377 |
+
ChatTemplate(
|
| 378 |
+
name="yi-vl",
|
| 379 |
+
default_system_prompt=(
|
| 380 |
+
"This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers."
|
| 381 |
+
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。"
|
| 382 |
+
),
|
| 383 |
+
role_prefix_and_suffix={
|
| 384 |
+
"system": ("", "\n\n"),
|
| 385 |
+
"user": ("### Human:", "\n"),
|
| 386 |
+
"assistant": ("### Assistant:", "\n"),
|
| 387 |
+
},
|
| 388 |
+
image_token=" <image_placeholder>\n",
|
| 389 |
+
)
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
register_chat_template(
|
| 393 |
+
ChatTemplate(
|
| 394 |
+
name="gemma-it",
|
| 395 |
+
default_system_prompt=None,
|
| 396 |
+
role_prefix_and_suffix={
|
| 397 |
+
"system": ("", ""),
|
| 398 |
+
"user": ("<start_of_turn>user\n", "<end_of_turn>\n"),
|
| 399 |
+
"assistant": ("<start_of_turn>model\n", "<end_of_turn>\n"),
|
| 400 |
+
},
|
| 401 |
+
style=ChatTemplateStyle.PLAIN,
|
| 402 |
+
)
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
register_chat_template(
|
| 406 |
+
ChatTemplate(
|
| 407 |
+
name="dbrx-instruct",
|
| 408 |
+
default_system_prompt="You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\nYou assist with various tasks, from writing to coding (using markdown for code blocks — remember to use ``` with code, JSON, and tables).\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY.",
|
| 409 |
+
role_prefix_and_suffix={
|
| 410 |
+
"system": ("<|im_start|>system\n", "<|im_end|>"),
|
| 411 |
+
"user": ("\n<|im_start|>user\n", "<|im_end|>"),
|
| 412 |
+
"assistant": ("\n<|im_start|>assistant\n", "<|im_end|>"),
|
| 413 |
+
},
|
| 414 |
+
stop_str=("<|im_end|>",),
|
| 415 |
+
)
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
register_chat_template(
|
| 419 |
+
ChatTemplate(
|
| 420 |
+
name="c4ai-command-r",
|
| 421 |
+
default_system_prompt=None,
|
| 422 |
+
role_prefix_and_suffix={
|
| 423 |
+
"system": (
|
| 424 |
+
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>",
|
| 425 |
+
"<|END_OF_TURN_TOKEN|>",
|
| 426 |
+
),
|
| 427 |
+
"user": ("<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
|
| 428 |
+
"assistant": (
|
| 429 |
+
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
|
| 430 |
+
"<|END_OF_TURN_TOKEN|>",
|
| 431 |
+
),
|
| 432 |
+
},
|
| 433 |
+
style=ChatTemplateStyle.PLAIN,
|
| 434 |
+
)
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
|
| 438 |
+
register_chat_template(
|
| 439 |
+
ChatTemplate(
|
| 440 |
+
name="internvl-2-5",
|
| 441 |
+
default_system_prompt="你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。",
|
| 442 |
+
role_prefix_and_suffix={
|
| 443 |
+
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
| 444 |
+
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
| 445 |
+
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
| 446 |
+
},
|
| 447 |
+
stop_str=["<|im_end|>", "<|action_end|>"],
|
| 448 |
+
)
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
register_chat_template(
|
| 452 |
+
ChatTemplate(
|
| 453 |
+
name="interns1",
|
| 454 |
+
default_system_prompt="You are an AI assistant whose name is Intern-S1 (书生大模型).\n- Intern-S1 (书生大模型) is a vision-language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n- Intern-S1 (书生大模型) can understand and communicate fluently in the language chosen by the user such as English and 中文.\nYou are an expert reasoner with extensive experience in all areas. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within <think>...</think> tags.",
|
| 455 |
+
role_prefix_and_suffix={
|
| 456 |
+
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
| 457 |
+
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
| 458 |
+
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
| 459 |
+
},
|
| 460 |
+
stop_str=["<|im_end|>", "<|action_end|>"],
|
| 461 |
+
)
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
register_chat_template(
|
| 465 |
+
ChatTemplate(
|
| 466 |
+
name="granite-3-instruct",
|
| 467 |
+
default_system_prompt=None,
|
| 468 |
+
role_prefix_and_suffix={
|
| 469 |
+
"system": (
|
| 470 |
+
"<|start_of_role|>system<|end_of_role|>",
|
| 471 |
+
"<|end_of_text|>",
|
| 472 |
+
),
|
| 473 |
+
"user": (
|
| 474 |
+
"<|start_of_role|>user<|end_of_role|>",
|
| 475 |
+
"<|end_of_text|>",
|
| 476 |
+
),
|
| 477 |
+
"assistant": (
|
| 478 |
+
"<|start_of_role|>assistant<|end_of_role|>",
|
| 479 |
+
"<|end_of_text|>",
|
| 480 |
+
),
|
| 481 |
+
},
|
| 482 |
+
stop_str=("<|end_of_text|>",),
|
| 483 |
+
)
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
register_chat_template(
|
| 487 |
+
ChatTemplate(
|
| 488 |
+
name="deepseek-v3",
|
| 489 |
+
default_system_prompt=None,
|
| 490 |
+
role_prefix_and_suffix={
|
| 491 |
+
"system": (
|
| 492 |
+
"",
|
| 493 |
+
"",
|
| 494 |
+
),
|
| 495 |
+
"user": (
|
| 496 |
+
"<|User|>",
|
| 497 |
+
"",
|
| 498 |
+
),
|
| 499 |
+
"assistant": (
|
| 500 |
+
"<|Assistant|>",
|
| 501 |
+
"<|end▁of▁sentence|>",
|
| 502 |
+
),
|
| 503 |
+
},
|
| 504 |
+
stop_str=("<|end▁of▁sentence|>",),
|
| 505 |
+
)
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
# Reference: https://huggingface.co/docs/transformers/main/model_doc/glm4_v#usage-example
|
| 509 |
+
register_chat_template(
|
| 510 |
+
ChatTemplate(
|
| 511 |
+
name="glm-4v",
|
| 512 |
+
default_system_prompt=None,
|
| 513 |
+
role_prefix_and_suffix={
|
| 514 |
+
"system": ("<|system|>\n", "\n"),
|
| 515 |
+
"user": ("<|user|>\n", "\n"),
|
| 516 |
+
"assistant": ("<|assistant|>\n", "\n"),
|
| 517 |
+
},
|
| 518 |
+
style=ChatTemplateStyle.PLAIN,
|
| 519 |
+
stop_str=["<|user|>", "<|endoftext|>", "<|observation|>"],
|
| 520 |
+
image_token="<|image|>",
|
| 521 |
+
)
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
@register_chat_template_matching_function
|
| 526 |
+
def match_deepseek(model_path: str):
|
| 527 |
+
if re.search(r"deepseek-(v3|r1)", model_path, re.IGNORECASE) and not re.search(
|
| 528 |
+
r"base", model_path, re.IGNORECASE
|
| 529 |
+
):
|
| 530 |
+
return "deepseek-v3"
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
@register_chat_template_matching_function
|
| 534 |
+
def match_orion(model_path: str):
|
| 535 |
+
if "orion" in model_path.lower():
|
| 536 |
+
return "claude"
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
@register_chat_template_matching_function
|
| 540 |
+
def match_deepseek_janus_pro(model_path: str):
|
| 541 |
+
if re.search(r"janus", model_path, re.IGNORECASE):
|
| 542 |
+
return "janus-pro"
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
@register_chat_template_matching_function
|
| 546 |
+
def match_dbrx(model_path: str):
|
| 547 |
+
if re.search(r"dbrx", model_path, re.IGNORECASE) and re.search(
|
| 548 |
+
r"instruct", model_path, re.IGNORECASE
|
| 549 |
+
):
|
| 550 |
+
return "dbrx-instruct"
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
@register_chat_template_matching_function
|
| 554 |
+
def match_vicuna(model_path: str):
|
| 555 |
+
if re.search(r"vicuna|llava-v1\.5|llava-next-video-7b", model_path, re.IGNORECASE):
|
| 556 |
+
return "vicuna_v1.1"
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
@register_chat_template_matching_function
|
| 560 |
+
def match_llama2_chat(model_path: str):
|
| 561 |
+
if re.search(
|
| 562 |
+
r"llama-2.*chat|codellama.*instruct",
|
| 563 |
+
model_path,
|
| 564 |
+
re.IGNORECASE,
|
| 565 |
+
):
|
| 566 |
+
return "llama-2-chat"
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
@register_chat_template_matching_function
|
| 570 |
+
def match_mistral(model_path: str):
|
| 571 |
+
if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
|
| 572 |
+
return "mistral"
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
@register_chat_template_matching_function
|
| 576 |
+
def match_llama3_instruct(model_path: str):
|
| 577 |
+
if re.search(r"llama-3.*instruct", model_path, re.IGNORECASE):
|
| 578 |
+
return "llama-3-instruct"
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
@register_chat_template_matching_function
|
| 582 |
+
def match_chat_ml(model_path: str):
|
| 583 |
+
if re.search(r"tinyllama", model_path, re.IGNORECASE):
|
| 584 |
+
return "chatml"
|
| 585 |
+
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
|
| 586 |
+
return "qwen2-vl"
|
| 587 |
+
if re.search(r"glm[-_]?4(\.\d+)?v", model_path, re.IGNORECASE):
|
| 588 |
+
return "glm-4v"
|
| 589 |
+
if re.search(r"qwen.*(chat|instruct)", model_path, re.IGNORECASE) and not re.search(
|
| 590 |
+
r"llava", model_path, re.IGNORECASE
|
| 591 |
+
):
|
| 592 |
+
return "qwen"
|
| 593 |
+
if re.search(
|
| 594 |
+
r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
|
| 595 |
+
model_path,
|
| 596 |
+
re.IGNORECASE,
|
| 597 |
+
):
|
| 598 |
+
return "chatml-llava"
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
@register_chat_template_matching_function
|
| 602 |
+
def match_chat_yi(model_path: str):
|
| 603 |
+
if re.search(r"yi-vl", model_path, re.IGNORECASE) and not re.search(
|
| 604 |
+
r"llava", model_path, re.IGNORECASE
|
| 605 |
+
):
|
| 606 |
+
return "yi-vl"
|
| 607 |
+
elif re.search(r"yi-1\.5.*chat", model_path, re.IGNORECASE):
|
| 608 |
+
return "yi-1.5"
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
@register_chat_template_matching_function
|
| 612 |
+
def match_gemma_it(model_path: str):
|
| 613 |
+
if re.search(r"gemma.*it", model_path, re.IGNORECASE):
|
| 614 |
+
return "gemma-it"
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
@register_chat_template_matching_function
|
| 618 |
+
def match_openbmb_minicpm(model_path: str):
|
| 619 |
+
if re.search(r"minicpm-v", model_path, re.IGNORECASE):
|
| 620 |
+
return "minicpmv"
|
| 621 |
+
elif re.search(r"minicpm-o", model_path, re.IGNORECASE):
|
| 622 |
+
return "minicpmo"
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
@register_chat_template_matching_function
|
| 626 |
+
def match_c4ai_command_r(model_path: str):
|
| 627 |
+
if re.search(r"c4ai-command-r", model_path, re.IGNORECASE):
|
| 628 |
+
return "c4ai-command-r"
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
@register_chat_template_matching_function
|
| 632 |
+
def match_granite_instruct(model_path: str):
|
| 633 |
+
if re.search(r"granite.*instruct", model_path, re.IGNORECASE):
|
| 634 |
+
return "granite-3-instruct"
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
@register_chat_template_matching_function
|
| 638 |
+
def match_gemma3_instruct(model_path: str):
|
| 639 |
+
if re.search(r"gemma-3", model_path, re.IGNORECASE):
|
| 640 |
+
return "gemma-it"
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
@register_chat_template_matching_function
|
| 644 |
+
def match_internvl_chat(model_path: str):
|
| 645 |
+
if re.search(r"internvl2_5", model_path, re.IGNORECASE):
|
| 646 |
+
return "internvl-2-5"
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
@register_chat_template_matching_function
|
| 650 |
+
def match_interns1_chat(model_path: str):
|
| 651 |
+
if re.search(r"intern-s1", model_path, re.IGNORECASE):
|
| 652 |
+
return "interns1"
|
| 653 |
+
if re.search(r"interns1", model_path, re.IGNORECASE):
|
| 654 |
+
return "interns1"
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
if __name__ == "__main__":
|
| 658 |
+
messages = [
|
| 659 |
+
{"role": "system", "content": None}, # None means default
|
| 660 |
+
# {"role": "system", "content": "You are a helpful, respectful and honest assistant."},
|
| 661 |
+
{"role": "user", "content": "Hello!"},
|
| 662 |
+
{"role": "assistant", "content": "Hi!"},
|
| 663 |
+
{"role": "user", "content": "What can you do?"},
|
| 664 |
+
{"role": "assistant", "content": "I can chat with you."},
|
| 665 |
+
]
|
| 666 |
+
|
| 667 |
+
template = get_chat_template("llama-2-chat")
|
| 668 |
+
print(template.get_prompt(messages))
|