Zip Ye commited on
Commit
fa1aa1c
·
1 Parent(s): 399c281

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +8 -11
  2. draft_probe_suite/draft_model/config.json +38 -0
  3. draft_probe_suite/draft_model/model.safetensors +3 -0
  4. draft_probe_suite/pretrained_draft_model/config.json +38 -0
  5. draft_probe_suite/pretrained_draft_model/model.safetensors +3 -0
  6. draft_probe_suite/probe/config.json +1 -0
  7. draft_probe_suite/probe/state_dict.pth +3 -0
  8. sglang/README.md +17 -0
  9. sglang/__init__.py +83 -0
  10. sglang/bench_offline_throughput.py +476 -0
  11. sglang/bench_one_batch.py +795 -0
  12. sglang/bench_one_batch_server.py +605 -0
  13. sglang/bench_serving.py +0 -0
  14. sglang/check_env.py +433 -0
  15. sglang/cli/__init__.py +0 -0
  16. sglang/cli/generate.py +33 -0
  17. sglang/cli/main.py +26 -0
  18. sglang/cli/serve.py +75 -0
  19. sglang/cli/utils.py +152 -0
  20. sglang/compile_deep_gemm.py +191 -0
  21. sglang/eval/llama3_eval.py +315 -0
  22. sglang/eval/loogle_eval.py +164 -0
  23. sglang/global_config.py +29 -0
  24. sglang/jit_kernel/.clang-format +19 -0
  25. sglang/jit_kernel/__pycache__/hicache.cpython-311.pyc +0 -0
  26. sglang/jit_kernel/__pycache__/utils.cpython-311.pyc +0 -0
  27. sglang/jit_kernel/csrc/cuda_wait_value.cuh +38 -0
  28. sglang/jit_kernel/csrc/hicache.cuh +264 -0
  29. sglang/jit_kernel/cuda_wait_value.py +79 -0
  30. sglang/jit_kernel/hicache.py +138 -0
  31. sglang/jit_kernel/include/sgl_kernel/tensor.h +487 -0
  32. sglang/jit_kernel/include/sgl_kernel/utils.cuh +101 -0
  33. sglang/jit_kernel/include/sgl_kernel/utils.h +88 -0
  34. sglang/jit_kernel/include/sgl_kernel/warp.cuh +145 -0
  35. sglang/jit_kernel/utils.py +103 -0
  36. sglang/lang/__pycache__/api.cpython-311.pyc +0 -0
  37. sglang/lang/__pycache__/chat_template.cpython-311.pyc +0 -0
  38. sglang/lang/__pycache__/choices.cpython-311.pyc +0 -0
  39. sglang/lang/__pycache__/interpreter.cpython-311.pyc +0 -0
  40. sglang/lang/__pycache__/ir.cpython-311.pyc +0 -0
  41. sglang/lang/api.py +292 -0
  42. sglang/lang/backend/__pycache__/base_backend.cpython-311.pyc +0 -0
  43. sglang/lang/backend/__pycache__/runtime_endpoint.cpython-311.pyc +0 -0
  44. sglang/lang/backend/anthropic.py +73 -0
  45. sglang/lang/backend/base_backend.py +82 -0
  46. sglang/lang/backend/litellm.py +90 -0
  47. sglang/lang/backend/openai.py +475 -0
  48. sglang/lang/backend/runtime_endpoint.py +527 -0
  49. sglang/lang/backend/vertexai.py +148 -0
  50. sglang/lang/chat_template.py +668 -0
README.md CHANGED
@@ -1,11 +1,10 @@
1
  <div align="center">
2
- <img src="assets/logo.png" alt="SEAGLE Logo" width="200"/>
3
  </div>
4
 
5
- # SEAGLE: Safety-Aware EAGLE
6
 
