Lekr0 commited on
Commit
5513247
·
verified ·
1 Parent(s): d02d576

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. sglang/3rdparty/amd/profiling/PROFILING.md +425 -0
  2. sglang/3rdparty/amd/profiling/client.sh +27 -0
  3. sglang/3rdparty/amd/profiling/install_rpd.sh +10 -0
  4. sglang/3rdparty/amd/profiling/loadTracer.sh +43 -0
  5. sglang/3rdparty/amd/profiling/rpd.patch +12 -0
  6. sglang/3rdparty/amd/profiling/rpd_profile_server_enable.patch +49 -0
  7. sglang/3rdparty/amd/profiling/rpd_profile_server_enable_wCPU_activities.patch +126 -0
  8. sglang/3rdparty/amd/profiling/server.sh +20 -0
  9. sglang/3rdparty/amd/profiling/torch_profiler.patch +25 -0
  10. sglang/3rdparty/amd/sgl-kernel/CMakeLists_rocm.txt +159 -0
  11. sglang/3rdparty/amd/sgl-kernel/build_rocm.sh +123 -0
  12. sglang/3rdparty/amd/sgl-kernel/rename_wheels_rocm.sh +30 -0
  13. sglang/3rdparty/amd/sgl-kernel/rocm_hipify.py +40 -0
  14. sglang/3rdparty/amd/tuning/TUNING.md +118 -0
  15. sglang/3rdparty/amd/tuning/benchmark_moe_rocm.py +378 -0
  16. sglang/docs/supported_models/extending/modelscope.md +28 -0
  17. sglang/docs/supported_models/extending/support_new_models.md +320 -0
  18. sglang/docs/supported_models/retrieval_ranking/classify_models.md +162 -0
  19. sglang/docs/supported_models/retrieval_ranking/embedding_models.md +126 -0
  20. sglang/docs/supported_models/retrieval_ranking/rerank_models.md +313 -0
  21. sglang/docs/supported_models/specialized/index.rst +9 -0
  22. sglang/docs/supported_models/specialized/reward_models.md +28 -0
  23. sglang/docs/supported_models/text_generation/diffusion_language_models.md +111 -0
  24. sglang/docs/supported_models/text_generation/generative_models.md +72 -0
  25. sglang/docs/supported_models/text_generation/index.rst +11 -0
  26. sglang/docs/supported_models/text_generation/multimodal_language_models.md +136 -0
  27. sglang/python/sglang/srt/__pycache__/constants.cpython-311.pyc +0 -0
  28. sglang/python/sglang/srt/__pycache__/environ.cpython-311.pyc +0 -0
  29. sglang/python/sglang/srt/batch_overlap/__pycache__/operations.cpython-311.pyc +0 -0
  30. sglang/python/sglang/srt/batch_overlap/__pycache__/operations_strategy.cpython-311.pyc +0 -0
  31. sglang/python/sglang/srt/batch_overlap/__pycache__/single_batch_overlap.cpython-311.pyc +0 -0
  32. sglang/python/sglang/srt/batch_overlap/__pycache__/two_batch_overlap.cpython-311.pyc +0 -0
  33. sglang/python/sglang/srt/batch_overlap/operations.py +213 -0
  34. sglang/python/sglang/srt/batch_overlap/operations_strategy.py +302 -0
  35. sglang/python/sglang/srt/batch_overlap/single_batch_overlap.py +144 -0
  36. sglang/python/sglang/srt/batch_overlap/two_batch_overlap.py +1082 -0
  37. sglang/python/sglang/srt/checkpoint_engine/__init__.py +9 -0
  38. sglang/python/sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +143 -0
  39. sglang/python/sglang/srt/checkpoint_engine/update.py +317 -0
  40. sglang/python/sglang/srt/compilation/__pycache__/compilation_config.cpython-311.pyc +0 -0
  41. sglang/python/sglang/srt/compilation/__pycache__/compile.cpython-311.pyc +0 -0
  42. sglang/python/sglang/srt/compilation/__pycache__/piecewise_context_manager.cpython-311.pyc +0 -0
  43. sglang/python/sglang/srt/compilation/backend.py +472 -0
  44. sglang/python/sglang/srt/compilation/compilation_config.py +45 -0
  45. sglang/python/sglang/srt/compilation/compilation_counter.py +47 -0
  46. sglang/python/sglang/srt/compilation/compile.py +203 -0
  47. sglang/python/sglang/srt/compilation/compiler_interface.py +504 -0
  48. sglang/python/sglang/srt/compilation/cuda_piecewise_backend.py +206 -0
  49. sglang/python/sglang/srt/compilation/fix_functionalization.py +134 -0
  50. sglang/python/sglang/srt/compilation/fx_utils.py +83 -0