7
  **SEAGLE** is a safety-aware speculative decoding policy based on [SGLang](https://github.com/sgl-project/sglang). It embeds a lightweight probe model into the draft loop of [EAGLE-3](https://github.com/SafeAILab/EAGLE) speculative decoding, performs real-time safety monitoring on each decoding step, dynamically adjusts draft tokens, and triggers a fallback mechanism when unsafe content is continuously detected.
8
- > [HERE](https://www.modelscope.cn/models/Alibaba-AAIG/SEAGLE) is the link to the source code and model weights.
9
 
10
  ![Leaderboard](assets/leaderboard.jpg)
11
 
@@ -76,7 +75,7 @@ from sglang.srt.server_args import ServerArgs
76
  from sglang.srt.entrypoints.http_server import launch_server as _launch_server
77
 
78
  # =========================================================
79
- # Launch SGlang Server with Safety-Aware Eagle3 Decoding
80
  # =========================================================
81
  MODEL_PATH = "your_qwen3_235b_a22b_instruct_2507_path"
82
  DRAFT_MODEL_PATH = "draft_probe_suite/draft_model"
@@ -246,7 +245,7 @@ We begin by evaluating the acceleration performance of our draft models, encompa
246
 
247
  > **Note:** Our pre-trained draft model can be found [here](https://www.modelscope.cn/models/Alibaba-AAIG/SEAGLE/tree/master/draft_probe_suite/pretrained_draft_model). Compared to the [Meituan](https://modelscope.cn/models/lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge-Meituan) version, our Eagle Head has undergone accelerated training specifically for Chinese. The pre-trained version can be used standalone as an Eagle Head for Qwen3-235B-A22B-Instruct-2507, delivering outstanding acceleration performance in both Chinese and English.
248
 
249
- **Launch with Standard SGLang CLI:**
250
 
251
  ```bash
252
  export SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1
@@ -282,11 +281,9 @@ Evaluate the probe's impact on normal chatting data (query safety & response saf
282
  | :--- | :---: | :---: | :---: |
283
  | FuseChat-Mixture | 50,000 | 0.99506 | 0.00494 |
284
 
285
- > **Note:** Even if the probe occasionally produces false positives, the safety-aware speculative decoding mechanism still ensures that the generated responses are meaningful and valuable.
286
 
287
- ### ⚖️ 3.3 End-to-End Utility and Safety
288
-
289
- The trained probe is integrated into the Eagle3 decoding pipeline. We evaluate the end-to-end utility and safety of the SafeAware decoding strategy using an SGLang single-request configuration.
290
 
291
  #### (1) Utility Performance
292
 
@@ -315,8 +312,8 @@ Safety scores are evaluated based on the discriminative reward model (DRM), gene
315
  | 📎 [Chinese: 100 High-Risk](assets/valuesTest_zh_hard_100.jsonl) | DRM Score | 0.43 | 0.49 | **0.83** |
316
  | | QwQ Score | 0.70 | 0.70 | **0.92** |
317
  | | GRM Score | 0.23 | 0.31 | **0.81** |
318
- | 📊 [English Log](assets/GRM_judge_log_en.xlsx) | Logs | ✅ | - | ✅ |
319
- | 📊 [Chinese Log](assets/GRM_judge_log_zh.xlsx) | Logs | ✅ | - | ✅ |
320
 
321
  ---
322
 
 
1
  <div align="center">
2
+ <img src="assets/logo.png" alt="SEAGLE Logo" width="250"/>
3
  </div>
4
 
5
+ # SEAGLE: Safe-Aware EAGLE
6
 
7
  **SEAGLE** is a safety-aware speculative decoding policy based on [SGLang](https://github.com/sgl-project/sglang). It embeds a lightweight probe model into the draft loop of [EAGLE-3](https://github.com/SafeAILab/EAGLE) speculative decoding, performs real-time safety monitoring on each decoding step, dynamically adjusts draft tokens, and triggers a fallback mechanism when unsafe content is continuously detected.
 
8
 
9
  ![Leaderboard](assets/leaderboard.jpg)
10
 
 
75
  from sglang.srt.entrypoints.http_server import launch_server as _launch_server
76
 
77
  # =========================================================
78
+ # Launch SGlang Server with Safe-Aware Eagle3 Decoding
79
  # =========================================================
80
  MODEL_PATH = "your_qwen3_235b_a22b_instruct_2507_path"
81
  DRAFT_MODEL_PATH = "draft_probe_suite/draft_model"
 
245
 
246
  > **Note:** Our pre-trained draft model can be found [here](https://www.modelscope.cn/models/Alibaba-AAIG/SEAGLE/tree/master/draft_probe_suite/pretrained_draft_model). Compared to the [Meituan](https://modelscope.cn/models/lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge-Meituan) version, our Eagle Head has undergone accelerated training specifically for Chinese. The pre-trained version can be used standalone as an Eagle Head for Qwen3-235B-A22B-Instruct-2507, delivering outstanding acceleration performance in both Chinese and English.
247
 
248
+ **Launch with Standard SGLang Command:**
249
 
250
  ```bash
251
  export SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1
 
281
  | :--- | :---: | :---: | :---: |
282
  | FuseChat-Mixture | 50,000 | 0.99506 | 0.00494 |
283
 
284
+ ### ⚖️ 3.3 Utility and Safety
285
 
286
+ The trained probe is integrated into the Eagle3 decoding pipeline. Using an SGLang & Single Request configuration, the general utility and security of the SafeAware decoding strategy are evaluated.
 
 
287
 
288
  #### (1) Utility Performance
289
 
 
312
  | 📎 [Chinese: 100 High-Risk](assets/valuesTest_zh_hard_100.jsonl) | DRM Score | 0.43 | 0.49 | **0.83** |
313
  | | QwQ Score | 0.70 | 0.70 | **0.92** |
314
  | | GRM Score | 0.23 | 0.31 | **0.81** |
315
+ | 📊 [English Log](assets/GRM_judge_log_en.xlsx) | Evaluation | ✅ | - | ✅ |
316
+ | 📊 [Chinese Log](assets/GRM_judge_log_zh.xlsx) | Evaluation | ✅ | - | ✅ |
317
 
318
  ---
319
 
draft_probe_suite/draft_model/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLMEagle3"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "draft_vocab_size": 32000,
9
+ "dtype": "bfloat16",
10
+ "eagle_config": {
11
+ "eagle_aux_hidden_state_layer_ids": [
12
+ 1,
13
+ 46,
14
+ 90
15
+ ],
16
+ "use_aux_hidden_state": true
17
+ },
18
+ "eos_token_id": 151645,
19
+ "head_dim": 128,
20
+ "hidden_act": "silu",
21
+ "hidden_size": 4096,
22
+ "initializer_range": 0.02,
23
+ "intermediate_size": 24576,
24
+ "max_position_embeddings": 40960,
25
+ "mlp_bias": false,
26
+ "model_type": "llama",
27
+ "num_attention_heads": 64,
28
+ "num_hidden_layers": 1,
29
+ "num_key_value_heads": 4,
30
+ "pretraining_tp": 1,
31
+ "rms_norm_eps": 1e-06,
32
+ "rope_scaling": null,
33
+ "rope_theta": 1000000.0,
34
+ "tie_word_embeddings": false,
35
+ "transformers_version": "4.57.1",
36
+ "use_cache": true,
37
+ "vocab_size": 151936
38
+ }
draft_probe_suite/draft_model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:855968a019ea7ce5d33cdc503737c4fe19fd3aa8a176440d818506bf018f130d
3
+ size 1185333104
draft_probe_suite/pretrained_draft_model/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLMEagle3"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "draft_vocab_size": 32000,
9
+ "dtype": "bfloat16",
10
+ "eagle_config": {
11
+ "eagle_aux_hidden_state_layer_ids": [
12
+ 1,
13
+ 46,
14
+ 90
15
+ ],
16
+ "use_aux_hidden_state": true
17
+ },
18
+ "eos_token_id": 151645,
19
+ "head_dim": 128,
20
+ "hidden_act": "silu",
21
+ "hidden_size": 4096,
22
+ "initializer_range": 0.02,
23
+ "intermediate_size": 24576,
24
+ "max_position_embeddings": 40960,
25
+ "mlp_bias": false,
26
+ "model_type": "llama",
27
+ "num_attention_heads": 64,
28
+ "num_hidden_layers": 1,
29
+ "num_key_value_heads": 4,
30
+ "pretraining_tp": 1,
31
+ "rms_norm_eps": 1e-06,
32
+ "rope_scaling": null,
33
+ "rope_theta": 1000000.0,
34
+ "tie_word_embeddings": false,
35
+ "transformers_version": "4.57.1",
36
+ "use_cache": true,
37
+ "vocab_size": 151936
38
+ }
draft_probe_suite/pretrained_draft_model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4b0c81dfb283d27ab78fbdffffe1b0962d6fc3ef4bce16dbc4d6561ef19a9b1
3
+ size 1185333104
draft_probe_suite/probe/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"input_size": 4096, "output_size": 1, "intermediate_size": 1024}
draft_probe_suite/probe/state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bee888de61324d261765237d29a23f7d4858fb6c571bbe58726321a8be0b8d26
3
+ size 16803511
sglang/README.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code Structure
2
+
3
+ - `eval`: The evaluation utilities.
4
+ - `lang`: The frontend language.
5
+ - `srt`: The backend engine for running local models. (SRT = SGLang Runtime).
6
+ - `test`: The test utilities.
7
+ - `api.py`: The public APIs.
8
+ - `bench_offline_throughput.py`: Benchmark the performance in the offline mode.
9
+ - `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server.
10
+ - `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server.
11
+ - `bench_serving.py`: Benchmark online serving with dynamic requests.
12
+ - `check_env.py`: Check the environment variables and dependencies.
13
+ - `global_config.py`: The global configs and constants.
14
+ - `launch_server.py`: The entry point for launching a local server.
15
+ - `profiler.py`: The profiling entry point to send profile requests.
16
+ - `utils.py`: Common utilities.
17
+ - `version.py`: Version info.
sglang/__init__.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SGLang public APIs
2
+
3
+ # Frontend Language APIs
4
+ from sglang.global_config import global_config
5
+ from sglang.lang.api import (
6
+ Engine,
7
+ Runtime,
8
+ assistant,
9
+ assistant_begin,
10
+ assistant_end,
11
+ flush_cache,
12
+ function,
13
+ gen,
14
+ gen_int,
15
+ gen_string,
16
+ get_server_info,
17
+ image,
18
+ select,
19
+ separate_reasoning,
20
+ set_default_backend,
21
+ system,
22
+ system_begin,
23
+ system_end,
24
+ user,
25
+ user_begin,
26
+ user_end,
27
+ video,
28
+ )
29
+ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
30
+ from sglang.lang.choices import (
31
+ greedy_token_selection,
32
+ token_length_normalized,
33
+ unconditional_likelihood_normalized,
34
+ )
35
+
36
+ # Lazy import some libraries
37
+ from sglang.utils import LazyImport
38
+ from sglang.version import __version__
39
+
40
+ Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
41
+ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
42
+ OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
43
+ VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
44
+
45
+ # Runtime Engine APIs
46
+ ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
47
+ Engine = LazyImport("sglang.srt.entrypoints.engine", "Engine")
48
+
49
+ __all__ = [
50
+ "Engine",
51
+ "Runtime",
52
+ "assistant",
53
+ "assistant_begin",
54
+ "assistant_end",
55
+ "flush_cache",
56
+ "function",
57
+ "gen",
58
+ "gen_int",
59
+ "gen_string",
60
+ "get_server_info",
61
+ "image",
62
+ "select",
63
+ "separate_reasoning",
64
+ "set_default_backend",
65
+ "system",
66
+ "system_begin",
67
+ "system_end",
68
+ "user",
69
+ "user_begin",
70
+ "user_end",
71
+ "video",
72
+ "RuntimeEndpoint",
73
+ "greedy_token_selection",
74
+ "token_length_normalized",
75
+ "unconditional_likelihood_normalized",
76
+ "ServerArgs",
77
+ "Anthropic",
78
+ "LiteLLM",
79
+ "OpenAI",
80
+ "VertexAI",
81
+ "global_config",
82
+ "__version__",
83
+ ]
sglang/bench_offline_throughput.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Benchmark the throughput in the offline mode.
3
+ It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py).
4
+
5
+ # Usage
6
+ ## Sharegpt dataset with default args
7
+ python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10
8
+
9
+ ## Random dataset with default args
10
+ python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024
11
+ """
12
+
13
+ import argparse
14
+ import asyncio
15
+ import dataclasses
16
+ import inspect
17
+ import json
18
+ import logging
19
+ import os
20
+ import random
21
+ import time
22
+ from typing import Dict, List, Optional
23
+
24
+ import numpy as np
25
+
26
+ from sglang.bench_serving import (
27
+ DatasetRow,
28
+ get_dataset,
29
+ get_tokenizer,
30
+ sample_random_requests,
31
+ set_ulimit,
32
+ )
33
+ from sglang.lang.backend.runtime_endpoint import Runtime
34
+ from sglang.srt.entrypoints.engine import Engine
35
+ from sglang.srt.server_args import ServerArgs
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class BenchArgs:
40
+ backend: str = "engine"
41
+ result_filename: str = ""
42
+ dataset_name: str = "sharegpt"
43
+ dataset_path: str = ""
44
+ num_prompts: int = 1000
45
+ sharegpt_output_len: Optional[int] = None
46
+ sharegpt_context_len: Optional[int] = None
47
+ random_input_len: int = 1024
48
+ random_output_len: int = 1024
49
+ random_range_ratio: float = 0.0
50
+ gsp_num_groups: int = 64
51
+ gsp_prompts_per_group: int = 16
52
+ gsp_system_prompt_len: int = 2048
53
+ gsp_question_len: int = 128
54
+ gsp_output_len: int = 256
55
+ seed: int = 1
56
+ disable_ignore_eos: bool = False
57
+ extra_request_body: Optional[str] = None
58
+ apply_chat_template: bool = False
59
+ profile: bool = False
60
+ skip_warmup: bool = False
61
+ do_not_exit: bool = False
62
+ prompt_suffix: str = ""
63
+ return_logprob: bool = False
64
+ logprob_start_len: int = -1
65
+
66
+ @staticmethod
67
+ def add_cli_args(parser: argparse.ArgumentParser):
68
+ parser.add_argument("--backend", type=str, default=BenchArgs.backend)
69
+ parser.add_argument(
70
+ "--result-filename", type=str, default=BenchArgs.result_filename
71
+ )
72
+ parser.add_argument(
73
+ "--dataset-name",
74
+ type=str,
75
+ default="sharegpt",
76
+ choices=["sharegpt", "random", "generated-shared-prefix"],
77
+ help="Name of the dataset to benchmark on.",
78
+ )
79
+ parser.add_argument(
80
+ "--dataset-path", type=str, default="", help="Path to the dataset."
81
+ )
82
+ parser.add_argument(
83
+ "--num-prompts",
84
+ type=int,
85
+ default=BenchArgs.num_prompts,
86
+ help="Number of prompts to process. Default is 1000.",
87
+ )
88
+ parser.add_argument(
89
+ "--sharegpt-output-len",
90
+ type=int,
91
+ default=BenchArgs.sharegpt_output_len,
92
+ help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
93
+ )
94
+ parser.add_argument(
95
+ "--sharegpt-context-len",
96
+ type=int,
97
+ default=BenchArgs.sharegpt_context_len,
98
+ help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
99
+ )
100
+ parser.add_argument(
101
+ "--random-input-len",
102
+ type=int,
103
+ default=BenchArgs.random_input_len,
104
+ help="Number of input tokens per request, used only for random dataset.",
105
+ )
106
+ parser.add_argument(
107
+ "--random-output-len",
108
+ type=int,
109
+ default=BenchArgs.random_output_len,
110
+ help="Number of output tokens per request, used only for random dataset.",
111
+ )
112
+ parser.add_argument(
113
+ "--random-range-ratio",
114
+ type=float,
115
+ default=BenchArgs.random_range_ratio,
116
+ help="Range of sampled ratio of input/output length, "
117
+ "used only for random dataset.",
118
+ )
119
+ parser.add_argument(
120
+ "--gsp-num-groups",
121
+ type=int,
122
+ default=BenchArgs.gsp_num_groups,
123
+ help="Number of groups with shared prefix, used"
124
+ "only for generate-shared-prefix",
125
+ )
126
+ parser.add_argument(
127
+ "--gsp-prompts-per-group",
128
+ type=int,
129
+ default=BenchArgs.gsp_prompts_per_group,
130
+ help="Number of prompts per group of shared prefix, used"
131
+ "only for generate-shared-prefix",
132
+ )
133
+ parser.add_argument(
134
+ "--gsp-system-prompt-len",
135
+ type=int,
136
+ default=BenchArgs.gsp_system_prompt_len,
137
+ help="System prompt length, used" "only for generate-shared-prefix",
138
+ )
139
+ parser.add_argument(
140
+ "--gsp-question-len",
141
+ type=int,
142
+ default=BenchArgs.gsp_question_len,
143
+ help="Question length, used" "only for generate-shared-prefix",
144
+ )
145
+ parser.add_argument(
146
+ "--gsp-output-len",
147
+ type=int,
148
+ default=BenchArgs.gsp_output_len,
149
+ help="Target length in tokens for outputs in generated-shared-prefix dataset",
150
+ )
151
+ parser.add_argument("--seed", type=int, default=1, help="The random seed.")
152
+ parser.add_argument(
153
+ "--disable-ignore-eos",
154
+ action="store_true",
155
+ help="Disable ignore EOS token",
156
+ )
157
+ parser.add_argument(
158
+ "--extra-request-body",
159
+ metavar='{"key1": "value1", "key2": "value2"}',
160
+ type=str,
161
+ default=BenchArgs.extra_request_body,
162
+ help="Append given JSON object to the request payload. You can use this to specify"
163
+ "additional generate params like sampling params.",
164
+ )
165
+ parser.add_argument(
166
+ "--apply-chat-template",
167
+ action="store_true",
168
+ help="Apply chat template",
169
+ )
170
+ parser.add_argument(
171
+ "--profile",
172
+ action="store_true",
173
+ help="Use Torch Profiler. The endpoint must be launched with "
174
+ "SGLANG_TORCH_PROFILER_DIR to enable profiler.",
175
+ )
176
+ parser.add_argument(
177
+ "--skip-warmup",
178
+ action="store_true",
179
+ help="Skip the warmup batches.",
180
+ )
181
+ parser.add_argument(
182
+ "--do-not-exit",
183
+ action="store_true",
184
+ help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
185
+ )
186
+ parser.add_argument(
187
+ "--prompt-suffix",
188
+ type=str,
189
+ default="",
190
+ help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
191
+ )
192
+ parser.add_argument(
193
+ "--return-logprob",
194
+ action="store_true",
195
+ help="Enable returning log probabilities.",
196
+ )
197
+ parser.add_argument(
198
+ "--logprob-start-len",
199
+ type=int,
200
+ default=-1,
201
+ help="Start length for logprob. -1 means only return logprobs for output tokens (default). 0 means return logprobs for all tokens including input.",
202
+ )
203
+
204
+ @classmethod
205
+ def from_cli_args(cls, args: argparse.Namespace):
206
+ attrs = [attr.name for attr in dataclasses.fields(cls)]
207
+ return cls(**{attr: getattr(args, attr) for attr in attrs})
208
+
209
+
210
+ def throughput_test_once(
211
+ backend_name: str,
212
+ backend,
213
+ reqs: List[DatasetRow],
214
+ ignore_eos: bool,
215
+ extra_request_body: Dict,
216
+ profile: bool,
217
+ return_logprob: bool = False,
218
+ logprob_start_len: int = -1,
219
+ ):
220
+ measurement_results = {
221
+ "backend": backend_name,
222
+ "successful_requests": len(reqs),
223
+ "total_latency": -1,
224
+ "total_input_tokens": sum(r.prompt_len for r in reqs),
225
+ "total_output_tokens": -1,
226
+ "request_throughput": -1,
227
+ "input_throughput": -1,
228
+ "output_throughput": -1,
229
+ "total_throughput": -1,
230
+ }
231
+
232
+ prompt = [r.prompt for r in reqs]
233
+ sampling_params = [
234
+ {
235
+ "temperature": 0,
236
+ "max_new_tokens": r.output_len,
237
+ "ignore_eos": ignore_eos,
238
+ **extra_request_body,
239
+ }
240
+ for r in reqs
241
+ ]
242
+
243
+ if profile:
244
+ assert (
245
+ "SGLANG_TORCH_PROFILER_DIR" in os.environ
246
+ ), "Please set SGLANG_TORCH_PROFILER_DIR."
247
+ os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True)
248
+ backend.start_profile()
249
+
250
+ st = time.perf_counter()
251
+ gen_out = backend.generate(
252
+ prompt=prompt,
253
+ sampling_params=sampling_params,
254
+ return_logprob=return_logprob,
255
+ logprob_start_len=logprob_start_len,
256
+ )
257
+ latency = time.perf_counter() - st
258
+
259
+ if profile:
260
+ dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
261
+ known_files = set(os.listdir(dir))
262
+ backend.stop_profile()
263
+ monitor_trace_file(known_files, dir)
264
+
265
+ if backend_name == "runtime":
266
+ gen_out = json.loads(gen_out)
267
+
268
+ server_info = backend.get_server_info()
269
+
270
+ measurement_results["total_latency"] = latency
271
+ measurement_results["total_output_tokens"] = sum(
272
+ o["meta_info"]["completion_tokens"] for o in gen_out
273
+ )
274
+ measurement_results["request_throughput"] = (
275
+ measurement_results["successful_requests"] / latency
276
+ )
277
+ measurement_results["input_throughput"] = (
278
+ measurement_results["total_input_tokens"] / latency
279
+ )
280
+ measurement_results["output_throughput"] = (
281
+ measurement_results["total_output_tokens"] / latency
282
+ )
283
+ measurement_results["total_throughput"] = (
284
+ measurement_results["total_input_tokens"]
285
+ + measurement_results["total_output_tokens"]
286
+ ) / latency
287
+
288
+ if inspect.isawaitable(server_info):
289
+ server_info = asyncio.run(server_info)
290
+
291
+ measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
292
+ "last_gen_throughput"
293
+ ]
294
+
295
+ return measurement_results
296
+
297
+
298
+ def monitor_trace_file(known_files, directory, interval=1):
299
+ print(f"Monitoring {directory} for new trace files...")
300
+
301
+ while True:
302
+ flag = False
303
+ time.sleep(interval)
304
+ current_files = set(os.listdir(directory))
305
+
306
+ new_files = current_files - known_files
307
+ for new_file in new_files:
308
+ new_file_path = os.path.join(directory, new_file)
309
+ print(f"New file detected: {new_file}")
310
+
311
+ previous_size = 0
312
+ while True:
313
+ try:
314
+ current_size = os.path.getsize(new_file_path)
315
+ except FileNotFoundError:
316
+ print(f"File {new_file} is no longer accessible.")
317
+ break
318
+
319
+ if current_size > previous_size:
320
+ previous_size = current_size
321
+ else:
322
+ flag = True
323
+ break
324
+
325
+ time.sleep(interval)
326
+ if flag:
327
+ break
328
+
329
+
330
+ def throughput_test(
331
+ server_args: ServerArgs,
332
+ bench_args: BenchArgs,
333
+ ):
334
+ if bench_args.backend == "engine":
335
+ backend = Engine(**dataclasses.asdict(server_args))
336
+ if not backend:
337
+ raise ValueError("Please provide valid engine arguments")
338
+ elif bench_args.backend == "runtime":
339
+ backend = Runtime(**dataclasses.asdict(server_args))
340
+ else:
341
+ raise ValueError('Please set backend to either "engine" or "runtime"')
342
+
343
+ tokenizer_id = server_args.tokenizer_path or server_args.model_path
344
+ tokenizer = get_tokenizer(tokenizer_id)
345
+
346
+ # Set global environments
347
+ set_ulimit()
348
+ random.seed(bench_args.seed)
349
+ np.random.seed(bench_args.seed)
350
+
351
+ # Parse args
352
+ extra_request_body = {}
353
+ if bench_args.extra_request_body:
354
+ extra_request_body = json.loads(args.extra_request_body)
355
+
356
+ # Read dataset
357
+ input_requests = get_dataset(bench_args, tokenizer)
358
+
359
+ warmup_requests = sample_random_requests(
360
+ input_len=256,
361
+ output_len=16,
362
+ num_prompts=min(bench_args.num_prompts, 16),
363
+ range_ratio=1.0,
364
+ tokenizer=tokenizer,
365
+ dataset_path=bench_args.dataset_path,
366
+ )
367
+
368
+ # Warm up
369
+ if not bench_args.skip_warmup:
370
+ logging.info("\nWarmup...")
371
+ throughput_test_once(
372
+ backend_name=bench_args.backend,
373
+ backend=backend,
374
+ reqs=warmup_requests,
375
+ ignore_eos=not bench_args.disable_ignore_eos,
376
+ extra_request_body=extra_request_body,
377
+ profile=False,
378
+ return_logprob=bench_args.return_logprob,
379
+ logprob_start_len=bench_args.logprob_start_len,
380
+ )
381
+ time.sleep(0.5)
382
+
383
+ logging.info("\nBenchmark...")
384
+ result = throughput_test_once(
385
+ backend_name=bench_args.backend,
386
+ backend=backend,
387
+ reqs=input_requests,
388
+ ignore_eos=not bench_args.disable_ignore_eos,
389
+ extra_request_body=extra_request_body,
390
+ profile=bench_args.profile,
391
+ return_logprob=bench_args.return_logprob,
392
+ logprob_start_len=bench_args.logprob_start_len,
393
+ )
394
+ backend.shutdown()
395
+
396
+ if bench_args.result_filename:
397
+ with open(bench_args.result_filename, "a") as fout:
398
+ fout.write(json.dumps(result) + "\n")
399
+
400
+ print(
401
+ "\n{s:{c}^{n}}".format(s=" Offline Throughput Benchmark Result ", n=50, c="=")
402
+ )
403
+ print("{:<40} {:<10}".format("Backend:", result["backend"]))
404
+ print("{:<40} {:<10}".format("Successful requests:", result["successful_requests"]))
405
+ print("{:<40} {:<10.2f}".format("Benchmark duration (s):", result["total_latency"]))
406
+ print("{:<40} {:<10}".format("Total input tokens:", result["total_input_tokens"]))
407
+ print(
408
+ "{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
409
+ )
410
+ print(
411
+ "{:<40} {:<10.2f}".format(
412
+ "Last generation throughput (tok/s):", result["last_gen_throughput"]
413
+ )
414
+ )
415
+ print(
416
+ "{:<40} {:<10.2f}".format(
417
+ "Request throughput (req/s):", result["request_throughput"]
418
+ )
419
+ )
420
+ print(
421
+ "{:<40} {:<10.2f}".format(
422
+ "Input token throughput (tok/s):", result["input_throughput"]
423
+ )
424
+ )
425
+ print(
426
+ "{:<40} {:<10.2f}".format(
427
+ "Output token throughput (tok/s):", result["output_throughput"]
428
+ )
429
+ )
430
+ print(
431
+ "{:<40} {:<10.2f}".format(
432
+ "Total token throughput (tok/s):", result["total_throughput"]
433
+ )
434
+ )
435
+ print("=" * 50)
436
+
437
+ return result
438
+
439
+
440
+ if __name__ == "__main__":
441
+ parser = argparse.ArgumentParser()
442
+ ServerArgs.add_cli_args(parser)
443
+ BenchArgs.add_cli_args(parser)
444
+ args = parser.parse_args()
445
+
446
+ # handling ModelScope model downloads
447
+ if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() in ("true", "1"):
448
+ if os.path.exists(args.model_path):
449
+ print(f"Using local model path: {args.model_path}")
450
+ else:
451
+ try:
452
+ from modelscope import snapshot_download
453
+
454
+ print(f"Using ModelScope to download model: {args.model_path}")
455
+
456
+ # download the model and replace args.model_path
457
+ args.model_path = snapshot_download(
458
+ args.model_path,
459
+ )
460
+ print(f"Model downloaded to: {args.model_path}")
461
+ except Exception as e:
462
+ print(f"ModelScope download failed: {str(e)}")
463
+ raise e
464
+
465
+ server_args = ServerArgs.from_cli_args(args)
466
+ bench_args = BenchArgs.from_cli_args(args)
467
+
468
+ logging.basicConfig(
469
+ level=getattr(logging, server_args.log_level.upper()),
470
+ format="%(message)s",
471
+ )
472
+
473
+ throughput_test(server_args, bench_args)
474
+
475
+ while bench_args.do_not_exit:
476
+ pass
sglang/bench_one_batch.py ADDED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Benchmark the latency of running a single static batch without a server.
3
+
4
+ This script does not launch a server and uses the low-level APIs.
5
+ It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
6
+
7
+ # Usage (latency test)
8
+ ## with dummy weights:
9
+ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
10
+ ## sweep through multiple data points and store (append) the results in a jsonl file:
11
+ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
12
+ ## run with profiling:
13
+ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile
14
+ ## run with profiling to custom directory:
15
+ export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log
16
+ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile
17
+ ## run with CUDA profiler (nsys):
18
+ nsys profile --force-overwrite=true -o bench_one_batch python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile --profile-activities CUDA_PROFILER
19
+ # Usage (correctness test):
20
+ python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
21
+
22
+ ## Reference output (of the correctness test above, can be gpu dependent):
23
+ input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]
24
+
25
+ prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
26
+ [-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
27
+ [ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]],
28
+ device='cuda:0')
29
+
30
+ prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141],
31
+ [-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781],
32
+ [-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]],
33
+ device='cuda:0')
34
+
35
+ ========== Prompt 0 ==========
36
+ <s> The capital of France is Paris.
37
+ The capital of the United States is Washington, D.C.
38
+
39
+
40
+ ========== Prompt 1 ==========
41
+ <s> The capital of the United Kindom is London.
42
+ The capital of the United Kingdom is London.
43
+ The capital of the
44
+
45
+ ========== Prompt 2 ==========
46
+ <s> Today is a sunny day and I like to go for a walk in the park.
47
+ I'm going to the park
48
+ """
49
+
50
+ import argparse
51
+ import copy
52
+ import dataclasses
53
+ import itertools
54
+ import json
55
+ import logging
56
+ import multiprocessing
57
+ import os
58
+ import time
59
+ from types import SimpleNamespace
60
+ from typing import Tuple
61
+
62
+ import numpy as np
63
+ import torch
64
+ import torch.distributed as dist
65
+
66
+ from sglang.srt.configs.model_config import ModelConfig
67
+ from sglang.srt.distributed.parallel_state import destroy_distributed_environment
68
+ from sglang.srt.entrypoints.engine import _set_envs_and_config
69
+ from sglang.srt.layers.moe import initialize_moe_config
70
+ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
71
+ from sglang.srt.managers.scheduler_dp_attn_mixin import prepare_mlp_sync_batch_raw
72
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
73
+ from sglang.srt.model_executor.model_runner import ModelRunner
74
+ from sglang.srt.sampling.sampling_params import SamplingParams
75
+ from sglang.srt.server_args import PortArgs, ServerArgs
76
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
77
+ from sglang.srt.utils import (
78
+ configure_logger,
79
+ get_bool_env_var,
80
+ is_cuda_alike,
81
+ is_xpu,
82
+ kill_process_tree,
83
+ maybe_reindex_device_id,
84
+ require_mlp_sync,
85
+ require_mlp_tp_gather,
86
+ set_gpu_proc_affinity,
87
+ suppress_other_loggers,
88
+ )
89
+ from sglang.srt.utils.hf_transformers_utils import get_tokenizer
90
+
91
+ profile_activities = [torch.profiler.ProfilerActivity.CPU] + [
92
+ profiler_activity
93
+ for available, profiler_activity in [
94
+ (is_cuda_alike(), torch.profiler.ProfilerActivity.CUDA),
95
+ (is_xpu(), torch.profiler.ProfilerActivity.XPU),
96
+ ]
97
+ if available
98
+ ]
99
+
100
+
101
+ def start_profile(profile_activities, profile_record_shapes=False, rank_print=print):
102
+ """
103
+ Abstracted function to start profiling based on profile_activities.
104
+ Returns profiler object (or None).
105
+ """
106
+ if "CUDA_PROFILER" in profile_activities:
107
+ try:
108
+ torch.cuda.cudart().cudaProfilerStart()
109
+ rank_print("CUDA Profiler started (nsys will begin capturing)")
110
+ except Exception as e:
111
+ rank_print(f"Failed to start CUDA profiler: {e}")
112
+ return None
113
+ else:
114
+ activities = []
115
+ if "CPU" in profile_activities:
116
+ activities.append(torch.profiler.ProfilerActivity.CPU)
117
+ if "GPU" in profile_activities:
118
+ activities.append(torch.profiler.ProfilerActivity.CUDA)
119
+ if activities:
120
+ profiler = torch.profiler.profile(
121
+ activities=activities,
122
+ with_stack=True,
123
+ record_shapes=profile_record_shapes,
124
+ )
125
+ profiler.start()
126
+ return profiler
127
+ return None
128
+
129
+
130
+ def stop_profile(
131
+ profiler,
132
+ profile_activities,
133
+ rank_print=print,
134
+ save_trace=False,
135
+ trace_filename=None,
136
+ stage=None,
137
+ ):
138
+ """
139
+ Abstracted function to stop profiling based on profile_activities.
140
+ Optionally saves trace results and prints completion messages.
141
+ """
142
+ if "CUDA_PROFILER" in profile_activities:
143
+ try:
144
+ torch.cuda.cudart().cudaProfilerStop()
145
+ rank_print("CUDA Profiler stopped (nsys should dump traces)")
146
+ except Exception as e:
147
+ rank_print(f"Failed to stop CUDA profiler: {e}")
148
+ elif profiler is not None:
149
+ profiler.stop()
150
+
151
+ if save_trace:
152
+ if profiler is not None:
153
+ if trace_filename:
154
+ _save_profile_trace_results(profiler, trace_filename)
155
+ stage_desc = f"for {stage}" if stage else ""
156
+ rank_print(
157
+ f"torch profiler chrome trace {stage_desc} saved to {trace_filename}"
158
+ )
159
+ if "CUDA_PROFILER" in profile_activities:
160
+ rank_print(f"CUDA profiler trace for {stage} completed")
161
+
162
+
163
+ @dataclasses.dataclass
164
+ class BenchArgs:
165
+ run_name: str = "default"
166
+ batch_size: Tuple[int] = (1,)
167
+ input_len: Tuple[int] = (1024,)
168
+ output_len: Tuple[int] = (16,)
169
+ prompt_filename: str = ""
170
+ result_filename: str = "result.jsonl"
171
+ correctness_test: bool = False
172
+ # This is only used for correctness test
173
+ cut_len: int = 4
174
+ log_decode_step: int = 0
175
+ profile: bool = False
176
+ profile_record_shapes: bool = False
177
+ profile_activities: Tuple[str] = ("CPU", "GPU")
178
+ profile_stage: str = "all"
179
+ profile_filename_prefix: str = "profile"
180
+
181
+ @staticmethod
182
+ def add_cli_args(parser: argparse.ArgumentParser):
183
+ parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
184
+ parser.add_argument(
185
+ "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
186
+ )
187
+ parser.add_argument(
188
+ "--input-len", type=int, nargs="+", default=BenchArgs.input_len
189
+ )
190
+ parser.add_argument(
191
+ "--output-len", type=int, nargs="+", default=BenchArgs.output_len
192
+ )
193
+ parser.add_argument(
194
+ "--prompt-filename", type=str, default=BenchArgs.prompt_filename
195
+ )
196
+ parser.add_argument(
197
+ "--result-filename", type=str, default=BenchArgs.result_filename
198
+ )
199
+ parser.add_argument("--correctness-test", action="store_true")
200
+ parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
201
+ parser.add_argument(
202
+ "--log-decode-step",
203
+ type=int,
204
+ default=BenchArgs.log_decode_step,
205
+ help="Log decode latency by step, default is set to zero to disable.",
206
+ )
207
+ parser.add_argument("--profile", action="store_true", help="Enable profiling.")
208
+ parser.add_argument(
209
+ "--profile-record-shapes",
210
+ action="store_true",
211
+ help="Record tensor shapes in profiling results.",
212
+ )
213
+ parser.add_argument(
214
+ "--profile-activities",
215
+ type=str,
216
+ nargs="+",
217
+ default=["CPU", "GPU"],
218
+ choices=["CPU", "GPU", "CUDA_PROFILER"],
219
+ help="Profiler activities: CPU, GPU, CUDA_PROFILER. If CPU/GPU, use torch profiler. If CUDA_PROFILER, use CUDA profiler.",
220
+ )
221
+ parser.add_argument(
222
+ "--profile-stage",
223
+ type=str,
224
+ default=BenchArgs.profile_stage,
225
+ choices=["all", "prefill", "decode"],
226
+ help="Which stage to profile: all, prefill, or decode only.",
227
+ )
228
+ parser.add_argument(
229
+ "--profile-filename-prefix",
230
+ type=str,
231
+ default=BenchArgs.profile_filename_prefix,
232
+ help="Prefix of the profiling file names. The full profiling result file(s) be "
233
+ '"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"',
234
+ )
235
+
236
+ @classmethod
237
+ def from_cli_args(cls, args: argparse.Namespace):
238
+ # use the default value's type to cast the args into correct types.
239
+ attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
240
+ return cls(
241
+ **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
242
+ )
243
+
244
+
245
+ def load_model(server_args, port_args, gpu_id, tp_rank):
246
+ suppress_other_loggers()
247
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
248
+ moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
249
+
250
+ model_config = ModelConfig.from_server_args(server_args)
251
+ model_runner = ModelRunner(
252
+ model_config=model_config,
253
+ mem_fraction_static=server_args.mem_fraction_static,
254
+ gpu_id=gpu_id,
255
+ tp_rank=tp_rank,
256
+ tp_size=server_args.tp_size,
257
+ moe_ep_rank=moe_ep_rank,
258
+ moe_ep_size=server_args.ep_size,
259
+ pp_rank=0,
260
+ pp_size=1,
261
+ nccl_port=port_args.nccl_port,
262
+ server_args=server_args,
263
+ )
264
+ rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
265
+ tokenizer = get_tokenizer(
266
+ server_args.tokenizer_path,
267
+ tokenizer_mode=server_args.tokenizer_mode,
268
+ trust_remote_code=server_args.trust_remote_code,
269
+ )
270
+ if server_args.tp_size > 1:
271
+ dist.barrier()
272
+ return model_runner, tokenizer
273
+
274
+
275
+ def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
276
+ prompts = (
277
+ custom_prompts
278
+ if custom_prompts
279
+ else [
280
+ "The capital of France is",
281
+ "The capital of the United Kindom is",
282
+ "Today is a sunny day and I like",
283
+ ]
284
+ )
285
+ input_ids = [tokenizer.encode(p) for p in prompts]
286
+ sampling_params = SamplingParams(
287
+ temperature=0,
288
+ max_new_tokens=BenchArgs.output_len,
289
+ )
290
+
291
+ reqs = []
292
+ for i in range(len(prompts)):
293
+ assert len(input_ids[i]) > bench_args.cut_len
294
+
295
+ tmp_input_ids = input_ids[i][: bench_args.cut_len]
296
+ req = Req(
297
+ rid=i,
298
+ origin_input_text=prompts[i],
299
+ origin_input_ids=tmp_input_ids,
300
+ sampling_params=sampling_params,
301
+ )
302
+ req.fill_ids = req.origin_input_ids
303
+ req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
304
+ req.logprob_start_len = len(req.origin_input_ids) - 1
305
+ reqs.append(req)
306
+
307
+ return input_ids, reqs
308
+
309
+
310
+ def prepare_extend_inputs_for_correctness_test(
311
+ bench_args, input_ids, reqs, model_runner
312
+ ):
313
+ for i in range(len(reqs)):
314
+ req = reqs[i]
315
+ req.fill_ids += input_ids[i][bench_args.cut_len :]
316
+ req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
317
+ i, : bench_args.cut_len
318
+ ]
319
+ req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
320
+ req.logprob_start_len = len(req.origin_input_ids) - 1
321
+ return reqs
322
+
323
+
324
+ def prepare_synthetic_inputs_for_latency_test(
325
+ batch_size, input_len, custom_inputs=None
326
+ ):
327
+ input_ids = (
328
+ custom_inputs
329
+ if custom_inputs
330
+ else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
331
+ )
332
+ sampling_params = SamplingParams(
333
+ temperature=0,
334
+ max_new_tokens=BenchArgs.output_len,
335
+ )
336
+
337
+ reqs = []
338
+ for i in range(len(input_ids)):
339
+ req = Req(
340
+ rid=i,
341
+ origin_input_text="",
342
+ origin_input_ids=list(input_ids[i]),
343
+ sampling_params=sampling_params,
344
+ )
345
+ req.fill_ids = req.origin_input_ids
346
+ req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
347
+ req.logprob_start_len = len(req.origin_input_ids) - 1
348
+ reqs.append(req)
349
+
350
+ return reqs
351
+
352
+
353
+ @torch.no_grad
354
+ def extend(reqs, model_runner):
355
+ # Create dummy tree_cache for benchmarks (no prefix caching, just allocation)
356
+ dummy_tree_cache = SimpleNamespace(
357
+ page_size=model_runner.server_args.page_size,
358
+ device=model_runner.device,
359
+ token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
360
+ )
361
+
362
+ batch = ScheduleBatch.init_new(
363
+ reqs=reqs,
364
+ req_to_token_pool=model_runner.req_to_token_pool,
365
+ token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
366
+ tree_cache=dummy_tree_cache,
367
+ model_config=model_runner.model_config,
368
+ enable_overlap=False,
369
+ spec_algorithm=SpeculativeAlgorithm.NONE,
370
+ )
371
+ batch.prepare_for_extend()
372
+ _maybe_prepare_mlp_sync_batch(batch, model_runner)
373
+ model_worker_batch = batch.get_model_worker_batch()
374
+ forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
375
+ logits_output, _ = model_runner.forward(forward_batch)
376
+ next_token_ids = model_runner.sample(logits_output, forward_batch)
377
+ return next_token_ids, logits_output.next_token_logits, batch
378
+
379
+
380
+ @torch.no_grad
381
+ def decode(input_token_ids, batch, model_runner):
382
+ batch.output_ids = input_token_ids
383
+ batch.prepare_for_decode()
384
+ _maybe_prepare_mlp_sync_batch(batch, model_runner)
385
+ model_worker_batch = batch.get_model_worker_batch()
386
+ forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
387
+ logits_output, _ = model_runner.forward(forward_batch)
388
+ next_token_ids = model_runner.sample(logits_output, forward_batch)
389
+ return next_token_ids, logits_output.next_token_logits
390
+
391
+
392
+ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
393
+ if require_mlp_sync(model_runner.server_args):
394
+ prepare_mlp_sync_batch_raw(
395
+ batch,
396
+ dp_size=model_runner.server_args.dp_size,
397
+ attn_tp_size=1,
398
+ tp_group=model_runner.tp_group,
399
+ get_idle_batch=None,
400
+ disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
401
+ require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
402
+ disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
403
+ offload_tags=set(),
404
+ )
405
+
406
+
407
+ def _read_prompts_from_file(prompt_file, rank_print):
408
+ """Read custom prompts from the file specified by `--prompt-filename`."""
409
+ if not prompt_file:
410
+ return []
411
+ if not os.path.exists(prompt_file):
412
+ rank_print(
413
+ f"Custom prompt file {prompt_file} not found. Using default inputs..."
414
+ )
415
+ return []
416
+ with open(prompt_file, "r") as pf:
417
+ return pf.readlines()
418
+
419
+
420
+ def _get_torch_profiler_output_dir():
421
+ return os.environ.get("SGLANG_TORCH_PROFILER_DIR", "/tmp")
422
+
423
+
424
+ def _create_torch_profiler_filename(
425
+ profile_filename_prefix, batch_size, input_len, output_len, stage
426
+ ):
427
+ output_dir = _get_torch_profiler_output_dir()
428
+ filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_{stage}.trace.json.gz"
429
+ return os.path.join(output_dir, filename)
430
+
431
+
432
+ def _save_profile_trace_results(profiler, filename):
433
+ parent_dir = os.path.dirname(os.path.abspath(filename))
434
+ os.makedirs(parent_dir, exist_ok=True)
435
+ profiler.export_chrome_trace(filename)
436
+ print(
437
+ profiler.key_averages(group_by_input_shape=True).table(
438
+ sort_by="self_cpu_time_total"
439
+ )
440
+ )
441
+
442
+
443
+ def correctness_test(
444
+ server_args,
445
+ port_args,
446
+ bench_args,
447
+ gpu_id,
448
+ tp_rank,
449
+ ):
450
+ # Configure the logger
451
+ configure_logger(server_args, prefix=f" TP{tp_rank}")
452
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
453
+
454
+ # Load the model
455
+ model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank)
456
+
457
+ # Prepare inputs
458
+ custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
459
+ input_ids, reqs = prepare_inputs_for_correctness_test(
460
+ bench_args, tokenizer, custom_prompts
461
+ )
462
+ rank_print(f"\n{input_ids=}\n")
463
+
464
+ if bench_args.cut_len > 0:
465
+ # Prefill
466
+ next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
467
+ rank_print(f"prefill logits (first half): {next_token_logits} \n")
468
+
469
+ # Prepare extend inputs
470
+ reqs = prepare_extend_inputs_for_correctness_test(
471
+ bench_args, input_ids, reqs, model_runner
472
+ )
473
+
474
+ # Extend (prefill w/ KV cache)
475
+ next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
476
+ rank_print(f"prefill logits (final): {next_token_logits} \n")
477
+
478
+ # Decode
479
+ output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
480
+ for _ in range(bench_args.output_len[0] - 1):
481
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
482
+ next_token_ids_list = next_token_ids.tolist()
483
+ for i in range(len(reqs)):
484
+ output_ids[i].append(next_token_ids_list[i])
485
+
486
+ # Print output texts
487
+ for i in range(len(reqs)):
488
+ rank_print(f"========== Prompt {i} ==========")
489
+ rank_print(tokenizer.decode(output_ids[i]), "\n")
490
+
491
+
492
+ def synchronize(device):
493
+ torch.get_device_module(device).synchronize()
494
+
495
+
496
+ def latency_test_run_once(
497
+ run_name,
498
+ model_runner,
499
+ rank_print,
500
+ reqs,
501
+ batch_size,
502
+ input_len,
503
+ output_len,
504
+ device,
505
+ log_decode_step,
506
+ profile,
507
+ profile_record_shapes,
508
+ profile_activities,
509
+ profile_filename_prefix,
510
+ profile_stage,
511
+ tp_rank,
512
+ ):
513
+ max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
514
+ if batch_size > max_batch_size:
515
+ rank_print(
516
+ f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
517
+ )
518
+ return
519
+
520
+ model_runner.req_to_token_pool.clear()
521
+ model_runner.token_to_kv_pool_allocator.clear()
522
+
523
+ measurement_results = {
524
+ "run_name": run_name,
525
+ "batch_size": batch_size,
526
+ "input_len": input_len,
527
+ "output_len": output_len,
528
+ }
529
+
530
+ tot_latency = 0
531
+
532
+ profiler = None
533
+ enable_profile_prefill = profile and profile_stage in ["all", "prefill"]
534
+ if enable_profile_prefill:
535
+ profiler = start_profile(
536
+ profile_activities,
537
+ profile_record_shapes=profile_record_shapes,
538
+ rank_print=rank_print,
539
+ )
540
+
541
+ synchronize(device)
542
+ tic = time.perf_counter()
543
+ next_token_ids, _, batch = extend(reqs, model_runner)
544
+ synchronize(device)
545
+ prefill_latency = time.perf_counter() - tic
546
+
547
+ if enable_profile_prefill:
548
+ trace_filename = _create_torch_profiler_filename(
549
+ profile_filename_prefix, batch_size, input_len, output_len, "prefill"
550
+ )
551
+ stop_profile(
552
+ profiler,
553
+ profile_activities,
554
+ rank_print=rank_print,
555
+ save_trace=True,
556
+ trace_filename=trace_filename,
557
+ stage="prefill",
558
+ )
559
+
560
+ tot_latency += prefill_latency
561
+ throughput = input_len * batch_size / prefill_latency
562
+ rank_print(
563
+ f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
564
+ )
565
+ measurement_results["prefill_latency"] = prefill_latency
566
+ measurement_results["prefill_throughput"] = throughput
567
+
568
+ decode_latencies = []
569
+ profile_step_of_interest = output_len // 2
570
+ enable_profile_decode = profile and profile_stage in ["all", "decode"]
571
+ for i in range(output_len - 1):
572
+ synchronize(device)
573
+ profiler = None
574
+ if enable_profile_decode and i == profile_step_of_interest:
575
+ profiler = start_profile(
576
+ profile_activities,
577
+ profile_record_shapes=profile_record_shapes,
578
+ rank_print=rank_print,
579
+ )
580
+
581
+ tic = time.perf_counter()
582
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
583
+ synchronize(device)
584
+ latency = time.perf_counter() - tic
585
+
586
+ if enable_profile_decode and i == profile_step_of_interest:
587
+ trace_filename = _create_torch_profiler_filename(
588
+ profile_filename_prefix, batch_size, input_len, output_len, "decode"
589
+ )
590
+ stop_profile(
591
+ profiler,
592
+ profile_activities,
593
+ rank_print=rank_print,
594
+ save_trace=True,
595
+ trace_filename=trace_filename,
596
+ stage="decode",
597
+ )
598
+
599
+ tot_latency += latency
600
+ throughput = batch_size / latency
601
+ decode_latencies.append(latency)
602
+ if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0):
603
+ rank_print(
604
+ f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
605
+ )
606
+
607
+ # Record decode timing from 2nd output
608
+ if output_len > 1:
609
+ med_decode_latency = np.median(decode_latencies)
610
+ med_decode_throughput = batch_size / med_decode_latency
611
+ rank_print(
612
+ f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
613
+ )
614
+ measurement_results["median_decode_latency"] = med_decode_latency
615
+ measurement_results["median_decode_throughput"] = med_decode_throughput
616
+
617
+ throughput = (input_len + output_len) * batch_size / tot_latency
618
+ rank_print(
619
+ f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
620
+ )
621
+ measurement_results["total_latency"] = tot_latency
622
+ measurement_results["overall_throughput"] = throughput
623
+ return measurement_results
624
+
625
+
626
+ def latency_test(
627
+ server_args,
628
+ port_args,
629
+ bench_args,
630
+ gpu_id,
631
+ tp_rank,
632
+ ):
633
+ initialize_moe_config(server_args)
634
+
635
+ # Set CPU affinity
636
+ if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
637
+ set_gpu_proc_affinity(
638
+ server_args.pp_size, server_args.tp_size, server_args.nnodes, tp_rank
639
+ )
640
+
641
+ # Configure the logger
642
+ configure_logger(server_args, prefix=f" TP{tp_rank}")
643
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
644
+
645
+ # Load the model
646
+ model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank)
647
+
648
+ # Prepare inputs for warm up
649
+ reqs = prepare_synthetic_inputs_for_latency_test(
650
+ bench_args.batch_size[0], bench_args.input_len[0]
651
+ )
652
+
653
+ # Warm up
654
+ rank_print("Warmup ...")
655
+ latency_test_run_once(
656
+ bench_args.run_name,
657
+ model_runner,
658
+ rank_print,
659
+ reqs,
660
+ bench_args.batch_size[0],
661
+ bench_args.input_len[0],
662
+ min(32, bench_args.output_len[0]), # shorter decoding to speed up the warmup
663
+ server_args.device,
664
+ log_decode_step=0,
665
+ profile=False,
666
+ profile_record_shapes=False,
667
+ profile_activities=("CPU", "GPU"),
668
+ profile_filename_prefix="",
669
+ profile_stage="all",
670
+ tp_rank=tp_rank,
671
+ )
672
+
673
+ rank_print("Benchmark ...")
674
+
675
+ custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
676
+ custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs]
677
+ custom_input_len = len(custom_inputs)
678
+
679
+ # Run the sweep
680
+ result_list = []
681
+ for bs, il, ol in itertools.product(
682
+ bench_args.batch_size, bench_args.input_len, bench_args.output_len
683
+ ):
684
+ bs_aligned_inputs = []
685
+ if custom_inputs:
686
+ if custom_input_len == bs:
687
+ bs_aligned_inputs = custom_inputs
688
+ elif custom_input_len > bs:
689
+ rank_print(
690
+ f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). "
691
+ f"Using the first {bs} prompts."
692
+ )
693
+ bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs])
694
+ else:
695
+ rank_print(
696
+ f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). "
697
+ f"Pad to the desired batch_size with the last prompt."
698
+ )
699
+ bs_aligned_inputs = copy.deepcopy(custom_inputs)
700
+ bs_aligned_inputs.extend(
701
+ [bs_aligned_inputs[-1]] * (bs - custom_input_len)
702
+ )
703
+
704
+ reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs)
705
+ ret = latency_test_run_once(
706
+ bench_args.run_name,
707
+ model_runner,
708
+ rank_print,
709
+ reqs,
710
+ bs,
711
+ il,
712
+ ol,
713
+ server_args.device,
714
+ bench_args.log_decode_step,
715
+ bench_args.profile if tp_rank == 0 else None,
716
+ bench_args.profile_record_shapes if tp_rank == 0 else None,
717
+ bench_args.profile_activities,
718
+ bench_args.profile_filename_prefix,
719
+ bench_args.profile_stage,
720
+ tp_rank,
721
+ )
722
+ if ret is not None:
723
+ result_list.append(ret)
724
+
725
+ # Write results in jsonlines format on rank 0.
726
+ if tp_rank == 0 and bench_args.result_filename:
727
+ with open(bench_args.result_filename, "a") as fout:
728
+ for result in result_list:
729
+ fout.write(json.dumps(result) + "\n")
730
+
731
+ if server_args.tp_size > 1:
732
+ destroy_distributed_environment()
733
+
734
+
735
+ def main(server_args, bench_args):
736
+ server_args.cuda_graph_max_bs = max(bench_args.batch_size)
737
+
738
+ _set_envs_and_config(server_args)
739
+
740
+ if server_args.model_path:
741
+ if bench_args.correctness_test:
742
+ work_func = correctness_test
743
+ else:
744
+ work_func = latency_test
745
+ else:
746
+ raise ValueError(
747
+ "Provide --model-path for running the tests or "
748
+ "provide --result-filename for plotting the results"
749
+ )
750
+
751
+ port_args = PortArgs.init_new(server_args)
752
+
753
+ if server_args.tp_size == 1:
754
+ work_func(server_args, port_args, bench_args, 0, 0)
755
+ else:
756
+ workers = []
757
+ for tp_rank in range(server_args.tp_size):
758
+ with maybe_reindex_device_id(tp_rank) as gpu_id:
759
+ proc = multiprocessing.Process(
760
+ target=work_func,
761
+ args=(
762
+ server_args,
763
+ port_args,
764
+ bench_args,
765
+ gpu_id,
766
+ tp_rank,
767
+ ),
768
+ )
769
+ proc.start()
770
+ workers.append(proc)
771
+
772
+ for proc in workers:
773
+ proc.join()
774
+
775
+ proc.terminate()
776
+
777
+
778
+ if __name__ == "__main__":
779
+ parser = argparse.ArgumentParser()
780
+ ServerArgs.add_cli_args(parser)
781
+ BenchArgs.add_cli_args(parser)
782
+ args = parser.parse_args()
783
+ server_args = ServerArgs.from_cli_args(args)
784
+ bench_args = BenchArgs.from_cli_args(args)
785
+
786
+ logging.basicConfig(
787
+ level=getattr(logging, server_args.log_level.upper()),
788
+ format="%(message)s",
789
+ )
790
+
791
+ try:
792
+ main(server_args, bench_args)
793
+ finally:
794
+ if server_args.tp_size != 1:
795
+ kill_process_tree(os.getpid(), include_parent=False)
sglang/bench_one_batch_server.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Benchmark the latency of running a single batch with a server.
3
+
4
+ This script launches a server and uses the HTTP interface.
5
+ It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
6
+
7
+ Usage:
8
+ python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
9
+
10
+ python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
11
+ python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
12
+ python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --output-path results.json --profile
13
+ """
14
+
15
+ import argparse
16
+ import dataclasses
17
+ import itertools
18
+ import json
19
+ import multiprocessing
20
+ import os
21
+ import random
22
+ import time
23
+ from typing import List, Optional, Tuple
24
+
25
+ import numpy as np
26
+ import requests
27
+ from pydantic import BaseModel
28
+ from transformers import AutoProcessor, PreTrainedTokenizer
29
+
30
+ from sglang.bench_serving import (
31
+ get_processor,
32
+ get_tokenizer,
33
+ sample_mmmu_requests,
34
+ sample_random_requests,
35
+ )
36
+ from sglang.profiler import run_profile
37
+ from sglang.srt.entrypoints.http_server import launch_server
38
+ from sglang.srt.server_args import ServerArgs
39
+ from sglang.srt.utils import is_blackwell, kill_process_tree
40
+ from sglang.test.nightly_bench_utils import save_results_as_pydantic_models
41
+ from sglang.test.test_utils import is_in_ci, write_github_step_summary
42
+
43
+
44
+ @dataclasses.dataclass
45
+ class BenchArgs:
46
+ run_name: str = "default"
47
+ batch_size: Tuple[int] = (1,)
48
+ input_len: Tuple[int] = (1024,)
49
+ output_len: Tuple[int] = (16,)
50
+ temperature: float = 0.0
51
+ return_logprob: bool = False
52
+ client_stream_interval: int = 1
53
+ input_len_step_percentage: float = 0.0
54
+ base_url: str = ""
55
+ skip_warmup: bool = False
56
+ show_report: bool = False
57
+ profile: bool = False
58
+ profile_steps: int = 5
59
+ profile_by_stage: bool = False
60
+ profile_prefix: Optional[str] = None
61
+ profile_output_dir: Optional[str] = None
62
+ dataset_path: str = ""
63
+ dataset_name: str = "random"
64
+ parallel_batch: bool = False
65
+ result_filename: str = "result.jsonl"
66
+ pydantic_result_filename: Optional[str] = None
67
+ append_to_github_summary: bool = True
68
+ seed: int = 42
69
+
70
+ @staticmethod
71
+ def add_cli_args(parser: argparse.ArgumentParser):
72
+ parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
73
+ parser.add_argument(
74
+ "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
75
+ )
76
+ parser.add_argument(
77
+ "--input-len", type=int, nargs="+", default=BenchArgs.input_len
78
+ )
79
+ parser.add_argument(
80
+ "--output-len", type=int, nargs="+", default=BenchArgs.output_len
81
+ )
82
+ parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
83
+ parser.add_argument("--return-logprob", action="store_true")
84
+ parser.add_argument(
85
+ "--client-stream-interval",
86
+ type=int,
87
+ default=BenchArgs.client_stream_interval,
88
+ )
89
+ parser.add_argument(
90
+ "--input-len-step-percentage",
91
+ type=float,
92
+ default=BenchArgs.input_len_step_percentage,
93
+ )
94
+ parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
95
+ parser.add_argument("--skip-warmup", action="store_true")
96
+ parser.add_argument("--show-report", action="store_true")
97
+ parser.add_argument("--profile", action="store_true")
98
+ parser.add_argument(
99
+ "--profile-steps", type=int, default=BenchArgs.profile_steps
100
+ )
101
+ parser.add_argument("--profile-by-stage", action="store_true")
102
+ parser.add_argument(
103
+ "--profile-prefix",
104
+ type=str,
105
+ default=BenchArgs.profile_prefix,
106
+ )
107
+ parser.add_argument(
108
+ "--profile-output-dir",
109
+ type=str,
110
+ default=BenchArgs.profile_output_dir,
111
+ )
112
+ parser.add_argument(
113
+ "--dataset-path",
114
+ type=str,
115
+ default=BenchArgs.dataset_path,
116
+ help="Path to the dataset.",
117
+ )
118
+ parser.add_argument(
119
+ "--dataset-name",
120
+ type=str,
121
+ default=BenchArgs.dataset_name,
122
+ choices=["mmmu", "random"],
123
+ help="Name of the dataset to benchmark on.",
124
+ )
125
+ parser.add_argument("--parallel-batch", action="store_true")
126
+ parser.add_argument(
127
+ "--result-filename",
128
+ type=str,
129
+ default=BenchArgs.result_filename,
130
+ help="Store the results line by line in the JSON Line format to this file.",
131
+ )
132
+ parser.add_argument(
133
+ "--pydantic-result-filename",
134
+ type=str,
135
+ default=BenchArgs.pydantic_result_filename,
136
+ help="Store the results as pydantic models in the JSON format to this file.",
137
+ )
138
+ parser.add_argument(
139
+ "--no-append-to-github-summary",
140
+ action="store_false",
141
+ dest="append_to_github_summary",
142
+ help="Disable appending the output of this run to github ci summary",
143
+ )
144
+ parser.add_argument("--seed", type=int, default=BenchArgs.seed)
145
+
146
+ @classmethod
147
+ def from_cli_args(cls, args: argparse.Namespace):
148
+ attrs = [attr.name for attr in dataclasses.fields(cls)]
149
+ return cls(**{attr: getattr(args, attr) for attr in attrs})
150
+
151
+
152
+ class BenchOneCaseResult(BaseModel):
153
+ run_name: str
154
+ batch_size: int
155
+ input_len: int
156
+ output_len: int
157
+ latency: float
158
+ input_throughput: float
159
+ output_throughput: float
160
+ overall_throughput: float
161
+ last_ttft: float
162
+ last_gen_throughput: float
163
+ acc_length: float
164
+ profile_link: Optional[str] = None
165
+
166
+ def dump_to_jsonl(self, result_filename: str):
167
+ with open(result_filename, "a") as fout:
168
+ res = {
169
+ "run_name": self.run_name,
170
+ "batch_size": self.batch_size,
171
+ "input_len": self.input_len,
172
+ "output_len": self.output_len,
173
+ "latency": round(self.latency, 4),
174
+ "input_throughput": round(self.input_throughput, 2),
175
+ "output_throughput": round(self.output_throughput, 2),
176
+ "overall_throughput": round(self.overall_throughput, 2),
177
+ "last_ttft": round(self.last_ttft, 4),
178
+ "last_gen_throughput": round(self.last_gen_throughput, 2),
179
+ "acc_length": round(self.acc_length, 2),
180
+ }
181
+ fout.write(json.dumps(res) + "\n")
182
+
183
+
184
+ def launch_server_internal(server_args):
185
+ try:
186
+ launch_server(server_args)
187
+ except Exception as e:
188
+ raise e
189
+ finally:
190
+ kill_process_tree(os.getpid(), include_parent=False)
191
+
192
+
193
+ def launch_server_process(server_args: ServerArgs):
194
+ proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
195
+ proc.start()
196
+ base_url = f"http://{server_args.host}:{server_args.port}"
197
+ timeout = 600
198
+
199
+ start_time = time.time()
200
+ while time.time() - start_time < timeout:
201
+ try:
202
+ headers = {
203
+ "Content-Type": "application/json; charset=utf-8",
204
+ }
205
+ response = requests.get(f"{base_url}/v1/models", headers=headers)
206
+ if response.status_code == 200:
207
+ return proc, base_url
208
+ except requests.RequestException:
209
+ pass
210
+ time.sleep(10)
211
+ raise TimeoutError("Server failed to start within the timeout period.")
212
+
213
+
214
+ def run_one_case(
215
+ url: str,
216
+ batch_size: int,
217
+ input_len: int,
218
+ output_len: int,
219
+ temperature: float,
220
+ return_logprob: bool,
221
+ stream_interval: int,
222
+ input_len_step_percentage: float,
223
+ run_name: str,
224
+ result_filename: str,
225
+ tokenizer: PreTrainedTokenizer | AutoProcessor,
226
+ profile: bool = False,
227
+ profile_steps: int = BenchArgs.profile_steps,
228
+ profile_by_stage: bool = False,
229
+ profile_prefix: Optional[str] = BenchArgs.profile_prefix,
230
+ profile_output_dir: Optional[str] = BenchArgs.profile_output_dir,
231
+ dataset_name: str = BenchArgs.dataset_name,
232
+ dataset_path: str = BenchArgs.dataset_path,
233
+ parallel_batch: bool = False,
234
+ ):
235
+ requests.post(url + "/flush_cache")
236
+
237
+ # Load input token ids
238
+ # TODO: reuse bench_serving.get_dataset ?
239
+ if dataset_name == "mmmu":
240
+ input_requests = sample_mmmu_requests(
241
+ num_requests=batch_size,
242
+ processor=tokenizer,
243
+ fixed_output_len=output_len,
244
+ random_sample=False,
245
+ )
246
+ elif dataset_name == "random":
247
+ input_requests = sample_random_requests(
248
+ input_len=input_len,
249
+ output_len=output_len,
250
+ num_prompts=batch_size,
251
+ range_ratio=1.0,
252
+ tokenizer=tokenizer,
253
+ dataset_path=dataset_path,
254
+ random_sample=True,
255
+ return_text=False,
256
+ )
257
+
258
+ # Load sampling parameters
259
+ use_structured_outputs = False
260
+ if use_structured_outputs:
261
+ texts = []
262
+ for _ in range(batch_size):
263
+ texts.append(
264
+ "Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n"
265
+ * 50
266
+ + "Assistant:"
267
+ )
268
+ json_schema = "$$ANY$$"
269
+ else:
270
+ json_schema = None
271
+
272
+ payload = {
273
+ "sampling_params": {
274
+ "temperature": temperature,
275
+ "max_new_tokens": output_len,
276
+ "ignore_eos": True,
277
+ "json_schema": json_schema,
278
+ "stream_interval": stream_interval,
279
+ },
280
+ "return_logprob": return_logprob,
281
+ "stream": True,
282
+ **({"parallel_batch": parallel_batch} if parallel_batch else {}),
283
+ }
284
+ if dataset_name == "mmmu":
285
+ # vlm
286
+ input_ids = []
287
+ # for vlms, tokenizer is an instance of AutoProcessor
288
+ tokenizer = tokenizer.tokenizer
289
+ for input_req in input_requests:
290
+ input_ids += [tokenizer.encode(input_req.prompt)]
291
+ payload["image_data"] = [req.image_data for req in input_requests]
292
+
293
+ else:
294
+ input_ids = [req.prompt for req in input_requests]
295
+
296
+ payload["input_ids"] = input_ids
297
+
298
+ # Turn on profiler
299
+ profile_link = None
300
+ if profile:
301
+ profile_link: str = run_profile(
302
+ url=url,
303
+ num_steps=profile_steps,
304
+ activities=["CPU", "GPU"],
305
+ output_dir=profile_output_dir,
306
+ profile_by_stage=profile_by_stage,
307
+ profile_prefix=profile_prefix,
308
+ )
309
+
310
+ # Run the request
311
+ tic = time.perf_counter()
312
+ response = requests.post(
313
+ url + "/generate",
314
+ json=payload,
315
+ stream=True,
316
+ )
317
+
318
+ # Get the TTFT of the last request in the batch
319
+ last_ttft = 0.0
320
+ for chunk in response.iter_lines(decode_unicode=False):
321
+ chunk = chunk.decode("utf-8")
322
+ if chunk and chunk.startswith("data:"):
323
+ if chunk == "data: [DONE]":
324
+ break
325
+ data = json.loads(chunk[5:].strip("\n"))
326
+ if "error" in data:
327
+ raise RuntimeError(f"Request has failed. {data}.")
328
+
329
+ assert (
330
+ data["meta_info"]["finish_reason"] is None
331
+ or data["meta_info"]["finish_reason"]["type"] == "length"
332
+ )
333
+ if data["meta_info"]["completion_tokens"] == 1:
334
+ last_ttft = time.perf_counter() - tic
335
+
336
+ # Compute metrics
337
+ latency = time.perf_counter() - tic
338
+ input_throughput = batch_size * input_len / last_ttft
339
+ output_throughput = batch_size * output_len / (latency - last_ttft)
340
+ overall_throughput = batch_size * (input_len + output_len) / latency
341
+
342
+ server_info = requests.get(url + "/get_server_info").json()
343
+ internal_state = server_info.get("internal_states", [{}])
344
+ last_gen_throughput = internal_state[0].get("last_gen_throughput", None) or -1
345
+ acc_length = internal_state[0].get("avg_spec_accept_length", None) or -1
346
+
347
+ # Print results
348
+ print(f"batch size: {batch_size}")
349
+ print(f"input_len: {input_len}")
350
+ print(f"output_len: {output_len}")
351
+ print(f"latency: {latency:.2f} s")
352
+ print(f"input throughput: {input_throughput:.2f} tok/s")
353
+ if output_len != 1:
354
+ print(f"output throughput: {output_throughput:.2f} tok/s")
355
+ print(f"last_ttft: {last_ttft:.2f} s")
356
+ print(f"last generation throughput: {last_gen_throughput:.2f} tok/s")
357
+ if acc_length > 0:
358
+ print(f"acc_length: {acc_length:.2f} ")
359
+
360
+ # Dump results
361
+ result = BenchOneCaseResult(
362
+ run_name=run_name,
363
+ batch_size=batch_size,
364
+ input_len=input_len,
365
+ output_len=output_len,
366
+ latency=latency,
367
+ input_throughput=input_throughput,
368
+ output_throughput=output_throughput,
369
+ overall_throughput=overall_throughput,
370
+ last_ttft=last_ttft,
371
+ last_gen_throughput=last_gen_throughput,
372
+ acc_length=acc_length,
373
+ profile_link=profile_link,
374
+ )
375
+
376
+ # Save and return the results
377
+ if result_filename:
378
+ result.dump_to_jsonl(result_filename)
379
+
380
+ return result
381
+
382
+
383
+ def should_skip_due_to_token_capacity(
384
+ batch_size, input_len, output_len, skip_token_capacity_threshold
385
+ ):
386
+ if batch_size * (input_len + output_len) > skip_token_capacity_threshold:
387
+ print(
388
+ "=" * 8
389
+ + f"Skip benchmark {batch_size=} * ({input_len=} + {output_len=}) = {batch_size * (input_len + output_len)} > {skip_token_capacity_threshold=} due to kv cache limit."
390
+ + "=" * 8
391
+ )
392
+ return True
393
+ return False
394
+
395
+
396
+ def get_report_summary(
397
+ results: List[BenchOneCaseResult], bench_args: BenchArgs, server_args: ServerArgs
398
+ ):
399
+ summary = (
400
+ f"\nInput lens: {bench_args.input_len}. Output lens: {bench_args.output_len}.\n"
401
+ )
402
+ summary += "| batch size | input len | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) |"
403
+
404
+ if bench_args.profile:
405
+ summary += " profile |"
406
+
407
+ summary += "\n"
408
+ summary += "| ---------- | --------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ----------------- | ------------------ |"
409
+
410
+ if bench_args.profile:
411
+ summary += "-------------|"
412
+ summary += "\n"
413
+
414
+ if is_blackwell():
415
+ hourly_cost_per_gpu = 4 # $4/hour for one B200
416
+ else:
417
+ hourly_cost_per_gpu = 2 # $2/hour for one H100
418
+ input_util = 0.7
419
+
420
+ # sort result by input_len
421
+ results.sort(key=lambda x: x.input_len)
422
+ for res in results:
423
+ hourly_cost = hourly_cost_per_gpu * server_args.tp_size
424
+ accept_length = round(res.acc_length, 2) if res.acc_length > 0 else "n/a"
425
+ line = (
426
+ f"| {res.batch_size} | "
427
+ f"{res.input_len} | "
428
+ f"{res.latency:.2f} | "
429
+ f"{res.input_throughput:.2f} | "
430
+ f"{res.output_throughput:.2f} | "
431
+ f"{accept_length} | "
432
+ f"{1 / (res.output_throughput/res.batch_size) * 1000:.2f} | "
433
+ f"{1e6 / (res.input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
434
+ f"{1e6 / res.output_throughput / 3600 * hourly_cost:.2f} |"
435
+ )
436
+ if bench_args.profile:
437
+ if res.profile_link:
438
+ line += f" [Profile]({res.profile_link}) |"
439
+ else:
440
+ line += f" n/a |"
441
+ line += "\n"
442
+ summary += line
443
+
444
+ return summary
445
+
446
+
447
+ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
448
+ if bench_args.base_url:
449
+ proc, base_url = None, bench_args.base_url
450
+ else:
451
+ proc, base_url = launch_server_process(server_args)
452
+
453
+ # Get tokenizer
454
+ server_info = requests.get(base_url + "/get_server_info").json()
455
+ if "tokenizer_path" in server_info:
456
+ tokenizer_path = server_info["tokenizer_path"]
457
+ elif "prefill" in server_info:
458
+ tokenizer_path = server_info["prefill"][0]["tokenizer_path"]
459
+ if bench_args.dataset_name == "mmmu":
460
+ # mmmu implies this is a MLLM
461
+ tokenizer = get_processor(tokenizer_path)
462
+ else:
463
+ tokenizer = get_tokenizer(tokenizer_path)
464
+
465
+ # Get token capacity
466
+ internal_state = server_info.get("internal_states", [{}])
467
+ skip_token_capacity_threshold = (
468
+ internal_state[0].get("memory_usage", {}).get("token_capacity", 1000000000)
469
+ )
470
+
471
+ # Warmup
472
+ if not bench_args.skip_warmup:
473
+ print("=" * 8 + " Warmup Begin " + "=" * 8)
474
+ print(f"Warmup with batch_size={bench_args.batch_size}")
475
+ for bs in bench_args.batch_size:
476
+ run_one_case(
477
+ base_url,
478
+ batch_size=bs,
479
+ input_len=1024,
480
+ output_len=16,
481
+ temperature=bench_args.temperature,
482
+ return_logprob=bench_args.return_logprob,
483
+ stream_interval=bench_args.client_stream_interval,
484
+ input_len_step_percentage=bench_args.input_len_step_percentage,
485
+ run_name="",
486
+ result_filename="",
487
+ tokenizer=tokenizer,
488
+ dataset_name=bench_args.dataset_name,
489
+ dataset_path=bench_args.dataset_path,
490
+ parallel_batch=bench_args.parallel_batch,
491
+ )
492
+ print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
493
+
494
+ results = []
495
+ profile_results = []
496
+ try:
497
+ # Benchmark all cases
498
+ for bs, il, ol in itertools.product(
499
+ bench_args.batch_size, bench_args.input_len, bench_args.output_len
500
+ ):
501
+ if should_skip_due_to_token_capacity(
502
+ bs, il, ol, skip_token_capacity_threshold
503
+ ):
504
+ continue
505
+ results.append(
506
+ run_one_case(
507
+ base_url,
508
+ bs,
509
+ il,
510
+ ol,
511
+ temperature=bench_args.temperature,
512
+ return_logprob=bench_args.return_logprob,
513
+ stream_interval=bench_args.client_stream_interval,
514
+ input_len_step_percentage=bench_args.input_len_step_percentage,
515
+ run_name=bench_args.run_name,
516
+ result_filename=bench_args.result_filename,
517
+ tokenizer=tokenizer,
518
+ dataset_name=bench_args.dataset_name,
519
+ dataset_path=bench_args.dataset_path,
520
+ parallel_batch=bench_args.parallel_batch,
521
+ )
522
+ )
523
+
524
+ # Profile all cases
525
+ if bench_args.profile:
526
+ try:
527
+ for bs, il, ol in itertools.product(
528
+ bench_args.batch_size, bench_args.input_len, bench_args.output_len
529
+ ):
530
+ if should_skip_due_to_token_capacity(
531
+ bs, il, ol, skip_token_capacity_threshold
532
+ ):
533
+ continue
534
+ profile_prefix = (
535
+ bench_args.profile_prefix or ""
536
+ ) + f"bs-{bs}-il-{il}"
537
+ profile_results.append(
538
+ run_one_case(
539
+ base_url,
540
+ bs,
541
+ il,
542
+ ol,
543
+ temperature=bench_args.temperature,
544
+ return_logprob=bench_args.return_logprob,
545
+ stream_interval=bench_args.client_stream_interval,
546
+ input_len_step_percentage=bench_args.input_len_step_percentage,
547
+ run_name=bench_args.run_name,
548
+ result_filename=bench_args.result_filename,
549
+ tokenizer=tokenizer,
550
+ dataset_name=bench_args.dataset_name,
551
+ dataset_path=bench_args.dataset_path,
552
+ parallel_batch=bench_args.parallel_batch,
553
+ profile=bench_args.profile,
554
+ profile_steps=bench_args.profile_steps,
555
+ profile_by_stage=bench_args.profile_by_stage,
556
+ profile_prefix=profile_prefix,
557
+ profile_output_dir=bench_args.profile_output_dir,
558
+ )
559
+ )
560
+
561
+ # Replace the profile link
562
+ for res, profile_res in zip(results, profile_results):
563
+ res.profile_link = profile_res.profile_link
564
+ except Exception as e:
565
+ print(f"Error profiling, there will be no profile trace dump: {e}")
566
+ finally:
567
+ if proc:
568
+ kill_process_tree(proc.pid)
569
+
570
+ print(f"\nResults are saved to {bench_args.result_filename}")
571
+
572
+ if not bench_args.show_report:
573
+ return
574
+
575
+ # Print summary
576
+ summary = get_report_summary(results, bench_args, server_args)
577
+ print(summary)
578
+
579
+ if is_in_ci() and bench_args.append_to_github_summary:
580
+ write_github_step_summary(summary)
581
+ else:
582
+ print(summary)
583
+
584
+ # Save results as pydantic models in the JSON format
585
+ if bench_args.pydantic_result_filename:
586
+ save_results_as_pydantic_models(
587
+ results,
588
+ pydantic_result_filename=bench_args.pydantic_result_filename,
589
+ model_path=server_args.model_path,
590
+ )
591
+
592
+
593
+ if __name__ == "__main__":
594
+ parser = argparse.ArgumentParser()
595
+ ServerArgs.add_cli_args(parser)
596
+ BenchArgs.add_cli_args(parser)
597
+ args = parser.parse_args()
598
+
599
+ random.seed(args.seed)
600
+ np.random.seed(args.seed)
601
+
602
+ server_args = ServerArgs.from_cli_args(args)
603
+ bench_args = BenchArgs.from_cli_args(args)
604
+
605
+ run_benchmark(server_args, bench_args)
sglang/bench_serving.py ADDED
The diff for this file is too large to render. See raw diff
 
sglang/check_env.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Check environment configurations and dependency versions."""
2
+
3
+ import importlib.metadata
4
+ import os
5
+ import resource
6
+ import subprocess
7
+ import sys
8
+ from abc import abstractmethod
9
+ from collections import OrderedDict, defaultdict
10
+
11
+ import torch
12
+
13
+ from sglang.srt.utils import is_hip, is_npu
14
+
15
+
16
+ def is_cuda_v2():
17
+ return torch.version.cuda is not None
18
+
19
+
20
+ # List of packages to check versions
21
+ PACKAGE_LIST = [
22
+ "sglang",
23
+ "sgl_kernel",
24
+ "flashinfer_python",
25
+ "flashinfer_cubin",
26
+ "flashinfer_jit_cache",
27
+ "triton",
28
+ "transformers",
29
+ "torchao",
30
+ "numpy",
31
+ "aiohttp",
32
+ "fastapi",
33
+ "hf_transfer",
34
+ "huggingface_hub",
35
+ "interegular",
36
+ "modelscope",
37
+ "orjson",
38
+ "outlines",
39
+ "packaging",
40
+ "psutil",
41
+ "pydantic",
42
+ "python-multipart",
43
+ "pyzmq",
44
+ "torchao",
45
+ "uvicorn",
46
+ "uvloop",
47
+ "vllm",
48
+ "xgrammar",
49
+ "openai",
50
+ "tiktoken",
51
+ "anthropic",
52
+ "litellm",
53
+ "decord2",
54
+ ]
55
+
56
+
57
+ class BaseEnv:
58
+ """Base class for environment check"""
59
+
60
+ def __init__(self):
61
+ self.package_list = PACKAGE_LIST
62
+
63
+ @abstractmethod
64
+ def get_info(self) -> dict:
65
+ """
66
+ Get CUDA-related information if available.
67
+ """
68
+ raise NotImplementedError
69
+
70
+ @abstractmethod
71
+ def get_topology(self) -> dict:
72
+ raise NotImplementedError
73
+
74
+ def get_package_versions(self) -> dict:
75
+ """
76
+ Get versions of specified packages.
77
+ """
78
+ versions = {}
79
+ for package in self.package_list:
80
+ package_name = package.split("==")[0].split(">=")[0].split("<=")[0]
81
+ try:
82
+ version = importlib.metadata.version(package_name)
83
+ versions[package_name] = version
84
+ except ModuleNotFoundError:
85
+ versions[package_name] = "Module Not Found"
86
+ return versions
87
+
88
+ def get_device_info(self):
89
+ """
90
+ Get information about available GPU devices.
91
+ """
92
+ devices = defaultdict(list)
93
+ capabilities = defaultdict(list)
94
+ for k in range(torch.cuda.device_count()):
95
+ devices[torch.cuda.get_device_name(k)].append(str(k))
96
+ capability = torch.cuda.get_device_capability(k)
97
+ capabilities[f"{capability[0]}.{capability[1]}"].append(str(k))
98
+
99
+ gpu_info = {}
100
+ for name, device_ids in devices.items():
101
+ gpu_info[f"GPU {','.join(device_ids)}"] = name
102
+
103
+ if len(capabilities) == 1:
104
+ # All GPUs have the same compute capability
105
+ cap, gpu_ids = list(capabilities.items())[0]
106
+ gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
107
+ else:
108
+ # GPUs have different compute capabilities
109
+ for cap, gpu_ids in capabilities.items():
110
+ gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
111
+
112
+ return gpu_info
113
+
114
+ def get_hypervisor_vendor(self) -> dict:
115
+ try:
116
+ output = subprocess.check_output(["lscpu"], text=True)
117
+ for line in output.split("\n"):
118
+ if "Hypervisor vendor:" in line:
119
+ return {"Hypervisor vendor:": line.split(":")[1].strip()}
120
+ return {}
121
+ except:
122
+ return {}
123
+
124
+ def get_ulimit_soft(self) -> dict:
125
+ ulimit_soft, _ = resource.getrlimit(resource.RLIMIT_NOFILE)
126
+ return {"ulimit soft": ulimit_soft}
127
+
128
+ def check_env(self):
129
+ """
130
+ Check and print environment information.
131
+ """
132
+ env_info = OrderedDict()
133
+ env_info["Python"] = sys.version.replace("\n", "")
134
+ env_info.update(self.get_info())
135
+ env_info["PyTorch"] = torch.__version__
136
+ env_info.update(self.get_package_versions())
137
+ env_info.update(self.get_topology())
138
+ env_info.update(self.get_hypervisor_vendor())
139
+ env_info.update(self.get_ulimit_soft())
140
+
141
+ for k, v in env_info.items():
142
+ print(f"{k}: {v}")
143
+
144
+
145
+ class GPUEnv(BaseEnv):
146
+ """Environment checker for Nvidia GPU"""
147
+
148
+ def get_info(self):
149
+ cuda_info = {"CUDA available": torch.cuda.is_available()}
150
+
151
+ if cuda_info["CUDA available"]:
152
+ cuda_info.update(self.get_device_info())
153
+ cuda_info.update(self._get_cuda_version_info())
154
+
155
+ return cuda_info
156
+
157
+ def _get_cuda_version_info(self):
158
+ """
159
+ Get CUDA version information.
160
+ """
161
+ from torch.utils.cpp_extension import CUDA_HOME
162
+
163
+ cuda_info = {"CUDA_HOME": CUDA_HOME}
164
+
165
+ if CUDA_HOME and os.path.isdir(CUDA_HOME):
166
+ cuda_info.update(self._get_nvcc_info())
167
+ cuda_info.update(self._get_cuda_driver_version())
168
+
169
+ return cuda_info
170
+
171
+ def _get_nvcc_info(self):
172
+ """
173
+ Get NVCC version information.
174
+ """
175
+ from torch.utils.cpp_extension import CUDA_HOME
176
+
177
+ try:
178
+ nvcc = os.path.join(CUDA_HOME, "bin/nvcc")
179
+ nvcc_output = (
180
+ subprocess.check_output(f'"{nvcc}" -V', shell=True)
181
+ .decode("utf-8")
182
+ .strip()
183
+ )
184
+ return {
185
+ "NVCC": nvcc_output[
186
+ nvcc_output.rfind("Cuda compilation tools") : nvcc_output.rfind(
187
+ "Build"
188
+ )
189
+ ].strip()
190
+ }
191
+ except subprocess.SubprocessError:
192
+ return {"NVCC": "Not Available"}
193
+
194
+ def _get_cuda_driver_version(self):
195
+ """
196
+ Get CUDA driver version.
197
+ """
198
+ versions = set()
199
+ try:
200
+ output = subprocess.check_output(
201
+ [
202
+ "nvidia-smi",
203
+ "--query-gpu=driver_version",
204
+ "--format=csv,noheader,nounits",
205
+ ]
206
+ )
207
+ versions = set(output.decode().strip().split("\n"))
208
+ if len(versions) == 1:
209
+ return {"CUDA Driver Version": versions.pop()}
210
+ else:
211
+ return {"CUDA Driver Versions": ", ".join(sorted(versions))}
212
+ except subprocess.SubprocessError:
213
+ return {"CUDA Driver Version": "Not Available"}
214
+
215
+ def get_topology(self):
216
+ """
217
+ Get GPU topology information.
218
+ """
219
+ try:
220
+ result = subprocess.run(
221
+ ["nvidia-smi", "topo", "-m"],
222
+ stdout=subprocess.PIPE,
223
+ stderr=subprocess.PIPE,
224
+ text=True,
225
+ check=True,
226
+ )
227
+ return {
228
+ "NVIDIA Topology": (
229
+ "\n" + result.stdout if result.returncode == 0 else None
230
+ )
231
+ }
232
+ except subprocess.SubprocessError:
233
+ return {}
234
+
235
+
236
+ class HIPEnv(BaseEnv):
237
+ """Environment checker for ROCm/HIP"""
238
+
239
+ def get_info(self):
240
+ cuda_info = {"ROCM available": torch.cuda.is_available()}
241
+
242
+ if cuda_info["ROCM available"]:
243
+ cuda_info.update(self.get_device_info())
244
+ cuda_info.update(self._get_cuda_version_info())
245
+
246
+ return cuda_info
247
+
248
+ def _get_cuda_version_info(self):
249
+ from torch.utils.cpp_extension import ROCM_HOME as ROCM_HOME
250
+
251
+ cuda_info = {"ROCM_HOME": ROCM_HOME}
252
+
253
+ if ROCM_HOME and os.path.isdir(ROCM_HOME):
254
+ cuda_info.update(self._get_hipcc_info())
255
+ cuda_info.update(self._get_rocm_driver_version())
256
+
257
+ return cuda_info
258
+
259
+ def _get_hipcc_info(self):
260
+ from torch.utils.cpp_extension import ROCM_HOME
261
+
262
+ try:
263
+ hipcc = os.path.join(ROCM_HOME, "bin/hipcc")
264
+ hipcc_output = (
265
+ subprocess.check_output(f'"{hipcc}" --version', shell=True)
266
+ .decode("utf-8")
267
+ .strip()
268
+ )
269
+ return {
270
+ "HIPCC": hipcc_output[
271
+ hipcc_output.rfind("HIP version") : hipcc_output.rfind("AMD clang")
272
+ ].strip()
273
+ }
274
+ except subprocess.SubprocessError:
275
+ return {"HIPCC": "Not Available"}
276
+
277
+ def _get_rocm_driver_version(self):
278
+ try:
279
+ output = subprocess.check_output(
280
+ [
281
+ "rocm-smi",
282
+ "--showdriverversion",
283
+ "--csv",
284
+ ]
285
+ )
286
+ versions = set(output.decode().strip().split("\n"))
287
+ versions.discard("name, value")
288
+ ver = versions.pop()
289
+ ver = ver.replace('"Driver version", ', "").replace('"', "")
290
+
291
+ return {"ROCM Driver Version": ver}
292
+ except subprocess.SubprocessError:
293
+ return {"ROCM Driver Version": "Not Available"}
294
+
295
+ def get_topology(self):
296
+ try:
297
+ result = subprocess.run(
298
+ ["rocm-smi", "--showtopotype"],
299
+ stdout=subprocess.PIPE,
300
+ stderr=subprocess.PIPE,
301
+ text=True,
302
+ check=True,
303
+ )
304
+ return {
305
+ "AMD Topology": "\n" + result.stdout if result.returncode == 0 else None
306
+ }
307
+ except subprocess.SubprocessError:
308
+ return {}
309
+
310
+
311
+ class NPUEnv(BaseEnv):
312
+ """Environment checker for Ascend NPU"""
313
+
314
+ EXTRA_PACKAGE_LIST = [
315
+ "torch_npu",
316
+ "sgl-kernel-npu",
317
+ "deep_ep",
318
+ ]
319
+
320
+ def __init__(self):
321
+ super().__init__()
322
+ self.package_list.extend(NPUEnv.EXTRA_PACKAGE_LIST)
323
+
324
+ def get_info(self):
325
+ cuda_info = {"NPU available": torch.npu.is_available()}
326
+ if cuda_info["NPU available"]:
327
+ cuda_info.update(self.get_device_info())
328
+ cuda_info.update(self._get_cann_version_info())
329
+
330
+ return cuda_info
331
+
332
+ def get_device_info(self):
333
+ """
334
+ Get information about available NPUs.
335
+ Need to override due to torch_npu interface differences.
336
+ """
337
+ devices = defaultdict(list)
338
+ for k in range(torch.npu.device_count()):
339
+ devices[torch.npu.get_device_name(k)].append(str(k))
340
+
341
+ npu_info = {}
342
+ for name, device_ids in devices.items():
343
+ npu_info[f"NPU {','.join(device_ids)}"] = name
344
+
345
+ return npu_info
346
+
347
+ def _get_cann_version_info(self):
348
+ cann_envs = ["ASCEND_TOOLKIT_HOME", "ASCEND_INSTALL_PATH"]
349
+ for var in cann_envs:
350
+ path = os.environ.get(var)
351
+ if path and os.path.exists(path):
352
+ CANN_HOME = path
353
+ break
354
+ else:
355
+ default_path = "/usr/local/Ascend/ascend-toolkit/latest"
356
+ CANN_HOME = default_path if os.path.exists(default_path) else None
357
+
358
+ if CANN_HOME:
359
+ npu_info = {"CANN_HOME": CANN_HOME}
360
+ npu_info.update(self._get_cann_info(CANN_HOME))
361
+ npu_info.update(self._get_ascend_driver_version())
362
+ return npu_info
363
+ else:
364
+ return {"CANN_HOME": "Not found"}
365
+
366
+ def _get_cann_info(self, CANN_HOME: str):
367
+ cann_info = {}
368
+ cann_version_file = os.path.join(CANN_HOME, "version.cfg")
369
+ if os.path.exists(cann_version_file):
370
+ with open(cann_version_file, "r", encoding="utf-8") as f:
371
+ f.readline() # discard first line comment in version.cfg
372
+ cann_info["CANN"] = f.readline().split("[")[1].split("]")[0]
373
+ else:
374
+ cann_info["CANN"] = "Not Available"
375
+ try:
376
+ bisheng = os.path.join(CANN_HOME, "compiler/ccec_compiler/bin/bisheng")
377
+ bisheng_output = (
378
+ subprocess.check_output([bisheng, "--version"]).decode("utf-8").strip()
379
+ )
380
+ cann_info["BiSheng"] = bisheng_output.split("\n")[0].strip()
381
+ except subprocess.SubprocessError:
382
+ cann_info["BiSheng"] = "Not Available"
383
+ return cann_info
384
+
385
+ def _get_ascend_driver_version(self):
386
+ try:
387
+ output = subprocess.check_output(
388
+ [
389
+ "npu-smi",
390
+ "info",
391
+ "-t",
392
+ "board",
393
+ "-i",
394
+ "0",
395
+ ]
396
+ )
397
+ for line in output.decode().strip().split("\n"):
398
+ if "Software Version" in line:
399
+ version = line.split(":")[-1].strip()
400
+ break
401
+ else:
402
+ version = "Not Available"
403
+
404
+ return {"Ascend Driver Version": version}
405
+ except subprocess.SubprocessError:
406
+ return {"Ascend Driver Version": "Not Available"}
407
+
408
+ def get_topology(self):
409
+ try:
410
+ result = subprocess.run(
411
+ ["npu-smi", "info", "-t", "topo"],
412
+ stdout=subprocess.PIPE,
413
+ stderr=subprocess.PIPE,
414
+ text=True,
415
+ check=True,
416
+ )
417
+ return {
418
+ "Ascend Topology": (
419
+ "\n" + result.stdout if result.returncode == 0 else None
420
+ )
421
+ }
422
+ except subprocess.SubprocessError:
423
+ return {}
424
+
425
+
426
+ if __name__ == "__main__":
427
+ if is_cuda_v2():
428
+ env = GPUEnv()
429
+ elif is_hip():
430
+ env = HIPEnv()
431
+ elif is_npu():
432
+ env = NPUEnv()
433
+ env.check_env()
sglang/cli/__init__.py ADDED
File without changes
sglang/cli/generate.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from sglang.cli.utils import get_is_diffusion_model, get_model_path
4
+
5
+
6
+ def generate(args, extra_argv):
7
+ # If help is requested, show generate subcommand help without requiring --model-path
8
+ if any(h in extra_argv for h in ("-h", "--help")):
9
+ from sglang.multimodal_gen.runtime.entrypoints.cli.generate import (
10
+ add_multimodal_gen_generate_args,
11
+ )
12
+
13
+ parser = argparse.ArgumentParser(description="SGLang Multimodal Generation")
14
+ add_multimodal_gen_generate_args(parser)
15
+ parser.parse_args(extra_argv)
16
+ return
17
+
18
+ model_path = get_model_path(extra_argv)
19
+ is_diffusion_model = get_is_diffusion_model(model_path)
20
+ if is_diffusion_model:
21
+ from sglang.multimodal_gen.runtime.entrypoints.cli.generate import (
22
+ add_multimodal_gen_generate_args,
23
+ generate_cmd,
24
+ )
25
+
26
+ parser = argparse.ArgumentParser(description="SGLang Multimodal Generation")
27
+ add_multimodal_gen_generate_args(parser)
28
+ parsed_args = parser.parse_args(extra_argv)
29
+ generate_cmd(parsed_args)
30
+ else:
31
+ raise Exception(
32
+ f"Generate subcommand is not yet supported for model: {model_path}"
33
+ )
sglang/cli/main.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from sglang.cli.generate import generate
4
+ from sglang.cli.serve import serve
5
+
6
+
7
+ def main():
8
+ parser = argparse.ArgumentParser()
9
+ subparsers = parser.add_subparsers(dest="subcommand", required=True)
10
+
11
+ serve_parser = subparsers.add_parser(
12
+ "serve",
13
+ help="Launch the SGLang server.",
14
+ add_help=False, # Defer help to the specific parser
15
+ )
16
+ serve_parser.set_defaults(func=serve)
17
+
18
+ generate_parser = subparsers.add_parser(
19
+ "generate",
20
+ help="Run inference on a multimodal model.",
21
+ add_help=False, # Defer help to the specific parser
22
+ )
23
+ generate_parser.set_defaults(func=generate)
24
+
25
+ args, extra_argv = parser.parse_known_args()
26
+ args.func(args, extra_argv)
sglang/cli/serve.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import argparse
4
+ import logging
5
+ import os
6
+
7
+ from sglang.cli.utils import get_is_diffusion_model, get_model_path
8
+ from sglang.srt.utils import kill_process_tree
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def serve(args, extra_argv):
14
+ if any(h in extra_argv for h in ("-h", "--help")):
15
+ # Since the server type is determined by the model, and we don't have a model path,
16
+ # we can't show the exact help. Instead, we show a general help message and then
17
+ # the help for both possible server types.
18
+ print(
19
+ "Usage: sglang serve --model-path <model-name-or-path> [additional-arguments]\n"
20
+ )
21
+ print(
22
+ "This command can launch either a standard language model server or a diffusion model server."
23
+ )
24
+ print("The server type is determined by the model path.\n")
25
+ print("For specific arguments, please provide a model_path.")
26
+ print("\n--- Help for Standard Language Model Server ---")
27
+ from sglang.srt.server_args import prepare_server_args
28
+
29
+ try:
30
+ prepare_server_args(["--help"])
31
+ except SystemExit:
32
+ pass # argparse --help calls sys.exit
33
+
34
+ print("\n--- Help for Diffusion Model Server ---")
35
+ from sglang.multimodal_gen.runtime.entrypoints.cli.serve import (
36
+ add_multimodal_gen_serve_args,
37
+ )
38
+
39
+ parser = argparse.ArgumentParser(description="SGLang Diffusion Model Serving")
40
+ add_multimodal_gen_serve_args(parser)
41
+ parser.print_help()
42
+ return
43
+
44
+ model_path = get_model_path(extra_argv)
45
+ try:
46
+ is_diffusion_model = get_is_diffusion_model(model_path)
47
+ if is_diffusion_model:
48
+ logger.info("Diffusion model detected")
49
+
50
+ if is_diffusion_model:
51
+ # Logic for Diffusion Models
52
+ from sglang.multimodal_gen.runtime.entrypoints.cli.serve import (
53
+ add_multimodal_gen_serve_args,
54
+ execute_serve_cmd,
55
+ )
56
+
57
+ parser = argparse.ArgumentParser(
58
+ description="SGLang Diffusion Model Serving"
59
+ )
60
+ add_multimodal_gen_serve_args(parser)
61
+ parsed_args, remaining_argv = parser.parse_known_args(extra_argv)
62
+
63
+ execute_serve_cmd(parsed_args, remaining_argv)
64
+ else:
65
+ # Logic for Standard Language Models
66
+ from sglang.launch_server import run_server
67
+ from sglang.srt.server_args import prepare_server_args
68
+
69
+ # Add a dummy argument for the program name, expected by prepare_server_args
70
+ # as it typically processes sys.argv
71
+ server_args = prepare_server_args(extra_argv)
72
+
73
+ run_server(server_args)
74
+ finally:
75
+ kill_process_tree(os.getpid(), include_parent=False)
sglang/cli/utils.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import json
3
+ import logging
4
+ import os
5
+ import tempfile
6
+ from typing import Optional
7
+
8
+ import filelock
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ temp_dir = tempfile.gettempdir()
14
+
15
+
16
+ def _get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
17
+ lock_dir = cache_dir or temp_dir
18
+ os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
19
+ model_name = model_name_or_path.replace("/", "-")
20
+ hash_name = hashlib.sha256(model_name.encode()).hexdigest()
21
+ lock_file_name = hash_name + model_name + ".lock"
22
+ lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
23
+ return lock
24
+
25
+
26
+ # Copied and adapted from hf_diffusers_utils.py
27
+ def _maybe_download_model(
28
+ model_name_or_path: str, local_dir: str | None = None, download: bool = True
29
+ ) -> str:
30
+ """
31
+ Resolve a model path. If it's a local directory, return it.
32
+ If it's a Hugging Face Hub ID, download only the config file
33
+ (`model_index.json` or `config.json`) and return its directory.
34
+
35
+ Args:
36
+ model_name_or_path: Local path or Hugging Face Hub model ID
37
+ local_dir: Local directory to save the downloaded file (if any)
38
+ download: Whether to download from Hugging Face Hub when needed
39
+
40
+ Returns:
41
+ Local directory path that contains the downloaded config file, or the original local directory.
42
+ """
43
+
44
+ if os.path.exists(model_name_or_path):
45
+ logger.info("Model already exists locally")
46
+ return model_name_or_path
47
+
48
+ if not download:
49
+ return model_name_or_path
50
+
51
+ with _get_lock(model_name_or_path):
52
+ # Try `model_index.json` first (diffusers models)
53
+ try:
54
+ logger.info(
55
+ "Downloading model_index.json from HF Hub for %s...",
56
+ model_name_or_path,
57
+ )
58
+ file_path = hf_hub_download(
59
+ repo_id=model_name_or_path,
60
+ filename="model_index.json",
61
+ local_dir=local_dir,
62
+ )
63
+ logger.info("Downloaded to %s", file_path)
64
+ return os.path.dirname(file_path)
65
+ except Exception as e_index:
66
+ logger.debug("model_index.json not found or failed: %s", e_index)
67
+
68
+ # Fallback to `config.json`
69
+ try:
70
+ logger.info(
71
+ "Downloading config.json from HF Hub for %s...", model_name_or_path
72
+ )
73
+ file_path = hf_hub_download(
74
+ repo_id=model_name_or_path,
75
+ filename="config.json",
76
+ local_dir=local_dir,
77
+ )
78
+ logger.info("Downloaded to %s", file_path)
79
+ return os.path.dirname(file_path)
80
+ except Exception as e_config:
81
+ raise ValueError(
82
+ (
83
+ "Could not find model locally at %s and failed to download "
84
+ "model_index.json/config.json from HF Hub: %s"
85
+ )
86
+ % (model_name_or_path, e_config)
87
+ ) from e_config
88
+
89
+
90
+ # Copied and adapted from hf_diffusers_utils.py
91
+ def is_diffusers_model_path(model_path: str) -> True:
92
+ """
93
+ Verify if the model directory contains a valid diffusers configuration.
94
+
95
+ Args:
96
+ model_path: Path to the model directory
97
+
98
+ Returns:
99
+ The loaded model configuration as a dictionary if the model is a diffusers model
100
+ None if the model is not a diffusers model
101
+ """
102
+
103
+ # Prefer model_index.json which indicates a diffusers pipeline
104
+ config_path = os.path.join(model_path, "model_index.json")
105
+ if not os.path.exists(config_path):
106
+ return False
107
+
108
+ # Load the config
109
+ with open(config_path) as f:
110
+ config = json.load(f)
111
+
112
+ # Verify diffusers version exists
113
+ if "_diffusers_version" not in config:
114
+ return False
115
+ return True
116
+
117
+
118
+ def get_is_diffusion_model(model_path: str):
119
+ model_path = _maybe_download_model(model_path)
120
+ is_diffusion_model = is_diffusers_model_path(model_path)
121
+ if is_diffusion_model:
122
+ logger.info("Diffusion model detected")
123
+ return is_diffusion_model
124
+
125
+
126
+ def get_model_path(extra_argv):
127
+ # Find the model_path argument
128
+ model_path = None
129
+ for i, arg in enumerate(extra_argv):
130
+ if arg == "--model-path":
131
+ if i + 1 < len(extra_argv):
132
+ model_path = extra_argv[i + 1]
133
+ break
134
+ elif arg.startswith("--model-path="):
135
+ model_path = arg.split("=", 1)[1]
136
+ break
137
+
138
+ if model_path is None:
139
+ # Fallback for --help or other cases where model-path is not provided
140
+ if any(h in extra_argv for h in ["-h", "--help"]):
141
+ raise Exception(
142
+ "Usage: sglang serve --model-path <model-name-or-path> [additional-arguments]\n\n"
143
+ "This command can launch either a standard language model server or a diffusion model server.\n"
144
+ "The server type is determined by the model path.\n"
145
+ "For specific arguments, please provide a model_path."
146
+ )
147
+ else:
148
+ raise Exception(
149
+ "Error: --model-path is required. "
150
+ "Please provide the path to the model."
151
+ )
152
+ return model_path
sglang/compile_deep_gemm.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Compile DeepGEMM Kernels for a model with specify server arguments
3
+
4
+ This script launches a server for capturing DeepGEMM calls and then compiles the kernels.
5
+ It accepts server arguments (the same as launch_server.py).
6
+
7
+ Usage:
8
+ python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
9
+
10
+ """
11
+
12
+ import argparse
13
+ import dataclasses
14
+ import multiprocessing
15
+ import os
16
+ import time
17
+
18
+ import requests
19
+
20
+ from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
21
+ from sglang.srt.entrypoints.http_server import launch_server
22
+ from sglang.srt.entrypoints.warmup import warmup
23
+ from sglang.srt.environ import envs
24
+ from sglang.srt.managers.io_struct import GenerateReqInput
25
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager
26
+ from sglang.srt.server_args import ServerArgs
27
+ from sglang.srt.utils import kill_process_tree
28
+
29
+ multiprocessing.set_start_method("spawn", force=True)
30
+
31
+ # Reduce warning
32
+ envs.SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE.set(True)
33
+ # Force enable deep gemm
34
+ envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(True)
35
+ # Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
36
+ os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0"
37
+
38
+
39
+ @dataclasses.dataclass
40
+ class CompileArgs:
41
+ timeout: int = 3600
42
+
43
+ @staticmethod
44
+ def add_cli_args(parser: argparse.ArgumentParser):
45
+ parser.add_argument("--timeout", type=int, default=CompileArgs.timeout)
46
+
47
+ @classmethod
48
+ def from_cli_args(cls, args: argparse.Namespace):
49
+ # use the default value's type to cast the args into correct types.
50
+ attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
51
+ return cls(
52
+ **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
53
+ )
54
+
55
+
56
+ @warmup("compile-deep-gemm")
57
+ async def warm_up_compile(
58
+ disaggregation_mode: str, tokenizer_manager: TokenizerManager
59
+ ):
60
+ print("\nGenerate warm up request for compiling DeepGEMM...\n")
61
+ generate_req_input = GenerateReqInput(
62
+ input_ids=[0, 1, 2, 3],
63
+ sampling_params={
64
+ "temperature": 0.0,
65
+ "max_new_tokens": 8,
66
+ "ignore_eos": True,
67
+ },
68
+ )
69
+ if disaggregation_mode != "null":
70
+ generate_req_input.bootstrap_room = 0
71
+ generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST
72
+
73
+ await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
74
+
75
+
76
+ def launch_server_internal(server_args):
77
+ try:
78
+ launch_server(server_args)
79
+ except Exception as e:
80
+ raise e
81
+ finally:
82
+ kill_process_tree(os.getpid(), include_parent=False)
83
+
84
+
85
+ def launch_server_process_and_send_one_request(
86
+ server_args: ServerArgs, compile_args: CompileArgs
87
+ ):
88
+ proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
89
+ proc.start()
90
+ base_url = f"http://{server_args.host}:{server_args.port}"
91
+ timeout = compile_args.timeout
92
+
93
+ start_time = time.perf_counter()
94
+ while time.perf_counter() - start_time < timeout:
95
+ try:
96
+ headers = {
97
+ "Content-Type": "application/json; charset=utf-8",
98
+ }
99
+ if server_args.node_rank == 0:
100
+ response = requests.get(f"{base_url}/v1/models", headers=headers)
101
+ else:
102
+ # This http api is created by launch_dummy_health_check_server for none-rank0 node.
103
+ response = requests.get(f"{base_url}/health", headers=headers)
104
+ if response.status_code == 200:
105
+ # Rank-0 node send a request to sync with other node and then return.
106
+ if server_args.node_rank == 0:
107
+ payload = {
108
+ "input_ids": [0, 1, 2, 3],
109
+ "sampling_params": {
110
+ "max_new_tokens": 8,
111
+ "temperature": 0,
112
+ },
113
+ }
114
+ # In PD mode, include fake bootstrap fields so workers don't assert
115
+ if server_args.disaggregation_mode != "null":
116
+ payload["bootstrap_host"] = FAKE_BOOTSTRAP_HOST
117
+ payload["bootstrap_room"] = 0
118
+
119
+ response = requests.post(
120
+ f"{base_url}/generate",
121
+ json=payload,
122
+ timeout=600,
123
+ )
124
+ if response.status_code != 200:
125
+ error = response.json()
126
+ raise RuntimeError(f"Sync request failed: {error}")
127
+ # Other nodes should wait for the exit signal from Rank-0 node.
128
+ else:
129
+ start_time_waiting = time.perf_counter()
130
+ while proc.is_alive():
131
+ if time.perf_counter() - start_time_waiting < timeout:
132
+ time.sleep(10)
133
+ else:
134
+ raise TimeoutError("Waiting for main node timeout!")
135
+ return proc
136
+ except requests.RequestException:
137
+ pass
138
+ time.sleep(10)
139
+ raise TimeoutError(
140
+ "DeepGEMM Kernels compilation timeout."
141
+ "\n\nFeel free and please restart the command."
142
+ )
143
+
144
+
145
+ def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
146
+ # Disable cuda graph and torch compile to save time
147
+ server_args.disable_cuda_graph = True
148
+ server_args.enable_torch_compile = False
149
+ print(f"Disable CUDA Graph and Torch Compile to save time...")
150
+
151
+ # Set watchdog timeout to compile_args.timeout because compilation will take a long time
152
+ server_args.watchdog_timeout = compile_args.timeout
153
+ server_args.warmups = "compile-deep-gemm"
154
+
155
+
156
+ def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
157
+ print(
158
+ "Begin DeepGEMM Kernels compilation...\n"
159
+ "It may take a long time and timeout maybe raised "
160
+ "while the compilation is still in progress.\n"
161
+ "Just feel free to restart the command "
162
+ "until the compilation is fully finished.\n"
163
+ )
164
+
165
+ proc = launch_server_process_and_send_one_request(server_args, compile_args)
166
+
167
+ print("\nDeepGEMM Kernels compilation finished successfully.")
168
+
169
+ # Sleep for safety
170
+ time.sleep(10)
171
+ if proc.is_alive():
172
+ # This is the rank0 node.
173
+ kill_process_tree(proc.pid)
174
+ else:
175
+ try:
176
+ kill_process_tree(proc.pid)
177
+ except Exception:
178
+ pass
179
+
180
+
181
+ if __name__ == "__main__":
182
+ parser = argparse.ArgumentParser()
183
+ ServerArgs.add_cli_args(parser)
184
+ CompileArgs.add_cli_args(parser)
185
+ args = parser.parse_args()
186
+ server_args = ServerArgs.from_cli_args(args)
187
+ compile_args = CompileArgs.from_cli_args(args)
188
+
189
+ refine_server_args(server_args, compile_args)
190
+
191
+ run_compile(server_args, compile_args)
sglang/eval/llama3_eval.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapt from https://github.com/fw-ai/llm_eval_meta
2
+
3
+ import argparse
4
+ import asyncio
5
+ import os
6
+ import pickle
7
+ import re
8
+ import shutil
9
+ from collections import defaultdict
10
+ from dataclasses import dataclass
11
+
12
+ import httpx
13
+ import numpy as np
14
+ import openai
15
+ from datasets import load_dataset
16
+ from openai import AsyncOpenAI
17
+ from tqdm import tqdm
18
+
19
+ # Mapping providers to their clients and models
20
+ provider_to_models = {
21
+ "b10": {
22
+ "8b": "meta-llama/Llama-3.1-8B-Instruct",
23
+ "70b": "meta-llama/Llama-3.1-70B-Instruct",
24
+ "405b": "meta-llama/Llama-3.1-405B-Instruct",
25
+ },
26
+ "oai": {
27
+ "8b": "meta-llama/Llama-3.1-8B-Instruct",
28
+ "70b": "meta-llama/Llama-3.1-70B-Instruct",
29
+ "405b": "meta-llama/Llama-3.1-405B-Instruct",
30
+ },
31
+ "sgl": {
32
+ "8b": "meta-llama/Llama-3.1-8B-Instruct",
33
+ "70b": "meta-llama/Llama-3.1-70B-Instruct",
34
+ "405b": "meta-llama/Llama-3.1-405B-Instruct",
35
+ },
36
+ }
37
+
38
+
39
+ async def fetch_responses(
40
+ client, prompt, semaphore, index, provider, model_size, output_dir, max_tokens
41
+ ):
42
+ output_file = os.path.join(output_dir, f"response_{index}.pkl")
43
+ if os.path.exists(output_file):
44
+ print(f"File {output_file} already exists, skipping.")
45
+ return
46
+
47
+ async with semaphore:
48
+ response = await client.completions.create(
49
+ model=provider_to_models[provider][model_size],
50
+ prompt=prompt,
51
+ temperature=0.0,
52
+ max_tokens=max_tokens,
53
+ )
54
+ if isinstance(response, openai.BadRequestError):
55
+ with open(output_file, "wb") as f:
56
+ pickle.dump("bad_response", f)
57
+ assert isinstance(response, openai.types.completion.Completion)
58
+ # Save response to a file
59
+ with open(output_file, "wb") as f:
60
+ pickle.dump(response, f)
61
+
62
+
63
+ TASK_TO_MAX_TOKENS = {
64
+ "evals__mmlu__details": 1,
65
+ "evals__mmlu__0_shot__cot__details": 1024,
66
+ # Official meta uses 1024, but a small % (.05) of questions are answered correctly after relaxing
67
+ "evals__mmlu_pro__details": 2048,
68
+ "evals__gsm8k__details": 1024,
69
+ }
70
+
71
+ TASK_TO_EVAL_SET = {
72
+ "mmlu": "evals__mmlu__details",
73
+ "mmlu_cot": "evals__mmlu__0_shot__cot__details",
74
+ "mmlu_pro": "evals__mmlu_pro__details",
75
+ "gsm8k": "evals__gsm8k__details",
76
+ }
77
+
78
+
79
+ class CustomAsyncHTTPXClient(httpx.AsyncClient):
80
+ async def send(self, request: httpx.Request, *args, **kwargs) -> httpx.Response:
81
+ request.url = httpx.URL(
82
+ f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict"
83
+ )
84
+ return await super().send(request, *args, **kwargs)
85
+
86
+
87
+ def get_client(provider):
88
+ if provider not in "b10":
89
+ if os.getenv("OPENAI_API_KEY") == None:
90
+ os.environ["OPENAI_API_KEY"] = "EMPTY"
91
+ return {
92
+ "oai": AsyncOpenAI(base_url="http://127.0.0.1:8000/v1/"),
93
+ "b10": AsyncOpenAI(
94
+ api_key=f"Api-Key {os.getenv('OPENAI_API_KEY')}",
95
+ base_url=f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict",
96
+ http_client=CustomAsyncHTTPXClient(),
97
+ ),
98
+ "sgl": AsyncOpenAI(base_url="http://127.0.0.1:30000/v1/"),
99
+ }[provider]
100
+
101
+
102
+ # Define the benchmark function
103
+ async def benchmark(args):
104
+ ds = load_dataset(
105
+ "meta-llama/Llama-3.1-405B-Instruct-evals",
106
+ f"Llama-3.1-405B-Instruct-{TASK_TO_EVAL_SET[args.task]}",
107
+ )
108
+ semaphore = asyncio.Semaphore(args.concurrency) # Limit to 16 concurrent tasks
109
+
110
+ if args.num_examples is None:
111
+ args.num_examples = len(ds["latest"]["input_final_prompts"])
112
+ prompts = ds["latest"]["input_final_prompts"][: args.num_examples]
113
+
114
+ # Create the output directory if it does not exist
115
+ os.makedirs(args.output_dir, exist_ok=True)
116
+
117
+ tasks = []
118
+ # Create the tasks with tqdm progress bar
119
+ max_tokens = TASK_TO_MAX_TOKENS[TASK_TO_EVAL_SET[args.task]]
120
+ client = get_client(args.provider)
121
+ for idx, prompt in enumerate(tqdm(prompts, desc="Creating tasks")):
122
+ tasks.append(
123
+ asyncio.create_task(
124
+ fetch_responses(
125
+ client,
126
+ f"<|begin_of_text|>{prompt[0]}",
127
+ semaphore,
128
+ idx,
129
+ args.provider,
130
+ args.model_size,
131
+ args.output_dir,
132
+ max_tokens=max_tokens,
133
+ )
134
+ )
135
+ )
136
+
137
+ # Run the tasks with tqdm progress bar
138
+ for future in tqdm(
139
+ asyncio.as_completed(tasks), total=len(tasks), desc="Processing tasks"
140
+ ):
141
+ await future
142
+
143
+
144
+ def get_mmlu_answer(response):
145
+ if response is not None:
146
+ return response.choices[0].text.lstrip().rstrip().upper().replace(".", "")
147
+ return None
148
+
149
+
150
+ def get_mmlu_cot_answer(response):
151
+ pattern = r"The best answer is (.+)\.?"
152
+ match = re.search(pattern, response.choices[0].text)
153
+ if match:
154
+ return match.group(1).replace(".", "").replace("*", "")
155
+
156
+ pattern = r"the best answer is (.+)\.?"
157
+ match = re.search(pattern, response.choices[0].text)
158
+ if match:
159
+ return match.group(1).replace(".", "")
160
+
161
+ pattern = r"The correct answer is (.+)\.?"
162
+ match = re.search(pattern, response.choices[0].text)
163
+ if match:
164
+ return match.group(1).replace(".", "")
165
+
166
+ pattern = r"the correct answer is (.+)\.?"
167
+ match = re.search(pattern, response.choices[0].text)
168
+ if match:
169
+ return match.group(1).replace(".", "")
170
+
171
+
172
+ def get_answer_gsm8k(response):
173
+ pattern = r"The final answer is (.+)\.?"
174
+ match = re.search(pattern, response.choices[0].text)
175
+ if match:
176
+ s = match.group(1)
177
+ for ok_symbol in ["%", "$"]:
178
+ s = s.replace(ok_symbol, "")
179
+ return s
180
+
181
+
182
+ TASK_TO_ANSWER_EXTRACTOR = {
183
+ "evals__mmlu__details": get_mmlu_answer,
184
+ "evals__mmlu__0_shot__cot__details": get_mmlu_cot_answer,
185
+ "evals__gsm8k__details": get_answer_gsm8k,
186
+ "evals__mmlu_pro__details": get_mmlu_cot_answer,
187
+ }
188
+
189
+
190
+ def get_dataset_from_task(task, response_path, model_size):
191
+ ds_405b = load_dataset(
192
+ f"meta-llama/Llama-3.1-405B-Instruct-evals",
193
+ f"Llama-3.1-405B-Instruct-{task}",
194
+ )
195
+ ds_405b_hash_order = [x[0] for x in ds_405b["latest"]["input_final_prompts_hash"]]
196
+
197
+ if "70b" in model_size or "8b" in model_size:
198
+ if "70" in model_size:
199
+ ref_model_ds = load_dataset(
200
+ f"meta-llama/Llama-3.1-70B-Instruct-evals",
201
+ f"Llama-3.1-70B-Instruct-{task}",
202
+ )
203
+ else:
204
+ ref_model_ds = load_dataset(
205
+ f"meta-llama/Llama-3.1-8B-Instruct-evals",
206
+ f"Llama-3.1-8B-Instruct-{task}",
207
+ )
208
+
209
+ hash_to_row = {}
210
+ for row in ref_model_ds["latest"]:
211
+ hash_to_row[row["input_final_prompts_hash"][0]] = row
212
+ reordered_rows = []
213
+ for prompt_hash in ds_405b_hash_order:
214
+ reordered_rows.append(hash_to_row[prompt_hash])
215
+ ref_model_ds["latest"] = reordered_rows
216
+ return ref_model_ds
217
+
218
+ return ds_405b
219
+
220
+
221
+ def analyze(task, response_path, model_size):
222
+ ds = get_dataset_from_task(task, response_path, model_size)
223
+
224
+ responses = []
225
+ total = len(ds["latest"])
226
+
227
+ for i in range(0, total):
228
+ response = pickle.load(
229
+ open(os.path.join(response_path, f"response_{i}.pkl"), "rb")
230
+ )
231
+ responses.append(response)
232
+
233
+ @dataclass
234
+ class Stats:
235
+ correct: int = 0
236
+ total: int = 0
237
+ meta_correct: int = 0
238
+
239
+ average: float = None
240
+
241
+ subtask_name_to_stats = defaultdict(lambda: Stats())
242
+
243
+ for response, ds_row in zip(responses, ds["latest"]):
244
+ model_answer = TASK_TO_ANSWER_EXTRACTOR[task](response)
245
+
246
+ subtask = ds_row["subtask_name"]
247
+
248
+ is_eval_correct = model_answer in ds_row["input_correct_responses"]
249
+ if is_eval_correct:
250
+ subtask_name_to_stats[subtask].correct += 1
251
+
252
+ if ds_row["is_correct"]:
253
+ subtask_name_to_stats[subtask].meta_correct += 1
254
+
255
+ subtask_name_to_stats[subtask].total += 1
256
+
257
+ micro_stats = Stats()
258
+ for subtask, stats in subtask_name_to_stats.items():
259
+ stats.average = stats.correct / stats.total
260
+ stats.meta_average = stats.meta_correct / stats.total
261
+
262
+ micro_stats.correct += stats.correct
263
+ micro_stats.total += stats.total
264
+ micro_stats.meta_correct += stats.meta_correct
265
+
266
+ micro_stats.average = micro_stats.correct / micro_stats.total
267
+ micro_stats.meta_average = micro_stats.meta_correct / micro_stats.total
268
+
269
+ print("Macro average", np.mean([x.average for x in subtask_name_to_stats.values()]))
270
+ print(
271
+ "Meta Macro average",
272
+ np.mean([x.meta_average for x in subtask_name_to_stats.values()]),
273
+ )
274
+ print("Micro average", micro_stats.average)
275
+ print("Meta Micro average", micro_stats.meta_average)
276
+
277
+
278
+ # Entry point for the script
279
+ if __name__ == "__main__":
280
+ parser = argparse.ArgumentParser(
281
+ description="Script to run model with specified parameters."
282
+ )
283
+ parser.add_argument(
284
+ "--model-size",
285
+ type=str,
286
+ default="8b",
287
+ help="Size of the model (e.g., 8b or 70b)",
288
+ )
289
+ parser.add_argument(
290
+ "--provider",
291
+ type=str,
292
+ default="sgl",
293
+ help="Provider name (e.g., sgl, oai, b10)",
294
+ )
295
+ parser.add_argument(
296
+ "--task",
297
+ type=str,
298
+ required=True,
299
+ help="Task (e.g., mmlu, mmlu_cot, mmlu_pro, gsm8k)",
300
+ )
301
+ parser.add_argument(
302
+ "--num-examples", type=int, default=None, help="Number of examples to process"
303
+ )
304
+ parser.add_argument("--concurrency", type=int, default=16)
305
+ parser.add_argument(
306
+ "--output-dir",
307
+ type=str,
308
+ default="tmp-output-dir",
309
+ help="Directory to save responses",
310
+ )
311
+
312
+ args = parser.parse_args()
313
+ asyncio.run(benchmark(args))
314
+ analyze(TASK_TO_EVAL_SET[args.task], args.output_dir, args.model_size)
315
+ shutil.rmtree("tmp-output-dir", ignore_errors=True)
sglang/eval/loogle_eval.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import asyncio
3
+ import os
4
+ import pickle
5
+ from pathlib import Path
6
+ from typing import List
7
+
8
+ import openai
9
+ import torch
10
+ from bert_score import BERTScorer
11
+ from datasets import load_dataset
12
+ from tqdm import tqdm
13
+
14
+
15
+ def get_client(api_url: str) -> openai.AsyncOpenAI:
16
+ if os.getenv("OPENAI_API_KEY") is None:
17
+ os.environ["OPENAI_API_KEY"] = "EMPTY"
18
+ return openai.AsyncOpenAI(base_url=api_url)
19
+
20
+
21
+ def get_dataset():
22
+ return load_dataset("bigai-nlco/LooGLE", "longdep_qa", split="test")
23
+
24
+
25
+ async def fetch_response(
26
+ client: openai.AsyncOpenAI,
27
+ context: str,
28
+ question: str,
29
+ semaphore: asyncio.Semaphore,
30
+ index: int,
31
+ model: str,
32
+ output_dir: Path,
33
+ ):
34
+ output_file = output_dir / f"response_{index}.pkl"
35
+ if output_file.exists():
36
+ return
37
+
38
+ prompt = (
39
+ "Please answer the question based on the long texts below.\n"
40
+ f"{context}\n"
41
+ f"Question: {question}\n"
42
+ "Answer:"
43
+ )
44
+ messages = [
45
+ {"role": "system", "content": "You are a helpful assistant."},
46
+ {"role": "user", "content": prompt},
47
+ ]
48
+
49
+ async with semaphore:
50
+ try:
51
+ response = await client.chat.completions.create(
52
+ model=model,
53
+ messages=messages,
54
+ temperature=0.0,
55
+ max_tokens=512,
56
+ )
57
+ except openai.BadRequestError as e:
58
+ with open(output_file, "wb") as f:
59
+ pickle.dump({"error": str(e)}, f)
60
+ return
61
+
62
+ with open(output_file, "wb") as f:
63
+ pickle.dump(response, f)
64
+
65
+
66
+ async def benchmark(args):
67
+ dataset = get_dataset()
68
+ output_dir = Path(args.output_dir)
69
+ output_dir.mkdir(parents=True, exist_ok=True)
70
+
71
+ client = get_client(args.api_url)
72
+ semaphore = asyncio.Semaphore(args.max_concurrency)
73
+
74
+ tasks: List[asyncio.Task] = []
75
+ for idx, ex in enumerate(dataset):
76
+ if idx >= args.num_prompts:
77
+ break
78
+ tasks.append(
79
+ asyncio.create_task(
80
+ fetch_response(
81
+ client,
82
+ ex["context"],
83
+ ex["question"],
84
+ semaphore,
85
+ idx,
86
+ args.model,
87
+ output_dir,
88
+ )
89
+ )
90
+ )
91
+
92
+ for _ in tqdm(
93
+ asyncio.as_completed(tasks), total=len(tasks), desc="Running benchmark"
94
+ ):
95
+ await _
96
+
97
+
98
+ def analyse(args):
99
+ dataset = get_dataset()
100
+ output_dir = Path(args.output_dir)
101
+
102
+ device = "cuda" if torch.cuda.is_available() else "cpu"
103
+ scorer = BERTScorer(lang="en", device=device)
104
+
105
+ hyps: List[str] = []
106
+ refs: List[str] = []
107
+ for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
108
+ if idx >= args.num_prompts:
109
+ break
110
+ pkl_file = output_dir / f"response_{idx}.pkl"
111
+ if not pkl_file.exists():
112
+ raise FileNotFoundError(pkl_file)
113
+
114
+ response = pickle.load(open(pkl_file, "rb"))
115
+ if isinstance(response, dict) and "error" in response:
116
+ continue
117
+
118
+ hyps.append(response.choices[0].message.content.strip())
119
+ refs.append(ex["answer"])
120
+
121
+ if not hyps:
122
+ print("No valid responses to score!")
123
+ return
124
+
125
+ batch_size = 64
126
+ all_f1: List[float] = []
127
+ for i in tqdm(range(0, len(hyps), batch_size), desc="Scoring batches"):
128
+ h_batch = hyps[i : i + batch_size]
129
+ r_batch = refs[i : i + batch_size]
130
+ _, _, f1_scores = scorer.score(h_batch, r_batch, verbose=False)
131
+ all_f1.extend([float(x) for x in f1_scores])
132
+
133
+ avg = sum(all_f1) / len(all_f1)
134
+ print(f"Average BERTScore (F1): {avg:.2%}")
135
+
136
+
137
+ if __name__ == "__main__":
138
+ parser = argparse.ArgumentParser(
139
+ description="Run benchmark and evaluation in one go."
140
+ )
141
+ parser.add_argument(
142
+ "--api-url",
143
+ default="http://127.0.0.1:30000/v1",
144
+ help="OpenAI‑compatible API base URL",
145
+ )
146
+ parser.add_argument(
147
+ "--model",
148
+ default="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
149
+ help="Model name or ID, only used for model name",
150
+ )
151
+ parser.add_argument(
152
+ "--max-concurrency", type=int, default=144, help="Maximum concurrent requests"
153
+ )
154
+ parser.add_argument(
155
+ "--output-dir", default="tmp-output-dir", help="Directory for cached responses"
156
+ )
157
+ parser.add_argument(
158
+ "--num-prompts", type=int, default=10000, help="Number of prompts to run"
159
+ )
160
+ args = parser.parse_args()
161
+
162
+ asyncio.run(benchmark(args))
163
+
164
+ analyse(args)
sglang/global_config.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Global configurations"""
2
+
3
+ # FIXME: deprecate this file and move all usage to sglang.srt.environ or sglang.__init__.py
4
+
5
+
6
+ class GlobalConfig:
7
+ """
8
+ Store some global constants.
9
+ """
10
+
11
+ def __init__(self):
12
+ # Verbosity level
13
+ # 0: do not output anything
14
+ # 2: output final text after every run
15
+ self.verbosity = 0
16
+
17
+ # Default backend of the language
18
+ self.default_backend = None
19
+
20
+ # Output tokenization configs
21
+ self.skip_special_tokens_in_output = True
22
+ self.spaces_between_special_tokens_in_out = True
23
+
24
+ # Language frontend interpreter optimization configs
25
+ self.enable_precache_with_tracing = True
26
+ self.enable_parallel_encoding = True
27
+
28
+
29
+ global_config = GlobalConfig()
sglang/jit_kernel/.clang-format ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BasedOnStyle: Google
2
+ IndentWidth: 2
3
+ ColumnLimit: 120
4
+ AllowShortFunctionsOnASingleLine: Empty
5
+ DerivePointerAlignment: false
6
+ PointerAlignment: Left
7
+ NamespaceIndentation: None
8
+ SortIncludes: true
9
+ AllowShortLoopsOnASingleLine: false
10
+ BinPackParameters: false # Prevents packing parameters in declarations
11
+ BinPackArguments: false # Prevents packing arguments in function calls
12
+ AlignAfterOpenBracket: AlwaysBreak # Forces a break after the opening parenthesis
13
+ AlignOperands: Align # Aligns arguments vertically
14
+ PenaltyBreakBeforeFirstCallParameter: 1 # Encourages breaking before the first argument
15
+ PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name
16
+
17
+ IncludeCategories:
18
+ - Regex: '^<sgl_kernel/.*>$'
19
+ Priority: 0
sglang/jit_kernel/__pycache__/hicache.cpython-311.pyc ADDED
Binary file (4.35 kB). View file
 
sglang/jit_kernel/__pycache__/utils.cpython-311.pyc ADDED
Binary file (6.87 kB). View file
 
sglang/jit_kernel/csrc/cuda_wait_value.cuh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <sgl_kernel/tensor.h>
2
+ #include <sgl_kernel/utils.cuh>
3
+
4
+ #include <cuda_runtime_api.h>
5
+
6
+ #include <cstdint>
7
+
8
+ namespace {
9
+
10
+ __global__ void wait_flag_kernel(const int32_t* flag, int32_t target) {
11
+ const volatile int32_t* vflag = (volatile const int32_t*)flag;
12
+
13
+ while (*vflag != target) {
14
+ #if __CUDA_ARCH__ >= 700
15
+ __nanosleep(100);
16
+ #else
17
+ // Note: This falls back to an inefficient busy-wait on pre-Volta architectures.
18
+ #endif
19
+ }
20
+ }
21
+
22
+ auto stream_wait_value(const tvm::ffi::TensorView flag, std::int32_t value) -> void {
23
+ using namespace host;
24
+
25
+ auto length = SymbolicSize{"length"};
26
+ TensorMatcher({length}).with_dtype<int32_t>().with_device<kDLCUDA>().verify(flag);
27
+ RuntimeCheck(length.unwrap() >= 1, "wait_flag expects a non-empty tensor.");
28
+
29
+ auto* ptr = static_cast<std::int32_t*>(flag.data_ptr());
30
+ const auto stream = LaunchKernel::resolve_device(flag.device());
31
+
32
+ constexpr int blocks = 1;
33
+ constexpr int threads = 1;
34
+ wait_flag_kernel<<<blocks, threads, 0, stream>>>(ptr, value);
35
+ RuntimeDeviceCheck(cudaGetLastError());
36
+ }
37
+
38
+ } // namespace
sglang/jit_kernel/csrc/hicache.cuh ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <sgl_kernel/tensor.h>
2
+ #include <sgl_kernel/utils.cuh>
3
+ #include <sgl_kernel/utils.h>
4
+ #include <sgl_kernel/warp.cuh>
5
+
6
+ #include <dlpack/dlpack.h>
7
+
8
+ #include <algorithm>
9
+ #include <concepts>
10
+ #include <cstddef>
11
+ #include <cstdint>
12
+ #include <type_traits>
13
+
14
+ namespace {
15
+
16
+ struct HicacheKernelParams {
17
+ void* __restrict__ k_cache_dst;
18
+ void* __restrict__ v_cache_dst;
19
+ const void* __restrict__ indices_dst;
20
+ void* __restrict__ k_cache_src;
21
+ void* __restrict__ v_cache_src;
22
+ const void* __restrict__ indices_src;
23
+ std::size_t length;
24
+ std::size_t kv_cache_src_stride;
25
+ std::size_t kv_cache_dst_stride;
26
+ std::size_t num_layers = 0; // only used in all_layer transfer
27
+ };
28
+
29
+ template <
30
+ std::integral T,
31
+ std::size_t kElementSize,
32
+ std::size_t kUnroll,
33
+ std::size_t kBlockQuota,
34
+ std::size_t kNumThreads,
35
+ std::size_t kMaxOccupancy>
36
+ __global__ __launch_bounds__(kNumThreads, kMaxOccupancy) void hicache_transfer_per_layer(
37
+ const __grid_constant__ HicacheKernelParams params) {
38
+ // each warp acts as a worker
39
+ using namespace device;
40
+ static_assert(kNumThreads % kWarpThreads == 0);
41
+ static_assert(kWarpThreads % kUnroll == 0);
42
+
43
+ constexpr auto kWarpThreads = device::kWarpThreads / kUnroll;
44
+ constexpr auto kWarpsPerBlock = kNumThreads / kWarpThreads;
45
+ constexpr auto kWorkers = kWarpsPerBlock * kBlockQuota;
46
+
47
+ const auto& [
48
+ k_cache_dst, v_cache_dst, indices_dst, // dst
49
+ k_cache_src, v_cache_src, indices_src, // src
50
+ length, kv_cache_src_stride, kv_cache_dst_stride, _ // metadata
51
+ ] = params;
52
+ const auto warp_id = blockIdx.x * kWarpsPerBlock + threadIdx.x / kWarpThreads;
53
+
54
+ // force to transfer 128 bytes per iteration
55
+ // since the PCIe transaction size is 128 bytes aligned
56
+ constexpr auto kGranularity = 128 / kWarpThreads;
57
+
58
+ for (auto i = warp_id; i < length; i += kWorkers) {
59
+ const auto pos_src = static_cast<const T*>(indices_src)[i];
60
+ const auto pos_dst = static_cast<const T*>(indices_dst)[i];
61
+ const auto src_k = pointer::offset(k_cache_src, pos_src * kv_cache_src_stride);
62
+ const auto dst_k = pointer::offset(k_cache_dst, pos_dst * kv_cache_dst_stride);
63
+ const auto src_v = pointer::offset(v_cache_src, pos_src * kv_cache_src_stride);
64
+ const auto dst_v = pointer::offset(v_cache_dst, pos_dst * kv_cache_dst_stride);
65
+ const auto vec_k = warp::load_vec<kElementSize, kGranularity, kWarpThreads>(src_k);
66
+ const auto vec_v = warp::load_vec<kElementSize, kGranularity, kWarpThreads>(src_v);
67
+ warp::store_vec<kElementSize, kGranularity, kWarpThreads>(dst_k, vec_k);
68
+ warp::store_vec<kElementSize, kGranularity, kWarpThreads>(dst_v, vec_v);
69
+ }
70
+ }
71
+
72
+ template <
73
+ std::integral T,
74
+ std::size_t kElementSize,
75
+ std::size_t kUnroll,
76
+ std::size_t kBlockQuota,
77
+ std::size_t kNumThreads,
78
+ std::size_t kMaxOccupancy>
79
+ __global__ __launch_bounds__(kNumThreads, kMaxOccupancy) void hicache_transfer_all_layer(
80
+ const __grid_constant__ HicacheKernelParams params) {
81
+ // each warp acts as a worker
82
+ using namespace device;
83
+ using src_ptr_t = std::add_pointer_t<const void* const>;
84
+ using dst_ptr_t = std::add_pointer_t<void* const>;
85
+
86
+ static_assert(kNumThreads % kWarpThreads == 0);
87
+ constexpr auto kWarpThreads = device::kWarpThreads / kUnroll;
88
+ constexpr auto kWarpsPerBlock = static_cast<uint32_t>(kNumThreads) / kWarpThreads;
89
+ constexpr auto kWorkers = kWarpsPerBlock * kBlockQuota;
90
+
91
+ const auto& [
92
+ k_ptr_dst, v_ptr_dst, indices_dst, // dst
93
+ k_ptr_src, v_ptr_src, indices_src, // src
94
+ length, kv_cache_src_stride, kv_cache_dst_stride, num_layers // metadata
95
+ ] = params;
96
+ const auto warp_id = blockIdx.x * kWarpsPerBlock + threadIdx.x / kWarpThreads;
97
+
98
+ // force to transfer 128 bytes per iteration
99
+ // since the PCIe transaction size is 128 bytes aligned
100
+ constexpr auto kGranularity = 128 / kWarpThreads;
101
+
102
+ for (auto i = warp_id; i < length; i += kWorkers) {
103
+ const auto pos_src = static_cast<const T*>(indices_src)[i];
104
+ const auto pos_dst = static_cast<const T*>(indices_dst)[i];
105
+ for (std::size_t layer = 0; layer < num_layers; ++layer) {
106
+ const auto k_cache_src = static_cast<src_ptr_t>(k_ptr_src)[layer];
107
+ const auto v_cache_src = static_cast<src_ptr_t>(v_ptr_src)[layer];
108
+ const auto k_cache_dst = static_cast<dst_ptr_t>(k_ptr_dst)[layer];
109
+ const auto v_cache_dst = static_cast<dst_ptr_t>(v_ptr_dst)[layer];
110
+ const auto src_k = pointer::offset(k_cache_src, pos_src * kv_cache_src_stride);
111
+ const auto dst_k = pointer::offset(k_cache_dst, pos_dst * kv_cache_dst_stride);
112
+ const auto src_v = pointer::offset(v_cache_src, pos_src * kv_cache_src_stride);
113
+ const auto dst_v = pointer::offset(v_cache_dst, pos_dst * kv_cache_dst_stride);
114
+ const auto vec_k = warp::load_vec<kElementSize, kGranularity, kWarpThreads>(src_k);
115
+ const auto vec_v = warp::load_vec<kElementSize, kGranularity, kWarpThreads>(src_v);
116
+ warp::store_vec<kElementSize, kGranularity, kWarpThreads>(dst_k, vec_k);
117
+ warp::store_vec<kElementSize, kGranularity, kWarpThreads>(dst_v, vec_v);
118
+ }
119
+ }
120
+ }
121
+
122
+ template <
123
+ std::size_t kElementSize,
124
+ std::size_t kUnroll,
125
+ std::size_t kBlockQuota,
126
+ std::size_t kNumThreads,
127
+ std::size_t kMaxOccupancy>
128
+ struct HiCacheKernel {
129
+ template <typename T>
130
+ static constexpr auto _kernel_one =
131
+ hicache_transfer_per_layer<T, kElementSize, kUnroll, kBlockQuota, kNumThreads, kMaxOccupancy>;
132
+ template <typename T>
133
+ static constexpr auto _kernel_all =
134
+ hicache_transfer_all_layer<T, kElementSize, kUnroll, kBlockQuota, kNumThreads, kMaxOccupancy>;
135
+
136
+ static void run_one(
137
+ const tvm::ffi::TensorView k_cache_dst,
138
+ const tvm::ffi::TensorView v_cache_dst,
139
+ const tvm::ffi::TensorView indices_dst,
140
+ const tvm::ffi::TensorView k_cache_src,
141
+ const tvm::ffi::TensorView v_cache_src,
142
+ const tvm::ffi::TensorView indices_src) {
143
+ using namespace host;
144
+
145
+ auto D = SymbolicSize{"D"}; // cache dimension
146
+ auto N = SymbolicSize{"N"}; // src kv stride
147
+ auto M = SymbolicSize{"M"}; // dst kv stride
148
+ auto L = SymbolicSize{"L"}; // indices length
149
+ auto cache_dtype = SymbolicDType{};
150
+ auto indices_dtype = SymbolicDType{};
151
+ auto indices_device = SymbolicDevice{};
152
+
153
+ TensorMatcher({-1, D}) //
154
+ .with_strides({N, 1})
155
+ .with_dtype(cache_dtype)
156
+ .with_device<kDLCUDA, kDLCUDAHost, kDLCPU>()
157
+ .verify(k_cache_src)
158
+ .verify(v_cache_src);
159
+ TensorMatcher({-1, D}) //
160
+ .with_strides({M, 1})
161
+ .with_dtype(cache_dtype)
162
+ .with_device<kDLCUDA, kDLCUDAHost, kDLCPU>()
163
+ .verify(k_cache_dst)
164
+ .verify(v_cache_dst);
165
+ TensorMatcher({L}) //
166
+ .with_dtype<int32_t, int64_t>(indices_dtype)
167
+ .with_device<kDLCUDA>(indices_device)
168
+ .verify(indices_src)
169
+ .verify(indices_dst);
170
+
171
+ // verify dimension match
172
+ const auto dtype_size = dtype_bytes(cache_dtype.unwrap());
173
+ const auto element_bytes = D.unwrap() * dtype_size;
174
+ RuntimeCheck(kElementSize == element_bytes, "HicacheKernel: cache dimension mismatch.");
175
+
176
+ const auto k_cache_dst_ptr = k_cache_dst.data_ptr();
177
+ const auto v_cache_dst_ptr = v_cache_dst.data_ptr();
178
+ const auto k_cache_src_ptr = k_cache_src.data_ptr();
179
+ const auto v_cache_src_ptr = v_cache_src.data_ptr();
180
+ const auto indices_dst_ptr = indices_dst.data_ptr();
181
+ const auto indices_src_ptr = indices_src.data_ptr();
182
+ const auto length = static_cast<std::size_t>(L.unwrap());
183
+ const auto kv_cache_src_stride = static_cast<std::size_t>(N.unwrap()) * dtype_size;
184
+ const auto kv_cache_dst_stride = static_cast<std::size_t>(M.unwrap()) * dtype_size;
185
+ const auto use_int32 = indices_dtype.unwrap().bits == 32;
186
+ const auto device = indices_device.unwrap();
187
+
188
+ constexpr auto kWorkersPerBlock = kNumThreads / (device::kWarpThreads / kUnroll);
189
+ const auto num_blocks = std::min(div_ceil(length, kWorkersPerBlock), kBlockQuota);
190
+ const auto params = HicacheKernelParams{
191
+ .k_cache_dst = k_cache_dst_ptr,
192
+ .v_cache_dst = v_cache_dst_ptr,
193
+ .indices_dst = indices_dst_ptr,
194
+ .k_cache_src = k_cache_src_ptr,
195
+ .v_cache_src = v_cache_src_ptr,
196
+ .indices_src = indices_src_ptr,
197
+ .length = length,
198
+ .kv_cache_src_stride = kv_cache_src_stride,
199
+ .kv_cache_dst_stride = kv_cache_dst_stride,
200
+ };
201
+ const auto kernel = use_int32 ? _kernel_one<int32_t> : _kernel_one<int64_t>;
202
+ LaunchKernel(num_blocks, kNumThreads, device)(kernel, params);
203
+ }
204
+
205
+ static void run_all(
206
+ const tvm::ffi::TensorView k_ptr_dst,
207
+ const tvm::ffi::TensorView v_ptr_dst,
208
+ const tvm::ffi::TensorView indices_dst,
209
+ const tvm::ffi::TensorView k_ptr_src,
210
+ const tvm::ffi::TensorView v_ptr_src,
211
+ const tvm::ffi::TensorView indices_src,
212
+ const std::size_t kv_src_stride,
213
+ const std::size_t kv_dst_stride) {
214
+ using namespace host;
215
+
216
+ auto N = SymbolicSize{"N"}; // num layers
217
+ auto L = SymbolicSize{"L"}; // indices length
218
+ auto dtype_ = SymbolicDType{};
219
+ auto device_ = SymbolicDevice{};
220
+
221
+ TensorMatcher({N}) //
222
+ .with_dtype<uint64_t>()
223
+ .with_device<kDLCUDA>(device_)
224
+ .verify(k_ptr_src)
225
+ .verify(v_ptr_src)
226
+ .verify(k_ptr_dst)
227
+ .verify(v_ptr_dst);
228
+ TensorMatcher({L}) //
229
+ .with_dtype<int32_t, int64_t>(dtype_)
230
+ .with_device<kDLCUDA>(device_)
231
+ .verify(indices_src)
232
+ .verify(indices_dst);
233
+
234
+ // verify dimension match
235
+ const auto k_cache_dst_ptr = k_ptr_dst.data_ptr();
236
+ const auto v_cache_dst_ptr = v_ptr_dst.data_ptr();
237
+ const auto k_cache_src_ptr = k_ptr_src.data_ptr();
238
+ const auto v_cache_src_ptr = v_ptr_src.data_ptr();
239
+ const auto indices_dst_ptr = indices_dst.data_ptr();
240
+ const auto indices_src_ptr = indices_src.data_ptr();
241
+ const auto length = static_cast<std::size_t>(L.unwrap());
242
+ const auto use_int32 = dtype_.unwrap().bits == 32;
243
+ const auto device = device_.unwrap();
244
+
245
+ constexpr auto kWorkersPerBlock = kNumThreads / (device::kWarpThreads / kUnroll);
246
+ const auto num_blocks = std::min(div_ceil(length, kWorkersPerBlock), kBlockQuota);
247
+ const auto params = HicacheKernelParams{
248
+ .k_cache_dst = k_cache_dst_ptr,
249
+ .v_cache_dst = v_cache_dst_ptr,
250
+ .indices_dst = indices_dst_ptr,
251
+ .k_cache_src = k_cache_src_ptr,
252
+ .v_cache_src = v_cache_src_ptr,
253
+ .indices_src = indices_src_ptr,
254
+ .length = length,
255
+ .kv_cache_src_stride = kv_src_stride,
256
+ .kv_cache_dst_stride = kv_dst_stride,
257
+ .num_layers = static_cast<std::size_t>(N.unwrap()),
258
+ };
259
+ const auto kernel = use_int32 ? _kernel_all<int32_t> : _kernel_all<int64_t>;
260
+ LaunchKernel(num_blocks, kNumThreads, device)(kernel, params);
261
+ }
262
+ };
263
+
264
+ } // namespace
sglang/jit_kernel/cuda_wait_value.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from functools import lru_cache
4
+ from typing import TYPE_CHECKING
5
+
6
+ import torch
7
+
8
+ from sglang.jit_kernel.utils import load_jit
9
+
10
+ if TYPE_CHECKING:
11
+ import torch
12
+ from tvm_ffi.module import Module
13
+
14
+
15
+ @lru_cache(maxsize=1)
16
+ def _jit_stream_wait_value_module() -> Module:
17
+ return load_jit(
18
+ "cuda_wait_value",
19
+ cuda_files=["cuda_wait_value.cuh"],
20
+ cuda_wrappers=[("stream_wait_value", "stream_wait_value")],
21
+ )
22
+
23
+
24
+ def stream_wait_value(flag: torch.Tensor, value: int) -> None:
25
+ module = _jit_stream_wait_value_module()
26
+ module.stream_wait_value(flag, value)
27
+
28
+
29
+ class Event:
30
+ def __init__(self) -> None:
31
+ self.flag = torch.zeros(1, dtype=torch.int32, device="cuda")
32
+
33
+ def record(self, value: int = 1) -> None:
34
+ self.flag[0] = value
35
+
36
+ def wait(self, value: int = 1) -> None:
37
+ stream_wait_value(self.flag, value)
38
+
39
+
40
+ def test_wait_before_record(event: Event | torch.cuda.Event):
41
+ stream_a = torch.cuda.Stream()
42
+ stream_b = torch.cuda.Stream()
43
+
44
+ with torch.cuda.stream(stream_a):
45
+ event.wait()
46
+
47
+ stream_a.synchronize()
48
+
49
+ with torch.cuda.stream(stream_b):
50
+ event.record()
51
+
52
+
53
+ def main():
54
+ import threading
55
+ import time
56
+
57
+ block_thead = threading.Thread(
58
+ target=test_wait_before_record, args=(Event(),), daemon=True
59
+ )
60
+ block_thead.start()
61
+
62
+ non_block_thread = threading.Thread(
63
+ target=test_wait_before_record, args=(torch.cuda.Event(),)
64
+ )
65
+ non_block_thread.start()
66
+
67
+ print("Checking if custom Event blocks the stream...", flush=True)
68
+ for _ in range(5):
69
+ print(f"{block_thead.is_alive()=}, {non_block_thread.is_alive()=}", flush=True)
70
+ time.sleep(1)
71
+
72
+ assert block_thead.is_alive(), "Custom Event did not block as expected"
73
+ assert not non_block_thread.is_alive(), "torch.cuda.Event should not block"
74
+ print("=" * 40)
75
+ print("Test completed successfully.")
76
+
77
+
78
+ if __name__ == "__main__":
79
+ main()
sglang/jit_kernel/hicache.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from functools import lru_cache
5
+ from typing import TYPE_CHECKING
6
+
7
+ from sglang.jit_kernel.utils import load_jit, make_cpp_args
8
+
9
+ if TYPE_CHECKING:
10
+ import torch
11
+ from tvm_ffi.module import Module
12
+
13
+ DEFAULT_BLOCK_QUOTA = 2
14
+
15
+
16
+ @lru_cache(maxsize=None)
17
+ def _jit_hicache_module(*, element_size: int, unroll: int, block_quota: int) -> Module:
18
+ num_threads, occupancy = 1024, 1
19
+ args = make_cpp_args(
20
+ element_size,
21
+ unroll,
22
+ block_quota,
23
+ num_threads,
24
+ occupancy,
25
+ )
26
+ return load_jit(
27
+ "hicache",
28
+ *args,
29
+ cuda_files=["hicache.cuh"],
30
+ cuda_wrappers=[
31
+ ("launch_one", f"HiCacheKernel<{args}>::run_one"),
32
+ ("launch_all", f"HiCacheKernel<{args}>::run_all"),
33
+ ],
34
+ )
35
+
36
+
37
+ def can_use_hicache_jit_kernel(
38
+ *,
39
+ element_size: int,
40
+ unroll: int | None = None, # can be tuned for performance
41
+ block_quota: int | None = None, # can be tuned for less interference
42
+ ) -> bool:
43
+ try:
44
+ unroll = unroll or _default_unroll(element_size)
45
+ block_quota = block_quota or DEFAULT_BLOCK_QUOTA
46
+ _jit_hicache_module(
47
+ element_size=element_size,
48
+ unroll=unroll,
49
+ block_quota=block_quota,
50
+ )
51
+ return True
52
+ except Exception as e:
53
+ logger = logging.getLogger(__name__)
54
+ logger.warning(f"Failed to load JIT HiCache kernel: {e}")
55
+ return False
56
+
57
+
58
+ def _default_unroll(element_size: int) -> int:
59
+ if element_size <= 512:
60
+ return 4
61
+
62
+ if element_size <= 1024:
63
+ return 2
64
+
65
+ # fallback: no unroll
66
+ return 1
67
+
68
+
69
+ def transfer_hicache_one_layer(
70
+ k_cache_dst: torch.Tensor,
71
+ v_cache_dst: torch.Tensor,
72
+ indices_dst: torch.Tensor,
73
+ k_cache_src: torch.Tensor,
74
+ v_cache_src: torch.Tensor,
75
+ indices_src: torch.Tensor,
76
+ *,
77
+ element_dim: int | None = None,
78
+ unroll: int | None = None, # can be tuned for performance
79
+ block_quota: int | None = None, # can be tuned for less interference
80
+ ) -> None:
81
+ element_dim = element_dim or k_cache_dst.size(-1)
82
+ k_cache_src = k_cache_src.view(-1, element_dim)
83
+ v_cache_src = v_cache_src.view(-1, element_dim)
84
+ k_cache_dst = k_cache_dst.view(-1, element_dim)
85
+ v_cache_dst = v_cache_dst.view(-1, element_dim)
86
+ element_size = element_dim * k_cache_dst.element_size()
87
+ block_quota = block_quota or DEFAULT_BLOCK_QUOTA
88
+ unroll = unroll or _default_unroll(element_size)
89
+ module = _jit_hicache_module(
90
+ element_size=element_size,
91
+ unroll=unroll,
92
+ block_quota=block_quota,
93
+ )
94
+ module.launch_one(
95
+ k_cache_dst,
96
+ v_cache_dst,
97
+ indices_dst,
98
+ k_cache_src,
99
+ v_cache_src,
100
+ indices_src,
101
+ )
102
+
103
+
104
+ def transfer_hicache_all_layer(
105
+ k_ptr_dst: torch.Tensor,
106
+ v_ptr_dst: torch.Tensor,
107
+ indices_dst: torch.Tensor,
108
+ k_ptr_src: torch.Tensor,
109
+ v_ptr_src: torch.Tensor,
110
+ indices_src: torch.Tensor,
111
+ *,
112
+ kv_cache_src_stride_bytes: int,
113
+ kv_cache_dst_stride_bytes: int,
114
+ element_size: int | None = None,
115
+ unroll: int | None = None, # can be tuned for performance
116
+ block_quota: int | None = None, # can be tuned for less interference
117
+ ) -> None:
118
+ if element_size is None: # assume both contiguous
119
+ assert kv_cache_dst_stride_bytes == kv_cache_src_stride_bytes
120
+ element_size = kv_cache_dst_stride_bytes
121
+
122
+ block_quota = block_quota or DEFAULT_BLOCK_QUOTA
123
+ unroll = unroll or _default_unroll(element_size)
124
+ module = _jit_hicache_module(
125
+ element_size=element_size,
126
+ unroll=unroll,
127
+ block_quota=block_quota,
128
+ )
129
+ module.launch_all(
130
+ k_ptr_dst,
131
+ v_ptr_dst,
132
+ indices_dst,
133
+ k_ptr_src,
134
+ v_ptr_src,
135
+ indices_src,
136
+ kv_cache_src_stride_bytes,
137
+ kv_cache_dst_stride_bytes,
138
+ )
sglang/jit_kernel/include/sgl_kernel/tensor.h ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <sgl_kernel/utils.h>
3
+
4
+ #include <dlpack/dlpack.h>
5
+ #include <tvm/ffi/container/tensor.h>
6
+ #include <tvm/ffi/dtype.h>
7
+
8
+ #include <algorithm>
9
+ #include <array>
10
+ #include <concepts>
11
+ #include <cstddef>
12
+ #include <cstdint>
13
+ #include <initializer_list>
14
+ #include <optional>
15
+ #include <ranges>
16
+ #include <source_location>
17
+ #include <span>
18
+ #include <sstream>
19
+ #include <string>
20
+ #include <string_view>
21
+ #include <type_traits>
22
+ #include <utility>
23
+
24
+ namespace host {
25
+
26
+ namespace stdr = std::ranges;
27
+ namespace stdv = std::views;
28
+
29
+ namespace details {
30
+
31
+ struct SizeRef;
32
+ struct DTypeRef;
33
+ struct DeviceRef;
34
+
35
+ template <typename T>
36
+ struct dtype_trait {};
37
+
38
+ template <std::integral T>
39
+ struct dtype_trait<T> {
40
+ inline static constexpr auto value = DLDataType{
41
+ .code = std::is_signed_v<T> ? DLDataTypeCode::kDLInt : DLDataTypeCode::kDLUInt,
42
+ .bits = static_cast<std::uint8_t>(sizeof(T) * 8),
43
+ .lanes = 1};
44
+ };
45
+
46
+ template <std::floating_point T>
47
+ struct dtype_trait<T> {
48
+ inline static constexpr auto value =
49
+ DLDataType{.code = DLDataTypeCode::kDLFloat, .bits = static_cast<std::uint8_t>(sizeof(T) * 8), .lanes = 1};
50
+ };
51
+
52
+ inline constexpr auto kAnyDeviceID = -1;
53
+ inline constexpr auto kAnySize = static_cast<int64_t>(-1);
54
+ inline constexpr auto kNullSize = static_cast<int64_t>(-1);
55
+ inline constexpr auto kNullDType = static_cast<DLDataTypeCode>(18u);
56
+ inline constexpr auto kNullDevice = static_cast<DLDeviceType>(-1);
57
+
58
+ template <typename... Ts>
59
+ inline constexpr auto kDTypeList = std::array<DLDataType, sizeof...(Ts)>{dtype_trait<Ts>::value...};
60
+
61
+ template <DLDeviceType... Codes>
62
+ inline constexpr auto kDeviceList = std::array<DLDevice, sizeof...(Codes)>{
63
+ DLDevice{.device_type = static_cast<DLDeviceType>(Codes), .device_id = kAnyDeviceID}...};
64
+
65
+ template <typename T>
66
+ struct PrintAbleSpan {
67
+ explicit PrintAbleSpan(std::span<const T> data) : data(data) {}
68
+ std::span<const T> data;
69
+ };
70
+
71
+ // define DLDataType comparison and printing in root namespace
72
+ inline constexpr auto kDeviceStringMap = [] {
73
+ constexpr auto map = std::array<std::pair<DLDeviceType, const char*>, 16>{
74
+ std::pair{DLDeviceType::kDLCPU, "cpu"},
75
+ std::pair{DLDeviceType::kDLCUDA, "cuda"},
76
+ std::pair{DLDeviceType::kDLCUDAHost, "cuda_host"},
77
+ std::pair{DLDeviceType::kDLOpenCL, "opencl"},
78
+ std::pair{DLDeviceType::kDLVulkan, "vulkan"},
79
+ std::pair{DLDeviceType::kDLMetal, "metal"},
80
+ std::pair{DLDeviceType::kDLVPI, "vpi"},
81
+ std::pair{DLDeviceType::kDLROCM, "rocm"},
82
+ std::pair{DLDeviceType::kDLROCMHost, "rocm_host"},
83
+ std::pair{DLDeviceType::kDLExtDev, "ext_dev"},
84
+ std::pair{DLDeviceType::kDLCUDAManaged, "cuda_managed"},
85
+ std::pair{DLDeviceType::kDLOneAPI, "oneapi"},
86
+ std::pair{DLDeviceType::kDLWebGPU, "webgpu"},
87
+ std::pair{DLDeviceType::kDLHexagon, "hexagon"},
88
+ std::pair{DLDeviceType::kDLMAIA, "maia"},
89
+ std::pair{DLDeviceType::kDLTrn, "trn"},
90
+ };
91
+ constexpr auto max_type = stdr::max(map | stdv::keys);
92
+ auto result = std::array<std::string_view, max_type + 1>{};
93
+ for (const auto& [code, name] : map) {
94
+ result[static_cast<std::size_t>(code)] = name;
95
+ }
96
+ return result;
97
+ }();
98
+
99
+ struct PrintableDevice {
100
+ DLDevice device;
101
+ };
102
+
103
+ inline auto& operator<<(std::ostream& os, DLDevice device) {
104
+ const auto& mapping = kDeviceStringMap;
105
+ const auto entry = static_cast<std::size_t>(device.device_type);
106
+ host::RuntimeCheck(entry < mapping.size());
107
+ const auto name = mapping[entry];
108
+ host::RuntimeCheck(!name.empty(), "Unknown device: ", int(device.device_type));
109
+ os << name;
110
+ if (device.device_id != kAnyDeviceID) os << "[" << device.device_id << "]";
111
+ return os;
112
+ }
113
+
114
+ inline auto& operator<<(std::ostream& os, PrintableDevice pd) {
115
+ return os << pd.device;
116
+ }
117
+
118
+ template <typename T>
119
+ inline auto& operator<<(std::ostream& os, PrintAbleSpan<T> span) {
120
+ os << "[";
121
+ for (const auto i : stdv::iota(std::size_t{0}, span.data.size())) {
122
+ if (i > 0) {
123
+ os << ", ";
124
+ }
125
+ os << span.data[i];
126
+ }
127
+ os << "]";
128
+ return os;
129
+ }
130
+
131
+ } // namespace details
132
+
133
+ struct SymbolicSize {
134
+ public:
135
+ SymbolicSize(std::string_view annotation = {}) : m_value(details::kNullSize), m_annotation(annotation) {}
136
+
137
+ auto get_name() const -> std::string_view {
138
+ return m_annotation;
139
+ }
140
+ auto set_value(int64_t value) -> void {
141
+ host::RuntimeCheck(!this->has_value(), "Size value already set");
142
+ m_value = value;
143
+ }
144
+ auto has_value() const -> bool {
145
+ return m_value != details::kNullSize;
146
+ }
147
+ auto get_value() const -> std::optional<int64_t> {
148
+ return this->has_value() ? std::optional{m_value} : std::nullopt;
149
+ }
150
+ auto unwrap() const -> int64_t {
151
+ host::RuntimeCheck(this->has_value(), "Size value is not set");
152
+ return m_value;
153
+ }
154
+
155
+ SymbolicSize(const SymbolicSize&) = delete;
156
+ SymbolicSize& operator=(const SymbolicSize&) = delete;
157
+
158
+ auto verify(int64_t dim) -> void {
159
+ if (this->has_value()) {
160
+ host::RuntimeCheck(m_value == dim, "Size mismatch: expected ", m_value, " but got ", dim);
161
+ } else {
162
+ this->set_value(dim);
163
+ }
164
+ }
165
+
166
+ private:
167
+ std::int64_t m_value;
168
+ std::string_view m_annotation;
169
+ };
170
+
171
+ inline auto operator==(DLDevice lhs, DLDevice rhs) -> bool {
172
+ return lhs.device_type == rhs.device_type && lhs.device_id == rhs.device_id;
173
+ }
174
+
175
+ struct SymbolicDType {
176
+ public:
177
+ SymbolicDType() : m_value({details::kNullDType, 0, 0}) {}
178
+
179
+ auto set_value(DLDataType value) -> void {
180
+ host::RuntimeCheck(!this->has_value(), "Dtype value already set");
181
+ host::RuntimeCheck(
182
+ m_check(value), "Dtype value [", value, "] not in the allowed options: ", details::PrintAbleSpan{m_options});
183
+ m_value = value;
184
+ }
185
+ auto has_value() const -> bool {
186
+ return m_value.code != details::kNullDType;
187
+ }
188
+ auto get_value() const -> std::optional<DLDataType> {
189
+ return this->has_value() ? std::optional{m_value} : std::nullopt;
190
+ }
191
+ auto unwrap() const -> DLDataType {
192
+ host::RuntimeCheck(this->has_value(), "Dtype value is not set");
193
+ return m_value;
194
+ }
195
+
196
+ auto set_options(std::span<const DLDataType> options) -> void {
197
+ m_options = options;
198
+ }
199
+ template <typename... Ts>
200
+ auto set_options() -> void {
201
+ m_options = details::kDTypeList<Ts...>;
202
+ }
203
+
204
+ auto verify(DLDataType dtype) -> void {
205
+ if (this->has_value()) {
206
+ host::RuntimeCheck(m_value == dtype, "DType mismatch: expected ", m_value, " but got ", dtype);
207
+ } else {
208
+ this->set_value(dtype);
209
+ }
210
+ }
211
+
212
+ private:
213
+ auto m_check(DLDataType value) const -> bool {
214
+ return stdr::empty(m_options) || (stdr::find(m_options, value) != stdr::end(m_options));
215
+ }
216
+
217
+ std::span<const DLDataType> m_options;
218
+ DLDataType m_value;
219
+ };
220
+
221
+ struct SymbolicDevice {
222
+ public:
223
+ SymbolicDevice() : m_value({details::kNullDevice, details::kAnyDeviceID}) {}
224
+
225
+ auto set_value(DLDevice value) -> void {
226
+ host::RuntimeCheck(!this->has_value(), "Device value already set");
227
+ host::RuntimeCheck(
228
+ m_check(value),
229
+ "Device value [",
230
+ details::PrintableDevice{value},
231
+ "] not in the allowed options: ",
232
+ details::PrintAbleSpan{m_options});
233
+ m_value = value;
234
+ }
235
+ auto has_value() const -> bool {
236
+ return m_value.device_type != details::kNullDevice;
237
+ }
238
+ auto get_value() const -> std::optional<DLDevice> {
239
+ return this->has_value() ? std::optional{m_value} : std::nullopt;
240
+ }
241
+ auto unwrap() const -> DLDevice {
242
+ host::RuntimeCheck(this->has_value(), "Device value is not set");
243
+ return m_value;
244
+ }
245
+
246
+ auto set_options(std::span<const DLDevice> options) -> void {
247
+ m_options = options;
248
+ }
249
+ template <DLDeviceType... Codes>
250
+ auto set_options() -> void {
251
+ m_options = details::kDeviceList<Codes...>;
252
+ }
253
+
254
+ auto verify(DLDevice device) -> void {
255
+ if (this->has_value()) {
256
+ host::RuntimeCheck(
257
+ m_value == device,
258
+ "Device mismatch: expected ",
259
+ details::PrintableDevice{m_value},
260
+ " but got ",
261
+ details::PrintableDevice{device});
262
+ } else {
263
+ this->set_value(device);
264
+ }
265
+ }
266
+
267
+ private:
268
+ auto m_check(DLDevice value) const -> bool {
269
+ return stdr::empty(m_options) || (stdr::any_of(m_options, [value](const DLDevice& opt) {
270
+ // device type must exactly match
271
+ if (opt.device_type != value.device_type) return false;
272
+ // device id can be wildcarded
273
+ return opt.device_id == details::kAnyDeviceID || opt.device_id == value.device_id;
274
+ }));
275
+ }
276
+
277
+ std::span<const DLDevice> m_options;
278
+ DLDevice m_value;
279
+ };
280
+
281
+ namespace details {
282
+
283
+ template <typename T>
284
+ struct BaseRef {
285
+ public:
286
+ BaseRef(const BaseRef&) = delete;
287
+ BaseRef& operator=(const BaseRef&) = delete;
288
+
289
+ auto operator->() const -> T* {
290
+ return m_ref;
291
+ }
292
+ auto operator*() const -> T& {
293
+ return *m_ref;
294
+ }
295
+ auto rebind(T& other) -> void {
296
+ m_ref = &other;
297
+ }
298
+
299
+ explicit BaseRef() : m_ref(&m_cache), m_cache() {}
300
+ BaseRef(T& size) : m_ref(&size), m_cache() {}
301
+
302
+ private:
303
+ T* m_ref;
304
+ T m_cache;
305
+ };
306
+
307
+ struct SizeRef : BaseRef<SymbolicSize> {
308
+ using BaseRef::BaseRef;
309
+ SizeRef(int64_t value) {
310
+ if (value != kAnySize) {
311
+ (**this).set_value(value);
312
+ } else {
313
+ // otherwise, we can match any size
314
+ }
315
+ }
316
+
317
+ auto value_or_name(std::size_t dim) const -> std::string {
318
+ if (const auto value = (**this).get_value()) {
319
+ return std::to_string(*value);
320
+ } else {
321
+ const auto annotation = (**this).get_name();
322
+ if (annotation.empty()) {
323
+ return "dim#" + std::to_string(dim);
324
+ } else {
325
+ return static_cast<std::string>(annotation);
326
+ }
327
+ }
328
+ }
329
+ };
330
+
331
+ struct DTypeRef : BaseRef<SymbolicDType> {
332
+ using BaseRef::BaseRef;
333
+ DTypeRef(DLDataType options) {
334
+ (**this).set_value(options);
335
+ }
336
+ DTypeRef(std::initializer_list<DLDataType> options) {
337
+ (**this).set_options(options);
338
+ }
339
+ DTypeRef(std::span<const DLDataType> options) {
340
+ (**this).set_options(options);
341
+ }
342
+ };
343
+
344
+ struct DeviceRef : BaseRef<SymbolicDevice> {
345
+ using BaseRef::BaseRef;
346
+ DeviceRef(DLDevice options) {
347
+ (**this).set_value(options);
348
+ }
349
+ DeviceRef(std::initializer_list<DLDevice> options) {
350
+ (**this).set_options(options);
351
+ }
352
+ DeviceRef(std::span<const DLDevice> options) {
353
+ (**this).set_options(options);
354
+ }
355
+ };
356
+
357
+ } // namespace details
358
+
359
+ struct TensorMatcher {
360
+ private:
361
+ using SizeRef = details::SizeRef;
362
+ using DTypeRef = details::DTypeRef;
363
+ using DeviceRef = details::DeviceRef;
364
+ using Loc_t = std::source_location;
365
+
366
+ public:
367
+ TensorMatcher(const TensorMatcher&) = delete;
368
+ TensorMatcher& operator=(const TensorMatcher&) = delete;
369
+
370
+ explicit TensorMatcher(std::initializer_list<SizeRef> shape) : m_shape(shape), m_strides(), m_dtype() {}
371
+
372
+ auto with_strides(std::initializer_list<SizeRef> strides) && -> TensorMatcher&& {
373
+ // no partial update allowed
374
+ host::RuntimeCheck(m_strides.size() == 0, "Strides already specified");
375
+ host::RuntimeCheck(m_shape.size() == strides.size(), "Strides size must match shape size");
376
+ m_strides = strides;
377
+ return std::move(*this);
378
+ }
379
+
380
+ template <typename... Ts>
381
+ auto with_dtype(DTypeRef&& dtype) && -> TensorMatcher&& {
382
+ m_init_dtype();
383
+ m_dtype.rebind(*dtype);
384
+ return std::move(*this);
385
+ }
386
+
387
+ template <typename... Ts>
388
+ auto with_dtype() && -> TensorMatcher&& {
389
+ static_assert(sizeof...(Ts) > 0, "At least one dtype option must be specified");
390
+ m_init_dtype();
391
+ m_dtype->set_options<Ts...>();
392
+ return std::move(*this);
393
+ }
394
+
395
+ template <DLDeviceType... Codes>
396
+ auto with_device(DeviceRef&& device) && -> TensorMatcher&& {
397
+ m_init_device();
398
+ m_device.rebind(*device);
399
+ return std::move(*this);
400
+ }
401
+
402
+ template <DLDeviceType... Codes>
403
+ auto with_device() && -> TensorMatcher&& {
404
+ static_assert(sizeof...(Codes) > 0, "At least one device option must be specified");
405
+ m_init_device();
406
+ m_device->set_options<Codes...>();
407
+ return std::move(*this);
408
+ }
409
+
410
+ // once we start verification, we cannot modify anymore
411
+ auto verify(tvm::ffi::TensorView view, Loc_t loc = Loc_t::current()) const&& -> const TensorMatcher&& {
412
+ try {
413
+ this->m_verify_impl(view);
414
+ } catch (PanicError& e) {
415
+ auto oss = std::ostringstream{};
416
+ oss << "Tensor match failed for " << this->debug_str() << " at " << loc.file_name() << ":" << loc.line()
417
+ << "\n- Root cause: " << e.detail();
418
+ throw PanicError(std::move(oss).str());
419
+ }
420
+ return std::move(*this);
421
+ }
422
+
423
+ auto debug_str() const -> std::string {
424
+ auto oss = std::ostringstream{};
425
+ oss << "Tensor<";
426
+ std::size_t dim = 0;
427
+ for (const auto& size_ref : m_shape) {
428
+ if (dim > 0) {
429
+ oss << ", ";
430
+ }
431
+ oss << size_ref.value_or_name(dim++);
432
+ }
433
+ oss << ">";
434
+ if (m_strides.size() > 0) {
435
+ oss << " [strides=<";
436
+ dim = 0;
437
+ for (const auto& stride_ref : m_strides) {
438
+ if (dim > 0) {
439
+ oss << ", ";
440
+ }
441
+ oss << stride_ref.value_or_name(dim++);
442
+ }
443
+ oss << ">]";
444
+ }
445
+ return std::move(oss).str();
446
+ }
447
+
448
+ private:
449
+ auto m_verify_impl(tvm::ffi::TensorView view) const -> void {
450
+ const auto dim = static_cast<std::size_t>(view.dim());
451
+ host::RuntimeCheck(dim == m_shape.size(), "Tensor dimension mismatch: expected ", m_shape.size(), " but got ", dim);
452
+ for (const auto i : stdv::iota(std::size_t{0}, dim)) {
453
+ m_shape[i]->verify(view.size(i));
454
+ }
455
+ if (this->m_has_strides()) {
456
+ for (const auto i : stdv::iota(std::size_t{0}, dim)) {
457
+ m_strides[i]->verify(view.stride(i));
458
+ }
459
+ } else {
460
+ host::RuntimeCheck(view.is_contiguous(), "Tensor is not contiguous as expected");
461
+ }
462
+ // since we may use the same matcher to verify again, we will force to check
463
+ m_dtype->verify(view.dtype());
464
+ m_device->verify(view.device());
465
+ }
466
+
467
+ auto m_init_dtype() -> void {
468
+ host::RuntimeCheck(!m_has_dtype, "DType already specified");
469
+ m_has_dtype = true;
470
+ }
471
+ auto m_init_device() -> void {
472
+ host::RuntimeCheck(!m_has_device, "Device already specified");
473
+ m_has_device = true;
474
+ }
475
+ auto m_has_strides() const -> bool {
476
+ return !m_strides.empty();
477
+ }
478
+
479
+ std::span<const SizeRef> m_shape;
480
+ std::span<const SizeRef> m_strides;
481
+ DTypeRef m_dtype;
482
+ DeviceRef m_device;
483
+ bool m_has_dtype = false;
484
+ bool m_has_device = false;
485
+ };
486
+
487
+ } // namespace host
sglang/jit_kernel/include/sgl_kernel/utils.cuh ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <sgl_kernel/utils.h>
4
+
5
+ #include <dlpack/dlpack.h>
6
+ #include <tvm/ffi/extra/c_env_api.h>
7
+
8
+ #include <concepts>
9
+ #include <cstddef>
10
+ #include <source_location>
11
+ #include <type_traits>
12
+
13
+ namespace device {
14
+
15
+ inline constexpr auto kWarpThreads = 32u;
16
+
17
+ namespace pointer {
18
+
19
+ // we only allow void * pointer arithmetic for safety
20
+
21
+ template <typename T, std::integral... U>
22
+ __always_inline __device__ auto offset(T* ptr, U... offset) -> void* {
23
+ static_assert(std::is_same_v<T, void>, "Pointer arithmetic is only allowed for void* pointers");
24
+ return static_cast<char*>(ptr) + (... + offset);
25
+ }
26
+
27
+ template <typename T, std::integral... U>
28
+ __always_inline __device__ auto offset(const T* ptr, U... offset) -> const void* {
29
+ static_assert(std::is_same_v<T, void>, "Pointer arithmetic is only allowed for void* pointers");
30
+ return static_cast<const char*>(ptr) + (... + offset);
31
+ }
32
+
33
+ } // namespace pointer
34
+
35
+ } // namespace device
36
+
37
+ namespace host {
38
+
39
+ inline auto
40
+ RuntimeDeviceCheck(::cudaError_t error, std::source_location location = std::source_location::current()) -> void {
41
+ if (error != ::cudaSuccess) {
42
+ [[unlikely]];
43
+ ::host::panic(location, "CUDA error: ", ::cudaGetErrorString(error));
44
+ }
45
+ }
46
+
47
+ inline auto RuntimeCudaCheck(std::source_location location = std::source_location::current()) -> void {
48
+ return RuntimeDeviceCheck(::cudaGetLastError(), location);
49
+ }
50
+
51
+ template <auto F>
52
+ inline void set_smem_once(std::size_t smem_size) {
53
+ static const auto last_smem_size = [&] {
54
+ RuntimeDeviceCheck(::cudaFuncSetAttribute(F, ::cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
55
+ return smem_size;
56
+ }();
57
+ RuntimeCheck(
58
+ smem_size <= last_smem_size,
59
+ "Dynamic shared memory size exceeds the previously set maximum size: ",
60
+ last_smem_size,
61
+ " bytes");
62
+ }
63
+
64
+ struct LaunchKernel {
65
+ public:
66
+ explicit LaunchKernel(
67
+ dim3 grid_dim, dim3 block_dim, DLDevice device, std::size_t dynamic_shared_mem_bytes = 0) noexcept
68
+ : m_config(s_make_config(grid_dim, block_dim, resolve_device(device), dynamic_shared_mem_bytes)) {}
69
+
70
+ explicit LaunchKernel(
71
+ dim3 grid_dim, dim3 block_dim, cudaStream_t stream, std::size_t dynamic_shared_mem_bytes = 0) noexcept
72
+ : m_config(s_make_config(grid_dim, block_dim, stream, dynamic_shared_mem_bytes)) {}
73
+
74
+ static auto resolve_device(DLDevice device) -> cudaStream_t {
75
+ return static_cast<cudaStream_t>(::TVMFFIEnvGetStream(device.device_type, device.device_id));
76
+ }
77
+
78
+ LaunchKernel(const LaunchKernel&) = delete;
79
+ LaunchKernel& operator=(const LaunchKernel&) = delete;
80
+
81
+ template <typename T, typename... Args>
82
+ auto operator()(T&& kernel, Args&&... args) const -> void {
83
+ host::RuntimeDeviceCheck(::cudaLaunchKernelEx(&m_config, kernel, std::forward<Args>(args)...));
84
+ }
85
+
86
+ private:
87
+ static auto
88
+ s_make_config(dim3 grid_dim, dim3 block_dim, cudaStream_t stream, std::size_t smem) -> cudaLaunchConfig_t {
89
+ auto config = ::cudaLaunchConfig_t{};
90
+ config.gridDim = grid_dim;
91
+ config.blockDim = block_dim;
92
+ config.dynamicSmemBytes = smem;
93
+ config.stream = stream;
94
+ config.numAttrs = 0;
95
+ return config;
96
+ }
97
+ cudaLaunchConfig_t m_config;
98
+ /// TODO: We can add a queue to store the attributes if needed in the future.
99
+ };
100
+
101
+ } // namespace host
sglang/jit_kernel/include/sgl_kernel/utils.h ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <dlpack/dlpack.h>
4
+
5
+ #include <concepts>
6
+ #include <ostream>
7
+ #include <source_location>
8
+ #include <sstream>
9
+ #include <utility>
10
+
11
+ namespace host {
12
+
13
+ struct PanicError : public std::runtime_error {
14
+ public:
15
+ // copy and move constructors
16
+ explicit PanicError(std::string msg) : runtime_error(msg), m_message(std::move(msg)) {}
17
+ auto detail() const -> std::string_view {
18
+ const auto sv = std::string_view{m_message};
19
+ const auto pos = sv.find(": ");
20
+ return pos == std::string_view::npos ? sv : sv.substr(pos + 2);
21
+ }
22
+
23
+ private:
24
+ std::string m_message;
25
+ };
26
+
27
+ template <typename... Args>
28
+ [[noreturn]]
29
+ inline auto panic(std::source_location location, Args&&... args) -> void {
30
+ std::ostringstream os;
31
+ os << "Runtime check failed at " << location.file_name() << ":" << location.line();
32
+ if constexpr (sizeof...(args) > 0) {
33
+ os << ": ";
34
+ (os << ... << std::forward<Args>(args));
35
+ } else {
36
+ os << " in " << location.function_name();
37
+ }
38
+ throw PanicError(std::move(os).str());
39
+ }
40
+
41
+ template <typename... Args>
42
+ struct RuntimeCheck {
43
+ using Loc_t = std::source_location;
44
+ template <typename Cond>
45
+ explicit RuntimeCheck(Cond&& condition, Args&&... args, Loc_t location = Loc_t::current()) {
46
+ if (!condition) {
47
+ [[unlikely]];
48
+ ::host::panic(location, std::forward<Args>(args)...);
49
+ }
50
+ }
51
+ };
52
+
53
+ template <typename Cond, typename... Args>
54
+ explicit RuntimeCheck(Cond&&, Args&&...) -> RuntimeCheck<Args...>;
55
+
56
+ template <std::signed_integral T, std::signed_integral U>
57
+ inline constexpr auto div_ceil(T a, U b) {
58
+ return (a + b - 1) / b;
59
+ }
60
+
61
+ template <std::unsigned_integral T, std::unsigned_integral U>
62
+ inline constexpr auto div_ceil(T a, U b) {
63
+ return (a + b - 1) / b;
64
+ }
65
+
66
+ inline auto dtype_bytes(DLDataType dtype) -> std::size_t {
67
+ return static_cast<std::size_t>(dtype.bits / 8);
68
+ }
69
+
70
+ namespace pointer {
71
+
72
+ // we only allow void * pointer arithmetic for safety
73
+
74
+ template <typename T, std::integral... U>
75
+ inline auto offset(T* ptr, U... offset) -> void* {
76
+ static_assert(std::is_same_v<T, void>, "Pointer arithmetic is only allowed for void* pointers");
77
+ return static_cast<char*>(ptr) + (... + offset);
78
+ }
79
+
80
+ template <typename T, std::integral... U>
81
+ inline auto offset(const T* ptr, U... offset) -> const void* {
82
+ static_assert(std::is_same_v<T, void>, "Pointer arithmetic is only allowed for void* pointers");
83
+ return static_cast<const char*>(ptr) + (... + offset);
84
+ }
85
+
86
+ } // namespace pointer
87
+
88
+ } // namespace host
sglang/jit_kernel/include/sgl_kernel/warp.cuh ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <sgl_kernel/utils.cuh>
3
+
4
+ #include <cstddef>
5
+ #include <cstdint>
6
+ #include <type_traits>
7
+
8
+ namespace device::warp {
9
+
10
+ namespace details {
11
+
12
+ template <std::size_t kUnit>
13
+ inline constexpr auto get_mem_package() {
14
+ if constexpr (kUnit == 16) {
15
+ return uint4{};
16
+ } else if constexpr (kUnit == 8) {
17
+ return uint2{};
18
+ } else if constexpr (kUnit == 4) {
19
+ return uint1{};
20
+ } else {
21
+ static_assert(kUnit == 16 || kUnit == 8 || kUnit == 4, "Unsupported memory package size");
22
+ }
23
+ }
24
+
25
+ inline constexpr auto default_unit_size(std::size_t x) -> std::size_t {
26
+ if (x % (16 * kWarpThreads) == 0) return 16;
27
+ if (x % (8 * kWarpThreads) == 0) return 8;
28
+ if (x % (4 * kWarpThreads) == 0) return 4;
29
+ return 0; // trigger static assert in _get_mem_package
30
+ }
31
+
32
+ template <std::size_t kBytes, std::size_t kUnit>
33
+ using mem_package_t = decltype(get_mem_package<kUnit>());
34
+
35
+ template <typename T, std::size_t N>
36
+ struct storage_vec {
37
+ T data[N];
38
+ };
39
+
40
+ __always_inline __device__ auto load_nc(const uint1* __restrict__ src) -> uint1 {
41
+ uint32_t tmp;
42
+ asm volatile("ld.global.cs.b32 %0,[%1];" : "=r"(tmp) : "l"(src));
43
+ return uint1{tmp};
44
+ }
45
+
46
+ __always_inline __device__ auto load_nc(const uint2* __restrict__ src) -> uint2 {
47
+ uint32_t tmp0, tmp1;
48
+ asm volatile("ld.global.cs.v2.b32 {%0,%1},[%2];" : "=r"(tmp0), "=r"(tmp1) : "l"(src));
49
+ return uint2{tmp0, tmp1};
50
+ }
51
+
52
+ __always_inline __device__ auto load_nc(const uint4* __restrict__ src) -> uint4 {
53
+ uint32_t tmp0, tmp1, tmp2, tmp3;
54
+ asm volatile("ld.global.cs.v4.b32 {%0,%1,%2,%3},[%4];" : "=r"(tmp0), "=r"(tmp1), "=r"(tmp2), "=r"(tmp3) : "l"(src));
55
+ return uint4{tmp0, tmp1, tmp2, tmp3};
56
+ }
57
+
58
+ __always_inline __device__ void store_nc(uint1* __restrict__ dst, const uint1& value) {
59
+ uint32_t tmp = value.x;
60
+ asm volatile("st.global.cs.b32 [%0],%1;" ::"l"(dst), "r"(tmp));
61
+ }
62
+
63
+ __always_inline __device__ void store_nc(uint2* __restrict__ dst, const uint2& value) {
64
+ uint32_t tmp0 = value.x;
65
+ uint32_t tmp1 = value.y;
66
+ asm volatile("st.global.cs.v2.b32 [%0],{%1,%2};" ::"l"(dst), "r"(tmp0), "r"(tmp1));
67
+ }
68
+
69
+ __always_inline __device__ void store_nc(uint4* __restrict__ dst, const uint4& value) {
70
+ uint32_t tmp0 = value.x;
71
+ uint32_t tmp1 = value.y;
72
+ uint32_t tmp2 = value.z;
73
+ uint32_t tmp3 = value.w;
74
+ asm volatile("st.global.cs.v4.b32 [%0],{%1,%2,%3,%4};" ::"l"(dst), "r"(tmp0), "r"(tmp1), "r"(tmp2), "r"(tmp3));
75
+ }
76
+
77
+ } // namespace details
78
+
79
+ template <
80
+ std::size_t kBytes,
81
+ std::size_t kUnit = details::default_unit_size(kBytes),
82
+ std::size_t kThreads = ::device::kWarpThreads>
83
+ __always_inline __device__ void copy(void* __restrict__ dst, const void* __restrict__ src) {
84
+ using Package = details::mem_package_t<kBytes, kUnit>;
85
+ constexpr auto kBytesPerLoop = sizeof(Package) * kThreads;
86
+ constexpr auto kLoopCount = kBytes / kBytesPerLoop;
87
+ static_assert(kBytes % kBytesPerLoop == 0, "kBytes must be multiple of 128 bytes");
88
+
89
+ const auto dst_packed = static_cast<Package*>(dst);
90
+ const auto src_packed = static_cast<const Package*>(src);
91
+ const auto lane_id = threadIdx.x % kThreads;
92
+
93
+ #pragma unroll kLoopCount
94
+ for (std::size_t i = 0; i < kLoopCount; ++i) {
95
+ const auto j = i * kThreads + lane_id;
96
+ dst_packed[j] = src_packed[j];
97
+ }
98
+ }
99
+
100
+ template <
101
+ std::size_t kBytes,
102
+ std::size_t kUnit = details::default_unit_size(kBytes),
103
+ std::size_t kThreads = ::device::kWarpThreads>
104
+ __always_inline __device__ auto load_vec(const void* __restrict__ src) {
105
+ using Package = details::mem_package_t<kBytes, kUnit>;
106
+ constexpr auto kBytesPerLoop = sizeof(Package) * kThreads;
107
+ constexpr auto kLoopCount = kBytes / kBytesPerLoop;
108
+ static_assert(kBytes % kBytesPerLoop == 0, "kBytes must be multiple of 128 bytes");
109
+
110
+ const auto src_packed = static_cast<const Package*>(src);
111
+ const auto lane_id = threadIdx.x % kThreads;
112
+ details::storage_vec<Package, kLoopCount> vec;
113
+
114
+ #pragma unroll kLoopCount
115
+ for (std::size_t i = 0; i < kLoopCount; ++i) {
116
+ const auto j = i * kThreads + lane_id;
117
+ vec.data[i] = details::load_nc(src_packed + j);
118
+ }
119
+
120
+ return vec;
121
+ }
122
+
123
+ template <
124
+ std::size_t kBytes,
125
+ std::size_t kUnit = details::default_unit_size(kBytes),
126
+ std::size_t kThreads = ::device::kWarpThreads,
127
+ typename Tp>
128
+ __always_inline __device__ void store_vec(void* __restrict__ dst, const Tp& vec) {
129
+ using Package = details::mem_package_t<kBytes, kUnit>;
130
+ constexpr auto kBytesPerLoop = sizeof(Package) * kThreads;
131
+ constexpr auto kLoopCount = kBytes / kBytesPerLoop;
132
+ static_assert(kBytes % kBytesPerLoop == 0, "kBytes must be multiple of 128 bytes");
133
+ static_assert(std::is_same_v<Tp, details::storage_vec<Package, kLoopCount>>);
134
+
135
+ const auto dst_packed = static_cast<Package*>(dst);
136
+ const auto lane_id = threadIdx.x % kThreads;
137
+
138
+ #pragma unroll kLoopCount
139
+ for (std::size_t i = 0; i < kLoopCount; ++i) {
140
+ const auto j = i * kThreads + lane_id;
141
+ details::store_nc(dst_packed + j, vec.data[i]);
142
+ }
143
+ }
144
+
145
+ } // namespace device::warp
sglang/jit_kernel/utils.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import pathlib
4
+ from functools import lru_cache
5
+ from typing import TYPE_CHECKING, List, Tuple, TypeAlias, Union
6
+
7
+ if TYPE_CHECKING:
8
+ from tvm_ffi import Module
9
+
10
+
11
+ def _make_wrapper(tup: Tuple[str, str]) -> str:
12
+ export_name, kernel_name = tup
13
+ return f"TVM_FFI_DLL_EXPORT_TYPED_FUNC({export_name}, ({kernel_name}));"
14
+
15
+
16
+ @lru_cache()
17
+ def _resolve_kernel_path() -> pathlib.Path:
18
+ cur_dir = pathlib.Path(__file__).parent.resolve()
19
+
20
+ # first, try this directory structure
21
+ def _environment_install():
22
+ candidate = cur_dir.resolve()
23
+ if (candidate / "include").exists() and (candidate / "csrc").exists():
24
+ return candidate
25
+ return None
26
+
27
+ def _package_install():
28
+ # TODO: support find path by package
29
+ return None
30
+
31
+ path = _environment_install() or _package_install()
32
+ if path is None:
33
+ raise RuntimeError("Cannot find sgl-kernel/jit path")
34
+ return path
35
+
36
+
37
+ KERNEL_PATH = _resolve_kernel_path()
38
+ DEFAULT_INCLUDE = [str(KERNEL_PATH / "include")]
39
+ DEFAULT_CFLAGS = ["-std=c++20", "-O3"]
40
+ DEFAULT_CUDA_CFLAGS = ["-std=c++20", "-O3", "--expt-relaxed-constexpr"]
41
+ DEFAULT_LDFLAGS = []
42
+ CPP_TEMPLATE_TYPE: TypeAlias = Union[int, float, bool]
43
+
44
+
45
+ class CPPArgList(list[str]):
46
+ def __str__(self) -> str:
47
+ return ", ".join(self)
48
+
49
+
50
+ def make_cpp_args(*args: CPP_TEMPLATE_TYPE) -> CPPArgList:
51
+ def _convert(arg: CPP_TEMPLATE_TYPE) -> str:
52
+ if isinstance(arg, bool):
53
+ return "true" if arg else "false"
54
+ if isinstance(arg, (int, float)):
55
+ return str(arg)
56
+ raise TypeError(f"Unsupported argument type for cpp template: {type(arg)}")
57
+
58
+ return CPPArgList(_convert(arg) for arg in args)
59
+
60
+
61
+ def load_jit(
62
+ *args: str,
63
+ cpp_files: List[str] | None = None,
64
+ cuda_files: List[str] | None = None,
65
+ cpp_wrappers: List[Tuple[str, str]] | None = None,
66
+ cuda_wrappers: List[Tuple[str, str]] | None = None,
67
+ extra_cflags: List[str] | None = None,
68
+ extra_cuda_cflags: List[str] | None = None,
69
+ extra_ldflags: List[str] | None = None,
70
+ extra_include_paths: List[str] | None = None,
71
+ build_directory: str | None = None,
72
+ ) -> Module:
73
+ from tvm_ffi.cpp import load_inline
74
+
75
+ cpp_files = cpp_files or []
76
+ cuda_files = cuda_files or []
77
+ cpp_wrappers = cpp_wrappers or []
78
+ cuda_wrappers = cuda_wrappers or []
79
+ extra_cflags = extra_cflags or []
80
+ extra_cuda_cflags = extra_cuda_cflags or []
81
+ extra_ldflags = extra_ldflags or []
82
+ extra_include_paths = extra_include_paths or []
83
+
84
+ # include cpp files
85
+ cpp_paths = [(KERNEL_PATH / "csrc" / f).resolve() for f in cpp_files]
86
+ cpp_sources = [f'#include "{path}"' for path in cpp_paths]
87
+ cpp_sources += [_make_wrapper(tup) for tup in cpp_wrappers]
88
+
89
+ # include cuda files
90
+ cuda_paths = [(KERNEL_PATH / "csrc" / f).resolve() for f in cuda_files]
91
+ cuda_sources = [f'#include "{path}"' for path in cuda_paths]
92
+ cuda_sources += [_make_wrapper(tup) for tup in cuda_wrappers]
93
+
94
+ return load_inline(
95
+ "sgl_kernel_jit_" + "_".join(str(arg) for arg in args),
96
+ cpp_sources=cpp_sources,
97
+ cuda_sources=cuda_sources,
98
+ extra_cflags=DEFAULT_CFLAGS + extra_cflags,
99
+ extra_cuda_cflags=DEFAULT_CUDA_CFLAGS + extra_cuda_cflags,
100
+ extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags,
101
+ extra_include_paths=DEFAULT_INCLUDE + extra_include_paths,
102
+ build_directory=build_directory,
103
+ )
sglang/lang/__pycache__/api.cpython-311.pyc ADDED
Binary file (10.3 kB). View file
 
sglang/lang/__pycache__/chat_template.cpython-311.pyc ADDED
Binary file (19.4 kB). View file
 
sglang/lang/__pycache__/choices.cpython-311.pyc ADDED
Binary file (9.41 kB). View file
 
sglang/lang/__pycache__/interpreter.cpython-311.pyc ADDED
Binary file (50 kB). View file
 
sglang/lang/__pycache__/ir.cpython-311.pyc ADDED
Binary file (33.4 kB). View file
 
sglang/lang/api.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Public APIs of the language."""
2
+
3
+ import re
4
+ from typing import Callable, List, Optional, Union
5
+
6
+ from sglang.global_config import global_config
7
+ from sglang.lang.backend.base_backend import BaseBackend
8
+ from sglang.lang.choices import ChoicesSamplingMethod, token_length_normalized
9
+ from sglang.lang.ir import (
10
+ SglExpr,
11
+ SglExprList,
12
+ SglFunction,
13
+ SglGen,
14
+ SglImage,
15
+ SglRoleBegin,
16
+ SglRoleEnd,
17
+ SglSelect,
18
+ SglSeparateReasoning,
19
+ SglVideo,
20
+ )
21
+
22
+
23
+ def function(
24
+ func: Optional[Callable] = None, num_api_spec_tokens: Optional[int] = None
25
+ ):
26
+ if func:
27
+ return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
28
+
29
+ def decorator(func):
30
+ return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
31
+
32
+ return decorator
33
+
34
+
35
+ def Runtime(*args, **kwargs):
36
+ # Avoid importing unnecessary dependency
37
+ from sglang.lang.backend.runtime_endpoint import Runtime
38
+
39
+ return Runtime(*args, **kwargs)
40
+
41
+
42
+ def Engine(*args, **kwargs):
43
+ # Avoid importing unnecessary dependency
44
+ from sglang.srt.entrypoints.engine import Engine
45
+
46
+ return Engine(*args, **kwargs)
47
+
48
+
49
+ def set_default_backend(backend: BaseBackend):
50
+ global_config.default_backend = backend
51
+
52
+
53
+ def flush_cache(backend: Optional[BaseBackend] = None):
54
+ backend = backend or global_config.default_backend
55
+ if backend is None:
56
+ return False
57
+
58
+ # If backend is Runtime
59
+ if hasattr(backend, "endpoint"):
60
+ backend = backend.endpoint
61
+ return backend.flush_cache()
62
+
63
+
64
+ def get_server_info(backend: Optional[BaseBackend] = None):
65
+ backend = backend or global_config.default_backend
66
+ if backend is None:
67
+ return None
68
+
69
+ # If backend is Runtime
70
+ if hasattr(backend, "endpoint"):
71
+ backend = backend.endpoint
72
+ return backend.get_server_info()
73
+
74
+
75
+ def gen(
76
+ name: Optional[str] = None,
77
+ max_tokens: Optional[int] = None,
78
+ min_tokens: Optional[int] = None,
79
+ n: Optional[int] = None,
80
+ stop: Optional[Union[str, List[str]]] = None,
81
+ stop_token_ids: Optional[List[int]] = None,
82
+ stop_regex: Optional[Union[str, List[str]]] = None,
83
+ temperature: Optional[float] = None,
84
+ top_p: Optional[float] = None,
85
+ top_k: Optional[int] = None,
86
+ min_p: Optional[float] = None,
87
+ frequency_penalty: Optional[float] = None,
88
+ presence_penalty: Optional[float] = None,
89
+ ignore_eos: Optional[bool] = None,
90
+ return_logprob: Optional[bool] = None,
91
+ logprob_start_len: Optional[int] = None,
92
+ top_logprobs_num: Optional[int] = None,
93
+ return_text_in_logprobs: Optional[bool] = None,
94
+ dtype: Optional[Union[type, str]] = None,
95
+ choices: Optional[List[str]] = None,
96
+ choices_method: Optional[ChoicesSamplingMethod] = None,
97
+ regex: Optional[str] = None,
98
+ json_schema: Optional[str] = None,
99
+ ):
100
+ """Call the model to generate. See the meaning of the arguments in docs/backend/sampling_params.md"""
101
+
102
+ if choices:
103
+ return SglSelect(
104
+ name,
105
+ choices,
106
+ 0.0 if temperature is None else temperature,
107
+ token_length_normalized if choices_method is None else choices_method,
108
+ )
109
+
110
+ # check regex is valid
111
+ if regex is not None:
112
+ try:
113
+ re.compile(regex)
114
+ except re.error as e:
115
+ raise e
116
+
117
+ return SglGen(
118
+ name,
119
+ max_tokens,
120
+ min_tokens,
121
+ n,
122
+ stop,
123
+ stop_token_ids,
124
+ stop_regex,
125
+ temperature,
126
+ top_p,
127
+ top_k,
128
+ min_p,
129
+ frequency_penalty,
130
+ presence_penalty,
131
+ ignore_eos,
132
+ return_logprob,
133
+ logprob_start_len,
134
+ top_logprobs_num,
135
+ return_text_in_logprobs,
136
+ dtype,
137
+ regex,
138
+ json_schema,
139
+ )
140
+
141
+
142
+ def gen_int(
143
+ name: Optional[str] = None,
144
+ max_tokens: Optional[int] = None,
145
+ n: Optional[int] = None,
146
+ stop: Optional[Union[str, List[str]]] = None,
147
+ stop_token_ids: Optional[List[int]] = None,
148
+ stop_regex: Optional[Union[str, List[str]]] = None,
149
+ temperature: Optional[float] = None,
150
+ top_p: Optional[float] = None,
151
+ top_k: Optional[int] = None,
152
+ min_p: Optional[float] = None,
153
+ frequency_penalty: Optional[float] = None,
154
+ presence_penalty: Optional[float] = None,
155
+ ignore_eos: Optional[bool] = None,
156
+ return_logprob: Optional[bool] = None,
157
+ logprob_start_len: Optional[int] = None,
158
+ top_logprobs_num: Optional[int] = None,
159
+ return_text_in_logprobs: Optional[bool] = None,
160
+ ):
161
+ return SglGen(
162
+ name,
163
+ max_tokens,
164
+ None,
165
+ n,
166
+ stop,
167
+ stop_token_ids,
168
+ stop_regex,
169
+ temperature,
170
+ top_p,
171
+ top_k,
172
+ min_p,
173
+ frequency_penalty,
174
+ presence_penalty,
175
+ ignore_eos,
176
+ return_logprob,
177
+ logprob_start_len,
178
+ top_logprobs_num,
179
+ return_text_in_logprobs,
180
+ int,
181
+ None,
182
+ )
183
+
184
+
185
+ def gen_string(
186
+ name: Optional[str] = None,
187
+ max_tokens: Optional[int] = None,
188
+ n: Optional[int] = None,
189
+ stop: Optional[Union[str, List[str]]] = None,
190
+ stop_token_ids: Optional[List[int]] = None,
191
+ stop_regex: Optional[Union[str, List[str]]] = None,
192
+ temperature: Optional[float] = None,
193
+ top_p: Optional[float] = None,
194
+ top_k: Optional[int] = None,
195
+ min_p: Optional[float] = None,
196
+ frequency_penalty: Optional[float] = None,
197
+ presence_penalty: Optional[float] = None,
198
+ ignore_eos: Optional[bool] = None,
199
+ return_logprob: Optional[bool] = None,
200
+ logprob_start_len: Optional[int] = None,
201
+ top_logprobs_num: Optional[int] = None,
202
+ return_text_in_logprobs: Optional[bool] = None,
203
+ ):
204
+ return SglGen(
205
+ name,
206
+ max_tokens,
207
+ None,
208
+ n,
209
+ stop,
210
+ stop_token_ids,
211
+ stop_regex,
212
+ temperature,
213
+ top_p,
214
+ top_k,
215
+ min_p,
216
+ frequency_penalty,
217
+ presence_penalty,
218
+ ignore_eos,
219
+ return_logprob,
220
+ logprob_start_len,
221
+ top_logprobs_num,
222
+ return_text_in_logprobs,
223
+ str,
224
+ None,
225
+ )
226
+
227
+
228
+ def image(expr: SglExpr):
229
+ return SglImage(expr)
230
+
231
+
232
+ def video(path: str, num_frames: int):
233
+ return SglVideo(path, num_frames)
234
+
235
+
236
+ def select(
237
+ name: Optional[str] = None,
238
+ choices: Optional[List[str]] = None,
239
+ temperature: float = 0.0,
240
+ choices_method: ChoicesSamplingMethod = token_length_normalized,
241
+ ):
242
+ assert choices is not None
243
+ return SglSelect(name, choices, temperature, choices_method)
244
+
245
+
246
+ def _role_common(name: str, expr: Optional[SglExpr] = None):
247
+ if expr is None:
248
+ return SglExprList([SglRoleBegin(name), SglRoleEnd(name)])
249
+ else:
250
+ return SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
251
+
252
+
253
+ def system(expr: Optional[SglExpr] = None):
254
+ return _role_common("system", expr)
255
+
256
+
257
+ def user(expr: Optional[SglExpr] = None):
258
+ return _role_common("user", expr)
259
+
260
+
261
+ def assistant(expr: Optional[SglExpr] = None):
262
+ return _role_common("assistant", expr)
263
+
264
+
265
+ def system_begin():
266
+ return SglRoleBegin("system")
267
+
268
+
269
+ def system_end():
270
+ return SglRoleEnd("system")
271
+
272
+
273
+ def user_begin():
274
+ return SglRoleBegin("user")
275
+
276
+
277
+ def user_end():
278
+ return SglRoleEnd("user")
279
+
280
+
281
+ def assistant_begin():
282
+ return SglRoleBegin("assistant")
283
+
284
+
285
+ def assistant_end():
286
+ return SglRoleEnd("assistant")
287
+
288
+
289
+ def separate_reasoning(
290
+ expr: Optional[SglExpr] = None, model_type: Optional[str] = None
291
+ ):
292
+ return SglExprList([expr, SglSeparateReasoning(model_type, expr=expr)])
sglang/lang/backend/__pycache__/base_backend.cpython-311.pyc ADDED
Binary file (4.6 kB). View file
 
sglang/lang/backend/__pycache__/runtime_endpoint.cpython-311.pyc ADDED
Binary file (25.5 kB). View file
 
sglang/lang/backend/anthropic.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sglang.lang.backend.base_backend import BaseBackend
2
+ from sglang.lang.chat_template import get_chat_template
3
+ from sglang.lang.interpreter import StreamExecutor
4
+ from sglang.lang.ir import SglSamplingParams
5
+
6
+ try:
7
+ import anthropic
8
+ except ImportError as e:
9
+ anthropic = e
10
+
11
+
12
+ class Anthropic(BaseBackend):
13
+ def __init__(self, model_name, *args, **kwargs):
14
+ super().__init__()
15
+
16
+ if isinstance(anthropic, Exception):
17
+ raise anthropic
18
+
19
+ self.model_name = model_name
20
+ self.chat_template = get_chat_template("claude")
21
+ self.client = anthropic.Anthropic(*args, **kwargs)
22
+
23
+ def get_chat_template(self):
24
+ return self.chat_template
25
+
26
+ def generate(
27
+ self,
28
+ s: StreamExecutor,
29
+ sampling_params: SglSamplingParams,
30
+ ):
31
+ if s.messages_:
32
+ messages = s.messages_
33
+ else:
34
+ messages = [{"role": "user", "content": s.text_}]
35
+
36
+ if messages and messages[0]["role"] == "system":
37
+ system = messages.pop(0)["content"]
38
+ else:
39
+ system = ""
40
+
41
+ ret = self.client.messages.create(
42
+ model=self.model_name,
43
+ system=system,
44
+ messages=messages,
45
+ **sampling_params.to_anthropic_kwargs(),
46
+ )
47
+ comp = ret.content[0].text
48
+
49
+ return comp, {}
50
+
51
+ def generate_stream(
52
+ self,
53
+ s: StreamExecutor,
54
+ sampling_params: SglSamplingParams,
55
+ ):
56
+ if s.messages_:
57
+ messages = s.messages_
58
+ else:
59
+ messages = [{"role": "user", "content": s.text_}]
60
+
61
+ if messages and messages[0]["role"] == "system":
62
+ system = messages.pop(0)["content"]
63
+ else:
64
+ system = ""
65
+
66
+ with self.client.messages.stream(
67
+ model=self.model_name,
68
+ system=system,
69
+ messages=messages,
70
+ **sampling_params.to_anthropic_kwargs(),
71
+ ) as stream:
72
+ for text in stream.text_stream:
73
+ yield text, {}
sglang/lang/backend/base_backend.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ from sglang.lang.chat_template import get_chat_template
4
+ from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
5
+ from sglang.lang.interpreter import StreamExecutor
6
+ from sglang.lang.ir import SglSamplingParams
7
+
8
+
9
+ class BaseBackend:
10
+ def __init__(self) -> None:
11
+ self.support_concate_and_append = False
12
+ self.chat_template = get_chat_template("default")
13
+
14
+ def get_model_name(self):
15
+ raise NotImplementedError()
16
+
17
+ def get_chat_template(self):
18
+ return self.chat_template
19
+
20
+ def cache_prefix(self, prefix_str: str):
21
+ pass
22
+
23
+ def uncache_prefix(self, rid: str):
24
+ pass
25
+
26
+ def end_request(self, rid: Union[str, List[str]]):
27
+ pass
28
+
29
+ def begin_program(self, s: StreamExecutor):
30
+ pass
31
+
32
+ def end_program(self, s: Union[StreamExecutor, List[StreamExecutor]]):
33
+ pass
34
+
35
+ def commit_lazy_operations(self, s: StreamExecutor):
36
+ pass
37
+
38
+ def fork_program(
39
+ self,
40
+ src: StreamExecutor,
41
+ dst: List[StreamExecutor],
42
+ position_ids_offset: Optional[List[int]] = None,
43
+ ):
44
+ pass
45
+
46
+ def fill_image(self, s: StreamExecutor):
47
+ pass
48
+
49
+ def generate(
50
+ self,
51
+ s: StreamExecutor,
52
+ sampling_params: SglSamplingParams,
53
+ ):
54
+ raise NotImplementedError()
55
+
56
+ def generate_stream(
57
+ self,
58
+ s: StreamExecutor,
59
+ sampling_params: SglSamplingParams,
60
+ ):
61
+ raise NotImplementedError()
62
+
63
+ def select(
64
+ self,
65
+ s: StreamExecutor,
66
+ choices: List[str],
67
+ temperature: float,
68
+ choices_method: Optional[ChoicesSamplingMethod] = None,
69
+ ) -> ChoicesDecision:
70
+ raise NotImplementedError()
71
+
72
+ def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
73
+ raise NotImplementedError()
74
+
75
+ def shutdown(self):
76
+ pass
77
+
78
+ def flush_cache(self):
79
+ pass
80
+
81
+ def get_server_info(self):
82
+ pass
sglang/lang/backend/litellm.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Mapping, Optional
2
+
3
+ from sglang.lang.backend.base_backend import BaseBackend
4
+ from sglang.lang.chat_template import get_chat_template_by_model_path
5
+ from sglang.lang.interpreter import StreamExecutor
6
+ from sglang.lang.ir import SglSamplingParams
7
+
8
+ try:
9
+ import litellm
10
+ except ImportError as e:
11
+ litellm = e
12
+ litellm.num_retries = 1
13
+
14
+
15
+ class LiteLLM(BaseBackend):
16
+ def __init__(
17
+ self,
18
+ model_name,
19
+ chat_template=None,
20
+ api_key=None,
21
+ organization: Optional[str] = None,
22
+ base_url: Optional[str] = None,
23
+ timeout: Optional[float] = 600,
24
+ max_retries: Optional[int] = litellm.num_retries,
25
+ default_headers: Optional[Mapping[str, str]] = None,
26
+ ):
27
+ super().__init__()
28
+
29
+ if isinstance(litellm, Exception):
30
+ raise litellm
31
+
32
+ self.model_name = model_name
33
+
34
+ self.chat_template = chat_template or get_chat_template_by_model_path(
35
+ model_name
36
+ )
37
+
38
+ self.client_params = {
39
+ "api_key": api_key,
40
+ "organization": organization,
41
+ "base_url": base_url,
42
+ "timeout": timeout,
43
+ "max_retries": max_retries,
44
+ "default_headers": default_headers,
45
+ }
46
+
47
+ def get_chat_template(self):
48
+ return self.chat_template
49
+
50
+ def generate(
51
+ self,
52
+ s: StreamExecutor,
53
+ sampling_params: SglSamplingParams,
54
+ ):
55
+ if s.messages_:
56
+ messages = s.messages_
57
+ else:
58
+ messages = [{"role": "user", "content": s.text_}]
59
+
60
+ ret = litellm.completion(
61
+ model=self.model_name,
62
+ messages=messages,
63
+ **self.client_params,
64
+ **sampling_params.to_litellm_kwargs(),
65
+ )
66
+ comp = ret.choices[0].message.content
67
+
68
+ return comp, {}
69
+
70
+ def generate_stream(
71
+ self,
72
+ s: StreamExecutor,
73
+ sampling_params: SglSamplingParams,
74
+ ):
75
+ if s.messages_:
76
+ messages = s.messages_
77
+ else:
78
+ messages = [{"role": "user", "content": s.text_}]
79
+
80
+ ret = litellm.completion(
81
+ model=self.model_name,
82
+ messages=messages,
83
+ stream=True,
84
+ **self.client_params,
85
+ **sampling_params.to_litellm_kwargs(),
86
+ )
87
+ for chunk in ret:
88
+ text = chunk.choices[0].delta.content
89
+ if text is not None:
90
+ yield text, {}
sglang/lang/backend/openai.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+ import time
4
+ import warnings
5
+ from typing import List, Optional, Union
6
+
7
+ import numpy as np
8
+
9
+ from sglang.lang.backend.base_backend import BaseBackend
10
+ from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
11
+ from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
12
+ from sglang.lang.interpreter import StreamExecutor
13
+ from sglang.lang.ir import SglSamplingParams
14
+
15
+ try:
16
+ import openai
17
+ import tiktoken
18
+ except ImportError as e:
19
+ openai = tiktoken = e
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def create_logit_bias_int(tokenizer):
26
+ """Get logit bias for integer numbers."""
27
+ int_token_ids = []
28
+
29
+ tokens = tokenizer._mergeable_ranks
30
+ for token, token_id in tokens.items():
31
+ s = tokenizer.decode([token_id])
32
+ if all([c.isdigit() for c in s]) or s in [" "]:
33
+ int_token_ids.append(token_id)
34
+ if len(int_token_ids) >= 300: # OpenAI API limit
35
+ break
36
+ special_tokens = tokenizer._special_tokens
37
+ mask = {t: 100 for t in int_token_ids[:299]}
38
+ mask[special_tokens["<|endoftext|>"]] = 100
39
+ return mask
40
+
41
+
42
+ INSTRUCT_MODEL_NAMES = [
43
+ "gpt-3.5-turbo-instruct",
44
+ ]
45
+
46
+
47
+ @dataclasses.dataclass
48
+ class TokenUsage:
49
+ prompt_tokens: int
50
+ completion_tokens: int
51
+
52
+ def reset(self):
53
+ self.prompt_tokens = self.completion_tokens = 0
54
+
55
+
56
+ class OpenAI(BaseBackend):
57
+ def __init__(
58
+ self,
59
+ model_name: str,
60
+ is_chat_model: Optional[bool] = None,
61
+ chat_template: Optional[ChatTemplate] = None,
62
+ is_azure: bool = False,
63
+ *args,
64
+ **kwargs,
65
+ ):
66
+ super().__init__()
67
+
68
+ if isinstance(openai, Exception):
69
+ raise openai
70
+
71
+ if is_azure:
72
+ self.client = openai.AzureOpenAI(*args, **kwargs)
73
+ else:
74
+ self.client = openai.OpenAI(*args, **kwargs)
75
+
76
+ self.model_name = model_name
77
+ try:
78
+ self.tokenizer = tiktoken.encoding_for_model(model_name)
79
+ except KeyError:
80
+ self.tokenizer = tiktoken.get_encoding("cl100k_base")
81
+ self.logit_bias_int = create_logit_bias_int(self.tokenizer)
82
+
83
+ self.chat_template = chat_template or get_chat_template_by_model_path(
84
+ model_name
85
+ )
86
+
87
+ if is_chat_model is not None:
88
+ self.is_chat_model = is_chat_model
89
+ else:
90
+ if model_name in INSTRUCT_MODEL_NAMES:
91
+ self.is_chat_model = False
92
+ else:
93
+ self.is_chat_model = True
94
+
95
+ self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
96
+
97
+ # Usage
98
+ self.token_usage = TokenUsage(0, 0)
99
+
100
+ # API speculative execution
101
+ # TODO(ying): This does not support multi-threading (run_batch)
102
+ self.spec_kwargs = {}
103
+ self.spec_format = []
104
+ self.spec_max_num_tries = 3
105
+
106
+ def get_chat_template(self):
107
+ return self.chat_template
108
+
109
+ def _prepare_spec_execution(
110
+ self,
111
+ sampling_params: SglSamplingParams,
112
+ num_api_spec_tokens: int,
113
+ spec_var_name: str,
114
+ ):
115
+ if "max_tokens" not in self.spec_kwargs:
116
+ self.spec_kwargs["max_tokens"] = num_api_spec_tokens
117
+ else:
118
+ assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
119
+
120
+ params = sampling_params.to_openai_kwargs()
121
+ for key, value in params.items():
122
+ if key in ["stop"]:
123
+ continue
124
+ if key in ["max_tokens"]:
125
+ warnings.warn(
126
+ "The parameter max_tokens will be overwritten by speculated number of tokens."
127
+ )
128
+ continue
129
+ if key not in self.spec_kwargs:
130
+ self.spec_kwargs[key] = value
131
+ else:
132
+ assert (
133
+ value == self.spec_kwargs[key]
134
+ ), "sampling parameters should be consistent if turn on api speculative execution."
135
+ self.spec_format.append(
136
+ {"text": "", "stop": params["stop"], "name": spec_var_name}
137
+ )
138
+ return "", {}
139
+
140
+ def generate(
141
+ self,
142
+ s: StreamExecutor,
143
+ sampling_params: SglSamplingParams,
144
+ spec_var_name: str = None,
145
+ ):
146
+ if sampling_params.dtype is None:
147
+ if self.is_chat_model:
148
+ if s.num_api_spec_tokens is None:
149
+ if not s.text_.endswith(self.chat_prefix):
150
+ raise RuntimeError(
151
+ "This use case is not supported if api speculative execution is off. "
152
+ "For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
153
+ "Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
154
+ )
155
+ prompt = s.messages_
156
+ else:
157
+ return self._prepare_spec_execution(
158
+ sampling_params, s.num_api_spec_tokens, spec_var_name
159
+ )
160
+ else:
161
+ prompt = s.text_
162
+
163
+ kwargs = sampling_params.to_openai_kwargs()
164
+ if (
165
+ self.model_name.startswith("o1")
166
+ or self.model_name.startswith("o3")
167
+ or "o1" in self.model_name
168
+ ):
169
+ kwargs.pop("max_tokens", None)
170
+ else:
171
+ kwargs.pop("max_completion_tokens", None)
172
+
173
+ comp = openai_completion(
174
+ client=self.client,
175
+ token_usage=self.token_usage,
176
+ is_chat=self.is_chat_model,
177
+ model=self.model_name,
178
+ prompt=prompt,
179
+ **kwargs,
180
+ )
181
+ # Keep the returned list (or string) as is.
182
+ elif sampling_params.dtype in [str, "str", "string"]:
183
+ assert (
184
+ not self.is_chat_model
185
+ ), "constrained type not supported on chat model"
186
+ kwargs = sampling_params.to_openai_kwargs()
187
+ kwargs.pop("stop")
188
+ comp = openai_completion(
189
+ client=self.client,
190
+ token_usage=self.token_usage,
191
+ is_chat=self.is_chat_model,
192
+ model=self.model_name,
193
+ prompt=s.text_ + '"',
194
+ stop='"',
195
+ **kwargs,
196
+ )
197
+ # Wrap each element in quotes if we have a list.
198
+ if isinstance(comp, list):
199
+ comp = ['"' + x + '"' for x in comp]
200
+ else:
201
+ comp = '"' + comp + '"'
202
+ elif sampling_params.dtype in [int, "int"]:
203
+ assert (
204
+ not self.is_chat_model
205
+ ), "constrained type not supported on chat model"
206
+ kwargs = sampling_params.to_openai_kwargs()
207
+ kwargs.pop("stop")
208
+ comp = openai_completion(
209
+ client=self.client,
210
+ token_usage=self.token_usage,
211
+ is_chat=self.is_chat_model,
212
+ model=self.model_name,
213
+ prompt=s.text_,
214
+ logit_bias=self.logit_bias_int,
215
+ stop=[" "],
216
+ **kwargs,
217
+ )
218
+ # Leave as a list if that's what is returned.
219
+ else:
220
+ raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
221
+
222
+ return comp, {}
223
+
224
+ def spec_fill(self, value: str):
225
+ assert self.is_chat_model
226
+ self.spec_format.append({"text": value, "stop": None, "name": None})
227
+
228
+ def spec_pattern_match(self, comp):
229
+ for i, term in enumerate(self.spec_format):
230
+ text = term["text"]
231
+ if text != "":
232
+ if comp.startswith(text):
233
+ comp = comp[len(text) :]
234
+ else:
235
+ return False
236
+ else:
237
+ pos = comp.find(term["stop"])
238
+ if pos != -1:
239
+ term["text"] = comp[:pos]
240
+ comp = comp[pos:]
241
+ else:
242
+ if i == len(self.spec_format) - 1:
243
+ term["text"] = comp
244
+ else:
245
+ return False
246
+ return True
247
+
248
+ def role_end_generate(
249
+ self,
250
+ s: StreamExecutor,
251
+ ):
252
+ if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
253
+ return
254
+
255
+ comp = ""
256
+ if not all(x["name"] is None for x in self.spec_format):
257
+ # TODO(ying): throw errors or warnings
258
+ for i in range(self.spec_max_num_tries):
259
+ comp = openai_completion(
260
+ client=self.client,
261
+ token_usage=self.token_usage,
262
+ is_chat=self.is_chat_model,
263
+ model=self.model_name,
264
+ prompt=s.messages_,
265
+ **self.spec_kwargs,
266
+ )
267
+ # Use a string for pattern matching.
268
+ comp_for_match = comp[0] if isinstance(comp, list) else comp
269
+ if self.spec_pattern_match(comp_for_match):
270
+ break
271
+
272
+ for term in self.spec_format:
273
+ s.text_ += term["text"]
274
+ name = term["name"]
275
+ if name is not None:
276
+ s.variables[name] = term["text"]
277
+ s.meta_info[name] = {}
278
+ s.variable_event[name].set()
279
+
280
+ self.spec_kwargs = {}
281
+ self.spec_format = []
282
+
283
+ def generate_stream(
284
+ self,
285
+ s: StreamExecutor,
286
+ sampling_params: SglSamplingParams,
287
+ ):
288
+ if sampling_params.dtype is None:
289
+ if self.is_chat_model:
290
+ if not s.text_.endswith(self.chat_prefix):
291
+ raise RuntimeError(
292
+ "This use case is not supported. "
293
+ "For OpenAI chat models, sgl.gen must be right after sgl.assistant"
294
+ )
295
+ prompt = s.messages_
296
+ else:
297
+ prompt = s.text_
298
+
299
+ kwargs = sampling_params.to_openai_kwargs()
300
+ generator = openai_completion_stream(
301
+ client=self.client,
302
+ token_usage=self.token_usage,
303
+ is_chat=self.is_chat_model,
304
+ model=self.model_name,
305
+ prompt=prompt,
306
+ **kwargs,
307
+ )
308
+ return generator
309
+ else:
310
+ raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
311
+
312
+ def select(
313
+ self,
314
+ s: StreamExecutor,
315
+ choices: List[str],
316
+ temperature: float,
317
+ choices_method: ChoicesSamplingMethod,
318
+ ) -> ChoicesDecision:
319
+ """Note: `choices_method` is not used by the OpenAI backend."""
320
+ if self.is_chat_model:
321
+ raise NotImplementedError(
322
+ "select/choices is not supported for chat models. "
323
+ "Please try to use a non-chat model such as gpt-3.5-turbo-instruct"
324
+ )
325
+
326
+ n_choices = len(choices)
327
+ token_ids = [self.tokenizer.encode(x) for x in choices]
328
+ scores = [0] * n_choices
329
+ valid = [len(x) > 0 for x in token_ids]
330
+ prompt_tokens = self.tokenizer.encode(s.text_)
331
+
332
+ max_len = max([len(x) for x in token_ids])
333
+ for step in range(max_len):
334
+ # Build logit bias
335
+ logit_bias = {}
336
+ for i in range(n_choices):
337
+ if valid[i]:
338
+ logit_bias[token_ids[i][step]] = 100
339
+
340
+ # Call API
341
+ ret = self.client.completions.create(
342
+ model=self.model_name,
343
+ prompt=prompt_tokens,
344
+ logit_bias=logit_bias,
345
+ max_tokens=1,
346
+ temperature=temperature,
347
+ )
348
+ ret_str = ret.choices[0].text
349
+ ret_token = self.tokenizer.encode(ret_str)[0]
350
+ self.token_usage.prompt_tokens += ret.usage.prompt_tokens
351
+ self.token_usage.completion_tokens = ret.usage.completion_tokens
352
+
353
+ # TODO:
354
+ # 1. return logits as the scores
355
+ # 2. compute logits of the full choice
356
+ # 3. consider chunk-based decoding
357
+
358
+ # Update valid
359
+ hit = False
360
+ for i in range(n_choices):
361
+ if valid[i]:
362
+ if step == len(token_ids[i]) - 1:
363
+ valid[i] = False
364
+
365
+ if ret_token == token_ids[i][step]:
366
+ scores[i] += 1
367
+ hit = True
368
+ else:
369
+ valid[i] = False
370
+ assert hit
371
+
372
+ if np.sum(valid) <= 1:
373
+ break
374
+
375
+ prompt_tokens.append(ret_token)
376
+
377
+ return ChoicesDecision(
378
+ decision=choices[np.argmax(scores)],
379
+ meta_info={"scores": scores},
380
+ )
381
+
382
+
383
+ def openai_completion(
384
+ client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
385
+ ) -> Union[str, List[str]]:
386
+ # if "ebnf" is in kwargs, warn and remove
387
+ if "ebnf" in kwargs:
388
+ warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
389
+ del kwargs["ebnf"]
390
+
391
+ for attempt in range(retries):
392
+ try:
393
+ if is_chat:
394
+ if "stop" in kwargs and kwargs["stop"] is None:
395
+ kwargs.pop("stop")
396
+ ret = client.chat.completions.create(messages=prompt, **kwargs)
397
+ if len(ret.choices) == 1:
398
+ comp = ret.choices[0].message.content
399
+ else:
400
+ comp = [c.message.content for c in ret.choices]
401
+ else:
402
+ ret = client.completions.create(prompt=prompt, **kwargs)
403
+ if isinstance(prompt, (list, tuple)):
404
+ comp = [c.text for c in ret.choices]
405
+ else:
406
+ comp = ret.choices[0].text
407
+ if len(ret.choices) > 1:
408
+ comp = [c.text for c in ret.choices]
409
+
410
+ token_usage.prompt_tokens += ret.usage.prompt_tokens
411
+ token_usage.completion_tokens += ret.usage.completion_tokens
412
+ break
413
+ except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
414
+ logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
415
+ time.sleep(5)
416
+ if attempt == retries - 1:
417
+ raise e
418
+ except Exception as e:
419
+ logger.error(f"RuntimeError {e}.")
420
+ raise e
421
+
422
+ return comp
423
+
424
+
425
+ def openai_completion_stream(
426
+ client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
427
+ ):
428
+ # if "ebnf" is in kwargs, warn and remove
429
+ if "ebnf" in kwargs:
430
+ warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
431
+ del kwargs["ebnf"]
432
+
433
+ for attempt in range(retries):
434
+ try:
435
+ if is_chat:
436
+ if "stop" in kwargs and kwargs["stop"] is None:
437
+ kwargs.pop("stop")
438
+ generator = client.chat.completions.create(
439
+ messages=prompt,
440
+ stream=True,
441
+ stream_options={"include_usage": True},
442
+ **kwargs,
443
+ )
444
+ for ret in generator:
445
+ if len(ret.choices) == 0:
446
+ continue
447
+ try:
448
+ content = ret.choices[0].delta.content
449
+ except IndexError:
450
+ content = None
451
+ yield content or "", {}
452
+ else:
453
+ generator = client.completions.create(
454
+ prompt=prompt,
455
+ stream=True,
456
+ stream_options={"include_usage": True},
457
+ **kwargs,
458
+ )
459
+ for ret in generator:
460
+ if len(ret.choices) == 0:
461
+ continue
462
+ content = ret.choices[0].text
463
+ yield content or "", {}
464
+
465
+ token_usage.prompt_tokens += ret.usage.prompt_tokens
466
+ token_usage.completion_tokens += ret.usage.completion_tokens
467
+ break
468
+ except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
469
+ logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
470
+ time.sleep(5)
471
+ if attempt == retries - 1:
472
+ raise e
473
+ except Exception as e:
474
+ logger.error(f"RuntimeError {e}.")
475
+ raise e
sglang/lang/backend/runtime_endpoint.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ import json
3
+ import multiprocessing
4
+ import warnings
5
+ from typing import Dict, List, Optional, Union
6
+
7
+ import aiohttp
8
+ import requests
9
+
10
+ from sglang.global_config import global_config
11
+ from sglang.lang.backend.base_backend import BaseBackend
12
+ from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
13
+ from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
14
+ from sglang.lang.interpreter import StreamExecutor
15
+ from sglang.lang.ir import (
16
+ REGEX_BOOL,
17
+ REGEX_FLOAT,
18
+ REGEX_INT,
19
+ REGEX_STR,
20
+ SglSamplingParams,
21
+ )
22
+ from sglang.utils import http_request
23
+
24
+
25
+ class RuntimeEndpoint(BaseBackend):
26
+ def __init__(
27
+ self,
28
+ base_url: str,
29
+ api_key: Optional[str] = None,
30
+ verify: Optional[str] = None,
31
+ chat_template_name: Optional[str] = None,
32
+ ):
33
+ super().__init__()
34
+ self.support_concate_and_append = True
35
+
36
+ self.base_url = base_url
37
+ self.api_key = api_key
38
+ self.verify = verify
39
+
40
+ res = http_request(
41
+ self.base_url + "/get_model_info",
42
+ api_key=self.api_key,
43
+ verify=self.verify,
44
+ )
45
+ self._assert_success(res)
46
+ self.model_info = res.json()
47
+
48
+ if chat_template_name:
49
+ self.chat_template = get_chat_template(chat_template_name)
50
+ else:
51
+ self.chat_template = get_chat_template_by_model_path(
52
+ self.model_info["model_path"]
53
+ )
54
+
55
+ def get_model_name(self):
56
+ return self.model_info["model_path"]
57
+
58
+ def flush_cache(self):
59
+ res = http_request(
60
+ self.base_url + "/flush_cache",
61
+ api_key=self.api_key,
62
+ verify=self.verify,
63
+ method="POST",
64
+ )
65
+ self._assert_success(res)
66
+
67
+ def get_server_info(self):
68
+ res = http_request(
69
+ self.base_url + "/get_server_info",
70
+ api_key=self.api_key,
71
+ verify=self.verify,
72
+ )
73
+ self._assert_success(res)
74
+ return res.json()
75
+
76
+ def get_chat_template(self):
77
+ return self.chat_template
78
+
79
+ def cache_prefix(self, prefix_str: str):
80
+ res = http_request(
81
+ self.base_url + "/generate",
82
+ json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
83
+ api_key=self.api_key,
84
+ verify=self.verify,
85
+ )
86
+ self._assert_success(res)
87
+
88
+ def start_profile(self):
89
+ res = http_request(
90
+ self.base_url + "/start_profile",
91
+ api_key=self.api_key,
92
+ verify=self.verify,
93
+ )
94
+ self._assert_success(res)
95
+
96
+ def stop_profile(self):
97
+ res = http_request(
98
+ self.base_url + "/stop_profile",
99
+ api_key=self.api_key,
100
+ verify=self.verify,
101
+ )
102
+ self._assert_success(res)
103
+
104
+ def commit_lazy_operations(self, s: StreamExecutor):
105
+ data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
106
+ self._add_images(s, data)
107
+ res = http_request(
108
+ self.base_url + "/generate",
109
+ json=data,
110
+ api_key=self.api_key,
111
+ verify=self.verify,
112
+ )
113
+ self._assert_success(res)
114
+
115
+ def fill_image(self, s: StreamExecutor):
116
+ data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
117
+ self._add_images(s, data)
118
+ res = http_request(
119
+ self.base_url + "/generate",
120
+ json=data,
121
+ api_key=self.api_key,
122
+ verify=self.verify,
123
+ )
124
+ self._assert_success(res)
125
+
126
+ def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams):
127
+ if sampling_params.dtype is None:
128
+ return
129
+
130
+ if sampling_params.stop == ():
131
+ sampling_params.stop = []
132
+
133
+ dtype_regex = None
134
+ if sampling_params.dtype in ["int", int]:
135
+
136
+ dtype_regex = REGEX_INT
137
+ sampling_params.stop.extend([" ", "\n"])
138
+ elif sampling_params.dtype in ["float", float]:
139
+
140
+ dtype_regex = REGEX_FLOAT
141
+ sampling_params.stop.extend([" ", "\n"])
142
+ elif sampling_params.dtype in ["str", str]:
143
+
144
+ dtype_regex = REGEX_STR
145
+ elif sampling_params.dtype in ["bool", bool]:
146
+
147
+ dtype_regex = REGEX_BOOL
148
+ else:
149
+ raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
150
+
151
+ if dtype_regex is not None and sampling_params.regex is not None:
152
+ warnings.warn(
153
+ f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}"
154
+ )
155
+
156
+ sampling_params.regex = dtype_regex
157
+
158
+ def generate(
159
+ self,
160
+ s: StreamExecutor,
161
+ sampling_params: SglSamplingParams,
162
+ ):
163
+ self._handle_dtype_to_regex(sampling_params)
164
+ data = {
165
+ "text": s.text_,
166
+ "sampling_params": {
167
+ "skip_special_tokens": global_config.skip_special_tokens_in_output,
168
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
169
+ **sampling_params.to_srt_kwargs(),
170
+ },
171
+ }
172
+
173
+ for item in [
174
+ "return_logprob",
175
+ "logprob_start_len",
176
+ "top_logprobs_num",
177
+ "return_text_in_logprobs",
178
+ ]:
179
+ value = getattr(sampling_params, item, None)
180
+ if value is not None:
181
+ data[item] = value
182
+
183
+ self._add_images(s, data)
184
+
185
+ res = http_request(
186
+ self.base_url + "/generate",
187
+ json=data,
188
+ api_key=self.api_key,
189
+ verify=self.verify,
190
+ )
191
+ self._assert_success(res)
192
+
193
+ obj = res.json()
194
+ comp = obj["text"]
195
+ return comp, obj["meta_info"]
196
+
197
+ def generate_stream(
198
+ self,
199
+ s: StreamExecutor,
200
+ sampling_params: SglSamplingParams,
201
+ ):
202
+ self._handle_dtype_to_regex(sampling_params)
203
+
204
+ data = {
205
+ "text": s.text_,
206
+ "sampling_params": {
207
+ "skip_special_tokens": global_config.skip_special_tokens_in_output,
208
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
209
+ **sampling_params.to_srt_kwargs(),
210
+ },
211
+ }
212
+
213
+ for item in [
214
+ "return_logprob",
215
+ "logprob_start_len",
216
+ "top_logprobs_num",
217
+ "return_text_in_logprobs",
218
+ ]:
219
+ value = getattr(sampling_params, item, None)
220
+ if value is not None:
221
+ data[item] = value
222
+
223
+ data["stream"] = True
224
+ self._add_images(s, data)
225
+
226
+ res = http_request(
227
+ self.base_url + "/generate",
228
+ json=data,
229
+ stream=True,
230
+ api_key=self.api_key,
231
+ verify=self.verify,
232
+ )
233
+ self._assert_success(res)
234
+ pos = 0
235
+
236
+ for chunk in res.iter_lines(decode_unicode=False):
237
+ chunk = chunk.decode("utf-8")
238
+ if chunk and chunk.startswith("data:"):
239
+ if chunk == "data: [DONE]":
240
+ break
241
+ data = json.loads(chunk[5:].strip("\n"))
242
+ chunk_text = data["text"][pos:]
243
+ meta_info = data["meta_info"]
244
+ pos += len(chunk_text)
245
+ yield chunk_text, meta_info
246
+
247
+ def select(
248
+ self,
249
+ s: StreamExecutor,
250
+ choices: List[str],
251
+ temperature: float,
252
+ choices_method: ChoicesSamplingMethod,
253
+ ) -> ChoicesDecision:
254
+ assert temperature <= 1e-5
255
+
256
+ # Cache common prefix
257
+ data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
258
+ obj = self._generate_http_request(s, data)
259
+ prompt_len = obj["meta_info"]["prompt_tokens"]
260
+ logprob_start_len = max(prompt_len - 2, 0) # For token healing
261
+
262
+ # Compute logprob
263
+ data = {
264
+ "text": [s.text_ + c for c in choices],
265
+ "sampling_params": {
266
+ "max_new_tokens": 0,
267
+ "temperature": 0,
268
+ },
269
+ "return_logprob": True,
270
+ "return_text_in_logprobs": True,
271
+ "logprob_start_len": logprob_start_len,
272
+ }
273
+ obj = self._generate_http_request(s, data)
274
+
275
+ input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
276
+ output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
277
+ normalized_prompt_logprobs = [
278
+ compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"])
279
+ for r in obj
280
+ ]
281
+
282
+ # Remove extra token if no token healing occurred
283
+ for i in range(len(input_token_logprobs)):
284
+ healed_token_str = input_token_logprobs[i][0][-1]
285
+ if s.text_.endswith(healed_token_str):
286
+ healed_token_logprob = input_token_logprobs[i][0][0]
287
+ normalized_prompt_logprobs[i] = (
288
+ normalized_prompt_logprobs[i] * len(input_token_logprobs[i])
289
+ - healed_token_logprob
290
+ ) / (len(input_token_logprobs[i]) - 1)
291
+ input_token_logprobs[i] = input_token_logprobs[i][1:]
292
+
293
+ # Compute unconditional logprobs if required
294
+ if choices_method.requires_unconditional_logprobs:
295
+ input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
296
+ data = {
297
+ "input_ids": input_ids,
298
+ "sampling_params": {"max_new_tokens": 0},
299
+ "return_logprob": True,
300
+ }
301
+ obj = self._generate_http_request(s, data)
302
+ unconditional_token_logprobs = [
303
+ r["meta_info"]["input_token_logprobs"] for r in obj
304
+ ]
305
+ else:
306
+ unconditional_token_logprobs = None
307
+
308
+ return choices_method(
309
+ choices=choices,
310
+ normalized_prompt_logprobs=normalized_prompt_logprobs,
311
+ input_token_logprobs=input_token_logprobs,
312
+ output_token_logprobs=output_token_logprobs,
313
+ unconditional_token_logprobs=unconditional_token_logprobs,
314
+ )
315
+
316
+ def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
317
+ res = http_request(
318
+ self.base_url + "/concate_and_append_request",
319
+ json={"src_rids": src_rids, "dst_rid": dst_rid},
320
+ api_key=self.api_key,
321
+ verify=self.verify,
322
+ )
323
+ self._assert_success(res)
324
+
325
+ def _generate_http_request(self, s: StreamExecutor, data):
326
+ self._add_images(s, data)
327
+ res = http_request(
328
+ self.base_url + "/generate",
329
+ json=data,
330
+ api_key=self.api_key,
331
+ verify=self.verify,
332
+ )
333
+ self._assert_success(res)
334
+ return res.json()
335
+
336
+ def _add_images(self, s: StreamExecutor, data):
337
+ if s.images_:
338
+ assert len(s.images_) == 1, "Only support one image."
339
+ data["image_data"] = s.images_[0][1]
340
+
341
+ def _assert_success(self, res):
342
+ if res.status_code != 200:
343
+ try:
344
+ content = res.json()
345
+ except json.JSONDecodeError:
346
+ content = res.text
347
+ raise RuntimeError(content)
348
+
349
+
350
+ def compute_normalized_prompt_logprobs(input_logprobs):
351
+ values = [x[0] for x in input_logprobs if x[0]]
352
+ return sum(values) / len(values)
353
+
354
+
355
+ class Runtime:
356
+ """
357
+ A wrapper for the HTTP server.
358
+ This is used for launching the server in a python program without
359
+ using the command line interface.
360
+
361
+ It is mainly used for the frontend language.
362
+ You should use the Engine class if you want to do normal offline processing without the frontend language.
363
+ """
364
+
365
+ def __init__(
366
+ self,
367
+ log_level: str = "error",
368
+ *args,
369
+ **kwargs,
370
+ ):
371
+ """See the arguments in server_args.py::ServerArgs"""
372
+ # We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run
373
+ # client code without installing SRT server and its dependency if they want.
374
+ from sglang.srt.entrypoints.http_server import launch_server
375
+ from sglang.srt.server_args import ServerArgs
376
+ from sglang.srt.utils import is_port_available
377
+
378
+ self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
379
+
380
+ # Pre-allocate ports
381
+ for port in range(self.server_args.port, 40000):
382
+ if is_port_available(port):
383
+ break
384
+ self.server_args.port = port
385
+
386
+ self.url = self.server_args.url()
387
+ self.generate_url = self.url + "/generate"
388
+
389
+ # NOTE: We store pid instead of proc to fix some issues during __delete__
390
+ self.pid = None
391
+ pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False)
392
+
393
+ ctx = multiprocessing.get_context("spawn")
394
+ proc = ctx.Process(
395
+ target=launch_server,
396
+ args=(self.server_args, pipe_writer),
397
+ )
398
+ proc.start()
399
+ pipe_writer.close()
400
+ self.pid = proc.pid
401
+
402
+ # Before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
403
+ atexit.register(self.shutdown)
404
+
405
+ # TODO: remove this pipe_writer mechanism and use `/health_generate` instead.
406
+ try:
407
+ init_state = pipe_reader.recv()
408
+ except EOFError:
409
+ init_state = ""
410
+
411
+ if init_state != "ready":
412
+ self.shutdown()
413
+ raise RuntimeError(
414
+ "Initialization failed. Please see the error messages above."
415
+ )
416
+
417
+ self.endpoint = RuntimeEndpoint(self.url)
418
+
419
+ def shutdown(self):
420
+ from sglang.srt.utils import kill_process_tree
421
+
422
+ if self.pid is not None:
423
+ kill_process_tree(self.pid)
424
+ self.pid = None
425
+
426
+ def start_profile(self):
427
+ self.endpoint.start_profile()
428
+
429
+ def stop_profile(self):
430
+ self.endpoint.stop_profile()
431
+
432
+ def cache_prefix(self, prefix: str):
433
+ self.endpoint.cache_prefix(prefix)
434
+
435
+ def get_tokenizer(self):
436
+ from sglang.srt.utils.hf_transformers_utils import get_tokenizer
437
+
438
+ return get_tokenizer(
439
+ self.server_args.tokenizer_path,
440
+ tokenizer_mode=self.server_args.tokenizer_mode,
441
+ trust_remote_code=self.server_args.trust_remote_code,
442
+ revision=self.server_args.revision,
443
+ )
444
+
445
+ async def async_generate(
446
+ self,
447
+ prompt: str,
448
+ sampling_params: Optional[Dict] = None,
449
+ ):
450
+ if self.server_args.skip_tokenizer_init:
451
+ json_data = {
452
+ "input_ids": prompt,
453
+ "sampling_params": sampling_params,
454
+ "stream": True,
455
+ }
456
+ else:
457
+ json_data = {
458
+ "text": prompt,
459
+ "sampling_params": sampling_params,
460
+ "stream": True,
461
+ }
462
+ pos = 0
463
+
464
+ timeout = aiohttp.ClientTimeout(total=3 * 3600)
465
+ async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
466
+ async with session.post(self.generate_url, json=json_data) as response:
467
+ async for chunk, _ in response.content.iter_chunks():
468
+ chunk = chunk.decode("utf-8")
469
+ if chunk and chunk.startswith("data:"):
470
+ if chunk == "data: [DONE]\n\n":
471
+ break
472
+ data = json.loads(chunk[5:].strip("\n"))
473
+ if "text" in data:
474
+ cur = data["text"][pos:]
475
+ if cur:
476
+ yield cur
477
+ pos += len(cur)
478
+ else:
479
+ yield data
480
+
481
+ add_request = async_generate
482
+
483
+ def generate(
484
+ self,
485
+ prompt: Union[str, List[str]],
486
+ sampling_params: Optional[Dict] = None,
487
+ return_logprob: Optional[Union[List[bool], bool]] = False,
488
+ logprob_start_len: Optional[Union[List[int], int]] = None,
489
+ top_logprobs_num: Optional[Union[List[int], int]] = None,
490
+ lora_path: Optional[List[Optional[str]]] = None,
491
+ ):
492
+ json_data = {
493
+ "text": prompt,
494
+ "sampling_params": sampling_params,
495
+ "return_logprob": return_logprob,
496
+ "logprob_start_len": logprob_start_len,
497
+ "top_logprobs_num": top_logprobs_num,
498
+ "lora_path": lora_path,
499
+ }
500
+ assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
501
+ response = requests.post(
502
+ self.url + "/generate",
503
+ json=json_data,
504
+ )
505
+ return json.dumps(response.json())
506
+
507
+ def encode(
508
+ self,
509
+ prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
510
+ ):
511
+ json_data = {"text": prompt}
512
+ response = requests.post(self.url + "/encode", json=json_data)
513
+ return json.dumps(response.json())
514
+
515
+ async def get_server_info(self):
516
+ async with aiohttp.ClientSession() as session:
517
+ async with session.get(f"{self.url}/get_server_info") as response:
518
+ if response.status == 200:
519
+ return await response.json()
520
+ else:
521
+ error_data = await response.json()
522
+ raise RuntimeError(
523
+ f"Failed to get server info. {error_data['error']['message']}"
524
+ )
525
+
526
+ def __del__(self):
527
+ self.shutdown()
sglang/lang/backend/vertexai.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+
4
+ from sglang.lang.backend.base_backend import BaseBackend
5
+ from sglang.lang.chat_template import get_chat_template
6
+ from sglang.lang.interpreter import StreamExecutor
7
+ from sglang.lang.ir import SglSamplingParams
8
+
9
+ try:
10
+ import vertexai
11
+ from vertexai.preview.generative_models import (
12
+ GenerationConfig,
13
+ GenerativeModel,
14
+ Image,
15
+ )
16
+ except ImportError as e:
17
+ GenerativeModel = e
18
+
19
+
20
+ class VertexAI(BaseBackend):
21
+ def __init__(self, model_name, safety_settings=None):
22
+ super().__init__()
23
+
24
+ if isinstance(GenerativeModel, Exception):
25
+ raise GenerativeModel
26
+
27
+ project_id = os.environ["GCP_PROJECT_ID"]
28
+ location = os.environ.get("GCP_LOCATION")
29
+ vertexai.init(project=project_id, location=location)
30
+
31
+ self.model_name = model_name
32
+ self.chat_template = get_chat_template("default")
33
+ self.safety_settings = safety_settings
34
+
35
+ def get_chat_template(self):
36
+ return self.chat_template
37
+
38
+ def generate(
39
+ self,
40
+ s: StreamExecutor,
41
+ sampling_params: SglSamplingParams,
42
+ ):
43
+ if s.messages_:
44
+ prompt = self.messages_to_vertexai_input(s.messages_)
45
+ else:
46
+ # single-turn
47
+ prompt = (
48
+ self.text_to_vertexai_input(s.text_, s.cur_images)
49
+ if s.cur_images
50
+ else s.text_
51
+ )
52
+ ret = GenerativeModel(self.model_name).generate_content(
53
+ prompt,
54
+ generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
55
+ safety_settings=self.safety_settings,
56
+ )
57
+
58
+ comp = ret.text
59
+
60
+ return comp, {}
61
+
62
+ def generate_stream(
63
+ self,
64
+ s: StreamExecutor,
65
+ sampling_params: SglSamplingParams,
66
+ ):
67
+ if s.messages_:
68
+ prompt = self.messages_to_vertexai_input(s.messages_)
69
+ else:
70
+ # single-turn
71
+ prompt = (
72
+ self.text_to_vertexai_input(s.text_, s.cur_images)
73
+ if s.cur_images
74
+ else s.text_
75
+ )
76
+ generator = GenerativeModel(self.model_name).generate_content(
77
+ prompt,
78
+ stream=True,
79
+ generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
80
+ safety_settings=self.safety_settings,
81
+ )
82
+ for ret in generator:
83
+ yield ret.text, {}
84
+
85
+ def text_to_vertexai_input(self, text, images):
86
+ input = []
87
+ # split with image token
88
+ text_segs = text.split(self.chat_template.image_token)
89
+ for image_path, image_base64_data in images:
90
+ text_seg = text_segs.pop(0)
91
+ if text_seg != "":
92
+ input.append(text_seg)
93
+ input.append(Image.from_bytes(image_base64_data))
94
+ text_seg = text_segs.pop(0)
95
+ if text_seg != "":
96
+ input.append(text_seg)
97
+ return input
98
+
99
+ def messages_to_vertexai_input(self, messages):
100
+ vertexai_message = []
101
+ # from openai message format to vertexai message format
102
+ for msg in messages:
103
+ if isinstance(msg["content"], str):
104
+ text = msg["content"]
105
+ else:
106
+ text = msg["content"][0]["text"]
107
+
108
+ if msg["role"] == "system":
109
+ warnings.warn("Warning: system prompt is not supported in VertexAI.")
110
+ vertexai_message.append(
111
+ {
112
+ "role": "user",
113
+ "parts": [{"text": "System prompt: " + text}],
114
+ }
115
+ )
116
+ vertexai_message.append(
117
+ {
118
+ "role": "model",
119
+ "parts": [{"text": "Understood."}],
120
+ }
121
+ )
122
+ continue
123
+ if msg["role"] == "user":
124
+ vertexai_msg = {
125
+ "role": "user",
126
+ "parts": [{"text": text}],
127
+ }
128
+ elif msg["role"] == "assistant":
129
+ vertexai_msg = {
130
+ "role": "model",
131
+ "parts": [{"text": text}],
132
+ }
133
+
134
+ # images
135
+ if isinstance(msg["content"], list) and len(msg["content"]) > 1:
136
+ for image in msg["content"][1:]:
137
+ assert image["type"] == "image_url"
138
+ vertexai_msg["parts"].append(
139
+ {
140
+ "inline_data": {
141
+ "data": image["image_url"]["url"].split(",")[1],
142
+ "mime_type": "image/jpeg",
143
+ }
144
+ }
145
+ )
146
+
147
+ vertexai_message.append(vertexai_msg)
148
+ return vertexai_message
sglang/lang/chat_template.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from dataclasses import dataclass
3
+ from enum import Enum, auto
4
+ from typing import Callable, Dict, List, Tuple
5
+
6
+
7
+ class ChatTemplateStyle(Enum):
8
+ PLAIN = auto()
9
+ LLAMA2 = auto()
10
+
11
+
12
+ @dataclass
13
+ class ChatTemplate:
14
+ name: str
15
+ default_system_prompt: str
16
+ role_prefix_and_suffix: Dict[str, Tuple[str, str]]
17
+ stop_str: List[str] = ()
18
+ image_token: str = "<image>"
19
+ audio_token: str = "<audio>"
20
+ style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
21
+
22
+ def get_prefix_and_suffix(
23
+ self, role: str, hist_messages: List[Dict]
24
+ ) -> Tuple[str, str]:
25
+ prefix, suffix = self.role_prefix_and_suffix.get(role, ("", ""))
26
+
27
+ if self.style == ChatTemplateStyle.LLAMA2:
28
+ if role == "system" and not hist_messages:
29
+ user_prefix, _ = self.role_prefix_and_suffix.get("user", ("", ""))
30
+ system_prefix, system_suffix = self.role_prefix_and_suffix.get(
31
+ "system", ("", "")
32
+ )
33
+ return (user_prefix + system_prefix, system_suffix)
34
+ elif (
35
+ role == "user"
36
+ and len(hist_messages) == 1
37
+ and hist_messages[0]["content"] is not None
38
+ ):
39
+ return ("", suffix)
40
+
41
+ return prefix, suffix
42
+
43
+ def get_prompt(self, messages: List[Dict]) -> str:
44
+ prompt = ""
45
+ for i, message in enumerate(messages):
46
+ role, content = message["role"], message["content"]
47
+ if role == "system" and content is None:
48
+ content = self.default_system_prompt
49
+ if content is None:
50
+ continue
51
+
52
+ prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
53
+ prompt += f"{prefix}{content}{suffix}"
54
+ return prompt
55
+
56
+
57
+ chat_template_registry: Dict[str, ChatTemplate] = {}
58
+ matching_function_registry: List[Callable] = []
59
+
60
+
61
+ def register_chat_template(template):
62
+ chat_template_registry[template.name] = template
63
+
64
+
65
+ def register_chat_template_matching_function(func):
66
+ matching_function_registry.append(func)
67
+
68
+
69
+ def get_chat_template(name):
70
+ return chat_template_registry[name]
71
+
72
+
73
+ def get_chat_template_by_model_path(model_path):
74
+ for matching_func in matching_function_registry:
75
+ template_name = matching_func(model_path)
76
+ if template_name is not None:
77
+ return get_chat_template(template_name)
78
+ return get_chat_template("default")
79
+
80
+
81
+ register_chat_template(
82
+ ChatTemplate(
83
+ name="default",
84
+ default_system_prompt=None,
85
+ role_prefix_and_suffix={
86
+ "system": ("SYSTEM:", "\n"),
87
+ "user": ("USER:", "\n"),
88
+ "assistant": ("ASSISTANT:", "\n"),
89
+ },
90
+ )
91
+ )
92
+
93
+ register_chat_template(
94
+ ChatTemplate(
95
+ name="claude",
96
+ default_system_prompt=None,
97
+ role_prefix_and_suffix={
98
+ "system": ("", ""),
99
+ "user": ("\n\nHuman: ", ""),
100
+ "assistant": ("\n\nAssistant:", ""),
101
+ },
102
+ )
103
+ )
104
+
105
+ register_chat_template(
106
+ ChatTemplate(
107
+ name="chatml",
108
+ default_system_prompt=None,
109
+ role_prefix_and_suffix={
110
+ "system": ("<|im_start|>system\n", "<|im_end|>\n"),
111
+ "user": ("<|im_start|>user\n", "<|im_end|>\n"),
112
+ "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
113
+ },
114
+ style=ChatTemplateStyle.PLAIN,
115
+ stop_str=("<|im_end|>",),
116
+ )
117
+ )
118
+
119
+ register_chat_template(
120
+ ChatTemplate(
121
+ name="chatml-llava",
122
+ default_system_prompt="You are a helpful assistant.",
123
+ role_prefix_and_suffix={
124
+ "system": ("<|im_start|>system\n", "<|im_end|>\n"),
125
+ "user": ("<|im_start|>user\n", "<|im_end|>\n"),
126
+ "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
127
+ },
128
+ style=ChatTemplateStyle.PLAIN,
129
+ stop_str=("<|im_end|>",),
130
+ image_token="<image>\n",
131
+ )
132
+ )
133
+
134
+ # There is default system prompt for qwen
135
+ # reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
136
+ # The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
137
+ register_chat_template(
138
+ ChatTemplate(
139
+ name="qwen",
140
+ default_system_prompt="You are a helpful assistant.",
141
+ role_prefix_and_suffix={
142
+ "system": ("<|im_start|>system\n", "<|im_end|>\n"),
143
+ "user": ("<|im_start|>user\n", "<|im_end|>\n"),
144
+ "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
145
+ },
146
+ style=ChatTemplateStyle.PLAIN,
147
+ stop_str=("<|im_end|>",),
148
+ )
149
+ )
150
+
151
+ # Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
152
+ register_chat_template(
153
+ ChatTemplate(
154
+ name="qwen2-vl",
155
+ default_system_prompt="You are a helpful assistant.",
156
+ role_prefix_and_suffix={
157
+ "system": ("<|im_start|>system\n", "<|im_end|>\n"),
158
+ "user": ("<|im_start|>user\n", "<|im_end|>\n"),
159
+ "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
160
+ },
161
+ style=ChatTemplateStyle.PLAIN,
162
+ stop_str=("<|im_end|>",),
163
+ image_token="<|vision_start|><|image_pad|><|vision_end|>",
164
+ )
165
+ )
166
+
167
+ # Reference: https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
168
+ register_chat_template(
169
+ ChatTemplate(
170
+ name="vicuna_v1.1",
171
+ default_system_prompt=(
172
+ "A chat between a curious user and an artificial intelligence assistant. "
173
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
174
+ ),
175
+ role_prefix_and_suffix={
176
+ "system": ("", " "),
177
+ "user": ("USER:", " "),
178
+ "assistant": ("ASSISTANT:", "</s>"),
179
+ },
180
+ image_token=" <image>\n",
181
+ )
182
+ )
183
+
184
+ register_chat_template(
185
+ ChatTemplate(
186
+ name="llama-2-chat",
187
+ default_system_prompt=None,
188
+ role_prefix_and_suffix={
189
+ "system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
190
+ "user": ("[INST] ", " [/INST]"),
191
+ "assistant": ("", " </s><s>"),
192
+ },
193
+ style=ChatTemplateStyle.LLAMA2,
194
+ )
195
+ )
196
+
197
+ # Reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
198
+ register_chat_template(
199
+ ChatTemplate(
200
+ name="mistral",
201
+ default_system_prompt=None,
202
+ role_prefix_and_suffix={
203
+ "system": ("[SYSTEM_PROMPT] ", " [/SYSTEM_PROMPT]"),
204
+ "user": ("[INST] ", " [/INST]"),
205
+ "assistant": ("", " </s><s>"),
206
+ },
207
+ stop_str=("</s>",),
208
+ image_token="[IMG]",
209
+ )
210
+ )
211
+
212
+ register_chat_template(
213
+ ChatTemplate(
214
+ name="llama-3-instruct",
215
+ default_system_prompt=None,
216
+ role_prefix_and_suffix={
217
+ "system": (
218
+ "<|start_header_id|>system<|end_header_id|>\n\n",
219
+ "<|eot_id|>",
220
+ ),
221
+ "user": (
222
+ "<|start_header_id|>user<|end_header_id|>\n\n",
223
+ "<|eot_id|>",
224
+ ),
225
+ "assistant": (
226
+ "<|start_header_id|>assistant<|end_header_id|>\n\n",
227
+ "<|eot_id|>",
228
+ ),
229
+ },
230
+ stop_str=("<|eot_id|>",),
231
+ image_token="<|image|>",
232
+ )
233
+ )
234
+
235
+ # https://huggingface.co/openbmb/MiniCPM-V-2_6
236
+ register_chat_template(
237
+ ChatTemplate(
238
+ name="minicpmv",
239
+ default_system_prompt=None,
240
+ role_prefix_and_suffix={
241
+ "system": ("", " "),
242
+ "user": ("user:", " "),
243
+ "assistant": ("assistant:", "</s>"),
244
+ },
245
+ stop_str=("<|im_end|>", "<|endoftext|>"),
246
+ image_token="(<image>./</image>)",
247
+ )
248
+ )
249
+
250
+ register_chat_template(
251
+ ChatTemplate(
252
+ name="janus-pro",
253
+ default_system_prompt=None,
254
+ role_prefix_and_suffix={
255
+ "system": (
256
+ "",
257
+ "",
258
+ ),
259
+ "User": (
260
+ "<|User|>",
261
+ "",
262
+ ),
263
+ "assistant": (
264
+ "<|Assistant|>",
265
+ "<|end▁of▁sentence|>",
266
+ ),
267
+ },
268
+ stop_str=("<|end▁of▁sentence|>",),
269
+ image_token="<image_placeholder>\n",
270
+ )
271
+ )
272
+
273
+ # https://huggingface.co/openbmb/MiniCPM-o-2_6
274
+ register_chat_template(
275
+ ChatTemplate(
276
+ name="minicpmo",
277
+ default_system_prompt=None,
278
+ role_prefix_and_suffix={
279
+ "system": ("", " "),
280
+ "user": ("user:", " "),
281
+ "assistant": ("assistant:", "</s>"),
282
+ },
283
+ stop_str=("<|im_end|>", "<|endoftext|>"),
284
+ image_token="(<image>./</image>)",
285
+ audio_token="(<audio>./</audio>)",
286
+ )
287
+ )
288
+
289
+ register_chat_template(
290
+ ChatTemplate(
291
+ name="janus",
292
+ default_system_prompt=None,
293
+ role_prefix_and_suffix={
294
+ "system": (
295
+ "",
296
+ "",
297
+ ),
298
+ "user": (
299
+ "<|User|>",
300
+ "",
301
+ ),
302
+ "assistant": (
303
+ "<|Assistant|>",
304
+ "<|end▁of▁sentence|>",
305
+ ),
306
+ },
307
+ stop_str=("<|end▁of▁sentence|>",),
308
+ image_token="<image_placeholder>\n",
309
+ )
310
+ )
311
+
312
+ # The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
313
+ register_chat_template(
314
+ ChatTemplate(
315
+ name="llama-3-instruct-llava",
316
+ default_system_prompt=None,
317
+ role_prefix_and_suffix={
318
+ "system": (
319
+ "<|start_header_id|>system<|end_header_id|>\n\n",
320
+ "<|eot_id|>",
321
+ ),
322
+ "user": (
323
+ "<|start_header_id|>user<|end_header_id|>\n\n",
324
+ "<|eot_id|>",
325
+ ),
326
+ "assistant": (
327
+ "<|start_header_id|>assistant<|end_header_id|>\n\n",
328
+ "<|eot_id|>",
329
+ ),
330
+ },
331
+ stop_str=("<|eot_id|>",),
332
+ image_token="<image>\n",
333
+ )
334
+ )
335
+
336
+ # Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
337
+ register_chat_template(
338
+ ChatTemplate(
339
+ name="llama-4",
340
+ default_system_prompt=None,
341
+ role_prefix_and_suffix={
342
+ "system": (
343
+ "<|header_start|>system<|header_end|>\n\n",
344
+ "<|eot|>",
345
+ ),
346
+ "user": (
347
+ "<|header_start|>user<|header_end|>\n\n",
348
+ "<|eot|>",
349
+ ),
350
+ "assistant": (
351
+ "<|header_start|>assistant<|header_end|>\n\n",
352
+ "<|eot|>",
353
+ ),
354
+ },
355
+ stop_str=("<|eot|>",),
356
+ image_token="<|image|>",
357
+ )
358
+ )
359
+
360
+ # Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
361
+ register_chat_template(
362
+ ChatTemplate(
363
+ name="yi-1.5",
364
+ default_system_prompt=None,
365
+ role_prefix_and_suffix={
366
+ "system": ("", ""),
367
+ "user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
368
+ "assistant": ("", "<|im_end|>\n"),
369
+ },
370
+ style=ChatTemplateStyle.PLAIN,
371
+ stop_str=("<|im_end|>",),
372
+ )
373
+ )
374
+
375
+ # Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
376
+ register_chat_template(
377
+ ChatTemplate(
378
+ name="yi-vl",
379
+ default_system_prompt=(
380
+ "This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers."
381
+ "这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。"
382
+ ),
383
+ role_prefix_and_suffix={
384
+ "system": ("", "\n\n"),
385
+ "user": ("### Human:", "\n"),
386
+ "assistant": ("### Assistant:", "\n"),
387
+ },
388
+ image_token=" <image_placeholder>\n",
389
+ )
390
+ )
391
+
392
+ register_chat_template(
393
+ ChatTemplate(
394
+ name="gemma-it",
395
+ default_system_prompt=None,
396
+ role_prefix_and_suffix={
397
+ "system": ("", ""),
398
+ "user": ("<start_of_turn>user\n", "<end_of_turn>\n"),
399
+ "assistant": ("<start_of_turn>model\n", "<end_of_turn>\n"),
400
+ },
401
+ style=ChatTemplateStyle.PLAIN,
402
+ )
403
+ )
404
+
405
+ register_chat_template(
406
+ ChatTemplate(
407
+ name="dbrx-instruct",
408
+ default_system_prompt="You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\nYou assist with various tasks, from writing to coding (using markdown for code blocks — remember to use ``` with code, JSON, and tables).\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY.",
409
+ role_prefix_and_suffix={
410
+ "system": ("<|im_start|>system\n", "<|im_end|>"),
411
+ "user": ("\n<|im_start|>user\n", "<|im_end|>"),
412
+ "assistant": ("\n<|im_start|>assistant\n", "<|im_end|>"),
413
+ },
414
+ stop_str=("<|im_end|>",),
415
+ )
416
+ )
417
+
418
+ register_chat_template(
419
+ ChatTemplate(
420
+ name="c4ai-command-r",
421
+ default_system_prompt=None,
422
+ role_prefix_and_suffix={
423
+ "system": (
424
+ "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>",
425
+ "<|END_OF_TURN_TOKEN|>",
426
+ ),
427
+ "user": ("<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
428
+ "assistant": (
429
+ "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
430
+ "<|END_OF_TURN_TOKEN|>",
431
+ ),
432
+ },
433
+ style=ChatTemplateStyle.PLAIN,
434
+ )
435
+ )
436
+
437
+ # Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
438
+ register_chat_template(
439
+ ChatTemplate(
440
+ name="internvl-2-5",
441
+ default_system_prompt="你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。",
442
+ role_prefix_and_suffix={
443
+ "system": ("<|im_start|>system\n", "<|im_end|>\n"),
444
+ "user": ("<|im_start|>user\n", "<|im_end|>\n"),
445
+ "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
446
+ },
447
+ stop_str=["<|im_end|>", "<|action_end|>"],
448
+ )
449
+ )
450
+
451
+ register_chat_template(
452
+ ChatTemplate(
453
+ name="interns1",
454
+ default_system_prompt="You are an AI assistant whose name is Intern-S1 (书生大模型).\n- Intern-S1 (书生大模型) is a vision-language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n- Intern-S1 (书生大模型) can understand and communicate fluently in the language chosen by the user such as English and 中文.\nYou are an expert reasoner with extensive experience in all areas. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within <think>...</think> tags.",
455
+ role_prefix_and_suffix={
456
+ "system": ("<|im_start|>system\n", "<|im_end|>\n"),
457
+ "user": ("<|im_start|>user\n", "<|im_end|>\n"),
458
+ "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
459
+ },
460
+ stop_str=["<|im_end|>", "<|action_end|>"],
461
+ )
462
+ )
463
+
464
+ register_chat_template(
465
+ ChatTemplate(
466
+ name="granite-3-instruct",
467
+ default_system_prompt=None,
468
+ role_prefix_and_suffix={
469
+ "system": (
470
+ "<|start_of_role|>system<|end_of_role|>",
471
+ "<|end_of_text|>",
472
+ ),
473
+ "user": (
474
+ "<|start_of_role|>user<|end_of_role|>",
475
+ "<|end_of_text|>",
476
+ ),
477
+ "assistant": (
478
+ "<|start_of_role|>assistant<|end_of_role|>",
479
+ "<|end_of_text|>",
480
+ ),
481
+ },
482
+ stop_str=("<|end_of_text|>",),
483
+ )
484
+ )
485
+
486
+ register_chat_template(
487
+ ChatTemplate(
488
+ name="deepseek-v3",
489
+ default_system_prompt=None,
490
+ role_prefix_and_suffix={
491
+ "system": (
492
+ "",
493
+ "",
494
+ ),
495
+ "user": (
496
+ "<|User|>",
497
+ "",
498
+ ),
499
+ "assistant": (
500
+ "<|Assistant|>",
501
+ "<|end▁of▁sentence|>",
502
+ ),
503
+ },
504
+ stop_str=("<|end▁of▁sentence|>",),
505
+ )
506
+ )
507
+
508
+ # Reference: https://huggingface.co/docs/transformers/main/model_doc/glm4_v#usage-example
509
+ register_chat_template(
510
+ ChatTemplate(
511
+ name="glm-4v",
512
+ default_system_prompt=None,
513
+ role_prefix_and_suffix={
514
+ "system": ("<|system|>\n", "\n"),
515
+ "user": ("<|user|>\n", "\n"),
516
+ "assistant": ("<|assistant|>\n", "\n"),
517
+ },
518
+ style=ChatTemplateStyle.PLAIN,
519
+ stop_str=["<|user|>", "<|endoftext|>", "<|observation|>"],
520
+ image_token="<|image|>",
521
+ )
522
+ )
523
+
524
+
525
+ @register_chat_template_matching_function
526
+ def match_deepseek(model_path: str):
527
+ if re.search(r"deepseek-(v3|r1)", model_path, re.IGNORECASE) and not re.search(
528
+ r"base", model_path, re.IGNORECASE
529
+ ):
530
+ return "deepseek-v3"
531
+
532
+
533
+ @register_chat_template_matching_function
534
+ def match_orion(model_path: str):
535
+ if "orion" in model_path.lower():
536
+ return "claude"
537
+
538
+
539
+ @register_chat_template_matching_function
540
+ def match_deepseek_janus_pro(model_path: str):
541
+ if re.search(r"janus", model_path, re.IGNORECASE):
542
+ return "janus-pro"
543
+
544
+
545
+ @register_chat_template_matching_function
546
+ def match_dbrx(model_path: str):
547
+ if re.search(r"dbrx", model_path, re.IGNORECASE) and re.search(
548
+ r"instruct", model_path, re.IGNORECASE
549
+ ):
550
+ return "dbrx-instruct"
551
+
552
+
553
+ @register_chat_template_matching_function
554
+ def match_vicuna(model_path: str):
555
+ if re.search(r"vicuna|llava-v1\.5|llava-next-video-7b", model_path, re.IGNORECASE):
556
+ return "vicuna_v1.1"
557
+
558
+
559
+ @register_chat_template_matching_function
560
+ def match_llama2_chat(model_path: str):
561
+ if re.search(
562
+ r"llama-2.*chat|codellama.*instruct",
563
+ model_path,
564
+ re.IGNORECASE,
565
+ ):
566
+ return "llama-2-chat"
567
+
568
+
569
+ @register_chat_template_matching_function
570
+ def match_mistral(model_path: str):
571
+ if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
572
+ return "mistral"
573
+
574
+
575
+ @register_chat_template_matching_function
576
+ def match_llama3_instruct(model_path: str):
577
+ if re.search(r"llama-3.*instruct", model_path, re.IGNORECASE):
578
+ return "llama-3-instruct"
579
+
580
+
581
+ @register_chat_template_matching_function
582
+ def match_chat_ml(model_path: str):
583
+ if re.search(r"tinyllama", model_path, re.IGNORECASE):
584
+ return "chatml"
585
+ if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
586
+ return "qwen2-vl"
587
+ if re.search(r"glm[-_]?4(\.\d+)?v", model_path, re.IGNORECASE):
588
+ return "glm-4v"
589
+ if re.search(r"qwen.*(chat|instruct)", model_path, re.IGNORECASE) and not re.search(
590
+ r"llava", model_path, re.IGNORECASE
591
+ ):
592
+ return "qwen"
593
+ if re.search(
594
+ r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
595
+ model_path,
596
+ re.IGNORECASE,
597
+ ):
598
+ return "chatml-llava"
599
+
600
+
601
+ @register_chat_template_matching_function
602
+ def match_chat_yi(model_path: str):
603
+ if re.search(r"yi-vl", model_path, re.IGNORECASE) and not re.search(
604
+ r"llava", model_path, re.IGNORECASE
605
+ ):
606
+ return "yi-vl"
607
+ elif re.search(r"yi-1\.5.*chat", model_path, re.IGNORECASE):
608
+ return "yi-1.5"
609
+
610
+
611
+ @register_chat_template_matching_function
612
+ def match_gemma_it(model_path: str):
613
+ if re.search(r"gemma.*it", model_path, re.IGNORECASE):
614
+ return "gemma-it"
615
+
616
+
617
+ @register_chat_template_matching_function
618
+ def match_openbmb_minicpm(model_path: str):
619
+ if re.search(r"minicpm-v", model_path, re.IGNORECASE):
620
+ return "minicpmv"
621
+ elif re.search(r"minicpm-o", model_path, re.IGNORECASE):
622
+ return "minicpmo"
623
+
624
+
625
+ @register_chat_template_matching_function
626
+ def match_c4ai_command_r(model_path: str):
627
+ if re.search(r"c4ai-command-r", model_path, re.IGNORECASE):
628
+ return "c4ai-command-r"
629
+
630
+
631
+ @register_chat_template_matching_function
632
+ def match_granite_instruct(model_path: str):
633
+ if re.search(r"granite.*instruct", model_path, re.IGNORECASE):
634
+ return "granite-3-instruct"
635
+
636
+
637
+ @register_chat_template_matching_function
638
+ def match_gemma3_instruct(model_path: str):
639
+ if re.search(r"gemma-3", model_path, re.IGNORECASE):
640
+ return "gemma-it"
641
+
642
+
643
+ @register_chat_template_matching_function
644
+ def match_internvl_chat(model_path: str):
645
+ if re.search(r"internvl2_5", model_path, re.IGNORECASE):
646
+ return "internvl-2-5"
647
+
648
+
649
+ @register_chat_template_matching_function
650
+ def match_interns1_chat(model_path: str):
651
+ if re.search(r"intern-s1", model_path, re.IGNORECASE):
652
+ return "interns1"
653
+ if re.search(r"interns1", model_path, re.IGNORECASE):
654
+ return "interns1"
655
+
656
+
657
+ if __name__ == "__main__":
658
+ messages = [
659
+ {"role": "system", "content": None}, # None means default
660
+ # {"role": "system", "content": "You are a helpful, respectful and honest assistant."},
661
+ {"role": "user", "content": "Hello!"},
662
+ {"role": "assistant", "content": "Hi!"},
663
+ {"role": "user", "content": "What can you do?"},
664
+ {"role": "assistant", "content": "I can chat with you."},
665
+ ]
666
+
667
+ template = get_chat_template("llama-2-chat")
668
+ print(template.get_prompt(messages))