sglang/3rdparty/amd/profiling/PROFILING.md ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Profiling SGLang Infer System with AMD GPUs
2
+ This AppNote describes the SGLang profiling technical, code augment and running steps for systems with AMD Instinct GPUs, nevertheless the same procedure may work with Nvidia GPUs too.
3
+ Examples and steps are provided in detail, to facilitate easy reproduce and use to localize performance problem towards optimizations.
4
+ Two primary methods are covered:
5
+ - [RPD](https://github.com/ROCm/rocmProfileData.git)
6
+ - [PyTorch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
7
+
8
+ ### Profiling SGLang Infer System with RPD Profiler
9
+ RPD profiler is a low-overhead cross-platform profiler. Therefore, the same RPD code augment not only works for profiling on ROCm/AMD GPUs, but also works for profiling on CUDA/Nvidia GPUs as well. To do RPD profiling on SGLang repository, please use scripts and patch files included in this directory and follow the steps below:
10
+ 1. Install RPD with rpd.patch applied during installation using install_rpd.sh, both files are in this directory.
11
+
12
+ install_rpd.sh
13
+
14
+ ```bash
15
+ # download and install RPD
16
+ apt update && apt install -y sqlite3 libsqlite3-dev libfmt-dev
17
+
18
+ # install rpd module
19
+ git clone https://github.com/ROCmSoftwarePlatform/rocmProfileData
20
+ cd rocmProfileData
21
+ git checkout 976899e9c6dbc6dd2bccf770818e4e44125590ac
22
+ git apply rpd.patch
23
+ make && make install
24
+ cd rocpd_python && python setup.py install && cd ..
25
+ cd rpd_tracer && make clean;make install && python setup.py install && cd ..
26
+ ```
27
+
28
+ rpd.patch
29
+
30
+ ```bash
31
+ diff --git a/rpd_tracer/Makefile b/rpd_tracer/Makefile
32
+ index e9d9feb..b2e9e1a 100644
33
+ --- a/rpd_tracer/Makefile
34
+ +++ b/rpd_tracer/Makefile
35
+ @@ -16,7 +16,7 @@ ifneq (,$(HIP_PATH))
36
+ $(info Building with roctracer)
37
+ RPD_LIBS += -L/opt/rocm/lib -lroctracer64 -lroctx64 -lamdhip64 -lrocm_smi64
38
+ RPD_INCLUDES += -I/opt/rocm/include -I/opt/rocm/include/roctracer -I/opt/rocm/include/hsa
39
+ - RPD_SRCS += RoctracerDataSource.cpp RocmSmiDataSource.cpp
40
+ + RPD_SRCS += RoctracerDataSource.cpp
41
+ RPD_INCLUDES += -D__HIP_PLATFORM_AMD__
42
+ endif
43
+ ```
44
+ 2. Add loadTracer.sh file included in this directory to /sglang/python/sglang.
45
+
46
+ loadTracer.sh
47
+
48
+ ```bash
49
+ #!/bin/bash
50
+ ################################################################################
51
+ # Copyright (c) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved.
52
+ #
53
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
54
+ # of this software and associated documentation files (the "Software"), to deal
55
+ # in the Software without restriction, including without limitation the rights
56
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
57
+ # copies of the Software, and to permit persons to whom the Software is
58
+ # furnished to do so, subject to the following conditions:
59
+ #
60
+ # The above copyright notice and this permission notice shall be included in
61
+ # all copies or substantial portions of the Software.
62
+ #
63
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
64
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
65
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
66
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
67
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
68
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
69
+ # THE SOFTWARE.
70
+ ################################################################################
71
+ OUTPUT_FILE="trace.rpd"
72
+
73
+ if [ "$1" = "-o" ] ; then
74
+ OUTPUT_FILE=$2
75
+ shift
76
+ shift
77
+ fi
78
+
79
+ if [ -e ${OUTPUT_FILE} ] ; then
80
+ rm ${OUTPUT_FILE}
81
+ fi
82
+
83
+ python3 -m rocpd.schema --create ${OUTPUT_FILE}
84
+ if [ $? != 0 ] ; then
85
+ echo "Error: Could not create rpd file. Please run 'python setup.py install' from the rocpd_python dir"
86
+ exit
87
+ fi
88
+
89
+ export RPDT_FILENAME=${OUTPUT_FILE}
90
+ export RPDT_AUTOSTART=0
91
+ LD_PRELOAD=librocm-smi_64:librpd_tracer.so "$@"
92
+ ```
93
+ 3. Apply patch (provided in this directory) with "git apply rpd_profile_server_enable.patch" if the main profiling purpose is to get info on gpu kernels as well as limited cpu activity info.
94
+
95
+ #### Common Notes 1
96
+ Please note that although we are doing TP=8 in the example, we purposely only log RPD profiling on 2 ranks in the patch file (i.e.tp_rank=0/1) for profiling/visualization convenience, as even Perfetto streaming mode can only load maximal 8GB json file for visualization. With 2 ranks logged in RPD profiling, we could still check whether there are issues among ranks (e.g. load imbalance issue, nccl issue), and at the same time, we could log relatively longer time duration before the json file generated from RPD file hits 8GB size.
97
+
98
+ rpd_profile_server_enable.patch
99
+
100
+ ```bash
101
+ diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
102
+ index 62d1ff9..9021c01 100644
103
+ --- a/python/sglang/srt/managers/scheduler.py
104
+ +++ b/python/sglang/srt/managers/scheduler.py
105
+ @@ -71,6 +71,8 @@ from sglang.srt.utils import (
106
+ suppress_other_loggers,
107
+ )
108
+ from sglang.utils import get_exception_traceback
109
+ +from rpdTracerControl import rpdTracerControl
110
+ +rpdTracerControl.skipCreate()
111
+
112
+ logger = logging.getLogger(__name__)
113
+
114
+ @@ -245,6 +247,7 @@ class Scheduler:
115
+ ],
116
+ with_stack=True,
117
+ )
118
+ + self.rpd = rpdTracerControl()
119
+
120
+ @torch.inference_mode()
121
+ def event_loop(self):
122
+ @@ -1027,15 +1030,24 @@ class Scheduler:
123
+ def start_profile(self) -> None:
124
+ if self.profiler is None:
125
+ raise RuntimeError("Profiler is not enabled.")
126
+ - self.profiler.start()
127
+ + #self.profiler.start() #block pytorch profiler for rpd profiler enabling
128
+ + if self.tp_rank == 0 or self.tp_rank == 1:
129
+ + self.rpd.start()
130
+ + self.rpd.rangePush("", "rpd profile range", "")
131
+ + logger.info("rpd is enabled")
132
+
133
+ def stop_profile(self) -> None:
134
+ if self.profiler is None:
135
+ raise RuntimeError("Profiler is not enabled.")
136
+ - self.profiler.stop()
137
+ - self.profiler.export_chrome_trace(
138
+ - self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
139
+ - )
140
+ + #self.profiler.stop()
141
+ + #self.profiler.export_chrome_trace(
142
+ + # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
143
+ + #)
144
+ + if self.tp_rank ==0 or self.tp_rank ==1:
145
+ + self.rpd.rangePop()
146
+ + self.rpd.stop()
147
+ + self.rpd.flush()
148
+ + logger.info("rpd is done")
149
+ logger.info("Profiler is done")
150
+ ```
151
+
152
+ #### Advanced Debugging with RPD Profiler
153
+ Sometimes, we want to use rpd profiler to capture more CPU and python activities in order to debug some challenging issues (e.g. root cause of load imbalance across gpu processes, root cause of bubbles, etc). Only in such cases, we need to apply patch "git apply rpd_profile_server_enable_wCPU_activities.patch", where 3 files are modified.
154
+
155
+ rpd_profile_server_enable_wCPU_activities.patch
156
+
157
+ ```bash
158
+ diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
159
+ index 62d1ff9..2edb427 100644
160
+ --- a/python/sglang/srt/managers/scheduler.py
161
+ +++ b/python/sglang/srt/managers/scheduler.py
162
+ @@ -71,6 +71,8 @@ from sglang.srt.utils import (
163
+ suppress_other_loggers,
164
+ )
165
+ from sglang.utils import get_exception_traceback
166
+ +from rpdTracerControl import rpdTracerControl
167
+ +rpdTracerControl.skipCreate()
168
+
169
+ logger = logging.getLogger(__name__)
170
+
171
+ @@ -245,6 +247,7 @@ class Scheduler:
172
+ ],
173
+ with_stack=True,
174
+ )
175
+ + self.rpd = rpdTracerControl()
176
+
177
+ @torch.inference_mode()
178
+ def event_loop(self):
179
+ @@ -1027,15 +1030,26 @@ class Scheduler:
180
+ def start_profile(self) -> None:
181
+ if self.profiler is None:
182
+ raise RuntimeError("Profiler is not enabled.")
183
+ - self.profiler.start()
184
+ + #self.profiler.start()
185
+ + logger.info("torch profiler is disabled")
186
+ + if self.tp_rank == 0 or self.tp_rank == 1:
187
+ + self.rpd.setPythonTrace(True)
188
+ + self.rpd.start()
189
+ + self.rpd.rangePush("", "scheduler", "")
190
+ + logger.info("rpd is enabled inside scheduler profiling")
191
+
192
+ def stop_profile(self) -> None:
193
+ if self.profiler is None:
194
+ raise RuntimeError("Profiler is not enabled.")
195
+ - self.profiler.stop()
196
+ - self.profiler.export_chrome_trace(
197
+ - self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
198
+ - )
199
+ + #self.profiler.stop()
200
+ + #self.profiler.export_chrome_trace(
201
+ + # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
202
+ + #)
203
+ + if self.tp_rank ==0 or self.tp_rank ==1:
204
+ + self.rpd.rangePop()
205
+ + self.rpd.stop()
206
+ + self.rpd.flush()
207
+ + logger.info("rpd is done inside scheduler")
208
+ logger.info("Profiler is done")
209
+
210
+
211
+ diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
212
+ index 2621ccd..181df85 100644
213
+ --- a/python/sglang/srt/managers/tokenizer_manager.py
214
+ +++ b/python/sglang/srt/managers/tokenizer_manager.py
215
+ @@ -58,6 +58,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
216
+ from sglang.srt.server_args import PortArgs, ServerArgs
217
+ from sglang.srt.utils import is_generation_model, is_multimodal_model
218
+
219
+ +from rpdTracerControl import rpdTracerControl
220
+ +rpdTracerControl.skipCreate()
221
+ +
222
+ +
223
+ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
224
+
225
+ logger = logging.getLogger(__name__)
226
+ @@ -514,10 +518,20 @@ class TokenizerManager:
227
+ self.send_to_scheduler.send_pyobj(req)
228
+
229
+ def start_profile(self):
230
+ + rpd = rpdTracerControl()
231
+ + rpd.setPythonTrace(True)
232
+ + rpd.start()
233
+ + rpd.rangePush("", "tokenizer_manager", "")
234
+ + logger.info("tokenizer_manager rpd profiling started!")
235
+ req = ProfileReq.START_PROFILE
236
+ self.send_to_scheduler.send_pyobj(req)
237
+
238
+ def stop_profile(self):
239
+ + rpd = rpdTracerControl()
240
+ + rpd.rangePop()
241
+ + rpd.stop()
242
+ + rpd.flush()
243
+ + logger.info("rpd profiling is done inside tokenizer_manager!")
244
+ req = ProfileReq.STOP_PROFILE
245
+ self.send_to_scheduler.send_pyobj(req)
246
+
247
+ diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
248
+ index 7111c93..2bd722c 100644
249
+ --- a/python/sglang/srt/server.py
250
+ +++ b/python/sglang/srt/server.py
251
+ @@ -30,6 +30,8 @@ import threading
252
+ import time
253
+ from http import HTTPStatus
254
+ from typing import Dict, List, Optional, Union
255
+ +from rpdTracerControl import rpdTracerControl
256
+ +rpdTracerControl.skipCreate()
257
+
258
+ # Fix a bug of Python threading
259
+ setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
260
+ @@ -152,6 +154,11 @@ async def flush_cache():
261
+ @app.post("/start_profile")
262
+ async def start_profile():
263
+ """Start profiling."""
264
+ + rpd = rpdTracerControl()
265
+ + rpd.setPythonTrace(True)
266
+ + rpd.start()
267
+ + rpd.rangePush("", "server rpd profile range", "")
268
+ + logger.info("rpd profiling started in server.py!")
269
+ tokenizer_manager.start_profile()
270
+ return Response(
271
+ content="Start profiling.\n",
272
+ @@ -164,6 +171,11 @@ async def start_profile():
273
+ async def stop_profile():
274
+ """Stop profiling."""
275
+ tokenizer_manager.stop_profile()
276
+ + rpd = rpdTracerControl()
277
+ + rpd.rangePop()
278
+ + rpd.stop()
279
+ + rpd.flush()
280
+ + logger.info("rpd profiling is done in server.py!")
281
+ return Response(
282
+ content="Stop profiling. This will take some time.\n",
283
+ status_code=200,
284
+ ```
285
+
286
+ 4. As an example for grok1 profiling, we create a dummy_grok1 directory with config.json (see content below) inside this directory and copy this directory to the right path for "--model-path" if you want to use the example server.sh file provided.
287
+ ```bash
288
+ cat ../dummy_grok1/config.json
289
+ {
290
+ "architectures": [
291
+ "Grok1ModelForCausalLM"
292
+ ],
293
+ "embedding_multiplier_scale": 78.38367176906169,
294
+ "output_multiplier_scale": 0.5773502691896257,
295
+ "vocab_size": 131072,
296
+ "hidden_size": 6144,
297
+ "intermediate_size": 32768,
298
+ "max_position_embeddings": 8192,
299
+ "num_experts_per_tok": 2,
300
+ "num_local_experts": 8,
301
+ "num_attention_heads": 48,
302
+ "num_hidden_layers": 64,
303
+ "num_key_value_heads": 8,
304
+ "head_dim": 128,
305
+ "rms_norm_eps": 1e-05,
306
+ "rope_theta": 10000.0,
307
+ "model_type": "mixtral",
308
+ "torch_dtype": "bfloat16"
309
+ }
310
+ ```
311
+ 5. Launch server with rpd enabled script ./server.sh in one terminal inside the docker container.
312
+
313
+ #### Common Notes 2
314
+ - Remember to change model-path to the correct path
315
+ - loadTracer.sh is needed to conduct profiling
316
+ - SGLANG_TORCH_PROFILER_DIR is used for default torch profiler
317
+ - Do not use loadTracer.sh if you are using the torch profiler, simply use python3 -m sglang.launch_server.
318
+
319
+
320
+ server.sh
321
+
322
+ ```bash
323
+ #!/bin/bash
324
+
325
+ # export SGLANG_TORCH_PROFILER_DIR=/data/sglang/
326
+ export SGLANG_TORCH_PROFILER_DIR=/sgl-workspace/sglang/profile/
327
+
328
+ # Get the current timestamp
329
+ TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
330
+
331
+ # Define the log file with a timestamp
332
+ LOGFILE="sglang_server_log_$TIMESTAMP.json"
333
+
334
+ # Run the Python command and save the output to the log file
335
+ loadTracer.sh python3 -m sglang.launch_server \
336
+ --model-path /sgl-workspace/sglang/dummy_grok1 \
337
+ --tokenizer-path Xenova/grok-1-tokenizer \
338
+ --load-format dummy \
339
+ --quantization fp8 \
340
+ --tp 8 \
341
+ --port 30000 \
342
+ --disable-radix-cache 2>&1 | tee "$LOGFILE"
343
+ ```
344
+ 6. Open another terminal for the same docker container, and run the rpd enabled ./client.sh after you see "The server is fired up and is ready to roll!" message from server side terminal.
345
+
346
+ #### Common Notes 3
347
+ - Use curl http://localhost:30000/start_profile & curl http://localhost:30000/stop_profile to control the start and end of profiling. Check sglang/python/sglang/srt/managers/scheduler.py for more details.
348
+ - Please don't use RPD profiler together with PyTorch profiler to avoid interference.
349
+ - The rocmProfileData/tools/rpd2tracing.py file is used to generate json file from RPD file.
350
+
351
+ client.sh
352
+
353
+ ```bash
354
+ #!/bin/bash
355
+
356
+ # Start profiling via API
357
+ curl http://localhost:30000/start_profile -H "Content-Type: application/json"
358
+
359
+ # Benchmark serving using sglang with random dataset and tokenizer
360
+ # Define the log file with a timestamp
361
+ TIMESTAMP=$(date +%Y%m%d_%H%M%S)
362
+ LOGFILE="sglang_client_log_$TIMESTAMP.json"
363
+
364
+ # Run the benchmark with specified parameters and save logs
365
+ python3 -m sglang.bench_serving \
366
+ --backend sglang \
367
+ --tokenizer Xenova/grok-1-tokenizer \
368
+ --dataset-name random \
369
+ --random-input 1024\
370
+ --random-output 1024 \
371
+ --num-prompts 120 \
372
+ --request-rate 8 \
373
+ --output-file online.jsonl 2>&1 | tee "$LOGFILE"
374
+
375
+ # Stop profiling via API
376
+ curl http://localhost:30000/stop_profile -H "Content-Type: application/json"
377
+
378
+ # Convert tracing file to csv & json
379
+ sqlite3 trace.rpd ".mode csv" ".header on" ".output trace.csv" "select * from top;" ".output stdout"
380
+ python3 ./rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json
381
+ ```
382
+ 7. Follow [Perfetto docs](https://perfetto.dev/docs/visualization/large-traces) to visualize large json files. Try to adjust parameters so that the trace.json file size is less than 9GB.
383
+
384
+ ### Profiling SGLang Infer System with PyTorch Profiler
385
+
386
+ Please use the steps as follows:
387
+
388
+ 1. Apply the patch torch_profiler.patch. Note that you can modify "if self.tp_rank == 0" in the patch to allow more ranks be recorded in profiling.
389
+
390
+ torch_profiler.patch
391
+ ```bash
392
+ diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
393
+ index 62d1ff9..6ecd78c 100644
394
+ --- a/python/sglang/srt/managers/scheduler.py
395
+ +++ b/python/sglang/srt/managers/scheduler.py
396
+ @@ -240,7 +240,6 @@ class Scheduler:
397
+ )
398
+ self.profiler = torch.profiler.profile(
399
+ activities=[
400
+ - torch.profiler.ProfilerActivity.CPU,
401
+ torch.profiler.ProfilerActivity.CUDA,
402
+ ],
403
+ with_stack=True,
404
+ @@ -1033,9 +1032,11 @@ class Scheduler:
405
+ if self.profiler is None:
406
+ raise RuntimeError("Profiler is not enabled.")
407
+ self.profiler.stop()
408
+ - self.profiler.export_chrome_trace(
409
+ - self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
410
+ - )
411
+ + if self.tp_rank == 0:
412
+ + with open(f"stats_repro_{int(time.time())}.txt", "w") as f:
413
+ + print(self.profiler.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=-1), file=f)
414
+ + print("Profiling stats done.")
415
+ +
416
+ logger.info("Profiler is done")
417
+ ```
418
+
419
+ 2. Create the model path directory and copy it to the right path for "--model-path" if you want to use the server.sh file provided.
420
+
421
+ 3. Modify the included server.sh by removing "loadTracer.sh" before python command and launch script ./server.sh in one terminal inside the docker container.
422
+
423
+ 4. Similar to step 6 in RPD profiling section, but remove the last 2 lines in client.sh, which converted rpd file into csv and json files. Run modified client.sh for PyTorch profiling.
424
+ -------
425
+ - [Torch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
sglang/3rdparty/amd/profiling/client.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Start profiling via API
4
+ curl http://localhost:30000/start_profile -H "Content-Type: application/json"
5
+
6
+ # Benchmark serving using sglang with random dataset and tokenizer
7
+ # Define the log file with a timestamp
8
+ TIMESTAMP=$(date +%Y%m%d_%H%M%S)
9
+ LOGFILE="sglang_client_log_$TIMESTAMP.json"
10
+
11
+ # Run the benchmark with specified parameters and save logs
12
+ python3 -m sglang.bench_serving \
13
+ --backend sglang \
14
+ --tokenizer Xenova/grok-1-tokenizer \
15
+ --dataset-name random \
16
+ --random-input 1024\
17
+ --random-output 1024 \
18
+ --num-prompts 240 \
19
+ --request-rate 8 \
20
+ --output-file online.jsonl 2>&1 | tee "$LOGFILE"
21
+
22
+ # Stop profiling via API
23
+ curl http://localhost:30000/stop_profile -H "Content-Type: application/json"
24
+
25
+ # Convert tracing file to csv & json
26
+ sqlite3 trace.rpd ".mode csv" ".header on" ".output trace.csv" "select * from top;" ".output stdout"
27
+ python3 /sgl-workspace/rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json
sglang/3rdparty/amd/profiling/install_rpd.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # download and install RPD
2
+ apt update && apt install -y sqlite3 libsqlite3-dev libfmt-dev
3
+
4
+ # install rpd module
5
+ git clone https://github.com/ROCmSoftwarePlatform/rocmProfileData
6
+ cd rocmProfileData
7
+ git apply rpd.patch
8
+ make && make install
9
+ cd rocpd_python && python setup.py install && cd ..
10
+ cd rpd_tracer && make clean;make install && python setup.py install && cd ..
sglang/3rdparty/amd/profiling/loadTracer.sh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ ################################################################################
3
+ # Copyright (c) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved.
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in
13
+ # all copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21
+ # THE SOFTWARE.
22
+ ################################################################################
23
+ OUTPUT_FILE="trace.rpd"
24
+
25
+ if [ "$1" = "-o" ] ; then
26
+ OUTPUT_FILE=$2
27
+ shift
28
+ shift
29
+ fi
30
+
31
+ if [ -e ${OUTPUT_FILE} ] ; then
32
+ rm ${OUTPUT_FILE}
33
+ fi
34
+
35
+ python3 -m rocpd.schema --create ${OUTPUT_FILE}
36
+ if [ $? != 0 ] ; then
37
+ echo "Error: Could not create rpd file. Please run 'python setup.py install' from the rocpd_python dir"
38
+ exit
39
+ fi
40
+
41
+ export RPDT_FILENAME=${OUTPUT_FILE}
42
+ export RPDT_AUTOSTART=0
43
+ LD_PRELOAD=librocm-smi_64:librpd_tracer.so "$@"
sglang/3rdparty/amd/profiling/rpd.patch ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/rpd_tracer/Makefile b/rpd_tracer/Makefile
2
+ index e9d9feb..b2e9e1a 100644
3
+ --- a/rpd_tracer/Makefile
4
+ +++ b/rpd_tracer/Makefile
5
+ @@ -16,7 +16,7 @@ ifneq (,$(HIP_PATH))
6
+ $(info Building with roctracer)
7
+ RPD_LIBS += -L/opt/rocm/lib -lroctracer64 -lroctx64 -lamdhip64 -lrocm_smi64
8
+ RPD_INCLUDES += -I/opt/rocm/include -I/opt/rocm/include/roctracer -I/opt/rocm/include/hsa
9
+ - RPD_SRCS += RoctracerDataSource.cpp RocmSmiDataSource.cpp
10
+ + RPD_SRCS += RoctracerDataSource.cpp
11
+ RPD_INCLUDES += -D__HIP_PLATFORM_AMD__
12
+ endif
sglang/3rdparty/amd/profiling/rpd_profile_server_enable.patch ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
2
+ index 62d1ff9..9021c01 100644
3
+ --- a/python/sglang/srt/managers/scheduler.py
4
+ +++ b/python/sglang/srt/managers/scheduler.py
5
+ @@ -71,6 +71,8 @@ from sglang.srt.utils import (
6
+ suppress_other_loggers,
7
+ )
8
+ from sglang.utils import get_exception_traceback
9
+ +from rpdTracerControl import rpdTracerControl
10
+ +rpdTracerControl.skipCreate()
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ @@ -245,6 +247,7 @@ class Scheduler:
15
+ ],
16
+ with_stack=True,
17
+ )
18
+ + self.rpd = rpdTracerControl()
19
+
20
+ @torch.inference_mode()
21
+ def event_loop(self):
22
+ @@ -1027,15 +1030,24 @@ class Scheduler:
23
+ def start_profile(self) -> None:
24
+ if self.profiler is None:
25
+ raise RuntimeError("Profiler is not enabled.")
26
+ - self.profiler.start()
27
+ + #self.profiler.start() #block pytorch profiler for rpd profiler enabling
28
+ + if self.tp_rank == 0 or self.tp_rank == 1:
29
+ + self.rpd.start()
30
+ + self.rpd.rangePush("", "rpd profile range", "")
31
+ + logger.info("rpd is enabled")
32
+
33
+ def stop_profile(self) -> None:
34
+ if self.profiler is None:
35
+ raise RuntimeError("Profiler is not enabled.")
36
+ - self.profiler.stop()
37
+ - self.profiler.export_chrome_trace(
38
+ - self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
39
+ - )
40
+ + #self.profiler.stop()
41
+ + #self.profiler.export_chrome_trace(
42
+ + # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
43
+ + #)
44
+ + if self.tp_rank ==0 or self.tp_rank ==1:
45
+ + self.rpd.rangePop()
46
+ + self.rpd.stop()
47
+ + self.rpd.flush()
48
+ + logger.info("rpd is done")
49
+ logger.info("Profiler is done")
sglang/3rdparty/amd/profiling/rpd_profile_server_enable_wCPU_activities.patch ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
2
+ index 62d1ff9..2edb427 100644
3
+ --- a/python/sglang/srt/managers/scheduler.py
4
+ +++ b/python/sglang/srt/managers/scheduler.py
5
+ @@ -71,6 +71,8 @@ from sglang.srt.utils import (
6
+ suppress_other_loggers,
7
+ )
8
+ from sglang.utils import get_exception_traceback
9
+ +from rpdTracerControl import rpdTracerControl
10
+ +rpdTracerControl.skipCreate()
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ @@ -245,6 +247,7 @@ class Scheduler:
15
+ ],
16
+ with_stack=True,
17
+ )
18
+ + self.rpd = rpdTracerControl()
19
+
20
+ @torch.inference_mode()
21
+ def event_loop(self):
22
+ @@ -1027,15 +1030,26 @@ class Scheduler:
23
+ def start_profile(self) -> None:
24
+ if self.profiler is None:
25
+ raise RuntimeError("Profiler is not enabled.")
26
+ - self.profiler.start()
27
+ + #self.profiler.start()
28
+ + logger.info("torch profiler is disabled")
29
+ + if self.tp_rank == 0 or self.tp_rank == 1:
30
+ + self.rpd.setPythonTrace(True)
31
+ + self.rpd.start()
32
+ + self.rpd.rangePush("", "scheduler", "")
33
+ + logger.info("rpd is enabled inside scheduler profiling")
34
+
35
+ def stop_profile(self) -> None:
36
+ if self.profiler is None:
37
+ raise RuntimeError("Profiler is not enabled.")
38
+ - self.profiler.stop()
39
+ - self.profiler.export_chrome_trace(
40
+ - self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
41
+ - )
42
+ + #self.profiler.stop()
43
+ + #self.profiler.export_chrome_trace(
44
+ + # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
45
+ + #)
46
+ + if self.tp_rank ==0 or self.tp_rank ==1:
47
+ + self.rpd.rangePop()
48
+ + self.rpd.stop()
49
+ + self.rpd.flush()
50
+ + logger.info("rpd is done inside scheduler")
51
+ logger.info("Profiler is done")
52
+
53
+
54
+ diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
55
+ index 2621ccd..181df85 100644
56
+ --- a/python/sglang/srt/managers/tokenizer_manager.py
57
+ +++ b/python/sglang/srt/managers/tokenizer_manager.py
58
+ @@ -58,6 +58,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
59
+ from sglang.srt.server_args import PortArgs, ServerArgs
60
+ from sglang.srt.utils import is_generation_model, is_multimodal_model
61
+
62
+ +from rpdTracerControl import rpdTracerControl
63
+ +rpdTracerControl.skipCreate()
64
+ +
65
+ +
66
+ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
67
+
68
+ logger = logging.getLogger(__name__)
69
+ @@ -514,10 +518,20 @@ class TokenizerManager:
70
+ self.send_to_scheduler.send_pyobj(req)
71
+
72
+ def start_profile(self):
73
+ + rpd = rpdTracerControl()
74
+ + rpd.setPythonTrace(True)
75
+ + rpd.start()
76
+ + rpd.rangePush("", "tokenizer_manager", "")
77
+ + logger.info("tokenizer_manager rpd profiling started!")
78
+ req = ProfileReq.START_PROFILE
79
+ self.send_to_scheduler.send_pyobj(req)
80
+
81
+ def stop_profile(self):
82
+ + rpd = rpdTracerControl()
83
+ + rpd.rangePop()
84
+ + rpd.stop()
85
+ + rpd.flush()
86
+ + logger.info("rpd profiling is done inside tokenizer_manager!")
87
+ req = ProfileReq.STOP_PROFILE
88
+ self.send_to_scheduler.send_pyobj(req)
89
+
90
+ diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
91
+ index 7111c93..2bd722c 100644
92
+ --- a/python/sglang/srt/server.py
93
+ +++ b/python/sglang/srt/server.py
94
+ @@ -30,6 +30,8 @@ import threading
95
+ import time
96
+ from http import HTTPStatus
97
+ from typing import Dict, List, Optional, Union
98
+ +from rpdTracerControl import rpdTracerControl
99
+ +rpdTracerControl.skipCreate()
100
+
101
+ # Fix a bug of Python threading
102
+ setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
103
+ @@ -152,6 +154,11 @@ async def flush_cache():
104
+ @app.post("/start_profile")
105
+ async def start_profile():
106
+ """Start profiling."""
107
+ + rpd = rpdTracerControl()
108
+ + rpd.setPythonTrace(True)
109
+ + rpd.start()
110
+ + rpd.rangePush("", "server rpd profile range", "")
111
+ + logger.info("rpd profiling started in server.py!")
112
+ tokenizer_manager.start_profile()
113
+ return Response(
114
+ content="Start profiling.\n",
115
+ @@ -164,6 +171,11 @@ async def start_profile():
116
+ async def stop_profile():
117
+ """Stop profiling."""
118
+ tokenizer_manager.stop_profile()
119
+ + rpd = rpdTracerControl()
120
+ + rpd.rangePop()
121
+ + rpd.stop()
122
+ + rpd.flush()
123
+ + logger.info("rpd profiling is done in server.py!")
124
+ return Response(
125
+ content="Stop profiling. This will take some time.\n",
126
+ status_code=200,
sglang/3rdparty/amd/profiling/server.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # export SGLANG_TORCH_PROFILER_DIR=/data/sglang/
4
+ export SGLANG_TORCH_PROFILER_DIR=/sgl-workspace/sglang/profile/
5
+
6
+ # Get the current timestamp
7
+ TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
8
+
9
+ # Define the log file with a timestamp
10
+ LOGFILE="sglang_server_log_$TIMESTAMP.json"
11
+
12
+ # Run the Python command and save the output to the log file
13
+ loadTracer.sh python3 -m sglang.launch_server \
14
+ --model-path /sgl-workspace/sglang/dummy_grok1 \
15
+ --tokenizer-path Xenova/grok-1-tokenizer \
16
+ --load-format dummy \
17
+ --quantization fp8 \
18
+ --tp 8 \
19
+ --port 30000 \
20
+ --disable-radix-cache 2>&1 | tee "$LOGFILE"
sglang/3rdparty/amd/profiling/torch_profiler.patch ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
2
+ index 62d1ff9..6ecd78c 100644
3
+ --- a/python/sglang/srt/managers/scheduler.py
4
+ +++ b/python/sglang/srt/managers/scheduler.py
5
+ @@ -240,7 +240,6 @@ class Scheduler:
6
+ )
7
+ self.profiler = torch.profiler.profile(
8
+ activities=[
9
+ - torch.profiler.ProfilerActivity.CPU,
10
+ torch.profiler.ProfilerActivity.CUDA,
11
+ ],
12
+ with_stack=True,
13
+ @@ -1033,9 +1032,11 @@ class Scheduler:
14
+ if self.profiler is None:
15
+ raise RuntimeError("Profiler is not enabled.")
16
+ self.profiler.stop()
17
+ - self.profiler.export_chrome_trace(
18
+ - self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
19
+ - )
20
+ + if self.tp_rank == 0:
21
+ + with open(f"stats_repro_{int(time.time())}.txt", "w") as f:
22
+ + print(self.profiler.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=-1), file=f)
23
+ + print("Profiling stats done.")
24
+ +
25
+ logger.info("Profiler is done")
sglang/3rdparty/amd/sgl-kernel/CMakeLists_rocm.txt ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cmake_minimum_required(VERSION 3.24 FATAL_ERROR)
2
+ project(sgl_kernel LANGUAGES CXX)
3
+
4
+ # Cmake
5
+ set(CMAKE_CXX_STANDARD 17)
6
+ set(CMAKE_CXX_STANDARD_REQUIRED ON)
7
+ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
8
+ set(CMAKE_SHARED_LIBRARY_PREFIX "")
9
+
10
+ set(CMAKE_COLOR_DIAGNOSTICS ON)
11
+ set(CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL "ON")
12
+
13
+ # Python / Torch
14
+ find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
15
+
16
+ execute_process(
17
+ COMMAND ${Python_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)"
18
+ OUTPUT_VARIABLE TORCH_PY_PREFIX
19
+ OUTPUT_STRIP_TRAILING_WHITESPACE
20
+ )
21
+
22
+ set(Torch_DIR "${TORCH_PY_PREFIX}/Torch")
23
+ list(APPEND CMAKE_PREFIX_PATH "${TORCH_PY_PREFIX}/Torch")
24
+ find_package(Torch REQUIRED)
25
+
26
+ execute_process(
27
+ COMMAND ${Python_EXECUTABLE} -c "import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))"
28
+ OUTPUT_VARIABLE TORCH_CXX11_ABI
29
+ OUTPUT_STRIP_TRAILING_WHITESPACE
30
+ )
31
+ if(TORCH_CXX11_ABI STREQUAL "0")
32
+ add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
33
+ else()
34
+ add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1)
35
+ endif()
36
+
37
+ # ROCm/HIP
38
+ enable_language(HIP)
39
+ find_package(hip REQUIRED CONFIG)
40
+
41
+ # Determine AMDGPU target from environment variable or default to gfx942
42
+ set(AMDGPU_TARGET_ENV "$ENV{AMDGPU_TARGET}")
43
+
44
+ if(AMDGPU_TARGET_ENV)
45
+ # Use environment variable if specified
46
+ set(AMDGPU_TARGETS "${AMDGPU_TARGET_ENV}")
47
+ message(STATUS "Using AMDGPU_TARGET from environment: ${AMDGPU_TARGETS}")
48
+ else()
49
+ # Default to gfx942 only
50
+ set(AMDGPU_TARGETS "gfx942")
51
+ message(STATUS "AMDGPU_TARGET not set, defaulting to gfx942")
52
+ endif()
53
+
54
+ # Set HIP architectures
55
+ set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
56
+
57
+ # FP8 macro selection
58
+ # Always define HIP_FP8_TYPE_FNUZ=1 (for gfx942 and host compilation)
59
+ # Additionally define HIP_FP8_TYPE_E4M3=1 when building for gfx950
60
+ # The existing utils.h logic will pick the right one based on architecture
61
+ set(SGL_FP8_MACROS "-DHIP_FP8_TYPE_FNUZ=1")
62
+
63
+ if(AMDGPU_TARGETS MATCHES "gfx950")
64
+ list(APPEND SGL_FP8_MACROS "-DHIP_FP8_TYPE_E4M3=1")
65
+ message(STATUS "Multi-arch build: Enabling both HIP_FP8_TYPE_FNUZ (gfx942) and HIP_FP8_TYPE_E4M3 (gfx950)")
66
+ elseif(AMDGPU_TARGETS MATCHES "gfx942")
67
+ message(STATUS "Single-arch build: Enabling HIP_FP8_TYPE_FNUZ for gfx942")
68
+ else()
69
+ message(FATAL_ERROR "Unsupported AMDGPU_TARGET '${AMDGPU_TARGETS}'. Expected 'gfx942' or 'gfx950' or both.")
70
+ endif()
71
+
72
+ # TopK dynamic smem bytes
73
+ # Dynamic shared-memory budget for the TopK kernels.
74
+ # - gfx942 (MI300/MI325): LDS is typically 64KB per workgroup -> keep dynamic smem <= ~48KB
75
+ # (leaves room for static shared allocations in the kernel).
76
+ # - gfx95x (MI350): LDS is larger (e.g. 160KB per CU) -> allow the original 128KB dynamic smem.
77
+ if(AMDGPU_TARGET_ONE STREQUAL "gfx942")
78
+ math(EXPR SGL_TOPK_DYNAMIC_SMEM_BYTES "48 * 1024")
79
+ else()
80
+ math(EXPR SGL_TOPK_DYNAMIC_SMEM_BYTES "32 * 1024 * 4")
81
+ endif()
82
+
83
+ set(SGL_TOPK_MACROS "-DSGL_TOPK_DYNAMIC_SMEM_BYTES=${SGL_TOPK_DYNAMIC_SMEM_BYTES}")
84
+
85
+ # Paths / includes
86
+ set(PROJ_ROOT ${CMAKE_CURRENT_LIST_DIR})
87
+ set(SGL_INCLUDE_DIRS
88
+ ${PROJ_ROOT}/include
89
+ ${PROJ_ROOT}/include/impl
90
+ ${PROJ_ROOT}/csrc
91
+ ${TORCH_INCLUDE_DIRS}
92
+ )
93
+
94
+ # Platform-specific library directory
95
+ set(PLAT_LIB_DIR "/usr/lib/x86_64-linux-gnu")
96
+ link_directories(${PLAT_LIB_DIR})
97
+
98
+ # Sources
99
+ set(SOURCES
100
+ ${PROJ_ROOT}/csrc/allreduce/custom_all_reduce.hip
101
+ ${PROJ_ROOT}/csrc/allreduce/deterministic_all_reduce.hip
102
+ ${PROJ_ROOT}/csrc/allreduce/quick_all_reduce.hip
103
+ ${PROJ_ROOT}/csrc/common_extension_rocm.cc
104
+ ${PROJ_ROOT}/csrc/elementwise/activation.hip
105
+ ${PROJ_ROOT}/csrc/elementwise/pos_enc.hip
106
+ ${PROJ_ROOT}/csrc/elementwise/topk.hip
107
+ ${PROJ_ROOT}/csrc/grammar/apply_token_bitmask_inplace_hip.hip
108
+ ${PROJ_ROOT}/csrc/kvcacheio/transfer.hip
109
+ ${PROJ_ROOT}/csrc/moe/moe_align_kernel.hip
110
+ ${PROJ_ROOT}/csrc/moe/moe_topk_softmax_kernels.hip
111
+ ${PROJ_ROOT}/csrc/moe/moe_topk_sigmoid_kernels.hip
112
+ ${PROJ_ROOT}/csrc/speculative/eagle_utils.hip
113
+ )
114
+ set_source_files_properties(
115
+ ${SOURCES}
116
+ PROPERTIES
117
+ LANGUAGE HIP
118
+ )
119
+
120
+ # Compile / Link flags
121
+ add_compile_options($<$<COMPILE_LANGUAGE:CXX>:-O3>)
122
+
123
+ set(SGL_HIP_FLAGS
124
+ -DNDEBUG
125
+ -DOPERATOR_NAMESPACE=sgl_kernel
126
+ -O3
127
+ -std=c++17
128
+ -DENABLE_BF16
129
+ -DENABLE_FP8
130
+ ${SGL_FP8_MACROS}
131
+ -Wno-pass-failed
132
+ -Wundefined-internal
133
+ ${SGL_TOPK_MACROS}
134
+ )
135
+
136
+ # Python extension
137
+ Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
138
+ target_include_directories(common_ops PRIVATE ${SGL_INCLUDE_DIRS})
139
+
140
+ # Apply per-language flags
141
+ target_compile_options(common_ops PRIVATE
142
+ $<$<COMPILE_LANGUAGE:HIP>:${SGL_HIP_FLAGS}>
143
+ )
144
+
145
+ target_link_libraries(common_ops PRIVATE
146
+ ${TORCH_LIBRARIES}
147
+ hip::device
148
+ hip::host
149
+ hiprtc
150
+ amdhip64
151
+ )
152
+
153
+ target_link_options(common_ops PRIVATE
154
+ "SHELL:-Wl,-rpath,'\$ORIGIN/../../torch/lib'"
155
+ )
156
+
157
+ install(TARGETS common_ops
158
+ LIBRARY DESTINATION sgl_kernel
159
+ )
sglang/3rdparty/amd/sgl-kernel/build_rocm.sh ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+ ROCM_VERSION=$1
4
+
5
+ PYTHON_ROOT_PATH="/opt/venv/bin"
6
+ AMDGPU_TARGET="gfx942;gfx950"
7
+
8
+ echo "Python root path is: $PYTHON_ROOT_PATH"
9
+
10
+ # Get version from git tags
11
+ SGLANG_VERSION="v0.5.6" # Default version, will be overridden if git tags are found
12
+
13
+ # Fetch tags from origin to ensure we have the latest
14
+ if git fetch --tags origin; then
15
+ # Get the latest version tag sorted by version number (e.g., v0.5.7)
16
+ VERSION_FROM_TAG=$(git tag -l 'v[0-9]*' --sort=-v:refname | head -1)
17
+ if [ -n "$VERSION_FROM_TAG" ]; then
18
+ SGLANG_VERSION="$VERSION_FROM_TAG"
19
+ echo "Using SGLang version from git tags: $SGLANG_VERSION"
20
+ else
21
+ echo "Warning: No version tags found; using default $SGLANG_VERSION" >&2
22
+ fi
23
+ else
24
+ echo "Warning: Failed to fetch tags from origin; using default $SGLANG_VERSION" >&2
25
+ fi
26
+
27
+ # Default base tags (can be overridden by command line arguments)
28
+ DEFAULT_MI30X_BASE_TAG="${SGLANG_VERSION}-rocm700-mi30x"
29
+ DEFAULT_MI35X_BASE_TAG="${SGLANG_VERSION}-rocm700-mi35x"
30
+
31
+ # Parse command line arguments
32
+ MI30X_BASE_TAG="${DEFAULT_MI30X_BASE_TAG}"
33
+ MI35X_BASE_TAG="${DEFAULT_MI35X_BASE_TAG}"
34
+
35
+ # Detect GPU architecture from the Kubernetes runner hostname
36
+ HOSTNAME_VALUE=$(hostname)
37
+ GPU_ARCH="mi30x" # default
38
+
39
+ # Host names look like: linux-mi35x-gpu-1-xxxxx-runner-zzzzz
40
+ if [[ "${HOSTNAME_VALUE}" =~ ^linux-(mi[0-9]+[a-z]*)-gpu-[0-9]+ ]]; then
41
+ GPU_ARCH="${BASH_REMATCH[1]}"
42
+ echo "Detected GPU architecture from hostname: ${GPU_ARCH}"
43
+ else
44
+ echo "Warning: could not parse GPU architecture from '${HOSTNAME_VALUE}', defaulting to ${GPU_ARCH}"
45
+ fi
46
+
47
+ case "${GPU_ARCH}" in
48
+ mi35x)
49
+ echo "Runner uses ${GPU_ARCH}; will fetch mi35x image."
50
+ ;;
51
+ mi30x|mi300|mi325)
52
+ echo "Runner uses ${GPU_ARCH}; will fetch mi30x image."
53
+ GPU_ARCH="mi30x"
54
+ ;;
55
+ *)
56
+ echo "Runner architecture '${GPU_ARCH}' unrecognised; defaulting to mi30x image." >&2
57
+ GPU_ARCH="mi30x"
58
+ ;;
59
+ esac
60
+
61
+ if [[ -f /etc/podinfo/gha-render-devices ]]; then
62
+ DEVICE_FLAG=$(cat /etc/podinfo/gha-render-devices)
63
+ else
64
+ DEVICE_FLAG="--device /dev/dri"
65
+ fi
66
+
67
+ # Find the latest image
68
+ find_latest_image() {
69
+ local gpu_arch=$1
70
+ local base_tag days_back image_tag
71
+
72
+ case "${gpu_arch}" in
73
+ mi30x) base_tag="${MI30X_BASE_TAG}" ;;
74
+ mi35x) base_tag="${MI35X_BASE_TAG}" ;;
75
+ *) echo "Error: unsupported GPU architecture '${gpu_arch}'" >&2; return 1 ;;
76
+ esac
77
+
78
+ for days_back in {0..6}; do
79
+ image_tag="${base_tag}-$(date -d "${days_back} days ago" +%Y%m%d)"
80
+ echo "Checking for image: rocm/sgl-dev:${image_tag}" >&2
81
+ if docker manifest inspect "rocm/sgl-dev:${image_tag}" >/dev/null 2>&1; then
82
+ echo "Found available image: rocm/sgl-dev:${image_tag}" >&2
83
+ echo "rocm/sgl-dev:${image_tag}"
84
+ return 0
85
+ fi
86
+ done
87
+
88
+ echo "Error: no ${gpu_arch} image found in the last 7 days for base ${base_tag}" >&2
89
+ echo "Using hard-coded fallback…" >&2
90
+ if [[ "${gpu_arch}" == "mi35x" ]]; then
91
+ echo "rocm/sgl-dev:v0.5.3-rocm700-mi35x-20251009"
92
+ else
93
+ echo "rocm/sgl-dev:v0.5.3-rocm700-mi30x-20251009"
94
+ fi
95
+ }
96
+
97
+ # Pull and run the latest image
98
+ IMAGE=$(find_latest_image "${GPU_ARCH}")
99
+ echo "Pulling Docker image: ${IMAGE}"
100
+ docker pull "${IMAGE}"
101
+
102
+ docker run --rm \
103
+ -v $(pwd):/sgl-kernel \
104
+ -e AMDGPU_TARGET="${AMDGPU_TARGET}" \
105
+ ${IMAGE} \
106
+ bash -c "
107
+ # Install CMake (version >= 3.26) - Robust Installation
108
+ export CMAKE_VERSION_MAJOR=3.31
109
+ export CMAKE_VERSION_MINOR=1
110
+ echo \"Downloading CMake from: https://cmake.org/files/v\${CMAKE_VERSION_MAJOR}/cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-x86_64.tar.gz\"
111
+ wget https://cmake.org/files/v\${CMAKE_VERSION_MAJOR}/cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-x86_64.tar.gz
112
+ tar -xzf cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-x86_64.tar.gz
113
+ mv cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-x86_64 /opt/cmake
114
+ export PATH=/opt/cmake/bin:\$PATH
115
+
116
+ ${PYTHON_ROOT_PATH}/pip install --no-cache-dir ninja setuptools wheel numpy uv scikit-build-core && \
117
+
118
+ cd /sgl-kernel && \
119
+ rm -rf CMakeLists.txt && mv CMakeLists_rocm.txt CMakeLists.txt && \
120
+ ${PYTHON_ROOT_PATH}/python rocm_hipify.py && \
121
+ ${PYTHON_ROOT_PATH}/python -m uv build --wheel -Cbuild-dir=build . --color=always --no-build-isolation && \
122
+ ./rename_wheels_rocm.sh
123
+ "
sglang/3rdparty/amd/sgl-kernel/rename_wheels_rocm.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -ex
3
+
4
+ WHEEL_DIR="dist"
5
+
6
+ wheel_files=($WHEEL_DIR/*.whl)
7
+ for wheel in "${wheel_files[@]}"; do
8
+ intermediate_wheel="${wheel/linux/manylinux2014}"
9
+
10
+ # Extract the current python version from the wheel name
11
+ if [[ $intermediate_wheel =~ -cp([0-9]+)- ]]; then
12
+ cp_version="${BASH_REMATCH[1]}"
13
+ else
14
+ echo "Could not extract Python version from wheel name: $intermediate_wheel"
15
+ continue
16
+ fi
17
+
18
+ # Detect ROCm version and add appropriate suffix
19
+ if ls /opt | grep -q "7.0"; then
20
+ new_wheel="${intermediate_wheel/-cp${cp_version}/+rocm700-cp${cp_version}}"
21
+ else
22
+ new_wheel="$intermediate_wheel"
23
+ fi
24
+
25
+ if [[ "$wheel" != "$new_wheel" ]]; then
26
+ echo "Renaming $wheel to $new_wheel"
27
+ mv -- "$wheel" "$new_wheel"
28
+ fi
29
+ done
30
+ echo "Wheel renaming completed."
sglang/3rdparty/amd/sgl-kernel/rocm_hipify.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ from torch.utils.cpp_extension import CUDAExtension
5
+
6
+ root = Path(__file__).parent.resolve()
7
+
8
+ include_dirs = [
9
+ root / "include",
10
+ root / "include" / "impl",
11
+ root / "csrc",
12
+ ]
13
+
14
+ sources = [
15
+ "csrc/allreduce/custom_all_reduce.hip",
16
+ "csrc/allreduce/deterministic_all_reduce.hip",
17
+ "csrc/allreduce/quick_all_reduce.cu",
18
+ "csrc/common_extension_rocm.cc",
19
+ "csrc/elementwise/activation.cu",
20
+ "csrc/elementwise/pos_enc.cu",
21
+ "csrc/elementwise/topk.cu",
22
+ "csrc/grammar/apply_token_bitmask_inplace_cuda.cu",
23
+ "csrc/kvcacheio/transfer.cu",
24
+ "csrc/moe/moe_align_kernel.cu",
25
+ "csrc/moe/moe_topk_softmax_kernels.cu",
26
+ "csrc/moe/moe_topk_sigmoid_kernels.cu",
27
+ "csrc/speculative/eagle_utils.cu",
28
+ ]
29
+
30
+ libraries = ["hiprtc", "amdhip64", "c10", "torch", "torch_python"]
31
+
32
+ ext_modules = [
33
+ CUDAExtension(
34
+ name="sgl_kernel.common_ops",
35
+ sources=sources,
36
+ include_dirs=include_dirs,
37
+ libraries=libraries,
38
+ py_limited_api=False,
39
+ ),
40
+ ]
sglang/3rdparty/amd/tuning/TUNING.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Tuning SGLang Infer System with AMD GPUs
2
+ This AppNote describes the SGLang performance tuning technical, code harness and running steps for systems with AMD Instinct GPUs.
3
+ Harness code, examples and steps are provided in detail, to facilitate easy reproduce & use to tune performance towards workloads.
4
+ Three primary runtime areas are covered:
5
+
6
+ ## 1. Triton Kernels
7
+ To maximize Triton kernel efficiency, several strategies can be employed:
8
+
9
+ ### Key Environment Variables:
10
+ - **num_stages**: Adjusts the number of pipeline stages to optimize kernel efficiency based on the specific type of operations (e.g., General Matrix Multiplication - GEMM).
11
+ - **waves_per_eu**: Controls the usage of Vector General Purpose Registers (VGPR) to enhance occupancy, thereby improving latency or throughput.
12
+ - **BLOCK_M, BLOCK_N, BLOCK_K**: Tunable tile sizes that assist in balancing memory transfer and computational efficiency.
13
+ - **matrix_instr_nonkdim**: Optimizes the usage of Matrix-Fused Multiply-Add (MFMA) instructions for specific kernel types, such as Flash Attention.
14
+ - **OPTIMIZE_EPILOGUE**: An environment variable that can be set to `1` to enhance performance by eliminating the `convert_layout` operation in the kernel's epilogue.
15
+ ```python
16
+ @triton.autotune(configs=[
17
+ triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1),
18
+ triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1),
19
+ triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1),
20
+ triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1),
21
+ triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1),
22
+ triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1),
23
+ triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1),
24
+ triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1),
25
+ triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1),
26
+ ], key=['BLOCK_N', 'NUM_TOKEN_BLKS'], use_cuda_graph=True)
27
+ @triton.jit
28
+ def _triton_kernel_function():
29
+ ...
30
+ ```
31
+ ## 2. Torch Tunable Operations
32
+ **TunableOp** is a feature in PyTorch that allows for the definition and optimization of custom kernels with tunable parameters. This feature is particularly useful for enhancing the performance of kernels by experimenting with different configurations.
33
+
34
+ ### Key Environment Variables:
35
+ 1. **PYTORCH_TUNABLEOP_ENABLED**:
36
+ - Default: `0`
37
+ - Set to `1` to enable TunableOp.
38
+
39
+ 2. **PYTORCH_TUNABLEOP_TUNING**:
40
+ - Default: `1`
41
+ - Set to `0` to disable tuning. If a tuned entry is not found, it will run the tuning step and record the entry when PYTORCH_TUNABLEOP_ENABLED is enabled.
42
+
43
+ 3. **PYTORCH_TUNABLEOP_VERBOSE**:
44
+ - Default: `0`
45
+ - Set to `1` to enable verbose output for TunableOp.
46
+
47
+ ### Usage Example:
48
+ To enable TunableOp and tuning, and optionally enable verbose mode, you can run the following command in your terminal:
49
+
50
+ ```bash
51
+ #Tuning
52
+ PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=1 your_script.sh
53
+
54
+ #Inference with tuning op
55
+ PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 your_script.sh
56
+
57
+ #Print out the log
58
+ PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 PYTORCH_TUNABLEOP_VERBOSE=1 your_script.sh
59
+
60
+ ```
61
+ ## 3. Torch Compilation
62
+
63
+
64
+ The following are suggestions for optimizing matrix multiplication (GEMM) and convolution (conv) operations in PyTorch using Inductor, a part of the PyTorch compilation framework. The goal is to leverage Triton to achieve better performance.
65
+
66
+ To tune Triton kernels with GEMM and convolution ops (conv), use the `torch.compile` function with the max-autotune mode. This benchmarks a predefined list of Triton configurations and selects the fastest one for each shape.
67
+
68
+ ### Key Configurations:
69
+ 1. **Max Autotune**:
70
+ - Set `torch._inductor.config.max_autotune = True` or `TORCHINDUCTOR_MAX_AUTOTUNE=1`.
71
+
72
+ 2. **Fine-Grained Control**:
73
+ - Enable GEMM tuning: `torch._inductor.config.max_autotune_gemm = True`.
74
+ - Enable tuning for pointwise/reduction ops: `torch._inductor.config.max_autotune.pointwise = True`.
75
+
76
+ 3. **Backend Selection**:
77
+ - Use `torch._inductor.max_autotune_gemm_backends` to limit backends to TRITON for better performance.
78
+
79
+ 4. **Freezing for Inference**:
80
+ - Use `torch._inductor.config.freezing=True` to enable constant folding optimizations.
81
+
82
+ 5. **Debugging**:
83
+ - Set `TORCH_COMPILE_DEBUG=1` to extract Triton kernels generated by Inductor.
84
+
85
+ ### Example Code Block:
86
+ ```bash
87
+ #Gemm Tuning
88
+ TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 your_script.sh
89
+
90
+ #Specify your backend to TRITON for Gemm Tuning
91
+ TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=TRITON your_script.sh
92
+
93
+ #Inference with large improvement on AMD GPU
94
+ TORCHINDUCTOR_FREEZING=1 your_script.sh
95
+ ```
96
+ ## 4. Fused MOE kernel
97
+ To maximize moe kernel efficiency, need to use below scripts to find out the best launch configuration
98
+
99
+ ### Key parameters:
100
+ - **--model**: what moe model type to do tuning, it will automatically decide the size of d_model, model_intermediate_size, num_layers
101
+ - **--tp-size**: simulate the whole model run configuration to set the dimension size using tp correctly
102
+ - **--batch**: M dimension size of moe kernel, for prefill moe kernel the value is batch*input_len, for decode moe kernel the value is batch
103
+ - **--dtype**: computation type
104
+
105
+ ```bash
106
+ #Tuning
107
+ #for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quantization fp8" to run, it defined batch-size 32 input length 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run).
108
+ #so we can tune decode moe use below command
109
+ python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32"
110
+ # and use this command to tune prefill moe
111
+ python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32768"
112
+ ```
113
+
114
+ ## Reference
115
+
116
+ For more detailed information on tuning SGLang performance with AMD GPUs, please refer to the following link:
117
+
118
+ [ROCm Documentation: Triton Kernel Performance Optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#triton-kernel-performance-optimization)
sglang/3rdparty/amd/tuning/benchmark_moe_rocm.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import sys
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import triton
9
+ import triton.language as tl
10
+ from tqdm import tqdm
11
+ from transformers import AutoConfig
12
+
13
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
14
+ fused_moe,
15
+ get_config_file_name,
16
+ )
17
+
18
+ padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
19
+
20
+
21
+ def main(model, tp_size, dtype: str, batches):
22
+ method = fused_moe
23
+
24
+ for bs in batches:
25
+ run_grid(int(bs), model=model, method=method, tp_size=tp_size, dtype=dtype)
26
+
27
+
28
+ def prune_configs(M, N, K, configs):
29
+ pruned_configs = []
30
+ elemBytes_a = 1 # [DV Note] Hard-coded for float16 (2 bytes)
31
+ elemBytes_b = 1 # [DV Note] Hard-coded for float16 (2 bytes)
32
+
33
+ mfma = 16 if M < 32 or N < 32 else 32
34
+
35
+ # TODO (zhanglx): figure out the boundary between large and small gemms
36
+ large_gemm = False
37
+ if M >= 2048 and N >= 2048:
38
+ large_gemm = True
39
+
40
+ for config in configs:
41
+ BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
42
+ BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
43
+ BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
44
+ num_warps = config.get("num_warps")
45
+ matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
46
+ # kpack = config.get("kpack")
47
+ if matrix_instr_nonkdim > mfma:
48
+ continue
49
+ if mfma == 4 and BLOCK_SIZE_K < 64:
50
+ continue
51
+ # some layouts could not work properly in case
52
+ # number elements per thread is less 1
53
+ if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
54
+ continue
55
+ SPLIT_K = 1 # config.get("SPLIT_K")
56
+ GROUP_M = config.get("GROUP_SIZE_M")
57
+ if matrix_instr_nonkdim > BLOCK_SIZE_M or matrix_instr_nonkdim > BLOCK_SIZE_N:
58
+ continue
59
+ if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
60
+ continue
61
+ if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
62
+ continue
63
+ # Skip BLOCK_SIZE that is too large compare to M/N
64
+ # unless BLOCK_SIZE is already small enough
65
+ if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
66
+ continue
67
+ if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
68
+ continue
69
+ # skip large split_k when not necessary
70
+ if SPLIT_K != 1 and not need_split_k(M, N, K):
71
+ continue
72
+ # skip split_k that leads to EVEN_K = false
73
+ leap = SPLIT_K * BLOCK_SIZE_K
74
+ modv = K % leap
75
+ if modv != 0:
76
+ continue
77
+ # skip large GROUP_M
78
+ if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
79
+ continue
80
+ # out of shared memory resource
81
+ # TODO (zhanglx): This does not consider the LDS usage in the epilogue
82
+ LDS = (
83
+ BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
84
+ + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
85
+ )
86
+ if LDS > 65536:
87
+ continue
88
+ # Skip small block sizes and num_warps for large gemm
89
+ # For fp16 and f8, we want to only use BLOCK_SIZE >= 64
90
+ if large_gemm:
91
+ if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
92
+ continue
93
+ if BLOCK_SIZE_K < 64:
94
+ continue
95
+ if num_warps < 4:
96
+ continue
97
+
98
+ pruned_configs.append(config)
99
+
100
+ return pruned_configs
101
+
102
+
103
+ def union_of_list_of_dicts(l1, l2):
104
+ result = []
105
+ temp_list = l1.copy()
106
+ temp_list.extend(l2)
107
+ for myDict in temp_list:
108
+ if myDict not in result:
109
+ result.append(myDict)
110
+
111
+ return result
112
+
113
+
114
+ def run_grid(bs, model, method, tp_size, dtype: str):
115
+
116
+ config = AutoConfig.from_pretrained(model)
117
+
118
+ top_k = config.num_experts_per_tok
119
+ d_model = config.hidden_size
120
+ model_intermediate_size = config.intermediate_size
121
+ num_layers = config.num_hidden_layers
122
+ hidden_states_dtype = config.torch_dtype
123
+
124
+ if config.num_experts_per_tok:
125
+ if config.architectures[0] == "Grok1ModelForCausalLM":
126
+ num_total_experts = config.num_experts
127
+ else:
128
+ num_total_experts = config.num_local_experts
129
+ else:
130
+ raise ValueError(f"Unsupported Mixtral model {model}")
131
+
132
+ # tp_size = 2
133
+ num_warmup_calls = 10
134
+ num_calls = 30
135
+
136
+ num_warmup_trials = 1
137
+ num_trials = 1
138
+
139
+ full_configs = []
140
+
141
+ block_m_range = [16, 32, 64, 128, 256]
142
+ block_n_range = [16, 32, 64, 128, 256]
143
+ block_k_range = [32, 64, 128, 256] # MUST >= 32
144
+ num_warps_range = [1, 2, 4, 8]
145
+ group_m_range = [1, 4, 8, 16, 32]
146
+ # For now we see better perf with num_stages=0 for all gemm configs we care
147
+ # But keep this explicit so that we do not forget we may need to set it to
148
+ # other values in the future
149
+ num_stage_range = [2]
150
+ waves_per_eu_range = [0, 1, 2, 4, 8]
151
+ # Remove 32 because of triton compiling error
152
+ matrix_instr_nonkdim_range = [16]
153
+ kpack_range = [1, 2]
154
+
155
+ for block_size_m in block_m_range:
156
+ for block_size_n in block_n_range:
157
+ for block_size_k in block_k_range:
158
+ for group_size_m in group_m_range:
159
+ for num_warps in num_warps_range:
160
+ for num_stages in num_stage_range:
161
+ for waves_per_eu in waves_per_eu_range:
162
+ for matrix_instr_nonkdim in matrix_instr_nonkdim_range:
163
+ for kpack in kpack_range:
164
+ full_configs.append(
165
+ {
166
+ "BLOCK_SIZE_M": block_size_m,
167
+ "BLOCK_SIZE_N": block_size_n,
168
+ "BLOCK_SIZE_K": block_size_k,
169
+ "GROUP_SIZE_M": group_size_m,
170
+ "num_warps": num_warps,
171
+ "num_stages": num_stages,
172
+ "waves_per_eu": waves_per_eu,
173
+ "matrix_instr_nonkdim": matrix_instr_nonkdim,
174
+ "kpack": kpack,
175
+ }
176
+ )
177
+
178
+ M1 = bs * 2
179
+ N1 = model_intermediate_size * 2 // tp_size
180
+ K1 = d_model
181
+ prune_configs_1 = prune_configs(M1, N1, K1, full_configs)
182
+
183
+ M2 = bs * 2
184
+ N2 = d_model
185
+ K2 = model_intermediate_size // tp_size
186
+ prune_configs_2 = prune_configs(M2, N2, K2, full_configs)
187
+
188
+ configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2)
189
+
190
+ print(f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \
191
+ {len(prune_configs_2)=} | {len(configs)=}")
192
+
193
+ best_config = None
194
+ best_time_us = 1e20
195
+
196
+ print(f"{tp_size=} {bs=}")
197
+
198
+ for config in tqdm(configs):
199
+ # warmup
200
+ try:
201
+ print(config)
202
+ for _ in range(num_warmup_trials):
203
+ run_timing(
204
+ num_calls=num_warmup_calls,
205
+ bs=bs,
206
+ d_model=d_model,
207
+ num_total_experts=num_total_experts,
208
+ top_k=top_k,
209
+ tp_size=tp_size,
210
+ model_intermediate_size=model_intermediate_size,
211
+ method=method,
212
+ config=config,
213
+ dtype=dtype,
214
+ hidden_states_dtype=hidden_states_dtype,
215
+ )
216
+ except triton.runtime.autotuner.OutOfResources:
217
+ continue
218
+
219
+ # trial
220
+ for _ in range(num_trials):
221
+ kernel_dur_ms = run_timing(
222
+ num_calls=num_calls,
223
+ bs=bs,
224
+ d_model=d_model,
225
+ num_total_experts=num_total_experts,
226
+ top_k=top_k,
227
+ tp_size=tp_size,
228
+ model_intermediate_size=model_intermediate_size,
229
+ method=method,
230
+ config=config,
231
+ dtype=dtype,
232
+ hidden_states_dtype=hidden_states_dtype,
233
+ )
234
+
235
+ kernel_dur_us = 1000 * kernel_dur_ms
236
+ model_dur_ms = kernel_dur_ms * num_layers
237
+
238
+ if kernel_dur_us < best_time_us:
239
+ best_config = config
240
+ best_time_us = kernel_dur_us
241
+
242
+ tqdm.write(
243
+ f"{kernel_dur_us=:.1f} {model_dur_ms=:.1f}"
244
+ f" {bs=} {tp_size=} {top_k=} {num_total_experts=} "
245
+ f"{d_model=} {model_intermediate_size=} {num_layers=}"
246
+ )
247
+
248
+ print("best_time_us", best_time_us)
249
+ print("best_config", best_config)
250
+
251
+ # holds Dict[str, Dict[str, int]]
252
+ filename = get_config_file_name(
253
+ num_total_experts,
254
+ model_intermediate_size // tp_size,
255
+ "float8" if dtype == "float8" else None,
256
+ )
257
+ print(f"writing config to file {filename}")
258
+ existing_content = {}
259
+ if os.path.exists(filename):
260
+ with open(filename, "r") as f:
261
+ existing_content = json.load(f)
262
+ existing_content[str(bs)] = best_config
263
+ with open(filename, "w") as f:
264
+ json.dump(existing_content, f, indent=4)
265
+ f.write("\n")
266
+
267
+
268
+ def run_timing(
269
+ num_calls: int,
270
+ bs: int,
271
+ d_model: int,
272
+ num_total_experts: int,
273
+ top_k: int,
274
+ tp_size: int,
275
+ model_intermediate_size: int,
276
+ method,
277
+ config,
278
+ dtype: str,
279
+ hidden_states_dtype,
280
+ ) -> float:
281
+ shard_intermediate_size = model_intermediate_size // tp_size
282
+
283
+ hidden_states = torch.rand(
284
+ (bs, d_model),
285
+ device="cuda:0",
286
+ dtype=hidden_states_dtype,
287
+ )
288
+
289
+ w1 = torch.rand(
290
+ (num_total_experts, 2 * shard_intermediate_size, d_model + padding_size),
291
+ device=hidden_states.device,
292
+ dtype=hidden_states.dtype,
293
+ )
294
+
295
+ w2 = torch.rand(
296
+ (num_total_experts, d_model, shard_intermediate_size + padding_size),
297
+ device=hidden_states.device,
298
+ dtype=hidden_states.dtype,
299
+ )
300
+
301
+ w1_scale = None
302
+ w2_scale = None
303
+ a1_scale = None
304
+ a2_scale = None
305
+
306
+ if dtype == "float8":
307
+ w1 = w1.to(torch.float8_e4m3fnuz)
308
+ w2 = w2.to(torch.float8_e4m3fnuz)
309
+ w1_scale = torch.ones(
310
+ num_total_experts, device=hidden_states.device, dtype=torch.float32
311
+ )
312
+ w2_scale = torch.ones(
313
+ num_total_experts, device=hidden_states.device, dtype=torch.float32
314
+ )
315
+ a1_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)
316
+ a2_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)
317
+
318
+ gating_output = F.softmax(
319
+ torch.rand(
320
+ (num_calls, bs, num_total_experts),
321
+ device=hidden_states.device,
322
+ dtype=torch.float32,
323
+ ),
324
+ dim=-1,
325
+ )
326
+
327
+ ##################################
328
+
329
+ start_event = torch.cuda.Event(enable_timing=True)
330
+ end_event = torch.cuda.Event(enable_timing=True)
331
+
332
+ start_event.record()
333
+ for i in range(num_calls):
334
+ hidden_states = method(
335
+ hidden_states=hidden_states,
336
+ w1=w1,
337
+ w2=w2,
338
+ w1_scale=w1_scale,
339
+ w2_scale=w2_scale,
340
+ a1_scale=a1_scale,
341
+ a2_scale=a2_scale,
342
+ gating_output=gating_output[0],
343
+ topk=top_k,
344
+ renormalize=True,
345
+ inplace=True,
346
+ override_config=config,
347
+ use_fp8=dtype == "float8",
348
+ )
349
+
350
+ end_event.record()
351
+ end_event.synchronize()
352
+
353
+ dur_ms = start_event.elapsed_time(end_event) / num_calls
354
+ return dur_ms
355
+
356
+
357
+ if __name__ == "__main__":
358
+ parser = argparse.ArgumentParser(
359
+ prog="benchmark_mixtral_moe",
360
+ description="Benchmark and tune the fused_moe kernel",
361
+ )
362
+ parser.add_argument(
363
+ "--dtype",
364
+ type=str,
365
+ default="auto",
366
+ choices=["float8", "float16", "bfloat16"],
367
+ help="Data type used for fused_moe kernel computations",
368
+ )
369
+ parser.add_argument("--model", type=str, default="hpcai-tech/grok-1")
370
+
371
+ parser.add_argument("--tp-size", type=int, default=2, help="Tensor paralleli size")
372
+ parser.add_argument("-b", "--batches", type=str)
373
+
374
+ args = parser.parse_args()
375
+
376
+ batches = args.batches.split(",")
377
+
378
+ sys.exit(main(args.model, args.tp_size, args.dtype, batches))
sglang/docs/supported_models/extending/modelscope.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Models From ModelScope
2
+
3
+ To use a model from [ModelScope](https://www.modelscope.cn), set the environment variable `SGLANG_USE_MODELSCOPE`.
4
+
5
+ ```bash
6
+ export SGLANG_USE_MODELSCOPE=true
7
+ ```
8
+
9
+ We take [Qwen2-7B-Instruct](https://www.modelscope.cn/models/qwen/qwen2-7b-instruct) as an example.
10
+
11
+ Launch the Server:
12
+ ```bash
13
+ python -m sglang.launch_server --model-path qwen/Qwen2-7B-Instruct --port 30000
14
+ ```
15
+
16
+ Or start it by docker:
17
+
18
+ ```bash
19
+ docker run --gpus all \
20
+ -p 30000:30000 \
21
+ -v ~/.cache/modelscope:/root/.cache/modelscope \
22
+ --env "SGLANG_USE_MODELSCOPE=true" \
23
+ --ipc=host \
24
+ lmsysorg/sglang:latest \
25
+ python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0 --port 30000
26
+ ```
27
+
28
+ Note that modelscope uses a different cache directory than huggingface. You may need to set it manually to avoid running out of disk space.
sglang/docs/supported_models/extending/support_new_models.md ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to Support New Models
2
+
3
+ This document explains how to add support for new language models and multimodal large language models (MLLMs) in
4
+ SGLang. It also covers how to test new models and register external implementations.
5
+
6
+ ## How to Support a New Language Model
7
+
8
+ To support a new model in SGLang, you only need to add a single file under
9
+ the [SGLang Models Directory](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/models). You can learn
10
+ from existing model implementations and create a new file for your model. For most models, you should be able to find a
11
+ similar model to start with (e.g., starting from Llama). Also refer how
12
+ to [port a Model from vLLM to SGLang](#port-a-model-from-vllm-to-sglang)
13
+
14
+ ## How to Support a New Multimodal Large Language Model
15
+
16
+ To support a new multimodal large language model (MLLM) in SGLang, there are several key components in addition to the
17
+ standard LLM support:
18
+
19
+ 1. **Register your new model as multimodal**:
20
+ Extend `is_multimodal_model`
21
+ in [model_config.py](https://github.com/sgl-project/sglang/blob/0ab3f437aba729b348a683ab32b35b214456efc7/python/sglang/srt/configs/model_config.py#L561)
22
+ to return `True` for your model.
23
+
24
+ 2. **Register a new chat-template**:
25
+ Only when your default chat-template is unable to accept images as input: Register a new chat template in [conversation.py](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/conversation.py) and the corresponding matching function.
26
+
27
+ 3. **Multimodal Data Processor**:
28
+ Define a new `Processor` class that inherits from `BaseMultimodalProcessor` and register this processor as your
29
+ model’s dedicated processor.
30
+ See [multimodal_processor.py](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/multimodal/processors)
31
+ for more details.
32
+
33
+ 4. **Handle Multimodal Tokens**:
34
+ Implement a `pad_input_ids` function for your new model. In this function, multimodal tokens in the prompt should be
35
+ expanded (if necessary) and padded with multimodal-data-hashes so that SGLang can recognize different multimodal data
36
+ with `RadixAttention`.
37
+
38
+ 5. **Handle Image Feature Extraction**:
39
+ Implement a `get_image_feature` function for your new model, which extracts image features from raw image data and converts them into the embeddings used by the language model.
40
+
41
+ 6. **Adapt to Vision Attention**:
42
+ Adapt the multi-headed `Attention` of ViT with SGLang’s `VisionAttention`.
43
+
44
+ You can refer to [Qwen2VL](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/qwen2_vl.py) or
45
+ other mllm implementations. These models demonstrate how to correctly handle both multimodal and textual inputs.
46
+
47
+ ## Testing and Debugging
48
+
49
+ Please note all your testing and benchmarking results in PR description.
50
+
51
+ ### Interactive Debugging
52
+
53
+ For interactive debugging, compare the outputs of Hugging Face/Transformers and SGLang. The following two commands
54
+ should give the same text output and very similar prefill logits:
55
+
56
+ - Get the reference output:
57
+ ```bash
58
+ python3 scripts/playground/reference_hf.py --model-path [new model] --model-type {text,mllm}
59
+ ```
60
+ - Get the SGLang output:
61
+ ```bash
62
+ python3 -m sglang.bench_one_batch --correct --model [new model]
63
+ ```
64
+
65
+ ### Add the Model to the Test Suite
66
+
67
+ To ensure the new model is well maintained, add it to the test suite by including it in the `ALL_OTHER_MODELS` list in
68
+ the [test_generation_models.py](https://github.com/sgl-project/sglang/blob/main/test/srt/models/test_generation_models.py)
69
+ file, test the new model on your local machine and report the results on demonstrative benchmarks (GSM8K, MMLU, MMMU,
70
+ MMMU-Pro, etc.) in your PR. \\
71
+ For VLMs, also include a test in `test_vision_openai_server_{x}.py` (e.g. [test_vision_openai_server_a.py](https://github.com/sgl-project/sglang/blob/main/test/srt/test_vision_openai_server_a.py), [test_vision_openai_server_b.py](https://github.com/sgl-project/sglang/blob/main/test/srt/test_vision_openai_server_b.py)).
72
+
73
+ This is an example command to run to test a new model on your local machine:
74
+
75
+ ```bash
76
+ ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others
77
+ ```
78
+
79
+ ### Benchmark
80
+
81
+ - **(Required) MMMU**: follow MMMU benchmark [README.md](https://github.com/sgl-project/sglang/blob/main/benchmark/mmmu/README.md) to get SGLang vs. HF Transformer accuracy comparison. The accuracy score from SGLang run should not be much lower than that from HF Transformer run. Similarly, follow https://docs.sglang.io/developer_guide/benchmark_and_profiling.html to get performance comparison: TTFT and throughput must meet or exceed baselines (e.g., HF Transformer).
82
+ - **(Optional) Other evals**: If you ran other evals, please note the results in PR description.
83
+
84
+ ## Port a Model from vLLM to SGLang
85
+
86
+ The [vLLM Models Directory](https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models) is a valuable
87
+ resource, as vLLM covers many models. SGLang reuses vLLM’s interface and some layers, making it easier to port models
88
+ from vLLM to SGLang.
89
+
90
+ To port a model from vLLM to SGLang:
91
+
92
+ - Compare these two files for guidance:
93
+ - [SGLang Llama Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py)
94
+ - [vLLM Llama Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py)
95
+ - The major differences include:
96
+ - **Replace vLLM’s `Attention` with `RadixAttention`** (ensure you pass `layer_id` to `RadixAttention`).
97
+ - **Replace vLLM’s `LogitsProcessor` with SGLang’s `LogitsProcessor`.**
98
+ - **Replace the multi-headed `Attention` of ViT with SGLang’s `VisionAttention`.**
99
+ - **Replace other vLLM layers** (such as `RMSNorm`, `SiluAndMul`) with SGLang layers.
100
+ - **Remove `Sample`.**
101
+ - **Change the `forward()` functions** and add a `forward_batch()` method.
102
+ - **Add `EntryClass`** at the end.
103
+ - **Ensure that the new implementation uses only SGLang components** and does not rely on any vLLM components.
104
+
105
+ Note: make sure you add your new model to the supported models list in the supported models documentation.
106
+
107
+ ## Registering an External Model Implementation
108
+
109
+ In addition to the methods above, you can register your new model with the `ModelRegistry` before launching the server.
110
+ This allows you to integrate your model without modifying the source code.
111
+
112
+ For example:
113
+
114
+ ```python
115
+ from sglang.srt.models.registry import ModelRegistry
116
+ from sglang.srt.entrypoints.http_server import launch_server
117
+
118
+ # For a single model, add it to the registry:
119
+ ModelRegistry.models[model_name] = model_class
120
+
121
+ # For multiple models, you can imitate the import_model_classes() function:
122
+ from functools import lru_cache
123
+
124
+ @lru_cache()
125
+ def import_new_model_classes():
126
+ model_arch_name_to_cls = {}
127
+ # Populate model_arch_name_to_cls with your new model classes.
128
+ ...
129
+ return model_arch_name_to_cls
130
+
131
+ ModelRegistry.models.update(import_new_model_classes())
132
+
133
+ # Launch the server with your server arguments:
134
+ launch_server(server_args)
135
+ ```
136
+
137
+ ## Example: Implementing and Serving a Llama Wrapper Model
138
+
139
+ Below is an introductory, step-by-step walkthrough on how to implement a new model end-to-end in SGLang and then run it via the [Offline Engine](https://github.com/sgl-project/sglang/blob/main/docs/basic_usage/offline_engine_api.ipynb).
140
+
141
+ ### Implementing Our Model
142
+
143
+ To keep things simple, this new model will be a simple wrapper around [Llama 3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct), and our goal will be just to bias the output logits for each `forward` call by taking the square root of each individual logit.
144
+
145
+ Let's start by defining our model in a file called `llama_wrapper.py`.
146
+ The first step is to import the necessary libraries from SRT, which is SGLang's internal backend.
147
+
148
+ ```python
149
+ # In the file `llama_wrapper.py`
150
+
151
+ import torch
152
+ from transformers import LlamaConfig
153
+ from typing import Optional
154
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
155
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
156
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
157
+
158
+ from sglang.srt.models.llama import LlamaForCausalLM
159
+ ```
160
+
161
+ Next, we declare a new `class` for our model and have it inherit from `LlamaForCausalLM`, which allows our model to access `LlamaForCausalLM`'s predefined modules and layers, such as `LlamaAttention` and `LlamaMLP`.
162
+ Note that almost all model implementations take in `config` and `quant_config` as arguments for their `__init__` method; `config` and `quant_config` are passed in via [`model_loader/loader.py`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_loader/loader.py#L219).
163
+ Because we have inherited from `LlamaForCausalLM`, we can pass our parameters directly to its constructor, which will set the member variables for us.
164
+
165
+ ```python
166
+ class LlamaWrapper(LlamaForCausalLM):
167
+ def __init__(
168
+ self,
169
+ config: LlamaConfig,
170
+ quant_config: Optional[QuantizationConfig] = None,
171
+ prefix: str = "",
172
+ ) -> None:
173
+ super().__init__(config=config, quant_config=quant_config, prefix=prefix)
174
+ ```
175
+
176
+ Now, we want to define the `forward` method, which is what will be called at inference time.
177
+ Note that the signature for `forward` is essentially the same for any model; you can take a look at the other models defined in the [`models` directory](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/) for references.
178
+ To see where exactly `forward` is called in the SGLang runtime's internals, take a look at [`forward_decode`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_executor/model_runner.py#L1705) and [`forward_extend`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_executor/model_runner.py#L1724) in the [`ModelRunner` class](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_executor/model_runner.py).
179
+
180
+ ```python
181
+ @torch.no_grad()
182
+ def forward(
183
+ self,
184
+ input_ids: torch.Tensor,
185
+ positions: torch.Tensor,
186
+ forward_batch: ForwardBatch,
187
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
188
+ input_embeds: Optional[torch.Tensor] = None,
189
+ get_embedding: bool = False,
190
+ ) -> LogitsProcessorOutput:
191
+ ```
192
+
193
+ We now call the `__call__` method for `self.model` (which is a member variable that `LlamaForCausalLM` defines in its `__init__` method), which eventually calls `LlamaForCausalLM`'s `forward` method.
194
+ After that, we feed the `hidden_states` into our model's `LogitsProcessor` (again defined in `LlamaForCausalLM`).
195
+
196
+ ```python
197
+ hidden_states = self.model(
198
+ input_ids,
199
+ positions,
200
+ forward_batch,
201
+ input_embeds,
202
+ pp_proxy_tensors=pp_proxy_tensors,
203
+ )
204
+
205
+ res: LogitsProcessorOutput = self.logits_processor(
206
+ input_ids,
207
+ hidden_states,
208
+ self.lm_head,
209
+ forward_batch,
210
+ )
211
+ ```
212
+
213
+ After receiving the logits for the next token, we can finally perform our biasing step.
214
+
215
+ ```python
216
+ orig_logits = res.next_token_logits
217
+ res.next_token_logits = torch.where(
218
+ orig_logits > 0,
219
+ orig_logits.sqrt(),
220
+ orig_logits
221
+ )
222
+
223
+ return res
224
+ ```
225
+
226
+ Now, our `LlamaWrapper` model is created and ready to be served!
227
+
228
+ ### Serving Our Model Via SGLang's Offline Engine
229
+
230
+ The next step of this walkthrough involves hosting our new model offline, so that it can be served locally and without an HTTP server.
231
+
232
+ First, create a new file called `run.py`.
233
+ Now, we must ensure that SGLang's `ModelRegistry` can find our model.
234
+ To do this, we first download the model's configuration and weights from Huggingface.
235
+
236
+ ```python
237
+ # In the file `run.py`
238
+
239
+ import asyncio
240
+ from functools import lru_cache
241
+ from huggingface_hub import snapshot_download
242
+ from llama_wrapper import LlamaWrapper # Make sure to import our new model!
243
+ import sglang as sgl
244
+ from sglang.srt.models.registry import ModelRegistry
245
+
246
+ # Make sure to request access to this model on Huggingface, then export your
247
+ # `HF_TOKEN` to download the model snapshot
248
+ llama_dir = snapshot_download(
249
+ repo_id="meta-llama/Llama-3.1-8B-Instruct",
250
+ local_dir="./llama_ckpt",
251
+ )
252
+ ```
253
+
254
+ Now that we have our model on disk, we want to point it to `LlamaWrapper` by changing the `architectures` field in `./llama_ckpt/config.json` to be `LlamaWrapper`.
255
+ That way, when we pass in the path of our model checkpoint to SGLang, it will know that we want to use "LlamaWrapper" instead of "LlamaForCausalLM" as our model.
256
+
257
+ ```python
258
+ {
259
+ "architectures": [
260
+ # "LlamaForCausalLM"
261
+ "LlamaWrapper"
262
+ ],
263
+ ...
264
+ }
265
+ ```
266
+
267
+ However, if we don't link our `LlamaWrapper` class to the "LlamaWrapper" registry keyword, then SGLang won't be able to find our model.
268
+ Thus, to register our `LlamaWrapper`, we want to follow the steps in the above section titled "Registering an External Model Implementation".
269
+
270
+ ```python
271
+ @lru_cache()
272
+ def import_new_model_classes():
273
+ model_arch_name_to_cls = {"LlamaWrapper": LlamaWrapper}
274
+ return model_arch_name_to_cls
275
+
276
+ ModelRegistry.models.update(import_new_model_classes())
277
+ ```
278
+
279
+ Lastly, when we create our `Engine`, we just pass in the path to the local model directory.
280
+ Then, our `LlamaWrapper` is ready to be served; for this walkthrough, we will use SGLang `Engine`'s non-streaming asynchronous generation endpoint.
281
+
282
+ ```python
283
+ def main():
284
+ llm = sgl.Engine(model_path="./llama_ckpt")
285
+ sampling_params = {"temperature": 0.2, "top_k": 5}
286
+ prompts = [
287
+ "Write a short, neutral self-introduction for a fictional character. Hello, my name is",
288
+ "Provide a concise factual statement about France’s capital city. The capital of France is",
289
+ "Explain possible future trends in artificial intelligence. The future of AI is",
290
+ ]
291
+
292
+ asyncio.run(run_llm(llm, sampling_params, prompts))
293
+
294
+ llm.shutdown()
295
+
296
+ async def run_llm(
297
+ llm,
298
+ sampling_params,
299
+ prompts,
300
+ ) -> None:
301
+ outputs = await llm.async_generate(prompts, sampling_params)
302
+
303
+ for prompt, output in zip(prompts, outputs):
304
+ print(f"\nPrompt: {prompt}")
305
+ print(f"Generated text: {output['text']}")
306
+
307
+ if __name__ == "__main__":
308
+ main()
309
+ ```
310
+
311
+ Now, when we call `python run.py`, we will get the outputs of our newly created model!
312
+
313
+ ## Documentation
314
+
315
+ Add to table of supported models in [generative_models.md](../text_generation/generative_models.md) or [multimodal_language_models.md](../text_generation/multimodal_language_models.md)
316
+
317
+ ---
318
+
319
+ By following these guidelines, you can add support for new language models and multimodal large language models in
320
+ SGLang and ensure they are thoroughly tested and easily integrated into the system.
sglang/docs/supported_models/retrieval_ranking/classify_models.md ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Classification API
2
+
3
+ This document describes the `/v1/classify` API endpoint implementation in SGLang, which is compatible with vLLM's classification API format.
4
+
5
+ ## Overview
6
+
7
+ The classification API allows you to classify text inputs using classification models. This implementation follows the same format as vLLM's 0.7.0 classification API.
8
+
9
+ ## API Endpoint
10
+
11
+ ```
12
+ POST /v1/classify
13
+ ```
14
+
15
+ ## Request Format
16
+
17
+ ```json
18
+ {
19
+ "model": "model_name",
20
+ "input": "text to classify"
21
+ }
22
+ ```
23
+
24
+ ### Parameters
25
+
26
+ - `model` (string, required): The name of the classification model to use
27
+ - `input` (string, required): The text to classify
28
+ - `user` (string, optional): User identifier for tracking
29
+ - `rid` (string, optional): Request ID for tracking
30
+ - `priority` (integer, optional): Request priority
31
+
32
+ ## Response Format
33
+
34
+ ```json
35
+ {
36
+ "id": "classify-9bf17f2847b046c7b2d5495f4b4f9682",
37
+ "object": "list",
38
+ "created": 1745383213,
39
+ "model": "jason9693/Qwen2.5-1.5B-apeach",
40
+ "data": [
41
+ {
42
+ "index": 0,
43
+ "label": "Default",
44
+ "probs": [0.565970778465271, 0.4340292513370514],
45
+ "num_classes": 2
46
+ }
47
+ ],
48
+ "usage": {
49
+ "prompt_tokens": 10,
50
+ "total_tokens": 10,
51
+ "completion_tokens": 0,
52
+ "prompt_tokens_details": null
53
+ }
54
+ }
55
+ ```
56
+
57
+ ### Response Fields
58
+
59
+ - `id`: Unique identifier for the classification request
60
+ - `object`: Always "list"
61
+ - `created`: Unix timestamp when the request was created
62
+ - `model`: The model used for classification
63
+ - `data`: Array of classification results
64
+ - `index`: Index of the result
65
+ - `label`: Predicted class label
66
+ - `probs`: Array of probabilities for each class
67
+ - `num_classes`: Total number of classes
68
+ - `usage`: Token usage information
69
+ - `prompt_tokens`: Number of input tokens
70
+ - `total_tokens`: Total number of tokens
71
+ - `completion_tokens`: Number of completion tokens (always 0 for classification)
72
+ - `prompt_tokens_details`: Additional token details (optional)
73
+
74
+ ## Example Usage
75
+
76
+ ### Using curl
77
+
78
+ ```bash
79
+ curl -v "http://127.0.0.1:8000/v1/classify" \
80
+ -H "Content-Type: application/json" \
81
+ -d '{
82
+ "model": "jason9693/Qwen2.5-1.5B-apeach",
83
+ "input": "Loved the new café—coffee was great."
84
+ }'
85
+ ```
86
+
87
+ ### Using Python
88
+
89
+ ```python
90
+ import requests
91
+ import json
92
+
93
+ # Make classification request
94
+ response = requests.post(
95
+ "http://127.0.0.1:8000/v1/classify",
96
+ headers={"Content-Type": "application/json"},
97
+ json={
98
+ "model": "jason9693/Qwen2.5-1.5B-apeach",
99
+ "input": "Loved the new café—coffee was great."
100
+ }
101
+ )
102
+
103
+ # Parse response
104
+ result = response.json()
105
+ print(json.dumps(result, indent=2))
106
+ ```
107
+
108
+ ## Supported Models
109
+
110
+ The classification API works with any classification model supported by SGLang, including:
111
+
112
+ ### Classification Models (Multi-class)
113
+ - `LlamaForSequenceClassification` - Multi-class classification
114
+ - `Qwen2ForSequenceClassification` - Multi-class classification
115
+ - `Qwen3ForSequenceClassification` - Multi-class classification
116
+ - `BertForSequenceClassification` - Multi-class classification
117
+ - `Gemma2ForSequenceClassification` - Multi-class classification
118
+
119
+ **Label Mapping**: The API automatically uses the `id2label` mapping from the model's `config.json` file to provide meaningful label names instead of generic class names. If `id2label` is not available, it falls back to `LABEL_0`, `LABEL_1`, etc., or `Class_0`, `Class_1` as a last resort.
120
+
121
+ ### Reward Models (Single score)
122
+ - `InternLM2ForRewardModel` - Single reward score
123
+ - `Qwen2ForRewardModel` - Single reward score
124
+ - `LlamaForSequenceClassificationWithNormal_Weights` - Special reward model
125
+
126
+ **Note**: The `/classify` endpoint in SGLang was originally designed for reward models but now supports all non-generative models. Our `/v1/classify` endpoint provides a standardized vLLM-compatible interface for classification tasks.
127
+
128
+ ## Error Handling
129
+
130
+ The API returns appropriate HTTP status codes and error messages:
131
+
132
+ - `400 Bad Request`: Invalid request format or missing required fields
133
+ - `500 Internal Server Error`: Server-side processing error
134
+
135
+ Error response format:
136
+ ```json
137
+ {
138
+ "error": "Error message",
139
+ "type": "error_type",
140
+ "code": 400
141
+ }
142
+ ```
143
+
144
+ ## Implementation Details
145
+
146
+ The classification API is implemented using:
147
+
148
+ 1. **Rust Model Gateway**: Handles routing and request/response models in `sgl-model-gateway/src/protocols/spec.rs`
149
+ 2. **Python HTTP Server**: Implements the actual endpoint in `python/sglang/srt/entrypoints/http_server.py`
150
+ 3. **Classification Service**: Handles the classification logic in `python/sglang/srt/entrypoints/openai/serving_classify.py`
151
+
152
+ ## Testing
153
+
154
+ Use the provided test script to verify the implementation:
155
+
156
+ ```bash
157
+ python test_classify_api.py
158
+ ```
159
+
160
+ ## Compatibility
161
+
162
+ This implementation is compatible with vLLM's classification API format, allowing seamless migration from vLLM to SGLang for classification tasks.
sglang/docs/supported_models/retrieval_ranking/embedding_models.md ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Embedding Models
2
+
3
+ SGLang provides robust support for embedding models by integrating efficient serving mechanisms with its flexible programming interface. This integration allows for streamlined handling of embedding tasks, facilitating faster and more accurate retrieval and semantic search operations. SGLang's architecture enables better resource utilization and reduced latency in embedding model deployment.
4
+
5
+ ```{important}
6
+ Embedding models are executed with `--is-embedding` flag and some may require `--trust-remote-code`
7
+ ```
8
+
9
+ ## Quick Start
10
+
11
+ ### Launch Server
12
+
13
+ ```shell
14
+ python3 -m sglang.launch_server \
15
+ --model-path Qwen/Qwen3-Embedding-4B \
16
+ --is-embedding \
17
+ --host 0.0.0.0 \
18
+ --port 30000
19
+ ```
20
+
21
+ ### Client Request
22
+
23
+ ```python
24
+ import requests
25
+
26
+ url = "http://127.0.0.1:30000"
27
+
28
+ payload = {
29
+ "model": "Qwen/Qwen3-Embedding-4B",
30
+ "input": "What is the capital of France?",
31
+ "encoding_format": "float"
32
+ }
33
+
34
+ response = requests.post(url + "/v1/embeddings", json=payload).json()
35
+ print("Embedding:", response["data"][0]["embedding"])
36
+ ```
37
+
38
+
39
+
40
+ ## Multimodal Embedding Example
41
+
42
+ For multimodal models like GME that support both text and images:
43
+
44
+ ```shell
45
+ python3 -m sglang.launch_server \
46
+ --model-path Alibaba-NLP/gme-Qwen2-VL-2B-Instruct \
47
+ --is-embedding \
48
+ --chat-template gme-qwen2-vl \
49
+ --host 0.0.0.0 \
50
+ --port 30000
51
+ ```
52
+
53
+ ```python
54
+ import requests
55
+
56
+ url = "http://127.0.0.1:30000"
57
+
58
+ text_input = "Represent this image in embedding space."
59
+ image_path = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/023.jpg"
60
+
61
+ payload = {
62
+ "model": "gme-qwen2-vl",
63
+ "input": [
64
+ {
65
+ "text": text_input
66
+ },
67
+ {
68
+ "image": image_path
69
+ }
70
+ ],
71
+ }
72
+
73
+ response = requests.post(url + "/v1/embeddings", json=payload).json()
74
+
75
+ print("Embeddings:", [x.get("embedding") for x in response.get("data", [])])
76
+ ```
77
+
78
+ ## Matryoshka Embedding Example
79
+
80
+ [Matryoshka Embeddings](https://sbert.net/examples/sentence_transformer/training/matryoshka/README.html#matryoshka-embeddings) or [Matryoshka Representation Learning (MRL)](https://arxiv.org/abs/2205.13147) is a technique used in training embedding models. It allows user to trade off between performance and cost.
81
+
82
+ ### 1. Launch a Matryoshka‑capable model
83
+
84
+ If the model config already includes `matryoshka_dimensions` or `is_matryoshka` then no override is needed. Otherwise, you can use `--json-model-override-args` as below:
85
+
86
+ ```shell
87
+ python3 -m sglang.launch_server \
88
+ --model-path Qwen/Qwen3-Embedding-0.6B \
89
+ --is-embedding \
90
+ --host 0.0.0.0 \
91
+ --port 30000 \
92
+ --json-model-override-args '{"matryoshka_dimensions": [128, 256, 512, 1024, 1536]}'
93
+ ```
94
+
95
+ 1. Setting `"is_matryoshka": true` allows truncating to any dimension. Otherwise, the server will validate that the specified dimension in the request is one of `matryoshka_dimensions`.
96
+ 2. Omitting `dimensions` in a request returns the full vector.
97
+
98
+ ### 2. Make requests with different output dimensions
99
+
100
+ ```python
101
+ import requests
102
+
103
+ url = "http://127.0.0.1:30000"
104
+
105
+ # Request a truncated (Matryoshka) embedding by specifying a supported dimension.
106
+ payload = {
107
+ "model": "Qwen/Qwen3-Embedding-0.6B",
108
+ "input": "Explain diffusion models simply.",
109
+ "dimensions": 512 # change to 128 / 1024 / omit for full size
110
+ }
111
+
112
+ response = requests.post(url + "/v1/embeddings", json=payload).json()
113
+ print("Embedding:", response["data"][0]["embedding"])
114
+ ```
115
+
116
+
117
+ ## Supported Models
118
+
119
+ | Model Family | Example Model | Chat Template | Description |
120
+ | ------------------------------------------ | -------------------------------------- | ------------- | --------------------------------------------------------------------------- |
121
+ | **E5 (Llama/Mistral based)** | `intfloat/e5-mistral-7b-instruct` | N/A | High-quality text embeddings based on Mistral/Llama architectures |
122
+ | **GTE-Qwen2** | `Alibaba-NLP/gte-Qwen2-7B-instruct` | N/A | Alibaba's text embedding model with multilingual support |
123
+ | **Qwen3-Embedding** | `Qwen/Qwen3-Embedding-4B` | N/A | Latest Qwen3-based text embedding model for semantic representation |
124
+ | **BGE** | `BAAI/bge-large-en-v1.5` | N/A | BAAI's text embeddings (requires `attention-backend` triton/torch_native) |
125
+ | **GME (Multimodal)** | `Alibaba-NLP/gme-Qwen2-VL-2B-Instruct`| `gme-qwen2-vl`| Multimodal embedding for text and image cross-modal tasks |
126
+ | **CLIP** | `openai/clip-vit-large-patch14-336` | N/A | OpenAI's CLIP for image and text embeddings |
sglang/docs/supported_models/retrieval_ranking/rerank_models.md ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Rerank Models
2
+
3
+ SGLang offers comprehensive support for rerank models by incorporating optimized serving frameworks with a flexible programming interface. This setup enables efficient processing of cross-encoder reranking tasks, improving the accuracy and relevance of search result ordering. SGLang’s design ensures high throughput and low latency during reranker model deployment, making it ideal for semantic-based result refinement in large-scale retrieval systems.
4
+
5
+ ```{important}
6
+ Rerank models in SGLang fall into two categories:
7
+
8
+ - **Cross-encoder rerank models**: run with `--is-embedding` (embedding runner).
9
+ - **Decoder-only rerank models**: run **without** `--is-embedding` and use next-token logprob scoring (yes/no).
10
+ - Text-only (e.g. Qwen3-Reranker)
11
+ - Multimodal (e.g. Qwen3-VL-Reranker): also supports image/video content
12
+
13
+ Some models may require `--trust-remote-code`.
14
+ ```
15
+
16
+ ## Supported rerank models
17
+
18
+ | Model Family (Rerank) | Example HuggingFace Identifier | Chat Template | Description |
19
+ |------------------------------------------------|--------------------------------------|---------------|----------------------------------------------------------------------------------------------------------------------------------|
20
+ | **BGE-Reranker (BgeRerankModel)** | `BAAI/bge-reranker-v2-m3` | N/A | Currently only support `attention-backend` `triton` and `torch_native`. High-performance cross-encoder reranker model from BAAI. Suitable for reranking search results based on semantic relevance. |
21
+ | **Qwen3-Reranker (decoder-only yes/no)** | `Qwen/Qwen3-Reranker-8B` | `examples/chat_template/qwen3_reranker.jinja` | Decoder-only reranker using next-token logprob scoring for labels (yes/no). Launch **without** `--is-embedding`. |
22
+ | **Qwen3-VL-Reranker (multimodal yes/no)** | `Qwen/Qwen3-VL-Reranker-2B` | `examples/chat_template/qwen3_vl_reranker.jinja` | Multimodal decoder-only reranker supporting text, images, and videos. Uses yes/no logprob scoring. Launch **without** `--is-embedding`. |
23
+
24
+
25
+ ## Cross-Encoder Rerank (embedding runner)
26
+
27
+ ### Launch Command
28
+
29
+ ```shell
30
+ python3 -m sglang.launch_server \
31
+ --model-path BAAI/bge-reranker-v2-m3 \
32
+ --host 0.0.0.0 \
33
+ --disable-radix-cache \
34
+ --chunked-prefill-size -1 \
35
+ --attention-backend triton \
36
+ --is-embedding \
37
+ --port 30000
38
+ ```
39
+
40
+ ### Example Client Request
41
+
42
+ ```python
43
+ import requests
44
+
45
+ url = "http://127.0.0.1:30000/v1/rerank"
46
+
47
+ payload = {
48
+ "model": "BAAI/bge-reranker-v2-m3",
49
+ "query": "what is panda?",
50
+ "documents": [
51
+ "hi",
52
+ "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China."
53
+ ],
54
+ "top_n": 1,
55
+ "return_documents": True
56
+ }
57
+
58
+ response = requests.post(url, json=payload)
59
+ response_json = response.json()
60
+
61
+ for item in response_json:
62
+ if item.get("document"):
63
+ print(f"Score: {item['score']:.2f} - Document: '{item['document']}'")
64
+ else:
65
+ print(f"Score: {item['score']:.2f} - Index: {item['index']}")
66
+ ```
67
+
68
+ **Request Parameters:**
69
+
70
+ - `query` (required): The query text to rank documents against
71
+ - `documents` (required): List of documents to be ranked
72
+ - `model` (required): Model to use for reranking
73
+ - `top_n` (optional): Maximum number of documents to return. Defaults to returning all documents. If specified value is greater than the total number of documents, all documents will be returned.
74
+ - `return_documents` (optional): Whether to return documents in the response. Defaults to `True`.
75
+
76
+ ## Qwen3-Reranker (decoder-only yes/no rerank)
77
+
78
+ ### Launch Command
79
+
80
+ ```shell
81
+ python3 -m sglang.launch_server \
82
+ --model-path Qwen/Qwen3-Reranker-0.6B \
83
+ --trust-remote-code \
84
+ --disable-radix-cache \
85
+ --host 0.0.0.0 \
86
+ --port 8001 \
87
+ --chat-template examples/chat_template/qwen3_reranker.jinja
88
+ ```
89
+
90
+ ```{note}
91
+ Qwen3-Reranker uses decoder-only logprob scoring (yes/no). Do NOT launch it with `--is-embedding`.
92
+ ```
93
+
94
+ ### Example Client Request (supports optional instruct, top_n, and return_documents)
95
+
96
+ ```shell
97
+ curl -X POST http://127.0.0.1:8001/v1/rerank \
98
+ -H "Content-Type: application/json" \
99
+ -d '{
100
+ "model": "Qwen3-Reranker-0.6B",
101
+ "query": "法国首都是哪里?",
102
+ "documents": [
103
+ "法国的首都是巴黎。",
104
+ "德国的首都是柏林。",
105
+ "香蕉是黄色的水果。"
106
+ ],
107
+ "instruct": "Given a web search query, retrieve relevant passages that answer the query.",
108
+ "top_n": 2,
109
+ "return_documents": true
110
+ }'
111
+ ```
112
+
113
+ **Request Parameters:**
114
+
115
+ - `query` (required): The query text to rank documents against
116
+ - `documents` (required): List of documents to be ranked
117
+ - `model` (required): Model to use for reranking
118
+ - `instruct` (optional): Instruction text for the reranker
119
+ - `top_n` (optional): Maximum number of documents to return. Defaults to returning all documents. If specified value is greater than the total number of documents, all documents will be returned.
120
+ - `return_documents` (optional): Whether to return documents in the response. Defaults to `True`.
121
+
122
+ ### Response Format
123
+
124
+ `/v1/rerank` returns a list of objects (sorted by descending score):
125
+
126
+ - `score`: float, higher means more relevant
127
+ - `document`: the original document string (only included when `return_documents` is `true`)
128
+ - `index`: the original index in the input `documents`
129
+ - `meta_info`: optional debug/usage info (may be present for some models)
130
+
131
+ The number of returned results is controlled by the `top_n` parameter. If `top_n` is not specified or is greater than the total number of documents, all documents are returned.
132
+
133
+ Example (with `return_documents: true`):
134
+
135
+ ```json
136
+ [
137
+ {"score": 0.99, "document": "法国的首都是巴黎。", "index": 0},
138
+ {"score": 0.01, "document": "德国的首都是柏林。", "index": 1},
139
+ {"score": 0.00, "document": "香蕉是黄色的水果。", "index": 2}
140
+ ]
141
+ ```
142
+
143
+ Example (with `return_documents: false`):
144
+
145
+ ```json
146
+ [
147
+ {"score": 0.99, "index": 0},
148
+ {"score": 0.01, "index": 1},
149
+ {"score": 0.00, "index": 2}
150
+ ]
151
+ ```
152
+
153
+ Example (with `top_n: 2`):
154
+
155
+ ```json
156
+ [
157
+ {"score": 0.99, "document": "法国的首都是巴黎。", "index": 0},
158
+ {"score": 0.01, "document": "德国的首都是柏林。", "index": 1}
159
+ ]
160
+ ```
161
+
162
+ ### Common Pitfalls
163
+
164
+ - If you launch Qwen3-Reranker with `--is-embedding`, `/v1/rerank` cannot compute yes/no logprob scores. Relaunch **without** `--is-embedding`.
165
+ - If you see a validation error like "score should be a valid number" and the backend returned a list, upgrade to a version that coerces `embedding[0]` into `score` for rerank responses.
166
+
167
+ ## Qwen3-VL-Reranker (multimodal decoder-only rerank)
168
+
169
+ Qwen3-VL-Reranker extends the Qwen3-Reranker to support multimodal content, allowing reranking of documents containing text, images, and videos.
170
+
171
+ ### Launch Command
172
+
173
+ ```shell
174
+ python3 -m sglang.launch_server \
175
+ --model-path Qwen/Qwen3-VL-Reranker-2B \
176
+ --trust-remote-code \
177
+ --disable-radix-cache \
178
+ --host 0.0.0.0 \
179
+ --port 30000 \
180
+ --chat-template examples/chat_template/qwen3_vl_reranker.jinja
181
+ ```
182
+
183
+ ```{note}
184
+ Qwen3-VL-Reranker uses decoder-only logprob scoring (yes/no) like Qwen3-Reranker. Do NOT launch it with `--is-embedding`.
185
+ ```
186
+
187
+ ### Text-Only Reranking (backward compatible)
188
+
189
+ ```python
190
+ import requests
191
+
192
+ url = "http://127.0.0.1:30000/v1/rerank"
193
+
194
+ payload = {
195
+ "model": "Qwen3-VL-Reranker-2B",
196
+ "query": "What is machine learning?",
197
+ "documents": [
198
+ "Machine learning is a branch of artificial intelligence that enables computers to learn from data.",
199
+ "The weather in Paris is usually mild with occasional rain.",
200
+ "Deep learning is a subset of machine learning using neural networks with many layers.",
201
+ ],
202
+ "instruct": "Retrieve passages that answer the question.",
203
+ "return_documents": True
204
+ }
205
+
206
+ response = requests.post(url, json=payload)
207
+ results = response.json()
208
+
209
+ for item in results:
210
+ print(f"Score: {item['score']:.4f} - {item['document'][:60]}...")
211
+ ```
212
+
213
+ ### Image Reranking (text query, image/mixed documents)
214
+
215
+ ```python
216
+ import requests
217
+
218
+ url = "http://127.0.0.1:30000/v1/rerank"
219
+
220
+ payload = {
221
+ "query": "A woman playing with her dog on a beach at sunset.",
222
+ "documents": [
223
+ # Document 1: Text description
224
+ "A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset.",
225
+ # Document 2: Image URL
226
+ [
227
+ {
228
+ "type": "image_url",
229
+ "image_url": {
230
+ "url": "https://example.com/beach_dog.jpeg"
231
+ }
232
+ }
233
+ ],
234
+ # Document 3: Text + Image (mixed)
235
+ [
236
+ {"type": "text", "text": "A joyful scene at the beach:"},
237
+ {
238
+ "type": "image_url",
239
+ "image_url": {
240
+ "url": "https://example.com/beach_dog.jpeg"
241
+ }
242
+ }
243
+ ]
244
+ ],
245
+ "instruct": "Retrieve images or text relevant to the user's query.",
246
+ "return_documents": False
247
+ }
248
+
249
+ response = requests.post(url, json=payload)
250
+ results = response.json()
251
+
252
+ for item in results:
253
+ print(f"Index: {item['index']}, Score: {item['score']:.4f}")
254
+ ```
255
+
256
+ ### Multimodal Query Reranking (query with image)
257
+
258
+ ```python
259
+ import requests
260
+
261
+ url = "http://127.0.0.1:30000/v1/rerank"
262
+
263
+ payload = {
264
+ # Query with text and image
265
+ "query": [
266
+ {"type": "text", "text": "Find similar images to this:"},
267
+ {
268
+ "type": "image_url",
269
+ "image_url": {
270
+ "url": "https://example.com/reference_image.jpeg"
271
+ }
272
+ }
273
+ ],
274
+ "documents": [
275
+ "A cat sleeping on a couch.",
276
+ "A woman and her dog enjoying the sunset at the beach.",
277
+ "A busy city street with cars and pedestrians.",
278
+ [
279
+ {
280
+ "type": "image_url",
281
+ "image_url": {
282
+ "url": "https://example.com/similar_image.jpeg"
283
+ }
284
+ }
285
+ ]
286
+ ],
287
+ "instruct": "Find images or descriptions similar to the query image."
288
+ }
289
+
290
+ response = requests.post(url, json=payload)
291
+ results = response.json()
292
+
293
+ for item in results:
294
+ print(f"Index: {item['index']}, Score: {item['score']:.4f}")
295
+ ```
296
+
297
+ ### Request Parameters (Multimodal)
298
+
299
+ - `query` (required): Can be a string (text-only) or a list of content parts:
300
+ - `{"type": "text", "text": "..."}` for text
301
+ - `{"type": "image_url", "image_url": {"url": "..."}}` for images
302
+ - `{"type": "video_url", "video_url": {"url": "..."}}` for videos
303
+ - `documents` (required): List where each document can be a string or list of content parts (same format as query)
304
+ - `instruct` (optional): Instruction text for the reranker
305
+ - `top_n` (optional): Maximum number of documents to return
306
+ - `return_documents` (optional): Whether to return documents in the response (default: `false`)
307
+
308
+ ### Common Pitfalls
309
+
310
+ - Always use `--chat-template examples/chat_template/qwen3_vl_reranker.jinja` for Qwen3-VL-Reranker.
311
+ - Do NOT launch with `--is-embedding`.
312
+ - For best results, use `--disable-radix-cache` to avoid caching issues with multimodal content.
313
+ - **Note**: Currently only `Qwen3-VL-Reranker-2B` is tested and supported. The 8B model may have different behavior and is not guaranteed to work with this template.
sglang/docs/supported_models/specialized/index.rst ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Specialized Models
2
+ ==================
3
+
4
+ Models for specialized tasks like reward modeling.
5
+
6
+ .. toctree::
7
+ :maxdepth: 1
8
+
9
+ reward_models.md
sglang/docs/supported_models/specialized/reward_models.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reward Models
2
+
3
+ These models output a scalar reward score or classification result, often used in reinforcement learning or content moderation tasks.
4
+
5
+ ```{important}
6
+ They are executed with `--is-embedding` and some may require `--trust-remote-code`.
7
+ ```
8
+
9
+ ## Example launch Command
10
+
11
+ ```shell
12
+ python3 -m sglang.launch_server \
13
+ --model-path Qwen/Qwen2.5-Math-RM-72B \ # example HF/local path
14
+ --is-embedding \
15
+ --host 0.0.0.0 \
16
+ --tp-size=4 \ # set for tensor parallelism
17
+ --port 30000 \
18
+ ```
19
+
20
+ ## Supported models
21
+
22
+ | Model Family (Reward) | Example HuggingFace Identifier | Description |
23
+ |---------------------------------------------------------------------------|-----------------------------------------------------|---------------------------------------------------------------------------------|
24
+ | **Llama (3.1 Reward / `LlamaForSequenceClassification`)** | `Skywork/Skywork-Reward-Llama-3.1-8B-v0.2` | Reward model (preference classifier) based on Llama 3.1 (8B) for scoring and ranking responses for RLHF. |
25
+ | **Gemma 2 (27B Reward / `Gemma2ForSequenceClassification`)** | `Skywork/Skywork-Reward-Gemma-2-27B-v0.2` | Derived from Gemma‑2 (27B), this model provides human preference scoring for RLHF and multilingual tasks. |
26
+ | **InternLM 2 (Reward / `InternLM2ForRewardMode`)** | `internlm/internlm2-7b-reward` | InternLM 2 (7B)–based reward model used in alignment pipelines to guide outputs toward preferred behavior. |
27
+ | **Qwen2.5 (Reward - Math / `Qwen2ForRewardModel`)** | `Qwen/Qwen2.5-Math-RM-72B` | A 72B math-specialized RLHF reward model from the Qwen2.5 series, tuned for evaluating and refining responses. |
28
+ | **Qwen2.5 (Reward - Sequence / `Qwen2ForSequenceClassification`)** | `jason9693/Qwen2.5-1.5B-apeach` | A smaller Qwen2.5 variant used for sequence classification, offering an alternative RLHF scoring mechanism. |
sglang/docs/supported_models/text_generation/diffusion_language_models.md ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diffusion Language Models
2
+
3
+ Diffusion language models have shown promise for non-autoregressive text generation with parallel decoding capabilities. Unlike auto-regressive language models, different diffusion language models require different decoding strategies.
4
+
5
+ ## Example Launch Command
6
+
7
+ SGLang supports different DLLM algorithms such as `LowConfidence` and `JointThreshold`.
8
+
9
+ ```shell
10
+ python3 -m sglang.launch_server \
11
+ --model-path inclusionAI/LLaDA2.0-mini \ # example HF/local path
12
+ --dllm-algorithm LowConfidence \
13
+ --dllm-algorithm-config ./config.yaml \ # Optional. Uses the algorithm's default if not set.
14
+ --host 0.0.0.0 \
15
+ --port 30000
16
+ ```
17
+
18
+ ## Example Configuration File
19
+
20
+ Depending on the algorithm selected, the configuration parameters vary.
21
+
22
+ LowConfidence Config:
23
+
24
+ ```yaml
25
+ # Confidence threshold for accepting predicted tokens
26
+ # - Higher values: More conservative, better quality but slower
27
+ # - Lower values: More aggressive, faster but potentially lower quality
28
+ # Range: 0.0 - 1.0
29
+ threshold: 0.95
30
+
31
+ # Default: 32, for LLaDA2MoeModelLM
32
+ block_size: 32
33
+ ```
34
+
35
+ JointThreshold Config:
36
+
37
+ ```yaml
38
+ # Decoding threshold for Mask-to-Token (M2T) phase
39
+ # - Higher values: More conservative, better quality but slower
40
+ # - Lower values: More aggressive, faster but potentially lower quality
41
+ # Range: 0.0 - 1.0
42
+ threshold: 0.5
43
+ # Decoding threshold for Token-to-Token (T2T) phase
44
+ # Range: 0.0 - 1.0
45
+ # Setting to 0.0 allows full editing (recommended for most cases).
46
+ edit_threshold: 0.0
47
+ # Max extra T2T steps after all masks are removed. Prevents infinite loops.
48
+ max_post_edit_steps: 16
49
+ # 2-gram repetition penalty (default 0).
50
+ # An empirical value of 3 is often sufficient to mitigate most repetitions.
51
+ penalty_lambda: 0
52
+ ```
53
+
54
+ ## Example Client Code Snippet
55
+
56
+ Just like other supported models, diffusion language models can be used via the REST API or Python client.
57
+
58
+ Python client example for making a generation request to the launched server:
59
+
60
+ ```python
61
+ import sglang as sgl
62
+
63
+ def main():
64
+ llm = sgl.Engine(model_path="inclusionAI/LLaDA2.0-mini",
65
+ dllm_algorithm="LowConfidence",
66
+ max_running_requests=1,
67
+ trust_remote_code=True)
68
+
69
+ prompts = [
70
+ "<role>SYSTEM</role>detailed thinking off<|role_end|><role>HUMAN</role> Write a brief introduction of the great wall <|role_end|><role>ASSISTANT</role>"
71
+ ]
72
+
73
+ sampling_params = {
74
+ "temperature": 0,
75
+ "max_new_tokens": 1024,
76
+ }
77
+
78
+ outputs = llm.generate(prompts, sampling_params)
79
+ print(outputs)
80
+
81
+ if __name__ == '__main__':
82
+ main()
83
+ ```
84
+
85
+ Curl example for making a generation request to the launched server:
86
+
87
+ ```bash
88
+ curl -X POST "http://127.0.0.1:30000/generate" \
89
+ -H "Content-Type: application/json" \
90
+ -d '{
91
+ "text": [
92
+ "<role>SYSTEM</role>detailed thinking off<|role_end|><role>HUMAN</role> Write the number from 1 to 128 <|role_end|><role>ASSISTANT</role>",
93
+ "<role>SYSTEM</role>detailed thinking off<|role_end|><role>HUMAN</role> Write a brief introduction of the great wall <|role_end|><role>ASSISTANT</role>"
94
+ ],
95
+ "stream": true,
96
+ "sampling_params": {
97
+ "temperature": 0,
98
+ "max_new_tokens": 1024
99
+ }
100
+ }'
101
+ ```
102
+
103
+ ## Supported Models
104
+
105
+ Below the supported models are summarized in a table.
106
+
107
+ | Model Family | Example Model | Description |
108
+ | -------------------------- | ---------------------------- | ---------------------------------------------------------------------------------------------------- |
109
+ | **LLaDA2.0 (mini, flash)** | `inclusionAI/LLaDA2.0-flash` | LLaDA2.0-flash is a diffusion language model featuring a 100B Mixture-of-Experts (MoE) architecture. |
110
+ | **SDAR (JetLM)** | `JetLM/SDAR-8B-Chat` | SDAR series diffusion language model (Chat), dense architecture. |
111
+ | **SDAR (JetLM)** | `JetLM/SDAR-30B-A3B-Chat` | SDAR series diffusion language model (Chat), MoE architecture. |
sglang/docs/supported_models/text_generation/generative_models.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Large Language Models
2
+
3
+ These models accept text input and produce text output (e.g., chat completions). They are primarily large language models (LLMs), some with mixture-of-experts (MoE) architectures for scaling.
4
+
5
+ ## Example launch Command
6
+
7
+ ```shell
8
+ python3 -m sglang.launch_server \
9
+ --model-path meta-llama/Llama-3.2-1B-Instruct \ # example HF/local path
10
+ --host 0.0.0.0 \
11
+ --port 30000 \
12
+ ```
13
+
14
+ ## Supported models
15
+
16
+ Below the supported models are summarized in a table.
17
+
18
+ If you are unsure if a specific architecture is implemented, you can search for it via GitHub. For example, to search for `Qwen3ForCausalLM`, use the expression:
19
+
20
+ ```
21
+ repo:sgl-project/sglang path:/^python\/sglang\/srt\/models\// Qwen3ForCausalLM
22
+ ```
23
+
24
+ in the GitHub search bar.
25
+
26
+ | Model Family (Variants) | Example HuggingFace Identifier | Description |
27
+ |-------------------------------------|--------------------------------------------------|----------------------------------------------------------------------------------------|
28
+ | **DeepSeek** (v1, v2, v3/R1) | `deepseek-ai/DeepSeek-R1` | Series of advanced reasoning-optimized models (including a 671B MoE) trained with reinforcement learning; top performance on complex reasoning, math, and code tasks. [SGLang provides Deepseek v3/R1 model-specific optimizations](../basic_usage/deepseek.md) and [Reasoning Parser](../advanced_features/separate_reasoning.ipynb)|
29
+ | **Kimi K2** (Thinking, Instruct) | `moonshotai/Kimi-K2-Instruct` | Moonshot AI's 1 trillion parameter MoE model (32B active) with 128K–256K context; state-of-the-art agentic intelligence with stable long-horizon agency across 200–300 sequential tool calls. Features MLA attention and native INT4 quantization. [See Reasoning Parser docs](../advanced_features/separate_reasoning.ipynb)|
30
+ | **Kimi Linear** (48B-A3B) | `moonshotai/Kimi-Linear-48B-A3B-Instruct` | Moonshot AI's hybrid linear attention model (48B total, 3B active) with 1M token context; features Kimi Delta Attention (KDA) for up to 6× faster decoding and 75% KV cache reduction vs full attention. |
31
+ | **GPT-OSS** | `openai/gpt-oss-20b`, `openai/gpt-oss-120b` | OpenAI’s latest GPT-OSS series for complex reasoning, agentic tasks, and versatile developer use cases.|
32
+ | **Qwen** (3.5, 3, 3MoE, 3Next, 2.5, 2 series) | `Qwen/Qwen3.5-397B-A17B`, `Qwen/Qwen3-0.6B`, `Qwen/Qwen3-30B-A3B` | Alibaba’s latest Qwen3 series for complex reasoning, language understanding, and generation tasks; Support for MoE variants along with previous generation 2.5, 2, etc. [SGLang provides Qwen3 specific reasoning parser](../advanced_features/separate_reasoning.ipynb)|
33
+ | **Llama** (2, 3.x, 4 series) | `meta-llama/Llama-4-Scout-17B-16E-Instruct` | Meta's open LLM series, spanning 7B to 400B parameters (Llama 2, 3, and new Llama 4) with well-recognized performance. [SGLang provides Llama-4 model-specific optimizations](../basic_usage/llama4.md) |
34
+ | **Mistral** (Mixtral, NeMo, Small3) | `mistralai/Mistral-7B-Instruct-v0.2` | Open 7B LLM by Mistral AI with strong performance; extended into MoE (“Mixtral”) and NeMo Megatron variants for larger scale. |
35
+ | **Gemma** (v1, v2, v3) | `google/gemma-3-1b-it` | Google’s family of efficient multilingual models (1B–27B); Gemma 3 offers a 128K context window, and its larger (4B+) variants support vision input. |
36
+ | **Phi** (Phi-1.5, Phi-2, Phi-3, Phi-4, Phi-MoE series) | `microsoft/Phi-4-multimodal-instruct`, `microsoft/Phi-3.5-MoE-instruct` | Microsoft’s Phi family of small models (1.3B–5.6B); Phi-4-multimodal (5.6B) processes text, images, and speech, Phi-4-mini is a high-accuracy text model and Phi-3.5-MoE is a mixture-of-experts model. |
37
+ | **MiniCPM** (v3, 4B) | `openbmb/MiniCPM3-4B` | OpenBMB’s series of compact LLMs for edge devices; MiniCPM 3 (4B) achieves GPT-3.5-level results in text tasks. |
38
+ | **OLMo** (2, 3) | `allenai/OLMo-3-1125-32B`, `allenai/OLMo-3-32B-Think`, `allenai/OLMo-2-1124-7B-Instruct` | Allen AI’s series of Open Language Models designed to enable the science of language models. |
39
+ | **OLMoE** (Open MoE) | `allenai/OLMoE-1B-7B-0924` | Allen AI’s open Mixture-of-Experts model (7B total, 1B active parameters) delivering state-of-the-art results with sparse expert activation. |
40
+ | **MiniMax-M2** (M2, M2.1, M2.5) | `MiniMaxAI/MiniMax-M2.5`, `MiniMaxAI/MiniMax-M2.1`, `MiniMaxAI/MiniMax-M2` | MiniMax's SOTA LLM for coding & agentic workflows. |
41
+ | **StableLM** (3B, 7B) | `stabilityai/stablelm-tuned-alpha-7b` | StabilityAI’s early open-source LLM (3B & 7B) for general text generation; a demonstration model with basic instruction-following ability. |
42
+ | **Command-(R,A)** (Cohere) | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, `CohereLabs/c4ai-command-a-03-2025` | Cohere’s open conversational LLM (Command series) optimized for long context, retrieval-augmented generation, and tool use. |
43
+ | **DBRX** (Databricks) | `databricks/dbrx-instruct` | Databricks’ 132B-parameter MoE model (36B active) trained on 12T tokens; competes with GPT-3.5 quality as a fully open foundation model. |
44
+ | **Grok** (xAI) | `xai-org/grok-1` | xAI’s grok-1 model known for vast size(314B parameters) and high quality; integrated in SGLang for high-performance inference. |
45
+ | **ChatGLM** (GLM-130B family) | `THUDM/chatglm2-6b` | Zhipu AI’s bilingual chat model (6B) excelling at Chinese-English dialogue; fine-tuned for conversational quality and alignment. |
46
+ | **InternLM 2** (7B, 20B) | `internlm/internlm2-7b` | Next-gen InternLM (7B and 20B) from SenseTime, offering strong reasoning and ultra-long context support (up to 200K tokens). |
47
+ | **ExaONE 3** (Korean-English) | `LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct` | LG AI Research’s Korean-English model (7.8B) trained on 8T tokens; provides high-quality bilingual understanding and generation. |
48
+ | **Baichuan 2** (7B, 13B) | `baichuan-inc/Baichuan2-13B-Chat` | BaichuanAI’s second-generation Chinese-English LLM (7B/13B) with improved performance and an open commercial license. |
49
+ | **XVERSE** (MoE) | `xverse/XVERSE-MoE-A36B` | Yuanxiang’s open MoE LLM (XVERSE-MoE-A36B: 255B total, 36B active) supporting ~40 languages; delivers 100B+ dense-level performance via expert routing. |
50
+ | **SmolLM** (135M–1.7B) | `HuggingFaceTB/SmolLM-1.7B` | Hugging Face’s ultra-small LLM series (135M–1.7B params) offering surprisingly strong results, enabling advanced AI on mobile/edge devices. |
51
+ | **GLM-4** (Multilingual 9B) | `ZhipuAI/glm-4-9b-chat` | Zhipu’s GLM-4 series (up to 9B parameters) – open multilingual models with support for 1M-token context and even a 5.6B multimodal variant (Phi-4V). |
52
+ | **MiMo** (7B series) | `XiaomiMiMo/MiMo-7B-RL` | Xiaomi's reasoning-optimized model series, leverages Multiple-Token Prediction for faster inference. |
53
+ | **ERNIE-4.5** (4.5, 4.5MoE series) | `baidu/ERNIE-4.5-21B-A3B-PT` | Baidu's ERNIE-4.5 series which consists of MoE with 47B and 3B active parameters, with the largest model having 424B total parameters, as well as a 0.3B dense model. |
54
+ | **Arcee AFM-4.5B** | `arcee-ai/AFM-4.5B-Base` | Arcee's foundational model series for real world reliability and edge deployments. |
55
+ | **Persimmon** (8B) | `adept/persimmon-8b-chat` | Adept’s open 8B model with a 16K context window and fast inference; trained for broad usability and licensed under Apache 2.0. |
56
+ | **Solar** (10.7B) | `upstage/SOLAR-10.7B-Instruct-v1.0` | Upstage's 10.7B parameter model, optimized for instruction-following tasks. This architecture incorporates a depth-up scaling methodology, enhancing model performance. |
57
+ | **Tele FLM** (52B-1T) | `CofeAI/Tele-FLM` | BAAI & TeleAI's multilingual model, available in 52-billion and 1-trillion parameter variants. It is a decoder-only transformer trained on ~2T tokens |
58
+ | **Ling** (16.8B–290B) | `inclusionAI/Ling-lite`, `inclusionAI/Ling-plus` | InclusionAI’s open MoE models. Ling-Lite has 16.8B total / 2.75B active parameters, and Ling-Plus has 290B total / 28.8B active parameters. They are designed for high performance on NLP and complex reasoning tasks. |
59
+ | **Granite 3.0, 3.1** (IBM) | `ibm-granite/granite-3.1-8b-instruct` | IBM's open dense foundation models optimized for reasoning, code, and business AI use cases. Integrated with Red Hat and watsonx systems. |
60
+ | **Granite 3.0 MoE** (IBM) | `ibm-granite/granite-3.0-3b-a800m-instruct` | IBM’s Mixture-of-Experts models offering strong performance with cost-efficiency. MoE expert routing designed for enterprise deployment at scale. |
61
+ | **GPT-J** (6B) | `EleutherAI/gpt-j-6b` | EleutherAI's GPT-2-like causal language model (6B) trained on the [Pile](https://pile.eleuther.ai/) dataset. |
62
+ | **Orion** (14B) | `OrionStarAI/Orion-14B-Base` | A series of open-source multilingual large language models by OrionStarAI, pretrained on a 2.5T token multilingual corpus including Chinese, English, Japanese, Korean, etc, and it exhibits superior performance in these languages. |
63
+ | **Llama Nemotron Super** (v1, v1.5, NVIDIA) | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, `nvidia/Llama-3_3-Nemotron-Super-49B-v1_5` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. |
64
+ | **Llama Nemotron Ultra** (v1, NVIDIA) | `nvidia/Llama-3_1-Nemotron-Ultra-253B-v1` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. |
65
+ | **NVIDIA Nemotron Nano 2.0** | `nvidia/NVIDIA-Nemotron-Nano-9B-v2` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. `Nemotron-Nano-9B-v2` is a hybrid Mamba-Transformer language model designed to increase throughput for reasoning workloads while achieving state-of-the-art accuracy compared to similarly-sized models. |
66
+ | **StarCoder2** (3B-15B) | `bigcode/starcoder2-7b` | StarCoder2 is a family of open large language models (LLMs) specialized for code generation and understanding. It is the successor to StarCoder, jointly developed by the BigCode project (a collaboration between Hugging Face, ServiceNow Research, and other contributors). |
67
+ | **Jet-Nemotron** | `jet-ai/Jet-Nemotron-2B` | Jet-Nemotron is a new family of hybrid-architecture language models that surpass state-of-the-art open-source full-attention language models, while achieving significant efficiency gains. |
68
+ | **Trinity** (Nano, Mini) | `arcee-ai/Trinity-Mini` | Arcee's foundational MoE Trinity family of models, open weights under Apache 2.0. |
69
+ | **Falcon-H1** (0.5B–34B) | `tiiuae/Falcon-H1-34B-Instruct` | TII's hybrid Mamba-Transformer architecture combining attention and state-space models for efficient long-context inference. |
70
+ | **Hunyuan-Large** (389B, MoE) | `tencent/Tencent-Hunyuan-Large` | Tencent's open-source MoE model with 389B total / 52B active parameters, featuring Cross-Layer Attention (CLA) for improved efficiency. |
71
+ | **IBM Granite 4.0 (Hybrid, Dense)** | `ibm-granite/granite-4.0-h-micro`, `ibm-granite/granite-4.0-micro` | IBM Granite 4.0 micro models: hybrid Mamba–MoE (`h-micro`) and dense (`micro`) variants. Enterprise-focused reasoning models |
72
+ | **Sarvam 2** (30B-A2B, 105B-A10B) | `sarvamai/sarvam-2` | Sarvam's Mixture-of-Experts models. The 105B variant uses MLA (Multi-head Latent Attention) and the 30B variant uses GQA, both with 128 routed experts. |
sglang/docs/supported_models/text_generation/index.rst ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Text Generation
2
+ ===============
3
+
4
+ Models for generating text from text or multimodal inputs.
5
+
6
+ .. toctree::
7
+ :maxdepth: 1
8
+
9
+ generative_models.md
10
+ multimodal_language_models.md
11
+ diffusion_language_models.md
sglang/docs/supported_models/text_generation/multimodal_language_models.md ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multimodal Language Models
2
+
3
+ These models accept multi-modal inputs (e.g., images and text) and generate text output. They augment language models with multimodal encoders.
4
+
5
+ ## Example launch Command
6
+
7
+ ```shell
8
+ python3 -m sglang.launch_server \
9
+ --model-path meta-llama/Llama-3.2-11B-Vision-Instruct \ # example HF/local path
10
+ --host 0.0.0.0 \
11
+ --port 30000 \
12
+ ```
13
+
14
+ > See the [OpenAI APIs section](https://docs.sglang.io/basic_usage/openai_api_vision.html) for how to send multimodal requests.
15
+
16
+ ## Supported models
17
+
18
+ Below the supported models are summarized in a table.
19
+
20
+ If you are unsure if a specific architecture is implemented, you can search for it via GitHub. For example, to search for `Qwen2_5_VLForConditionalGeneration`, use the expression:
21
+
22
+ ```
23
+ repo:sgl-project/sglang path:/^python\/sglang\/srt\/models\// Qwen2_5_VLForConditionalGeneration
24
+ ```
25
+
26
+ in the GitHub search bar.
27
+
28
+
29
+ | Model Family (Variants) | Example HuggingFace Identifier | Description | Notes |
30
+ |----------------------------|--------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------|
31
+ | **Qwen-VL** | `Qwen/Qwen3-VL-235B-A22B-Instruct` | Alibaba's vision-language extension of Qwen; for example, Qwen2.5-VL (7B and larger variants) can analyze and converse about image content. | |
32
+ | **DeepSeek-VL2** | `deepseek-ai/deepseek-vl2` | Vision-language variant of DeepSeek (with a dedicated image processor), enabling advanced multimodal reasoning on image and text inputs. | |
33
+ | **DeepSeek-OCR / OCR-2** | `deepseek-ai/DeepSeek-OCR-2` | OCR-focused DeepSeek models for document understanding and text extraction. | Use `--trust-remote-code`. |
34
+ | **Janus-Pro** (1B, 7B) | `deepseek-ai/Janus-Pro-7B` | DeepSeek's open-source multimodal model capable of both image understanding and generation. Janus-Pro employs a decoupled architecture for separate visual encoding paths, enhancing performance in both tasks. | |
35
+ | **MiniCPM-V / MiniCPM-o** | `openbmb/MiniCPM-V-2_6` | MiniCPM-V (2.6, ~8B) supports image inputs, and MiniCPM-o adds audio/video; these multimodal LLMs are optimized for end-side deployment on mobile/edge devices. | |
36
+ | **Llama 3.2 Vision** (11B) | `meta-llama/Llama-3.2-11B-Vision-Instruct` | Vision-enabled variant of Llama 3 (11B) that accepts image inputs for visual question answering and other multimodal tasks. | |
37
+ | **LLaVA** (v1.5 & v1.6) | *e.g.* `liuhaotian/llava-v1.5-13b` | Open vision-chat models that add an image encoder to LLaMA/Vicuna (e.g. LLaMA2 13B) for following multimodal instruction prompts. | |
38
+ | **LLaVA-NeXT** (8B, 72B) | `lmms-lab/llava-next-72b` | Improved LLaVA models (with an 8B Llama3 version and a 72B version) offering enhanced visual instruction-following and accuracy on multimodal benchmarks. | |
39
+ | **LLaVA-OneVision** | `lmms-lab/llava-onevision-qwen2-7b-ov` | Enhanced LLaVA variant integrating Qwen as the backbone; supports multiple images (and even video frames) as inputs via an OpenAI Vision API-compatible format. | |
40
+ | **Gemma 3 (Multimodal)** | `google/gemma-3-4b-it` | Gemma 3's larger models (4B, 12B, 27B) accept images (each image encoded as 256 tokens) alongside text in a combined 128K-token context. | |
41
+ | **Kimi-VL** (A3B) | `moonshotai/Kimi-VL-A3B-Instruct` | Kimi-VL is a multimodal model that can understand and generate text from images. | |
42
+ | **Mistral-Small-3.1-24B** | `mistralai/Mistral-Small-3.1-24B-Instruct-2503` | Mistral 3.1 is a multimodal model that can generate text from text or images input. It also supports tool calling and structured output. | |
43
+ | **Phi-4-multimodal-instruct** | `microsoft/Phi-4-multimodal-instruct` | Phi-4-multimodal-instruct is the multimodal variant of the Phi-4-mini model, enhanced with LoRA for improved multimodal capabilities. It supports text, vision and audio modalities in SGLang. | |
44
+ | **MiMo-VL** (7B) | `XiaomiMiMo/MiMo-VL-7B-RL` | Xiaomi's compact yet powerful vision-language model featuring a native resolution ViT encoder for fine-grained visual details, an MLP projector for cross-modal alignment, and the MiMo-7B language model optimized for complex reasoning tasks. | |
45
+ | **GLM-4.5V** (106B) / **GLM-4.1V**(9B) | `zai-org/GLM-4.5V` | GLM-4.5V and GLM-4.1V-Thinking: Towards Versatile Multimodal Reasoning with Scalable Reinforcement Learning | Use `--chat-template glm-4v` |
46
+ | **GLM-OCR** | `zai-org/GLM-OCR` | GLM-OCR: A fast and accurate general OCR model | |
47
+ | **DotsVLM** (General/OCR) | `rednote-hilab/dots.vlm1.inst` | RedNote's vision-language model built on a 1.2B vision encoder and DeepSeek V3 LLM, featuring NaViT vision encoder trained from scratch with dynamic resolution support and enhanced OCR capabilities through structured image data training. | |
48
+ | **DotsVLM-OCR** | `rednote-hilab/dots.ocr` | Specialized OCR variant of DotsVLM optimized for optical character recognition tasks with enhanced text extraction and document understanding capabilities. | Don't use `--trust-remote-code` |
49
+ | **NVILA** (8B, 15B, Lite-2B, Lite-8B, Lite-15B) | `Efficient-Large-Model/NVILA-8B` | `chatml` | NVILA explores the full stack efficiency of multi-modal design, achieving cheaper training, faster deployment and better performance. |
50
+ | **NVIDIA Nemotron Nano 2.0 VL** | `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16` | NVIDIA Nemotron Nano v2 VL enables multi-image reasoning and video understanding, along with strong document intelligence, visual Q&A and summarization capabilities. It builds on Nemotron Nano V2, a hybrid Mamba-Transformer LLM, in order to achieve higher inference throughput in long document and video scenarios. | Use `--trust-remote-code`. You may need to adjust `--max-mamba-cache-size` [default is 512] to fit memory constraints. |
51
+ | **Ernie4.5-VL** | `baidu/ERNIE-4.5-VL-28B-A3B-PT` | Baidu's vision-language models(28B,424B). Support image and video comprehension, and also support thinking. | |
52
+ | **JetVLM** | | JetVLM is an vision-language model designed for high-performance multimodal understanding and generation tasks built upon Jet-Nemotron. | Coming soon |
53
+ | **Step3-VL** (10B) | `stepfun-ai/Step3-VL-10B` | StepFun's lightweight open-source 10B parameter VLM for multimodal intelligence, excelling in visual perception, complex reasoning, and human alignment. | |
54
+ | **Qwen3-Omni** | `Qwen/Qwen3-Omni-30B-A3B-Instruct` | Alibaba's omni-modal MoE model. Currently supports the **Thinker** component (multimodal understanding for text, images, audio, and video), while the **Talker** component (audio generation) is not yet supported. | |
55
+
56
+ ## Video Input Support
57
+
58
+ SGLang supports video input for Vision-Language Models (VLMs), enabling temporal reasoning tasks such as video question answering, captioning, and holistic scene understanding. Video clips are decoded, key frames are sampled, and the resulting tensors are batched together with the text prompt, allowing multimodal inference to integrate visual and linguistic context.
59
+
60
+ | Model Family | Example Identifier | Video notes |
61
+ |--------------|--------------------|-------------|
62
+ | **Qwen-VL** (Qwen2-VL, Qwen2.5-VL, Qwen3-VL, Qwen3-Omni) | `Qwen/Qwen3-VL-235B-A22B-Instruct` | The processor gathers `video_data`, runs Qwen's frame sampler, and merges the resulting features with text tokens before inference. |
63
+ | **GLM-4v** (4.5V, 4.1V, MOE) | `zai-org/GLM-4.5V` | Video clips are read with Decord, converted to tensors, and passed to the model alongside metadata for rotary-position handling. |
64
+ | **NVILA** (Full & Lite) | `Efficient-Large-Model/NVILA-8B` | The runtime samples eight frames per clip and attaches them to the multimodal request when `video_data` is present. |
65
+ | **LLaVA video variants** (LLaVA-NeXT-Video, LLaVA-OneVision) | `lmms-lab/LLaVA-NeXT-Video-7B` | The processor routes video prompts to the LlavaVid video-enabled architecture, and the provided example shows how to query it with `sgl.video(...)` clips. |
66
+ | **NVIDIA Nemotron Nano 2.0 VL** | `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16` | The processor samples at 2 FPS, at a max of 128 frames, as per model training. The model uses [EVS](../../python/sglang/srt/multimodal/evs/README.md), a pruning method that removes redundant tokens from video embeddings. By default `video_pruning_rate=0.7`. Change this by providing: `--json-model-override-args '{"video_pruning_rate": 0.0}'` to disable EVS, for example. |
67
+ | **JetVLM** | | The runtime samples eight frames per clip and attaches them to the multimodal request when `video_data` is present. |
68
+
69
+ Use `sgl.video(path, num_frames)` when building prompts to attach clips from your SGLang programs.
70
+
71
+ Example OpenAI-compatible request that sends a video clip:
72
+
73
+ ```python
74
+ import requests
75
+
76
+ url = "http://localhost:30000/v1/chat/completions"
77
+
78
+ data = {
79
+ "model": "Qwen/Qwen3-VL-30B-A3B-Instruct",
80
+ "messages": [
81
+ {
82
+ "role": "user",
83
+ "content": [
84
+ {"type": "text", "text": "What’s happening in this video?"},
85
+ {
86
+ "type": "video_url",
87
+ "video_url": {
88
+ "url": "https://github.com/sgl-project/sgl-test-files/raw/refs/heads/main/videos/jobs_presenting_ipod.mp4"
89
+ },
90
+ },
91
+ ],
92
+ }
93
+ ],
94
+ "max_tokens": 300,
95
+ }
96
+
97
+ response = requests.post(url, json=data)
98
+ print(response.text)
99
+ ```
100
+
101
+ ## Usage Notes
102
+
103
+ ### Performance Optimization
104
+
105
+ For multimodal models, you can use the `--keep-mm-feature-on-device` flag to optimize for latency at the cost of increased GPU memory usage:
106
+
107
+ - **Default behavior**: Multimodal feature tensors are moved to CPU after processing to save GPU memory
108
+ - **With `--keep-mm-feature-on-device`**: Feature tensors remain on GPU, reducing device-to-host copy overhead and improving latency, but consuming more GPU memory
109
+
110
+ Use this flag when you have sufficient GPU memory and want to minimize latency for multimodal inference.
111
+
112
+ ### Multimodal Inputs Limitation
113
+
114
+ - **Use `--mm-process-config '{"image":{"max_pixels":1048576},"video":{"fps":3,"max_pixels":602112,"max_frames":60}}'`**: To set `image`, `video`, and `audio` input limits.
115
+
116
+ This can reduce GPU memory usage, improve inference speed, and help to avoid OOM, but may impact model performance, thus set a proper value based on your specific use case. Currently, only `qwen_vl` supports this config. Please refer to [qwen_vl processor](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/multimodal/processors/qwen_vl.py) for understanding the meaning of each parameter.
117
+
118
+ ### Bidirectional Attention in Multimodal Model Serving
119
+ **Note for serving the Gemma-3 multimodal model**:
120
+
121
+ As mentioned in [Welcome Gemma 3: Google's all new multimodal, multilingual, long context open LLM
122
+ ](https://huggingface.co/blog/gemma3#multimodality), Gemma-3 employs bidirectional attention between image tokens during the prefill phase. Currently, SGLang only supports bidirectional attention when using the Triton Attention Backend. Note, however, that SGLang's current bidirectional attention implementation is incompatible with both CUDA Graph and Chunked Prefill.
123
+
124
+ To enable bidirectional attention, you can use the `TritonAttnBackend` while disabling CUDA Graph and Chunked Prefill. Example launch command:
125
+ ```shell
126
+ python -m sglang.launch_server \
127
+ --model-path google/gemma-3-4b-it \
128
+ --host 0.0.0.0 --port 30000 \
129
+ --enable-multimodal \
130
+ --dtype bfloat16 --triton-attention-reduce-in-fp32 \
131
+ --attention-backend triton \ # Use Triton attention backend
132
+ --disable-cuda-graph \ # Disable Cuda Graph
133
+ --chunked-prefill-size -1 # Disable Chunked Prefill
134
+ ```
135
+
136
+ If higher serving performance is required and a certain degree of accuracy loss is acceptable, you may choose to use other attention backends, and you can also enable features like CUDA Graph and Chunked Prefill for better performance, but note that the model will fall back to using causal attention instead of bidirectional attention.
sglang/python/sglang/srt/__pycache__/constants.cpython-311.pyc ADDED
Binary file (348 Bytes). View file
 
sglang/python/sglang/srt/__pycache__/environ.cpython-311.pyc ADDED
Binary file (35.1 kB). View file
 
sglang/python/sglang/srt/batch_overlap/__pycache__/operations.cpython-311.pyc ADDED
Binary file (12.4 kB). View file
 
sglang/python/sglang/srt/batch_overlap/__pycache__/operations_strategy.cpython-311.pyc ADDED
Binary file (12.4 kB). View file
 
sglang/python/sglang/srt/batch_overlap/__pycache__/single_batch_overlap.cpython-311.pyc ADDED
Binary file (5.8 kB). View file
 
sglang/python/sglang/srt/batch_overlap/__pycache__/two_batch_overlap.cpython-311.pyc ADDED
Binary file (42.2 kB). View file
 
sglang/python/sglang/srt/batch_overlap/operations.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from contextlib import contextmanager
5
+ from dataclasses import dataclass
6
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Sequence, Union
7
+
8
+ import torch
9
+
10
+ from sglang.srt.layers.dp_attention import set_dp_buffer_len
11
+
12
+ if TYPE_CHECKING:
13
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
14
+
15
+ _ENABLE_PROFILE = bool(int(os.environ.get("SGLANG_OPERATIONS_ENABLE_PROFILE", "0")))
16
+
17
+ if _ENABLE_PROFILE:
18
+ import nvtx
19
+
20
+
21
+ def execute_operations(inputs, operations):
22
+ stages = _convert_operations_to_stages(operations)
23
+ executor = _StageExecutor("primary", stages, inputs=inputs)
24
+ for _ in range(executor.num_stages):
25
+ executor.next()
26
+ assert executor.done
27
+ return executor.output
28
+
29
+
30
+ def execute_overlapped_operations(
31
+ inputs_arr: Sequence,
32
+ operations_arr: Sequence,
33
+ delta_stages: Sequence[int],
34
+ ) -> Sequence:
35
+ # Make it explicit for clarity; if we need multi-batch overlap, this can be generalized
36
+ inputs_a, inputs_b = inputs_arr
37
+ operations_a, operations_b = operations_arr
38
+ delta_stage_a, delta_stage_b = delta_stages
39
+ assert delta_stage_a == 0
40
+ delta_stage = delta_stage_b
41
+
42
+ stages_a = _convert_operations_to_stages(operations_a)
43
+ stages_b = _convert_operations_to_stages(operations_b)
44
+ executor_a = _StageExecutor("a", stages_a, inputs=inputs_a)
45
+ executor_b = _StageExecutor("b", stages_b, inputs=inputs_b)
46
+
47
+ for _ in range(delta_stage):
48
+ executor_a.next()
49
+
50
+ for _ in range(executor_a.num_stages - delta_stage):
51
+ executor_a.next()
52
+ executor_b.next()
53
+
54
+ for _ in range(delta_stage):
55
+ executor_b.next()
56
+
57
+ assert executor_a.done and executor_b.done
58
+ return [executor_a.output, executor_b.output]
59
+
60
+
61
+ class YieldOperation:
62
+ pass
63
+
64
+
65
+ @dataclass
66
+ class ExecutionOperation:
67
+ debug_name: str
68
+ fn: Callable
69
+
70
+
71
+ Operation = Union[YieldOperation, ExecutionOperation, Callable]
72
+ Stage = List[ExecutionOperation]
73
+
74
+
75
+ class _StageExecutor:
76
+ def __init__(self, debug_name: str, stages: List[Stage], inputs: dict):
77
+ self._debug_name = debug_name
78
+ self._stages = stages
79
+ self._index = 0
80
+ self._stage_state = _StateDict()
81
+ self._stage_output = inputs
82
+
83
+ # handling DP attention
84
+ forward_batch: ForwardBatch = inputs["forward_batch"]
85
+ self._global_dp_buffer_len = forward_batch.global_dp_buffer_len
86
+ self._local_dp_buffer_len = forward_batch.tbo_padded_len
87
+ self._global_num_tokens = forward_batch.global_num_tokens_cpu
88
+ self._is_dp_max_padding = forward_batch.dp_padding_mode.is_max_len()
89
+
90
+ def next(self):
91
+ assert not self.done
92
+
93
+ stage = self._stages[self._index]
94
+
95
+ # TODO: We currently always call set_dp_buffer_len here because sub-batches
96
+ # may have different padded lengths. It can likely be removed after TBO slice &
97
+ # pad logic is refactored.
98
+ set_dp_buffer_len(
99
+ self._global_dp_buffer_len,
100
+ self._local_dp_buffer_len,
101
+ self._is_dp_max_padding,
102
+ self._global_num_tokens,
103
+ )
104
+
105
+ with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
106
+ for op in stage:
107
+ with _annotate_region(debug_name=op.debug_name):
108
+ self._stage_output = op.fn(
109
+ state=self._stage_state,
110
+ **(
111
+ self._stage_output if self._stage_output is not None else {}
112
+ ),
113
+ )
114
+
115
+ self._index += 1
116
+
117
+ @property
118
+ def output(self):
119
+ assert self.done
120
+ return self._stage_output
121
+
122
+ @property
123
+ def done(self):
124
+ return self._index >= self.num_stages
125
+
126
+ @property
127
+ def num_stages(self):
128
+ return len(self._stages)
129
+
130
+
131
+ @contextmanager
132
+ def _annotate_region(debug_name):
133
+ if _ENABLE_PROFILE:
134
+ with torch.autograd.profiler.record_function(debug_name):
135
+ with nvtx.annotate(debug_name):
136
+ yield
137
+ else:
138
+ yield
139
+
140
+
141
+ class _StateDict:
142
+ def __init__(self):
143
+ self._data = {}
144
+
145
+ def __setattr__(self, key, value):
146
+ if key == "_data":
147
+ super().__setattr__(key, value)
148
+ return
149
+ assert (
150
+ key not in self._data
151
+ ), f"`{key}` already exist, are you sure you want to override it?"
152
+ self._data[key] = value
153
+
154
+ def __getattr__(self, item):
155
+ return self._data[item]
156
+
157
+ def __delattr__(self, item):
158
+ del self._data[item]
159
+
160
+ def pop(self, item):
161
+ return self._data.pop(item)
162
+
163
+ def update(self, values: Dict[str, Any]):
164
+ for k, v in values.items():
165
+ setattr(self, k, v)
166
+
167
+ def get(self, item):
168
+ return self._data.get(item)
169
+
170
+ def clear(self, expect_keys: Sequence[str]):
171
+ if set(self._data.keys()) != set(expect_keys):
172
+ raise Exception(
173
+ f"Unexpected keys when clearing. This may indicate you do not release memory early enough but leave it until here. {list(self._data.keys())=} {expect_keys=}"
174
+ )
175
+
176
+ self._data.clear()
177
+
178
+
179
+ def _convert_operations_to_stages(operations: List[Operation]) -> List[Stage]:
180
+ operations = _decorate_operations(operations)
181
+ operation_chunks = list(
182
+ _chunk_by_separator(operations, lambda op: isinstance(op, YieldOperation))
183
+ )
184
+ assert all(len(chunk) > 0 for chunk in operation_chunks)
185
+ return operation_chunks
186
+
187
+
188
+ def _chunk_by_separator(
189
+ items: List[Any], is_separator: Callable[[Any], bool]
190
+ ) -> Generator[List[Any], None, None]:
191
+ pending_items = []
192
+ for item in items:
193
+ if is_separator(item):
194
+ yield pending_items
195
+ pending_items = []
196
+ else:
197
+ pending_items.append(item)
198
+ if len(pending_items) > 0:
199
+ yield pending_items
200
+
201
+
202
+ def _decorate_operations(operations: List[Operation], debug_name_prefix: str = ""):
203
+ return [_decorate_operation(op, debug_name_prefix) for op in operations]
204
+
205
+
206
+ def _decorate_operation(operation: Operation, debug_name_prefix: str):
207
+ if isinstance(operation, YieldOperation):
208
+ return operation
209
+ return ExecutionOperation(
210
+ debug_name=debug_name_prefix
211
+ + getattr(operation, "__name__", "unknown").replace("op_", ""),
212
+ fn=operation,
213
+ )
sglang/python/sglang/srt/batch_overlap/operations_strategy.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional
3
+
4
+ import torch
5
+
6
+ from sglang.srt.batch_overlap import operations
7
+ from sglang.srt.batch_overlap.operations import Operation
8
+ from sglang.srt.layers.moe.token_dispatcher import DeepEPConfig
9
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
10
+ from sglang.srt.utils import is_hip
11
+
12
+ _is_hip = is_hip()
13
+
14
+
15
+ @dataclass
16
+ class OperationsStrategy:
17
+ operations: List[Operation]
18
+ deep_gemm_num_sms: Optional[int] = None
19
+ tbo_delta_stages: Optional[int] = None
20
+
21
+ @classmethod
22
+ def concat(cls, items: List["OperationsStrategy"]) -> "OperationsStrategy":
23
+ return OperationsStrategy(
24
+ operations=[x for item in items for x in item.operations],
25
+ deep_gemm_num_sms=_assert_all_same(
26
+ [item.deep_gemm_num_sms for item in items]
27
+ ),
28
+ tbo_delta_stages=_assert_all_same(
29
+ [item.tbo_delta_stages for item in items]
30
+ ),
31
+ )
32
+
33
+ @staticmethod
34
+ def init_new_tbo(
35
+ layers: torch.nn.ModuleList,
36
+ forward_mode: ForwardMode,
37
+ ) -> "OperationsStrategy":
38
+ layer_name = layers[0].__class__.__name__
39
+ if layer_name == "DeepseekV2DecoderLayer":
40
+ return OperationsStrategy.concat(
41
+ [
42
+ _compute_moe_deepseek_layer_operations_strategy_tbo(
43
+ layer, forward_mode
44
+ )
45
+ for layer in layers
46
+ ]
47
+ )
48
+ elif layer_name == "Qwen3MoeDecoderLayer":
49
+ return OperationsStrategy.concat(
50
+ [
51
+ _compute_moe_qwen3_layer_operations_strategy_tbo(
52
+ layer, forward_mode
53
+ )
54
+ for layer in layers
55
+ ]
56
+ )
57
+ elif layer_name == "MiMoV2DecoderLayer":
58
+ return OperationsStrategy.concat(
59
+ [
60
+ _compute_moe_mimov2_layer_operations_strategy_tbo(
61
+ layer, forward_mode
62
+ )
63
+ for layer in layers
64
+ ]
65
+ )
66
+ else:
67
+ raise NotImplementedError
68
+
69
+
70
+ def _assert_all_same(items: List):
71
+ assert all(item == items[0] for item in items)
72
+ return items[0]
73
+
74
+
75
+ # -------------------------------- Strategy for DeepSeek ---------------------------------------
76
+
77
+
78
+ # TODO can refactor to make it more fancy if we have more complex strategies
79
+ def _compute_moe_deepseek_layer_operations_strategy_tbo(
80
+ layer: torch.nn.Module,
81
+ forward_mode: ForwardMode,
82
+ ) -> OperationsStrategy:
83
+ assert layer.is_layer_sparse, "dense layer TBO not yet implemented"
84
+ if forward_mode == ForwardMode.EXTEND:
85
+ return _compute_moe_deepseek_blog_prefill(layer)
86
+ elif (
87
+ forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY
88
+ ):
89
+ return _compute_moe_deepseek_blog_decode(layer)
90
+ else:
91
+ raise NotImplementedError(f"Unsupported {forward_mode=}")
92
+
93
+
94
+ def _compute_moe_deepseek_blog_prefill(layer):
95
+ device_properties = torch.cuda.get_device_properties(device="cuda")
96
+ total_num_sms = device_properties.multi_processor_count
97
+ deep_gemm_num_sms = None
98
+ if not _is_hip:
99
+ deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
100
+
101
+ return OperationsStrategy(
102
+ deep_gemm_num_sms=deep_gemm_num_sms,
103
+ tbo_delta_stages=0,
104
+ operations=[
105
+ layer.op_comm_prepare_attn,
106
+ layer.self_attn.op_prepare,
107
+ layer.self_attn.op_core,
108
+ layer.op_comm_prepare_mlp,
109
+ layer.mlp.op_gate,
110
+ layer.mlp.op_select_experts,
111
+ layer.mlp.op_dispatch_a,
112
+ operations.YieldOperation(),
113
+ layer.mlp.op_dispatch_b,
114
+ layer.mlp.op_experts,
115
+ layer.mlp.op_combine_a,
116
+ operations.YieldOperation(),
117
+ layer.mlp.op_shared_experts,
118
+ layer.mlp.op_combine_b,
119
+ layer.mlp.op_output,
120
+ layer.op_comm_postprocess_layer,
121
+ ],
122
+ )
123
+
124
+
125
+ def _compute_moe_deepseek_blog_decode(layer):
126
+ return OperationsStrategy(
127
+ deep_gemm_num_sms=None,
128
+ tbo_delta_stages=2,
129
+ operations=[
130
+ layer.op_comm_prepare_attn,
131
+ layer.self_attn.op_prepare,
132
+ operations.YieldOperation(),
133
+ layer.self_attn.op_core,
134
+ layer.op_comm_prepare_mlp,
135
+ layer.mlp.op_gate,
136
+ layer.mlp.op_select_experts,
137
+ operations.YieldOperation(),
138
+ layer.mlp.op_dispatch_a,
139
+ layer.mlp.op_shared_experts,
140
+ operations.YieldOperation(),
141
+ layer.mlp.op_dispatch_b,
142
+ layer.mlp.op_experts,
143
+ layer.mlp.op_combine_a,
144
+ operations.YieldOperation(),
145
+ layer.mlp.op_combine_b,
146
+ operations.YieldOperation(),
147
+ layer.mlp.op_output,
148
+ layer.op_comm_postprocess_layer,
149
+ ],
150
+ )
151
+
152
+
153
+ # -------------------------------- Strategy for Qwen3 ---------------------------------------
154
+
155
+
156
+ # TODO: unstable, current strategy is almost the same as DeepSeek, keep redundant code here for
157
+ # convenience to adjust strategy
158
+ def _compute_moe_qwen3_layer_operations_strategy_tbo(
159
+ layer: torch.nn.Module,
160
+ forward_mode: ForwardMode,
161
+ ) -> OperationsStrategy:
162
+ assert layer.is_layer_sparse, "qwen3 moe only support sparse layers"
163
+ if forward_mode == ForwardMode.EXTEND:
164
+ return _compute_moe_qwen3_prefill(layer)
165
+ elif (
166
+ forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY
167
+ ):
168
+ return _compute_moe_qwen3_decode(layer)
169
+ else:
170
+ raise NotImplementedError(f"Unsupported {forward_mode=}")
171
+
172
+
173
+ def _compute_moe_qwen3_prefill(layer):
174
+ device_properties = torch.cuda.get_device_properties(device="cuda")
175
+ total_num_sms = device_properties.multi_processor_count
176
+ deep_gemm_num_sms = None
177
+ if not _is_hip:
178
+ deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
179
+
180
+ return OperationsStrategy(
181
+ deep_gemm_num_sms=deep_gemm_num_sms,
182
+ tbo_delta_stages=0,
183
+ operations=[
184
+ layer.op_comm_prepare_attn,
185
+ layer.self_attn.op_prepare,
186
+ layer.self_attn.op_core,
187
+ layer.op_comm_prepare_mlp,
188
+ layer.mlp.op_gate,
189
+ layer.mlp.op_select_experts,
190
+ layer.mlp.op_dispatch_a,
191
+ operations.YieldOperation(),
192
+ layer.mlp.op_dispatch_b,
193
+ layer.mlp.op_experts,
194
+ layer.mlp.op_combine_a,
195
+ operations.YieldOperation(),
196
+ layer.mlp.op_combine_b,
197
+ layer.mlp.op_output,
198
+ layer.op_comm_postprocess_layer,
199
+ ],
200
+ )
201
+
202
+
203
+ def _compute_moe_qwen3_decode(layer):
204
+ return OperationsStrategy(
205
+ deep_gemm_num_sms=None,
206
+ tbo_delta_stages=2,
207
+ operations=[
208
+ layer.op_comm_prepare_attn,
209
+ layer.self_attn.op_prepare,
210
+ operations.YieldOperation(),
211
+ layer.self_attn.op_core,
212
+ layer.op_comm_prepare_mlp,
213
+ layer.mlp.op_gate,
214
+ layer.mlp.op_select_experts,
215
+ operations.YieldOperation(),
216
+ layer.mlp.op_dispatch_a,
217
+ operations.YieldOperation(),
218
+ layer.mlp.op_dispatch_b,
219
+ layer.mlp.op_experts,
220
+ layer.mlp.op_combine_a,
221
+ operations.YieldOperation(),
222
+ layer.mlp.op_combine_b,
223
+ layer.mlp.op_output,
224
+ layer.op_comm_postprocess_layer,
225
+ operations.YieldOperation(),
226
+ ],
227
+ )
228
+
229
+
230
+ # -------------------------------- Strategy for MiMoV2DecoderLayer ---------------------------------------
231
+
232
+
233
+ # TODO: unstable; current strategy matches DeepSeek for the common operations (MiMoV2 has no op_shared_experts),
234
+ # so we keep this redundant code here for convenience when adjusting the strategy
235
+ def _compute_moe_mimov2_layer_operations_strategy_tbo(
236
+ layer: torch.nn.Module,
237
+ forward_mode: ForwardMode,
238
+ ) -> OperationsStrategy:
239
+ assert layer.is_layer_sparse, "MiMoV2DecoderLayer moe only support sparse layers"
240
+ if forward_mode == ForwardMode.EXTEND:
241
+ return _compute_moe_mimov2_prefill(layer)
242
+ elif (
243
+ forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY
244
+ ):
245
+ return _compute_moe_mimov2_decode(layer)
246
+ else:
247
+ raise NotImplementedError(f"Unsupported {forward_mode=}")
248
+
249
+
250
+ def _compute_moe_mimov2_prefill(layer):
251
+ device_properties = torch.cuda.get_device_properties(device="cuda")
252
+ total_num_sms = device_properties.multi_processor_count
253
+ deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
254
+
255
+ return OperationsStrategy(
256
+ deep_gemm_num_sms=deep_gemm_num_sms,
257
+ tbo_delta_stages=0,
258
+ operations=[
259
+ layer.op_comm_prepare_attn,
260
+ layer.self_attn.op_prepare,
261
+ layer.self_attn.op_core,
262
+ layer.op_comm_prepare_mlp,
263
+ layer.mlp.op_gate,
264
+ layer.mlp.op_select_experts,
265
+ layer.mlp.op_dispatch_a,
266
+ operations.YieldOperation(),
267
+ layer.mlp.op_dispatch_b,
268
+ layer.mlp.op_experts,
269
+ layer.mlp.op_combine_a,
270
+ operations.YieldOperation(),
271
+ layer.mlp.op_combine_b,
272
+ layer.mlp.op_output,
273
+ layer.op_comm_postprocess_layer,
274
+ ],
275
+ )
276
+
277
+
278
+ def _compute_moe_mimov2_decode(layer):
279
+ return OperationsStrategy(
280
+ deep_gemm_num_sms=None,
281
+ tbo_delta_stages=2,
282
+ operations=[
283
+ layer.op_comm_prepare_attn,
284
+ layer.self_attn.op_prepare,
285
+ operations.YieldOperation(),
286
+ layer.self_attn.op_core,
287
+ layer.op_comm_prepare_mlp,
288
+ layer.mlp.op_gate,
289
+ layer.mlp.op_select_experts,
290
+ operations.YieldOperation(),
291
+ layer.mlp.op_dispatch_a,
292
+ operations.YieldOperation(),
293
+ layer.mlp.op_dispatch_b,
294
+ layer.mlp.op_experts,
295
+ layer.mlp.op_combine_a,
296
+ operations.YieldOperation(),
297
+ layer.mlp.op_combine_b,
298
+ layer.mlp.op_output,
299
+ layer.op_comm_postprocess_layer,
300
+ operations.YieldOperation(),
301
+ ],
302
+ )
sglang/python/sglang/srt/batch_overlap/single_batch_overlap.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ from __future__ import annotations
16
+
17
+ from dataclasses import dataclass
18
+ from typing import Optional
19
+
20
+ import torch
21
+
22
+ from sglang.srt.environ import envs
23
+ from sglang.srt.layers.moe import get_moe_runner_backend
24
+ from sglang.srt.layers.moe.utils import is_sbo_enabled
25
+ from sglang.srt.utils import is_blackwell
26
+
27
+
28
+ class SboFlags:
29
+ # TODO may have: "enable_dispatch_gateup_gemm_two_stream_overlap", ...
30
+
31
+ @classmethod
32
+ def enable_combine_down_gemm_two_stream_overlap(cls):
33
+ return (
34
+ is_sbo_enabled()
35
+ # currently only cutedsl backend supports it
36
+ and (
37
+ get_moe_runner_backend().is_flashinfer_cutedsl()
38
+ or (get_moe_runner_backend().is_deep_gemm() and not is_blackwell())
39
+ )
40
+ )
41
+
42
+ @classmethod
43
+ def enable_combine_shared_two_stream_overlap(cls):
44
+ return (
45
+ is_sbo_enabled()
46
+ and not cls.enable_dispatch_shared_one_stream_overlap()
47
+ and not envs.SGLANG_BLACKWELL_OVERLAP_SHARED_EXPERTS_OUTSIDE_SBO.get()
48
+ )
49
+
50
+ @classmethod
51
+ def enable_dispatch_shared_one_stream_overlap(cls):
52
+ return is_sbo_enabled() and not is_blackwell()
53
+
54
+ @classmethod
55
+ def fuse_shared_experts_inside_sbo(cls):
56
+ return (
57
+ cls.enable_combine_shared_two_stream_overlap()
58
+ or cls.enable_dispatch_shared_one_stream_overlap()
59
+ )
60
+
61
+
62
+ @dataclass
63
+ class CombineOverlapArgs:
64
+ # this "overlap" flag means overlapping with down gemm, not the general two-stream overlap
65
+ overlap: bool
66
+ stream: torch.cuda.Stream
67
+ wait_event: torch.cuda.Event
68
+ num_sms: Optional[int] = None
69
+ signal: Optional[torch.Tensor] = None
70
+ block_m: Optional[int] = 64
71
+ threshold: Optional[int] = 0
72
+
73
+
74
+ @dataclass
75
+ class DownGemmOverlapArgs:
76
+ num_sms: int
77
+ signal: torch.Tensor
78
+ start_event: torch.cuda.Event
79
+
80
+
81
+ def compute_overlap_args(dispatch_output, alt_stream):
82
+ if not (
83
+ SboFlags.enable_combine_down_gemm_two_stream_overlap()
84
+ or SboFlags.enable_combine_shared_two_stream_overlap()
85
+ ):
86
+ return None, None, {}
87
+
88
+ hidden_states = dispatch_output.hidden_states
89
+
90
+ num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape
91
+
92
+ total_num_sms = torch.cuda.get_device_properties(
93
+ device="cuda"
94
+ ).multi_processor_count
95
+
96
+ if envs.SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS.is_set():
97
+ communicate_num_sms = envs.SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS.get()
98
+ else:
99
+ communicate_num_sms = 32 if is_blackwell() else 3
100
+ compute_num_sms = total_num_sms - communicate_num_sms
101
+
102
+ assert alt_stream is not None
103
+ combine_wait_event = torch.cuda.Event()
104
+ combine_overlap_args = CombineOverlapArgs(
105
+ overlap=False,
106
+ num_sms=communicate_num_sms,
107
+ stream=alt_stream,
108
+ wait_event=combine_wait_event,
109
+ )
110
+ meta_overlap_args = dict(
111
+ compute_num_sms=compute_num_sms,
112
+ )
113
+ down_gemm_overlap_args = None
114
+
115
+ if SboFlags.enable_combine_down_gemm_two_stream_overlap():
116
+ # TODO use zero_allocator to remove this `torch.zeros` call
117
+ # NOTE ours v2 use uint32 not int32 currently
118
+ if is_blackwell():
119
+ combine_signal = torch.zeros(
120
+ num_local_experts, dtype=torch.uint32, device=hidden_states.device
121
+ )
122
+ else:
123
+ MIN_BLOCK_M = 64
124
+ combine_signal_size = num_local_experts * (
125
+ (num_tokens_static + MIN_BLOCK_M - 1) // MIN_BLOCK_M
126
+ )
127
+ combine_signal = torch.zeros(
128
+ combine_signal_size, dtype=torch.int32, device=hidden_states.device
129
+ )
130
+
131
+ down_gemm_overlap_args = DownGemmOverlapArgs(
132
+ signal=combine_signal,
133
+ start_event=combine_wait_event,
134
+ num_sms=compute_num_sms,
135
+ )
136
+ combine_overlap_args.overlap = True
137
+ combine_overlap_args.signal = combine_signal
138
+ combine_overlap_args.threshold = compute_num_sms
139
+ else:
140
+ meta_overlap_args |= dict(
141
+ record_event_after_down=combine_wait_event,
142
+ )
143
+
144
+ return combine_overlap_args, down_gemm_overlap_args, meta_overlap_args
sglang/python/sglang/srt/batch_overlap/two_batch_overlap.py ADDED
@@ -0,0 +1,1082 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import dataclasses
5
+ import logging
6
+ from dataclasses import replace
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
8
+
9
+ import torch
10
+
11
+ from sglang.srt.batch_overlap.operations import (
12
+ execute_operations,
13
+ execute_overlapped_operations,
14
+ )
15
+ from sglang.srt.batch_overlap.operations_strategy import OperationsStrategy
16
+ from sglang.srt.layers import deep_gemm_wrapper
17
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
18
+ from sglang.srt.layers.communicator import (
19
+ CommunicateContext,
20
+ CommunicateSummableTensorPairFn,
21
+ ScatterMode,
22
+ )
23
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
24
+ from sglang.srt.layers.moe import (
25
+ get_deepep_mode,
26
+ get_moe_a2a_backend,
27
+ get_tbo_token_distribution_threshold,
28
+ is_tbo_enabled,
29
+ )
30
+ from sglang.srt.layers.moe.token_dispatcher import (
31
+ DeepEPDispatcher,
32
+ MooncakeEPDispatcher,
33
+ MoriEPDispatcher,
34
+ )
35
+ from sglang.srt.layers.moe.token_dispatcher.base import BaseDispatcher
36
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
37
+ from sglang.srt.model_executor.forward_batch_info import (
38
+ ForwardBatch,
39
+ ForwardMode,
40
+ compute_position,
41
+ )
42
+ from sglang.srt.server_args import get_global_server_args
43
+ from sglang.srt.speculative.spec_info import SpecInput
44
+ from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip
45
+
46
+ if TYPE_CHECKING:
47
+ from sglang.srt.batch_overlap.single_batch_overlap import CombineOverlapArgs
48
+ from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
49
+ from sglang.srt.speculative.eagle_info import EagleVerifyInput
50
+
51
+ _is_hip = is_hip()
52
+
53
+ _tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
54
+
55
+ logger = logging.getLogger(__name__)
56
+
57
+
58
+ # -------------------------------- Compute Basic Info ---------------------------------------
59
+
60
+
61
+ def get_token_num_per_seq(
62
+ forward_mode: ForwardMode,
63
+ spec_info: Optional[SpecInput] = None,
64
+ ):
65
+ if forward_mode.is_target_verify():
66
+ return spec_info.draft_token_num
67
+ elif forward_mode.is_decode():
68
+ return 1
69
+ elif forward_mode.is_idle():
70
+ return 0
71
+ else:
72
+ # For extend, we should not use `token_num_per_seq`.
73
+ return None
74
+
75
+
76
+ # TODO: may smartly disable TBO when batch size is too small b/c it will slow down
77
+ def compute_split_seq_index(
78
+ forward_mode: ForwardMode,
79
+ num_tokens: int,
80
+ extend_lens: Optional[Sequence[int]],
81
+ token_num_per_seq: Optional[int],
82
+ ) -> Optional[int]:
83
+ if forward_mode == ForwardMode.EXTEND:
84
+ assert extend_lens is not None
85
+ return _split_extend_seqs(extend_lens)
86
+ elif forward_mode.is_target_verify() or forward_mode.is_decode():
87
+ assert token_num_per_seq is not None
88
+ return (num_tokens // token_num_per_seq) // 2
89
+ elif forward_mode.is_idle() or forward_mode.is_prebuilt():
90
+ assert num_tokens == 0
91
+ return 0
92
+ else:
93
+ raise NotImplementedError()
94
+
95
+
96
+ def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
97
+ if extend_lens is None:
98
+ return False
99
+
100
+ vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
101
+ left_sum = sum(extend_lens[:vanilla_split_seq_index])
102
+ overall_sum = sum(extend_lens)
103
+ threshold = get_tbo_token_distribution_threshold()
104
+ assert threshold <= 0.5, f"{threshold=}"
105
+ return left_sum < overall_sum * threshold or left_sum > overall_sum * (
106
+ 1 - threshold
107
+ )
108
+
109
+
110
+ def _split_extend_seqs(arr: Sequence[int]) -> int:
111
+ if _is_two_chunk_split_enabled(arr):
112
+ return _split_array_by_cum_less_than_half(arr)
113
+
114
+ return _split_array_by_balanced_sum(arr)
115
+
116
+
117
+ def _split_array_by_cum_less_than_half(arr: Sequence[int]) -> int:
118
+ left_sum = 0
119
+ overall_sum = sum(arr)
120
+ half_sum = overall_sum // 2
121
+ chosen_index = 0
122
+
123
+ for i in range(len(arr)):
124
+ left_sum += arr[i]
125
+ if left_sum > half_sum:
126
+ chosen_index = i
127
+ break
128
+
129
+ return chosen_index
130
+
131
+
132
+ def _split_array_by_balanced_sum(arr: Sequence[int]) -> int:
133
+ overall_sum = sum(arr)
134
+ left_sum = 0
135
+ min_diff = float("inf")
136
+ best_index = 0
137
+
138
+ for i in range(1, len(arr)):
139
+ left_sum += arr[i - 1]
140
+ right_sum = overall_sum - left_sum
141
+ diff = abs(left_sum - right_sum)
142
+ if diff <= min_diff:
143
+ min_diff = diff
144
+ best_index = i
145
+ else:
146
+ break
147
+
148
+ return best_index
149
+
150
+
151
+ def _update_device_and_sum_field_from_cpu_field(
152
+ batch: ForwardBatch, cpu_field: str, device_field: str, sum_field: str = None
153
+ ):
154
+ cpu_value = getattr(batch, cpu_field, None)
155
+ old_device_value = getattr(batch, device_field, None)
156
+ if (
157
+ cpu_value is None
158
+ or old_device_value is None
159
+ or not (isinstance(cpu_value, torch.Tensor) or isinstance(cpu_value, list))
160
+ ):
161
+ return
162
+
163
+ new_device_value = (
164
+ cpu_value
165
+ if isinstance(cpu_value, torch.Tensor)
166
+ else torch.tensor(cpu_value, dtype=old_device_value.dtype)
167
+ ).to(device=get_global_server_args().device, non_blocking=True)
168
+ setattr(batch, device_field, new_device_value)
169
+
170
+ if sum_field is not None:
171
+ sum_value = (
172
+ cpu_value.sum().item()
173
+ if isinstance(cpu_value, torch.Tensor)
174
+ else sum(cpu_value)
175
+ )
176
+ setattr(batch, sum_field, sum_value)
177
+
178
+
179
+ def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:
180
+ if seq_index == 0:
181
+ return 0
182
+
183
+ offset = 0
184
+ max_seq_len = min(seq_index, spec_info.seq_lens_cpu.shape[0])
185
+ for i in range(max_seq_len):
186
+ offset += (
187
+ spec_info.seq_lens_cpu[i] + spec_info.draft_token_num
188
+ ) * spec_info.draft_token_num
189
+ return offset
190
+
191
+
192
+ def split_spec_info(
193
+ spec_info: Optional[EagleVerifyInput],
194
+ start_seq_index: int,
195
+ end_seq_index: int,
196
+ start_token_index: int,
197
+ end_token_index: int,
198
+ ):
199
+ if spec_info is None:
200
+ return None
201
+ if spec_info.draft_token is not None:
202
+ draft_token = spec_info.draft_token[start_token_index:end_token_index]
203
+ else:
204
+ draft_token = None
205
+ if spec_info.custom_mask is not None and spec_info.draft_token is not None:
206
+ custom_mask_start = _compute_mask_offset(start_seq_index, spec_info)
207
+ if end_seq_index == spec_info.seq_lens_cpu.shape[0]:
208
+ custom_mask_end = spec_info.custom_mask.shape[0]
209
+ else:
210
+ custom_mask_end = _compute_mask_offset(end_seq_index, spec_info)
211
+
212
+ if custom_mask_end > custom_mask_start:
213
+ custom_mask = spec_info.custom_mask[custom_mask_start:custom_mask_end]
214
+ else:
215
+ custom_mask = spec_info.custom_mask
216
+ else:
217
+ custom_mask = spec_info.custom_mask
218
+ if spec_info.positions is not None:
219
+ positions = spec_info.positions[start_token_index:end_token_index]
220
+ else:
221
+ positions = None
222
+ if spec_info.retrive_index is not None:
223
+ retrive_index = spec_info.retrive_index[start_seq_index:end_seq_index]
224
+ else:
225
+ retrive_index = None
226
+ if spec_info.retrive_next_token is not None:
227
+ retrive_next_token = spec_info.retrive_next_token[start_seq_index:end_seq_index]
228
+ else:
229
+ retrive_next_token = None
230
+ if spec_info.retrive_next_sibling is not None:
231
+ retrive_next_sibling = spec_info.retrive_next_sibling[
232
+ start_seq_index:end_seq_index
233
+ ]
234
+ else:
235
+ retrive_next_sibling = None
236
+ if spec_info.retrive_cum_len is not None:
237
+ retrive_cum_len = spec_info.retrive_cum_len[start_seq_index:end_seq_index]
238
+ else:
239
+ retrive_cum_len = None
240
+
241
+ if spec_info.seq_lens_cpu is not None:
242
+ seq_lens_cpu = spec_info.seq_lens_cpu[start_seq_index:end_seq_index]
243
+ else:
244
+ seq_lens_cpu = None
245
+ if seq_lens_cpu is not None:
246
+ seq_lens_sum = seq_lens_cpu.sum()
247
+ else:
248
+ seq_lens_sum = None
249
+ output_spec_info = replace(
250
+ spec_info,
251
+ custom_mask=custom_mask,
252
+ draft_token=draft_token,
253
+ positions=positions,
254
+ retrive_index=retrive_index,
255
+ retrive_next_token=retrive_next_token,
256
+ retrive_next_sibling=retrive_next_sibling,
257
+ retrive_cum_len=retrive_cum_len,
258
+ seq_lens_cpu=seq_lens_cpu,
259
+ seq_lens_sum=seq_lens_sum,
260
+ )
261
+ return output_spec_info
262
+
263
+
264
+ def compute_split_token_index(
265
+ split_seq_index: int,
266
+ forward_mode: "ForwardMode",
267
+ extend_seq_lens: Optional[Sequence[int]],
268
+ token_num_per_seq: Optional[int],
269
+ ) -> int:
270
+ if forward_mode == ForwardMode.EXTEND:
271
+ assert extend_seq_lens is not None
272
+ if _is_two_chunk_split_enabled(extend_seq_lens):
273
+ return sum(extend_seq_lens) // 2
274
+ return sum(extend_seq_lens[:split_seq_index])
275
+ elif forward_mode.is_target_verify() or forward_mode.is_decode():
276
+ assert token_num_per_seq is not None
277
+ return split_seq_index * token_num_per_seq
278
+ elif forward_mode.is_idle():
279
+ assert split_seq_index == 0
280
+ return 0
281
+ else:
282
+ raise NotImplementedError
283
+
284
+
285
+ def compute_split_indices_for_cuda_graph_replay(
286
+ forward_mode: ForwardMode,
287
+ cuda_graph_num_tokens: int,
288
+ spec_info: Optional[SpecInput],
289
+ ):
290
+ forward_mode_for_tbo_split = (
291
+ forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
292
+ )
293
+ token_num_per_seq = get_token_num_per_seq(
294
+ forward_mode=forward_mode, spec_info=spec_info
295
+ )
296
+ tbo_split_seq_index = compute_split_seq_index(
297
+ forward_mode=forward_mode_for_tbo_split,
298
+ num_tokens=cuda_graph_num_tokens,
299
+ extend_lens=None,
300
+ token_num_per_seq=token_num_per_seq,
301
+ )
302
+ tbo_split_token_index = compute_split_token_index(
303
+ split_seq_index=tbo_split_seq_index,
304
+ forward_mode=forward_mode_for_tbo_split,
305
+ extend_seq_lens=None,
306
+ token_num_per_seq=token_num_per_seq,
307
+ )
308
+ return tbo_split_seq_index, tbo_split_token_index
309
+
310
+
311
+ # -------------------------------- Preparation ---------------------------------------
312
+
313
+
314
+ class TboCudaGraphRunnerPlugin:
315
+ def __init__(self):
316
+ self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)
317
+
318
+ def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
319
+ if not is_tbo_enabled():
320
+ return
321
+ token_num_per_seq = get_token_num_per_seq(
322
+ forward_mode=batch.forward_mode, spec_info=batch.spec_info
323
+ )
324
+
325
+ batch.tbo_split_seq_index = compute_split_seq_index(
326
+ forward_mode=batch.forward_mode,
327
+ num_tokens=num_tokens,
328
+ extend_lens=None,
329
+ token_num_per_seq=token_num_per_seq,
330
+ )
331
+ # For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
332
+ assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"
333
+
334
+ self._tbo_children_num_token_non_padded[...] = (
335
+ TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded(batch)
336
+ )
337
+
338
+ TboForwardBatchPreparer.prepare_raw(
339
+ batch,
340
+ tbo_children_num_token_non_padded=self._tbo_children_num_token_non_padded,
341
+ )
342
+
343
+ def replay_prepare(
344
+ self,
345
+ forward_mode: ForwardMode,
346
+ bs: int,
347
+ num_token_non_padded: int,
348
+ spec_info: Optional[SpecInput],
349
+ ):
350
+ token_num_per_seq = get_token_num_per_seq(
351
+ forward_mode=forward_mode, spec_info=spec_info
352
+ )
353
+ tbo_split_seq_index, tbo_split_token_index = (
354
+ compute_split_indices_for_cuda_graph_replay(
355
+ forward_mode=forward_mode,
356
+ cuda_graph_num_tokens=bs * token_num_per_seq,
357
+ spec_info=spec_info,
358
+ )
359
+ )
360
+
361
+ self._tbo_children_num_token_non_padded[...] = (
362
+ TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded_raw(
363
+ tbo_split_token_index=tbo_split_token_index,
364
+ num_token_non_padded=num_token_non_padded,
365
+ )
366
+ )
367
+
368
+
369
+ class TboDPAttentionPreparer:
370
+ def prepare_all_gather(
371
+ self,
372
+ local_batch: ScheduleBatch,
373
+ ):
374
+
375
+ deepep_mode = get_deepep_mode()
376
+ enable_a2a_moe = not get_moe_a2a_backend().is_none()
377
+ enable_two_batch_overlap = is_tbo_enabled()
378
+
379
+ self.enable_two_batch_overlap = enable_two_batch_overlap
380
+
381
+ if local_batch is not None:
382
+ token_num_per_seq = get_token_num_per_seq(
383
+ forward_mode=local_batch.forward_mode, spec_info=local_batch.spec_info
384
+ )
385
+
386
+ if (
387
+ local_batch.forward_mode.is_target_verify()
388
+ or local_batch.forward_mode.is_decode()
389
+ ):
390
+ num_tokens = local_batch.batch_size() * token_num_per_seq
391
+ elif local_batch.forward_mode.is_prebuilt():
392
+ num_tokens = 0
393
+ else:
394
+ num_tokens = local_batch.extend_num_tokens
395
+ self.local_tbo_split_seq_index = compute_split_seq_index(
396
+ forward_mode=local_batch.forward_mode,
397
+ num_tokens=num_tokens,
398
+ extend_lens=local_batch.extend_lens,
399
+ token_num_per_seq=token_num_per_seq,
400
+ )
401
+ resolved_deepep_mode = deepep_mode.resolve(local_batch.is_extend_in_batch)
402
+ local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
403
+ (
404
+ local_batch.forward_mode.is_extend()
405
+ and not local_batch.forward_mode.is_target_verify()
406
+ )
407
+ and enable_a2a_moe
408
+ and (resolved_deepep_mode.is_low_latency())
409
+ )
410
+ else:
411
+ self.local_tbo_split_seq_index = 0
412
+ local_can_run_tbo = True
413
+
414
+ local_forward_mode = self._compute_local_forward_mode(local_batch)
415
+
416
+ return local_can_run_tbo, local_forward_mode
417
+
418
+ def compute_output(self, partial_global_info):
419
+ # Perform only one Device-to-Host (D2H) memory copy
420
+ cpu_data = partial_global_info[:, :2].cpu()
421
+ local_can_run_tbo_aggregated = min(cpu_data[:, 0].tolist())
422
+ forward_modes = cpu_data[:, 1].tolist()
423
+
424
+ global_forward_mode, forward_mode_agree = self._compute_global_forward_mode(
425
+ forward_modes
426
+ )
427
+
428
+ can_run_tbo = (
429
+ self.enable_two_batch_overlap
430
+ and local_can_run_tbo_aggregated
431
+ and forward_mode_agree
432
+ )
433
+
434
+ tbo_split_seq_index = self.local_tbo_split_seq_index if can_run_tbo else None
435
+ global_forward_mode = global_forward_mode if can_run_tbo else None
436
+ return tbo_split_seq_index, global_forward_mode
437
+
438
+ @staticmethod
439
+ def _compute_local_forward_mode(local_batch):
440
+ return (
441
+ local_batch.forward_mode if local_batch is not None else ForwardMode.IDLE
442
+ ).value
443
+
444
+ @staticmethod
445
+ def _compute_global_forward_mode(forward_modes):
446
+ forward_modes_excluding_idle_and_prebuilt = [
447
+ x
448
+ for x in forward_modes
449
+ if x != ForwardMode.IDLE.value and x != ForwardMode.PREBUILT.value
450
+ ]
451
+
452
+ if not forward_modes_excluding_idle_and_prebuilt:
453
+ return ForwardMode.IDLE, False
454
+
455
+ forward_mode_agree = TboDPAttentionPreparer._is_all_same(
456
+ forward_modes_excluding_idle_and_prebuilt
457
+ )
458
+
459
+ global_forward_mode = (
460
+ ForwardMode(forward_modes_excluding_idle_and_prebuilt[0])
461
+ if forward_mode_agree
462
+ else None
463
+ )
464
+ return global_forward_mode, forward_mode_agree
465
+
466
+ @staticmethod
467
+ def _is_all_same(x):
468
+ return all(value == x[0] for value in x)
469
+
470
+
471
+ class TboForwardBatchPreparer:
472
+ @classmethod
473
+ def prepare(cls, batch: ForwardBatch, is_draft_worker: bool = False):
474
+ if batch.tbo_split_seq_index is None or is_draft_worker:
475
+ return
476
+
477
+ tbo_children_num_token_non_padded = (
478
+ cls.compute_tbo_children_num_token_non_padded(batch)
479
+ )
480
+ cls.prepare_raw(
481
+ batch, tbo_children_num_token_non_padded=tbo_children_num_token_non_padded
482
+ )
483
+
484
+ @classmethod
485
+ def prepare_raw(
486
+ cls, batch: ForwardBatch, tbo_children_num_token_non_padded: torch.Tensor
487
+ ):
488
+ from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
489
+
490
+ tbo_split_token_index = cls._compute_split_token_index(batch)
491
+
492
+ is_enable_two_chunk = (
493
+ batch.forward_mode == ForwardMode.EXTEND
494
+ and _is_two_chunk_split_enabled(batch.extend_seq_lens_cpu)
495
+ )
496
+
497
+ if _tbo_debug:
498
+ logger.info(
499
+ f"TboForwardBatchPreparer.prepare "
500
+ f"is_enable_two_chunk={is_enable_two_chunk} "
501
+ f"tbo_split_seq_index={batch.tbo_split_seq_index} "
502
+ f"tbo_split_token_index={tbo_split_token_index} "
503
+ f"extend_seq_lens={batch.extend_seq_lens_cpu} "
504
+ f"bs={batch.batch_size} "
505
+ f"forward_mode={batch.forward_mode}"
506
+ )
507
+
508
+ assert isinstance(batch.attn_backend, TboAttnBackend)
509
+ attn_backend_child_a, attn_backend_child_b = batch.attn_backend.children
510
+
511
+ [out_num_token_non_padded_a, out_num_token_non_padded_b] = (
512
+ tbo_children_num_token_non_padded
513
+ )
514
+
515
+ child_a = cls.filter_batch(
516
+ batch,
517
+ start_token_index=0,
518
+ end_token_index=tbo_split_token_index,
519
+ start_seq_index=0,
520
+ end_seq_index=(
521
+ batch.tbo_split_seq_index + 1
522
+ if is_enable_two_chunk
523
+ else batch.tbo_split_seq_index
524
+ ),
525
+ output_attn_backend=attn_backend_child_a,
526
+ out_num_token_non_padded=out_num_token_non_padded_a,
527
+ )
528
+ child_b = cls.filter_batch(
529
+ batch,
530
+ start_token_index=tbo_split_token_index,
531
+ end_token_index=batch.input_ids.shape[0],
532
+ start_seq_index=batch.tbo_split_seq_index,
533
+ end_seq_index=batch.batch_size,
534
+ output_attn_backend=attn_backend_child_b,
535
+ out_num_token_non_padded=out_num_token_non_padded_b,
536
+ )
537
+
538
+ if is_enable_two_chunk:
539
+ cls.derive_fields_related_to_seq_len_for_two_chunk(
540
+ batch,
541
+ child_a=child_a,
542
+ child_b=child_b,
543
+ tbo_split_seq_index=batch.tbo_split_seq_index,
544
+ )
545
+
546
+ assert batch.tbo_children is None
547
+ batch.tbo_children = [child_a, child_b]
548
+
549
+ @classmethod
550
+ def derive_fields_related_to_seq_len_for_two_chunk(
551
+ cls,
552
+ batch: ForwardBatch,
553
+ *,
554
+ child_a: ForwardBatch,
555
+ child_b: ForwardBatch,
556
+ tbo_split_seq_index: int,
557
+ ):
558
+ extend_seq_lens_cpu = batch.extend_seq_lens_cpu
559
+ overall_seq_lens_sum = sum(extend_seq_lens_cpu)
560
+ half_seq_lens_sum = overall_seq_lens_sum // 2
561
+ left_last_seq_token_num = half_seq_lens_sum - sum(
562
+ extend_seq_lens_cpu[:tbo_split_seq_index]
563
+ )
564
+ right_first_seq_token_num = (
565
+ extend_seq_lens_cpu[tbo_split_seq_index] - left_last_seq_token_num
566
+ )
567
+
568
+ # making deepcopy to be extra safe
569
+ child_a.extend_seq_lens_cpu = copy.deepcopy(child_a.extend_seq_lens_cpu)
570
+ child_a.extend_seq_lens_cpu[-1] = left_last_seq_token_num
571
+ child_b.extend_seq_lens_cpu = copy.deepcopy(child_b.extend_seq_lens_cpu)
572
+ child_b.extend_seq_lens_cpu[0] = right_first_seq_token_num
573
+ for child in [child_a, child_b]:
574
+ _update_device_and_sum_field_from_cpu_field(
575
+ batch=child,
576
+ cpu_field="extend_seq_lens_cpu",
577
+ device_field="extend_seq_lens",
578
+ sum_field="extend_num_tokens",
579
+ )
580
+
581
+ assert (
582
+ child_a.extend_num_tokens == half_seq_lens_sum
583
+ ), f"{child_a.extend_num_tokens=}, {half_seq_lens_sum=}"
584
+
585
+ child_a.seq_lens_cpu = copy.deepcopy(child_a.seq_lens_cpu)
586
+ child_a.seq_lens_cpu[-1] = (
587
+ child_a.extend_seq_lens_cpu[-1] + child_a.extend_prefix_lens_cpu[-1]
588
+ )
589
+ _update_device_and_sum_field_from_cpu_field(
590
+ batch=child_a,
591
+ cpu_field="seq_lens_cpu",
592
+ device_field="seq_lens",
593
+ sum_field="seq_lens_sum",
594
+ )
595
+
596
+ child_b.extend_prefix_lens_cpu = copy.deepcopy(child_b.extend_prefix_lens_cpu)
597
+ child_b.extend_prefix_lens_cpu[0] += left_last_seq_token_num
598
+ _update_device_and_sum_field_from_cpu_field(
599
+ batch=child_b,
600
+ cpu_field="extend_prefix_lens_cpu",
601
+ device_field="extend_prefix_lens",
602
+ sum_field=None,
603
+ )
604
+ _, child_b.extend_start_loc = compute_position(
605
+ get_global_server_args().attention_backend,
606
+ child_b.extend_prefix_lens,
607
+ child_b.extend_seq_lens,
608
+ child_b.extend_num_tokens,
609
+ )
610
+
611
+ @classmethod
612
+ def filter_batch(
613
+ cls,
614
+ batch: ForwardBatch,
615
+ *,
616
+ start_token_index: int,
617
+ end_token_index: int,
618
+ start_seq_index: int,
619
+ end_seq_index: int,
620
+ output_attn_backend: AttentionBackend,
621
+ out_num_token_non_padded: torch.Tensor,
622
+ ):
623
+ assert (
624
+ end_token_index >= start_token_index
625
+ ), f"{end_token_index=}, {start_token_index=}, batch={batch}"
626
+ num_tokens = batch.input_ids.shape[0]
627
+ num_seqs = batch.batch_size
628
+
629
+ output_dict = dict()
630
+
631
+ for key in [
632
+ "input_ids",
633
+ "positions",
634
+ "out_cache_loc",
635
+ ]:
636
+ old_value = getattr(batch, key)
637
+ assert (
638
+ old_value.shape[0] == num_tokens
639
+ ), f"{key=} {old_value=} {num_tokens=} {batch=}"
640
+ output_dict[key] = old_value[start_token_index:end_token_index]
641
+
642
+ attention_tp_size = get_attention_tp_size()
643
+ output_dict["tbo_padded_len"] = (
644
+ (end_token_index - start_token_index - 1) // attention_tp_size + 1
645
+ ) * attention_tp_size
646
+
647
+ for key in [
648
+ "req_pool_indices",
649
+ "seq_lens",
650
+ "seq_lens_cpu",
651
+ "extend_seq_lens",
652
+ "extend_prefix_lens",
653
+ "extend_start_loc",
654
+ "extend_prefix_lens_cpu",
655
+ "extend_seq_lens_cpu",
656
+ "extend_logprob_start_lens_cpu",
657
+ "lora_ids",
658
+ "rids",
659
+ ]:
660
+ old_value = getattr(batch, key)
661
+ if old_value is None:
662
+ continue
663
+ elif batch.forward_mode.is_target_verify() and (
664
+ key == "extend_seq_lens"
665
+ or key == "extend_prefix_lens"
666
+ or key == "extend_start_loc"
667
+ or key == "extend_prefix_lens_cpu"
668
+ or key == "extend_seq_lens_cpu"
669
+ or key == "extend_logprob_start_lens_cpu"
670
+ ):
671
+ output_dict[key] = None
672
+ continue
673
+ assert (
674
+ len(old_value) == num_seqs
675
+ ), f"{key=} {old_value=} {num_seqs=} {batch=}"
676
+ output_dict[key] = old_value[start_seq_index:end_seq_index]
677
+
678
+ spec_info = getattr(batch, "spec_info")
679
+ output_spec_info = split_spec_info(
680
+ spec_info=spec_info,
681
+ start_token_index=start_token_index,
682
+ end_token_index=end_token_index,
683
+ start_seq_index=start_seq_index,
684
+ end_seq_index=end_seq_index,
685
+ )
686
+ output_dict["spec_info"] = output_spec_info
687
+ for key in [
688
+ "forward_mode",
689
+ "is_extend_in_batch",
690
+ "all_extend_in_batch",
691
+ "return_logprob",
692
+ "req_to_token_pool",
693
+ "token_to_kv_pool",
694
+ "can_run_dp_cuda_graph",
695
+ "dp_padding_mode",
696
+ "global_forward_mode",
697
+ "is_prefill_only",
698
+ "spec_algorithm",
699
+ "capture_hidden_mode",
700
+ "padded_static_len",
701
+ "mrope_positions", # only used by qwen2-vl, thus not care
702
+ "split_index", # for split prefill
703
+ "orig_seq_lens", # only used by qwen-1m, thus not care
704
+ ]:
705
+ output_dict[key] = getattr(batch, key)
706
+ if not batch.forward_mode.is_target_verify():
707
+ assert (
708
+ _compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
709
+ == batch.extend_num_tokens
710
+ ), f"{batch=}"
711
+ extend_num_tokens = _compute_extend_num_tokens(
712
+ output_dict["input_ids"], output_dict["forward_mode"]
713
+ )
714
+
715
+ # TODO improve, e.g. unify w/ `init_raw`
716
+ if (
717
+ get_global_server_args().moe_dense_tp_size == 1
718
+ and batch.global_dp_buffer_len is not None
719
+ ):
720
+ sum_len = end_token_index - start_token_index
721
+ global_dp_buffer_len = sum_len
722
+ else:
723
+ global_dp_buffer_len = None
724
+
725
+ output_dict.update(
726
+ dict(
727
+ batch_size=end_seq_index - start_seq_index,
728
+ seq_lens_sum=(
729
+ output_dict["seq_lens_cpu"].sum()
730
+ if "seq_lens_cpu" in output_dict
731
+ else None
732
+ ),
733
+ extend_num_tokens=extend_num_tokens,
734
+ attn_backend=output_attn_backend,
735
+ num_token_non_padded=out_num_token_non_padded,
736
+ # TODO: handle it when we need TBO + DeepSeek V3.2
737
+ num_token_non_padded_cpu=None,
738
+ tbo_split_seq_index=None,
739
+ tbo_parent_token_range=(start_token_index, end_token_index),
740
+ tbo_children=None,
741
+ original_global_num_tokens_cpu=None,
742
+ global_num_tokens_gpu=None,
743
+ global_num_tokens_cpu=None,
744
+ global_dp_buffer_len=global_dp_buffer_len,
745
+ global_num_tokens_for_logprob_gpu=None,
746
+ global_num_tokens_for_logprob_cpu=None,
747
+ sampling_info=None,
748
+ # For logits and logprobs post processing, thus we do not care
749
+ temp_scaled_logprobs=False,
750
+ temperature=None,
751
+ top_p_normalized_logprobs=False,
752
+ top_p=None,
753
+ mm_inputs=None,
754
+ top_logprobs_nums=None,
755
+ token_ids_logprobs=None,
756
+ next_token_logits_buffer=None,
757
+ return_hidden_states_before_norm=False,
758
+ )
759
+ )
760
+
761
+ errors = []
762
+ for field in dataclasses.fields(ForwardBatch):
763
+ if getattr(batch, field.name) is not None and field.name not in output_dict:
764
+ errors.append(
765
+ f"Field {field.name} has value, but is not yet supported (value={getattr(batch, field.name)} batch={batch})"
766
+ )
767
+ if len(errors) > 0:
768
+ raise Exception(f"{len(errors)} errors happen:\n" + "\n\n".join(errors))
769
+
770
+ return ForwardBatch(**output_dict)
771
+
772
+ @classmethod
773
+ def compute_tbo_children_num_token_non_padded(cls, batch: ForwardBatch):
774
+ return cls.compute_tbo_children_num_token_non_padded_raw(
775
+ tbo_split_token_index=cls._compute_split_token_index(batch),
776
+ num_token_non_padded=len(batch.input_ids),
777
+ )
778
+
779
+ @classmethod
780
+ def compute_tbo_children_num_token_non_padded_raw(
781
+ cls, tbo_split_token_index: int, num_token_non_padded: int
782
+ ):
783
+ # TODO we may make padding on both sub-batches to make it slightly more balanced
784
+ value_a = min(tbo_split_token_index, num_token_non_padded)
785
+ value_b = max(0, num_token_non_padded - tbo_split_token_index)
786
+ return torch.tensor([value_a, value_b], dtype=torch.int32).to(
787
+ device=get_global_server_args().device, non_blocking=True
788
+ )
789
+
790
+ @classmethod
791
+ def _compute_split_token_index(cls, batch: ForwardBatch):
792
+ token_num_per_seq = get_token_num_per_seq(
793
+ forward_mode=batch.forward_mode, spec_info=batch.spec_info
794
+ )
795
+ return compute_split_token_index(
796
+ split_seq_index=batch.tbo_split_seq_index,
797
+ forward_mode=batch.forward_mode,
798
+ extend_seq_lens=batch.extend_seq_lens_cpu,
799
+ token_num_per_seq=token_num_per_seq,
800
+ )
801
+
802
+
803
+ def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
804
+ if (
805
+ forward_mode.is_decode()
806
+ or forward_mode.is_idle()
807
+ or forward_mode.is_target_verify()
808
+ ):
809
+ return None
810
+ elif forward_mode.is_extend():
811
+ return input_ids.shape[0]
812
+ raise NotImplementedError
813
+
814
+
815
+ # -------------------------------- Execution ---------------------------------------
816
+
817
+
818
+ def model_forward_maybe_tbo(
819
+ layers,
820
+ enable_tbo: bool,
821
+ positions: torch.Tensor,
822
+ forward_batch: ForwardBatch,
823
+ hidden_states: torch.Tensor,
824
+ input_data_scatter_mode: ScatterMode,
825
+ residual: Optional[torch.Tensor],
826
+ zero_allocator: Optional[BumpAllocator] = None,
827
+ ):
828
+ inputs = dict(
829
+ positions=positions,
830
+ hidden_states=hidden_states,
831
+ forward_batch=forward_batch,
832
+ residual=residual,
833
+ zero_allocator=zero_allocator,
834
+ )
835
+ layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode
836
+ operations_strategy = OperationsStrategy.init_new_tbo(
837
+ layers, forward_batch.global_forward_mode
838
+ )
839
+ if enable_tbo:
840
+ return _model_forward_tbo(
841
+ inputs=inputs,
842
+ operations_strategy=operations_strategy,
843
+ input_data_scatter_mode=input_data_scatter_mode,
844
+ layer_input_scatter_mode=layer_input_scatter_mode,
845
+ )
846
+ else:
847
+ return _model_forward_non_tbo(inputs, operations_strategy)
848
+
849
+
850
+ def _model_forward_tbo(
851
+ inputs,
852
+ operations_strategy: OperationsStrategy,
853
+ input_data_scatter_mode: ScatterMode,
854
+ layer_input_scatter_mode: ScatterMode,
855
+ ):
856
+ inputs_arr = _model_forward_tbo_split_inputs(
857
+ **inputs,
858
+ input_data_scatter_mode=input_data_scatter_mode,
859
+ layer_input_scatter_mode=layer_input_scatter_mode,
860
+ )
861
+ original_hidden_states_len = inputs["hidden_states"].shape[0]
862
+ del inputs
863
+
864
+ context = (
865
+ empty_context()
866
+ if _is_hip
867
+ else deep_gemm_wrapper.configure_deep_gemm_num_sms(
868
+ operations_strategy.deep_gemm_num_sms
869
+ )
870
+ )
871
+
872
+ with context:
873
+ outputs_arr = execute_overlapped_operations(
874
+ inputs_arr=inputs_arr,
875
+ operations_arr=[operations_strategy.operations] * 2,
876
+ delta_stages=[0, operations_strategy.tbo_delta_stages],
877
+ )
878
+
879
+ return _model_forward_tbo_merge_outputs(*outputs_arr, original_hidden_states_len)
880
+
881
+
882
+ def _model_forward_non_tbo(inputs, operations_strategy: OperationsStrategy):
883
+ outputs = execute_operations(inputs, operations_strategy.operations)
884
+ return outputs["hidden_states"], outputs["residual"]
885
+
886
+
887
+ def _model_forward_tbo_split_inputs(
888
+ hidden_states: torch.Tensor,
889
+ residual: torch.Tensor,
890
+ positions: torch.Tensor,
891
+ forward_batch: ForwardBatch,
892
+ zero_allocator: Optional[BumpAllocator],
893
+ input_data_scatter_mode: ScatterMode,
894
+ layer_input_scatter_mode: ScatterMode,
895
+ ) -> List[Dict]:
896
+ tbo_splitter_scatter_mode = ScatterMode.TP_ATTN_FULL
897
+ context = CommunicateContext.init_new()
898
+
899
+ hidden_states, residual = CommunicateSummableTensorPairFn.execute(
900
+ hidden_states_input_mode=input_data_scatter_mode,
901
+ residual_input_mode=input_data_scatter_mode,
902
+ output_mode=tbo_splitter_scatter_mode,
903
+ hidden_states=hidden_states,
904
+ residual=residual,
905
+ forward_batch=forward_batch,
906
+ context=context,
907
+ )
908
+
909
+ inputs_arr = _model_forward_tbo_split_inputs_raw(
910
+ hidden_states=hidden_states,
911
+ residual=residual,
912
+ positions=positions,
913
+ forward_batch=forward_batch,
914
+ zero_allocator=zero_allocator,
915
+ )
916
+
917
+ def _post_transform(hidden_states, residual, forward_batch, **kwargs):
918
+ hidden_states, residual = CommunicateSummableTensorPairFn.execute(
919
+ hidden_states_input_mode=tbo_splitter_scatter_mode,
920
+ residual_input_mode=tbo_splitter_scatter_mode,
921
+ output_mode=layer_input_scatter_mode,
922
+ hidden_states=hidden_states,
923
+ residual=residual,
924
+ forward_batch=forward_batch,
925
+ context=context,
926
+ )
927
+ return dict(
928
+ hidden_states=hidden_states,
929
+ residual=residual,
930
+ forward_batch=forward_batch,
931
+ **kwargs,
932
+ )
933
+
934
+ return [_post_transform(**inputs) for inputs in inputs_arr]
935
+
936
+
937
+ def _model_forward_tbo_split_inputs_raw(
938
+ hidden_states: torch.Tensor,
939
+ residual: torch.Tensor,
940
+ positions: torch.Tensor,
941
+ forward_batch: ForwardBatch,
942
+ zero_allocator: Optional[BumpAllocator],
943
+ ) -> List[Dict]:
944
+ return [
945
+ dict(
946
+ **_model_forward_filter_inputs(
947
+ hidden_states=hidden_states,
948
+ residual=residual,
949
+ positions=positions,
950
+ output_forward_batch=output_forward_batch,
951
+ tbo_subbatch_index=tbo_subbatch_index,
952
+ ),
953
+ **(
954
+ dict(zero_allocator=zero_allocator)
955
+ if zero_allocator is not None
956
+ else {}
957
+ ),
958
+ )
959
+ for tbo_subbatch_index, output_forward_batch in enumerate(
960
+ forward_batch.tbo_children
961
+ )
962
+ ]
963
+
964
+
965
+ def _model_forward_filter_inputs(
966
+ hidden_states: torch.Tensor,
967
+ residual: torch.Tensor,
968
+ positions: torch.Tensor,
969
+ output_forward_batch: ForwardBatch,
970
+ tbo_subbatch_index: int,
971
+ ) -> Dict:
972
+ token_slice = slice(*output_forward_batch.tbo_parent_token_range)
973
+ hidden_states = hidden_states[token_slice]
974
+ residual = None if residual is None else residual[token_slice]
975
+ positions = positions[token_slice]
976
+
977
+ assert output_forward_batch.tbo_padded_len is not None
978
+ padded_len = output_forward_batch.tbo_padded_len
979
+
980
+ def _pad(x):
981
+ nonlocal padded_len
982
+ if x is None:
983
+ return None
984
+ if x.shape[0] == padded_len:
985
+ return x
986
+ res = torch.zeros((padded_len, *x.shape[1:]), dtype=x.dtype, device=x.device)
987
+ res[: x.shape[0]] = x
988
+ return res
989
+
990
+ return dict(
991
+ hidden_states=_pad(hidden_states),
992
+ residual=_pad(residual),
993
+ positions=_pad(positions),
994
+ forward_batch=output_forward_batch,
995
+ tbo_subbatch_index=tbo_subbatch_index,
996
+ )
997
+
998
+
999
+ def _model_forward_tbo_merge_outputs(output_a, output_b, original_len):
1000
+ def _handle_key(name):
1001
+ value_a = output_a[name]
1002
+ value_b = output_b[name]
1003
+ assert (value_a is None) == (value_b is None)
1004
+ if value_a is None:
1005
+ return None
1006
+ s0, t0 = output_a["forward_batch"].tbo_parent_token_range
1007
+ s1, t1 = output_b["forward_batch"].tbo_parent_token_range
1008
+ res = torch.zeros(
1009
+ (original_len, *value_a.shape[1:]),
1010
+ dtype=value_a.dtype,
1011
+ device=value_a.device,
1012
+ )
1013
+ res[slice(s0, t0)] = value_a[: t0 - s0]
1014
+ res[slice(s1, t1)] = value_b[: t1 - s1]
1015
+ return res
1016
+
1017
+ return _handle_key("hidden_states"), _handle_key("residual")
1018
+
1019
+
1020
+ # -------------------------------- Utilities and wrappers ---------------------------------------
1021
+
1022
+
1023
+ class MaybeTboDeepEPDispatcher(BaseDispatcher):
1024
+ def __init__(self, **kwargs):
1025
+ super().__init__()
1026
+ num_inner_dispatchers = 2 if is_tbo_enabled() else 1
1027
+ if get_moe_a2a_backend().is_deepep():
1028
+ self._inners = [
1029
+ DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
1030
+ ]
1031
+ elif get_moe_a2a_backend().is_mooncake():
1032
+ self._inners = [
1033
+ MooncakeEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
1034
+ ]
1035
+ elif get_moe_a2a_backend().is_mori():
1036
+ self._inners = [
1037
+ MoriEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
1038
+ ]
1039
+
1040
+ def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
1041
+ return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)
1042
+
1043
+ def dispatch(self, **kwargs) -> DispatchOutput:
1044
+ return self._execute("dispatch", **kwargs)
1045
+
1046
+ def dispatch_a(self, **kwargs):
1047
+ return self._execute("dispatch_a", **kwargs)
1048
+
1049
+ def dispatch_b(self, **kwargs):
1050
+ return self._execute("dispatch_b", **kwargs)
1051
+
1052
+ def combine(self, **kwargs) -> torch.Tensor:
1053
+ return self._execute("combine", **kwargs)
1054
+
1055
+ def combine_a(self, **kwargs):
1056
+ return self._execute("combine_a", **kwargs)
1057
+
1058
+ def combine_b(self, **kwargs):
1059
+ return self._execute("combine_b", **kwargs)
1060
+
1061
+ def register_deepep_dispatch_hook(self, hook):
1062
+ handle_list = []
1063
+ for inner in self._inners:
1064
+ handle_list.append(inner.register_deepep_dispatch_hook(hook))
1065
+ return handle_list
1066
+
1067
+ def set_quant_config(self, quant_config: dict):
1068
+ super().set_quant_config(quant_config)
1069
+ for inner in self._inners:
1070
+ inner.set_quant_config(quant_config)
1071
+
1072
+ def set_overlap_args(
1073
+ self, combine_overlap_args: CombineOverlapArgs, meta_overlap_args: dict
1074
+ ):
1075
+ super().set_overlap_args(combine_overlap_args, meta_overlap_args)
1076
+ for inner in self._inners:
1077
+ inner.set_overlap_args(combine_overlap_args, meta_overlap_args)
1078
+
1079
+ def clear_overlap_args(self):
1080
+ super().clear_overlap_args()
1081
+ for inner in self._inners:
1082
+ inner.clear_overlap_args()
sglang/python/sglang/srt/checkpoint_engine/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Checkpoint engine module for SGLang.
3
+
4
+ This module provides functionality for updating model weights via checkpoint engine.
5
+ """
6
+
7
+ from sglang.srt.checkpoint_engine.update import main
8
+
9
+ __all__ = ["main"]
sglang/python/sglang/srt/checkpoint_engine/checkpoint_engine_worker.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """
15
+ Checkpoint-engine integration for SGLang.
16
+ This module provides weight update functionality via IPC for checkpoint-engine compatibility.
17
+ """
18
+
19
+ import logging
20
+ from typing import Callable, Dict, Optional
21
+
22
+ import torch
23
+ import zmq
24
+
25
+ try:
26
+ from checkpoint_engine.worker import update_weights_from_ipc
27
+ except ImportError:
28
+ raise ImportError(
29
+ "checkpoint-engine is not installed. "
30
+ "Please install it with: pip install sglang[checkpoint-engine]"
31
+ )
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class SGLangCheckpointEngineWorkerExtension:
37
+ """
38
+ Worker extension for SGLang to support checkpoint-engine IPC weight updates.
39
+ This class provides the interface needed for checkpoint-engine integration.
40
+ """
41
+
42
+ def __init__(self):
43
+ self._zmq_ctx: Optional[zmq.Context] = None
44
+
45
+ def get_device_uuid(self) -> str:
46
+ """Get the UUID of current device."""
47
+ # We need to implement this to get the device UUID
48
+ # This will be overridden when integrated into SGLang's worker
49
+ raise NotImplementedError(
50
+ "This method should be overridden by SGLang integration"
51
+ )
52
+
53
+ def get_device_id(self) -> int:
54
+ """Get the device ID."""
55
+ raise NotImplementedError(
56
+ "This method should be overridden by SGLang integration"
57
+ )
58
+
59
+ def get_model_loader(self) -> Callable:
60
+ """Get the model weight loader function."""
61
+ raise NotImplementedError(
62
+ "This method should be overridden by SGLang integration"
63
+ )
64
+
65
+ def get_post_hook(self) -> Optional[Callable]:
66
+ """Get the post-processing hook after weight loading."""
67
+ return None
68
+
69
+ def update_weights_from_ipc(self, zmq_handles: Dict[str, str]):
70
+ """
71
+ Update weights from IPC communication.
72
+ Args:
73
+ zmq_handles: Dict mapping device UUID to ZMQ socket path
74
+ """
75
+ if self._zmq_ctx is None:
76
+ self._zmq_ctx = zmq.Context()
77
+ device_uuid = self.get_device_uuid()
78
+ device_id = self.get_device_id()
79
+ if device_uuid not in zmq_handles:
80
+ raise ValueError(
81
+ f"Device UUID {device_uuid} not found in zmq_handles: {list(zmq_handles.keys())}"
82
+ )
83
+ update_weights_from_ipc(
84
+ self._zmq_ctx,
85
+ zmq_handles[device_uuid],
86
+ device_id=device_id,
87
+ run=self.get_model_loader(),
88
+ post_hook=self.get_post_hook(),
89
+ )
90
+
91
+
92
+ class SGLangCheckpointEngineWorkerExtensionImpl(SGLangCheckpointEngineWorkerExtension):
93
+ """
94
+ Implementation of SGLangCheckpointEngineWorkerExtension that integrates with SGLang's model runner.
95
+ This class provides the concrete implementation for checkpoint-engine IPC weight updates.
96
+ """
97
+
98
+ def __init__(self, model_runner):
99
+ super().__init__()
100
+ self.model_runner = model_runner
101
+
102
+ def get_device_uuid(self) -> str:
103
+ """Get the UUID of current device."""
104
+ # Get device UUID for current device
105
+ device_id = torch.cuda.current_device()
106
+ try:
107
+ return f"GPU-{torch.cuda.get_device_properties(device_id).uuid!s}"
108
+ except AssertionError as e:
109
+ raise ValueError(f"Failed to get GPU UUID for device {device_id}") from e
110
+
111
+ def get_device_id(self) -> int:
112
+ """Get the device ID."""
113
+ return torch.cuda.current_device()
114
+
115
+ def get_model_loader(self) -> Callable:
116
+ """Get the model weight loader function."""
117
+ return self.model_runner.model.load_weights
118
+
119
+ def get_post_hook(self) -> Optional[Callable]:
120
+ """Get the post-processing hook after weight loading."""
121
+
122
+ def post_hook():
123
+ # Perform post-processing after weight loading similar to DefaultModelLoader
124
+ try:
125
+ from sglang.srt.model_loader.loader import device_loading_context
126
+
127
+ # Process quantization methods after loading weights
128
+ for _, module in self.model_runner.model.named_modules():
129
+ quant_method = getattr(module, "quant_method", None)
130
+ if quant_method is not None:
131
+ # Move parameters to device if needed for quantization processing
132
+ target_device = torch.device(
133
+ "cuda", torch.cuda.current_device()
134
+ )
135
+ with device_loading_context(module, target_device):
136
+ quant_method.process_weights_after_loading(module)
137
+ # Call model-specific post-loading hook if available
138
+ if hasattr(self.model_runner.model, "post_load_weights"):
139
+ self.model_runner.model.post_load_weights()
140
+ except Exception as e:
141
+ logger.warning(f"Post-hook processing failed: {e}")
142
+
143
+ return post_hook
sglang/python/sglang/srt/checkpoint_engine/update.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ 1) Launch the server with wait-for-initial-weights option in one terminal:
4
+ python -m sglang.launch_server --model-path /workspace/Qwen/Qwen3-4B/ --tensor-parallel-size 2 --port 19730 --load-format dummy --checkpoint-engine-wait-weights-before-ready --mem-fraction-static 0.7
5
+
6
+ 2) Torchrun this script in another terminal:
7
+ torchrun --nproc-per-node 2 update.py --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2
8
+
9
+ Or use the integrated entry point:
10
+ python -m sglang.srt.checkpoint_engine.update --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2
11
+ """
12
+
13
+ import argparse
14
+ import json
15
+ import os
16
+ import pickle
17
+ import subprocess
18
+ import sys
19
+ import time
20
+ from collections import defaultdict
21
+ from collections.abc import Callable
22
+ from contextlib import contextmanager
23
+ from typing import Literal
24
+
25
+ import httpx
26
+ import torch
27
+ import torch.distributed as dist
28
+ from safetensors import safe_open
29
+
30
+ try:
31
+ from checkpoint_engine.ps import ParameterServer
32
+ from loguru import logger
33
+ except ImportError:
34
+ # Fallback for when checkpoint_engine is not available
35
+ ParameterServer = None
36
+ import logging
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ @contextmanager
42
+ def timer(msg: str):
43
+ start = time.perf_counter()
44
+ yield
45
+ end = time.perf_counter()
46
+ logger.info(f"{msg} duration: {end - start:.2f} seconds")
47
+
48
+
49
+ def check_sglang_ready(
50
+ endpoint: str, inference_parallel_size: int, uds: str | None = None
51
+ ):
52
+ rank = int(os.getenv("RANK", 0))
53
+ if rank != rank // inference_parallel_size * inference_parallel_size:
54
+ return
55
+ retry_num = 0
56
+ transport = None
57
+ if uds is not None:
58
+ transport = httpx.HTTPTransport(uds=uds)
59
+ with httpx.Client(transport=transport) as client:
60
+ while True:
61
+ try:
62
+ response = client.get(f"{endpoint}/ping", timeout=10)
63
+ response.raise_for_status()
64
+ break
65
+ except (httpx.ConnectError, httpx.HTTPStatusError) as e:
66
+ if retry_num % 10 == 0:
67
+ logger.warning(
68
+ f"fail to check sglang ready, retry {retry_num} times, error: {e}"
69
+ )
70
+ retry_num += 1
71
+ time.sleep(0.1)
72
+
73
+
74
+ def split_checkpoint_files(
75
+ checkpoint_path: str, rank: int, world_size: int
76
+ ) -> list[str]:
77
+ checkpoint_files = [
78
+ os.path.join(checkpoint_path, f)
79
+ for f in filter(
80
+ lambda x: x.endswith(".safetensors"), os.listdir(checkpoint_path)
81
+ )
82
+ ]
83
+ files_per_rank = (len(checkpoint_files) + world_size - 1) // world_size
84
+ return checkpoint_files[rank * files_per_rank : (rank + 1) * files_per_rank]
85
+
86
+
87
+ def split_tensors(
88
+ checkpoint_path: str, rank: int, world_size: int
89
+ ) -> dict[str, torch.Tensor]:
90
+ index_fn = os.path.join(checkpoint_path, "model.safetensors.index.json")
91
+ with open(index_fn) as f:
92
+ weight_map: dict[str, str] = json.load(f)["weight_map"]
93
+ weights_per_rank = (len(weight_map) + world_size - 1) // world_size
94
+ fn_tensors: dict[str, list[str]] = defaultdict(list)
95
+ weight_keys = list(weight_map.items())
96
+ for name, file in weight_keys[
97
+ rank * weights_per_rank : (rank + 1) * weights_per_rank
98
+ ]:
99
+ fn_tensors[file].append(name)
100
+ named_tensors = {}
101
+ for file, names in fn_tensors.items():
102
+ with safe_open(os.path.join(checkpoint_path, file), framework="pt") as f:
103
+ for name in names:
104
+ named_tensors[name] = f.get_tensor(name)
105
+ return named_tensors
106
+
107
+
108
+ def req_inference(
109
+ endpoint: str,
110
+ inference_parallel_size: int,
111
+ timeout: float = 300.0,
112
+ uds: str | None = None,
113
+ weight_version: str | None = None,
114
+ ) -> Callable[[list[tuple[str, str]]], None]:
115
+ rank = int(os.getenv("RANK", 0))
116
+ src = rank // inference_parallel_size * inference_parallel_size
117
+
118
+ def req_func(socket_paths: list[tuple[str, str]]):
119
+ if rank == src:
120
+ with httpx.Client(transport=httpx.HTTPTransport(uds=uds)) as client:
121
+ resp = client.post(
122
+ f"{endpoint}/update_weights_from_ipc",
123
+ json={
124
+ "zmq_handles": dict(
125
+ socket_paths[src : src + inference_parallel_size]
126
+ ),
127
+ "flush_cache": True,
128
+ "weight_version": weight_version,
129
+ },
130
+ timeout=timeout,
131
+ )
132
+ resp.raise_for_status()
133
+
134
+ return req_func
135
+
136
+
137
+ def update_weights(
138
+ ps,
139
+ checkpoint_name: str,
140
+ checkpoint_files: list[str],
141
+ named_tensors: dict[str, torch.Tensor],
142
+ req_func: Callable[[list[tuple[str, str]]], None],
143
+ inference_parallel_size: int,
144
+ endpoint: str,
145
+ save_metas_file: str | None = None,
146
+ update_method: Literal["broadcast", "p2p", "all"] = "broadcast",
147
+ uds: str | None = None,
148
+ ):
149
+ ps.register_checkpoint(
150
+ checkpoint_name, files=checkpoint_files, named_tensors=named_tensors
151
+ )
152
+ ps.init_process_group()
153
+ check_sglang_ready(endpoint, inference_parallel_size, uds)
154
+ dist.barrier()
155
+ with timer("Gather metas"):
156
+ ps.gather_metas(checkpoint_name)
157
+ if save_metas_file and int(os.getenv("RANK")) == 0:
158
+ with open(save_metas_file, "wb") as f:
159
+ pickle.dump(ps.get_metas(), f)
160
+
161
+ if update_method == "broadcast" or update_method == "all":
162
+ with timer("Update weights without setting ranks"):
163
+ ps.update(checkpoint_name, req_func)
164
+
165
+ if update_method == "p2p" or update_method == "all":
166
+ if update_method:
167
+ # sleep 2s to wait destroy process group
168
+ time.sleep(2)
169
+ with timer("Update weights with setting ranks"):
170
+ ps.update(
171
+ checkpoint_name, req_func, ranks=list(range(inference_parallel_size))
172
+ )
173
+
174
+
175
+ def join(
176
+ ps: ParameterServer,
177
+ checkpoint_name: str,
178
+ load_metas_file: str,
179
+ req_func: Callable[[list[tuple[str, str]]], None],
180
+ inference_parallel_size: int,
181
+ endpoint: str,
182
+ uds: str | None = None,
183
+ ):
184
+ assert load_metas_file, "load_metas_file is required"
185
+ with open(load_metas_file, "rb") as f:
186
+ metas = pickle.load(f)
187
+ ps.init_process_group()
188
+ check_sglang_ready(endpoint, inference_parallel_size, uds)
189
+ dist.barrier()
190
+ with timer("Gather metas before join"):
191
+ ps.gather_metas(checkpoint_name)
192
+ ps.load_metas(metas)
193
+ with timer(
194
+ f"Update weights with setting ranks as range(0, {inference_parallel_size}) by using p2p"
195
+ ):
196
+ ps.update(checkpoint_name, req_func, ranks=list(range(inference_parallel_size)))
197
+
198
+
199
+ def run_with_torchrun():
200
+ """Run the update script with torchrun automatically."""
201
+ # Parse inference_parallel_size from command line arguments to determine nproc-per-node
202
+ inference_parallel_size = 8 # default
203
+ args = sys.argv[1:] # Skip the script name
204
+
205
+ # Look for --inference-parallel-size in arguments
206
+ for i, arg in enumerate(args):
207
+ if arg == "--inference-parallel-size" and i + 1 < len(args):
208
+ try:
209
+ inference_parallel_size = int(args[i + 1])
210
+ except ValueError:
211
+ pass
212
+ break
213
+ elif arg.startswith("--inference-parallel-size="):
214
+ try:
215
+ inference_parallel_size = int(arg.split("=", 1)[1])
216
+ except ValueError:
217
+ pass
218
+ break
219
+
220
+ # Build torchrun command
221
+ cmd = ["torchrun", f"--nproc-per-node={inference_parallel_size}", __file__] + args
222
+
223
+ print(f"Running: {' '.join(cmd)}", file=sys.stderr)
224
+
225
+ # Execute torchrun with the original script
226
+ try:
227
+ result = subprocess.run(cmd, check=False)
228
+ sys.exit(result.returncode)
229
+ except FileNotFoundError:
230
+ print(
231
+ "Error: torchrun command not found. Please ensure PyTorch is installed.",
232
+ file=sys.stderr,
233
+ )
234
+ sys.exit(1)
235
+ except KeyboardInterrupt:
236
+ print("\nInterrupted by user", file=sys.stderr)
237
+ sys.exit(130)
238
+
239
+
240
+ def main():
241
+ # Check if we're running under torchrun or need to invoke it
242
+ if os.getenv("RANK") is None:
243
+ # Not running under torchrun, so invoke it
244
+ run_with_torchrun()
245
+ return
246
+
247
+ # Running under torchrun, proceed with normal execution
248
+ parser = argparse.ArgumentParser(description="Update weights example")
249
+ parser.add_argument("--checkpoint-path", type=str, default=None)
250
+ parser.add_argument("--save-metas-file", type=str, default=None)
251
+ parser.add_argument("--load-metas-file", type=str, default=None)
252
+ parser.add_argument("--sleep-time", type=int, default=0)
253
+ parser.add_argument("--endpoint", type=str, default="http://localhost:19730")
254
+ parser.add_argument("--inference-parallel-size", type=int, default=8)
255
+ parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0")
256
+ parser.add_argument("--update-method", type=str, default="broadcast")
257
+ parser.add_argument("--uds", type=str, default=None)
258
+ parser.add_argument("--weight-version", type=str, default=None)
259
+ args = parser.parse_args()
260
+
261
+ # Get rank and world_size from environment (set by torchrun)
262
+ rank = int(os.getenv("RANK", 0))
263
+ world_size = int(os.getenv("WORLD_SIZE", 1))
264
+
265
+ req_func = req_inference(
266
+ args.endpoint,
267
+ args.inference_parallel_size,
268
+ uds=args.uds,
269
+ weight_version=args.weight_version,
270
+ )
271
+
272
+ if ParameterServer is None:
273
+ print("Error: checkpoint_engine package not available", file=sys.stderr)
274
+ sys.exit(1)
275
+
276
+ ps = ParameterServer(auto_pg=True)
277
+ ps._p2p_store = None
278
+ if args.load_metas_file:
279
+ join(
280
+ ps,
281
+ args.checkpoint_name,
282
+ args.load_metas_file,
283
+ req_func,
284
+ args.inference_parallel_size,
285
+ args.endpoint,
286
+ args.uds,
287
+ )
288
+ else:
289
+ if args.checkpoint_path and os.path.exists(
290
+ os.path.join(args.checkpoint_path, "model.safetensors.index.json")
291
+ ):
292
+ named_tensors = split_tensors(args.checkpoint_path, rank, world_size)
293
+ checkpoint_files = []
294
+ else:
295
+ checkpoint_files = (
296
+ split_checkpoint_files(args.checkpoint_path, rank, world_size)
297
+ if args.checkpoint_path
298
+ else []
299
+ )
300
+ named_tensors = {}
301
+ update_weights(
302
+ ps,
303
+ args.checkpoint_name,
304
+ checkpoint_files,
305
+ named_tensors,
306
+ req_func,
307
+ args.inference_parallel_size,
308
+ args.endpoint,
309
+ args.save_metas_file,
310
+ args.update_method,
311
+ args.uds,
312
+ )
313
+ time.sleep(args.sleep_time)
314
+
315
+
316
+ if __name__ == "__main__":
317
+ main()
sglang/python/sglang/srt/compilation/__pycache__/compilation_config.cpython-311.pyc ADDED
Binary file (2.79 kB). View file
 
sglang/python/sglang/srt/compilation/__pycache__/compile.cpython-311.pyc ADDED
Binary file (12 kB). View file
 
sglang/python/sglang/srt/compilation/__pycache__/piecewise_context_manager.cpython-311.pyc ADDED
Binary file (5.35 kB). View file
 
sglang/python/sglang/srt/compilation/backend.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/backend.py
2
+
3
+
4
+ import ast
5
+ import dataclasses
6
+ import logging
7
+ import os
8
+ import pprint
9
+ import time
10
+ from collections.abc import Sequence
11
+ from contextlib import contextmanager
12
+ from typing import Any, Callable, Optional
13
+
14
+ import torch
15
+ import torch.fx as fx
16
+ from torch._dispatch.python import enable_python_dispatcher
17
+
18
+ from sglang.srt.compilation.compilation_config import CompilationConfig
19
+ from sglang.srt.compilation.compilation_counter import compilation_counter
20
+ from sglang.srt.compilation.compiler_interface import EagerAdapter, InductorAdaptor
21
+ from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend
22
+ from sglang.srt.compilation.npu_piecewise_backend import NPUPiecewiseBackend
23
+ from sglang.srt.compilation.pass_manager import PostGradPassManager
24
+ from sglang.srt.utils.common import is_npu, rank0_log
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def make_compiler(config: CompilationConfig):
30
+ if config.compiler == "eager":
31
+ return EagerAdapter()
32
+ elif config.compiler == "inductor":
33
+ return InductorAdaptor()
34
+ else:
35
+ raise ValueError(f"Unknown compiler: {config.compiler}")
36
+
37
+
38
+ def make_backend(
39
+ graph: fx.GraphModule,
40
+ compile_config: CompilationConfig,
41
+ inductor_config: dict[str, Any],
42
+ graph_pool: Any,
43
+ piecewise_compile_index: int,
44
+ total_piecewise_compiles: int,
45
+ sym_shape_indices: list[int],
46
+ compiled_graph_for_general_shape: Callable,
47
+ sglang_backend,
48
+ ):
49
+
50
+ backend_cls = CUDAPiecewiseBackend if not is_npu() else NPUPiecewiseBackend
51
+ return backend_cls(
52
+ graph,
53
+ compile_config,
54
+ inductor_config,
55
+ graph_pool,
56
+ piecewise_compile_index,
57
+ total_piecewise_compiles,
58
+ sym_shape_indices,
59
+ compiled_graph_for_general_shape,
60
+ sglang_backend,
61
+ )
62
+
63
+
64
+ class CompilerManager:
65
+ def __init__(
66
+ self,
67
+ config: CompilationConfig,
68
+ ):
69
+ self.cache = dict()
70
+ self.is_cache_updated = False
71
+ self.compiler = make_compiler(config)
72
+
73
+ def compute_hash(self):
74
+ return self.compiler.compute_hash()
75
+
76
+ def initialize_cache(
77
+ self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
78
+ ):
79
+ self.disable_cache = disable_cache
80
+ self.cache_dir = cache_dir
81
+ self.cache_file_path = os.path.join(cache_dir, "sglang_compile_cache.py")
82
+
83
+ if not disable_cache and os.path.exists(self.cache_file_path):
84
+ with open(self.cache_file_path) as f:
85
+ self.cache = ast.literal_eval(f.read())
86
+
87
+ self.compiler.initialize_cache(
88
+ cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
89
+ )
90
+
91
+ def save_to_file(self):
92
+ if self.disable_cache or not self.is_cache_updated:
93
+ return
94
+ printer = pprint.PrettyPrinter(indent=4)
95
+ data = printer.pformat(self.cache)
96
+ with open(self.cache_file_path, "w") as f:
97
+ f.write(data)
98
+
99
+ def load(
100
+ self,
101
+ graph: fx.GraphModule,
102
+ example_inputs: list[Any],
103
+ graph_index: int,
104
+ runtime_shape: Optional[int] = None,
105
+ ) -> Optional[Callable]:
106
+ handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
107
+ compiled_graph = self.compiler.load(
108
+ handle, graph, example_inputs, graph_index, runtime_shape
109
+ )
110
+ if runtime_shape is None:
111
+ logger.debug(
112
+ "Directly load the %s-th graph for dynamic shape from %s via "
113
+ "handle %s",
114
+ graph_index,
115
+ self.compiler.name,
116
+ handle,
117
+ )
118
+ else:
119
+ logger.debug(
120
+ "Directly load the %s-th graph for shape %s from %s via " "handle %s",
121
+ graph_index,
122
+ str(runtime_shape),
123
+ self.compiler.name,
124
+ handle,
125
+ )
126
+ return compiled_graph
127
+
128
+ def compile(
129
+ self,
130
+ graph: fx.GraphModule,
131
+ example_inputs,
132
+ inductor_config: dict[str, Any],
133
+ graph_index: int = 0,
134
+ num_graphs: int = 1,
135
+ runtime_shape: Optional[int] = None,
136
+ ) -> Any:
137
+ if graph_index == 0:
138
+ # before compiling the first graph, record the start time
139
+ global compilation_start_time
140
+ compilation_start_time = time.time()
141
+
142
+ compilation_counter.num_backend_compilations += 1
143
+
144
+ compiled_graph = None
145
+
146
+ # TODO(Yuwei): support cache loading
147
+
148
+ # no compiler cached the graph, or the cache is disabled,
149
+ # we need to compile it
150
+ if isinstance(self.compiler, InductorAdaptor):
151
+ maybe_key = None
152
+ else:
153
+ maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
154
+ compiled_graph, handle = self.compiler.compile(
155
+ graph, example_inputs, inductor_config, runtime_shape, maybe_key
156
+ )
157
+
158
+ assert compiled_graph is not None, "Failed to compile the graph"
159
+
160
+ # store the artifact in the cache
161
+ if handle is not None:
162
+ self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
163
+ compilation_counter.num_cache_entries_updated += 1
164
+ self.is_cache_updated = True
165
+ if graph_index == 0:
166
+ # adds some info logging for the first graph
167
+ if runtime_shape is None:
168
+ logger.info("Cache the graph for dynamic shape for later use")
169
+ else:
170
+ logger.info(
171
+ "Cache the graph of shape %s for later use", str(runtime_shape)
172
+ )
173
+ if runtime_shape is None:
174
+ logger.debug(
175
+ "Store the %s-th graph for dynamic shape from %s via " "handle %s",
176
+ graph_index,
177
+ self.compiler.name,
178
+ handle,
179
+ )
180
+ else:
181
+ logger.debug(
182
+ "Store the %s-th graph for shape %s from %s via handle %s",
183
+ graph_index,
184
+ str(runtime_shape),
185
+ self.compiler.name,
186
+ handle,
187
+ )
188
+
189
+ # after compiling the last graph, record the end time
190
+ if graph_index == num_graphs - 1:
191
+ now = time.time()
192
+ elapsed = now - compilation_start_time
193
+ if runtime_shape is None:
194
+ logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed)
195
+ else:
196
+ logger.info(
197
+ "Compiling a graph for shape %s takes %.2f s",
198
+ runtime_shape,
199
+ elapsed,
200
+ )
201
+
202
+ return compiled_graph
203
+
204
+
205
+ @dataclasses.dataclass
206
+ class SplitItem:
207
+ submod_name: str
208
+ graph_id: int
209
+ is_splitting_graph: bool
210
+ graph: fx.GraphModule
211
+
212
+
213
+ def split_graph(
214
+ graph: fx.GraphModule, ops: list[str]
215
+ ) -> tuple[fx.GraphModule, list[SplitItem]]:
216
+ # split graph by ops
217
+ subgraph_id = 0
218
+ node_to_subgraph_id = {}
219
+ split_op_graphs = []
220
+ for node in graph.graph.nodes:
221
+ if node.op in ("output", "placeholder"):
222
+ continue
223
+ if node.op == "call_function" and str(node.target) in ops:
224
+ subgraph_id += 1
225
+ node_to_subgraph_id[node] = subgraph_id
226
+ split_op_graphs.append(subgraph_id)
227
+ subgraph_id += 1
228
+ else:
229
+ node_to_subgraph_id[node] = subgraph_id
230
+
231
+ # `keep_original_order` is important!
232
+ # otherwise pytorch might reorder the nodes and
233
+ # the semantics of the graph will change when we
234
+ # have mutations in the graph
235
+ split_gm = torch.fx.passes.split_module.split_module(
236
+ graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
237
+ )
238
+
239
+ outputs = []
240
+
241
+ names = [name for (name, module) in split_gm.named_modules()]
242
+
243
+ for name in names:
244
+ if "." in name or name == "":
245
+ # recursive child module or the root module
246
+ continue
247
+
248
+ module = getattr(split_gm, name)
249
+
250
+ graph_id = int(name.replace("submod_", ""))
251
+ outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
252
+
253
+ # sort by intetger graph_id, rather than string name
254
+ outputs.sort(key=lambda x: x.graph_id)
255
+
256
+ return split_gm, outputs
257
+
258
+
259
+ # we share the global graph pool among all the backends
260
+ global_graph_pool = None
261
+
262
+ compilation_start_time = 0.0
263
+
264
+
265
+ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
266
+ def __init__(
267
+ self,
268
+ module: torch.fx.GraphModule,
269
+ compile_submod_names: list[str],
270
+ inductor_config: dict[str, Any],
271
+ graph_pool,
272
+ compile_config: CompilationConfig,
273
+ sglang_backend: "SGLangBackend",
274
+ ):
275
+ super().__init__(module)
276
+ from torch._guards import detect_fake_mode
277
+
278
+ self.fake_mode = detect_fake_mode()
279
+ self.compile_submod_names = compile_submod_names
280
+ self.graph_pool = graph_pool
281
+ self.sglang_backend = sglang_backend
282
+ # When True, it annoyingly dumps the torch.fx.Graph on errors.
283
+ self.extra_traceback = False
284
+ self.inductor_config = inductor_config
285
+ self.compile_config = compile_config
286
+
287
+ def run(self, *args):
288
+ fake_args = [
289
+ self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
290
+ for t in args
291
+ ]
292
+ with self.fake_mode, enable_python_dispatcher():
293
+ return super().run(*fake_args)
294
+
295
+ def call_module(
296
+ self,
297
+ target: torch.fx.node.Target,
298
+ args: tuple[torch.fx.node.Argument, ...],
299
+ kwargs: dict[str, Any],
300
+ ) -> Any:
301
+ assert isinstance(target, str)
302
+ output = super().call_module(target, args, kwargs)
303
+
304
+ if target in self.compile_submod_names:
305
+ index = self.compile_submod_names.index(target)
306
+ submod = self.fetch_attr(target)
307
+ sym_shape_indices = [
308
+ i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
309
+ ]
310
+ global compilation_start_time
311
+ compiled_graph_for_dynamic_shape = (
312
+ self.sglang_backend.compiler_manager.compile(
313
+ submod,
314
+ args,
315
+ self.inductor_config,
316
+ graph_index=index,
317
+ num_graphs=len(self.compile_submod_names),
318
+ runtime_shape=None,
319
+ )
320
+ )
321
+
322
+ self.module.__dict__[target] = make_backend(
323
+ submod,
324
+ self.compile_config,
325
+ self.inductor_config,
326
+ self.graph_pool,
327
+ index,
328
+ len(self.compile_submod_names),
329
+ sym_shape_indices,
330
+ compiled_graph_for_dynamic_shape,
331
+ self.sglang_backend,
332
+ )
333
+
334
+ compilation_counter.num_piecewise_capturable_graphs_seen += 1
335
+
336
+ return output
337
+
338
+
339
+ model_tag: str = "backbone"
340
+
341
+
342
+ @contextmanager
343
+ def set_model_tag(tag: str):
344
+ """Context manager to set the model tag."""
345
+ global model_tag
346
+ assert (
347
+ tag != model_tag
348
+ ), f"Model tag {tag} is the same as the current tag {model_tag}."
349
+ old_tag = model_tag
350
+ model_tag = tag
351
+ try:
352
+ yield
353
+ finally:
354
+ model_tag = old_tag
355
+
356
+
357
+ class SGLangBackend:
358
+
359
+ graph_pool: Any
360
+ _called: bool = False
361
+ # the graph we compiled
362
+ graph: fx.GraphModule
363
+ # the stiching graph module for all the piecewise graphs
364
+ split_gm: fx.GraphModule
365
+ piecewise_graphs: list[SplitItem]
366
+ returned_callable: Callable
367
+ # Inductor passes to run on the graph pre-defunctionalization
368
+ post_grad_passes: Sequence[Callable]
369
+ sym_tensor_indices: list[int]
370
+ input_buffers: list[torch.Tensor]
371
+ compiler_manager: CompilerManager
372
+
373
+ def __init__(
374
+ self,
375
+ config: CompilationConfig,
376
+ graph_pool: Any,
377
+ ):
378
+ rank0_log(f"Initializing SGLangBackend")
379
+ assert graph_pool is not None
380
+ self.graph_pool = graph_pool
381
+
382
+ self.post_grad_pass_manager = PostGradPassManager()
383
+ self.sym_tensor_indices = []
384
+ self.input_buffers = []
385
+
386
+ self.compiler_manager = CompilerManager(config)
387
+ self.inductor_config = {
388
+ "enable_auto_functionalized_v2": False,
389
+ }
390
+ self.compile_config = config
391
+
392
+ def configure_post_pass(self):
393
+ self.post_grad_pass_manager.configure()
394
+ self.inductor_config["post_grad_custom_post_pass"] = self.post_grad_pass_manager
395
+
396
+ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
397
+ rank0_log(f"SGLangBackend __call__")
398
+ base_cache_dir = os.path.expanduser(
399
+ os.getenv("SGLANG_CACHE_DIR", "~/.cache/sglang/")
400
+ )
401
+
402
+ cache_hash = self.compiler_manager.compute_hash()
403
+ cache_dir = os.path.join(
404
+ base_cache_dir,
405
+ "torch_compile_cache",
406
+ cache_hash,
407
+ )
408
+
409
+ os.makedirs(cache_dir, exist_ok=True)
410
+ rank = 0
411
+ dp_rank = 0
412
+ local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", model_tag)
413
+ os.makedirs(local_cache_dir, exist_ok=True)
414
+ self.compiler_manager.initialize_cache(
415
+ local_cache_dir, disable_cache=False, prefix=""
416
+ )
417
+ compilation_counter.num_graphs_seen += 1
418
+
419
+ assert not self._called, "SGLangBackend can only be called once"
420
+
421
+ self.graph = graph
422
+ self.configure_post_pass()
423
+
424
+ self.split_gm, self.piecewise_graphs = split_graph(
425
+ graph,
426
+ self.compile_config.split_ops,
427
+ )
428
+ from torch._dynamo.utils import lazy_format_graph_code
429
+
430
+ # depyf will hook lazy_format_graph_code and dump the graph
431
+ # for debugging, no need to print the graph here
432
+ lazy_format_graph_code("before split", self.graph)
433
+ lazy_format_graph_code("after split", self.split_gm)
434
+
435
+ compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
436
+
437
+ submod_names_to_compile = [
438
+ item.submod_name
439
+ for item in self.piecewise_graphs
440
+ if not item.is_splitting_graph
441
+ ]
442
+
443
+ PiecewiseCompileInterpreter(
444
+ self.split_gm,
445
+ submod_names_to_compile,
446
+ self.inductor_config,
447
+ self.graph_pool,
448
+ self.compile_config,
449
+ self,
450
+ ).run(*example_inputs)
451
+
452
+ rank = torch.distributed.get_rank()
453
+
454
+ if rank == 0:
455
+ graph_path = os.path.join(
456
+ local_cache_dir, f"computation_graph_{time.time()}.py"
457
+ )
458
+ if not os.path.exists(graph_path):
459
+ # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
460
+ # use `print_readable` because it can include submodules
461
+ src = (
462
+ "from __future__ import annotations\nimport torch\n"
463
+ + self.split_gm.print_readable(print_output=False)
464
+ )
465
+ src = src.replace("<lambda>", "GraphModule")
466
+ with open(graph_path, "w") as f:
467
+ f.write(src)
468
+
469
+ rank0_log(f"Computation graph saved to {graph_path}")
470
+
471
+ self._called = True
472
+ return self.split_gm
sglang/python/sglang/srt/compilation/compilation_config.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_config.py
2
+
3
+ from typing import Callable, List, Optional
4
+
5
+ SPLIT_OPS = []
6
+
7
+
8
+ def register_split_op(op_name: Optional[str] = None):
9
+ def decorator(op_func: Callable):
10
+ name = op_name or op_func.__name__
11
+ SPLIT_OPS.append(f"sglang.{name}")
12
+ return op_func
13
+
14
+ return decorator
15
+
16
+
17
+ # TODO(Yuwei): support better compile config support
18
+ class CompilationConfig:
19
+ def __init__(
20
+ self,
21
+ capture_sizes: List[int],
22
+ compiler: str = "eager",
23
+ enable_debug_mode: bool = False,
24
+ ):
25
+ self.traced_files = set()
26
+ self.capture_sizes = capture_sizes
27
+ self.compiler = compiler
28
+ self.enable_debug_mode = enable_debug_mode
29
+ self.split_ops = []
30
+ self.split_ops.extend(SPLIT_OPS)
31
+
32
+ def add_split_op(self, op: str):
33
+ self.split_ops.append(op)
34
+
35
+ def add_traced_file(self, file_path: str):
36
+ self.traced_files.add(file_path)
37
+
38
+ def get_traced_files(self):
39
+ return self.traced_files
40
+
41
+ def get_capture_sizes(self):
42
+ return self.capture_sizes
43
+
44
+ def get_enable_debug_mode(self):
45
+ return self.enable_debug_mode
sglang/python/sglang/srt/compilation/compilation_counter.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_counter.py
2
+
3
+ import copy
4
+ import dataclasses
5
+ from contextlib import contextmanager
6
+
7
+
8
+ @dataclasses.dataclass
9
+ class CompilationCounter:
10
+ num_models_seen: int = 0
11
+ num_graphs_seen: int = 0
12
+ # including the splitting ops
13
+ num_piecewise_graphs_seen: int = 0
14
+ # not including the splitting ops
15
+ num_piecewise_capturable_graphs_seen: int = 0
16
+ num_backend_compilations: int = 0
17
+ # Number of gpu_model_runner attempts to trigger CUDAGraphs capture
18
+ num_gpu_runner_capture_triggers: int = 0
19
+ # Number of CUDAGraphs captured
20
+ num_cudagraph_captured: int = 0
21
+ # InductorAdapter.compile calls
22
+ num_inductor_compiles: int = 0
23
+ # EagerAdapter.compile calls
24
+ num_eager_compiles: int = 0
25
+ # The number of time vLLM's compiler cache entry was updated
26
+ num_cache_entries_updated: int = 0
27
+ # The number of standalone_compile compiled artifacts saved
28
+ num_compiled_artifacts_saved: int = 0
29
+ # Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS
30
+ dynamo_as_is_count: int = 0
31
+
32
+ def clone(self) -> "CompilationCounter":
33
+ return copy.deepcopy(self)
34
+
35
+ @contextmanager
36
+ def expect(self, **kwargs):
37
+ old = self.clone()
38
+ yield
39
+ for k, v in kwargs.items():
40
+ assert getattr(self, k) - getattr(old, k) == v, (
41
+ f"{k} not as expected, before it is {getattr(old, k)}"
42
+ f", after it is {getattr(self, k)}, "
43
+ f"expected diff is {v}"
44
+ )
45
+
46
+
47
+ compilation_counter = CompilationCounter()
sglang/python/sglang/srt/compilation/compile.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import logging
3
+ import os
4
+ import sys
5
+ import types
6
+ from dataclasses import dataclass
7
+ from typing import Any, Callable, Optional, Union
8
+
9
+ import torch
10
+
11
+ from sglang.srt.compilation.compilation_config import CompilationConfig
12
+ from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph
13
+ from sglang.srt.utils.common import rank0_log
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @dataclass
19
+ class IntermediateTensors:
20
+ """For all pipeline stages except the last, we need to return the hidden
21
+ states and residuals to be sent to the next stage. This data structure
22
+ contains the hidden states and residuals for a request.
23
+
24
+ Each stage also needs to handle its own finished_sending and
25
+ finished_recving in case of kv transfer.
26
+ """
27
+
28
+ tensors: dict[str, torch.Tensor]
29
+ # [req_ids]
30
+ finished_sending: Optional[set[str]] = None
31
+ finished_recving: Optional[set[str]] = None
32
+
33
+ def __init__(self, tensors):
34
+ # manually define this function, so that
35
+ # Dynamo knows `IntermediateTensors()` comes from this file.
36
+ # Otherwise, dataclass will generate this function by evaluating
37
+ # a string, and we will lose the information about the source file.
38
+ self.tensors = tensors
39
+
40
+ def __getitem__(self, key: Union[str, slice]):
41
+ if isinstance(key, str):
42
+ return self.tensors[key]
43
+ elif isinstance(key, slice):
44
+ return self.__class__({k: v[key] for k, v in self.tensors.items()})
45
+
46
+ def __setitem__(self, key: str, value: torch.Tensor):
47
+ self.tensors[key] = value
48
+
49
+ def items(self):
50
+ return self.tensors.items()
51
+
52
+ def __len__(self):
53
+ return len(self.tensors)
54
+
55
+ def __eq__(self, other: object):
56
+ return isinstance(other, self.__class__) and self
57
+
58
+ def __repr__(self) -> str:
59
+ return f"IntermediateTensors(tensors={self.tensors})"
60
+
61
+
62
+ def _normalize_dims(dims, ndim: int):
63
+ dims = [dims] if isinstance(dims, int) else list(dims)
64
+ return [d if d >= 0 else ndim + d for d in dims]
65
+
66
+
67
+ class _MaybeIntermediateTensors:
68
+ """Duck-typed check to support your IntermediateTensors without importing."""
69
+
70
+ def __init__(self, obj):
71
+ self.is_intermediate = hasattr(obj, "tensors") and isinstance(
72
+ getattr(obj, "tensors"), dict
73
+ )
74
+ self.obj = obj
75
+
76
+
77
+ def _mark_dynamic_on_value(val, dims):
78
+ if isinstance(val, torch.Tensor):
79
+ torch._dynamo.maybe_mark_dynamic(val, _normalize_dims(dims, val.ndim))
80
+ else:
81
+ mit = _MaybeIntermediateTensors(val)
82
+ if mit.is_intermediate:
83
+ for t in mit.obj.tensors.values():
84
+ torch._dynamo.maybe_mark_dynamic(t, _normalize_dims(dims, t.ndim))
85
+ # else: ignore (None or non-tensor)
86
+
87
+
88
+ def _infer_dynamic_arg_dims_from_annotations(forward_fn):
89
+ sig = inspect.signature(forward_fn)
90
+ dyn = {}
91
+ for name, p in sig.parameters.items():
92
+ ann = p.annotation
93
+ # Accept torch.Tensor / Optional[torch.Tensor] / your IntermediateTensors types by name
94
+ if (
95
+ ann is torch.Tensor
96
+ or getattr(getattr(ann, "__args__", [None])[0], "__name__", "") == "Tensor"
97
+ ):
98
+ dyn[name] = 0
99
+ elif getattr(ann, "__name__", "") in ("IntermediateTensors",) or any(
100
+ getattr(a, "__name__", "") == "IntermediateTensors"
101
+ for a in getattr(ann, "__args__", [])
102
+ ):
103
+ dyn[name] = 0
104
+ elif ann == "torch.Tensor" or ann == "Optional[torch.Tensor]":
105
+ # For future import annotations (e.g. from __future__ import annotations), the annotation is a string
106
+ dyn[name] = 0
107
+ if not dyn:
108
+ raise ValueError("No dynamic dims inferred; pass dynamic_arg_dims explicitly.")
109
+ return dyn
110
+
111
+
112
+ def install_torch_compiled(
113
+ module: torch.nn.Module,
114
+ *,
115
+ dynamic_arg_dims: dict[str, Union[int, list[int]]] | None = None,
116
+ backend_factory: Optional[Callable[[torch.fx.GraphModule, list], Callable]] = None,
117
+ compile_config: CompilationConfig = None,
118
+ fullgraph: bool = True,
119
+ graph_pool: Any = None,
120
+ ):
121
+ rank0_log(f"install_torch_compiled")
122
+ unbound_fwd = module.__class__.forward
123
+ if not callable(unbound_fwd):
124
+ raise TypeError("module.__class__.forward must be callable")
125
+ original_code = unbound_fwd.__code__
126
+
127
+ dyn_map = dynamic_arg_dims or _infer_dynamic_arg_dims_from_annotations(unbound_fwd)
128
+
129
+ if backend_factory is None:
130
+ from sglang.srt.compilation.backend import SGLangBackend
131
+
132
+ backend_factory = lambda gm, ex: SGLangBackend(compile_config, graph_pool)(
133
+ gm, ex
134
+ )
135
+
136
+ compiled_codes: list[type(original_code)] = []
137
+ state = {"compiled": False, "compiled_callable": None}
138
+
139
+ def bytecode_hook(old_code, new_code):
140
+ if old_code is not original_code:
141
+ return
142
+ frame = sys._getframe()
143
+ while frame and frame.f_back:
144
+ frame = frame.f_back
145
+ if (
146
+ frame.f_code.co_name == "_compile"
147
+ and os.path.basename(frame.f_code.co_filename) == "convert_frame.py"
148
+ ):
149
+ break
150
+ try:
151
+ dynamo_frame = frame.f_locals["frame"]
152
+ except Exception:
153
+ return
154
+ if dynamo_frame.f_code is not old_code:
155
+ return
156
+ if dynamo_frame.f_locals.get("self") is not module:
157
+ return
158
+ compiled_codes.append(new_code)
159
+
160
+ torch._dynamo.convert_frame.register_bytecode_hook(bytecode_hook)
161
+
162
+ def _ensure_compiled(self, *args, **kwargs):
163
+ """Compile on first use (with flag ON)."""
164
+ if state["compiled"]:
165
+ return
166
+ # Mark dynamic dims only when we are about to compile
167
+ sig = inspect.signature(unbound_fwd)
168
+ ba = sig.bind(self, *args, **kwargs)
169
+ ba.apply_defaults()
170
+ for name, dims in (dyn_map or {}).items():
171
+ if name in ba.arguments:
172
+ val = ba.arguments[name]
173
+ if val is not None:
174
+ _mark_dynamic_on_value(val, dims)
175
+
176
+ # Avoid cross-instance cache reuse
177
+ torch._dynamo.eval_frame.remove_from_cache(unbound_fwd.__code__)
178
+
179
+ bound = types.MethodType(unbound_fwd, self)
180
+ compiled_callable = torch.compile(
181
+ bound, fullgraph=fullgraph, backend=backend_factory
182
+ )
183
+
184
+ # Trigger Dynamo so bytecode hook can capture
185
+ compiled_callable(*args, **kwargs)
186
+
187
+ state["compiled"] = True
188
+ state["compiled_callable"] = compiled_callable
189
+
190
+ def trampoline(self, *args, **kwargs):
191
+ use_compiled = is_in_piecewise_cuda_graph()
192
+ if use_compiled:
193
+ if not state["compiled"]:
194
+ _ensure_compiled(self, *args, **kwargs)
195
+
196
+ compiled_callable = state["compiled_callable"]
197
+ return compiled_callable(*args, **kwargs)
198
+ else:
199
+ # Explicitly run the original uncompiled forward
200
+ return unbound_fwd(self, *args, **kwargs)
201
+
202
+ module.forward = types.MethodType(trampoline, module)
203
+ return module
sglang/python/sglang/srt/compilation/compiler_interface.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compiler_interface.py
2
+
3
+ import contextlib
4
+ import copy
5
+ import hashlib
6
+ import os
7
+ from contextlib import ExitStack
8
+ from typing import Any, Callable, Optional
9
+ from unittest.mock import patch
10
+
11
+ import torch
12
+ import torch._inductor.compile_fx
13
+ import torch.fx as fx
14
+
15
+ from sglang.srt.compilation.compilation_counter import compilation_counter
16
+ from sglang.srt.compilation.inductor_pass import pass_context
17
+ from sglang.srt.utils.common import torch_release
18
+
19
+
20
+ class CompilerInterface:
21
+ """
22
+ The interface for a compiler that can be used by vLLM.
23
+ """
24
+
25
+ # The name of the compiler, e.g. inductor.
26
+ # This is a class-level attribute.
27
+ name: str
28
+
29
+ def initialize_cache(
30
+ self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
31
+ ):
32
+ """
33
+ when the vLLM process uses `cache_dir` as the cache directory,
34
+ the compiler should initialize itself with the cache directory,
35
+ e.g. by re-directing its own cache directory to a sub-directory.
36
+
37
+ prefix can be used in combination with cache_dir to figure out the base
38
+ cache directory, e.g. there're multiple parts of model being compiled,
39
+ but we want to share the same cache directory for all of them.
40
+
41
+ e.g.
42
+ cache_dir = "/path/to/dir/backbone", prefix = "backbone"
43
+ cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
44
+ """
45
+ pass
46
+
47
+ def compute_hash(self) -> str:
48
+ """
49
+ Gather all the relevant information from the vLLM config,
50
+ to compute a hash so that we can cache the compiled model.
51
+
52
+ See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash]
53
+ to check what information
54
+ is already considered by default. This function should only
55
+ consider the information that is specific to the compiler.
56
+ """
57
+ return ""
58
+
59
+ def compile(
60
+ self,
61
+ graph: fx.GraphModule,
62
+ example_inputs: list[Any],
63
+ compiler_config: dict[str, Any],
64
+ runtime_shape: Optional[int] = None,
65
+ key: Optional[str] = None,
66
+ ) -> tuple[Optional[Callable], Optional[Any]]:
67
+ """
68
+ Compile the graph with the given example inputs and compiler config,
69
+ with a runtime shape. If the `runtime_shape` is None, it means
70
+ the `example_inputs` have a dynamic shape. Otherwise, the
71
+ `runtime_shape` specifies the shape of the inputs. Right now we only
72
+ support one variable shape for all inputs, which is the batchsize
73
+ (number of tokens) during inference.
74
+
75
+ Dynamo will make sure `graph(*example_inputs)` is valid.
76
+
77
+ The function should return a compiled callable function, as well as
78
+ a handle that can be used to directly load the compiled function.
79
+
80
+ The handle should be a plain Python object, preferably a string or a
81
+ file path for readability.
82
+
83
+ If the compiler doesn't support caching, it should return None for the
84
+ handle. If the compiler fails to compile the graph, it should return
85
+ None for the compiled function as well.
86
+
87
+ `key` is required for StandaloneInductorAdapter, it specifies where to
88
+ save the compiled artifact. The compiled artifact gets saved to
89
+ `cache_dir/key`.
90
+ """
91
+ return None, None
92
+
93
+ def load(
94
+ self,
95
+ handle: Any,
96
+ graph: fx.GraphModule,
97
+ example_inputs: list[Any],
98
+ graph_index: int,
99
+ runtime_shape: Optional[int] = None,
100
+ ) -> Callable:
101
+ """
102
+ Load the compiled function from the handle.
103
+ Raises an error if the handle is invalid.
104
+
105
+ The handle is the second return value of the `compile` function.
106
+ """
107
+ raise NotImplementedError("caching is not supported")
108
+
109
+
110
+ def get_inductor_factors() -> list[Any]:
111
+ factors: list[Any] = []
112
+ # summarize system state
113
+ from torch._inductor.codecache import CacheBase
114
+
115
+ system_factors = CacheBase.get_system()
116
+ factors.append(system_factors)
117
+
118
+ # summarize pytorch state
119
+ from torch._inductor.codecache import torch_key
120
+
121
+ torch_factors = torch_key()
122
+ factors.append(torch_factors)
123
+ return factors
124
+
125
+
126
+ class AlwaysHitShapeEnv:
127
+ """
128
+ Why do we need this class:
129
+
130
+ For normal `torch.compile` usage, every compilation will have
131
+ one Dynamo bytecode compilation and one Inductor compilation.
132
+ The Inductor compilation happens under the context of the
133
+ Dynamo bytecode compilation, and that context is used to
134
+ determine the dynamic shape information, etc.
135
+
136
+ For our use case, we only run Dynamo bytecode compilation once,
137
+ and run Inductor compilation multiple times with different shapes
138
+ plus a general shape. The compilation for specific shapes happens
139
+ outside of the context of the Dynamo bytecode compilation. At that
140
+ time, we don't have shape environment to provide to Inductor, and
141
+ it will fail the Inductor code cache lookup.
142
+
143
+ By providing a dummy shape environment that always hits, we can
144
+ make the Inductor code cache lookup always hit, and we can
145
+ compile the graph for different shapes as needed.
146
+
147
+ The following dummy methods are obtained by trial-and-error
148
+ until it works.
149
+ """
150
+
151
+ def __init__(self) -> None:
152
+ self.guards: list[Any] = []
153
+
154
+ def evaluate_guards_expression(self, *args, **kwargs):
155
+ return True
156
+
157
+ def get_pruned_guards(self, *args, **kwargs):
158
+ return []
159
+
160
+ def produce_guards_expression(self, *args, **kwargs):
161
+ return ""
162
+
163
+
164
+ class InductorAdaptor(CompilerInterface):
165
+ """
166
+ The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
167
+ """
168
+
169
+ name = "inductor"
170
+
171
+ def compute_hash(self) -> str:
172
+ factors = get_inductor_factors()
173
+ hash_str = hashlib.md5(
174
+ str(factors).encode(), usedforsecurity=False
175
+ ).hexdigest()[:10]
176
+ return hash_str
177
+
178
+ def initialize_cache(
179
+ self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
180
+ ):
181
+ self.cache_dir = cache_dir
182
+ self.prefix = prefix
183
+ self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
184
+ if disable_cache:
185
+ return
186
+ # redirect the cache directory to a sub-directory
187
+ # set flags so that Inductor and Triton store their cache
188
+ # in the cache_dir, then users only need to copy the cache_dir
189
+ # to another machine to reuse the cache.
190
+ inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
191
+ os.makedirs(inductor_cache, exist_ok=True)
192
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
193
+ triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
194
+ os.makedirs(triton_cache, exist_ok=True)
195
+ os.environ["TRITON_CACHE_DIR"] = triton_cache
196
+
197
+ def compile(
198
+ self,
199
+ graph: fx.GraphModule,
200
+ example_inputs: list[Any],
201
+ compiler_config: dict[str, Any],
202
+ runtime_shape: Optional[int] = None,
203
+ key: Optional[str] = None,
204
+ ) -> tuple[Optional[Callable], Optional[Any]]:
205
+ compilation_counter.num_inductor_compiles += 1
206
+ from torch._inductor.compile_fx import compile_fx
207
+
208
+ current_config = {}
209
+ if compiler_config is not None:
210
+ current_config.update(compiler_config)
211
+
212
+ # disable remote cache
213
+ current_config["fx_graph_cache"] = True
214
+ current_config["fx_graph_remote_cache"] = False
215
+
216
+ set_inductor_config(current_config, runtime_shape)
217
+
218
+ # inductor can inplace modify the graph, so we need to copy it
219
+ # see https://github.com/pytorch/pytorch/issues/138980
220
+ graph = copy.deepcopy(graph)
221
+
222
+ # it's the first time we compile this graph
223
+ # the assumption is that we don't have nested Inductor compilation.
224
+ # compiled_fx_graph_hash will only be called once, and we can hook
225
+ # it to get the hash of the compiled graph directly.
226
+
227
+ hash_str, file_path = None, None
228
+ from torch._inductor.codecache import FxGraphCache, compiled_fx_graph_hash
229
+
230
+ if torch_release[:2] == (2, 5):
231
+ original_load = FxGraphCache.load
232
+ original_load_name = "torch._inductor.codecache.FxGraphCache.load"
233
+
234
+ def hijack_load(*args, **kwargs):
235
+ inductor_compiled_graph = original_load(*args, **kwargs)
236
+ nonlocal file_path
237
+ compiled_fn = inductor_compiled_graph.current_callable
238
+ file_path = compiled_fn.__code__.co_filename # noqa
239
+ if not file_path.startswith(self.base_cache_dir):
240
+ # hooked in the align_inputs_from_check_idxs function
241
+ # in torch/_inductor/utils.py
242
+ for cell in compiled_fn.__closure__:
243
+ if not callable(cell.cell_contents):
244
+ continue
245
+ if cell.cell_contents.__code__.co_filename.startswith(
246
+ self.base_cache_dir
247
+ ):
248
+ # this is the real file path compiled from Inductor
249
+ file_path = cell.cell_contents.__code__.co_filename
250
+ break
251
+ return inductor_compiled_graph
252
+
253
+ hijacked_compile_fx_inner = (
254
+ torch._inductor.compile_fx.compile_fx_inner
255
+ ) # noqa
256
+ elif torch_release >= (2, 6):
257
+ # function renamed in 2.6
258
+ original_load_name = None
259
+
260
+ def hijacked_compile_fx_inner(*args, **kwargs):
261
+ output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
262
+ nonlocal hash_str
263
+ inductor_compiled_graph = output
264
+ if inductor_compiled_graph is not None:
265
+ nonlocal file_path
266
+ compiled_fn = inductor_compiled_graph.current_callable
267
+ file_path = compiled_fn.__code__.co_filename # noqa
268
+ if not file_path.startswith(self.base_cache_dir):
269
+ # hooked in the align_inputs_from_check_idxs function
270
+ # in torch/_inductor/utils.py
271
+ for cell in compiled_fn.__closure__:
272
+ if not callable(cell.cell_contents):
273
+ continue
274
+ code = cell.cell_contents.__code__
275
+ if code.co_filename.startswith(self.base_cache_dir):
276
+ # this is the real file path
277
+ # compiled from Inductor
278
+ file_path = code.co_filename
279
+ break
280
+ hash_str = inductor_compiled_graph._fx_graph_cache_key
281
+ return output
282
+
283
+ def hijack_compiled_fx_graph_hash(*args, **kwargs):
284
+ out = compiled_fx_graph_hash(*args, **kwargs)
285
+ nonlocal hash_str
286
+ hash_str = out[0]
287
+ return out
288
+
289
+ def _check_can_cache(*args, **kwargs):
290
+ # no error means it can be cached.
291
+ # Inductor refuses to cache the graph outside of Dynamo
292
+ # tracing context, and also disables caching for graphs
293
+ # with high-order ops.
294
+ # For vLLM, in either case, we want to cache the graph.
295
+ # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
296
+ return
297
+
298
+ def _get_shape_env() -> AlwaysHitShapeEnv:
299
+ return AlwaysHitShapeEnv()
300
+
301
+ with ExitStack() as stack:
302
+ # hijack to get the compiled graph itself
303
+ if original_load_name is not None:
304
+ stack.enter_context(patch(original_load_name, hijack_load))
305
+
306
+ # for hijacking the hash of the compiled graph
307
+ stack.enter_context(
308
+ patch(
309
+ "torch._inductor.codecache.compiled_fx_graph_hash",
310
+ hijack_compiled_fx_graph_hash,
311
+ )
312
+ )
313
+
314
+ # for providing a dummy shape environment
315
+ stack.enter_context(
316
+ patch(
317
+ "torch._inductor.codecache.FxGraphCache._get_shape_env",
318
+ _get_shape_env,
319
+ )
320
+ )
321
+
322
+ from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
323
+
324
+ # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
325
+ if hasattr(AOTAutogradCache, "_get_shape_env"):
326
+ stack.enter_context(
327
+ patch(
328
+ "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
329
+ _get_shape_env,
330
+ )
331
+ )
332
+
333
+ # for forcing the graph to be cached
334
+ stack.enter_context(
335
+ patch(
336
+ "torch._inductor.codecache.FxGraphCache._check_can_cache",
337
+ _check_can_cache,
338
+ )
339
+ )
340
+
341
+ # Dynamo metrics context, see method for more details.
342
+ stack.enter_context(self.metrics_context())
343
+
344
+ # Disable remote caching. When these are on, on remote cache-hit,
345
+ # the monkey-patched functions never actually get called.
346
+ # vLLM today assumes and requires the monkey-patched functions to
347
+ # get hit.
348
+ # TODO(zou3519): we're going to replace this all with
349
+ # standalone_compile sometime.
350
+
351
+ stack.enter_context(
352
+ torch._inductor.config.patch(fx_graph_remote_cache=False)
353
+ )
354
+ # InductorAdaptor (unfortunately) requires AOTAutogradCache
355
+ # to be turned off to run. It will fail to acquire the hash_str
356
+ # and error if not.
357
+ # StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
358
+ stack.enter_context(
359
+ torch._functorch.config.patch(enable_autograd_cache=False)
360
+ )
361
+ stack.enter_context(
362
+ torch._functorch.config.patch(enable_remote_autograd_cache=False)
363
+ )
364
+
365
+ with pass_context(runtime_shape):
366
+ compiled_graph = compile_fx(
367
+ graph,
368
+ example_inputs,
369
+ inner_compile=hijacked_compile_fx_inner,
370
+ config_patches=current_config,
371
+ )
372
+ return compiled_graph, (hash_str, file_path)
373
+
374
+ def load(
375
+ self,
376
+ handle: Any,
377
+ graph: fx.GraphModule,
378
+ example_inputs: list[Any],
379
+ graph_index: int,
380
+ runtime_shape: Optional[int] = None,
381
+ ) -> Callable:
382
+ assert isinstance(handle, tuple)
383
+ assert isinstance(handle[0], str)
384
+ assert isinstance(handle[1], str)
385
+ hash_str = handle[0]
386
+
387
+ from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
388
+ from torch._inductor.codecache import FxGraphCache
389
+
390
+ with ExitStack() as exit_stack:
391
+ exit_stack.enter_context(
392
+ patch(
393
+ "torch._inductor.codecache.FxGraphCache._get_shape_env",
394
+ lambda *args, **kwargs: AlwaysHitShapeEnv(),
395
+ )
396
+ )
397
+ # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
398
+ if hasattr(AOTAutogradCache, "_get_shape_env"):
399
+ exit_stack.enter_context(
400
+ patch(
401
+ "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
402
+ lambda *args, **kwargs: AlwaysHitShapeEnv(),
403
+ )
404
+ )
405
+
406
+ # Dynamo metrics context, see method for more details.
407
+ exit_stack.enter_context(self.metrics_context())
408
+
409
+ if torch_release[:2] == (2, 5):
410
+ inductor_compiled_graph = FxGraphCache._lookup_graph(
411
+ hash_str, example_inputs, True, False
412
+ )
413
+ assert inductor_compiled_graph is not None, (
414
+ "Inductor cache lookup failed. Please remove"
415
+ f"the cache directory and try again." # noqa
416
+ )
417
+ elif torch_release >= (2, 6):
418
+ from torch._inductor.output_code import CompiledFxGraphConstantsWithGm
419
+
420
+ constants = CompiledFxGraphConstantsWithGm(graph)
421
+ inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
422
+ hash_str, example_inputs, True, None, constants
423
+ )
424
+ assert inductor_compiled_graph is not None, (
425
+ "Inductor cache lookup failed. Please remove"
426
+ f"the cache directory and try again." # noqa
427
+ )
428
+
429
+ # Inductor calling convention (function signature):
430
+ # f(list) -> tuple
431
+ # Dynamo calling convention (function signature):
432
+ # f(*args) -> Any
433
+
434
+ # need to know if the graph returns a tuple
435
+ from torch._inductor.compile_fx import graph_returns_tuple
436
+
437
+ returns_tuple = graph_returns_tuple(graph)
438
+
439
+ # this is the callable we return to Dynamo to run
440
+ def compiled_graph(*args):
441
+ # convert args to list
442
+ list_args = list(args)
443
+ graph_output = inductor_compiled_graph(list_args)
444
+ # unpack the tuple if needed
445
+ if returns_tuple:
446
+ return graph_output
447
+ else:
448
+ return graph_output[0]
449
+
450
+ return compiled_graph
451
+
452
+ def metrics_context(self) -> contextlib.AbstractContextManager:
453
+ """
454
+ This method returns the Dynamo metrics context (if it exists,
455
+ otherwise a null context). It is used by various compile components.
456
+ Present in torch>=2.6, it's used inside FxGraphCache in
457
+ torch==2.6 (but not after). It might also be used in various other
458
+ torch.compile internal functions.
459
+
460
+ Because it is re-entrant, we always set it (even if entering via Dynamo
461
+ and the context was already entered). We might want to revisit if it
462
+ should be set at a different level of compilation.
463
+
464
+ This is likely a bug in PyTorch: public APIs should not rely on
465
+ manually setting up internal contexts. But we also rely on non-public
466
+ APIs which might not provide these guarantees.
467
+ """
468
+ import torch._dynamo.utils
469
+
470
+ return torch._dynamo.utils.get_metrics_context()
471
+
472
+
473
+ def set_inductor_config(config, runtime_shape):
474
+ if isinstance(runtime_shape, int):
475
+ # for a specific batchsize, tuning triton kernel parameters
476
+ # can be beneficial
477
+ config["max_autotune"] = True
478
+ config["coordinate_descent_tuning"] = True
479
+
480
+
481
+ class EagerAdapter(CompilerInterface):
482
+ name = "eager"
483
+
484
+ def compile(
485
+ self,
486
+ graph: fx.GraphModule,
487
+ example_inputs: list[Any],
488
+ compiler_config: dict[str, Any],
489
+ runtime_shape: Optional[int] = None,
490
+ key: Optional[str] = None,
491
+ num_graphs: int = 1,
492
+ ) -> tuple[Optional[Callable], Optional[Any]]:
493
+ return graph, None
494
+
495
+ def load(
496
+ self,
497
+ handle: Any,
498
+ graph: fx.GraphModule,
499
+ example_inputs: list[Any],
500
+ graph_index: int,
501
+ runtime_shape: Optional[int] = None,
502
+ num_graphs: int = 1,
503
+ ) -> Callable:
504
+ raise NotImplementedError("eager compilation is not supported")
sglang/python/sglang/srt/compilation/cuda_piecewise_backend.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/cuda_piecewise_backend.py
2
+
3
+ import dataclasses
4
+ import logging
5
+ from contextlib import ExitStack
6
+ from typing import Any, Callable, Optional
7
+ from unittest.mock import patch
8
+
9
+ import torch
10
+ import torch.fx as fx
11
+
12
+ from sglang.srt.compilation.compilation_config import CompilationConfig
13
+ from sglang.srt.compilation.compilation_counter import compilation_counter
14
+ from sglang.srt.compilation.piecewise_context_manager import (
15
+ get_pcg_capture_stream,
16
+ is_in_pcg_torch_compile,
17
+ )
18
+ from sglang.srt.compilation.weak_ref_tensor import weak_ref_tensors
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ @dataclasses.dataclass
24
+ class ConcreteSizeEntry:
25
+ runtime_shape: int
26
+ need_to_compile: bool # the size is in compile_sizes
27
+ use_cudagraph: bool # the size is in cudagraph_capture_sizes
28
+
29
+ compiled: bool = False
30
+ runnable: Callable = None # type: ignore
31
+ num_finished_warmup: int = 0
32
+ cudagraph: Optional[torch.cuda.CUDAGraph] = None
33
+ output: Optional[Any] = None
34
+
35
+ # for cudagraph debugging, track the input addresses
36
+ # during capture, and check if they are the same during replay
37
+ input_addresses: Optional[list[int]] = None
38
+
39
+
40
+ class CUDAPiecewiseBackend:
41
+
42
+ def __init__(
43
+ self,
44
+ graph: fx.GraphModule,
45
+ compile_config: CompilationConfig,
46
+ inductor_config: dict[str, Any],
47
+ graph_pool: Any,
48
+ piecewise_compile_index: int,
49
+ total_piecewise_compiles: int,
50
+ sym_shape_indices: list[int],
51
+ compiled_graph_for_general_shape: Callable,
52
+ sglang_backend,
53
+ ):
54
+ """
55
+ The backend for piecewise compilation.
56
+ It mainly handles the compilation and cudagraph capturing.
57
+
58
+ We will compile `self.graph` once for the general shape,
59
+ and then compile for different shapes specified in
60
+ `compilation_config.compile_sizes`.
61
+
62
+ Independently, we will capture cudagraph for different shapes.
63
+
64
+ If a shape needs both compilation and cudagraph, we will
65
+ compile it first, and then capture cudagraph.
66
+ """
67
+ self.graph = graph
68
+ self.inductor_config = inductor_config
69
+ self.graph_pool = graph_pool
70
+ self.piecewise_compile_index = piecewise_compile_index
71
+ self.total_piecewise_compiles = total_piecewise_compiles
72
+ self.sglang_backend = sglang_backend
73
+
74
+ self.is_first_graph = piecewise_compile_index == 0
75
+ self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
76
+
77
+ self.compile_sizes: set[int] = set([])
78
+ self.compile_config = compile_config
79
+ self.cudagraph_capture_sizes: set[int] = set(compile_config.get_capture_sizes())
80
+
81
+ self.first_run_finished = False
82
+
83
+ self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
84
+
85
+ self.sym_shape_indices = sym_shape_indices
86
+
87
+ # the entries for different shapes that we need to either
88
+ # compile or capture cudagraph
89
+ self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
90
+
91
+ # to_be_compiled_sizes tracks the remaining sizes to compile,
92
+ # and updates during the compilation process, so we need to copy it
93
+ self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
94
+ for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
95
+ self.concrete_size_entries[shape] = ConcreteSizeEntry(
96
+ runtime_shape=shape,
97
+ need_to_compile=shape in self.compile_sizes,
98
+ use_cudagraph=shape in self.cudagraph_capture_sizes,
99
+ )
100
+
101
+ def check_for_ending_compilation(self):
102
+ if self.is_last_graph and not self.to_be_compiled_sizes:
103
+ # no specific sizes to compile
104
+ # save the hash of the inductor graph for the next run
105
+ self.sglang_backend.compiler_manager.save_to_file()
106
+
107
+ def __call__(self, *args) -> Any:
108
+ if not self.first_run_finished:
109
+ self.first_run_finished = True
110
+ self.check_for_ending_compilation()
111
+ return self.compiled_graph_for_general_shape(*args)
112
+
113
+ if len(self.sym_shape_indices) == 0:
114
+ return self.compiled_graph_for_general_shape(*args)
115
+
116
+ runtime_shape = args[self.sym_shape_indices[0]]
117
+ if runtime_shape not in self.concrete_size_entries:
118
+ # we don't need to do anything for this shape
119
+ return self.compiled_graph_for_general_shape(*args)
120
+
121
+ entry = self.concrete_size_entries[runtime_shape]
122
+
123
+ if entry.runnable is None:
124
+ entry.runnable = self.compiled_graph_for_general_shape
125
+
126
+ if entry.need_to_compile and not entry.compiled:
127
+ entry.compiled = True
128
+ self.to_be_compiled_sizes.remove(runtime_shape)
129
+ # args are real arguments
130
+ entry.runnable = self.sglang_backend.compiler_manager.compile(
131
+ self.graph,
132
+ args,
133
+ self.inductor_config,
134
+ graph_index=self.piecewise_compile_index,
135
+ num_graphs=self.total_piecewise_compiles,
136
+ runtime_shape=runtime_shape,
137
+ )
138
+
139
+ # finished compilations for all required shapes
140
+ if self.is_last_graph and not self.to_be_compiled_sizes:
141
+ self.check_for_ending_compilation()
142
+
143
+ if is_in_pcg_torch_compile():
144
+ return entry.runnable(*args)
145
+
146
+ if entry.cudagraph is None:
147
+ if entry.num_finished_warmup < 1: # noqa
148
+ entry.num_finished_warmup += 1
149
+ return entry.runnable(*args)
150
+
151
+ if self.compile_config.get_enable_debug_mode():
152
+ input_addresses = [
153
+ x.data_ptr() for x in args if isinstance(x, torch.Tensor)
154
+ ]
155
+ entry.input_addresses = input_addresses
156
+ cudagraph = torch.cuda.CUDAGraph()
157
+
158
+ with ExitStack() as stack:
159
+ if not self.is_first_graph:
160
+ # during every model forward, we will capture
161
+ # many pieces of cudagraphs (roughly one per layer).
162
+ # running gc again and again across layers will
163
+ # make the cudagraph capture very slow.
164
+ # therefore, we only run gc for the first graph,
165
+ # and disable gc for the rest of the graphs.
166
+ stack.enter_context(patch("gc.collect", lambda: None))
167
+ stack.enter_context(patch("torch.cuda.empty_cache", lambda: None))
168
+ # mind-exploding: carefully manage the reference and memory.
169
+ stream = get_pcg_capture_stream()
170
+ assert (
171
+ stream is not None
172
+ ), "PCG capture stream is not set, please check if runtime recompilation happened"
173
+ with torch.cuda.graph(cudagraph, pool=self.graph_pool, stream=stream):
174
+ # `output` is managed by pytorch's cudagraph pool
175
+ output = entry.runnable(*args)
176
+ if self.is_last_graph:
177
+ # by converting it to weak ref,
178
+ # the original `output` will immediately be released
179
+ # to save memory. It is only safe to do this for
180
+ # the last graph, because the output of the last graph
181
+ # will not be used by any other cuda graph.
182
+ output = weak_ref_tensors(output)
183
+
184
+ # here we always use weak ref for the output
185
+ # to save memory
186
+ entry.output = weak_ref_tensors(output)
187
+ entry.cudagraph = cudagraph
188
+
189
+ compilation_counter.num_cudagraph_captured += 1
190
+
191
+ # important: we need to return the output, rather than
192
+ # the weak ref of the output, so that pytorch can correctly
193
+ # manage the memory during cuda graph capture
194
+ return output
195
+
196
+ if self.compile_config.get_enable_debug_mode():
197
+ # check if the input addresses are the same
198
+ new_input_addresses = [
199
+ x.data_ptr() for x in args if isinstance(x, torch.Tensor)
200
+ ]
201
+ assert new_input_addresses == entry.input_addresses, (
202
+ "Input addresses for cudagraphs are different during replay."
203
+ f" Expected {entry.input_addresses}, got {new_input_addresses}"
204
+ )
205
+ entry.cudagraph.replay()
206
+ return entry.output
sglang/python/sglang/srt/compilation/fix_functionalization.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fix_functionalization.py
2
+
3
+ import logging
4
+ import operator
5
+ from collections.abc import Iterable
6
+ from typing import Optional, Union
7
+
8
+ import torch
9
+ from torch._higher_order_ops.auto_functionalize import auto_functionalized
10
+
11
+ from sglang.srt.compilation.fx_utils import is_func
12
+ from sglang.srt.compilation.inductor_pass import SGLangInductorPass
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class FixFunctionalizationPass(SGLangInductorPass):
18
+ """
19
+ This pass defunctionalizes certain nodes to avoid redundant tensor copies.
20
+ After this pass, DCE (dead-code elimination) should never be run,
21
+ as de-functionalized nodes may appear as dead code.
22
+
23
+ To add new nodes to defunctionalize, add to the if-elif chain in __call__.
24
+ """
25
+
26
+ def __call__(self, graph: torch.fx.Graph):
27
+ self.begin()
28
+ self.dump_graph(graph, "before_fix_functionalization")
29
+
30
+ self.nodes_to_remove: list[torch.fx.Node] = []
31
+ count = 0
32
+ for node in graph.nodes:
33
+ if not is_func(node, auto_functionalized):
34
+ continue # Avoid deep if-elif nesting
35
+ count += 1
36
+
37
+ self.dump_graph(graph, "before_fix_functionalization_cleanup")
38
+
39
+ # Remove the nodes all at once
40
+ count_removed = len(self.nodes_to_remove)
41
+ for node in self.nodes_to_remove:
42
+ graph.erase_node(node)
43
+
44
+ logger.debug(
45
+ "De-functionalized %s nodes, removed %s nodes", count, count_removed
46
+ )
47
+ self.dump_graph(graph, "after_fix_functionalization")
48
+ self.end_and_log()
49
+
50
+ def _remove(self, node_or_nodes: Union[torch.fx.Node, Iterable[torch.fx.Node]]):
51
+ """
52
+ Stage a node (or nodes) for removal at the end of the pass.
53
+ """
54
+ if isinstance(node_or_nodes, torch.fx.Node):
55
+ self.nodes_to_remove.append(node_or_nodes)
56
+ else:
57
+ self.nodes_to_remove.extend(node_or_nodes)
58
+
59
+ def defunctionalize(
60
+ self,
61
+ graph: torch.fx.Graph,
62
+ node: torch.fx.Node,
63
+ mutated_args: dict[int, Union[torch.fx.Node, str]],
64
+ args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None,
65
+ ):
66
+ """
67
+ De-functionalize a node by replacing it with a call to the original.
68
+ It also replaces the getitem users with the mutated arguments.
69
+ See replace_users_with_mutated_args and insert_defunctionalized.
70
+ """
71
+ self.replace_users_with_mutated_args(node, mutated_args)
72
+ self.insert_defunctionalized(graph, node, args=args)
73
+ self._remove(node)
74
+
75
+ def replace_users_with_mutated_args(
76
+ self, node: torch.fx.Node, mutated_args: dict[int, Union[torch.fx.Node, str]]
77
+ ):
78
+ """
79
+ Replace all getitem users of the auto-functionalized node with the
80
+ mutated arguments.
81
+ :param node: The auto-functionalized node
82
+ :param mutated_args: The mutated arguments, indexed by getitem index.
83
+ If the value of an arg is a string, `node.kwargs[arg]` is used.
84
+ """
85
+ for idx, user in self.getitem_users(node).items():
86
+ arg = mutated_args[idx]
87
+ arg = node.kwargs[arg] if isinstance(arg, str) else arg
88
+ user.replace_all_uses_with(arg)
89
+ self._remove(user)
90
+
91
+ def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
92
+ """
93
+ Returns the operator.getitem users of the auto-functionalized node,
94
+ indexed by the index they are getting.
95
+ """
96
+ users = {}
97
+ for user in node.users:
98
+ if is_func(user, operator.getitem):
99
+ idx = user.args[1]
100
+ users[idx] = user
101
+ return users
102
+
103
+ def insert_defunctionalized(
104
+ self,
105
+ graph: torch.fx.Graph,
106
+ node: torch.fx.Node,
107
+ args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None,
108
+ ):
109
+ """
110
+ Insert a new defunctionalized node into the graph before node.
111
+ If one of the kwargs is 'out', provide args directly,
112
+ as node.kwargs cannot be used.
113
+ See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
114
+
115
+ :param graph: Graph to insert the defunctionalized node into
116
+ :param node: The auto-functionalized node to defunctionalize
117
+ :param args: If we cannot use kwargs, specify args directly.
118
+ If an arg is a string, `node.kwargs[arg]` is used.
119
+ """ # noqa: E501
120
+ assert is_func(
121
+ node, auto_functionalized
122
+ ), f"node must be auto-functionalized, is {node} instead"
123
+
124
+ # Create a new call to the original function
125
+ with graph.inserting_before(node):
126
+ function = node.args[0]
127
+ if args is None:
128
+ graph.call_function(function, kwargs=node.kwargs)
129
+ else:
130
+ # Args passed as strings refer to items in node.kwargs
131
+ args = tuple(
132
+ node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
133
+ )
134
+ graph.call_function(function, args=args)
sglang/python/sglang/srt/compilation/fx_utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fx_utils.py
2
+
3
+ import operator
4
+ from collections.abc import Iterable, Iterator
5
+ from typing import Optional
6
+
7
+ from torch import fx
8
+ from torch._higher_order_ops.auto_functionalize import auto_functionalized
9
+ from torch._ops import OpOverload
10
+
11
+
12
+ def is_func(node: fx.Node, target) -> bool:
13
+ return node.op == "call_function" and node.target == target
14
+
15
+
16
+ def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
17
+ return is_func(node, auto_functionalized) and node.args[0] == op
18
+
19
+
20
+ # Returns the first specified node with the given op (if it exists)
21
+ def find_specified_fn_maybe(
22
+ nodes: Iterable[fx.Node], op: OpOverload
23
+ ) -> Optional[fx.Node]:
24
+ for node in nodes:
25
+ if node.target == op:
26
+ return node
27
+ return None
28
+
29
+
30
+ # Returns the first specified node with the given op
31
+ def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
32
+ node = find_specified_fn_maybe(nodes, op)
33
+ assert node is not None, f"Could not find {op} in nodes {nodes}"
34
+ return node
35
+
36
+
37
+ # Returns the first auto_functionalized node with the given op (if it exists)
38
+ def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> Optional[fx.Node]:
39
+ for node in nodes:
40
+ if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
41
+ return node
42
+ return None
43
+
44
+
45
+ # Returns the first auto_functionalized node with the given op
46
+ def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
47
+ node = find_auto_fn_maybe(nodes, op)
48
+ assert node is not None, f"Could not find {op} in nodes {nodes}"
49
+ return node
50
+
51
+
52
+ # Returns the getitem node that extracts the idx-th element from node
53
+ # (if it exists)
54
+ def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]:
55
+ for user in node.users:
56
+ if is_func(user, operator.getitem) and user.args[1] == idx:
57
+ return user
58
+ return None
59
+
60
+
61
+ # Returns the getitem node that extracts the idx-th element from node
62
+ def find_getitem(node: fx.Node, idx: int) -> fx.Node:
63
+ ret = find_getitem_maybe(node, idx)
64
+ assert ret is not None, f"Could not find getitem {idx} in node {node}"
65
+ return ret
66
+
67
+
68
+ # An auto-functionalization-aware utility for finding nodes with a specific op
69
+ def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
70
+ if not op._schema.is_mutable:
71
+ yield from graph.find_nodes(op="call_function", target=op)
72
+
73
+ for n in graph.find_nodes(op="call_function", target=auto_functionalized):
74
+ if n.args[0] == op:
75
+ yield n
76
+
77
+
78
+ # Asserts that the node only has one user and returns it
79
+ # Even if a node has only 1 user, it might share storage with another node,
80
+ # which might need to be taken into account.
81
+ def get_only_user(node: fx.Node) -> fx.Node:
82
+ assert len(node.users) == 1
83
+ return next(iter(node.users))