Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- sglang/3rdparty/amd/profiling/PROFILING.md +425 -0
- sglang/3rdparty/amd/profiling/client.sh +27 -0
- sglang/3rdparty/amd/profiling/install_rpd.sh +10 -0
- sglang/3rdparty/amd/profiling/loadTracer.sh +43 -0
- sglang/3rdparty/amd/profiling/rpd.patch +12 -0
- sglang/3rdparty/amd/profiling/rpd_profile_server_enable.patch +49 -0
- sglang/3rdparty/amd/profiling/rpd_profile_server_enable_wCPU_activities.patch +126 -0
- sglang/3rdparty/amd/profiling/server.sh +20 -0
- sglang/3rdparty/amd/profiling/torch_profiler.patch +25 -0
- sglang/3rdparty/amd/sgl-kernel/CMakeLists_rocm.txt +159 -0
- sglang/3rdparty/amd/sgl-kernel/build_rocm.sh +123 -0
- sglang/3rdparty/amd/sgl-kernel/rename_wheels_rocm.sh +30 -0
- sglang/3rdparty/amd/sgl-kernel/rocm_hipify.py +40 -0
- sglang/3rdparty/amd/tuning/TUNING.md +118 -0
- sglang/3rdparty/amd/tuning/benchmark_moe_rocm.py +378 -0
- sglang/docs/supported_models/extending/modelscope.md +28 -0
- sglang/docs/supported_models/extending/support_new_models.md +320 -0
- sglang/docs/supported_models/retrieval_ranking/classify_models.md +162 -0
- sglang/docs/supported_models/retrieval_ranking/embedding_models.md +126 -0
- sglang/docs/supported_models/retrieval_ranking/rerank_models.md +313 -0
- sglang/docs/supported_models/specialized/index.rst +9 -0
- sglang/docs/supported_models/specialized/reward_models.md +28 -0
- sglang/docs/supported_models/text_generation/diffusion_language_models.md +111 -0
- sglang/docs/supported_models/text_generation/generative_models.md +72 -0
- sglang/docs/supported_models/text_generation/index.rst +11 -0
- sglang/docs/supported_models/text_generation/multimodal_language_models.md +136 -0
- sglang/python/sglang/srt/__pycache__/constants.cpython-311.pyc +0 -0
- sglang/python/sglang/srt/__pycache__/environ.cpython-311.pyc +0 -0
- sglang/python/sglang/srt/batch_overlap/__pycache__/operations.cpython-311.pyc +0 -0
- sglang/python/sglang/srt/batch_overlap/__pycache__/operations_strategy.cpython-311.pyc +0 -0
- sglang/python/sglang/srt/batch_overlap/__pycache__/single_batch_overlap.cpython-311.pyc +0 -0
- sglang/python/sglang/srt/batch_overlap/__pycache__/two_batch_overlap.cpython-311.pyc +0 -0
- sglang/python/sglang/srt/batch_overlap/operations.py +213 -0
- sglang/python/sglang/srt/batch_overlap/operations_strategy.py +302 -0
- sglang/python/sglang/srt/batch_overlap/single_batch_overlap.py +144 -0
- sglang/python/sglang/srt/batch_overlap/two_batch_overlap.py +1082 -0
- sglang/python/sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/python/sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +143 -0
- sglang/python/sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/python/sglang/srt/compilation/__pycache__/compilation_config.cpython-311.pyc +0 -0
- sglang/python/sglang/srt/compilation/__pycache__/compile.cpython-311.pyc +0 -0
- sglang/python/sglang/srt/compilation/__pycache__/piecewise_context_manager.cpython-311.pyc +0 -0
- sglang/python/sglang/srt/compilation/backend.py +472 -0
- sglang/python/sglang/srt/compilation/compilation_config.py +45 -0
- sglang/python/sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/python/sglang/srt/compilation/compile.py +203 -0
- sglang/python/sglang/srt/compilation/compiler_interface.py +504 -0
- sglang/python/sglang/srt/compilation/cuda_piecewise_backend.py +206 -0
- sglang/python/sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/python/sglang/srt/compilation/fx_utils.py +83 -0
sglang/3rdparty/amd/profiling/PROFILING.md
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Profiling SGLang Infer System with AMD GPUs
|
| 2 |
+
This AppNote describes the SGLang profiling technical, code augment and running steps for systems with AMD Instinct GPUs, nevertheless the same procedure may work with Nvidia GPUs too.
|
| 3 |
+
Examples and steps are provided in detail, to facilitate easy reproduce and use to localize performance problem towards optimizations.
|
| 4 |
+
Two primary methods are covered:
|
| 5 |
+
- [RPD](https://github.com/ROCm/rocmProfileData.git)
|
| 6 |
+
- [PyTorch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
|
| 7 |
+
|
| 8 |
+
### Profiling SGLang Infer System with RPD Profiler
|
| 9 |
+
RPD profiler is a low-overhead cross-platform profiler. Therefore, the same RPD code augment not only works for profiling on ROCm/AMD GPUs, but also works for profiling on CUDA/Nvidia GPUs as well. To do RPD profiling on SGLang repository, please use scripts and patch files included in this directory and follow the steps below:
|
| 10 |
+
1. Install RPD with rpd.patch applied during installation using install_rpd.sh, both files are in this directory.
|
| 11 |
+
|
| 12 |
+
install_rpd.sh
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
# download and install RPD
|
| 16 |
+
apt update && apt install -y sqlite3 libsqlite3-dev libfmt-dev
|
| 17 |
+
|
| 18 |
+
# install rpd module
|
| 19 |
+
git clone https://github.com/ROCmSoftwarePlatform/rocmProfileData
|
| 20 |
+
cd rocmProfileData
|
| 21 |
+
git checkout 976899e9c6dbc6dd2bccf770818e4e44125590ac
|
| 22 |
+
git apply rpd.patch
|
| 23 |
+
make && make install
|
| 24 |
+
cd rocpd_python && python setup.py install && cd ..
|
| 25 |
+
cd rpd_tracer && make clean;make install && python setup.py install && cd ..
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
rpd.patch
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
diff --git a/rpd_tracer/Makefile b/rpd_tracer/Makefile
|
| 32 |
+
index e9d9feb..b2e9e1a 100644
|
| 33 |
+
--- a/rpd_tracer/Makefile
|
| 34 |
+
+++ b/rpd_tracer/Makefile
|
| 35 |
+
@@ -16,7 +16,7 @@ ifneq (,$(HIP_PATH))
|
| 36 |
+
$(info Building with roctracer)
|
| 37 |
+
RPD_LIBS += -L/opt/rocm/lib -lroctracer64 -lroctx64 -lamdhip64 -lrocm_smi64
|
| 38 |
+
RPD_INCLUDES += -I/opt/rocm/include -I/opt/rocm/include/roctracer -I/opt/rocm/include/hsa
|
| 39 |
+
- RPD_SRCS += RoctracerDataSource.cpp RocmSmiDataSource.cpp
|
| 40 |
+
+ RPD_SRCS += RoctracerDataSource.cpp
|
| 41 |
+
RPD_INCLUDES += -D__HIP_PLATFORM_AMD__
|
| 42 |
+
endif
|
| 43 |
+
```
|
| 44 |
+
2. Add loadTracer.sh file included in this directory to /sglang/python/sglang.
|
| 45 |
+
|
| 46 |
+
loadTracer.sh
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
#!/bin/bash
|
| 50 |
+
################################################################################
|
| 51 |
+
# Copyright (c) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved.
|
| 52 |
+
#
|
| 53 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 54 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 55 |
+
# in the Software without restriction, including without limitation the rights
|
| 56 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 57 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 58 |
+
# furnished to do so, subject to the following conditions:
|
| 59 |
+
#
|
| 60 |
+
# The above copyright notice and this permission notice shall be included in
|
| 61 |
+
# all copies or substantial portions of the Software.
|
| 62 |
+
#
|
| 63 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 64 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 65 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 66 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 67 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 68 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
| 69 |
+
# THE SOFTWARE.
|
| 70 |
+
################################################################################
|
| 71 |
+
OUTPUT_FILE="trace.rpd"
|
| 72 |
+
|
| 73 |
+
if [ "$1" = "-o" ] ; then
|
| 74 |
+
OUTPUT_FILE=$2
|
| 75 |
+
shift
|
| 76 |
+
shift
|
| 77 |
+
fi
|
| 78 |
+
|
| 79 |
+
if [ -e ${OUTPUT_FILE} ] ; then
|
| 80 |
+
rm ${OUTPUT_FILE}
|
| 81 |
+
fi
|
| 82 |
+
|
| 83 |
+
python3 -m rocpd.schema --create ${OUTPUT_FILE}
|
| 84 |
+
if [ $? != 0 ] ; then
|
| 85 |
+
echo "Error: Could not create rpd file. Please run 'python setup.py install' from the rocpd_python dir"
|
| 86 |
+
exit
|
| 87 |
+
fi
|
| 88 |
+
|
| 89 |
+
export RPDT_FILENAME=${OUTPUT_FILE}
|
| 90 |
+
export RPDT_AUTOSTART=0
|
| 91 |
+
LD_PRELOAD=librocm-smi_64:librpd_tracer.so "$@"
|
| 92 |
+
```
|
| 93 |
+
3. Apply patch (provided in this directory) with "git apply rpd_profile_server_enable.patch" if the main profiling purpose is to get info on gpu kernels as well as limited cpu activity info.
|
| 94 |
+
|
| 95 |
+
#### Common Notes 1
|
| 96 |
+
Please note that although we are doing TP=8 in the example, we purposely only log RPD profiling on 2 ranks in the patch file (i.e.tp_rank=0/1) for profiling/visualization convenience, as even Perfetto streaming mode can only load maximal 8GB json file for visualization. With 2 ranks logged in RPD profiling, we could still check whether there are issues among ranks (e.g. load imbalance issue, nccl issue), and at the same time, we could log relatively longer time duration before the json file generated from RPD file hits 8GB size.
|
| 97 |
+
|
| 98 |
+
rpd_profile_server_enable.patch
|
| 99 |
+
|
| 100 |
+
```bash
|
| 101 |
+
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
| 102 |
+
index 62d1ff9..9021c01 100644
|
| 103 |
+
--- a/python/sglang/srt/managers/scheduler.py
|
| 104 |
+
+++ b/python/sglang/srt/managers/scheduler.py
|
| 105 |
+
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
| 106 |
+
suppress_other_loggers,
|
| 107 |
+
)
|
| 108 |
+
from sglang.utils import get_exception_traceback
|
| 109 |
+
+from rpdTracerControl import rpdTracerControl
|
| 110 |
+
+rpdTracerControl.skipCreate()
|
| 111 |
+
|
| 112 |
+
logger = logging.getLogger(__name__)
|
| 113 |
+
|
| 114 |
+
@@ -245,6 +247,7 @@ class Scheduler:
|
| 115 |
+
],
|
| 116 |
+
with_stack=True,
|
| 117 |
+
)
|
| 118 |
+
+ self.rpd = rpdTracerControl()
|
| 119 |
+
|
| 120 |
+
@torch.inference_mode()
|
| 121 |
+
def event_loop(self):
|
| 122 |
+
@@ -1027,15 +1030,24 @@ class Scheduler:
|
| 123 |
+
def start_profile(self) -> None:
|
| 124 |
+
if self.profiler is None:
|
| 125 |
+
raise RuntimeError("Profiler is not enabled.")
|
| 126 |
+
- self.profiler.start()
|
| 127 |
+
+ #self.profiler.start() #block pytorch profiler for rpd profiler enabling
|
| 128 |
+
+ if self.tp_rank == 0 or self.tp_rank == 1:
|
| 129 |
+
+ self.rpd.start()
|
| 130 |
+
+ self.rpd.rangePush("", "rpd profile range", "")
|
| 131 |
+
+ logger.info("rpd is enabled")
|
| 132 |
+
|
| 133 |
+
def stop_profile(self) -> None:
|
| 134 |
+
if self.profiler is None:
|
| 135 |
+
raise RuntimeError("Profiler is not enabled.")
|
| 136 |
+
- self.profiler.stop()
|
| 137 |
+
- self.profiler.export_chrome_trace(
|
| 138 |
+
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
| 139 |
+
- )
|
| 140 |
+
+ #self.profiler.stop()
|
| 141 |
+
+ #self.profiler.export_chrome_trace(
|
| 142 |
+
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
| 143 |
+
+ #)
|
| 144 |
+
+ if self.tp_rank ==0 or self.tp_rank ==1:
|
| 145 |
+
+ self.rpd.rangePop()
|
| 146 |
+
+ self.rpd.stop()
|
| 147 |
+
+ self.rpd.flush()
|
| 148 |
+
+ logger.info("rpd is done")
|
| 149 |
+
logger.info("Profiler is done")
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
#### Advanced Debugging with RPD Profiler
|
| 153 |
+
Sometimes, we want to use rpd profiler to capture more CPU and python activities in order to debug some challenging issues (e.g. root cause of load imbalance across gpu processes, root cause of bubbles, etc). Only in such cases, we need to apply patch "git apply rpd_profile_server_enable_wCPU_activities.patch", where 3 files are modified.
|
| 154 |
+
|
| 155 |
+
rpd_profile_server_enable_wCPU_activities.patch
|
| 156 |
+
|
| 157 |
+
```bash
|
| 158 |
+
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
| 159 |
+
index 62d1ff9..2edb427 100644
|
| 160 |
+
--- a/python/sglang/srt/managers/scheduler.py
|
| 161 |
+
+++ b/python/sglang/srt/managers/scheduler.py
|
| 162 |
+
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
| 163 |
+
suppress_other_loggers,
|
| 164 |
+
)
|
| 165 |
+
from sglang.utils import get_exception_traceback
|
| 166 |
+
+from rpdTracerControl import rpdTracerControl
|
| 167 |
+
+rpdTracerControl.skipCreate()
|
| 168 |
+
|
| 169 |
+
logger = logging.getLogger(__name__)
|
| 170 |
+
|
| 171 |
+
@@ -245,6 +247,7 @@ class Scheduler:
|
| 172 |
+
],
|
| 173 |
+
with_stack=True,
|
| 174 |
+
)
|
| 175 |
+
+ self.rpd = rpdTracerControl()
|
| 176 |
+
|
| 177 |
+
@torch.inference_mode()
|
| 178 |
+
def event_loop(self):
|
| 179 |
+
@@ -1027,15 +1030,26 @@ class Scheduler:
|
| 180 |
+
def start_profile(self) -> None:
|
| 181 |
+
if self.profiler is None:
|
| 182 |
+
raise RuntimeError("Profiler is not enabled.")
|
| 183 |
+
- self.profiler.start()
|
| 184 |
+
+ #self.profiler.start()
|
| 185 |
+
+ logger.info("torch profiler is disabled")
|
| 186 |
+
+ if self.tp_rank == 0 or self.tp_rank == 1:
|
| 187 |
+
+ self.rpd.setPythonTrace(True)
|
| 188 |
+
+ self.rpd.start()
|
| 189 |
+
+ self.rpd.rangePush("", "scheduler", "")
|
| 190 |
+
+ logger.info("rpd is enabled inside scheduler profiling")
|
| 191 |
+
|
| 192 |
+
def stop_profile(self) -> None:
|
| 193 |
+
if self.profiler is None:
|
| 194 |
+
raise RuntimeError("Profiler is not enabled.")
|
| 195 |
+
- self.profiler.stop()
|
| 196 |
+
- self.profiler.export_chrome_trace(
|
| 197 |
+
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
| 198 |
+
- )
|
| 199 |
+
+ #self.profiler.stop()
|
| 200 |
+
+ #self.profiler.export_chrome_trace(
|
| 201 |
+
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
| 202 |
+
+ #)
|
| 203 |
+
+ if self.tp_rank ==0 or self.tp_rank ==1:
|
| 204 |
+
+ self.rpd.rangePop()
|
| 205 |
+
+ self.rpd.stop()
|
| 206 |
+
+ self.rpd.flush()
|
| 207 |
+
+ logger.info("rpd is done inside scheduler")
|
| 208 |
+
logger.info("Profiler is done")
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
|
| 212 |
+
index 2621ccd..181df85 100644
|
| 213 |
+
--- a/python/sglang/srt/managers/tokenizer_manager.py
|
| 214 |
+
+++ b/python/sglang/srt/managers/tokenizer_manager.py
|
| 215 |
+
@@ -58,6 +58,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
| 216 |
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
| 217 |
+
from sglang.srt.utils import is_generation_model, is_multimodal_model
|
| 218 |
+
|
| 219 |
+
+from rpdTracerControl import rpdTracerControl
|
| 220 |
+
+rpdTracerControl.skipCreate()
|
| 221 |
+
+
|
| 222 |
+
+
|
| 223 |
+
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
| 224 |
+
|
| 225 |
+
logger = logging.getLogger(__name__)
|
| 226 |
+
@@ -514,10 +518,20 @@ class TokenizerManager:
|
| 227 |
+
self.send_to_scheduler.send_pyobj(req)
|
| 228 |
+
|
| 229 |
+
def start_profile(self):
|
| 230 |
+
+ rpd = rpdTracerControl()
|
| 231 |
+
+ rpd.setPythonTrace(True)
|
| 232 |
+
+ rpd.start()
|
| 233 |
+
+ rpd.rangePush("", "tokenizer_manager", "")
|
| 234 |
+
+ logger.info("tokenizer_manager rpd profiling started!")
|
| 235 |
+
req = ProfileReq.START_PROFILE
|
| 236 |
+
self.send_to_scheduler.send_pyobj(req)
|
| 237 |
+
|
| 238 |
+
def stop_profile(self):
|
| 239 |
+
+ rpd = rpdTracerControl()
|
| 240 |
+
+ rpd.rangePop()
|
| 241 |
+
+ rpd.stop()
|
| 242 |
+
+ rpd.flush()
|
| 243 |
+
+ logger.info("rpd profiling is done inside tokenizer_manager!")
|
| 244 |
+
req = ProfileReq.STOP_PROFILE
|
| 245 |
+
self.send_to_scheduler.send_pyobj(req)
|
| 246 |
+
|
| 247 |
+
diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
|
| 248 |
+
index 7111c93..2bd722c 100644
|
| 249 |
+
--- a/python/sglang/srt/server.py
|
| 250 |
+
+++ b/python/sglang/srt/server.py
|
| 251 |
+
@@ -30,6 +30,8 @@ import threading
|
| 252 |
+
import time
|
| 253 |
+
from http import HTTPStatus
|
| 254 |
+
from typing import Dict, List, Optional, Union
|
| 255 |
+
+from rpdTracerControl import rpdTracerControl
|
| 256 |
+
+rpdTracerControl.skipCreate()
|
| 257 |
+
|
| 258 |
+
# Fix a bug of Python threading
|
| 259 |
+
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
| 260 |
+
@@ -152,6 +154,11 @@ async def flush_cache():
|
| 261 |
+
@app.post("/start_profile")
|
| 262 |
+
async def start_profile():
|
| 263 |
+
"""Start profiling."""
|
| 264 |
+
+ rpd = rpdTracerControl()
|
| 265 |
+
+ rpd.setPythonTrace(True)
|
| 266 |
+
+ rpd.start()
|
| 267 |
+
+ rpd.rangePush("", "server rpd profile range", "")
|
| 268 |
+
+ logger.info("rpd profiling started in server.py!")
|
| 269 |
+
tokenizer_manager.start_profile()
|
| 270 |
+
return Response(
|
| 271 |
+
content="Start profiling.\n",
|
| 272 |
+
@@ -164,6 +171,11 @@ async def start_profile():
|
| 273 |
+
async def stop_profile():
|
| 274 |
+
"""Stop profiling."""
|
| 275 |
+
tokenizer_manager.stop_profile()
|
| 276 |
+
+ rpd = rpdTracerControl()
|
| 277 |
+
+ rpd.rangePop()
|
| 278 |
+
+ rpd.stop()
|
| 279 |
+
+ rpd.flush()
|
| 280 |
+
+ logger.info("rpd profiling is done in server.py!")
|
| 281 |
+
return Response(
|
| 282 |
+
content="Stop profiling. This will take some time.\n",
|
| 283 |
+
status_code=200,
|
| 284 |
+
```
|
| 285 |
+
|
| 286 |
+
4. As an example for grok1 profiling, we create a dummy_grok1 directory with config.json (see content below) inside this directory and copy this directory to the right path for "--model-path" if you want to use the example server.sh file provided.
|
| 287 |
+
```bash
|
| 288 |
+
cat ../dummy_grok1/config.json
|
| 289 |
+
{
|
| 290 |
+
"architectures": [
|
| 291 |
+
"Grok1ModelForCausalLM"
|
| 292 |
+
],
|
| 293 |
+
"embedding_multiplier_scale": 78.38367176906169,
|
| 294 |
+
"output_multiplier_scale": 0.5773502691896257,
|
| 295 |
+
"vocab_size": 131072,
|
| 296 |
+
"hidden_size": 6144,
|
| 297 |
+
"intermediate_size": 32768,
|
| 298 |
+
"max_position_embeddings": 8192,
|
| 299 |
+
"num_experts_per_tok": 2,
|
| 300 |
+
"num_local_experts": 8,
|
| 301 |
+
"num_attention_heads": 48,
|
| 302 |
+
"num_hidden_layers": 64,
|
| 303 |
+
"num_key_value_heads": 8,
|
| 304 |
+
"head_dim": 128,
|
| 305 |
+
"rms_norm_eps": 1e-05,
|
| 306 |
+
"rope_theta": 10000.0,
|
| 307 |
+
"model_type": "mixtral",
|
| 308 |
+
"torch_dtype": "bfloat16"
|
| 309 |
+
}
|
| 310 |
+
```
|
| 311 |
+
5. Launch server with rpd enabled script ./server.sh in one terminal inside the docker container.
|
| 312 |
+
|
| 313 |
+
#### Common Notes 2
|
| 314 |
+
- Remember to change model-path to the correct path
|
| 315 |
+
- loadTracer.sh is needed to conduct profiling
|
| 316 |
+
- SGLANG_TORCH_PROFILER_DIR is used for default torch profiler
|
| 317 |
+
- Do not use loadTracer.sh if you are using the torch profiler, simply use python3 -m sglang.launch_server.
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
server.sh
|
| 321 |
+
|
| 322 |
+
```bash
|
| 323 |
+
#!/bin/bash
|
| 324 |
+
|
| 325 |
+
# export SGLANG_TORCH_PROFILER_DIR=/data/sglang/
|
| 326 |
+
export SGLANG_TORCH_PROFILER_DIR=/sgl-workspace/sglang/profile/
|
| 327 |
+
|
| 328 |
+
# Get the current timestamp
|
| 329 |
+
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
|
| 330 |
+
|
| 331 |
+
# Define the log file with a timestamp
|
| 332 |
+
LOGFILE="sglang_server_log_$TIMESTAMP.json"
|
| 333 |
+
|
| 334 |
+
# Run the Python command and save the output to the log file
|
| 335 |
+
loadTracer.sh python3 -m sglang.launch_server \
|
| 336 |
+
--model-path /sgl-workspace/sglang/dummy_grok1 \
|
| 337 |
+
--tokenizer-path Xenova/grok-1-tokenizer \
|
| 338 |
+
--load-format dummy \
|
| 339 |
+
--quantization fp8 \
|
| 340 |
+
--tp 8 \
|
| 341 |
+
--port 30000 \
|
| 342 |
+
--disable-radix-cache 2>&1 | tee "$LOGFILE"
|
| 343 |
+
```
|
| 344 |
+
6. Open another terminal for the same docker container, and run the rpd enabled ./client.sh after you see "The server is fired up and is ready to roll!" message from server side terminal.
|
| 345 |
+
|
| 346 |
+
#### Common Notes 3
|
| 347 |
+
- Use curl http://localhost:30000/start_profile & curl http://localhost:30000/stop_profile to control the start and end of profiling. Check sglang/python/sglang/srt/managers/scheduler.py for more details.
|
| 348 |
+
- Please don't use RPD profiler together with PyTorch profiler to avoid interference.
|
| 349 |
+
- The rocmProfileData/tools/rpd2tracing.py file is used to generate json file from RPD file.
|
| 350 |
+
|
| 351 |
+
client.sh
|
| 352 |
+
|
| 353 |
+
```bash
|
| 354 |
+
#!/bin/bash
|
| 355 |
+
|
| 356 |
+
# Start profiling via API
|
| 357 |
+
curl http://localhost:30000/start_profile -H "Content-Type: application/json"
|
| 358 |
+
|
| 359 |
+
# Benchmark serving using sglang with random dataset and tokenizer
|
| 360 |
+
# Define the log file with a timestamp
|
| 361 |
+
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
| 362 |
+
LOGFILE="sglang_client_log_$TIMESTAMP.json"
|
| 363 |
+
|
| 364 |
+
# Run the benchmark with specified parameters and save logs
|
| 365 |
+
python3 -m sglang.bench_serving \
|
| 366 |
+
--backend sglang \
|
| 367 |
+
--tokenizer Xenova/grok-1-tokenizer \
|
| 368 |
+
--dataset-name random \
|
| 369 |
+
--random-input 1024\
|
| 370 |
+
--random-output 1024 \
|
| 371 |
+
--num-prompts 120 \
|
| 372 |
+
--request-rate 8 \
|
| 373 |
+
--output-file online.jsonl 2>&1 | tee "$LOGFILE"
|
| 374 |
+
|
| 375 |
+
# Stop profiling via API
|
| 376 |
+
curl http://localhost:30000/stop_profile -H "Content-Type: application/json"
|
| 377 |
+
|
| 378 |
+
# Convert tracing file to csv & json
|
| 379 |
+
sqlite3 trace.rpd ".mode csv" ".header on" ".output trace.csv" "select * from top;" ".output stdout"
|
| 380 |
+
python3 ./rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json
|
| 381 |
+
```
|
| 382 |
+
7. Follow [Perfetto docs](https://perfetto.dev/docs/visualization/large-traces) to visualize large json files. Try to adjust parameters so that the trace.json file size is less than 9GB.
|
| 383 |
+
|
| 384 |
+
### Profiling SGLang Infer System with PyTorch Profiler
|
| 385 |
+
|
| 386 |
+
Please use the steps as follows:
|
| 387 |
+
|
| 388 |
+
1. Apply the patch torch_profiler.patch. Note that you can modify "if self.tp_rank == 0" in the patch to allow more ranks be recorded in profiling.
|
| 389 |
+
|
| 390 |
+
torch_profiler.patch
|
| 391 |
+
```bash
|
| 392 |
+
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
| 393 |
+
index 62d1ff9..6ecd78c 100644
|
| 394 |
+
--- a/python/sglang/srt/managers/scheduler.py
|
| 395 |
+
+++ b/python/sglang/srt/managers/scheduler.py
|
| 396 |
+
@@ -240,7 +240,6 @@ class Scheduler:
|
| 397 |
+
)
|
| 398 |
+
self.profiler = torch.profiler.profile(
|
| 399 |
+
activities=[
|
| 400 |
+
- torch.profiler.ProfilerActivity.CPU,
|
| 401 |
+
torch.profiler.ProfilerActivity.CUDA,
|
| 402 |
+
],
|
| 403 |
+
with_stack=True,
|
| 404 |
+
@@ -1033,9 +1032,11 @@ class Scheduler:
|
| 405 |
+
if self.profiler is None:
|
| 406 |
+
raise RuntimeError("Profiler is not enabled.")
|
| 407 |
+
self.profiler.stop()
|
| 408 |
+
- self.profiler.export_chrome_trace(
|
| 409 |
+
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
| 410 |
+
- )
|
| 411 |
+
+ if self.tp_rank == 0:
|
| 412 |
+
+ with open(f"stats_repro_{int(time.time())}.txt", "w") as f:
|
| 413 |
+
+ print(self.profiler.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=-1), file=f)
|
| 414 |
+
+ print("Profiling stats done.")
|
| 415 |
+
+
|
| 416 |
+
logger.info("Profiler is done")
|
| 417 |
+
```
|
| 418 |
+
|
| 419 |
+
2. Create the model path directory and copy it to the right path for "--model-path" if you want to use the server.sh file provided.
|
| 420 |
+
|
| 421 |
+
3. Modify the included server.sh by removing "loadTracer.sh" before python command and launch script ./server.sh in one terminal inside the docker container.
|
| 422 |
+
|
| 423 |
+
4. Similar to step 6 in RPD profiling section, but remove the last 2 lines in client.sh, which converted rpd file into csv and json files. Run modified client.sh for PyTorch profiling.
|
| 424 |
+
-------
|
| 425 |
+
- [Torch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
|
sglang/3rdparty/amd/profiling/client.sh
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Start profiling via API
|
| 4 |
+
curl http://localhost:30000/start_profile -H "Content-Type: application/json"
|
| 5 |
+
|
| 6 |
+
# Benchmark serving using sglang with random dataset and tokenizer
|
| 7 |
+
# Define the log file with a timestamp
|
| 8 |
+
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
| 9 |
+
LOGFILE="sglang_client_log_$TIMESTAMP.json"
|
| 10 |
+
|
| 11 |
+
# Run the benchmark with specified parameters and save logs
|
| 12 |
+
python3 -m sglang.bench_serving \
|
| 13 |
+
--backend sglang \
|
| 14 |
+
--tokenizer Xenova/grok-1-tokenizer \
|
| 15 |
+
--dataset-name random \
|
| 16 |
+
--random-input 1024\
|
| 17 |
+
--random-output 1024 \
|
| 18 |
+
--num-prompts 240 \
|
| 19 |
+
--request-rate 8 \
|
| 20 |
+
--output-file online.jsonl 2>&1 | tee "$LOGFILE"
|
| 21 |
+
|
| 22 |
+
# Stop profiling via API
|
| 23 |
+
curl http://localhost:30000/stop_profile -H "Content-Type: application/json"
|
| 24 |
+
|
| 25 |
+
# Convert tracing file to csv & json
|
| 26 |
+
sqlite3 trace.rpd ".mode csv" ".header on" ".output trace.csv" "select * from top;" ".output stdout"
|
| 27 |
+
python3 /sgl-workspace/rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json
|
sglang/3rdparty/amd/profiling/install_rpd.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# download and install RPD
|
| 2 |
+
apt update && apt install -y sqlite3 libsqlite3-dev libfmt-dev
|
| 3 |
+
|
| 4 |
+
# install rpd module
|
| 5 |
+
git clone https://github.com/ROCmSoftwarePlatform/rocmProfileData
|
| 6 |
+
cd rocmProfileData
|
| 7 |
+
git apply rpd.patch
|
| 8 |
+
make && make install
|
| 9 |
+
cd rocpd_python && python setup.py install && cd ..
|
| 10 |
+
cd rpd_tracer && make clean;make install && python setup.py install && cd ..
|
sglang/3rdparty/amd/profiling/loadTracer.sh
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
################################################################################
|
| 3 |
+
# Copyright (c) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
#
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in
|
| 13 |
+
# all copies or substantial portions of the Software.
|
| 14 |
+
#
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
| 21 |
+
# THE SOFTWARE.
|
| 22 |
+
################################################################################
|
| 23 |
+
OUTPUT_FILE="trace.rpd"
|
| 24 |
+
|
| 25 |
+
if [ "$1" = "-o" ] ; then
|
| 26 |
+
OUTPUT_FILE=$2
|
| 27 |
+
shift
|
| 28 |
+
shift
|
| 29 |
+
fi
|
| 30 |
+
|
| 31 |
+
if [ -e ${OUTPUT_FILE} ] ; then
|
| 32 |
+
rm ${OUTPUT_FILE}
|
| 33 |
+
fi
|
| 34 |
+
|
| 35 |
+
python3 -m rocpd.schema --create ${OUTPUT_FILE}
|
| 36 |
+
if [ $? != 0 ] ; then
|
| 37 |
+
echo "Error: Could not create rpd file. Please run 'python setup.py install' from the rocpd_python dir"
|
| 38 |
+
exit
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
export RPDT_FILENAME=${OUTPUT_FILE}
|
| 42 |
+
export RPDT_AUTOSTART=0
|
| 43 |
+
LD_PRELOAD=librocm-smi_64:librpd_tracer.so "$@"
|
sglang/3rdparty/amd/profiling/rpd.patch
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
diff --git a/rpd_tracer/Makefile b/rpd_tracer/Makefile
|
| 2 |
+
index e9d9feb..b2e9e1a 100644
|
| 3 |
+
--- a/rpd_tracer/Makefile
|
| 4 |
+
+++ b/rpd_tracer/Makefile
|
| 5 |
+
@@ -16,7 +16,7 @@ ifneq (,$(HIP_PATH))
|
| 6 |
+
$(info Building with roctracer)
|
| 7 |
+
RPD_LIBS += -L/opt/rocm/lib -lroctracer64 -lroctx64 -lamdhip64 -lrocm_smi64
|
| 8 |
+
RPD_INCLUDES += -I/opt/rocm/include -I/opt/rocm/include/roctracer -I/opt/rocm/include/hsa
|
| 9 |
+
- RPD_SRCS += RoctracerDataSource.cpp RocmSmiDataSource.cpp
|
| 10 |
+
+ RPD_SRCS += RoctracerDataSource.cpp
|
| 11 |
+
RPD_INCLUDES += -D__HIP_PLATFORM_AMD__
|
| 12 |
+
endif
|
sglang/3rdparty/amd/profiling/rpd_profile_server_enable.patch
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
| 2 |
+
index 62d1ff9..9021c01 100644
|
| 3 |
+
--- a/python/sglang/srt/managers/scheduler.py
|
| 4 |
+
+++ b/python/sglang/srt/managers/scheduler.py
|
| 5 |
+
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
| 6 |
+
suppress_other_loggers,
|
| 7 |
+
)
|
| 8 |
+
from sglang.utils import get_exception_traceback
|
| 9 |
+
+from rpdTracerControl import rpdTracerControl
|
| 10 |
+
+rpdTracerControl.skipCreate()
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
@@ -245,6 +247,7 @@ class Scheduler:
|
| 15 |
+
],
|
| 16 |
+
with_stack=True,
|
| 17 |
+
)
|
| 18 |
+
+ self.rpd = rpdTracerControl()
|
| 19 |
+
|
| 20 |
+
@torch.inference_mode()
|
| 21 |
+
def event_loop(self):
|
| 22 |
+
@@ -1027,15 +1030,24 @@ class Scheduler:
|
| 23 |
+
def start_profile(self) -> None:
|
| 24 |
+
if self.profiler is None:
|
| 25 |
+
raise RuntimeError("Profiler is not enabled.")
|
| 26 |
+
- self.profiler.start()
|
| 27 |
+
+ #self.profiler.start() #block pytorch profiler for rpd profiler enabling
|
| 28 |
+
+ if self.tp_rank == 0 or self.tp_rank == 1:
|
| 29 |
+
+ self.rpd.start()
|
| 30 |
+
+ self.rpd.rangePush("", "rpd profile range", "")
|
| 31 |
+
+ logger.info("rpd is enabled")
|
| 32 |
+
|
| 33 |
+
def stop_profile(self) -> None:
|
| 34 |
+
if self.profiler is None:
|
| 35 |
+
raise RuntimeError("Profiler is not enabled.")
|
| 36 |
+
- self.profiler.stop()
|
| 37 |
+
- self.profiler.export_chrome_trace(
|
| 38 |
+
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
| 39 |
+
- )
|
| 40 |
+
+ #self.profiler.stop()
|
| 41 |
+
+ #self.profiler.export_chrome_trace(
|
| 42 |
+
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
| 43 |
+
+ #)
|
| 44 |
+
+ if self.tp_rank ==0 or self.tp_rank ==1:
|
| 45 |
+
+ self.rpd.rangePop()
|
| 46 |
+
+ self.rpd.stop()
|
| 47 |
+
+ self.rpd.flush()
|
| 48 |
+
+ logger.info("rpd is done")
|
| 49 |
+
logger.info("Profiler is done")
|
sglang/3rdparty/amd/profiling/rpd_profile_server_enable_wCPU_activities.patch
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
| 2 |
+
index 62d1ff9..2edb427 100644
|
| 3 |
+
--- a/python/sglang/srt/managers/scheduler.py
|
| 4 |
+
+++ b/python/sglang/srt/managers/scheduler.py
|
| 5 |
+
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
| 6 |
+
suppress_other_loggers,
|
| 7 |
+
)
|
| 8 |
+
from sglang.utils import get_exception_traceback
|
| 9 |
+
+from rpdTracerControl import rpdTracerControl
|
| 10 |
+
+rpdTracerControl.skipCreate()
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
@@ -245,6 +247,7 @@ class Scheduler:
|
| 15 |
+
],
|
| 16 |
+
with_stack=True,
|
| 17 |
+
)
|
| 18 |
+
+ self.rpd = rpdTracerControl()
|
| 19 |
+
|
| 20 |
+
@torch.inference_mode()
|
| 21 |
+
def event_loop(self):
|
| 22 |
+
@@ -1027,15 +1030,26 @@ class Scheduler:
|
| 23 |
+
def start_profile(self) -> None:
|
| 24 |
+
if self.profiler is None:
|
| 25 |
+
raise RuntimeError("Profiler is not enabled.")
|
| 26 |
+
- self.profiler.start()
|
| 27 |
+
+ #self.profiler.start()
|
| 28 |
+
+ logger.info("torch profiler is disabled")
|
| 29 |
+
+ if self.tp_rank == 0 or self.tp_rank == 1:
|
| 30 |
+
+ self.rpd.setPythonTrace(True)
|
| 31 |
+
+ self.rpd.start()
|
| 32 |
+
+ self.rpd.rangePush("", "scheduler", "")
|
| 33 |
+
+ logger.info("rpd is enabled inside scheduler profiling")
|
| 34 |
+
|
| 35 |
+
def stop_profile(self) -> None:
|
| 36 |
+
if self.profiler is None:
|
| 37 |
+
raise RuntimeError("Profiler is not enabled.")
|
| 38 |
+
- self.profiler.stop()
|
| 39 |
+
- self.profiler.export_chrome_trace(
|
| 40 |
+
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
| 41 |
+
- )
|
| 42 |
+
+ #self.profiler.stop()
|
| 43 |
+
+ #self.profiler.export_chrome_trace(
|
| 44 |
+
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
| 45 |
+
+ #)
|
| 46 |
+
+ if self.tp_rank ==0 or self.tp_rank ==1:
|
| 47 |
+
+ self.rpd.rangePop()
|
| 48 |
+
+ self.rpd.stop()
|
| 49 |
+
+ self.rpd.flush()
|
| 50 |
+
+ logger.info("rpd is done inside scheduler")
|
| 51 |
+
logger.info("Profiler is done")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
|
| 55 |
+
index 2621ccd..181df85 100644
|
| 56 |
+
--- a/python/sglang/srt/managers/tokenizer_manager.py
|
| 57 |
+
+++ b/python/sglang/srt/managers/tokenizer_manager.py
|
| 58 |
+
@@ -58,6 +58,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
| 59 |
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
| 60 |
+
from sglang.srt.utils import is_generation_model, is_multimodal_model
|
| 61 |
+
|
| 62 |
+
+from rpdTracerControl import rpdTracerControl
|
| 63 |
+
+rpdTracerControl.skipCreate()
|
| 64 |
+
+
|
| 65 |
+
+
|
| 66 |
+
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
| 67 |
+
|
| 68 |
+
logger = logging.getLogger(__name__)
|
| 69 |
+
@@ -514,10 +518,20 @@ class TokenizerManager:
|
| 70 |
+
self.send_to_scheduler.send_pyobj(req)
|
| 71 |
+
|
| 72 |
+
def start_profile(self):
|
| 73 |
+
+ rpd = rpdTracerControl()
|
| 74 |
+
+ rpd.setPythonTrace(True)
|
| 75 |
+
+ rpd.start()
|
| 76 |
+
+ rpd.rangePush("", "tokenizer_manager", "")
|
| 77 |
+
+ logger.info("tokenizer_manager rpd profiling started!")
|
| 78 |
+
req = ProfileReq.START_PROFILE
|
| 79 |
+
self.send_to_scheduler.send_pyobj(req)
|
| 80 |
+
|
| 81 |
+
def stop_profile(self):
|
| 82 |
+
+ rpd = rpdTracerControl()
|
| 83 |
+
+ rpd.rangePop()
|
| 84 |
+
+ rpd.stop()
|
| 85 |
+
+ rpd.flush()
|
| 86 |
+
+ logger.info("rpd profiling is done inside tokenizer_manager!")
|
| 87 |
+
req = ProfileReq.STOP_PROFILE
|
| 88 |
+
self.send_to_scheduler.send_pyobj(req)
|
| 89 |
+
|
| 90 |
+
diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
|
| 91 |
+
index 7111c93..2bd722c 100644
|
| 92 |
+
--- a/python/sglang/srt/server.py
|
| 93 |
+
+++ b/python/sglang/srt/server.py
|
| 94 |
+
@@ -30,6 +30,8 @@ import threading
|
| 95 |
+
import time
|
| 96 |
+
from http import HTTPStatus
|
| 97 |
+
from typing import Dict, List, Optional, Union
|
| 98 |
+
+from rpdTracerControl import rpdTracerControl
|
| 99 |
+
+rpdTracerControl.skipCreate()
|
| 100 |
+
|
| 101 |
+
# Fix a bug of Python threading
|
| 102 |
+
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
| 103 |
+
@@ -152,6 +154,11 @@ async def flush_cache():
|
| 104 |
+
@app.post("/start_profile")
|
| 105 |
+
async def start_profile():
|
| 106 |
+
"""Start profiling."""
|
| 107 |
+
+ rpd = rpdTracerControl()
|
| 108 |
+
+ rpd.setPythonTrace(True)
|
| 109 |
+
+ rpd.start()
|
| 110 |
+
+ rpd.rangePush("", "server rpd profile range", "")
|
| 111 |
+
+ logger.info("rpd profiling started in server.py!")
|
| 112 |
+
tokenizer_manager.start_profile()
|
| 113 |
+
return Response(
|
| 114 |
+
content="Start profiling.\n",
|
| 115 |
+
@@ -164,6 +171,11 @@ async def start_profile():
|
| 116 |
+
async def stop_profile():
|
| 117 |
+
"""Stop profiling."""
|
| 118 |
+
tokenizer_manager.stop_profile()
|
| 119 |
+
+ rpd = rpdTracerControl()
|
| 120 |
+
+ rpd.rangePop()
|
| 121 |
+
+ rpd.stop()
|
| 122 |
+
+ rpd.flush()
|
| 123 |
+
+ logger.info("rpd profiling is done in server.py!")
|
| 124 |
+
return Response(
|
| 125 |
+
content="Stop profiling. This will take some time.\n",
|
| 126 |
+
status_code=200,
|
sglang/3rdparty/amd/profiling/server.sh
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# export SGLANG_TORCH_PROFILER_DIR=/data/sglang/
|
| 4 |
+
export SGLANG_TORCH_PROFILER_DIR=/sgl-workspace/sglang/profile/
|
| 5 |
+
|
| 6 |
+
# Get the current timestamp
|
| 7 |
+
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
|
| 8 |
+
|
| 9 |
+
# Define the log file with a timestamp
|
| 10 |
+
LOGFILE="sglang_server_log_$TIMESTAMP.json"
|
| 11 |
+
|
| 12 |
+
# Run the Python command and save the output to the log file
|
| 13 |
+
loadTracer.sh python3 -m sglang.launch_server \
|
| 14 |
+
--model-path /sgl-workspace/sglang/dummy_grok1 \
|
| 15 |
+
--tokenizer-path Xenova/grok-1-tokenizer \
|
| 16 |
+
--load-format dummy \
|
| 17 |
+
--quantization fp8 \
|
| 18 |
+
--tp 8 \
|
| 19 |
+
--port 30000 \
|
| 20 |
+
--disable-radix-cache 2>&1 | tee "$LOGFILE"
|
sglang/3rdparty/amd/profiling/torch_profiler.patch
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
| 2 |
+
index 62d1ff9..6ecd78c 100644
|
| 3 |
+
--- a/python/sglang/srt/managers/scheduler.py
|
| 4 |
+
+++ b/python/sglang/srt/managers/scheduler.py
|
| 5 |
+
@@ -240,7 +240,6 @@ class Scheduler:
|
| 6 |
+
)
|
| 7 |
+
self.profiler = torch.profiler.profile(
|
| 8 |
+
activities=[
|
| 9 |
+
- torch.profiler.ProfilerActivity.CPU,
|
| 10 |
+
torch.profiler.ProfilerActivity.CUDA,
|
| 11 |
+
],
|
| 12 |
+
with_stack=True,
|
| 13 |
+
@@ -1033,9 +1032,11 @@ class Scheduler:
|
| 14 |
+
if self.profiler is None:
|
| 15 |
+
raise RuntimeError("Profiler is not enabled.")
|
| 16 |
+
self.profiler.stop()
|
| 17 |
+
- self.profiler.export_chrome_trace(
|
| 18 |
+
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
| 19 |
+
- )
|
| 20 |
+
+ if self.tp_rank == 0:
|
| 21 |
+
+ with open(f"stats_repro_{int(time.time())}.txt", "w") as f:
|
| 22 |
+
+ print(self.profiler.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=-1), file=f)
|
| 23 |
+
+ print("Profiling stats done.")
|
| 24 |
+
+
|
| 25 |
+
logger.info("Profiler is done")
|
sglang/3rdparty/amd/sgl-kernel/CMakeLists_rocm.txt
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cmake_minimum_required(VERSION 3.24 FATAL_ERROR)
|
| 2 |
+
project(sgl_kernel LANGUAGES CXX)
|
| 3 |
+
|
| 4 |
+
# Cmake
|
| 5 |
+
set(CMAKE_CXX_STANDARD 17)
|
| 6 |
+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
| 7 |
+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
| 8 |
+
set(CMAKE_SHARED_LIBRARY_PREFIX "")
|
| 9 |
+
|
| 10 |
+
set(CMAKE_COLOR_DIAGNOSTICS ON)
|
| 11 |
+
set(CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL "ON")
|
| 12 |
+
|
| 13 |
+
# Python / Torch
|
| 14 |
+
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
|
| 15 |
+
|
| 16 |
+
execute_process(
|
| 17 |
+
COMMAND ${Python_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)"
|
| 18 |
+
OUTPUT_VARIABLE TORCH_PY_PREFIX
|
| 19 |
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
set(Torch_DIR "${TORCH_PY_PREFIX}/Torch")
|
| 23 |
+
list(APPEND CMAKE_PREFIX_PATH "${TORCH_PY_PREFIX}/Torch")
|
| 24 |
+
find_package(Torch REQUIRED)
|
| 25 |
+
|
| 26 |
+
execute_process(
|
| 27 |
+
COMMAND ${Python_EXECUTABLE} -c "import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))"
|
| 28 |
+
OUTPUT_VARIABLE TORCH_CXX11_ABI
|
| 29 |
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
| 30 |
+
)
|
| 31 |
+
if(TORCH_CXX11_ABI STREQUAL "0")
|
| 32 |
+
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
| 33 |
+
else()
|
| 34 |
+
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1)
|
| 35 |
+
endif()
|
| 36 |
+
|
| 37 |
+
# ROCm/HIP
|
| 38 |
+
enable_language(HIP)
|
| 39 |
+
find_package(hip REQUIRED CONFIG)
|
| 40 |
+
|
| 41 |
+
# Determine AMDGPU target from environment variable or default to gfx942
|
| 42 |
+
set(AMDGPU_TARGET_ENV "$ENV{AMDGPU_TARGET}")
|
| 43 |
+
|
| 44 |
+
if(AMDGPU_TARGET_ENV)
|
| 45 |
+
# Use environment variable if specified
|
| 46 |
+
set(AMDGPU_TARGETS "${AMDGPU_TARGET_ENV}")
|
| 47 |
+
message(STATUS "Using AMDGPU_TARGET from environment: ${AMDGPU_TARGETS}")
|
| 48 |
+
else()
|
| 49 |
+
# Default to gfx942 only
|
| 50 |
+
set(AMDGPU_TARGETS "gfx942")
|
| 51 |
+
message(STATUS "AMDGPU_TARGET not set, defaulting to gfx942")
|
| 52 |
+
endif()
|
| 53 |
+
|
| 54 |
+
# Set HIP architectures
|
| 55 |
+
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
|
| 56 |
+
|
| 57 |
+
# FP8 macro selection
|
| 58 |
+
# Always define HIP_FP8_TYPE_FNUZ=1 (for gfx942 and host compilation)
|
| 59 |
+
# Additionally define HIP_FP8_TYPE_E4M3=1 when building for gfx950
|
| 60 |
+
# The existing utils.h logic will pick the right one based on architecture
|
| 61 |
+
set(SGL_FP8_MACROS "-DHIP_FP8_TYPE_FNUZ=1")
|
| 62 |
+
|
| 63 |
+
if(AMDGPU_TARGETS MATCHES "gfx950")
|
| 64 |
+
list(APPEND SGL_FP8_MACROS "-DHIP_FP8_TYPE_E4M3=1")
|
| 65 |
+
message(STATUS "Multi-arch build: Enabling both HIP_FP8_TYPE_FNUZ (gfx942) and HIP_FP8_TYPE_E4M3 (gfx950)")
|
| 66 |
+
elseif(AMDGPU_TARGETS MATCHES "gfx942")
|
| 67 |
+
message(STATUS "Single-arch build: Enabling HIP_FP8_TYPE_FNUZ for gfx942")
|
| 68 |
+
else()
|
| 69 |
+
message(FATAL_ERROR "Unsupported AMDGPU_TARGET '${AMDGPU_TARGETS}'. Expected 'gfx942' or 'gfx950' or both.")
|
| 70 |
+
endif()
|
| 71 |
+
|
| 72 |
+
# TopK dynamic smem bytes
|
| 73 |
+
# Dynamic shared-memory budget for the TopK kernels.
|
| 74 |
+
# - gfx942 (MI300/MI325): LDS is typically 64KB per workgroup -> keep dynamic smem <= ~48KB
|
| 75 |
+
# (leaves room for static shared allocations in the kernel).
|
| 76 |
+
# - gfx95x (MI350): LDS is larger (e.g. 160KB per CU) -> allow the original 128KB dynamic smem.
|
| 77 |
+
if(AMDGPU_TARGET_ONE STREQUAL "gfx942")
|
| 78 |
+
math(EXPR SGL_TOPK_DYNAMIC_SMEM_BYTES "48 * 1024")
|
| 79 |
+
else()
|
| 80 |
+
math(EXPR SGL_TOPK_DYNAMIC_SMEM_BYTES "32 * 1024 * 4")
|
| 81 |
+
endif()
|
| 82 |
+
|
| 83 |
+
set(SGL_TOPK_MACROS "-DSGL_TOPK_DYNAMIC_SMEM_BYTES=${SGL_TOPK_DYNAMIC_SMEM_BYTES}")
|
| 84 |
+
|
| 85 |
+
# Paths / includes
|
| 86 |
+
set(PROJ_ROOT ${CMAKE_CURRENT_LIST_DIR})
|
| 87 |
+
set(SGL_INCLUDE_DIRS
|
| 88 |
+
${PROJ_ROOT}/include
|
| 89 |
+
${PROJ_ROOT}/include/impl
|
| 90 |
+
${PROJ_ROOT}/csrc
|
| 91 |
+
${TORCH_INCLUDE_DIRS}
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Platform-specific library directory
|
| 95 |
+
set(PLAT_LIB_DIR "/usr/lib/x86_64-linux-gnu")
|
| 96 |
+
link_directories(${PLAT_LIB_DIR})
|
| 97 |
+
|
| 98 |
+
# Sources
|
| 99 |
+
set(SOURCES
|
| 100 |
+
${PROJ_ROOT}/csrc/allreduce/custom_all_reduce.hip
|
| 101 |
+
${PROJ_ROOT}/csrc/allreduce/deterministic_all_reduce.hip
|
| 102 |
+
${PROJ_ROOT}/csrc/allreduce/quick_all_reduce.hip
|
| 103 |
+
${PROJ_ROOT}/csrc/common_extension_rocm.cc
|
| 104 |
+
${PROJ_ROOT}/csrc/elementwise/activation.hip
|
| 105 |
+
${PROJ_ROOT}/csrc/elementwise/pos_enc.hip
|
| 106 |
+
${PROJ_ROOT}/csrc/elementwise/topk.hip
|
| 107 |
+
${PROJ_ROOT}/csrc/grammar/apply_token_bitmask_inplace_hip.hip
|
| 108 |
+
${PROJ_ROOT}/csrc/kvcacheio/transfer.hip
|
| 109 |
+
${PROJ_ROOT}/csrc/moe/moe_align_kernel.hip
|
| 110 |
+
${PROJ_ROOT}/csrc/moe/moe_topk_softmax_kernels.hip
|
| 111 |
+
${PROJ_ROOT}/csrc/moe/moe_topk_sigmoid_kernels.hip
|
| 112 |
+
${PROJ_ROOT}/csrc/speculative/eagle_utils.hip
|
| 113 |
+
)
|
| 114 |
+
set_source_files_properties(
|
| 115 |
+
${SOURCES}
|
| 116 |
+
PROPERTIES
|
| 117 |
+
LANGUAGE HIP
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Compile / Link flags
|
| 121 |
+
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:-O3>)
|
| 122 |
+
|
| 123 |
+
set(SGL_HIP_FLAGS
|
| 124 |
+
-DNDEBUG
|
| 125 |
+
-DOPERATOR_NAMESPACE=sgl_kernel
|
| 126 |
+
-O3
|
| 127 |
+
-std=c++17
|
| 128 |
+
-DENABLE_BF16
|
| 129 |
+
-DENABLE_FP8
|
| 130 |
+
${SGL_FP8_MACROS}
|
| 131 |
+
-Wno-pass-failed
|
| 132 |
+
-Wundefined-internal
|
| 133 |
+
${SGL_TOPK_MACROS}
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Python extension
|
| 137 |
+
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
|
| 138 |
+
target_include_directories(common_ops PRIVATE ${SGL_INCLUDE_DIRS})
|
| 139 |
+
|
| 140 |
+
# Apply per-language flags
|
| 141 |
+
target_compile_options(common_ops PRIVATE
|
| 142 |
+
$<$<COMPILE_LANGUAGE:HIP>:${SGL_HIP_FLAGS}>
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
target_link_libraries(common_ops PRIVATE
|
| 146 |
+
${TORCH_LIBRARIES}
|
| 147 |
+
hip::device
|
| 148 |
+
hip::host
|
| 149 |
+
hiprtc
|
| 150 |
+
amdhip64
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
target_link_options(common_ops PRIVATE
|
| 154 |
+
"SHELL:-Wl,-rpath,'\$ORIGIN/../../torch/lib'"
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
install(TARGETS common_ops
|
| 158 |
+
LIBRARY DESTINATION sgl_kernel
|
| 159 |
+
)
|
sglang/3rdparty/amd/sgl-kernel/build_rocm.sh
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
ROCM_VERSION=$1
|
| 4 |
+
|
| 5 |
+
PYTHON_ROOT_PATH="/opt/venv/bin"
|
| 6 |
+
AMDGPU_TARGET="gfx942;gfx950"
|
| 7 |
+
|
| 8 |
+
echo "Python root path is: $PYTHON_ROOT_PATH"
|
| 9 |
+
|
| 10 |
+
# Get version from git tags
|
| 11 |
+
SGLANG_VERSION="v0.5.6" # Default version, will be overridden if git tags are found
|
| 12 |
+
|
| 13 |
+
# Fetch tags from origin to ensure we have the latest
|
| 14 |
+
if git fetch --tags origin; then
|
| 15 |
+
# Get the latest version tag sorted by version number (e.g., v0.5.7)
|
| 16 |
+
VERSION_FROM_TAG=$(git tag -l 'v[0-9]*' --sort=-v:refname | head -1)
|
| 17 |
+
if [ -n "$VERSION_FROM_TAG" ]; then
|
| 18 |
+
SGLANG_VERSION="$VERSION_FROM_TAG"
|
| 19 |
+
echo "Using SGLang version from git tags: $SGLANG_VERSION"
|
| 20 |
+
else
|
| 21 |
+
echo "Warning: No version tags found; using default $SGLANG_VERSION" >&2
|
| 22 |
+
fi
|
| 23 |
+
else
|
| 24 |
+
echo "Warning: Failed to fetch tags from origin; using default $SGLANG_VERSION" >&2
|
| 25 |
+
fi
|
| 26 |
+
|
| 27 |
+
# Default base tags (can be overridden by command line arguments)
|
| 28 |
+
DEFAULT_MI30X_BASE_TAG="${SGLANG_VERSION}-rocm700-mi30x"
|
| 29 |
+
DEFAULT_MI35X_BASE_TAG="${SGLANG_VERSION}-rocm700-mi35x"
|
| 30 |
+
|
| 31 |
+
# Parse command line arguments
|
| 32 |
+
MI30X_BASE_TAG="${DEFAULT_MI30X_BASE_TAG}"
|
| 33 |
+
MI35X_BASE_TAG="${DEFAULT_MI35X_BASE_TAG}"
|
| 34 |
+
|
| 35 |
+
# Detect GPU architecture from the Kubernetes runner hostname
|
| 36 |
+
HOSTNAME_VALUE=$(hostname)
|
| 37 |
+
GPU_ARCH="mi30x" # default
|
| 38 |
+
|
| 39 |
+
# Host names look like: linux-mi35x-gpu-1-xxxxx-runner-zzzzz
|
| 40 |
+
if [[ "${HOSTNAME_VALUE}" =~ ^linux-(mi[0-9]+[a-z]*)-gpu-[0-9]+ ]]; then
|
| 41 |
+
GPU_ARCH="${BASH_REMATCH[1]}"
|
| 42 |
+
echo "Detected GPU architecture from hostname: ${GPU_ARCH}"
|
| 43 |
+
else
|
| 44 |
+
echo "Warning: could not parse GPU architecture from '${HOSTNAME_VALUE}', defaulting to ${GPU_ARCH}"
|
| 45 |
+
fi
|
| 46 |
+
|
| 47 |
+
case "${GPU_ARCH}" in
|
| 48 |
+
mi35x)
|
| 49 |
+
echo "Runner uses ${GPU_ARCH}; will fetch mi35x image."
|
| 50 |
+
;;
|
| 51 |
+
mi30x|mi300|mi325)
|
| 52 |
+
echo "Runner uses ${GPU_ARCH}; will fetch mi30x image."
|
| 53 |
+
GPU_ARCH="mi30x"
|
| 54 |
+
;;
|
| 55 |
+
*)
|
| 56 |
+
echo "Runner architecture '${GPU_ARCH}' unrecognised; defaulting to mi30x image." >&2
|
| 57 |
+
GPU_ARCH="mi30x"
|
| 58 |
+
;;
|
| 59 |
+
esac
|
| 60 |
+
|
| 61 |
+
if [[ -f /etc/podinfo/gha-render-devices ]]; then
|
| 62 |
+
DEVICE_FLAG=$(cat /etc/podinfo/gha-render-devices)
|
| 63 |
+
else
|
| 64 |
+
DEVICE_FLAG="--device /dev/dri"
|
| 65 |
+
fi
|
| 66 |
+
|
| 67 |
+
# Find the latest image
|
| 68 |
+
find_latest_image() {
|
| 69 |
+
local gpu_arch=$1
|
| 70 |
+
local base_tag days_back image_tag
|
| 71 |
+
|
| 72 |
+
case "${gpu_arch}" in
|
| 73 |
+
mi30x) base_tag="${MI30X_BASE_TAG}" ;;
|
| 74 |
+
mi35x) base_tag="${MI35X_BASE_TAG}" ;;
|
| 75 |
+
*) echo "Error: unsupported GPU architecture '${gpu_arch}'" >&2; return 1 ;;
|
| 76 |
+
esac
|
| 77 |
+
|
| 78 |
+
for days_back in {0..6}; do
|
| 79 |
+
image_tag="${base_tag}-$(date -d "${days_back} days ago" +%Y%m%d)"
|
| 80 |
+
echo "Checking for image: rocm/sgl-dev:${image_tag}" >&2
|
| 81 |
+
if docker manifest inspect "rocm/sgl-dev:${image_tag}" >/dev/null 2>&1; then
|
| 82 |
+
echo "Found available image: rocm/sgl-dev:${image_tag}" >&2
|
| 83 |
+
echo "rocm/sgl-dev:${image_tag}"
|
| 84 |
+
return 0
|
| 85 |
+
fi
|
| 86 |
+
done
|
| 87 |
+
|
| 88 |
+
echo "Error: no ${gpu_arch} image found in the last 7 days for base ${base_tag}" >&2
|
| 89 |
+
echo "Using hard-coded fallback…" >&2
|
| 90 |
+
if [[ "${gpu_arch}" == "mi35x" ]]; then
|
| 91 |
+
echo "rocm/sgl-dev:v0.5.3-rocm700-mi35x-20251009"
|
| 92 |
+
else
|
| 93 |
+
echo "rocm/sgl-dev:v0.5.3-rocm700-mi30x-20251009"
|
| 94 |
+
fi
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
# Pull and run the latest image
|
| 98 |
+
IMAGE=$(find_latest_image "${GPU_ARCH}")
|
| 99 |
+
echo "Pulling Docker image: ${IMAGE}"
|
| 100 |
+
docker pull "${IMAGE}"
|
| 101 |
+
|
| 102 |
+
docker run --rm \
|
| 103 |
+
-v $(pwd):/sgl-kernel \
|
| 104 |
+
-e AMDGPU_TARGET="${AMDGPU_TARGET}" \
|
| 105 |
+
${IMAGE} \
|
| 106 |
+
bash -c "
|
| 107 |
+
# Install CMake (version >= 3.26) - Robust Installation
|
| 108 |
+
export CMAKE_VERSION_MAJOR=3.31
|
| 109 |
+
export CMAKE_VERSION_MINOR=1
|
| 110 |
+
echo \"Downloading CMake from: https://cmake.org/files/v\${CMAKE_VERSION_MAJOR}/cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-x86_64.tar.gz\"
|
| 111 |
+
wget https://cmake.org/files/v\${CMAKE_VERSION_MAJOR}/cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-x86_64.tar.gz
|
| 112 |
+
tar -xzf cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-x86_64.tar.gz
|
| 113 |
+
mv cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-x86_64 /opt/cmake
|
| 114 |
+
export PATH=/opt/cmake/bin:\$PATH
|
| 115 |
+
|
| 116 |
+
${PYTHON_ROOT_PATH}/pip install --no-cache-dir ninja setuptools wheel numpy uv scikit-build-core && \
|
| 117 |
+
|
| 118 |
+
cd /sgl-kernel && \
|
| 119 |
+
rm -rf CMakeLists.txt && mv CMakeLists_rocm.txt CMakeLists.txt && \
|
| 120 |
+
${PYTHON_ROOT_PATH}/python rocm_hipify.py && \
|
| 121 |
+
${PYTHON_ROOT_PATH}/python -m uv build --wheel -Cbuild-dir=build . --color=always --no-build-isolation && \
|
| 122 |
+
./rename_wheels_rocm.sh
|
| 123 |
+
"
|
sglang/3rdparty/amd/sgl-kernel/rename_wheels_rocm.sh
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -ex
|
| 3 |
+
|
| 4 |
+
WHEEL_DIR="dist"
|
| 5 |
+
|
| 6 |
+
wheel_files=($WHEEL_DIR/*.whl)
|
| 7 |
+
for wheel in "${wheel_files[@]}"; do
|
| 8 |
+
intermediate_wheel="${wheel/linux/manylinux2014}"
|
| 9 |
+
|
| 10 |
+
# Extract the current python version from the wheel name
|
| 11 |
+
if [[ $intermediate_wheel =~ -cp([0-9]+)- ]]; then
|
| 12 |
+
cp_version="${BASH_REMATCH[1]}"
|
| 13 |
+
else
|
| 14 |
+
echo "Could not extract Python version from wheel name: $intermediate_wheel"
|
| 15 |
+
continue
|
| 16 |
+
fi
|
| 17 |
+
|
| 18 |
+
# Detect ROCm version and add appropriate suffix
|
| 19 |
+
if ls /opt | grep -q "7.0"; then
|
| 20 |
+
new_wheel="${intermediate_wheel/-cp${cp_version}/+rocm700-cp${cp_version}}"
|
| 21 |
+
else
|
| 22 |
+
new_wheel="$intermediate_wheel"
|
| 23 |
+
fi
|
| 24 |
+
|
| 25 |
+
if [[ "$wheel" != "$new_wheel" ]]; then
|
| 26 |
+
echo "Renaming $wheel to $new_wheel"
|
| 27 |
+
mv -- "$wheel" "$new_wheel"
|
| 28 |
+
fi
|
| 29 |
+
done
|
| 30 |
+
echo "Wheel renaming completed."
|
sglang/3rdparty/amd/sgl-kernel/rocm_hipify.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.cpp_extension import CUDAExtension
|
| 5 |
+
|
| 6 |
+
root = Path(__file__).parent.resolve()
|
| 7 |
+
|
| 8 |
+
include_dirs = [
|
| 9 |
+
root / "include",
|
| 10 |
+
root / "include" / "impl",
|
| 11 |
+
root / "csrc",
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
sources = [
|
| 15 |
+
"csrc/allreduce/custom_all_reduce.hip",
|
| 16 |
+
"csrc/allreduce/deterministic_all_reduce.hip",
|
| 17 |
+
"csrc/allreduce/quick_all_reduce.cu",
|
| 18 |
+
"csrc/common_extension_rocm.cc",
|
| 19 |
+
"csrc/elementwise/activation.cu",
|
| 20 |
+
"csrc/elementwise/pos_enc.cu",
|
| 21 |
+
"csrc/elementwise/topk.cu",
|
| 22 |
+
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu",
|
| 23 |
+
"csrc/kvcacheio/transfer.cu",
|
| 24 |
+
"csrc/moe/moe_align_kernel.cu",
|
| 25 |
+
"csrc/moe/moe_topk_softmax_kernels.cu",
|
| 26 |
+
"csrc/moe/moe_topk_sigmoid_kernels.cu",
|
| 27 |
+
"csrc/speculative/eagle_utils.cu",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
libraries = ["hiprtc", "amdhip64", "c10", "torch", "torch_python"]
|
| 31 |
+
|
| 32 |
+
ext_modules = [
|
| 33 |
+
CUDAExtension(
|
| 34 |
+
name="sgl_kernel.common_ops",
|
| 35 |
+
sources=sources,
|
| 36 |
+
include_dirs=include_dirs,
|
| 37 |
+
libraries=libraries,
|
| 38 |
+
py_limited_api=False,
|
| 39 |
+
),
|
| 40 |
+
]
|
sglang/3rdparty/amd/tuning/TUNING.md
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Tuning SGLang Infer System with AMD GPUs
|
| 2 |
+
This AppNote describes the SGLang performance tuning technical, code harness and running steps for systems with AMD Instinct GPUs.
|
| 3 |
+
Harness code, examples and steps are provided in detail, to facilitate easy reproduce & use to tune performance towards workloads.
|
| 4 |
+
Three primary runtime areas are covered:
|
| 5 |
+
|
| 6 |
+
## 1. Triton Kernels
|
| 7 |
+
To maximize Triton kernel efficiency, several strategies can be employed:
|
| 8 |
+
|
| 9 |
+
### Key Environment Variables:
|
| 10 |
+
- **num_stages**: Adjusts the number of pipeline stages to optimize kernel efficiency based on the specific type of operations (e.g., General Matrix Multiplication - GEMM).
|
| 11 |
+
- **waves_per_eu**: Controls the usage of Vector General Purpose Registers (VGPR) to enhance occupancy, thereby improving latency or throughput.
|
| 12 |
+
- **BLOCK_M, BLOCK_N, BLOCK_K**: Tunable tile sizes that assist in balancing memory transfer and computational efficiency.
|
| 13 |
+
- **matrix_instr_nonkdim**: Optimizes the usage of Matrix-Fused Multiply-Add (MFMA) instructions for specific kernel types, such as Flash Attention.
|
| 14 |
+
- **OPTIMIZE_EPILOGUE**: An environment variable that can be set to `1` to enhance performance by eliminating the `convert_layout` operation in the kernel's epilogue.
|
| 15 |
+
```python
|
| 16 |
+
@triton.autotune(configs=[
|
| 17 |
+
triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1),
|
| 18 |
+
triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1),
|
| 19 |
+
triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1),
|
| 20 |
+
triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1),
|
| 21 |
+
triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1),
|
| 22 |
+
triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1),
|
| 23 |
+
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1),
|
| 24 |
+
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1),
|
| 25 |
+
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1),
|
| 26 |
+
], key=['BLOCK_N', 'NUM_TOKEN_BLKS'], use_cuda_graph=True)
|
| 27 |
+
@triton.jit
|
| 28 |
+
def _triton_kernel_function():
|
| 29 |
+
...
|
| 30 |
+
```
|
| 31 |
+
## 2. Torch Tunable Operations
|
| 32 |
+
**TunableOp** is a feature in PyTorch that allows for the definition and optimization of custom kernels with tunable parameters. This feature is particularly useful for enhancing the performance of kernels by experimenting with different configurations.
|
| 33 |
+
|
| 34 |
+
### Key Environment Variables:
|
| 35 |
+
1. **PYTORCH_TUNABLEOP_ENABLED**:
|
| 36 |
+
- Default: `0`
|
| 37 |
+
- Set to `1` to enable TunableOp.
|
| 38 |
+
|
| 39 |
+
2. **PYTORCH_TUNABLEOP_TUNING**:
|
| 40 |
+
- Default: `1`
|
| 41 |
+
- Set to `0` to disable tuning. If a tuned entry is not found, it will run the tuning step and record the entry when PYTORCH_TUNABLEOP_ENABLED is enabled.
|
| 42 |
+
|
| 43 |
+
3. **PYTORCH_TUNABLEOP_VERBOSE**:
|
| 44 |
+
- Default: `0`
|
| 45 |
+
- Set to `1` to enable verbose output for TunableOp.
|
| 46 |
+
|
| 47 |
+
### Usage Example:
|
| 48 |
+
To enable TunableOp and tuning, and optionally enable verbose mode, you can run the following command in your terminal:
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
#Tuning
|
| 52 |
+
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=1 your_script.sh
|
| 53 |
+
|
| 54 |
+
#Inference with tuning op
|
| 55 |
+
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 your_script.sh
|
| 56 |
+
|
| 57 |
+
#Print out the log
|
| 58 |
+
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 PYTORCH_TUNABLEOP_VERBOSE=1 your_script.sh
|
| 59 |
+
|
| 60 |
+
```
|
| 61 |
+
## 3. Torch Compilation
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
The following are suggestions for optimizing matrix multiplication (GEMM) and convolution (conv) operations in PyTorch using Inductor, a part of the PyTorch compilation framework. The goal is to leverage Triton to achieve better performance.
|
| 65 |
+
|
| 66 |
+
To tune Triton kernels with GEMM and convolution ops (conv), use the `torch.compile` function with the max-autotune mode. This benchmarks a predefined list of Triton configurations and selects the fastest one for each shape.
|
| 67 |
+
|
| 68 |
+
### Key Configurations:
|
| 69 |
+
1. **Max Autotune**:
|
| 70 |
+
- Set `torch._inductor.config.max_autotune = True` or `TORCHINDUCTOR_MAX_AUTOTUNE=1`.
|
| 71 |
+
|
| 72 |
+
2. **Fine-Grained Control**:
|
| 73 |
+
- Enable GEMM tuning: `torch._inductor.config.max_autotune_gemm = True`.
|
| 74 |
+
- Enable tuning for pointwise/reduction ops: `torch._inductor.config.max_autotune.pointwise = True`.
|
| 75 |
+
|
| 76 |
+
3. **Backend Selection**:
|
| 77 |
+
- Use `torch._inductor.max_autotune_gemm_backends` to limit backends to TRITON for better performance.
|
| 78 |
+
|
| 79 |
+
4. **Freezing for Inference**:
|
| 80 |
+
- Use `torch._inductor.config.freezing=True` to enable constant folding optimizations.
|
| 81 |
+
|
| 82 |
+
5. **Debugging**:
|
| 83 |
+
- Set `TORCH_COMPILE_DEBUG=1` to extract Triton kernels generated by Inductor.
|
| 84 |
+
|
| 85 |
+
### Example Code Block:
|
| 86 |
+
```bash
|
| 87 |
+
#Gemm Tuning
|
| 88 |
+
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 your_script.sh
|
| 89 |
+
|
| 90 |
+
#Specify your backend to TRITON for Gemm Tuning
|
| 91 |
+
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=TRITON your_script.sh
|
| 92 |
+
|
| 93 |
+
#Inference with large improvement on AMD GPU
|
| 94 |
+
TORCHINDUCTOR_FREEZING=1 your_script.sh
|
| 95 |
+
```
|
| 96 |
+
## 4. Fused MOE kernel
|
| 97 |
+
To maximize moe kernel efficiency, need to use below scripts to find out the best launch configuration
|
| 98 |
+
|
| 99 |
+
### Key parameters:
|
| 100 |
+
- **--model**: what moe model type to do tuning, it will automatically decide the size of d_model, model_intermediate_size, num_layers
|
| 101 |
+
- **--tp-size**: simulate the whole model run configuration to set the dimension size using tp correctly
|
| 102 |
+
- **--batch**: M dimension size of moe kernel, for prefill moe kernel the value is batch*input_len, for decode moe kernel the value is batch
|
| 103 |
+
- **--dtype**: computation type
|
| 104 |
+
|
| 105 |
+
```bash
|
| 106 |
+
#Tuning
|
| 107 |
+
#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quantization fp8" to run, it defined batch-size 32 input length 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run).
|
| 108 |
+
#so we can tune decode moe use below command
|
| 109 |
+
python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32"
|
| 110 |
+
# and use this command to tune prefill moe
|
| 111 |
+
python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32768"
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
## Reference
|
| 115 |
+
|
| 116 |
+
For more detailed information on tuning SGLang performance with AMD GPUs, please refer to the following link:
|
| 117 |
+
|
| 118 |
+
[ROCm Documentation: Triton Kernel Performance Optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#triton-kernel-performance-optimization)
|
sglang/3rdparty/amd/tuning/benchmark_moe_rocm.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import triton
|
| 9 |
+
import triton.language as tl
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from transformers import AutoConfig
|
| 12 |
+
|
| 13 |
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
| 14 |
+
fused_moe,
|
| 15 |
+
get_config_file_name,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def main(model, tp_size, dtype: str, batches):
|
| 22 |
+
method = fused_moe
|
| 23 |
+
|
| 24 |
+
for bs in batches:
|
| 25 |
+
run_grid(int(bs), model=model, method=method, tp_size=tp_size, dtype=dtype)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def prune_configs(M, N, K, configs):
|
| 29 |
+
pruned_configs = []
|
| 30 |
+
elemBytes_a = 1 # [DV Note] Hard-coded for float16 (2 bytes)
|
| 31 |
+
elemBytes_b = 1 # [DV Note] Hard-coded for float16 (2 bytes)
|
| 32 |
+
|
| 33 |
+
mfma = 16 if M < 32 or N < 32 else 32
|
| 34 |
+
|
| 35 |
+
# TODO (zhanglx): figure out the boundary between large and small gemms
|
| 36 |
+
large_gemm = False
|
| 37 |
+
if M >= 2048 and N >= 2048:
|
| 38 |
+
large_gemm = True
|
| 39 |
+
|
| 40 |
+
for config in configs:
|
| 41 |
+
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
|
| 42 |
+
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
|
| 43 |
+
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
|
| 44 |
+
num_warps = config.get("num_warps")
|
| 45 |
+
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
|
| 46 |
+
# kpack = config.get("kpack")
|
| 47 |
+
if matrix_instr_nonkdim > mfma:
|
| 48 |
+
continue
|
| 49 |
+
if mfma == 4 and BLOCK_SIZE_K < 64:
|
| 50 |
+
continue
|
| 51 |
+
# some layouts could not work properly in case
|
| 52 |
+
# number elements per thread is less 1
|
| 53 |
+
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
|
| 54 |
+
continue
|
| 55 |
+
SPLIT_K = 1 # config.get("SPLIT_K")
|
| 56 |
+
GROUP_M = config.get("GROUP_SIZE_M")
|
| 57 |
+
if matrix_instr_nonkdim > BLOCK_SIZE_M or matrix_instr_nonkdim > BLOCK_SIZE_N:
|
| 58 |
+
continue
|
| 59 |
+
if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
|
| 60 |
+
continue
|
| 61 |
+
if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
|
| 62 |
+
continue
|
| 63 |
+
# Skip BLOCK_SIZE that is too large compare to M/N
|
| 64 |
+
# unless BLOCK_SIZE is already small enough
|
| 65 |
+
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
|
| 66 |
+
continue
|
| 67 |
+
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
|
| 68 |
+
continue
|
| 69 |
+
# skip large split_k when not necessary
|
| 70 |
+
if SPLIT_K != 1 and not need_split_k(M, N, K):
|
| 71 |
+
continue
|
| 72 |
+
# skip split_k that leads to EVEN_K = false
|
| 73 |
+
leap = SPLIT_K * BLOCK_SIZE_K
|
| 74 |
+
modv = K % leap
|
| 75 |
+
if modv != 0:
|
| 76 |
+
continue
|
| 77 |
+
# skip large GROUP_M
|
| 78 |
+
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
|
| 79 |
+
continue
|
| 80 |
+
# out of shared memory resource
|
| 81 |
+
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
|
| 82 |
+
LDS = (
|
| 83 |
+
BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
|
| 84 |
+
+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
|
| 85 |
+
)
|
| 86 |
+
if LDS > 65536:
|
| 87 |
+
continue
|
| 88 |
+
# Skip small block sizes and num_warps for large gemm
|
| 89 |
+
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
|
| 90 |
+
if large_gemm:
|
| 91 |
+
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
|
| 92 |
+
continue
|
| 93 |
+
if BLOCK_SIZE_K < 64:
|
| 94 |
+
continue
|
| 95 |
+
if num_warps < 4:
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
pruned_configs.append(config)
|
| 99 |
+
|
| 100 |
+
return pruned_configs
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def union_of_list_of_dicts(l1, l2):
|
| 104 |
+
result = []
|
| 105 |
+
temp_list = l1.copy()
|
| 106 |
+
temp_list.extend(l2)
|
| 107 |
+
for myDict in temp_list:
|
| 108 |
+
if myDict not in result:
|
| 109 |
+
result.append(myDict)
|
| 110 |
+
|
| 111 |
+
return result
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def run_grid(bs, model, method, tp_size, dtype: str):
|
| 115 |
+
|
| 116 |
+
config = AutoConfig.from_pretrained(model)
|
| 117 |
+
|
| 118 |
+
top_k = config.num_experts_per_tok
|
| 119 |
+
d_model = config.hidden_size
|
| 120 |
+
model_intermediate_size = config.intermediate_size
|
| 121 |
+
num_layers = config.num_hidden_layers
|
| 122 |
+
hidden_states_dtype = config.torch_dtype
|
| 123 |
+
|
| 124 |
+
if config.num_experts_per_tok:
|
| 125 |
+
if config.architectures[0] == "Grok1ModelForCausalLM":
|
| 126 |
+
num_total_experts = config.num_experts
|
| 127 |
+
else:
|
| 128 |
+
num_total_experts = config.num_local_experts
|
| 129 |
+
else:
|
| 130 |
+
raise ValueError(f"Unsupported Mixtral model {model}")
|
| 131 |
+
|
| 132 |
+
# tp_size = 2
|
| 133 |
+
num_warmup_calls = 10
|
| 134 |
+
num_calls = 30
|
| 135 |
+
|
| 136 |
+
num_warmup_trials = 1
|
| 137 |
+
num_trials = 1
|
| 138 |
+
|
| 139 |
+
full_configs = []
|
| 140 |
+
|
| 141 |
+
block_m_range = [16, 32, 64, 128, 256]
|
| 142 |
+
block_n_range = [16, 32, 64, 128, 256]
|
| 143 |
+
block_k_range = [32, 64, 128, 256] # MUST >= 32
|
| 144 |
+
num_warps_range = [1, 2, 4, 8]
|
| 145 |
+
group_m_range = [1, 4, 8, 16, 32]
|
| 146 |
+
# For now we see better perf with num_stages=0 for all gemm configs we care
|
| 147 |
+
# But keep this explicit so that we do not forget we may need to set it to
|
| 148 |
+
# other values in the future
|
| 149 |
+
num_stage_range = [2]
|
| 150 |
+
waves_per_eu_range = [0, 1, 2, 4, 8]
|
| 151 |
+
# Remove 32 because of triton compiling error
|
| 152 |
+
matrix_instr_nonkdim_range = [16]
|
| 153 |
+
kpack_range = [1, 2]
|
| 154 |
+
|
| 155 |
+
for block_size_m in block_m_range:
|
| 156 |
+
for block_size_n in block_n_range:
|
| 157 |
+
for block_size_k in block_k_range:
|
| 158 |
+
for group_size_m in group_m_range:
|
| 159 |
+
for num_warps in num_warps_range:
|
| 160 |
+
for num_stages in num_stage_range:
|
| 161 |
+
for waves_per_eu in waves_per_eu_range:
|
| 162 |
+
for matrix_instr_nonkdim in matrix_instr_nonkdim_range:
|
| 163 |
+
for kpack in kpack_range:
|
| 164 |
+
full_configs.append(
|
| 165 |
+
{
|
| 166 |
+
"BLOCK_SIZE_M": block_size_m,
|
| 167 |
+
"BLOCK_SIZE_N": block_size_n,
|
| 168 |
+
"BLOCK_SIZE_K": block_size_k,
|
| 169 |
+
"GROUP_SIZE_M": group_size_m,
|
| 170 |
+
"num_warps": num_warps,
|
| 171 |
+
"num_stages": num_stages,
|
| 172 |
+
"waves_per_eu": waves_per_eu,
|
| 173 |
+
"matrix_instr_nonkdim": matrix_instr_nonkdim,
|
| 174 |
+
"kpack": kpack,
|
| 175 |
+
}
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
M1 = bs * 2
|
| 179 |
+
N1 = model_intermediate_size * 2 // tp_size
|
| 180 |
+
K1 = d_model
|
| 181 |
+
prune_configs_1 = prune_configs(M1, N1, K1, full_configs)
|
| 182 |
+
|
| 183 |
+
M2 = bs * 2
|
| 184 |
+
N2 = d_model
|
| 185 |
+
K2 = model_intermediate_size // tp_size
|
| 186 |
+
prune_configs_2 = prune_configs(M2, N2, K2, full_configs)
|
| 187 |
+
|
| 188 |
+
configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2)
|
| 189 |
+
|
| 190 |
+
print(f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \
|
| 191 |
+
{len(prune_configs_2)=} | {len(configs)=}")
|
| 192 |
+
|
| 193 |
+
best_config = None
|
| 194 |
+
best_time_us = 1e20
|
| 195 |
+
|
| 196 |
+
print(f"{tp_size=} {bs=}")
|
| 197 |
+
|
| 198 |
+
for config in tqdm(configs):
|
| 199 |
+
# warmup
|
| 200 |
+
try:
|
| 201 |
+
print(config)
|
| 202 |
+
for _ in range(num_warmup_trials):
|
| 203 |
+
run_timing(
|
| 204 |
+
num_calls=num_warmup_calls,
|
| 205 |
+
bs=bs,
|
| 206 |
+
d_model=d_model,
|
| 207 |
+
num_total_experts=num_total_experts,
|
| 208 |
+
top_k=top_k,
|
| 209 |
+
tp_size=tp_size,
|
| 210 |
+
model_intermediate_size=model_intermediate_size,
|
| 211 |
+
method=method,
|
| 212 |
+
config=config,
|
| 213 |
+
dtype=dtype,
|
| 214 |
+
hidden_states_dtype=hidden_states_dtype,
|
| 215 |
+
)
|
| 216 |
+
except triton.runtime.autotuner.OutOfResources:
|
| 217 |
+
continue
|
| 218 |
+
|
| 219 |
+
# trial
|
| 220 |
+
for _ in range(num_trials):
|
| 221 |
+
kernel_dur_ms = run_timing(
|
| 222 |
+
num_calls=num_calls,
|
| 223 |
+
bs=bs,
|
| 224 |
+
d_model=d_model,
|
| 225 |
+
num_total_experts=num_total_experts,
|
| 226 |
+
top_k=top_k,
|
| 227 |
+
tp_size=tp_size,
|
| 228 |
+
model_intermediate_size=model_intermediate_size,
|
| 229 |
+
method=method,
|
| 230 |
+
config=config,
|
| 231 |
+
dtype=dtype,
|
| 232 |
+
hidden_states_dtype=hidden_states_dtype,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
kernel_dur_us = 1000 * kernel_dur_ms
|
| 236 |
+
model_dur_ms = kernel_dur_ms * num_layers
|
| 237 |
+
|
| 238 |
+
if kernel_dur_us < best_time_us:
|
| 239 |
+
best_config = config
|
| 240 |
+
best_time_us = kernel_dur_us
|
| 241 |
+
|
| 242 |
+
tqdm.write(
|
| 243 |
+
f"{kernel_dur_us=:.1f} {model_dur_ms=:.1f}"
|
| 244 |
+
f" {bs=} {tp_size=} {top_k=} {num_total_experts=} "
|
| 245 |
+
f"{d_model=} {model_intermediate_size=} {num_layers=}"
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
print("best_time_us", best_time_us)
|
| 249 |
+
print("best_config", best_config)
|
| 250 |
+
|
| 251 |
+
# holds Dict[str, Dict[str, int]]
|
| 252 |
+
filename = get_config_file_name(
|
| 253 |
+
num_total_experts,
|
| 254 |
+
model_intermediate_size // tp_size,
|
| 255 |
+
"float8" if dtype == "float8" else None,
|
| 256 |
+
)
|
| 257 |
+
print(f"writing config to file {filename}")
|
| 258 |
+
existing_content = {}
|
| 259 |
+
if os.path.exists(filename):
|
| 260 |
+
with open(filename, "r") as f:
|
| 261 |
+
existing_content = json.load(f)
|
| 262 |
+
existing_content[str(bs)] = best_config
|
| 263 |
+
with open(filename, "w") as f:
|
| 264 |
+
json.dump(existing_content, f, indent=4)
|
| 265 |
+
f.write("\n")
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def run_timing(
|
| 269 |
+
num_calls: int,
|
| 270 |
+
bs: int,
|
| 271 |
+
d_model: int,
|
| 272 |
+
num_total_experts: int,
|
| 273 |
+
top_k: int,
|
| 274 |
+
tp_size: int,
|
| 275 |
+
model_intermediate_size: int,
|
| 276 |
+
method,
|
| 277 |
+
config,
|
| 278 |
+
dtype: str,
|
| 279 |
+
hidden_states_dtype,
|
| 280 |
+
) -> float:
|
| 281 |
+
shard_intermediate_size = model_intermediate_size // tp_size
|
| 282 |
+
|
| 283 |
+
hidden_states = torch.rand(
|
| 284 |
+
(bs, d_model),
|
| 285 |
+
device="cuda:0",
|
| 286 |
+
dtype=hidden_states_dtype,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
w1 = torch.rand(
|
| 290 |
+
(num_total_experts, 2 * shard_intermediate_size, d_model + padding_size),
|
| 291 |
+
device=hidden_states.device,
|
| 292 |
+
dtype=hidden_states.dtype,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
w2 = torch.rand(
|
| 296 |
+
(num_total_experts, d_model, shard_intermediate_size + padding_size),
|
| 297 |
+
device=hidden_states.device,
|
| 298 |
+
dtype=hidden_states.dtype,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
w1_scale = None
|
| 302 |
+
w2_scale = None
|
| 303 |
+
a1_scale = None
|
| 304 |
+
a2_scale = None
|
| 305 |
+
|
| 306 |
+
if dtype == "float8":
|
| 307 |
+
w1 = w1.to(torch.float8_e4m3fnuz)
|
| 308 |
+
w2 = w2.to(torch.float8_e4m3fnuz)
|
| 309 |
+
w1_scale = torch.ones(
|
| 310 |
+
num_total_experts, device=hidden_states.device, dtype=torch.float32
|
| 311 |
+
)
|
| 312 |
+
w2_scale = torch.ones(
|
| 313 |
+
num_total_experts, device=hidden_states.device, dtype=torch.float32
|
| 314 |
+
)
|
| 315 |
+
a1_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)
|
| 316 |
+
a2_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)
|
| 317 |
+
|
| 318 |
+
gating_output = F.softmax(
|
| 319 |
+
torch.rand(
|
| 320 |
+
(num_calls, bs, num_total_experts),
|
| 321 |
+
device=hidden_states.device,
|
| 322 |
+
dtype=torch.float32,
|
| 323 |
+
),
|
| 324 |
+
dim=-1,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
##################################
|
| 328 |
+
|
| 329 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
| 330 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
| 331 |
+
|
| 332 |
+
start_event.record()
|
| 333 |
+
for i in range(num_calls):
|
| 334 |
+
hidden_states = method(
|
| 335 |
+
hidden_states=hidden_states,
|
| 336 |
+
w1=w1,
|
| 337 |
+
w2=w2,
|
| 338 |
+
w1_scale=w1_scale,
|
| 339 |
+
w2_scale=w2_scale,
|
| 340 |
+
a1_scale=a1_scale,
|
| 341 |
+
a2_scale=a2_scale,
|
| 342 |
+
gating_output=gating_output[0],
|
| 343 |
+
topk=top_k,
|
| 344 |
+
renormalize=True,
|
| 345 |
+
inplace=True,
|
| 346 |
+
override_config=config,
|
| 347 |
+
use_fp8=dtype == "float8",
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
end_event.record()
|
| 351 |
+
end_event.synchronize()
|
| 352 |
+
|
| 353 |
+
dur_ms = start_event.elapsed_time(end_event) / num_calls
|
| 354 |
+
return dur_ms
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
if __name__ == "__main__":
|
| 358 |
+
parser = argparse.ArgumentParser(
|
| 359 |
+
prog="benchmark_mixtral_moe",
|
| 360 |
+
description="Benchmark and tune the fused_moe kernel",
|
| 361 |
+
)
|
| 362 |
+
parser.add_argument(
|
| 363 |
+
"--dtype",
|
| 364 |
+
type=str,
|
| 365 |
+
default="auto",
|
| 366 |
+
choices=["float8", "float16", "bfloat16"],
|
| 367 |
+
help="Data type used for fused_moe kernel computations",
|
| 368 |
+
)
|
| 369 |
+
parser.add_argument("--model", type=str, default="hpcai-tech/grok-1")
|
| 370 |
+
|
| 371 |
+
parser.add_argument("--tp-size", type=int, default=2, help="Tensor paralleli size")
|
| 372 |
+
parser.add_argument("-b", "--batches", type=str)
|
| 373 |
+
|
| 374 |
+
args = parser.parse_args()
|
| 375 |
+
|
| 376 |
+
batches = args.batches.split(",")
|
| 377 |
+
|
| 378 |
+
sys.exit(main(args.model, args.tp_size, args.dtype, batches))
|
sglang/docs/supported_models/extending/modelscope.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use Models From ModelScope
|
| 2 |
+
|
| 3 |
+
To use a model from [ModelScope](https://www.modelscope.cn), set the environment variable `SGLANG_USE_MODELSCOPE`.
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
export SGLANG_USE_MODELSCOPE=true
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
We take [Qwen2-7B-Instruct](https://www.modelscope.cn/models/qwen/qwen2-7b-instruct) as an example.
|
| 10 |
+
|
| 11 |
+
Launch the Server:
|
| 12 |
+
```bash
|
| 13 |
+
python -m sglang.launch_server --model-path qwen/Qwen2-7B-Instruct --port 30000
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
Or start it by docker:
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
docker run --gpus all \
|
| 20 |
+
-p 30000:30000 \
|
| 21 |
+
-v ~/.cache/modelscope:/root/.cache/modelscope \
|
| 22 |
+
--env "SGLANG_USE_MODELSCOPE=true" \
|
| 23 |
+
--ipc=host \
|
| 24 |
+
lmsysorg/sglang:latest \
|
| 25 |
+
python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0 --port 30000
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
Note that modelscope uses a different cache directory than huggingface. You may need to set it manually to avoid running out of disk space.
|
sglang/docs/supported_models/extending/support_new_models.md
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# How to Support New Models
|
| 2 |
+
|
| 3 |
+
This document explains how to add support for new language models and multimodal large language models (MLLMs) in
|
| 4 |
+
SGLang. It also covers how to test new models and register external implementations.
|
| 5 |
+
|
| 6 |
+
## How to Support a New Language Model
|
| 7 |
+
|
| 8 |
+
To support a new model in SGLang, you only need to add a single file under
|
| 9 |
+
the [SGLang Models Directory](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/models). You can learn
|
| 10 |
+
from existing model implementations and create a new file for your model. For most models, you should be able to find a
|
| 11 |
+
similar model to start with (e.g., starting from Llama). Also refer how
|
| 12 |
+
to [port a Model from vLLM to SGLang](#port-a-model-from-vllm-to-sglang)
|
| 13 |
+
|
| 14 |
+
## How to Support a New Multimodal Large Language Model
|
| 15 |
+
|
| 16 |
+
To support a new multimodal large language model (MLLM) in SGLang, there are several key components in addition to the
|
| 17 |
+
standard LLM support:
|
| 18 |
+
|
| 19 |
+
1. **Register your new model as multimodal**:
|
| 20 |
+
Extend `is_multimodal_model`
|
| 21 |
+
in [model_config.py](https://github.com/sgl-project/sglang/blob/0ab3f437aba729b348a683ab32b35b214456efc7/python/sglang/srt/configs/model_config.py#L561)
|
| 22 |
+
to return `True` for your model.
|
| 23 |
+
|
| 24 |
+
2. **Register a new chat-template**:
|
| 25 |
+
Only when your default chat-template is unable to accept images as input: Register a new chat template in [conversation.py](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/conversation.py) and the corresponding matching function.
|
| 26 |
+
|
| 27 |
+
3. **Multimodal Data Processor**:
|
| 28 |
+
Define a new `Processor` class that inherits from `BaseMultimodalProcessor` and register this processor as your
|
| 29 |
+
model’s dedicated processor.
|
| 30 |
+
See [multimodal_processor.py](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/multimodal/processors)
|
| 31 |
+
for more details.
|
| 32 |
+
|
| 33 |
+
4. **Handle Multimodal Tokens**:
|
| 34 |
+
Implement a `pad_input_ids` function for your new model. In this function, multimodal tokens in the prompt should be
|
| 35 |
+
expanded (if necessary) and padded with multimodal-data-hashes so that SGLang can recognize different multimodal data
|
| 36 |
+
with `RadixAttention`.
|
| 37 |
+
|
| 38 |
+
5. **Handle Image Feature Extraction**:
|
| 39 |
+
Implement a `get_image_feature` function for your new model, which extracts image features from raw image data and converts them into the embeddings used by the language model.
|
| 40 |
+
|
| 41 |
+
6. **Adapt to Vision Attention**:
|
| 42 |
+
Adapt the multi-headed `Attention` of ViT with SGLang’s `VisionAttention`.
|
| 43 |
+
|
| 44 |
+
You can refer to [Qwen2VL](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/qwen2_vl.py) or
|
| 45 |
+
other mllm implementations. These models demonstrate how to correctly handle both multimodal and textual inputs.
|
| 46 |
+
|
| 47 |
+
## Testing and Debugging
|
| 48 |
+
|
| 49 |
+
Please note all your testing and benchmarking results in PR description.
|
| 50 |
+
|
| 51 |
+
### Interactive Debugging
|
| 52 |
+
|
| 53 |
+
For interactive debugging, compare the outputs of Hugging Face/Transformers and SGLang. The following two commands
|
| 54 |
+
should give the same text output and very similar prefill logits:
|
| 55 |
+
|
| 56 |
+
- Get the reference output:
|
| 57 |
+
```bash
|
| 58 |
+
python3 scripts/playground/reference_hf.py --model-path [new model] --model-type {text,mllm}
|
| 59 |
+
```
|
| 60 |
+
- Get the SGLang output:
|
| 61 |
+
```bash
|
| 62 |
+
python3 -m sglang.bench_one_batch --correct --model [new model]
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
### Add the Model to the Test Suite
|
| 66 |
+
|
| 67 |
+
To ensure the new model is well maintained, add it to the test suite by including it in the `ALL_OTHER_MODELS` list in
|
| 68 |
+
the [test_generation_models.py](https://github.com/sgl-project/sglang/blob/main/test/srt/models/test_generation_models.py)
|
| 69 |
+
file, test the new model on your local machine and report the results on demonstrative benchmarks (GSM8K, MMLU, MMMU,
|
| 70 |
+
MMMU-Pro, etc.) in your PR. \\
|
| 71 |
+
For VLMs, also include a test in `test_vision_openai_server_{x}.py` (e.g. [test_vision_openai_server_a.py](https://github.com/sgl-project/sglang/blob/main/test/srt/test_vision_openai_server_a.py), [test_vision_openai_server_b.py](https://github.com/sgl-project/sglang/blob/main/test/srt/test_vision_openai_server_b.py)).
|
| 72 |
+
|
| 73 |
+
This is an example command to run to test a new model on your local machine:
|
| 74 |
+
|
| 75 |
+
```bash
|
| 76 |
+
ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
### Benchmark
|
| 80 |
+
|
| 81 |
+
- **(Required) MMMU**: follow MMMU benchmark [README.md](https://github.com/sgl-project/sglang/blob/main/benchmark/mmmu/README.md) to get SGLang vs. HF Transformer accuracy comparison. The accuracy score from SGLang run should not be much lower than that from HF Transformer run. Similarly, follow https://docs.sglang.io/developer_guide/benchmark_and_profiling.html to get performance comparison: TTFT and throughput must meet or exceed baselines (e.g., HF Transformer).
|
| 82 |
+
- **(Optional) Other evals**: If you ran other evals, please note the results in PR description.
|
| 83 |
+
|
| 84 |
+
## Port a Model from vLLM to SGLang
|
| 85 |
+
|
| 86 |
+
The [vLLM Models Directory](https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models) is a valuable
|
| 87 |
+
resource, as vLLM covers many models. SGLang reuses vLLM’s interface and some layers, making it easier to port models
|
| 88 |
+
from vLLM to SGLang.
|
| 89 |
+
|
| 90 |
+
To port a model from vLLM to SGLang:
|
| 91 |
+
|
| 92 |
+
- Compare these two files for guidance:
|
| 93 |
+
- [SGLang Llama Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py)
|
| 94 |
+
- [vLLM Llama Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py)
|
| 95 |
+
- The major differences include:
|
| 96 |
+
- **Replace vLLM’s `Attention` with `RadixAttention`** (ensure you pass `layer_id` to `RadixAttention`).
|
| 97 |
+
- **Replace vLLM’s `LogitsProcessor` with SGLang’s `LogitsProcessor`.**
|
| 98 |
+
- **Replace the multi-headed `Attention` of ViT with SGLang’s `VisionAttention`.**
|
| 99 |
+
- **Replace other vLLM layers** (such as `RMSNorm`, `SiluAndMul`) with SGLang layers.
|
| 100 |
+
- **Remove `Sample`.**
|
| 101 |
+
- **Change the `forward()` functions** and add a `forward_batch()` method.
|
| 102 |
+
- **Add `EntryClass`** at the end.
|
| 103 |
+
- **Ensure that the new implementation uses only SGLang components** and does not rely on any vLLM components.
|
| 104 |
+
|
| 105 |
+
Note: make sure you add your new model to the supported models list in the supported models documentation.
|
| 106 |
+
|
| 107 |
+
## Registering an External Model Implementation
|
| 108 |
+
|
| 109 |
+
In addition to the methods above, you can register your new model with the `ModelRegistry` before launching the server.
|
| 110 |
+
This allows you to integrate your model without modifying the source code.
|
| 111 |
+
|
| 112 |
+
For example:
|
| 113 |
+
|
| 114 |
+
```python
|
| 115 |
+
from sglang.srt.models.registry import ModelRegistry
|
| 116 |
+
from sglang.srt.entrypoints.http_server import launch_server
|
| 117 |
+
|
| 118 |
+
# For a single model, add it to the registry:
|
| 119 |
+
ModelRegistry.models[model_name] = model_class
|
| 120 |
+
|
| 121 |
+
# For multiple models, you can imitate the import_model_classes() function:
|
| 122 |
+
from functools import lru_cache
|
| 123 |
+
|
| 124 |
+
@lru_cache()
|
| 125 |
+
def import_new_model_classes():
|
| 126 |
+
model_arch_name_to_cls = {}
|
| 127 |
+
# Populate model_arch_name_to_cls with your new model classes.
|
| 128 |
+
...
|
| 129 |
+
return model_arch_name_to_cls
|
| 130 |
+
|
| 131 |
+
ModelRegistry.models.update(import_new_model_classes())
|
| 132 |
+
|
| 133 |
+
# Launch the server with your server arguments:
|
| 134 |
+
launch_server(server_args)
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
## Example: Implementing and Serving a Llama Wrapper Model
|
| 138 |
+
|
| 139 |
+
Below is an introductory, step-by-step walkthrough on how to implement a new model end-to-end in SGLang and then run it via the [Offline Engine](https://github.com/sgl-project/sglang/blob/main/docs/basic_usage/offline_engine_api.ipynb).
|
| 140 |
+
|
| 141 |
+
### Implementing Our Model
|
| 142 |
+
|
| 143 |
+
To keep things simple, this new model will be a simple wrapper around [Llama 3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct), and our goal will be just to bias the output logits for each `forward` call by taking the square root of each individual logit.
|
| 144 |
+
|
| 145 |
+
Let's start by defining our model in a file called `llama_wrapper.py`.
|
| 146 |
+
The first step is to import the necessary libraries from SRT, which is SGLang's internal backend.
|
| 147 |
+
|
| 148 |
+
```python
|
| 149 |
+
# In the file `llama_wrapper.py`
|
| 150 |
+
|
| 151 |
+
import torch
|
| 152 |
+
from transformers import LlamaConfig
|
| 153 |
+
from typing import Optional
|
| 154 |
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
| 155 |
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
| 156 |
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
| 157 |
+
|
| 158 |
+
from sglang.srt.models.llama import LlamaForCausalLM
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
Next, we declare a new `class` for our model and have it inherit from `LlamaForCausalLM`, which allows our model to access `LlamaForCausalLM`'s predefined modules and layers, such as `LlamaAttention` and `LlamaMLP`.
|
| 162 |
+
Note that almost all model implementations take in `config` and `quant_config` as arguments for their `__init__` method; `config` and `quant_config` are passed in via [`model_loader/loader.py`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_loader/loader.py#L219).
|
| 163 |
+
Because we have inherited from `LlamaForCausalLM`, we can pass our parameters directly to its constructor, which will set the member variables for us.
|
| 164 |
+
|
| 165 |
+
```python
|
| 166 |
+
class LlamaWrapper(LlamaForCausalLM):
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
config: LlamaConfig,
|
| 170 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 171 |
+
prefix: str = "",
|
| 172 |
+
) -> None:
|
| 173 |
+
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
Now, we want to define the `forward` method, which is what will be called at inference time.
|
| 177 |
+
Note that the signature for `forward` is essentially the same for any model; you can take a look at the other models defined in the [`models` directory](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/) for references.
|
| 178 |
+
To see where exactly `forward` is called in the SGLang runtime's internals, take a look at [`forward_decode`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_executor/model_runner.py#L1705) and [`forward_extend`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_executor/model_runner.py#L1724) in the [`ModelRunner` class](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_executor/model_runner.py).
|
| 179 |
+
|
| 180 |
+
```python
|
| 181 |
+
@torch.no_grad()
|
| 182 |
+
def forward(
|
| 183 |
+
self,
|
| 184 |
+
input_ids: torch.Tensor,
|
| 185 |
+
positions: torch.Tensor,
|
| 186 |
+
forward_batch: ForwardBatch,
|
| 187 |
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
| 188 |
+
input_embeds: Optional[torch.Tensor] = None,
|
| 189 |
+
get_embedding: bool = False,
|
| 190 |
+
) -> LogitsProcessorOutput:
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
We now call the `__call__` method for `self.model` (which is a member variable that `LlamaForCausalLM` defines in its `__init__` method), which eventually calls `LlamaForCausalLM`'s `forward` method.
|
| 194 |
+
After that, we feed the `hidden_states` into our model's `LogitsProcessor` (again defined in `LlamaForCausalLM`).
|
| 195 |
+
|
| 196 |
+
```python
|
| 197 |
+
hidden_states = self.model(
|
| 198 |
+
input_ids,
|
| 199 |
+
positions,
|
| 200 |
+
forward_batch,
|
| 201 |
+
input_embeds,
|
| 202 |
+
pp_proxy_tensors=pp_proxy_tensors,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
res: LogitsProcessorOutput = self.logits_processor(
|
| 206 |
+
input_ids,
|
| 207 |
+
hidden_states,
|
| 208 |
+
self.lm_head,
|
| 209 |
+
forward_batch,
|
| 210 |
+
)
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
After receiving the logits for the next token, we can finally perform our biasing step.
|
| 214 |
+
|
| 215 |
+
```python
|
| 216 |
+
orig_logits = res.next_token_logits
|
| 217 |
+
res.next_token_logits = torch.where(
|
| 218 |
+
orig_logits > 0,
|
| 219 |
+
orig_logits.sqrt(),
|
| 220 |
+
orig_logits
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
return res
|
| 224 |
+
```
|
| 225 |
+
|
| 226 |
+
Now, our `LlamaWrapper` model is created and ready to be served!
|
| 227 |
+
|
| 228 |
+
### Serving Our Model Via SGLang's Offline Engine
|
| 229 |
+
|
| 230 |
+
The next step of this walkthrough involves hosting our new model offline, so that it can be served locally and without an HTTP server.
|
| 231 |
+
|
| 232 |
+
First, create a new file called `run.py`.
|
| 233 |
+
Now, we must ensure that SGLang's `ModelRegistry` can find our model.
|
| 234 |
+
To do this, we first download the model's configuration and weights from Huggingface.
|
| 235 |
+
|
| 236 |
+
```python
|
| 237 |
+
# In the file `run.py`
|
| 238 |
+
|
| 239 |
+
import asyncio
|
| 240 |
+
from functools import lru_cache
|
| 241 |
+
from huggingface_hub import snapshot_download
|
| 242 |
+
from llama_wrapper import LlamaWrapper # Make sure to import our new model!
|
| 243 |
+
import sglang as sgl
|
| 244 |
+
from sglang.srt.models.registry import ModelRegistry
|
| 245 |
+
|
| 246 |
+
# Make sure to request access to this model on Huggingface, then export your
|
| 247 |
+
# `HF_TOKEN` to download the model snapshot
|
| 248 |
+
llama_dir = snapshot_download(
|
| 249 |
+
repo_id="meta-llama/Llama-3.1-8B-Instruct",
|
| 250 |
+
local_dir="./llama_ckpt",
|
| 251 |
+
)
|
| 252 |
+
```
|
| 253 |
+
|
| 254 |
+
Now that we have our model on disk, we want to point it to `LlamaWrapper` by changing the `architectures` field in `./llama_ckpt/config.json` to be `LlamaWrapper`.
|
| 255 |
+
That way, when we pass in the path of our model checkpoint to SGLang, it will know that we want to use "LlamaWrapper" instead of "LlamaForCausalLM" as our model.
|
| 256 |
+
|
| 257 |
+
```python
|
| 258 |
+
{
|
| 259 |
+
"architectures": [
|
| 260 |
+
# "LlamaForCausalLM"
|
| 261 |
+
"LlamaWrapper"
|
| 262 |
+
],
|
| 263 |
+
...
|
| 264 |
+
}
|
| 265 |
+
```
|
| 266 |
+
|
| 267 |
+
However, if we don't link our `LlamaWrapper` class to the "LlamaWrapper" registry keyword, then SGLang won't be able to find our model.
|
| 268 |
+
Thus, to register our `LlamaWrapper`, we want to follow the steps in the above section titled "Registering an External Model Implementation".
|
| 269 |
+
|
| 270 |
+
```python
|
| 271 |
+
@lru_cache()
|
| 272 |
+
def import_new_model_classes():
|
| 273 |
+
model_arch_name_to_cls = {"LlamaWrapper": LlamaWrapper}
|
| 274 |
+
return model_arch_name_to_cls
|
| 275 |
+
|
| 276 |
+
ModelRegistry.models.update(import_new_model_classes())
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
Lastly, when we create our `Engine`, we just pass in the path to the local model directory.
|
| 280 |
+
Then, our `LlamaWrapper` is ready to be served; for this walkthrough, we will use SGLang `Engine`'s non-streaming asynchronous generation endpoint.
|
| 281 |
+
|
| 282 |
+
```python
|
| 283 |
+
def main():
|
| 284 |
+
llm = sgl.Engine(model_path="./llama_ckpt")
|
| 285 |
+
sampling_params = {"temperature": 0.2, "top_k": 5}
|
| 286 |
+
prompts = [
|
| 287 |
+
"Write a short, neutral self-introduction for a fictional character. Hello, my name is",
|
| 288 |
+
"Provide a concise factual statement about France’s capital city. The capital of France is",
|
| 289 |
+
"Explain possible future trends in artificial intelligence. The future of AI is",
|
| 290 |
+
]
|
| 291 |
+
|
| 292 |
+
asyncio.run(run_llm(llm, sampling_params, prompts))
|
| 293 |
+
|
| 294 |
+
llm.shutdown()
|
| 295 |
+
|
| 296 |
+
async def run_llm(
|
| 297 |
+
llm,
|
| 298 |
+
sampling_params,
|
| 299 |
+
prompts,
|
| 300 |
+
) -> None:
|
| 301 |
+
outputs = await llm.async_generate(prompts, sampling_params)
|
| 302 |
+
|
| 303 |
+
for prompt, output in zip(prompts, outputs):
|
| 304 |
+
print(f"\nPrompt: {prompt}")
|
| 305 |
+
print(f"Generated text: {output['text']}")
|
| 306 |
+
|
| 307 |
+
if __name__ == "__main__":
|
| 308 |
+
main()
|
| 309 |
+
```
|
| 310 |
+
|
| 311 |
+
Now, when we call `python run.py`, we will get the outputs of our newly created model!
|
| 312 |
+
|
| 313 |
+
## Documentation
|
| 314 |
+
|
| 315 |
+
Add to table of supported models in [generative_models.md](../text_generation/generative_models.md) or [multimodal_language_models.md](../text_generation/multimodal_language_models.md)
|
| 316 |
+
|
| 317 |
+
---
|
| 318 |
+
|
| 319 |
+
By following these guidelines, you can add support for new language models and multimodal large language models in
|
| 320 |
+
SGLang and ensure they are thoroughly tested and easily integrated into the system.
|
sglang/docs/supported_models/retrieval_ranking/classify_models.md
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Classification API
|
| 2 |
+
|
| 3 |
+
This document describes the `/v1/classify` API endpoint implementation in SGLang, which is compatible with vLLM's classification API format.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
The classification API allows you to classify text inputs using classification models. This implementation follows the same format as vLLM's 0.7.0 classification API.
|
| 8 |
+
|
| 9 |
+
## API Endpoint
|
| 10 |
+
|
| 11 |
+
```
|
| 12 |
+
POST /v1/classify
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
## Request Format
|
| 16 |
+
|
| 17 |
+
```json
|
| 18 |
+
{
|
| 19 |
+
"model": "model_name",
|
| 20 |
+
"input": "text to classify"
|
| 21 |
+
}
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
### Parameters
|
| 25 |
+
|
| 26 |
+
- `model` (string, required): The name of the classification model to use
|
| 27 |
+
- `input` (string, required): The text to classify
|
| 28 |
+
- `user` (string, optional): User identifier for tracking
|
| 29 |
+
- `rid` (string, optional): Request ID for tracking
|
| 30 |
+
- `priority` (integer, optional): Request priority
|
| 31 |
+
|
| 32 |
+
## Response Format
|
| 33 |
+
|
| 34 |
+
```json
|
| 35 |
+
{
|
| 36 |
+
"id": "classify-9bf17f2847b046c7b2d5495f4b4f9682",
|
| 37 |
+
"object": "list",
|
| 38 |
+
"created": 1745383213,
|
| 39 |
+
"model": "jason9693/Qwen2.5-1.5B-apeach",
|
| 40 |
+
"data": [
|
| 41 |
+
{
|
| 42 |
+
"index": 0,
|
| 43 |
+
"label": "Default",
|
| 44 |
+
"probs": [0.565970778465271, 0.4340292513370514],
|
| 45 |
+
"num_classes": 2
|
| 46 |
+
}
|
| 47 |
+
],
|
| 48 |
+
"usage": {
|
| 49 |
+
"prompt_tokens": 10,
|
| 50 |
+
"total_tokens": 10,
|
| 51 |
+
"completion_tokens": 0,
|
| 52 |
+
"prompt_tokens_details": null
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
### Response Fields
|
| 58 |
+
|
| 59 |
+
- `id`: Unique identifier for the classification request
|
| 60 |
+
- `object`: Always "list"
|
| 61 |
+
- `created`: Unix timestamp when the request was created
|
| 62 |
+
- `model`: The model used for classification
|
| 63 |
+
- `data`: Array of classification results
|
| 64 |
+
- `index`: Index of the result
|
| 65 |
+
- `label`: Predicted class label
|
| 66 |
+
- `probs`: Array of probabilities for each class
|
| 67 |
+
- `num_classes`: Total number of classes
|
| 68 |
+
- `usage`: Token usage information
|
| 69 |
+
- `prompt_tokens`: Number of input tokens
|
| 70 |
+
- `total_tokens`: Total number of tokens
|
| 71 |
+
- `completion_tokens`: Number of completion tokens (always 0 for classification)
|
| 72 |
+
- `prompt_tokens_details`: Additional token details (optional)
|
| 73 |
+
|
| 74 |
+
## Example Usage
|
| 75 |
+
|
| 76 |
+
### Using curl
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
curl -v "http://127.0.0.1:8000/v1/classify" \
|
| 80 |
+
-H "Content-Type: application/json" \
|
| 81 |
+
-d '{
|
| 82 |
+
"model": "jason9693/Qwen2.5-1.5B-apeach",
|
| 83 |
+
"input": "Loved the new café—coffee was great."
|
| 84 |
+
}'
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### Using Python
|
| 88 |
+
|
| 89 |
+
```python
|
| 90 |
+
import requests
|
| 91 |
+
import json
|
| 92 |
+
|
| 93 |
+
# Make classification request
|
| 94 |
+
response = requests.post(
|
| 95 |
+
"http://127.0.0.1:8000/v1/classify",
|
| 96 |
+
headers={"Content-Type": "application/json"},
|
| 97 |
+
json={
|
| 98 |
+
"model": "jason9693/Qwen2.5-1.5B-apeach",
|
| 99 |
+
"input": "Loved the new café—coffee was great."
|
| 100 |
+
}
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Parse response
|
| 104 |
+
result = response.json()
|
| 105 |
+
print(json.dumps(result, indent=2))
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## Supported Models
|
| 109 |
+
|
| 110 |
+
The classification API works with any classification model supported by SGLang, including:
|
| 111 |
+
|
| 112 |
+
### Classification Models (Multi-class)
|
| 113 |
+
- `LlamaForSequenceClassification` - Multi-class classification
|
| 114 |
+
- `Qwen2ForSequenceClassification` - Multi-class classification
|
| 115 |
+
- `Qwen3ForSequenceClassification` - Multi-class classification
|
| 116 |
+
- `BertForSequenceClassification` - Multi-class classification
|
| 117 |
+
- `Gemma2ForSequenceClassification` - Multi-class classification
|
| 118 |
+
|
| 119 |
+
**Label Mapping**: The API automatically uses the `id2label` mapping from the model's `config.json` file to provide meaningful label names instead of generic class names. If `id2label` is not available, it falls back to `LABEL_0`, `LABEL_1`, etc., or `Class_0`, `Class_1` as a last resort.
|
| 120 |
+
|
| 121 |
+
### Reward Models (Single score)
|
| 122 |
+
- `InternLM2ForRewardModel` - Single reward score
|
| 123 |
+
- `Qwen2ForRewardModel` - Single reward score
|
| 124 |
+
- `LlamaForSequenceClassificationWithNormal_Weights` - Special reward model
|
| 125 |
+
|
| 126 |
+
**Note**: The `/classify` endpoint in SGLang was originally designed for reward models but now supports all non-generative models. Our `/v1/classify` endpoint provides a standardized vLLM-compatible interface for classification tasks.
|
| 127 |
+
|
| 128 |
+
## Error Handling
|
| 129 |
+
|
| 130 |
+
The API returns appropriate HTTP status codes and error messages:
|
| 131 |
+
|
| 132 |
+
- `400 Bad Request`: Invalid request format or missing required fields
|
| 133 |
+
- `500 Internal Server Error`: Server-side processing error
|
| 134 |
+
|
| 135 |
+
Error response format:
|
| 136 |
+
```json
|
| 137 |
+
{
|
| 138 |
+
"error": "Error message",
|
| 139 |
+
"type": "error_type",
|
| 140 |
+
"code": 400
|
| 141 |
+
}
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
## Implementation Details
|
| 145 |
+
|
| 146 |
+
The classification API is implemented using:
|
| 147 |
+
|
| 148 |
+
1. **Rust Model Gateway**: Handles routing and request/response models in `sgl-model-gateway/src/protocols/spec.rs`
|
| 149 |
+
2. **Python HTTP Server**: Implements the actual endpoint in `python/sglang/srt/entrypoints/http_server.py`
|
| 150 |
+
3. **Classification Service**: Handles the classification logic in `python/sglang/srt/entrypoints/openai/serving_classify.py`
|
| 151 |
+
|
| 152 |
+
## Testing
|
| 153 |
+
|
| 154 |
+
Use the provided test script to verify the implementation:
|
| 155 |
+
|
| 156 |
+
```bash
|
| 157 |
+
python test_classify_api.py
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
## Compatibility
|
| 161 |
+
|
| 162 |
+
This implementation is compatible with vLLM's classification API format, allowing seamless migration from vLLM to SGLang for classification tasks.
|
sglang/docs/supported_models/retrieval_ranking/embedding_models.md
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Embedding Models
|
| 2 |
+
|
| 3 |
+
SGLang provides robust support for embedding models by integrating efficient serving mechanisms with its flexible programming interface. This integration allows for streamlined handling of embedding tasks, facilitating faster and more accurate retrieval and semantic search operations. SGLang's architecture enables better resource utilization and reduced latency in embedding model deployment.
|
| 4 |
+
|
| 5 |
+
```{important}
|
| 6 |
+
Embedding models are executed with `--is-embedding` flag and some may require `--trust-remote-code`
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
## Quick Start
|
| 10 |
+
|
| 11 |
+
### Launch Server
|
| 12 |
+
|
| 13 |
+
```shell
|
| 14 |
+
python3 -m sglang.launch_server \
|
| 15 |
+
--model-path Qwen/Qwen3-Embedding-4B \
|
| 16 |
+
--is-embedding \
|
| 17 |
+
--host 0.0.0.0 \
|
| 18 |
+
--port 30000
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
### Client Request
|
| 22 |
+
|
| 23 |
+
```python
|
| 24 |
+
import requests
|
| 25 |
+
|
| 26 |
+
url = "http://127.0.0.1:30000"
|
| 27 |
+
|
| 28 |
+
payload = {
|
| 29 |
+
"model": "Qwen/Qwen3-Embedding-4B",
|
| 30 |
+
"input": "What is the capital of France?",
|
| 31 |
+
"encoding_format": "float"
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
response = requests.post(url + "/v1/embeddings", json=payload).json()
|
| 35 |
+
print("Embedding:", response["data"][0]["embedding"])
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
## Multimodal Embedding Example
|
| 41 |
+
|
| 42 |
+
For multimodal models like GME that support both text and images:
|
| 43 |
+
|
| 44 |
+
```shell
|
| 45 |
+
python3 -m sglang.launch_server \
|
| 46 |
+
--model-path Alibaba-NLP/gme-Qwen2-VL-2B-Instruct \
|
| 47 |
+
--is-embedding \
|
| 48 |
+
--chat-template gme-qwen2-vl \
|
| 49 |
+
--host 0.0.0.0 \
|
| 50 |
+
--port 30000
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
```python
|
| 54 |
+
import requests
|
| 55 |
+
|
| 56 |
+
url = "http://127.0.0.1:30000"
|
| 57 |
+
|
| 58 |
+
text_input = "Represent this image in embedding space."
|
| 59 |
+
image_path = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/023.jpg"
|
| 60 |
+
|
| 61 |
+
payload = {
|
| 62 |
+
"model": "gme-qwen2-vl",
|
| 63 |
+
"input": [
|
| 64 |
+
{
|
| 65 |
+
"text": text_input
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"image": image_path
|
| 69 |
+
}
|
| 70 |
+
],
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
response = requests.post(url + "/v1/embeddings", json=payload).json()
|
| 74 |
+
|
| 75 |
+
print("Embeddings:", [x.get("embedding") for x in response.get("data", [])])
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
## Matryoshka Embedding Example
|
| 79 |
+
|
| 80 |
+
[Matryoshka Embeddings](https://sbert.net/examples/sentence_transformer/training/matryoshka/README.html#matryoshka-embeddings) or [Matryoshka Representation Learning (MRL)](https://arxiv.org/abs/2205.13147) is a technique used in training embedding models. It allows user to trade off between performance and cost.
|
| 81 |
+
|
| 82 |
+
### 1. Launch a Matryoshka‑capable model
|
| 83 |
+
|
| 84 |
+
If the model config already includes `matryoshka_dimensions` or `is_matryoshka` then no override is needed. Otherwise, you can use `--json-model-override-args` as below:
|
| 85 |
+
|
| 86 |
+
```shell
|
| 87 |
+
python3 -m sglang.launch_server \
|
| 88 |
+
--model-path Qwen/Qwen3-Embedding-0.6B \
|
| 89 |
+
--is-embedding \
|
| 90 |
+
--host 0.0.0.0 \
|
| 91 |
+
--port 30000 \
|
| 92 |
+
--json-model-override-args '{"matryoshka_dimensions": [128, 256, 512, 1024, 1536]}'
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
1. Setting `"is_matryoshka": true` allows truncating to any dimension. Otherwise, the server will validate that the specified dimension in the request is one of `matryoshka_dimensions`.
|
| 96 |
+
2. Omitting `dimensions` in a request returns the full vector.
|
| 97 |
+
|
| 98 |
+
### 2. Make requests with different output dimensions
|
| 99 |
+
|
| 100 |
+
```python
|
| 101 |
+
import requests
|
| 102 |
+
|
| 103 |
+
url = "http://127.0.0.1:30000"
|
| 104 |
+
|
| 105 |
+
# Request a truncated (Matryoshka) embedding by specifying a supported dimension.
|
| 106 |
+
payload = {
|
| 107 |
+
"model": "Qwen/Qwen3-Embedding-0.6B",
|
| 108 |
+
"input": "Explain diffusion models simply.",
|
| 109 |
+
"dimensions": 512 # change to 128 / 1024 / omit for full size
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
response = requests.post(url + "/v1/embeddings", json=payload).json()
|
| 113 |
+
print("Embedding:", response["data"][0]["embedding"])
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
## Supported Models
|
| 118 |
+
|
| 119 |
+
| Model Family | Example Model | Chat Template | Description |
|
| 120 |
+
| ------------------------------------------ | -------------------------------------- | ------------- | --------------------------------------------------------------------------- |
|
| 121 |
+
| **E5 (Llama/Mistral based)** | `intfloat/e5-mistral-7b-instruct` | N/A | High-quality text embeddings based on Mistral/Llama architectures |
|
| 122 |
+
| **GTE-Qwen2** | `Alibaba-NLP/gte-Qwen2-7B-instruct` | N/A | Alibaba's text embedding model with multilingual support |
|
| 123 |
+
| **Qwen3-Embedding** | `Qwen/Qwen3-Embedding-4B` | N/A | Latest Qwen3-based text embedding model for semantic representation |
|
| 124 |
+
| **BGE** | `BAAI/bge-large-en-v1.5` | N/A | BAAI's text embeddings (requires `attention-backend` triton/torch_native) |
|
| 125 |
+
| **GME (Multimodal)** | `Alibaba-NLP/gme-Qwen2-VL-2B-Instruct`| `gme-qwen2-vl`| Multimodal embedding for text and image cross-modal tasks |
|
| 126 |
+
| **CLIP** | `openai/clip-vit-large-patch14-336` | N/A | OpenAI's CLIP for image and text embeddings |
|
sglang/docs/supported_models/retrieval_ranking/rerank_models.md
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Rerank Models
|
| 2 |
+
|
| 3 |
+
SGLang offers comprehensive support for rerank models by incorporating optimized serving frameworks with a flexible programming interface. This setup enables efficient processing of cross-encoder reranking tasks, improving the accuracy and relevance of search result ordering. SGLang’s design ensures high throughput and low latency during reranker model deployment, making it ideal for semantic-based result refinement in large-scale retrieval systems.
|
| 4 |
+
|
| 5 |
+
```{important}
|
| 6 |
+
Rerank models in SGLang fall into two categories:
|
| 7 |
+
|
| 8 |
+
- **Cross-encoder rerank models**: run with `--is-embedding` (embedding runner).
|
| 9 |
+
- **Decoder-only rerank models**: run **without** `--is-embedding` and use next-token logprob scoring (yes/no).
|
| 10 |
+
- Text-only (e.g. Qwen3-Reranker)
|
| 11 |
+
- Multimodal (e.g. Qwen3-VL-Reranker): also supports image/video content
|
| 12 |
+
|
| 13 |
+
Some models may require `--trust-remote-code`.
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
## Supported rerank models
|
| 17 |
+
|
| 18 |
+
| Model Family (Rerank) | Example HuggingFace Identifier | Chat Template | Description |
|
| 19 |
+
|------------------------------------------------|--------------------------------------|---------------|----------------------------------------------------------------------------------------------------------------------------------|
|
| 20 |
+
| **BGE-Reranker (BgeRerankModel)** | `BAAI/bge-reranker-v2-m3` | N/A | Currently only support `attention-backend` `triton` and `torch_native`. High-performance cross-encoder reranker model from BAAI. Suitable for reranking search results based on semantic relevance. |
|
| 21 |
+
| **Qwen3-Reranker (decoder-only yes/no)** | `Qwen/Qwen3-Reranker-8B` | `examples/chat_template/qwen3_reranker.jinja` | Decoder-only reranker using next-token logprob scoring for labels (yes/no). Launch **without** `--is-embedding`. |
|
| 22 |
+
| **Qwen3-VL-Reranker (multimodal yes/no)** | `Qwen/Qwen3-VL-Reranker-2B` | `examples/chat_template/qwen3_vl_reranker.jinja` | Multimodal decoder-only reranker supporting text, images, and videos. Uses yes/no logprob scoring. Launch **without** `--is-embedding`. |
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
## Cross-Encoder Rerank (embedding runner)
|
| 26 |
+
|
| 27 |
+
### Launch Command
|
| 28 |
+
|
| 29 |
+
```shell
|
| 30 |
+
python3 -m sglang.launch_server \
|
| 31 |
+
--model-path BAAI/bge-reranker-v2-m3 \
|
| 32 |
+
--host 0.0.0.0 \
|
| 33 |
+
--disable-radix-cache \
|
| 34 |
+
--chunked-prefill-size -1 \
|
| 35 |
+
--attention-backend triton \
|
| 36 |
+
--is-embedding \
|
| 37 |
+
--port 30000
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
### Example Client Request
|
| 41 |
+
|
| 42 |
+
```python
|
| 43 |
+
import requests
|
| 44 |
+
|
| 45 |
+
url = "http://127.0.0.1:30000/v1/rerank"
|
| 46 |
+
|
| 47 |
+
payload = {
|
| 48 |
+
"model": "BAAI/bge-reranker-v2-m3",
|
| 49 |
+
"query": "what is panda?",
|
| 50 |
+
"documents": [
|
| 51 |
+
"hi",
|
| 52 |
+
"The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China."
|
| 53 |
+
],
|
| 54 |
+
"top_n": 1,
|
| 55 |
+
"return_documents": True
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
response = requests.post(url, json=payload)
|
| 59 |
+
response_json = response.json()
|
| 60 |
+
|
| 61 |
+
for item in response_json:
|
| 62 |
+
if item.get("document"):
|
| 63 |
+
print(f"Score: {item['score']:.2f} - Document: '{item['document']}'")
|
| 64 |
+
else:
|
| 65 |
+
print(f"Score: {item['score']:.2f} - Index: {item['index']}")
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
**Request Parameters:**
|
| 69 |
+
|
| 70 |
+
- `query` (required): The query text to rank documents against
|
| 71 |
+
- `documents` (required): List of documents to be ranked
|
| 72 |
+
- `model` (required): Model to use for reranking
|
| 73 |
+
- `top_n` (optional): Maximum number of documents to return. Defaults to returning all documents. If specified value is greater than the total number of documents, all documents will be returned.
|
| 74 |
+
- `return_documents` (optional): Whether to return documents in the response. Defaults to `True`.
|
| 75 |
+
|
| 76 |
+
## Qwen3-Reranker (decoder-only yes/no rerank)
|
| 77 |
+
|
| 78 |
+
### Launch Command
|
| 79 |
+
|
| 80 |
+
```shell
|
| 81 |
+
python3 -m sglang.launch_server \
|
| 82 |
+
--model-path Qwen/Qwen3-Reranker-0.6B \
|
| 83 |
+
--trust-remote-code \
|
| 84 |
+
--disable-radix-cache \
|
| 85 |
+
--host 0.0.0.0 \
|
| 86 |
+
--port 8001 \
|
| 87 |
+
--chat-template examples/chat_template/qwen3_reranker.jinja
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
```{note}
|
| 91 |
+
Qwen3-Reranker uses decoder-only logprob scoring (yes/no). Do NOT launch it with `--is-embedding`.
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
### Example Client Request (supports optional instruct, top_n, and return_documents)
|
| 95 |
+
|
| 96 |
+
```shell
|
| 97 |
+
curl -X POST http://127.0.0.1:8001/v1/rerank \
|
| 98 |
+
-H "Content-Type: application/json" \
|
| 99 |
+
-d '{
|
| 100 |
+
"model": "Qwen3-Reranker-0.6B",
|
| 101 |
+
"query": "法国首都是哪里?",
|
| 102 |
+
"documents": [
|
| 103 |
+
"法国的首都是巴黎。",
|
| 104 |
+
"德国的首都是柏林。",
|
| 105 |
+
"香蕉是黄色的水果。"
|
| 106 |
+
],
|
| 107 |
+
"instruct": "Given a web search query, retrieve relevant passages that answer the query.",
|
| 108 |
+
"top_n": 2,
|
| 109 |
+
"return_documents": true
|
| 110 |
+
}'
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
**Request Parameters:**
|
| 114 |
+
|
| 115 |
+
- `query` (required): The query text to rank documents against
|
| 116 |
+
- `documents` (required): List of documents to be ranked
|
| 117 |
+
- `model` (required): Model to use for reranking
|
| 118 |
+
- `instruct` (optional): Instruction text for the reranker
|
| 119 |
+
- `top_n` (optional): Maximum number of documents to return. Defaults to returning all documents. If specified value is greater than the total number of documents, all documents will be returned.
|
| 120 |
+
- `return_documents` (optional): Whether to return documents in the response. Defaults to `True`.
|
| 121 |
+
|
| 122 |
+
### Response Format
|
| 123 |
+
|
| 124 |
+
`/v1/rerank` returns a list of objects (sorted by descending score):
|
| 125 |
+
|
| 126 |
+
- `score`: float, higher means more relevant
|
| 127 |
+
- `document`: the original document string (only included when `return_documents` is `true`)
|
| 128 |
+
- `index`: the original index in the input `documents`
|
| 129 |
+
- `meta_info`: optional debug/usage info (may be present for some models)
|
| 130 |
+
|
| 131 |
+
The number of returned results is controlled by the `top_n` parameter. If `top_n` is not specified or is greater than the total number of documents, all documents are returned.
|
| 132 |
+
|
| 133 |
+
Example (with `return_documents: true`):
|
| 134 |
+
|
| 135 |
+
```json
|
| 136 |
+
[
|
| 137 |
+
{"score": 0.99, "document": "法国的首都是巴黎。", "index": 0},
|
| 138 |
+
{"score": 0.01, "document": "德国的首都是柏林。", "index": 1},
|
| 139 |
+
{"score": 0.00, "document": "香蕉是黄色的水果。", "index": 2}
|
| 140 |
+
]
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
Example (with `return_documents: false`):
|
| 144 |
+
|
| 145 |
+
```json
|
| 146 |
+
[
|
| 147 |
+
{"score": 0.99, "index": 0},
|
| 148 |
+
{"score": 0.01, "index": 1},
|
| 149 |
+
{"score": 0.00, "index": 2}
|
| 150 |
+
]
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
Example (with `top_n: 2`):
|
| 154 |
+
|
| 155 |
+
```json
|
| 156 |
+
[
|
| 157 |
+
{"score": 0.99, "document": "法国的首都是巴黎。", "index": 0},
|
| 158 |
+
{"score": 0.01, "document": "德国的首都是柏林。", "index": 1}
|
| 159 |
+
]
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
### Common Pitfalls
|
| 163 |
+
|
| 164 |
+
- If you launch Qwen3-Reranker with `--is-embedding`, `/v1/rerank` cannot compute yes/no logprob scores. Relaunch **without** `--is-embedding`.
|
| 165 |
+
- If you see a validation error like "score should be a valid number" and the backend returned a list, upgrade to a version that coerces `embedding[0]` into `score` for rerank responses.
|
| 166 |
+
|
| 167 |
+
## Qwen3-VL-Reranker (multimodal decoder-only rerank)
|
| 168 |
+
|
| 169 |
+
Qwen3-VL-Reranker extends the Qwen3-Reranker to support multimodal content, allowing reranking of documents containing text, images, and videos.
|
| 170 |
+
|
| 171 |
+
### Launch Command
|
| 172 |
+
|
| 173 |
+
```shell
|
| 174 |
+
python3 -m sglang.launch_server \
|
| 175 |
+
--model-path Qwen/Qwen3-VL-Reranker-2B \
|
| 176 |
+
--trust-remote-code \
|
| 177 |
+
--disable-radix-cache \
|
| 178 |
+
--host 0.0.0.0 \
|
| 179 |
+
--port 30000 \
|
| 180 |
+
--chat-template examples/chat_template/qwen3_vl_reranker.jinja
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
```{note}
|
| 184 |
+
Qwen3-VL-Reranker uses decoder-only logprob scoring (yes/no) like Qwen3-Reranker. Do NOT launch it with `--is-embedding`.
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
### Text-Only Reranking (backward compatible)
|
| 188 |
+
|
| 189 |
+
```python
|
| 190 |
+
import requests
|
| 191 |
+
|
| 192 |
+
url = "http://127.0.0.1:30000/v1/rerank"
|
| 193 |
+
|
| 194 |
+
payload = {
|
| 195 |
+
"model": "Qwen3-VL-Reranker-2B",
|
| 196 |
+
"query": "What is machine learning?",
|
| 197 |
+
"documents": [
|
| 198 |
+
"Machine learning is a branch of artificial intelligence that enables computers to learn from data.",
|
| 199 |
+
"The weather in Paris is usually mild with occasional rain.",
|
| 200 |
+
"Deep learning is a subset of machine learning using neural networks with many layers.",
|
| 201 |
+
],
|
| 202 |
+
"instruct": "Retrieve passages that answer the question.",
|
| 203 |
+
"return_documents": True
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
response = requests.post(url, json=payload)
|
| 207 |
+
results = response.json()
|
| 208 |
+
|
| 209 |
+
for item in results:
|
| 210 |
+
print(f"Score: {item['score']:.4f} - {item['document'][:60]}...")
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
### Image Reranking (text query, image/mixed documents)
|
| 214 |
+
|
| 215 |
+
```python
|
| 216 |
+
import requests
|
| 217 |
+
|
| 218 |
+
url = "http://127.0.0.1:30000/v1/rerank"
|
| 219 |
+
|
| 220 |
+
payload = {
|
| 221 |
+
"query": "A woman playing with her dog on a beach at sunset.",
|
| 222 |
+
"documents": [
|
| 223 |
+
# Document 1: Text description
|
| 224 |
+
"A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset.",
|
| 225 |
+
# Document 2: Image URL
|
| 226 |
+
[
|
| 227 |
+
{
|
| 228 |
+
"type": "image_url",
|
| 229 |
+
"image_url": {
|
| 230 |
+
"url": "https://example.com/beach_dog.jpeg"
|
| 231 |
+
}
|
| 232 |
+
}
|
| 233 |
+
],
|
| 234 |
+
# Document 3: Text + Image (mixed)
|
| 235 |
+
[
|
| 236 |
+
{"type": "text", "text": "A joyful scene at the beach:"},
|
| 237 |
+
{
|
| 238 |
+
"type": "image_url",
|
| 239 |
+
"image_url": {
|
| 240 |
+
"url": "https://example.com/beach_dog.jpeg"
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
]
|
| 244 |
+
],
|
| 245 |
+
"instruct": "Retrieve images or text relevant to the user's query.",
|
| 246 |
+
"return_documents": False
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
response = requests.post(url, json=payload)
|
| 250 |
+
results = response.json()
|
| 251 |
+
|
| 252 |
+
for item in results:
|
| 253 |
+
print(f"Index: {item['index']}, Score: {item['score']:.4f}")
|
| 254 |
+
```
|
| 255 |
+
|
| 256 |
+
### Multimodal Query Reranking (query with image)
|
| 257 |
+
|
| 258 |
+
```python
|
| 259 |
+
import requests
|
| 260 |
+
|
| 261 |
+
url = "http://127.0.0.1:30000/v1/rerank"
|
| 262 |
+
|
| 263 |
+
payload = {
|
| 264 |
+
# Query with text and image
|
| 265 |
+
"query": [
|
| 266 |
+
{"type": "text", "text": "Find similar images to this:"},
|
| 267 |
+
{
|
| 268 |
+
"type": "image_url",
|
| 269 |
+
"image_url": {
|
| 270 |
+
"url": "https://example.com/reference_image.jpeg"
|
| 271 |
+
}
|
| 272 |
+
}
|
| 273 |
+
],
|
| 274 |
+
"documents": [
|
| 275 |
+
"A cat sleeping on a couch.",
|
| 276 |
+
"A woman and her dog enjoying the sunset at the beach.",
|
| 277 |
+
"A busy city street with cars and pedestrians.",
|
| 278 |
+
[
|
| 279 |
+
{
|
| 280 |
+
"type": "image_url",
|
| 281 |
+
"image_url": {
|
| 282 |
+
"url": "https://example.com/similar_image.jpeg"
|
| 283 |
+
}
|
| 284 |
+
}
|
| 285 |
+
]
|
| 286 |
+
],
|
| 287 |
+
"instruct": "Find images or descriptions similar to the query image."
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
response = requests.post(url, json=payload)
|
| 291 |
+
results = response.json()
|
| 292 |
+
|
| 293 |
+
for item in results:
|
| 294 |
+
print(f"Index: {item['index']}, Score: {item['score']:.4f}")
|
| 295 |
+
```
|
| 296 |
+
|
| 297 |
+
### Request Parameters (Multimodal)
|
| 298 |
+
|
| 299 |
+
- `query` (required): Can be a string (text-only) or a list of content parts:
|
| 300 |
+
- `{"type": "text", "text": "..."}` for text
|
| 301 |
+
- `{"type": "image_url", "image_url": {"url": "..."}}` for images
|
| 302 |
+
- `{"type": "video_url", "video_url": {"url": "..."}}` for videos
|
| 303 |
+
- `documents` (required): List where each document can be a string or list of content parts (same format as query)
|
| 304 |
+
- `instruct` (optional): Instruction text for the reranker
|
| 305 |
+
- `top_n` (optional): Maximum number of documents to return
|
| 306 |
+
- `return_documents` (optional): Whether to return documents in the response (default: `false`)
|
| 307 |
+
|
| 308 |
+
### Common Pitfalls
|
| 309 |
+
|
| 310 |
+
- Always use `--chat-template examples/chat_template/qwen3_vl_reranker.jinja` for Qwen3-VL-Reranker.
|
| 311 |
+
- Do NOT launch with `--is-embedding`.
|
| 312 |
+
- For best results, use `--disable-radix-cache` to avoid caching issues with multimodal content.
|
| 313 |
+
- **Note**: Currently only `Qwen3-VL-Reranker-2B` is tested and supported. The 8B model may have different behavior and is not guaranteed to work with this template.
|
sglang/docs/supported_models/specialized/index.rst
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Specialized Models
|
| 2 |
+
==================
|
| 3 |
+
|
| 4 |
+
Models for specialized tasks like reward modeling.
|
| 5 |
+
|
| 6 |
+
.. toctree::
|
| 7 |
+
:maxdepth: 1
|
| 8 |
+
|
| 9 |
+
reward_models.md
|
sglang/docs/supported_models/specialized/reward_models.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Reward Models
|
| 2 |
+
|
| 3 |
+
These models output a scalar reward score or classification result, often used in reinforcement learning or content moderation tasks.
|
| 4 |
+
|
| 5 |
+
```{important}
|
| 6 |
+
They are executed with `--is-embedding` and some may require `--trust-remote-code`.
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
## Example launch Command
|
| 10 |
+
|
| 11 |
+
```shell
|
| 12 |
+
python3 -m sglang.launch_server \
|
| 13 |
+
--model-path Qwen/Qwen2.5-Math-RM-72B \ # example HF/local path
|
| 14 |
+
--is-embedding \
|
| 15 |
+
--host 0.0.0.0 \
|
| 16 |
+
--tp-size=4 \ # set for tensor parallelism
|
| 17 |
+
--port 30000 \
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
## Supported models
|
| 21 |
+
|
| 22 |
+
| Model Family (Reward) | Example HuggingFace Identifier | Description |
|
| 23 |
+
|---------------------------------------------------------------------------|-----------------------------------------------------|---------------------------------------------------------------------------------|
|
| 24 |
+
| **Llama (3.1 Reward / `LlamaForSequenceClassification`)** | `Skywork/Skywork-Reward-Llama-3.1-8B-v0.2` | Reward model (preference classifier) based on Llama 3.1 (8B) for scoring and ranking responses for RLHF. |
|
| 25 |
+
| **Gemma 2 (27B Reward / `Gemma2ForSequenceClassification`)** | `Skywork/Skywork-Reward-Gemma-2-27B-v0.2` | Derived from Gemma‑2 (27B), this model provides human preference scoring for RLHF and multilingual tasks. |
|
| 26 |
+
| **InternLM 2 (Reward / `InternLM2ForRewardMode`)** | `internlm/internlm2-7b-reward` | InternLM 2 (7B)–based reward model used in alignment pipelines to guide outputs toward preferred behavior. |
|
| 27 |
+
| **Qwen2.5 (Reward - Math / `Qwen2ForRewardModel`)** | `Qwen/Qwen2.5-Math-RM-72B` | A 72B math-specialized RLHF reward model from the Qwen2.5 series, tuned for evaluating and refining responses. |
|
| 28 |
+
| **Qwen2.5 (Reward - Sequence / `Qwen2ForSequenceClassification`)** | `jason9693/Qwen2.5-1.5B-apeach` | A smaller Qwen2.5 variant used for sequence classification, offering an alternative RLHF scoring mechanism. |
|
sglang/docs/supported_models/text_generation/diffusion_language_models.md
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Diffusion Language Models
|
| 2 |
+
|
| 3 |
+
Diffusion language models have shown promise for non-autoregressive text generation with parallel decoding capabilities. Unlike auto-regressive language models, different diffusion language models require different decoding strategies.
|
| 4 |
+
|
| 5 |
+
## Example Launch Command
|
| 6 |
+
|
| 7 |
+
SGLang supports different DLLM algorithms such as `LowConfidence` and `JointThreshold`.
|
| 8 |
+
|
| 9 |
+
```shell
|
| 10 |
+
python3 -m sglang.launch_server \
|
| 11 |
+
--model-path inclusionAI/LLaDA2.0-mini \ # example HF/local path
|
| 12 |
+
--dllm-algorithm LowConfidence \
|
| 13 |
+
--dllm-algorithm-config ./config.yaml \ # Optional. Uses the algorithm's default if not set.
|
| 14 |
+
--host 0.0.0.0 \
|
| 15 |
+
--port 30000
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
## Example Configuration File
|
| 19 |
+
|
| 20 |
+
Depending on the algorithm selected, the configuration parameters vary.
|
| 21 |
+
|
| 22 |
+
LowConfidence Config:
|
| 23 |
+
|
| 24 |
+
```yaml
|
| 25 |
+
# Confidence threshold for accepting predicted tokens
|
| 26 |
+
# - Higher values: More conservative, better quality but slower
|
| 27 |
+
# - Lower values: More aggressive, faster but potentially lower quality
|
| 28 |
+
# Range: 0.0 - 1.0
|
| 29 |
+
threshold: 0.95
|
| 30 |
+
|
| 31 |
+
# Default: 32, for LLaDA2MoeModelLM
|
| 32 |
+
block_size: 32
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
JointThreshold Config:
|
| 36 |
+
|
| 37 |
+
```yaml
|
| 38 |
+
# Decoding threshold for Mask-to-Token (M2T) phase
|
| 39 |
+
# - Higher values: More conservative, better quality but slower
|
| 40 |
+
# - Lower values: More aggressive, faster but potentially lower quality
|
| 41 |
+
# Range: 0.0 - 1.0
|
| 42 |
+
threshold: 0.5
|
| 43 |
+
# Decoding threshold for Token-to-Token (T2T) phase
|
| 44 |
+
# Range: 0.0 - 1.0
|
| 45 |
+
# Setting to 0.0 allows full editing (recommended for most cases).
|
| 46 |
+
edit_threshold: 0.0
|
| 47 |
+
# Max extra T2T steps after all masks are removed. Prevents infinite loops.
|
| 48 |
+
max_post_edit_steps: 16
|
| 49 |
+
# 2-gram repetition penalty (default 0).
|
| 50 |
+
# An empirical value of 3 is often sufficient to mitigate most repetitions.
|
| 51 |
+
penalty_lambda: 0
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## Example Client Code Snippet
|
| 55 |
+
|
| 56 |
+
Just like other supported models, diffusion language models can be used via the REST API or Python client.
|
| 57 |
+
|
| 58 |
+
Python client example for making a generation request to the launched server:
|
| 59 |
+
|
| 60 |
+
```python
|
| 61 |
+
import sglang as sgl
|
| 62 |
+
|
| 63 |
+
def main():
|
| 64 |
+
llm = sgl.Engine(model_path="inclusionAI/LLaDA2.0-mini",
|
| 65 |
+
dllm_algorithm="LowConfidence",
|
| 66 |
+
max_running_requests=1,
|
| 67 |
+
trust_remote_code=True)
|
| 68 |
+
|
| 69 |
+
prompts = [
|
| 70 |
+
"<role>SYSTEM</role>detailed thinking off<|role_end|><role>HUMAN</role> Write a brief introduction of the great wall <|role_end|><role>ASSISTANT</role>"
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
sampling_params = {
|
| 74 |
+
"temperature": 0,
|
| 75 |
+
"max_new_tokens": 1024,
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
outputs = llm.generate(prompts, sampling_params)
|
| 79 |
+
print(outputs)
|
| 80 |
+
|
| 81 |
+
if __name__ == '__main__':
|
| 82 |
+
main()
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
Curl example for making a generation request to the launched server:
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
curl -X POST "http://127.0.0.1:30000/generate" \
|
| 89 |
+
-H "Content-Type: application/json" \
|
| 90 |
+
-d '{
|
| 91 |
+
"text": [
|
| 92 |
+
"<role>SYSTEM</role>detailed thinking off<|role_end|><role>HUMAN</role> Write the number from 1 to 128 <|role_end|><role>ASSISTANT</role>",
|
| 93 |
+
"<role>SYSTEM</role>detailed thinking off<|role_end|><role>HUMAN</role> Write a brief introduction of the great wall <|role_end|><role>ASSISTANT</role>"
|
| 94 |
+
],
|
| 95 |
+
"stream": true,
|
| 96 |
+
"sampling_params": {
|
| 97 |
+
"temperature": 0,
|
| 98 |
+
"max_new_tokens": 1024
|
| 99 |
+
}
|
| 100 |
+
}'
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
## Supported Models
|
| 104 |
+
|
| 105 |
+
Below the supported models are summarized in a table.
|
| 106 |
+
|
| 107 |
+
| Model Family | Example Model | Description |
|
| 108 |
+
| -------------------------- | ---------------------------- | ---------------------------------------------------------------------------------------------------- |
|
| 109 |
+
| **LLaDA2.0 (mini, flash)** | `inclusionAI/LLaDA2.0-flash` | LLaDA2.0-flash is a diffusion language model featuring a 100B Mixture-of-Experts (MoE) architecture. |
|
| 110 |
+
| **SDAR (JetLM)** | `JetLM/SDAR-8B-Chat` | SDAR series diffusion language model (Chat), dense architecture. |
|
| 111 |
+
| **SDAR (JetLM)** | `JetLM/SDAR-30B-A3B-Chat` | SDAR series diffusion language model (Chat), MoE architecture. |
|
sglang/docs/supported_models/text_generation/generative_models.md
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Large Language Models
|
| 2 |
+
|
| 3 |
+
These models accept text input and produce text output (e.g., chat completions). They are primarily large language models (LLMs), some with mixture-of-experts (MoE) architectures for scaling.
|
| 4 |
+
|
| 5 |
+
## Example launch Command
|
| 6 |
+
|
| 7 |
+
```shell
|
| 8 |
+
python3 -m sglang.launch_server \
|
| 9 |
+
--model-path meta-llama/Llama-3.2-1B-Instruct \ # example HF/local path
|
| 10 |
+
--host 0.0.0.0 \
|
| 11 |
+
--port 30000 \
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
## Supported models
|
| 15 |
+
|
| 16 |
+
Below the supported models are summarized in a table.
|
| 17 |
+
|
| 18 |
+
If you are unsure if a specific architecture is implemented, you can search for it via GitHub. For example, to search for `Qwen3ForCausalLM`, use the expression:
|
| 19 |
+
|
| 20 |
+
```
|
| 21 |
+
repo:sgl-project/sglang path:/^python\/sglang\/srt\/models\// Qwen3ForCausalLM
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
in the GitHub search bar.
|
| 25 |
+
|
| 26 |
+
| Model Family (Variants) | Example HuggingFace Identifier | Description |
|
| 27 |
+
|-------------------------------------|--------------------------------------------------|----------------------------------------------------------------------------------------|
|
| 28 |
+
| **DeepSeek** (v1, v2, v3/R1) | `deepseek-ai/DeepSeek-R1` | Series of advanced reasoning-optimized models (including a 671B MoE) trained with reinforcement learning; top performance on complex reasoning, math, and code tasks. [SGLang provides Deepseek v3/R1 model-specific optimizations](../basic_usage/deepseek.md) and [Reasoning Parser](../advanced_features/separate_reasoning.ipynb)|
|
| 29 |
+
| **Kimi K2** (Thinking, Instruct) | `moonshotai/Kimi-K2-Instruct` | Moonshot AI's 1 trillion parameter MoE model (32B active) with 128K–256K context; state-of-the-art agentic intelligence with stable long-horizon agency across 200–300 sequential tool calls. Features MLA attention and native INT4 quantization. [See Reasoning Parser docs](../advanced_features/separate_reasoning.ipynb)|
|
| 30 |
+
| **Kimi Linear** (48B-A3B) | `moonshotai/Kimi-Linear-48B-A3B-Instruct` | Moonshot AI's hybrid linear attention model (48B total, 3B active) with 1M token context; features Kimi Delta Attention (KDA) for up to 6× faster decoding and 75% KV cache reduction vs full attention. |
|
| 31 |
+
| **GPT-OSS** | `openai/gpt-oss-20b`, `openai/gpt-oss-120b` | OpenAI’s latest GPT-OSS series for complex reasoning, agentic tasks, and versatile developer use cases.|
|
| 32 |
+
| **Qwen** (3.5, 3, 3MoE, 3Next, 2.5, 2 series) | `Qwen/Qwen3.5-397B-A17B`, `Qwen/Qwen3-0.6B`, `Qwen/Qwen3-30B-A3B` | Alibaba’s latest Qwen3 series for complex reasoning, language understanding, and generation tasks; Support for MoE variants along with previous generation 2.5, 2, etc. [SGLang provides Qwen3 specific reasoning parser](../advanced_features/separate_reasoning.ipynb)|
|
| 33 |
+
| **Llama** (2, 3.x, 4 series) | `meta-llama/Llama-4-Scout-17B-16E-Instruct` | Meta's open LLM series, spanning 7B to 400B parameters (Llama 2, 3, and new Llama 4) with well-recognized performance. [SGLang provides Llama-4 model-specific optimizations](../basic_usage/llama4.md) |
|
| 34 |
+
| **Mistral** (Mixtral, NeMo, Small3) | `mistralai/Mistral-7B-Instruct-v0.2` | Open 7B LLM by Mistral AI with strong performance; extended into MoE (“Mixtral”) and NeMo Megatron variants for larger scale. |
|
| 35 |
+
| **Gemma** (v1, v2, v3) | `google/gemma-3-1b-it` | Google’s family of efficient multilingual models (1B–27B); Gemma 3 offers a 128K context window, and its larger (4B+) variants support vision input. |
|
| 36 |
+
| **Phi** (Phi-1.5, Phi-2, Phi-3, Phi-4, Phi-MoE series) | `microsoft/Phi-4-multimodal-instruct`, `microsoft/Phi-3.5-MoE-instruct` | Microsoft’s Phi family of small models (1.3B–5.6B); Phi-4-multimodal (5.6B) processes text, images, and speech, Phi-4-mini is a high-accuracy text model and Phi-3.5-MoE is a mixture-of-experts model. |
|
| 37 |
+
| **MiniCPM** (v3, 4B) | `openbmb/MiniCPM3-4B` | OpenBMB’s series of compact LLMs for edge devices; MiniCPM 3 (4B) achieves GPT-3.5-level results in text tasks. |
|
| 38 |
+
| **OLMo** (2, 3) | `allenai/OLMo-3-1125-32B`, `allenai/OLMo-3-32B-Think`, `allenai/OLMo-2-1124-7B-Instruct` | Allen AI’s series of Open Language Models designed to enable the science of language models. |
|
| 39 |
+
| **OLMoE** (Open MoE) | `allenai/OLMoE-1B-7B-0924` | Allen AI’s open Mixture-of-Experts model (7B total, 1B active parameters) delivering state-of-the-art results with sparse expert activation. |
|
| 40 |
+
| **MiniMax-M2** (M2, M2.1, M2.5) | `MiniMaxAI/MiniMax-M2.5`, `MiniMaxAI/MiniMax-M2.1`, `MiniMaxAI/MiniMax-M2` | MiniMax's SOTA LLM for coding & agentic workflows. |
|
| 41 |
+
| **StableLM** (3B, 7B) | `stabilityai/stablelm-tuned-alpha-7b` | StabilityAI’s early open-source LLM (3B & 7B) for general text generation; a demonstration model with basic instruction-following ability. |
|
| 42 |
+
| **Command-(R,A)** (Cohere) | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, `CohereLabs/c4ai-command-a-03-2025` | Cohere’s open conversational LLM (Command series) optimized for long context, retrieval-augmented generation, and tool use. |
|
| 43 |
+
| **DBRX** (Databricks) | `databricks/dbrx-instruct` | Databricks’ 132B-parameter MoE model (36B active) trained on 12T tokens; competes with GPT-3.5 quality as a fully open foundation model. |
|
| 44 |
+
| **Grok** (xAI) | `xai-org/grok-1` | xAI’s grok-1 model known for vast size(314B parameters) and high quality; integrated in SGLang for high-performance inference. |
|
| 45 |
+
| **ChatGLM** (GLM-130B family) | `THUDM/chatglm2-6b` | Zhipu AI’s bilingual chat model (6B) excelling at Chinese-English dialogue; fine-tuned for conversational quality and alignment. |
|
| 46 |
+
| **InternLM 2** (7B, 20B) | `internlm/internlm2-7b` | Next-gen InternLM (7B and 20B) from SenseTime, offering strong reasoning and ultra-long context support (up to 200K tokens). |
|
| 47 |
+
| **ExaONE 3** (Korean-English) | `LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct` | LG AI Research’s Korean-English model (7.8B) trained on 8T tokens; provides high-quality bilingual understanding and generation. |
|
| 48 |
+
| **Baichuan 2** (7B, 13B) | `baichuan-inc/Baichuan2-13B-Chat` | BaichuanAI’s second-generation Chinese-English LLM (7B/13B) with improved performance and an open commercial license. |
|
| 49 |
+
| **XVERSE** (MoE) | `xverse/XVERSE-MoE-A36B` | Yuanxiang’s open MoE LLM (XVERSE-MoE-A36B: 255B total, 36B active) supporting ~40 languages; delivers 100B+ dense-level performance via expert routing. |
|
| 50 |
+
| **SmolLM** (135M–1.7B) | `HuggingFaceTB/SmolLM-1.7B` | Hugging Face’s ultra-small LLM series (135M–1.7B params) offering surprisingly strong results, enabling advanced AI on mobile/edge devices. |
|
| 51 |
+
| **GLM-4** (Multilingual 9B) | `ZhipuAI/glm-4-9b-chat` | Zhipu’s GLM-4 series (up to 9B parameters) – open multilingual models with support for 1M-token context and even a 5.6B multimodal variant (Phi-4V). |
|
| 52 |
+
| **MiMo** (7B series) | `XiaomiMiMo/MiMo-7B-RL` | Xiaomi's reasoning-optimized model series, leverages Multiple-Token Prediction for faster inference. |
|
| 53 |
+
| **ERNIE-4.5** (4.5, 4.5MoE series) | `baidu/ERNIE-4.5-21B-A3B-PT` | Baidu's ERNIE-4.5 series which consists of MoE with 47B and 3B active parameters, with the largest model having 424B total parameters, as well as a 0.3B dense model. |
|
| 54 |
+
| **Arcee AFM-4.5B** | `arcee-ai/AFM-4.5B-Base` | Arcee's foundational model series for real world reliability and edge deployments. |
|
| 55 |
+
| **Persimmon** (8B) | `adept/persimmon-8b-chat` | Adept’s open 8B model with a 16K context window and fast inference; trained for broad usability and licensed under Apache 2.0. |
|
| 56 |
+
| **Solar** (10.7B) | `upstage/SOLAR-10.7B-Instruct-v1.0` | Upstage's 10.7B parameter model, optimized for instruction-following tasks. This architecture incorporates a depth-up scaling methodology, enhancing model performance. |
|
| 57 |
+
| **Tele FLM** (52B-1T) | `CofeAI/Tele-FLM` | BAAI & TeleAI's multilingual model, available in 52-billion and 1-trillion parameter variants. It is a decoder-only transformer trained on ~2T tokens |
|
| 58 |
+
| **Ling** (16.8B–290B) | `inclusionAI/Ling-lite`, `inclusionAI/Ling-plus` | InclusionAI’s open MoE models. Ling-Lite has 16.8B total / 2.75B active parameters, and Ling-Plus has 290B total / 28.8B active parameters. They are designed for high performance on NLP and complex reasoning tasks. |
|
| 59 |
+
| **Granite 3.0, 3.1** (IBM) | `ibm-granite/granite-3.1-8b-instruct` | IBM's open dense foundation models optimized for reasoning, code, and business AI use cases. Integrated with Red Hat and watsonx systems. |
|
| 60 |
+
| **Granite 3.0 MoE** (IBM) | `ibm-granite/granite-3.0-3b-a800m-instruct` | IBM’s Mixture-of-Experts models offering strong performance with cost-efficiency. MoE expert routing designed for enterprise deployment at scale. |
|
| 61 |
+
| **GPT-J** (6B) | `EleutherAI/gpt-j-6b` | EleutherAI's GPT-2-like causal language model (6B) trained on the [Pile](https://pile.eleuther.ai/) dataset. |
|
| 62 |
+
| **Orion** (14B) | `OrionStarAI/Orion-14B-Base` | A series of open-source multilingual large language models by OrionStarAI, pretrained on a 2.5T token multilingual corpus including Chinese, English, Japanese, Korean, etc, and it exhibits superior performance in these languages. |
|
| 63 |
+
| **Llama Nemotron Super** (v1, v1.5, NVIDIA) | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, `nvidia/Llama-3_3-Nemotron-Super-49B-v1_5` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. |
|
| 64 |
+
| **Llama Nemotron Ultra** (v1, NVIDIA) | `nvidia/Llama-3_1-Nemotron-Ultra-253B-v1` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. |
|
| 65 |
+
| **NVIDIA Nemotron Nano 2.0** | `nvidia/NVIDIA-Nemotron-Nano-9B-v2` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. `Nemotron-Nano-9B-v2` is a hybrid Mamba-Transformer language model designed to increase throughput for reasoning workloads while achieving state-of-the-art accuracy compared to similarly-sized models. |
|
| 66 |
+
| **StarCoder2** (3B-15B) | `bigcode/starcoder2-7b` | StarCoder2 is a family of open large language models (LLMs) specialized for code generation and understanding. It is the successor to StarCoder, jointly developed by the BigCode project (a collaboration between Hugging Face, ServiceNow Research, and other contributors). |
|
| 67 |
+
| **Jet-Nemotron** | `jet-ai/Jet-Nemotron-2B` | Jet-Nemotron is a new family of hybrid-architecture language models that surpass state-of-the-art open-source full-attention language models, while achieving significant efficiency gains. |
|
| 68 |
+
| **Trinity** (Nano, Mini) | `arcee-ai/Trinity-Mini` | Arcee's foundational MoE Trinity family of models, open weights under Apache 2.0. |
|
| 69 |
+
| **Falcon-H1** (0.5B–34B) | `tiiuae/Falcon-H1-34B-Instruct` | TII's hybrid Mamba-Transformer architecture combining attention and state-space models for efficient long-context inference. |
|
| 70 |
+
| **Hunyuan-Large** (389B, MoE) | `tencent/Tencent-Hunyuan-Large` | Tencent's open-source MoE model with 389B total / 52B active parameters, featuring Cross-Layer Attention (CLA) for improved efficiency. |
|
| 71 |
+
| **IBM Granite 4.0 (Hybrid, Dense)** | `ibm-granite/granite-4.0-h-micro`, `ibm-granite/granite-4.0-micro` | IBM Granite 4.0 micro models: hybrid Mamba–MoE (`h-micro`) and dense (`micro`) variants. Enterprise-focused reasoning models |
|
| 72 |
+
| **Sarvam 2** (30B-A2B, 105B-A10B) | `sarvamai/sarvam-2` | Sarvam's Mixture-of-Experts models. The 105B variant uses MLA (Multi-head Latent Attention) and the 30B variant uses GQA, both with 128 routed experts. |
|
sglang/docs/supported_models/text_generation/index.rst
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Text Generation
|
| 2 |
+
===============
|
| 3 |
+
|
| 4 |
+
Models for generating text from text or multimodal inputs.
|
| 5 |
+
|
| 6 |
+
.. toctree::
|
| 7 |
+
:maxdepth: 1
|
| 8 |
+
|
| 9 |
+
generative_models.md
|
| 10 |
+
multimodal_language_models.md
|
| 11 |
+
diffusion_language_models.md
|
sglang/docs/supported_models/text_generation/multimodal_language_models.md
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Multimodal Language Models
|
| 2 |
+
|
| 3 |
+
These models accept multi-modal inputs (e.g., images and text) and generate text output. They augment language models with multimodal encoders.
|
| 4 |
+
|
| 5 |
+
## Example launch Command
|
| 6 |
+
|
| 7 |
+
```shell
|
| 8 |
+
python3 -m sglang.launch_server \
|
| 9 |
+
--model-path meta-llama/Llama-3.2-11B-Vision-Instruct \ # example HF/local path
|
| 10 |
+
--host 0.0.0.0 \
|
| 11 |
+
--port 30000 \
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
> See the [OpenAI APIs section](https://docs.sglang.io/basic_usage/openai_api_vision.html) for how to send multimodal requests.
|
| 15 |
+
|
| 16 |
+
## Supported models
|
| 17 |
+
|
| 18 |
+
Below the supported models are summarized in a table.
|
| 19 |
+
|
| 20 |
+
If you are unsure if a specific architecture is implemented, you can search for it via GitHub. For example, to search for `Qwen2_5_VLForConditionalGeneration`, use the expression:
|
| 21 |
+
|
| 22 |
+
```
|
| 23 |
+
repo:sgl-project/sglang path:/^python\/sglang\/srt\/models\// Qwen2_5_VLForConditionalGeneration
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
in the GitHub search bar.
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
| Model Family (Variants) | Example HuggingFace Identifier | Description | Notes |
|
| 30 |
+
|----------------------------|--------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------|
|
| 31 |
+
| **Qwen-VL** | `Qwen/Qwen3-VL-235B-A22B-Instruct` | Alibaba's vision-language extension of Qwen; for example, Qwen2.5-VL (7B and larger variants) can analyze and converse about image content. | |
|
| 32 |
+
| **DeepSeek-VL2** | `deepseek-ai/deepseek-vl2` | Vision-language variant of DeepSeek (with a dedicated image processor), enabling advanced multimodal reasoning on image and text inputs. | |
|
| 33 |
+
| **DeepSeek-OCR / OCR-2** | `deepseek-ai/DeepSeek-OCR-2` | OCR-focused DeepSeek models for document understanding and text extraction. | Use `--trust-remote-code`. |
|
| 34 |
+
| **Janus-Pro** (1B, 7B) | `deepseek-ai/Janus-Pro-7B` | DeepSeek's open-source multimodal model capable of both image understanding and generation. Janus-Pro employs a decoupled architecture for separate visual encoding paths, enhancing performance in both tasks. | |
|
| 35 |
+
| **MiniCPM-V / MiniCPM-o** | `openbmb/MiniCPM-V-2_6` | MiniCPM-V (2.6, ~8B) supports image inputs, and MiniCPM-o adds audio/video; these multimodal LLMs are optimized for end-side deployment on mobile/edge devices. | |
|
| 36 |
+
| **Llama 3.2 Vision** (11B) | `meta-llama/Llama-3.2-11B-Vision-Instruct` | Vision-enabled variant of Llama 3 (11B) that accepts image inputs for visual question answering and other multimodal tasks. | |
|
| 37 |
+
| **LLaVA** (v1.5 & v1.6) | *e.g.* `liuhaotian/llava-v1.5-13b` | Open vision-chat models that add an image encoder to LLaMA/Vicuna (e.g. LLaMA2 13B) for following multimodal instruction prompts. | |
|
| 38 |
+
| **LLaVA-NeXT** (8B, 72B) | `lmms-lab/llava-next-72b` | Improved LLaVA models (with an 8B Llama3 version and a 72B version) offering enhanced visual instruction-following and accuracy on multimodal benchmarks. | |
|
| 39 |
+
| **LLaVA-OneVision** | `lmms-lab/llava-onevision-qwen2-7b-ov` | Enhanced LLaVA variant integrating Qwen as the backbone; supports multiple images (and even video frames) as inputs via an OpenAI Vision API-compatible format. | |
|
| 40 |
+
| **Gemma 3 (Multimodal)** | `google/gemma-3-4b-it` | Gemma 3's larger models (4B, 12B, 27B) accept images (each image encoded as 256 tokens) alongside text in a combined 128K-token context. | |
|
| 41 |
+
| **Kimi-VL** (A3B) | `moonshotai/Kimi-VL-A3B-Instruct` | Kimi-VL is a multimodal model that can understand and generate text from images. | |
|
| 42 |
+
| **Mistral-Small-3.1-24B** | `mistralai/Mistral-Small-3.1-24B-Instruct-2503` | Mistral 3.1 is a multimodal model that can generate text from text or images input. It also supports tool calling and structured output. | |
|
| 43 |
+
| **Phi-4-multimodal-instruct** | `microsoft/Phi-4-multimodal-instruct` | Phi-4-multimodal-instruct is the multimodal variant of the Phi-4-mini model, enhanced with LoRA for improved multimodal capabilities. It supports text, vision and audio modalities in SGLang. | |
|
| 44 |
+
| **MiMo-VL** (7B) | `XiaomiMiMo/MiMo-VL-7B-RL` | Xiaomi's compact yet powerful vision-language model featuring a native resolution ViT encoder for fine-grained visual details, an MLP projector for cross-modal alignment, and the MiMo-7B language model optimized for complex reasoning tasks. | |
|
| 45 |
+
| **GLM-4.5V** (106B) / **GLM-4.1V**(9B) | `zai-org/GLM-4.5V` | GLM-4.5V and GLM-4.1V-Thinking: Towards Versatile Multimodal Reasoning with Scalable Reinforcement Learning | Use `--chat-template glm-4v` |
|
| 46 |
+
| **GLM-OCR** | `zai-org/GLM-OCR` | GLM-OCR: A fast and accurate general OCR model | |
|
| 47 |
+
| **DotsVLM** (General/OCR) | `rednote-hilab/dots.vlm1.inst` | RedNote's vision-language model built on a 1.2B vision encoder and DeepSeek V3 LLM, featuring NaViT vision encoder trained from scratch with dynamic resolution support and enhanced OCR capabilities through structured image data training. | |
|
| 48 |
+
| **DotsVLM-OCR** | `rednote-hilab/dots.ocr` | Specialized OCR variant of DotsVLM optimized for optical character recognition tasks with enhanced text extraction and document understanding capabilities. | Don't use `--trust-remote-code` |
|
| 49 |
+
| **NVILA** (8B, 15B, Lite-2B, Lite-8B, Lite-15B) | `Efficient-Large-Model/NVILA-8B` | `chatml` | NVILA explores the full stack efficiency of multi-modal design, achieving cheaper training, faster deployment and better performance. |
|
| 50 |
+
| **NVIDIA Nemotron Nano 2.0 VL** | `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16` | NVIDIA Nemotron Nano v2 VL enables multi-image reasoning and video understanding, along with strong document intelligence, visual Q&A and summarization capabilities. It builds on Nemotron Nano V2, a hybrid Mamba-Transformer LLM, in order to achieve higher inference throughput in long document and video scenarios. | Use `--trust-remote-code`. You may need to adjust `--max-mamba-cache-size` [default is 512] to fit memory constraints. |
|
| 51 |
+
| **Ernie4.5-VL** | `baidu/ERNIE-4.5-VL-28B-A3B-PT` | Baidu's vision-language models(28B,424B). Support image and video comprehension, and also support thinking. | |
|
| 52 |
+
| **JetVLM** | | JetVLM is an vision-language model designed for high-performance multimodal understanding and generation tasks built upon Jet-Nemotron. | Coming soon |
|
| 53 |
+
| **Step3-VL** (10B) | `stepfun-ai/Step3-VL-10B` | StepFun's lightweight open-source 10B parameter VLM for multimodal intelligence, excelling in visual perception, complex reasoning, and human alignment. | |
|
| 54 |
+
| **Qwen3-Omni** | `Qwen/Qwen3-Omni-30B-A3B-Instruct` | Alibaba's omni-modal MoE model. Currently supports the **Thinker** component (multimodal understanding for text, images, audio, and video), while the **Talker** component (audio generation) is not yet supported. | |
|
| 55 |
+
|
| 56 |
+
## Video Input Support
|
| 57 |
+
|
| 58 |
+
SGLang supports video input for Vision-Language Models (VLMs), enabling temporal reasoning tasks such as video question answering, captioning, and holistic scene understanding. Video clips are decoded, key frames are sampled, and the resulting tensors are batched together with the text prompt, allowing multimodal inference to integrate visual and linguistic context.
|
| 59 |
+
|
| 60 |
+
| Model Family | Example Identifier | Video notes |
|
| 61 |
+
|--------------|--------------------|-------------|
|
| 62 |
+
| **Qwen-VL** (Qwen2-VL, Qwen2.5-VL, Qwen3-VL, Qwen3-Omni) | `Qwen/Qwen3-VL-235B-A22B-Instruct` | The processor gathers `video_data`, runs Qwen's frame sampler, and merges the resulting features with text tokens before inference. |
|
| 63 |
+
| **GLM-4v** (4.5V, 4.1V, MOE) | `zai-org/GLM-4.5V` | Video clips are read with Decord, converted to tensors, and passed to the model alongside metadata for rotary-position handling. |
|
| 64 |
+
| **NVILA** (Full & Lite) | `Efficient-Large-Model/NVILA-8B` | The runtime samples eight frames per clip and attaches them to the multimodal request when `video_data` is present. |
|
| 65 |
+
| **LLaVA video variants** (LLaVA-NeXT-Video, LLaVA-OneVision) | `lmms-lab/LLaVA-NeXT-Video-7B` | The processor routes video prompts to the LlavaVid video-enabled architecture, and the provided example shows how to query it with `sgl.video(...)` clips. |
|
| 66 |
+
| **NVIDIA Nemotron Nano 2.0 VL** | `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16` | The processor samples at 2 FPS, at a max of 128 frames, as per model training. The model uses [EVS](../../python/sglang/srt/multimodal/evs/README.md), a pruning method that removes redundant tokens from video embeddings. By default `video_pruning_rate=0.7`. Change this by providing: `--json-model-override-args '{"video_pruning_rate": 0.0}'` to disable EVS, for example. |
|
| 67 |
+
| **JetVLM** | | The runtime samples eight frames per clip and attaches them to the multimodal request when `video_data` is present. |
|
| 68 |
+
|
| 69 |
+
Use `sgl.video(path, num_frames)` when building prompts to attach clips from your SGLang programs.
|
| 70 |
+
|
| 71 |
+
Example OpenAI-compatible request that sends a video clip:
|
| 72 |
+
|
| 73 |
+
```python
|
| 74 |
+
import requests
|
| 75 |
+
|
| 76 |
+
url = "http://localhost:30000/v1/chat/completions"
|
| 77 |
+
|
| 78 |
+
data = {
|
| 79 |
+
"model": "Qwen/Qwen3-VL-30B-A3B-Instruct",
|
| 80 |
+
"messages": [
|
| 81 |
+
{
|
| 82 |
+
"role": "user",
|
| 83 |
+
"content": [
|
| 84 |
+
{"type": "text", "text": "What’s happening in this video?"},
|
| 85 |
+
{
|
| 86 |
+
"type": "video_url",
|
| 87 |
+
"video_url": {
|
| 88 |
+
"url": "https://github.com/sgl-project/sgl-test-files/raw/refs/heads/main/videos/jobs_presenting_ipod.mp4"
|
| 89 |
+
},
|
| 90 |
+
},
|
| 91 |
+
],
|
| 92 |
+
}
|
| 93 |
+
],
|
| 94 |
+
"max_tokens": 300,
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
response = requests.post(url, json=data)
|
| 98 |
+
print(response.text)
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
## Usage Notes
|
| 102 |
+
|
| 103 |
+
### Performance Optimization
|
| 104 |
+
|
| 105 |
+
For multimodal models, you can use the `--keep-mm-feature-on-device` flag to optimize for latency at the cost of increased GPU memory usage:
|
| 106 |
+
|
| 107 |
+
- **Default behavior**: Multimodal feature tensors are moved to CPU after processing to save GPU memory
|
| 108 |
+
- **With `--keep-mm-feature-on-device`**: Feature tensors remain on GPU, reducing device-to-host copy overhead and improving latency, but consuming more GPU memory
|
| 109 |
+
|
| 110 |
+
Use this flag when you have sufficient GPU memory and want to minimize latency for multimodal inference.
|
| 111 |
+
|
| 112 |
+
### Multimodal Inputs Limitation
|
| 113 |
+
|
| 114 |
+
- **Use `--mm-process-config '{"image":{"max_pixels":1048576},"video":{"fps":3,"max_pixels":602112,"max_frames":60}}'`**: To set `image`, `video`, and `audio` input limits.
|
| 115 |
+
|
| 116 |
+
This can reduce GPU memory usage, improve inference speed, and help to avoid OOM, but may impact model performance, thus set a proper value based on your specific use case. Currently, only `qwen_vl` supports this config. Please refer to [qwen_vl processor](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/multimodal/processors/qwen_vl.py) for understanding the meaning of each parameter.
|
| 117 |
+
|
| 118 |
+
### Bidirectional Attention in Multimodal Model Serving
|
| 119 |
+
**Note for serving the Gemma-3 multimodal model**:
|
| 120 |
+
|
| 121 |
+
As mentioned in [Welcome Gemma 3: Google's all new multimodal, multilingual, long context open LLM
|
| 122 |
+
](https://huggingface.co/blog/gemma3#multimodality), Gemma-3 employs bidirectional attention between image tokens during the prefill phase. Currently, SGLang only supports bidirectional attention when using the Triton Attention Backend. Note, however, that SGLang's current bidirectional attention implementation is incompatible with both CUDA Graph and Chunked Prefill.
|
| 123 |
+
|
| 124 |
+
To enable bidirectional attention, you can use the `TritonAttnBackend` while disabling CUDA Graph and Chunked Prefill. Example launch command:
|
| 125 |
+
```shell
|
| 126 |
+
python -m sglang.launch_server \
|
| 127 |
+
--model-path google/gemma-3-4b-it \
|
| 128 |
+
--host 0.0.0.0 --port 30000 \
|
| 129 |
+
--enable-multimodal \
|
| 130 |
+
--dtype bfloat16 --triton-attention-reduce-in-fp32 \
|
| 131 |
+
--attention-backend triton \ # Use Triton attention backend
|
| 132 |
+
--disable-cuda-graph \ # Disable Cuda Graph
|
| 133 |
+
--chunked-prefill-size -1 # Disable Chunked Prefill
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
If higher serving performance is required and a certain degree of accuracy loss is acceptable, you may choose to use other attention backends, and you can also enable features like CUDA Graph and Chunked Prefill for better performance, but note that the model will fall back to using causal attention instead of bidirectional attention.
|
sglang/python/sglang/srt/__pycache__/constants.cpython-311.pyc
ADDED
|
Binary file (348 Bytes). View file
|
|
|
sglang/python/sglang/srt/__pycache__/environ.cpython-311.pyc
ADDED
|
Binary file (35.1 kB). View file
|
|
|
sglang/python/sglang/srt/batch_overlap/__pycache__/operations.cpython-311.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
sglang/python/sglang/srt/batch_overlap/__pycache__/operations_strategy.cpython-311.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
sglang/python/sglang/srt/batch_overlap/__pycache__/single_batch_overlap.cpython-311.pyc
ADDED
|
Binary file (5.8 kB). View file
|
|
|
sglang/python/sglang/srt/batch_overlap/__pycache__/two_batch_overlap.cpython-311.pyc
ADDED
|
Binary file (42.2 kB). View file
|
|
|
sglang/python/sglang/srt/batch_overlap/operations.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from contextlib import contextmanager
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Sequence, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from sglang.srt.layers.dp_attention import set_dp_buffer_len
|
| 11 |
+
|
| 12 |
+
if TYPE_CHECKING:
|
| 13 |
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
| 14 |
+
|
| 15 |
+
_ENABLE_PROFILE = bool(int(os.environ.get("SGLANG_OPERATIONS_ENABLE_PROFILE", "0")))
|
| 16 |
+
|
| 17 |
+
if _ENABLE_PROFILE:
|
| 18 |
+
import nvtx
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def execute_operations(inputs, operations):
|
| 22 |
+
stages = _convert_operations_to_stages(operations)
|
| 23 |
+
executor = _StageExecutor("primary", stages, inputs=inputs)
|
| 24 |
+
for _ in range(executor.num_stages):
|
| 25 |
+
executor.next()
|
| 26 |
+
assert executor.done
|
| 27 |
+
return executor.output
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def execute_overlapped_operations(
|
| 31 |
+
inputs_arr: Sequence,
|
| 32 |
+
operations_arr: Sequence,
|
| 33 |
+
delta_stages: Sequence[int],
|
| 34 |
+
) -> Sequence:
|
| 35 |
+
# Make it explicit for clarity; if we need multi-batch overlap, this can be generalized
|
| 36 |
+
inputs_a, inputs_b = inputs_arr
|
| 37 |
+
operations_a, operations_b = operations_arr
|
| 38 |
+
delta_stage_a, delta_stage_b = delta_stages
|
| 39 |
+
assert delta_stage_a == 0
|
| 40 |
+
delta_stage = delta_stage_b
|
| 41 |
+
|
| 42 |
+
stages_a = _convert_operations_to_stages(operations_a)
|
| 43 |
+
stages_b = _convert_operations_to_stages(operations_b)
|
| 44 |
+
executor_a = _StageExecutor("a", stages_a, inputs=inputs_a)
|
| 45 |
+
executor_b = _StageExecutor("b", stages_b, inputs=inputs_b)
|
| 46 |
+
|
| 47 |
+
for _ in range(delta_stage):
|
| 48 |
+
executor_a.next()
|
| 49 |
+
|
| 50 |
+
for _ in range(executor_a.num_stages - delta_stage):
|
| 51 |
+
executor_a.next()
|
| 52 |
+
executor_b.next()
|
| 53 |
+
|
| 54 |
+
for _ in range(delta_stage):
|
| 55 |
+
executor_b.next()
|
| 56 |
+
|
| 57 |
+
assert executor_a.done and executor_b.done
|
| 58 |
+
return [executor_a.output, executor_b.output]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class YieldOperation:
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@dataclass
|
| 66 |
+
class ExecutionOperation:
|
| 67 |
+
debug_name: str
|
| 68 |
+
fn: Callable
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
Operation = Union[YieldOperation, ExecutionOperation, Callable]
|
| 72 |
+
Stage = List[ExecutionOperation]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class _StageExecutor:
|
| 76 |
+
def __init__(self, debug_name: str, stages: List[Stage], inputs: dict):
|
| 77 |
+
self._debug_name = debug_name
|
| 78 |
+
self._stages = stages
|
| 79 |
+
self._index = 0
|
| 80 |
+
self._stage_state = _StateDict()
|
| 81 |
+
self._stage_output = inputs
|
| 82 |
+
|
| 83 |
+
# handling DP attention
|
| 84 |
+
forward_batch: ForwardBatch = inputs["forward_batch"]
|
| 85 |
+
self._global_dp_buffer_len = forward_batch.global_dp_buffer_len
|
| 86 |
+
self._local_dp_buffer_len = forward_batch.tbo_padded_len
|
| 87 |
+
self._global_num_tokens = forward_batch.global_num_tokens_cpu
|
| 88 |
+
self._is_dp_max_padding = forward_batch.dp_padding_mode.is_max_len()
|
| 89 |
+
|
| 90 |
+
def next(self):
|
| 91 |
+
assert not self.done
|
| 92 |
+
|
| 93 |
+
stage = self._stages[self._index]
|
| 94 |
+
|
| 95 |
+
# TODO: We currently always call set_dp_buffer_len here because sub-batches
|
| 96 |
+
# may have different padded lengths. It can likely be removed after TBO slice &
|
| 97 |
+
# pad logic is refactored.
|
| 98 |
+
set_dp_buffer_len(
|
| 99 |
+
self._global_dp_buffer_len,
|
| 100 |
+
self._local_dp_buffer_len,
|
| 101 |
+
self._is_dp_max_padding,
|
| 102 |
+
self._global_num_tokens,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
|
| 106 |
+
for op in stage:
|
| 107 |
+
with _annotate_region(debug_name=op.debug_name):
|
| 108 |
+
self._stage_output = op.fn(
|
| 109 |
+
state=self._stage_state,
|
| 110 |
+
**(
|
| 111 |
+
self._stage_output if self._stage_output is not None else {}
|
| 112 |
+
),
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self._index += 1
|
| 116 |
+
|
| 117 |
+
@property
|
| 118 |
+
def output(self):
|
| 119 |
+
assert self.done
|
| 120 |
+
return self._stage_output
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def done(self):
|
| 124 |
+
return self._index >= self.num_stages
|
| 125 |
+
|
| 126 |
+
@property
|
| 127 |
+
def num_stages(self):
|
| 128 |
+
return len(self._stages)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@contextmanager
|
| 132 |
+
def _annotate_region(debug_name):
|
| 133 |
+
if _ENABLE_PROFILE:
|
| 134 |
+
with torch.autograd.profiler.record_function(debug_name):
|
| 135 |
+
with nvtx.annotate(debug_name):
|
| 136 |
+
yield
|
| 137 |
+
else:
|
| 138 |
+
yield
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class _StateDict:
|
| 142 |
+
def __init__(self):
|
| 143 |
+
self._data = {}
|
| 144 |
+
|
| 145 |
+
def __setattr__(self, key, value):
|
| 146 |
+
if key == "_data":
|
| 147 |
+
super().__setattr__(key, value)
|
| 148 |
+
return
|
| 149 |
+
assert (
|
| 150 |
+
key not in self._data
|
| 151 |
+
), f"`{key}` already exist, are you sure you want to override it?"
|
| 152 |
+
self._data[key] = value
|
| 153 |
+
|
| 154 |
+
def __getattr__(self, item):
|
| 155 |
+
return self._data[item]
|
| 156 |
+
|
| 157 |
+
def __delattr__(self, item):
|
| 158 |
+
del self._data[item]
|
| 159 |
+
|
| 160 |
+
def pop(self, item):
|
| 161 |
+
return self._data.pop(item)
|
| 162 |
+
|
| 163 |
+
def update(self, values: Dict[str, Any]):
|
| 164 |
+
for k, v in values.items():
|
| 165 |
+
setattr(self, k, v)
|
| 166 |
+
|
| 167 |
+
def get(self, item):
|
| 168 |
+
return self._data.get(item)
|
| 169 |
+
|
| 170 |
+
def clear(self, expect_keys: Sequence[str]):
|
| 171 |
+
if set(self._data.keys()) != set(expect_keys):
|
| 172 |
+
raise Exception(
|
| 173 |
+
f"Unexpected keys when clearing. This may indicate you do not release memory early enough but leave it until here. {list(self._data.keys())=} {expect_keys=}"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
self._data.clear()
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _convert_operations_to_stages(operations: List[Operation]) -> List[Stage]:
|
| 180 |
+
operations = _decorate_operations(operations)
|
| 181 |
+
operation_chunks = list(
|
| 182 |
+
_chunk_by_separator(operations, lambda op: isinstance(op, YieldOperation))
|
| 183 |
+
)
|
| 184 |
+
assert all(len(chunk) > 0 for chunk in operation_chunks)
|
| 185 |
+
return operation_chunks
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _chunk_by_separator(
|
| 189 |
+
items: List[Any], is_separator: Callable[[Any], bool]
|
| 190 |
+
) -> Generator[List[Any], None, None]:
|
| 191 |
+
pending_items = []
|
| 192 |
+
for item in items:
|
| 193 |
+
if is_separator(item):
|
| 194 |
+
yield pending_items
|
| 195 |
+
pending_items = []
|
| 196 |
+
else:
|
| 197 |
+
pending_items.append(item)
|
| 198 |
+
if len(pending_items) > 0:
|
| 199 |
+
yield pending_items
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def _decorate_operations(operations: List[Operation], debug_name_prefix: str = ""):
|
| 203 |
+
return [_decorate_operation(op, debug_name_prefix) for op in operations]
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def _decorate_operation(operation: Operation, debug_name_prefix: str):
|
| 207 |
+
if isinstance(operation, YieldOperation):
|
| 208 |
+
return operation
|
| 209 |
+
return ExecutionOperation(
|
| 210 |
+
debug_name=debug_name_prefix
|
| 211 |
+
+ getattr(operation, "__name__", "unknown").replace("op_", ""),
|
| 212 |
+
fn=operation,
|
| 213 |
+
)
|
sglang/python/sglang/srt/batch_overlap/operations_strategy.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from sglang.srt.batch_overlap import operations
|
| 7 |
+
from sglang.srt.batch_overlap.operations import Operation
|
| 8 |
+
from sglang.srt.layers.moe.token_dispatcher import DeepEPConfig
|
| 9 |
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
| 10 |
+
from sglang.srt.utils import is_hip
|
| 11 |
+
|
| 12 |
+
_is_hip = is_hip()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class OperationsStrategy:
|
| 17 |
+
operations: List[Operation]
|
| 18 |
+
deep_gemm_num_sms: Optional[int] = None
|
| 19 |
+
tbo_delta_stages: Optional[int] = None
|
| 20 |
+
|
| 21 |
+
@classmethod
|
| 22 |
+
def concat(cls, items: List["OperationsStrategy"]) -> "OperationsStrategy":
|
| 23 |
+
return OperationsStrategy(
|
| 24 |
+
operations=[x for item in items for x in item.operations],
|
| 25 |
+
deep_gemm_num_sms=_assert_all_same(
|
| 26 |
+
[item.deep_gemm_num_sms for item in items]
|
| 27 |
+
),
|
| 28 |
+
tbo_delta_stages=_assert_all_same(
|
| 29 |
+
[item.tbo_delta_stages for item in items]
|
| 30 |
+
),
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
@staticmethod
|
| 34 |
+
def init_new_tbo(
|
| 35 |
+
layers: torch.nn.ModuleList,
|
| 36 |
+
forward_mode: ForwardMode,
|
| 37 |
+
) -> "OperationsStrategy":
|
| 38 |
+
layer_name = layers[0].__class__.__name__
|
| 39 |
+
if layer_name == "DeepseekV2DecoderLayer":
|
| 40 |
+
return OperationsStrategy.concat(
|
| 41 |
+
[
|
| 42 |
+
_compute_moe_deepseek_layer_operations_strategy_tbo(
|
| 43 |
+
layer, forward_mode
|
| 44 |
+
)
|
| 45 |
+
for layer in layers
|
| 46 |
+
]
|
| 47 |
+
)
|
| 48 |
+
elif layer_name == "Qwen3MoeDecoderLayer":
|
| 49 |
+
return OperationsStrategy.concat(
|
| 50 |
+
[
|
| 51 |
+
_compute_moe_qwen3_layer_operations_strategy_tbo(
|
| 52 |
+
layer, forward_mode
|
| 53 |
+
)
|
| 54 |
+
for layer in layers
|
| 55 |
+
]
|
| 56 |
+
)
|
| 57 |
+
elif layer_name == "MiMoV2DecoderLayer":
|
| 58 |
+
return OperationsStrategy.concat(
|
| 59 |
+
[
|
| 60 |
+
_compute_moe_mimov2_layer_operations_strategy_tbo(
|
| 61 |
+
layer, forward_mode
|
| 62 |
+
)
|
| 63 |
+
for layer in layers
|
| 64 |
+
]
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
raise NotImplementedError
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _assert_all_same(items: List):
|
| 71 |
+
assert all(item == items[0] for item in items)
|
| 72 |
+
return items[0]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# -------------------------------- Strategy for DeepSeek ---------------------------------------
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# TODO can refactor to make it more fancy if we have more complex strategies
|
| 79 |
+
def _compute_moe_deepseek_layer_operations_strategy_tbo(
|
| 80 |
+
layer: torch.nn.Module,
|
| 81 |
+
forward_mode: ForwardMode,
|
| 82 |
+
) -> OperationsStrategy:
|
| 83 |
+
assert layer.is_layer_sparse, "dense layer TBO not yet implemented"
|
| 84 |
+
if forward_mode == ForwardMode.EXTEND:
|
| 85 |
+
return _compute_moe_deepseek_blog_prefill(layer)
|
| 86 |
+
elif (
|
| 87 |
+
forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY
|
| 88 |
+
):
|
| 89 |
+
return _compute_moe_deepseek_blog_decode(layer)
|
| 90 |
+
else:
|
| 91 |
+
raise NotImplementedError(f"Unsupported {forward_mode=}")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _compute_moe_deepseek_blog_prefill(layer):
|
| 95 |
+
device_properties = torch.cuda.get_device_properties(device="cuda")
|
| 96 |
+
total_num_sms = device_properties.multi_processor_count
|
| 97 |
+
deep_gemm_num_sms = None
|
| 98 |
+
if not _is_hip:
|
| 99 |
+
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
|
| 100 |
+
|
| 101 |
+
return OperationsStrategy(
|
| 102 |
+
deep_gemm_num_sms=deep_gemm_num_sms,
|
| 103 |
+
tbo_delta_stages=0,
|
| 104 |
+
operations=[
|
| 105 |
+
layer.op_comm_prepare_attn,
|
| 106 |
+
layer.self_attn.op_prepare,
|
| 107 |
+
layer.self_attn.op_core,
|
| 108 |
+
layer.op_comm_prepare_mlp,
|
| 109 |
+
layer.mlp.op_gate,
|
| 110 |
+
layer.mlp.op_select_experts,
|
| 111 |
+
layer.mlp.op_dispatch_a,
|
| 112 |
+
operations.YieldOperation(),
|
| 113 |
+
layer.mlp.op_dispatch_b,
|
| 114 |
+
layer.mlp.op_experts,
|
| 115 |
+
layer.mlp.op_combine_a,
|
| 116 |
+
operations.YieldOperation(),
|
| 117 |
+
layer.mlp.op_shared_experts,
|
| 118 |
+
layer.mlp.op_combine_b,
|
| 119 |
+
layer.mlp.op_output,
|
| 120 |
+
layer.op_comm_postprocess_layer,
|
| 121 |
+
],
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _compute_moe_deepseek_blog_decode(layer):
|
| 126 |
+
return OperationsStrategy(
|
| 127 |
+
deep_gemm_num_sms=None,
|
| 128 |
+
tbo_delta_stages=2,
|
| 129 |
+
operations=[
|
| 130 |
+
layer.op_comm_prepare_attn,
|
| 131 |
+
layer.self_attn.op_prepare,
|
| 132 |
+
operations.YieldOperation(),
|
| 133 |
+
layer.self_attn.op_core,
|
| 134 |
+
layer.op_comm_prepare_mlp,
|
| 135 |
+
layer.mlp.op_gate,
|
| 136 |
+
layer.mlp.op_select_experts,
|
| 137 |
+
operations.YieldOperation(),
|
| 138 |
+
layer.mlp.op_dispatch_a,
|
| 139 |
+
layer.mlp.op_shared_experts,
|
| 140 |
+
operations.YieldOperation(),
|
| 141 |
+
layer.mlp.op_dispatch_b,
|
| 142 |
+
layer.mlp.op_experts,
|
| 143 |
+
layer.mlp.op_combine_a,
|
| 144 |
+
operations.YieldOperation(),
|
| 145 |
+
layer.mlp.op_combine_b,
|
| 146 |
+
operations.YieldOperation(),
|
| 147 |
+
layer.mlp.op_output,
|
| 148 |
+
layer.op_comm_postprocess_layer,
|
| 149 |
+
],
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# -------------------------------- Strategy for Qwen3 ---------------------------------------
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# TODO: unstable, current strategy is almost the same as DeepSeek, keep redundant code here for
|
| 157 |
+
# convenience to adjust strategy
|
| 158 |
+
def _compute_moe_qwen3_layer_operations_strategy_tbo(
|
| 159 |
+
layer: torch.nn.Module,
|
| 160 |
+
forward_mode: ForwardMode,
|
| 161 |
+
) -> OperationsStrategy:
|
| 162 |
+
assert layer.is_layer_sparse, "qwen3 moe only support sparse layers"
|
| 163 |
+
if forward_mode == ForwardMode.EXTEND:
|
| 164 |
+
return _compute_moe_qwen3_prefill(layer)
|
| 165 |
+
elif (
|
| 166 |
+
forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY
|
| 167 |
+
):
|
| 168 |
+
return _compute_moe_qwen3_decode(layer)
|
| 169 |
+
else:
|
| 170 |
+
raise NotImplementedError(f"Unsupported {forward_mode=}")
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _compute_moe_qwen3_prefill(layer):
|
| 174 |
+
device_properties = torch.cuda.get_device_properties(device="cuda")
|
| 175 |
+
total_num_sms = device_properties.multi_processor_count
|
| 176 |
+
deep_gemm_num_sms = None
|
| 177 |
+
if not _is_hip:
|
| 178 |
+
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
|
| 179 |
+
|
| 180 |
+
return OperationsStrategy(
|
| 181 |
+
deep_gemm_num_sms=deep_gemm_num_sms,
|
| 182 |
+
tbo_delta_stages=0,
|
| 183 |
+
operations=[
|
| 184 |
+
layer.op_comm_prepare_attn,
|
| 185 |
+
layer.self_attn.op_prepare,
|
| 186 |
+
layer.self_attn.op_core,
|
| 187 |
+
layer.op_comm_prepare_mlp,
|
| 188 |
+
layer.mlp.op_gate,
|
| 189 |
+
layer.mlp.op_select_experts,
|
| 190 |
+
layer.mlp.op_dispatch_a,
|
| 191 |
+
operations.YieldOperation(),
|
| 192 |
+
layer.mlp.op_dispatch_b,
|
| 193 |
+
layer.mlp.op_experts,
|
| 194 |
+
layer.mlp.op_combine_a,
|
| 195 |
+
operations.YieldOperation(),
|
| 196 |
+
layer.mlp.op_combine_b,
|
| 197 |
+
layer.mlp.op_output,
|
| 198 |
+
layer.op_comm_postprocess_layer,
|
| 199 |
+
],
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def _compute_moe_qwen3_decode(layer):
|
| 204 |
+
return OperationsStrategy(
|
| 205 |
+
deep_gemm_num_sms=None,
|
| 206 |
+
tbo_delta_stages=2,
|
| 207 |
+
operations=[
|
| 208 |
+
layer.op_comm_prepare_attn,
|
| 209 |
+
layer.self_attn.op_prepare,
|
| 210 |
+
operations.YieldOperation(),
|
| 211 |
+
layer.self_attn.op_core,
|
| 212 |
+
layer.op_comm_prepare_mlp,
|
| 213 |
+
layer.mlp.op_gate,
|
| 214 |
+
layer.mlp.op_select_experts,
|
| 215 |
+
operations.YieldOperation(),
|
| 216 |
+
layer.mlp.op_dispatch_a,
|
| 217 |
+
operations.YieldOperation(),
|
| 218 |
+
layer.mlp.op_dispatch_b,
|
| 219 |
+
layer.mlp.op_experts,
|
| 220 |
+
layer.mlp.op_combine_a,
|
| 221 |
+
operations.YieldOperation(),
|
| 222 |
+
layer.mlp.op_combine_b,
|
| 223 |
+
layer.mlp.op_output,
|
| 224 |
+
layer.op_comm_postprocess_layer,
|
| 225 |
+
operations.YieldOperation(),
|
| 226 |
+
],
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# -------------------------------- Strategy for MiMoV2DecoderLayer ---------------------------------------
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# TODO: unstable; current strategy matches DeepSeek for the common operations (MiMoV2 has no op_shared_experts),
|
| 234 |
+
# so we keep this redundant code here for convenience when adjusting the strategy
|
| 235 |
+
def _compute_moe_mimov2_layer_operations_strategy_tbo(
|
| 236 |
+
layer: torch.nn.Module,
|
| 237 |
+
forward_mode: ForwardMode,
|
| 238 |
+
) -> OperationsStrategy:
|
| 239 |
+
assert layer.is_layer_sparse, "MiMoV2DecoderLayer moe only support sparse layers"
|
| 240 |
+
if forward_mode == ForwardMode.EXTEND:
|
| 241 |
+
return _compute_moe_mimov2_prefill(layer)
|
| 242 |
+
elif (
|
| 243 |
+
forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY
|
| 244 |
+
):
|
| 245 |
+
return _compute_moe_mimov2_decode(layer)
|
| 246 |
+
else:
|
| 247 |
+
raise NotImplementedError(f"Unsupported {forward_mode=}")
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _compute_moe_mimov2_prefill(layer):
|
| 251 |
+
device_properties = torch.cuda.get_device_properties(device="cuda")
|
| 252 |
+
total_num_sms = device_properties.multi_processor_count
|
| 253 |
+
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
|
| 254 |
+
|
| 255 |
+
return OperationsStrategy(
|
| 256 |
+
deep_gemm_num_sms=deep_gemm_num_sms,
|
| 257 |
+
tbo_delta_stages=0,
|
| 258 |
+
operations=[
|
| 259 |
+
layer.op_comm_prepare_attn,
|
| 260 |
+
layer.self_attn.op_prepare,
|
| 261 |
+
layer.self_attn.op_core,
|
| 262 |
+
layer.op_comm_prepare_mlp,
|
| 263 |
+
layer.mlp.op_gate,
|
| 264 |
+
layer.mlp.op_select_experts,
|
| 265 |
+
layer.mlp.op_dispatch_a,
|
| 266 |
+
operations.YieldOperation(),
|
| 267 |
+
layer.mlp.op_dispatch_b,
|
| 268 |
+
layer.mlp.op_experts,
|
| 269 |
+
layer.mlp.op_combine_a,
|
| 270 |
+
operations.YieldOperation(),
|
| 271 |
+
layer.mlp.op_combine_b,
|
| 272 |
+
layer.mlp.op_output,
|
| 273 |
+
layer.op_comm_postprocess_layer,
|
| 274 |
+
],
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def _compute_moe_mimov2_decode(layer):
|
| 279 |
+
return OperationsStrategy(
|
| 280 |
+
deep_gemm_num_sms=None,
|
| 281 |
+
tbo_delta_stages=2,
|
| 282 |
+
operations=[
|
| 283 |
+
layer.op_comm_prepare_attn,
|
| 284 |
+
layer.self_attn.op_prepare,
|
| 285 |
+
operations.YieldOperation(),
|
| 286 |
+
layer.self_attn.op_core,
|
| 287 |
+
layer.op_comm_prepare_mlp,
|
| 288 |
+
layer.mlp.op_gate,
|
| 289 |
+
layer.mlp.op_select_experts,
|
| 290 |
+
operations.YieldOperation(),
|
| 291 |
+
layer.mlp.op_dispatch_a,
|
| 292 |
+
operations.YieldOperation(),
|
| 293 |
+
layer.mlp.op_dispatch_b,
|
| 294 |
+
layer.mlp.op_experts,
|
| 295 |
+
layer.mlp.op_combine_a,
|
| 296 |
+
operations.YieldOperation(),
|
| 297 |
+
layer.mlp.op_combine_b,
|
| 298 |
+
layer.mlp.op_output,
|
| 299 |
+
layer.op_comm_postprocess_layer,
|
| 300 |
+
operations.YieldOperation(),
|
| 301 |
+
],
|
| 302 |
+
)
|
sglang/python/sglang/srt/batch_overlap/single_batch_overlap.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 SGLang Team
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ==============================================================================
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Optional
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from sglang.srt.environ import envs
|
| 23 |
+
from sglang.srt.layers.moe import get_moe_runner_backend
|
| 24 |
+
from sglang.srt.layers.moe.utils import is_sbo_enabled
|
| 25 |
+
from sglang.srt.utils import is_blackwell
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SboFlags:
|
| 29 |
+
# TODO may have: "enable_dispatch_gateup_gemm_two_stream_overlap", ...
|
| 30 |
+
|
| 31 |
+
@classmethod
|
| 32 |
+
def enable_combine_down_gemm_two_stream_overlap(cls):
|
| 33 |
+
return (
|
| 34 |
+
is_sbo_enabled()
|
| 35 |
+
# currently only cutedsl backend supports it
|
| 36 |
+
and (
|
| 37 |
+
get_moe_runner_backend().is_flashinfer_cutedsl()
|
| 38 |
+
or (get_moe_runner_backend().is_deep_gemm() and not is_blackwell())
|
| 39 |
+
)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
@classmethod
|
| 43 |
+
def enable_combine_shared_two_stream_overlap(cls):
|
| 44 |
+
return (
|
| 45 |
+
is_sbo_enabled()
|
| 46 |
+
and not cls.enable_dispatch_shared_one_stream_overlap()
|
| 47 |
+
and not envs.SGLANG_BLACKWELL_OVERLAP_SHARED_EXPERTS_OUTSIDE_SBO.get()
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
@classmethod
|
| 51 |
+
def enable_dispatch_shared_one_stream_overlap(cls):
|
| 52 |
+
return is_sbo_enabled() and not is_blackwell()
|
| 53 |
+
|
| 54 |
+
@classmethod
|
| 55 |
+
def fuse_shared_experts_inside_sbo(cls):
|
| 56 |
+
return (
|
| 57 |
+
cls.enable_combine_shared_two_stream_overlap()
|
| 58 |
+
or cls.enable_dispatch_shared_one_stream_overlap()
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@dataclass
|
| 63 |
+
class CombineOverlapArgs:
|
| 64 |
+
# this "overlap" flag means overlapping with down gemm, not the general two-stream overlap
|
| 65 |
+
overlap: bool
|
| 66 |
+
stream: torch.cuda.Stream
|
| 67 |
+
wait_event: torch.cuda.Event
|
| 68 |
+
num_sms: Optional[int] = None
|
| 69 |
+
signal: Optional[torch.Tensor] = None
|
| 70 |
+
block_m: Optional[int] = 64
|
| 71 |
+
threshold: Optional[int] = 0
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@dataclass
|
| 75 |
+
class DownGemmOverlapArgs:
|
| 76 |
+
num_sms: int
|
| 77 |
+
signal: torch.Tensor
|
| 78 |
+
start_event: torch.cuda.Event
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def compute_overlap_args(dispatch_output, alt_stream):
|
| 82 |
+
if not (
|
| 83 |
+
SboFlags.enable_combine_down_gemm_two_stream_overlap()
|
| 84 |
+
or SboFlags.enable_combine_shared_two_stream_overlap()
|
| 85 |
+
):
|
| 86 |
+
return None, None, {}
|
| 87 |
+
|
| 88 |
+
hidden_states = dispatch_output.hidden_states
|
| 89 |
+
|
| 90 |
+
num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape
|
| 91 |
+
|
| 92 |
+
total_num_sms = torch.cuda.get_device_properties(
|
| 93 |
+
device="cuda"
|
| 94 |
+
).multi_processor_count
|
| 95 |
+
|
| 96 |
+
if envs.SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS.is_set():
|
| 97 |
+
communicate_num_sms = envs.SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS.get()
|
| 98 |
+
else:
|
| 99 |
+
communicate_num_sms = 32 if is_blackwell() else 3
|
| 100 |
+
compute_num_sms = total_num_sms - communicate_num_sms
|
| 101 |
+
|
| 102 |
+
assert alt_stream is not None
|
| 103 |
+
combine_wait_event = torch.cuda.Event()
|
| 104 |
+
combine_overlap_args = CombineOverlapArgs(
|
| 105 |
+
overlap=False,
|
| 106 |
+
num_sms=communicate_num_sms,
|
| 107 |
+
stream=alt_stream,
|
| 108 |
+
wait_event=combine_wait_event,
|
| 109 |
+
)
|
| 110 |
+
meta_overlap_args = dict(
|
| 111 |
+
compute_num_sms=compute_num_sms,
|
| 112 |
+
)
|
| 113 |
+
down_gemm_overlap_args = None
|
| 114 |
+
|
| 115 |
+
if SboFlags.enable_combine_down_gemm_two_stream_overlap():
|
| 116 |
+
# TODO use zero_allocator to remove this `torch.zeros` call
|
| 117 |
+
# NOTE ours v2 use uint32 not int32 currently
|
| 118 |
+
if is_blackwell():
|
| 119 |
+
combine_signal = torch.zeros(
|
| 120 |
+
num_local_experts, dtype=torch.uint32, device=hidden_states.device
|
| 121 |
+
)
|
| 122 |
+
else:
|
| 123 |
+
MIN_BLOCK_M = 64
|
| 124 |
+
combine_signal_size = num_local_experts * (
|
| 125 |
+
(num_tokens_static + MIN_BLOCK_M - 1) // MIN_BLOCK_M
|
| 126 |
+
)
|
| 127 |
+
combine_signal = torch.zeros(
|
| 128 |
+
combine_signal_size, dtype=torch.int32, device=hidden_states.device
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
down_gemm_overlap_args = DownGemmOverlapArgs(
|
| 132 |
+
signal=combine_signal,
|
| 133 |
+
start_event=combine_wait_event,
|
| 134 |
+
num_sms=compute_num_sms,
|
| 135 |
+
)
|
| 136 |
+
combine_overlap_args.overlap = True
|
| 137 |
+
combine_overlap_args.signal = combine_signal
|
| 138 |
+
combine_overlap_args.threshold = compute_num_sms
|
| 139 |
+
else:
|
| 140 |
+
meta_overlap_args |= dict(
|
| 141 |
+
record_event_after_down=combine_wait_event,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return combine_overlap_args, down_gemm_overlap_args, meta_overlap_args
|
sglang/python/sglang/srt/batch_overlap/two_batch_overlap.py
ADDED
|
@@ -0,0 +1,1082 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import dataclasses
|
| 5 |
+
import logging
|
| 6 |
+
from dataclasses import replace
|
| 7 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from sglang.srt.batch_overlap.operations import (
|
| 12 |
+
execute_operations,
|
| 13 |
+
execute_overlapped_operations,
|
| 14 |
+
)
|
| 15 |
+
from sglang.srt.batch_overlap.operations_strategy import OperationsStrategy
|
| 16 |
+
from sglang.srt.layers import deep_gemm_wrapper
|
| 17 |
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
| 18 |
+
from sglang.srt.layers.communicator import (
|
| 19 |
+
CommunicateContext,
|
| 20 |
+
CommunicateSummableTensorPairFn,
|
| 21 |
+
ScatterMode,
|
| 22 |
+
)
|
| 23 |
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
| 24 |
+
from sglang.srt.layers.moe import (
|
| 25 |
+
get_deepep_mode,
|
| 26 |
+
get_moe_a2a_backend,
|
| 27 |
+
get_tbo_token_distribution_threshold,
|
| 28 |
+
is_tbo_enabled,
|
| 29 |
+
)
|
| 30 |
+
from sglang.srt.layers.moe.token_dispatcher import (
|
| 31 |
+
DeepEPDispatcher,
|
| 32 |
+
MooncakeEPDispatcher,
|
| 33 |
+
MoriEPDispatcher,
|
| 34 |
+
)
|
| 35 |
+
from sglang.srt.layers.moe.token_dispatcher.base import BaseDispatcher
|
| 36 |
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
| 37 |
+
from sglang.srt.model_executor.forward_batch_info import (
|
| 38 |
+
ForwardBatch,
|
| 39 |
+
ForwardMode,
|
| 40 |
+
compute_position,
|
| 41 |
+
)
|
| 42 |
+
from sglang.srt.server_args import get_global_server_args
|
| 43 |
+
from sglang.srt.speculative.spec_info import SpecInput
|
| 44 |
+
from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip
|
| 45 |
+
|
| 46 |
+
if TYPE_CHECKING:
|
| 47 |
+
from sglang.srt.batch_overlap.single_batch_overlap import CombineOverlapArgs
|
| 48 |
+
from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
|
| 49 |
+
from sglang.srt.speculative.eagle_info import EagleVerifyInput
|
| 50 |
+
|
| 51 |
+
_is_hip = is_hip()
|
| 52 |
+
|
| 53 |
+
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
|
| 54 |
+
|
| 55 |
+
logger = logging.getLogger(__name__)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# -------------------------------- Compute Basic Info ---------------------------------------
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_token_num_per_seq(
|
| 62 |
+
forward_mode: ForwardMode,
|
| 63 |
+
spec_info: Optional[SpecInput] = None,
|
| 64 |
+
):
|
| 65 |
+
if forward_mode.is_target_verify():
|
| 66 |
+
return spec_info.draft_token_num
|
| 67 |
+
elif forward_mode.is_decode():
|
| 68 |
+
return 1
|
| 69 |
+
elif forward_mode.is_idle():
|
| 70 |
+
return 0
|
| 71 |
+
else:
|
| 72 |
+
# For extend, we should not use `token_num_per_seq`.
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# TODO: may smartly disable TBO when batch size is too small b/c it will slow down
|
| 77 |
+
def compute_split_seq_index(
|
| 78 |
+
forward_mode: ForwardMode,
|
| 79 |
+
num_tokens: int,
|
| 80 |
+
extend_lens: Optional[Sequence[int]],
|
| 81 |
+
token_num_per_seq: Optional[int],
|
| 82 |
+
) -> Optional[int]:
|
| 83 |
+
if forward_mode == ForwardMode.EXTEND:
|
| 84 |
+
assert extend_lens is not None
|
| 85 |
+
return _split_extend_seqs(extend_lens)
|
| 86 |
+
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
| 87 |
+
assert token_num_per_seq is not None
|
| 88 |
+
return (num_tokens // token_num_per_seq) // 2
|
| 89 |
+
elif forward_mode.is_idle() or forward_mode.is_prebuilt():
|
| 90 |
+
assert num_tokens == 0
|
| 91 |
+
return 0
|
| 92 |
+
else:
|
| 93 |
+
raise NotImplementedError()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
|
| 97 |
+
if extend_lens is None:
|
| 98 |
+
return False
|
| 99 |
+
|
| 100 |
+
vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
|
| 101 |
+
left_sum = sum(extend_lens[:vanilla_split_seq_index])
|
| 102 |
+
overall_sum = sum(extend_lens)
|
| 103 |
+
threshold = get_tbo_token_distribution_threshold()
|
| 104 |
+
assert threshold <= 0.5, f"{threshold=}"
|
| 105 |
+
return left_sum < overall_sum * threshold or left_sum > overall_sum * (
|
| 106 |
+
1 - threshold
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _split_extend_seqs(arr: Sequence[int]) -> int:
|
| 111 |
+
if _is_two_chunk_split_enabled(arr):
|
| 112 |
+
return _split_array_by_cum_less_than_half(arr)
|
| 113 |
+
|
| 114 |
+
return _split_array_by_balanced_sum(arr)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _split_array_by_cum_less_than_half(arr: Sequence[int]) -> int:
|
| 118 |
+
left_sum = 0
|
| 119 |
+
overall_sum = sum(arr)
|
| 120 |
+
half_sum = overall_sum // 2
|
| 121 |
+
chosen_index = 0
|
| 122 |
+
|
| 123 |
+
for i in range(len(arr)):
|
| 124 |
+
left_sum += arr[i]
|
| 125 |
+
if left_sum > half_sum:
|
| 126 |
+
chosen_index = i
|
| 127 |
+
break
|
| 128 |
+
|
| 129 |
+
return chosen_index
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _split_array_by_balanced_sum(arr: Sequence[int]) -> int:
|
| 133 |
+
overall_sum = sum(arr)
|
| 134 |
+
left_sum = 0
|
| 135 |
+
min_diff = float("inf")
|
| 136 |
+
best_index = 0
|
| 137 |
+
|
| 138 |
+
for i in range(1, len(arr)):
|
| 139 |
+
left_sum += arr[i - 1]
|
| 140 |
+
right_sum = overall_sum - left_sum
|
| 141 |
+
diff = abs(left_sum - right_sum)
|
| 142 |
+
if diff <= min_diff:
|
| 143 |
+
min_diff = diff
|
| 144 |
+
best_index = i
|
| 145 |
+
else:
|
| 146 |
+
break
|
| 147 |
+
|
| 148 |
+
return best_index
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _update_device_and_sum_field_from_cpu_field(
|
| 152 |
+
batch: ForwardBatch, cpu_field: str, device_field: str, sum_field: str = None
|
| 153 |
+
):
|
| 154 |
+
cpu_value = getattr(batch, cpu_field, None)
|
| 155 |
+
old_device_value = getattr(batch, device_field, None)
|
| 156 |
+
if (
|
| 157 |
+
cpu_value is None
|
| 158 |
+
or old_device_value is None
|
| 159 |
+
or not (isinstance(cpu_value, torch.Tensor) or isinstance(cpu_value, list))
|
| 160 |
+
):
|
| 161 |
+
return
|
| 162 |
+
|
| 163 |
+
new_device_value = (
|
| 164 |
+
cpu_value
|
| 165 |
+
if isinstance(cpu_value, torch.Tensor)
|
| 166 |
+
else torch.tensor(cpu_value, dtype=old_device_value.dtype)
|
| 167 |
+
).to(device=get_global_server_args().device, non_blocking=True)
|
| 168 |
+
setattr(batch, device_field, new_device_value)
|
| 169 |
+
|
| 170 |
+
if sum_field is not None:
|
| 171 |
+
sum_value = (
|
| 172 |
+
cpu_value.sum().item()
|
| 173 |
+
if isinstance(cpu_value, torch.Tensor)
|
| 174 |
+
else sum(cpu_value)
|
| 175 |
+
)
|
| 176 |
+
setattr(batch, sum_field, sum_value)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:
|
| 180 |
+
if seq_index == 0:
|
| 181 |
+
return 0
|
| 182 |
+
|
| 183 |
+
offset = 0
|
| 184 |
+
max_seq_len = min(seq_index, spec_info.seq_lens_cpu.shape[0])
|
| 185 |
+
for i in range(max_seq_len):
|
| 186 |
+
offset += (
|
| 187 |
+
spec_info.seq_lens_cpu[i] + spec_info.draft_token_num
|
| 188 |
+
) * spec_info.draft_token_num
|
| 189 |
+
return offset
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def split_spec_info(
|
| 193 |
+
spec_info: Optional[EagleVerifyInput],
|
| 194 |
+
start_seq_index: int,
|
| 195 |
+
end_seq_index: int,
|
| 196 |
+
start_token_index: int,
|
| 197 |
+
end_token_index: int,
|
| 198 |
+
):
|
| 199 |
+
if spec_info is None:
|
| 200 |
+
return None
|
| 201 |
+
if spec_info.draft_token is not None:
|
| 202 |
+
draft_token = spec_info.draft_token[start_token_index:end_token_index]
|
| 203 |
+
else:
|
| 204 |
+
draft_token = None
|
| 205 |
+
if spec_info.custom_mask is not None and spec_info.draft_token is not None:
|
| 206 |
+
custom_mask_start = _compute_mask_offset(start_seq_index, spec_info)
|
| 207 |
+
if end_seq_index == spec_info.seq_lens_cpu.shape[0]:
|
| 208 |
+
custom_mask_end = spec_info.custom_mask.shape[0]
|
| 209 |
+
else:
|
| 210 |
+
custom_mask_end = _compute_mask_offset(end_seq_index, spec_info)
|
| 211 |
+
|
| 212 |
+
if custom_mask_end > custom_mask_start:
|
| 213 |
+
custom_mask = spec_info.custom_mask[custom_mask_start:custom_mask_end]
|
| 214 |
+
else:
|
| 215 |
+
custom_mask = spec_info.custom_mask
|
| 216 |
+
else:
|
| 217 |
+
custom_mask = spec_info.custom_mask
|
| 218 |
+
if spec_info.positions is not None:
|
| 219 |
+
positions = spec_info.positions[start_token_index:end_token_index]
|
| 220 |
+
else:
|
| 221 |
+
positions = None
|
| 222 |
+
if spec_info.retrive_index is not None:
|
| 223 |
+
retrive_index = spec_info.retrive_index[start_seq_index:end_seq_index]
|
| 224 |
+
else:
|
| 225 |
+
retrive_index = None
|
| 226 |
+
if spec_info.retrive_next_token is not None:
|
| 227 |
+
retrive_next_token = spec_info.retrive_next_token[start_seq_index:end_seq_index]
|
| 228 |
+
else:
|
| 229 |
+
retrive_next_token = None
|
| 230 |
+
if spec_info.retrive_next_sibling is not None:
|
| 231 |
+
retrive_next_sibling = spec_info.retrive_next_sibling[
|
| 232 |
+
start_seq_index:end_seq_index
|
| 233 |
+
]
|
| 234 |
+
else:
|
| 235 |
+
retrive_next_sibling = None
|
| 236 |
+
if spec_info.retrive_cum_len is not None:
|
| 237 |
+
retrive_cum_len = spec_info.retrive_cum_len[start_seq_index:end_seq_index]
|
| 238 |
+
else:
|
| 239 |
+
retrive_cum_len = None
|
| 240 |
+
|
| 241 |
+
if spec_info.seq_lens_cpu is not None:
|
| 242 |
+
seq_lens_cpu = spec_info.seq_lens_cpu[start_seq_index:end_seq_index]
|
| 243 |
+
else:
|
| 244 |
+
seq_lens_cpu = None
|
| 245 |
+
if seq_lens_cpu is not None:
|
| 246 |
+
seq_lens_sum = seq_lens_cpu.sum()
|
| 247 |
+
else:
|
| 248 |
+
seq_lens_sum = None
|
| 249 |
+
output_spec_info = replace(
|
| 250 |
+
spec_info,
|
| 251 |
+
custom_mask=custom_mask,
|
| 252 |
+
draft_token=draft_token,
|
| 253 |
+
positions=positions,
|
| 254 |
+
retrive_index=retrive_index,
|
| 255 |
+
retrive_next_token=retrive_next_token,
|
| 256 |
+
retrive_next_sibling=retrive_next_sibling,
|
| 257 |
+
retrive_cum_len=retrive_cum_len,
|
| 258 |
+
seq_lens_cpu=seq_lens_cpu,
|
| 259 |
+
seq_lens_sum=seq_lens_sum,
|
| 260 |
+
)
|
| 261 |
+
return output_spec_info
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def compute_split_token_index(
|
| 265 |
+
split_seq_index: int,
|
| 266 |
+
forward_mode: "ForwardMode",
|
| 267 |
+
extend_seq_lens: Optional[Sequence[int]],
|
| 268 |
+
token_num_per_seq: Optional[int],
|
| 269 |
+
) -> int:
|
| 270 |
+
if forward_mode == ForwardMode.EXTEND:
|
| 271 |
+
assert extend_seq_lens is not None
|
| 272 |
+
if _is_two_chunk_split_enabled(extend_seq_lens):
|
| 273 |
+
return sum(extend_seq_lens) // 2
|
| 274 |
+
return sum(extend_seq_lens[:split_seq_index])
|
| 275 |
+
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
| 276 |
+
assert token_num_per_seq is not None
|
| 277 |
+
return split_seq_index * token_num_per_seq
|
| 278 |
+
elif forward_mode.is_idle():
|
| 279 |
+
assert split_seq_index == 0
|
| 280 |
+
return 0
|
| 281 |
+
else:
|
| 282 |
+
raise NotImplementedError
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def compute_split_indices_for_cuda_graph_replay(
|
| 286 |
+
forward_mode: ForwardMode,
|
| 287 |
+
cuda_graph_num_tokens: int,
|
| 288 |
+
spec_info: Optional[SpecInput],
|
| 289 |
+
):
|
| 290 |
+
forward_mode_for_tbo_split = (
|
| 291 |
+
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
|
| 292 |
+
)
|
| 293 |
+
token_num_per_seq = get_token_num_per_seq(
|
| 294 |
+
forward_mode=forward_mode, spec_info=spec_info
|
| 295 |
+
)
|
| 296 |
+
tbo_split_seq_index = compute_split_seq_index(
|
| 297 |
+
forward_mode=forward_mode_for_tbo_split,
|
| 298 |
+
num_tokens=cuda_graph_num_tokens,
|
| 299 |
+
extend_lens=None,
|
| 300 |
+
token_num_per_seq=token_num_per_seq,
|
| 301 |
+
)
|
| 302 |
+
tbo_split_token_index = compute_split_token_index(
|
| 303 |
+
split_seq_index=tbo_split_seq_index,
|
| 304 |
+
forward_mode=forward_mode_for_tbo_split,
|
| 305 |
+
extend_seq_lens=None,
|
| 306 |
+
token_num_per_seq=token_num_per_seq,
|
| 307 |
+
)
|
| 308 |
+
return tbo_split_seq_index, tbo_split_token_index
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
# -------------------------------- Preparation ---------------------------------------
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class TboCudaGraphRunnerPlugin:
|
| 315 |
+
def __init__(self):
|
| 316 |
+
self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)
|
| 317 |
+
|
| 318 |
+
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
|
| 319 |
+
if not is_tbo_enabled():
|
| 320 |
+
return
|
| 321 |
+
token_num_per_seq = get_token_num_per_seq(
|
| 322 |
+
forward_mode=batch.forward_mode, spec_info=batch.spec_info
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
batch.tbo_split_seq_index = compute_split_seq_index(
|
| 326 |
+
forward_mode=batch.forward_mode,
|
| 327 |
+
num_tokens=num_tokens,
|
| 328 |
+
extend_lens=None,
|
| 329 |
+
token_num_per_seq=token_num_per_seq,
|
| 330 |
+
)
|
| 331 |
+
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
|
| 332 |
+
assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"
|
| 333 |
+
|
| 334 |
+
self._tbo_children_num_token_non_padded[...] = (
|
| 335 |
+
TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded(batch)
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
TboForwardBatchPreparer.prepare_raw(
|
| 339 |
+
batch,
|
| 340 |
+
tbo_children_num_token_non_padded=self._tbo_children_num_token_non_padded,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
def replay_prepare(
|
| 344 |
+
self,
|
| 345 |
+
forward_mode: ForwardMode,
|
| 346 |
+
bs: int,
|
| 347 |
+
num_token_non_padded: int,
|
| 348 |
+
spec_info: Optional[SpecInput],
|
| 349 |
+
):
|
| 350 |
+
token_num_per_seq = get_token_num_per_seq(
|
| 351 |
+
forward_mode=forward_mode, spec_info=spec_info
|
| 352 |
+
)
|
| 353 |
+
tbo_split_seq_index, tbo_split_token_index = (
|
| 354 |
+
compute_split_indices_for_cuda_graph_replay(
|
| 355 |
+
forward_mode=forward_mode,
|
| 356 |
+
cuda_graph_num_tokens=bs * token_num_per_seq,
|
| 357 |
+
spec_info=spec_info,
|
| 358 |
+
)
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
self._tbo_children_num_token_non_padded[...] = (
|
| 362 |
+
TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded_raw(
|
| 363 |
+
tbo_split_token_index=tbo_split_token_index,
|
| 364 |
+
num_token_non_padded=num_token_non_padded,
|
| 365 |
+
)
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class TboDPAttentionPreparer:
|
| 370 |
+
def prepare_all_gather(
|
| 371 |
+
self,
|
| 372 |
+
local_batch: ScheduleBatch,
|
| 373 |
+
):
|
| 374 |
+
|
| 375 |
+
deepep_mode = get_deepep_mode()
|
| 376 |
+
enable_a2a_moe = not get_moe_a2a_backend().is_none()
|
| 377 |
+
enable_two_batch_overlap = is_tbo_enabled()
|
| 378 |
+
|
| 379 |
+
self.enable_two_batch_overlap = enable_two_batch_overlap
|
| 380 |
+
|
| 381 |
+
if local_batch is not None:
|
| 382 |
+
token_num_per_seq = get_token_num_per_seq(
|
| 383 |
+
forward_mode=local_batch.forward_mode, spec_info=local_batch.spec_info
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
if (
|
| 387 |
+
local_batch.forward_mode.is_target_verify()
|
| 388 |
+
or local_batch.forward_mode.is_decode()
|
| 389 |
+
):
|
| 390 |
+
num_tokens = local_batch.batch_size() * token_num_per_seq
|
| 391 |
+
elif local_batch.forward_mode.is_prebuilt():
|
| 392 |
+
num_tokens = 0
|
| 393 |
+
else:
|
| 394 |
+
num_tokens = local_batch.extend_num_tokens
|
| 395 |
+
self.local_tbo_split_seq_index = compute_split_seq_index(
|
| 396 |
+
forward_mode=local_batch.forward_mode,
|
| 397 |
+
num_tokens=num_tokens,
|
| 398 |
+
extend_lens=local_batch.extend_lens,
|
| 399 |
+
token_num_per_seq=token_num_per_seq,
|
| 400 |
+
)
|
| 401 |
+
resolved_deepep_mode = deepep_mode.resolve(local_batch.is_extend_in_batch)
|
| 402 |
+
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
|
| 403 |
+
(
|
| 404 |
+
local_batch.forward_mode.is_extend()
|
| 405 |
+
and not local_batch.forward_mode.is_target_verify()
|
| 406 |
+
)
|
| 407 |
+
and enable_a2a_moe
|
| 408 |
+
and (resolved_deepep_mode.is_low_latency())
|
| 409 |
+
)
|
| 410 |
+
else:
|
| 411 |
+
self.local_tbo_split_seq_index = 0
|
| 412 |
+
local_can_run_tbo = True
|
| 413 |
+
|
| 414 |
+
local_forward_mode = self._compute_local_forward_mode(local_batch)
|
| 415 |
+
|
| 416 |
+
return local_can_run_tbo, local_forward_mode
|
| 417 |
+
|
| 418 |
+
def compute_output(self, partial_global_info):
|
| 419 |
+
# Perform only one Device-to-Host (D2H) memory copy
|
| 420 |
+
cpu_data = partial_global_info[:, :2].cpu()
|
| 421 |
+
local_can_run_tbo_aggregated = min(cpu_data[:, 0].tolist())
|
| 422 |
+
forward_modes = cpu_data[:, 1].tolist()
|
| 423 |
+
|
| 424 |
+
global_forward_mode, forward_mode_agree = self._compute_global_forward_mode(
|
| 425 |
+
forward_modes
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
can_run_tbo = (
|
| 429 |
+
self.enable_two_batch_overlap
|
| 430 |
+
and local_can_run_tbo_aggregated
|
| 431 |
+
and forward_mode_agree
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
tbo_split_seq_index = self.local_tbo_split_seq_index if can_run_tbo else None
|
| 435 |
+
global_forward_mode = global_forward_mode if can_run_tbo else None
|
| 436 |
+
return tbo_split_seq_index, global_forward_mode
|
| 437 |
+
|
| 438 |
+
@staticmethod
|
| 439 |
+
def _compute_local_forward_mode(local_batch):
|
| 440 |
+
return (
|
| 441 |
+
local_batch.forward_mode if local_batch is not None else ForwardMode.IDLE
|
| 442 |
+
).value
|
| 443 |
+
|
| 444 |
+
@staticmethod
|
| 445 |
+
def _compute_global_forward_mode(forward_modes):
|
| 446 |
+
forward_modes_excluding_idle_and_prebuilt = [
|
| 447 |
+
x
|
| 448 |
+
for x in forward_modes
|
| 449 |
+
if x != ForwardMode.IDLE.value and x != ForwardMode.PREBUILT.value
|
| 450 |
+
]
|
| 451 |
+
|
| 452 |
+
if not forward_modes_excluding_idle_and_prebuilt:
|
| 453 |
+
return ForwardMode.IDLE, False
|
| 454 |
+
|
| 455 |
+
forward_mode_agree = TboDPAttentionPreparer._is_all_same(
|
| 456 |
+
forward_modes_excluding_idle_and_prebuilt
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
global_forward_mode = (
|
| 460 |
+
ForwardMode(forward_modes_excluding_idle_and_prebuilt[0])
|
| 461 |
+
if forward_mode_agree
|
| 462 |
+
else None
|
| 463 |
+
)
|
| 464 |
+
return global_forward_mode, forward_mode_agree
|
| 465 |
+
|
| 466 |
+
@staticmethod
|
| 467 |
+
def _is_all_same(x):
|
| 468 |
+
return all(value == x[0] for value in x)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
class TboForwardBatchPreparer:
|
| 472 |
+
@classmethod
|
| 473 |
+
def prepare(cls, batch: ForwardBatch, is_draft_worker: bool = False):
|
| 474 |
+
if batch.tbo_split_seq_index is None or is_draft_worker:
|
| 475 |
+
return
|
| 476 |
+
|
| 477 |
+
tbo_children_num_token_non_padded = (
|
| 478 |
+
cls.compute_tbo_children_num_token_non_padded(batch)
|
| 479 |
+
)
|
| 480 |
+
cls.prepare_raw(
|
| 481 |
+
batch, tbo_children_num_token_non_padded=tbo_children_num_token_non_padded
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
@classmethod
|
| 485 |
+
def prepare_raw(
|
| 486 |
+
cls, batch: ForwardBatch, tbo_children_num_token_non_padded: torch.Tensor
|
| 487 |
+
):
|
| 488 |
+
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
|
| 489 |
+
|
| 490 |
+
tbo_split_token_index = cls._compute_split_token_index(batch)
|
| 491 |
+
|
| 492 |
+
is_enable_two_chunk = (
|
| 493 |
+
batch.forward_mode == ForwardMode.EXTEND
|
| 494 |
+
and _is_two_chunk_split_enabled(batch.extend_seq_lens_cpu)
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
if _tbo_debug:
|
| 498 |
+
logger.info(
|
| 499 |
+
f"TboForwardBatchPreparer.prepare "
|
| 500 |
+
f"is_enable_two_chunk={is_enable_two_chunk} "
|
| 501 |
+
f"tbo_split_seq_index={batch.tbo_split_seq_index} "
|
| 502 |
+
f"tbo_split_token_index={tbo_split_token_index} "
|
| 503 |
+
f"extend_seq_lens={batch.extend_seq_lens_cpu} "
|
| 504 |
+
f"bs={batch.batch_size} "
|
| 505 |
+
f"forward_mode={batch.forward_mode}"
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
assert isinstance(batch.attn_backend, TboAttnBackend)
|
| 509 |
+
attn_backend_child_a, attn_backend_child_b = batch.attn_backend.children
|
| 510 |
+
|
| 511 |
+
[out_num_token_non_padded_a, out_num_token_non_padded_b] = (
|
| 512 |
+
tbo_children_num_token_non_padded
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
child_a = cls.filter_batch(
|
| 516 |
+
batch,
|
| 517 |
+
start_token_index=0,
|
| 518 |
+
end_token_index=tbo_split_token_index,
|
| 519 |
+
start_seq_index=0,
|
| 520 |
+
end_seq_index=(
|
| 521 |
+
batch.tbo_split_seq_index + 1
|
| 522 |
+
if is_enable_two_chunk
|
| 523 |
+
else batch.tbo_split_seq_index
|
| 524 |
+
),
|
| 525 |
+
output_attn_backend=attn_backend_child_a,
|
| 526 |
+
out_num_token_non_padded=out_num_token_non_padded_a,
|
| 527 |
+
)
|
| 528 |
+
child_b = cls.filter_batch(
|
| 529 |
+
batch,
|
| 530 |
+
start_token_index=tbo_split_token_index,
|
| 531 |
+
end_token_index=batch.input_ids.shape[0],
|
| 532 |
+
start_seq_index=batch.tbo_split_seq_index,
|
| 533 |
+
end_seq_index=batch.batch_size,
|
| 534 |
+
output_attn_backend=attn_backend_child_b,
|
| 535 |
+
out_num_token_non_padded=out_num_token_non_padded_b,
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
if is_enable_two_chunk:
|
| 539 |
+
cls.derive_fields_related_to_seq_len_for_two_chunk(
|
| 540 |
+
batch,
|
| 541 |
+
child_a=child_a,
|
| 542 |
+
child_b=child_b,
|
| 543 |
+
tbo_split_seq_index=batch.tbo_split_seq_index,
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
assert batch.tbo_children is None
|
| 547 |
+
batch.tbo_children = [child_a, child_b]
|
| 548 |
+
|
| 549 |
+
@classmethod
|
| 550 |
+
def derive_fields_related_to_seq_len_for_two_chunk(
|
| 551 |
+
cls,
|
| 552 |
+
batch: ForwardBatch,
|
| 553 |
+
*,
|
| 554 |
+
child_a: ForwardBatch,
|
| 555 |
+
child_b: ForwardBatch,
|
| 556 |
+
tbo_split_seq_index: int,
|
| 557 |
+
):
|
| 558 |
+
extend_seq_lens_cpu = batch.extend_seq_lens_cpu
|
| 559 |
+
overall_seq_lens_sum = sum(extend_seq_lens_cpu)
|
| 560 |
+
half_seq_lens_sum = overall_seq_lens_sum // 2
|
| 561 |
+
left_last_seq_token_num = half_seq_lens_sum - sum(
|
| 562 |
+
extend_seq_lens_cpu[:tbo_split_seq_index]
|
| 563 |
+
)
|
| 564 |
+
right_first_seq_token_num = (
|
| 565 |
+
extend_seq_lens_cpu[tbo_split_seq_index] - left_last_seq_token_num
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
# making deepcopy to be extra safe
|
| 569 |
+
child_a.extend_seq_lens_cpu = copy.deepcopy(child_a.extend_seq_lens_cpu)
|
| 570 |
+
child_a.extend_seq_lens_cpu[-1] = left_last_seq_token_num
|
| 571 |
+
child_b.extend_seq_lens_cpu = copy.deepcopy(child_b.extend_seq_lens_cpu)
|
| 572 |
+
child_b.extend_seq_lens_cpu[0] = right_first_seq_token_num
|
| 573 |
+
for child in [child_a, child_b]:
|
| 574 |
+
_update_device_and_sum_field_from_cpu_field(
|
| 575 |
+
batch=child,
|
| 576 |
+
cpu_field="extend_seq_lens_cpu",
|
| 577 |
+
device_field="extend_seq_lens",
|
| 578 |
+
sum_field="extend_num_tokens",
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
assert (
|
| 582 |
+
child_a.extend_num_tokens == half_seq_lens_sum
|
| 583 |
+
), f"{child_a.extend_num_tokens=}, {half_seq_lens_sum=}"
|
| 584 |
+
|
| 585 |
+
child_a.seq_lens_cpu = copy.deepcopy(child_a.seq_lens_cpu)
|
| 586 |
+
child_a.seq_lens_cpu[-1] = (
|
| 587 |
+
child_a.extend_seq_lens_cpu[-1] + child_a.extend_prefix_lens_cpu[-1]
|
| 588 |
+
)
|
| 589 |
+
_update_device_and_sum_field_from_cpu_field(
|
| 590 |
+
batch=child_a,
|
| 591 |
+
cpu_field="seq_lens_cpu",
|
| 592 |
+
device_field="seq_lens",
|
| 593 |
+
sum_field="seq_lens_sum",
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
child_b.extend_prefix_lens_cpu = copy.deepcopy(child_b.extend_prefix_lens_cpu)
|
| 597 |
+
child_b.extend_prefix_lens_cpu[0] += left_last_seq_token_num
|
| 598 |
+
_update_device_and_sum_field_from_cpu_field(
|
| 599 |
+
batch=child_b,
|
| 600 |
+
cpu_field="extend_prefix_lens_cpu",
|
| 601 |
+
device_field="extend_prefix_lens",
|
| 602 |
+
sum_field=None,
|
| 603 |
+
)
|
| 604 |
+
_, child_b.extend_start_loc = compute_position(
|
| 605 |
+
get_global_server_args().attention_backend,
|
| 606 |
+
child_b.extend_prefix_lens,
|
| 607 |
+
child_b.extend_seq_lens,
|
| 608 |
+
child_b.extend_num_tokens,
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
@classmethod
|
| 612 |
+
def filter_batch(
|
| 613 |
+
cls,
|
| 614 |
+
batch: ForwardBatch,
|
| 615 |
+
*,
|
| 616 |
+
start_token_index: int,
|
| 617 |
+
end_token_index: int,
|
| 618 |
+
start_seq_index: int,
|
| 619 |
+
end_seq_index: int,
|
| 620 |
+
output_attn_backend: AttentionBackend,
|
| 621 |
+
out_num_token_non_padded: torch.Tensor,
|
| 622 |
+
):
|
| 623 |
+
assert (
|
| 624 |
+
end_token_index >= start_token_index
|
| 625 |
+
), f"{end_token_index=}, {start_token_index=}, batch={batch}"
|
| 626 |
+
num_tokens = batch.input_ids.shape[0]
|
| 627 |
+
num_seqs = batch.batch_size
|
| 628 |
+
|
| 629 |
+
output_dict = dict()
|
| 630 |
+
|
| 631 |
+
for key in [
|
| 632 |
+
"input_ids",
|
| 633 |
+
"positions",
|
| 634 |
+
"out_cache_loc",
|
| 635 |
+
]:
|
| 636 |
+
old_value = getattr(batch, key)
|
| 637 |
+
assert (
|
| 638 |
+
old_value.shape[0] == num_tokens
|
| 639 |
+
), f"{key=} {old_value=} {num_tokens=} {batch=}"
|
| 640 |
+
output_dict[key] = old_value[start_token_index:end_token_index]
|
| 641 |
+
|
| 642 |
+
attention_tp_size = get_attention_tp_size()
|
| 643 |
+
output_dict["tbo_padded_len"] = (
|
| 644 |
+
(end_token_index - start_token_index - 1) // attention_tp_size + 1
|
| 645 |
+
) * attention_tp_size
|
| 646 |
+
|
| 647 |
+
for key in [
|
| 648 |
+
"req_pool_indices",
|
| 649 |
+
"seq_lens",
|
| 650 |
+
"seq_lens_cpu",
|
| 651 |
+
"extend_seq_lens",
|
| 652 |
+
"extend_prefix_lens",
|
| 653 |
+
"extend_start_loc",
|
| 654 |
+
"extend_prefix_lens_cpu",
|
| 655 |
+
"extend_seq_lens_cpu",
|
| 656 |
+
"extend_logprob_start_lens_cpu",
|
| 657 |
+
"lora_ids",
|
| 658 |
+
"rids",
|
| 659 |
+
]:
|
| 660 |
+
old_value = getattr(batch, key)
|
| 661 |
+
if old_value is None:
|
| 662 |
+
continue
|
| 663 |
+
elif batch.forward_mode.is_target_verify() and (
|
| 664 |
+
key == "extend_seq_lens"
|
| 665 |
+
or key == "extend_prefix_lens"
|
| 666 |
+
or key == "extend_start_loc"
|
| 667 |
+
or key == "extend_prefix_lens_cpu"
|
| 668 |
+
or key == "extend_seq_lens_cpu"
|
| 669 |
+
or key == "extend_logprob_start_lens_cpu"
|
| 670 |
+
):
|
| 671 |
+
output_dict[key] = None
|
| 672 |
+
continue
|
| 673 |
+
assert (
|
| 674 |
+
len(old_value) == num_seqs
|
| 675 |
+
), f"{key=} {old_value=} {num_seqs=} {batch=}"
|
| 676 |
+
output_dict[key] = old_value[start_seq_index:end_seq_index]
|
| 677 |
+
|
| 678 |
+
spec_info = getattr(batch, "spec_info")
|
| 679 |
+
output_spec_info = split_spec_info(
|
| 680 |
+
spec_info=spec_info,
|
| 681 |
+
start_token_index=start_token_index,
|
| 682 |
+
end_token_index=end_token_index,
|
| 683 |
+
start_seq_index=start_seq_index,
|
| 684 |
+
end_seq_index=end_seq_index,
|
| 685 |
+
)
|
| 686 |
+
output_dict["spec_info"] = output_spec_info
|
| 687 |
+
for key in [
|
| 688 |
+
"forward_mode",
|
| 689 |
+
"is_extend_in_batch",
|
| 690 |
+
"all_extend_in_batch",
|
| 691 |
+
"return_logprob",
|
| 692 |
+
"req_to_token_pool",
|
| 693 |
+
"token_to_kv_pool",
|
| 694 |
+
"can_run_dp_cuda_graph",
|
| 695 |
+
"dp_padding_mode",
|
| 696 |
+
"global_forward_mode",
|
| 697 |
+
"is_prefill_only",
|
| 698 |
+
"spec_algorithm",
|
| 699 |
+
"capture_hidden_mode",
|
| 700 |
+
"padded_static_len",
|
| 701 |
+
"mrope_positions", # only used by qwen2-vl, thus not care
|
| 702 |
+
"split_index", # for split prefill
|
| 703 |
+
"orig_seq_lens", # only used by qwen-1m, thus not care
|
| 704 |
+
]:
|
| 705 |
+
output_dict[key] = getattr(batch, key)
|
| 706 |
+
if not batch.forward_mode.is_target_verify():
|
| 707 |
+
assert (
|
| 708 |
+
_compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
|
| 709 |
+
== batch.extend_num_tokens
|
| 710 |
+
), f"{batch=}"
|
| 711 |
+
extend_num_tokens = _compute_extend_num_tokens(
|
| 712 |
+
output_dict["input_ids"], output_dict["forward_mode"]
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
# TODO improve, e.g. unify w/ `init_raw`
|
| 716 |
+
if (
|
| 717 |
+
get_global_server_args().moe_dense_tp_size == 1
|
| 718 |
+
and batch.global_dp_buffer_len is not None
|
| 719 |
+
):
|
| 720 |
+
sum_len = end_token_index - start_token_index
|
| 721 |
+
global_dp_buffer_len = sum_len
|
| 722 |
+
else:
|
| 723 |
+
global_dp_buffer_len = None
|
| 724 |
+
|
| 725 |
+
output_dict.update(
|
| 726 |
+
dict(
|
| 727 |
+
batch_size=end_seq_index - start_seq_index,
|
| 728 |
+
seq_lens_sum=(
|
| 729 |
+
output_dict["seq_lens_cpu"].sum()
|
| 730 |
+
if "seq_lens_cpu" in output_dict
|
| 731 |
+
else None
|
| 732 |
+
),
|
| 733 |
+
extend_num_tokens=extend_num_tokens,
|
| 734 |
+
attn_backend=output_attn_backend,
|
| 735 |
+
num_token_non_padded=out_num_token_non_padded,
|
| 736 |
+
# TODO: handle it when we need TBO + DeepSeek V3.2
|
| 737 |
+
num_token_non_padded_cpu=None,
|
| 738 |
+
tbo_split_seq_index=None,
|
| 739 |
+
tbo_parent_token_range=(start_token_index, end_token_index),
|
| 740 |
+
tbo_children=None,
|
| 741 |
+
original_global_num_tokens_cpu=None,
|
| 742 |
+
global_num_tokens_gpu=None,
|
| 743 |
+
global_num_tokens_cpu=None,
|
| 744 |
+
global_dp_buffer_len=global_dp_buffer_len,
|
| 745 |
+
global_num_tokens_for_logprob_gpu=None,
|
| 746 |
+
global_num_tokens_for_logprob_cpu=None,
|
| 747 |
+
sampling_info=None,
|
| 748 |
+
# For logits and logprobs post processing, thus we do not care
|
| 749 |
+
temp_scaled_logprobs=False,
|
| 750 |
+
temperature=None,
|
| 751 |
+
top_p_normalized_logprobs=False,
|
| 752 |
+
top_p=None,
|
| 753 |
+
mm_inputs=None,
|
| 754 |
+
top_logprobs_nums=None,
|
| 755 |
+
token_ids_logprobs=None,
|
| 756 |
+
next_token_logits_buffer=None,
|
| 757 |
+
return_hidden_states_before_norm=False,
|
| 758 |
+
)
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
errors = []
|
| 762 |
+
for field in dataclasses.fields(ForwardBatch):
|
| 763 |
+
if getattr(batch, field.name) is not None and field.name not in output_dict:
|
| 764 |
+
errors.append(
|
| 765 |
+
f"Field {field.name} has value, but is not yet supported (value={getattr(batch, field.name)} batch={batch})"
|
| 766 |
+
)
|
| 767 |
+
if len(errors) > 0:
|
| 768 |
+
raise Exception(f"{len(errors)} errors happen:\n" + "\n\n".join(errors))
|
| 769 |
+
|
| 770 |
+
return ForwardBatch(**output_dict)
|
| 771 |
+
|
| 772 |
+
@classmethod
|
| 773 |
+
def compute_tbo_children_num_token_non_padded(cls, batch: ForwardBatch):
|
| 774 |
+
return cls.compute_tbo_children_num_token_non_padded_raw(
|
| 775 |
+
tbo_split_token_index=cls._compute_split_token_index(batch),
|
| 776 |
+
num_token_non_padded=len(batch.input_ids),
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
@classmethod
|
| 780 |
+
def compute_tbo_children_num_token_non_padded_raw(
|
| 781 |
+
cls, tbo_split_token_index: int, num_token_non_padded: int
|
| 782 |
+
):
|
| 783 |
+
# TODO we may make padding on both sub-batches to make it slightly more balanced
|
| 784 |
+
value_a = min(tbo_split_token_index, num_token_non_padded)
|
| 785 |
+
value_b = max(0, num_token_non_padded - tbo_split_token_index)
|
| 786 |
+
return torch.tensor([value_a, value_b], dtype=torch.int32).to(
|
| 787 |
+
device=get_global_server_args().device, non_blocking=True
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
@classmethod
|
| 791 |
+
def _compute_split_token_index(cls, batch: ForwardBatch):
|
| 792 |
+
token_num_per_seq = get_token_num_per_seq(
|
| 793 |
+
forward_mode=batch.forward_mode, spec_info=batch.spec_info
|
| 794 |
+
)
|
| 795 |
+
return compute_split_token_index(
|
| 796 |
+
split_seq_index=batch.tbo_split_seq_index,
|
| 797 |
+
forward_mode=batch.forward_mode,
|
| 798 |
+
extend_seq_lens=batch.extend_seq_lens_cpu,
|
| 799 |
+
token_num_per_seq=token_num_per_seq,
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
|
| 804 |
+
if (
|
| 805 |
+
forward_mode.is_decode()
|
| 806 |
+
or forward_mode.is_idle()
|
| 807 |
+
or forward_mode.is_target_verify()
|
| 808 |
+
):
|
| 809 |
+
return None
|
| 810 |
+
elif forward_mode.is_extend():
|
| 811 |
+
return input_ids.shape[0]
|
| 812 |
+
raise NotImplementedError
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
# -------------------------------- Execution ---------------------------------------
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
def model_forward_maybe_tbo(
|
| 819 |
+
layers,
|
| 820 |
+
enable_tbo: bool,
|
| 821 |
+
positions: torch.Tensor,
|
| 822 |
+
forward_batch: ForwardBatch,
|
| 823 |
+
hidden_states: torch.Tensor,
|
| 824 |
+
input_data_scatter_mode: ScatterMode,
|
| 825 |
+
residual: Optional[torch.Tensor],
|
| 826 |
+
zero_allocator: Optional[BumpAllocator] = None,
|
| 827 |
+
):
|
| 828 |
+
inputs = dict(
|
| 829 |
+
positions=positions,
|
| 830 |
+
hidden_states=hidden_states,
|
| 831 |
+
forward_batch=forward_batch,
|
| 832 |
+
residual=residual,
|
| 833 |
+
zero_allocator=zero_allocator,
|
| 834 |
+
)
|
| 835 |
+
layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode
|
| 836 |
+
operations_strategy = OperationsStrategy.init_new_tbo(
|
| 837 |
+
layers, forward_batch.global_forward_mode
|
| 838 |
+
)
|
| 839 |
+
if enable_tbo:
|
| 840 |
+
return _model_forward_tbo(
|
| 841 |
+
inputs=inputs,
|
| 842 |
+
operations_strategy=operations_strategy,
|
| 843 |
+
input_data_scatter_mode=input_data_scatter_mode,
|
| 844 |
+
layer_input_scatter_mode=layer_input_scatter_mode,
|
| 845 |
+
)
|
| 846 |
+
else:
|
| 847 |
+
return _model_forward_non_tbo(inputs, operations_strategy)
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
def _model_forward_tbo(
|
| 851 |
+
inputs,
|
| 852 |
+
operations_strategy: OperationsStrategy,
|
| 853 |
+
input_data_scatter_mode: ScatterMode,
|
| 854 |
+
layer_input_scatter_mode: ScatterMode,
|
| 855 |
+
):
|
| 856 |
+
inputs_arr = _model_forward_tbo_split_inputs(
|
| 857 |
+
**inputs,
|
| 858 |
+
input_data_scatter_mode=input_data_scatter_mode,
|
| 859 |
+
layer_input_scatter_mode=layer_input_scatter_mode,
|
| 860 |
+
)
|
| 861 |
+
original_hidden_states_len = inputs["hidden_states"].shape[0]
|
| 862 |
+
del inputs
|
| 863 |
+
|
| 864 |
+
context = (
|
| 865 |
+
empty_context()
|
| 866 |
+
if _is_hip
|
| 867 |
+
else deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
| 868 |
+
operations_strategy.deep_gemm_num_sms
|
| 869 |
+
)
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
with context:
|
| 873 |
+
outputs_arr = execute_overlapped_operations(
|
| 874 |
+
inputs_arr=inputs_arr,
|
| 875 |
+
operations_arr=[operations_strategy.operations] * 2,
|
| 876 |
+
delta_stages=[0, operations_strategy.tbo_delta_stages],
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
return _model_forward_tbo_merge_outputs(*outputs_arr, original_hidden_states_len)
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
def _model_forward_non_tbo(inputs, operations_strategy: OperationsStrategy):
|
| 883 |
+
outputs = execute_operations(inputs, operations_strategy.operations)
|
| 884 |
+
return outputs["hidden_states"], outputs["residual"]
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
def _model_forward_tbo_split_inputs(
|
| 888 |
+
hidden_states: torch.Tensor,
|
| 889 |
+
residual: torch.Tensor,
|
| 890 |
+
positions: torch.Tensor,
|
| 891 |
+
forward_batch: ForwardBatch,
|
| 892 |
+
zero_allocator: Optional[BumpAllocator],
|
| 893 |
+
input_data_scatter_mode: ScatterMode,
|
| 894 |
+
layer_input_scatter_mode: ScatterMode,
|
| 895 |
+
) -> List[Dict]:
|
| 896 |
+
tbo_splitter_scatter_mode = ScatterMode.TP_ATTN_FULL
|
| 897 |
+
context = CommunicateContext.init_new()
|
| 898 |
+
|
| 899 |
+
hidden_states, residual = CommunicateSummableTensorPairFn.execute(
|
| 900 |
+
hidden_states_input_mode=input_data_scatter_mode,
|
| 901 |
+
residual_input_mode=input_data_scatter_mode,
|
| 902 |
+
output_mode=tbo_splitter_scatter_mode,
|
| 903 |
+
hidden_states=hidden_states,
|
| 904 |
+
residual=residual,
|
| 905 |
+
forward_batch=forward_batch,
|
| 906 |
+
context=context,
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
inputs_arr = _model_forward_tbo_split_inputs_raw(
|
| 910 |
+
hidden_states=hidden_states,
|
| 911 |
+
residual=residual,
|
| 912 |
+
positions=positions,
|
| 913 |
+
forward_batch=forward_batch,
|
| 914 |
+
zero_allocator=zero_allocator,
|
| 915 |
+
)
|
| 916 |
+
|
| 917 |
+
def _post_transform(hidden_states, residual, forward_batch, **kwargs):
|
| 918 |
+
hidden_states, residual = CommunicateSummableTensorPairFn.execute(
|
| 919 |
+
hidden_states_input_mode=tbo_splitter_scatter_mode,
|
| 920 |
+
residual_input_mode=tbo_splitter_scatter_mode,
|
| 921 |
+
output_mode=layer_input_scatter_mode,
|
| 922 |
+
hidden_states=hidden_states,
|
| 923 |
+
residual=residual,
|
| 924 |
+
forward_batch=forward_batch,
|
| 925 |
+
context=context,
|
| 926 |
+
)
|
| 927 |
+
return dict(
|
| 928 |
+
hidden_states=hidden_states,
|
| 929 |
+
residual=residual,
|
| 930 |
+
forward_batch=forward_batch,
|
| 931 |
+
**kwargs,
|
| 932 |
+
)
|
| 933 |
+
|
| 934 |
+
return [_post_transform(**inputs) for inputs in inputs_arr]
|
| 935 |
+
|
| 936 |
+
|
| 937 |
+
def _model_forward_tbo_split_inputs_raw(
|
| 938 |
+
hidden_states: torch.Tensor,
|
| 939 |
+
residual: torch.Tensor,
|
| 940 |
+
positions: torch.Tensor,
|
| 941 |
+
forward_batch: ForwardBatch,
|
| 942 |
+
zero_allocator: Optional[BumpAllocator],
|
| 943 |
+
) -> List[Dict]:
|
| 944 |
+
return [
|
| 945 |
+
dict(
|
| 946 |
+
**_model_forward_filter_inputs(
|
| 947 |
+
hidden_states=hidden_states,
|
| 948 |
+
residual=residual,
|
| 949 |
+
positions=positions,
|
| 950 |
+
output_forward_batch=output_forward_batch,
|
| 951 |
+
tbo_subbatch_index=tbo_subbatch_index,
|
| 952 |
+
),
|
| 953 |
+
**(
|
| 954 |
+
dict(zero_allocator=zero_allocator)
|
| 955 |
+
if zero_allocator is not None
|
| 956 |
+
else {}
|
| 957 |
+
),
|
| 958 |
+
)
|
| 959 |
+
for tbo_subbatch_index, output_forward_batch in enumerate(
|
| 960 |
+
forward_batch.tbo_children
|
| 961 |
+
)
|
| 962 |
+
]
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
def _model_forward_filter_inputs(
|
| 966 |
+
hidden_states: torch.Tensor,
|
| 967 |
+
residual: torch.Tensor,
|
| 968 |
+
positions: torch.Tensor,
|
| 969 |
+
output_forward_batch: ForwardBatch,
|
| 970 |
+
tbo_subbatch_index: int,
|
| 971 |
+
) -> Dict:
|
| 972 |
+
token_slice = slice(*output_forward_batch.tbo_parent_token_range)
|
| 973 |
+
hidden_states = hidden_states[token_slice]
|
| 974 |
+
residual = None if residual is None else residual[token_slice]
|
| 975 |
+
positions = positions[token_slice]
|
| 976 |
+
|
| 977 |
+
assert output_forward_batch.tbo_padded_len is not None
|
| 978 |
+
padded_len = output_forward_batch.tbo_padded_len
|
| 979 |
+
|
| 980 |
+
def _pad(x):
|
| 981 |
+
nonlocal padded_len
|
| 982 |
+
if x is None:
|
| 983 |
+
return None
|
| 984 |
+
if x.shape[0] == padded_len:
|
| 985 |
+
return x
|
| 986 |
+
res = torch.zeros((padded_len, *x.shape[1:]), dtype=x.dtype, device=x.device)
|
| 987 |
+
res[: x.shape[0]] = x
|
| 988 |
+
return res
|
| 989 |
+
|
| 990 |
+
return dict(
|
| 991 |
+
hidden_states=_pad(hidden_states),
|
| 992 |
+
residual=_pad(residual),
|
| 993 |
+
positions=_pad(positions),
|
| 994 |
+
forward_batch=output_forward_batch,
|
| 995 |
+
tbo_subbatch_index=tbo_subbatch_index,
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
|
| 999 |
+
def _model_forward_tbo_merge_outputs(output_a, output_b, original_len):
|
| 1000 |
+
def _handle_key(name):
|
| 1001 |
+
value_a = output_a[name]
|
| 1002 |
+
value_b = output_b[name]
|
| 1003 |
+
assert (value_a is None) == (value_b is None)
|
| 1004 |
+
if value_a is None:
|
| 1005 |
+
return None
|
| 1006 |
+
s0, t0 = output_a["forward_batch"].tbo_parent_token_range
|
| 1007 |
+
s1, t1 = output_b["forward_batch"].tbo_parent_token_range
|
| 1008 |
+
res = torch.zeros(
|
| 1009 |
+
(original_len, *value_a.shape[1:]),
|
| 1010 |
+
dtype=value_a.dtype,
|
| 1011 |
+
device=value_a.device,
|
| 1012 |
+
)
|
| 1013 |
+
res[slice(s0, t0)] = value_a[: t0 - s0]
|
| 1014 |
+
res[slice(s1, t1)] = value_b[: t1 - s1]
|
| 1015 |
+
return res
|
| 1016 |
+
|
| 1017 |
+
return _handle_key("hidden_states"), _handle_key("residual")
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
+
# -------------------------------- Utilities and wrappers ---------------------------------------
|
| 1021 |
+
|
| 1022 |
+
|
| 1023 |
+
class MaybeTboDeepEPDispatcher(BaseDispatcher):
|
| 1024 |
+
def __init__(self, **kwargs):
|
| 1025 |
+
super().__init__()
|
| 1026 |
+
num_inner_dispatchers = 2 if is_tbo_enabled() else 1
|
| 1027 |
+
if get_moe_a2a_backend().is_deepep():
|
| 1028 |
+
self._inners = [
|
| 1029 |
+
DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
|
| 1030 |
+
]
|
| 1031 |
+
elif get_moe_a2a_backend().is_mooncake():
|
| 1032 |
+
self._inners = [
|
| 1033 |
+
MooncakeEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
|
| 1034 |
+
]
|
| 1035 |
+
elif get_moe_a2a_backend().is_mori():
|
| 1036 |
+
self._inners = [
|
| 1037 |
+
MoriEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
|
| 1038 |
+
]
|
| 1039 |
+
|
| 1040 |
+
def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
|
| 1041 |
+
return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)
|
| 1042 |
+
|
| 1043 |
+
def dispatch(self, **kwargs) -> DispatchOutput:
|
| 1044 |
+
return self._execute("dispatch", **kwargs)
|
| 1045 |
+
|
| 1046 |
+
def dispatch_a(self, **kwargs):
|
| 1047 |
+
return self._execute("dispatch_a", **kwargs)
|
| 1048 |
+
|
| 1049 |
+
def dispatch_b(self, **kwargs):
|
| 1050 |
+
return self._execute("dispatch_b", **kwargs)
|
| 1051 |
+
|
| 1052 |
+
def combine(self, **kwargs) -> torch.Tensor:
|
| 1053 |
+
return self._execute("combine", **kwargs)
|
| 1054 |
+
|
| 1055 |
+
def combine_a(self, **kwargs):
|
| 1056 |
+
return self._execute("combine_a", **kwargs)
|
| 1057 |
+
|
| 1058 |
+
def combine_b(self, **kwargs):
|
| 1059 |
+
return self._execute("combine_b", **kwargs)
|
| 1060 |
+
|
| 1061 |
+
def register_deepep_dispatch_hook(self, hook):
|
| 1062 |
+
handle_list = []
|
| 1063 |
+
for inner in self._inners:
|
| 1064 |
+
handle_list.append(inner.register_deepep_dispatch_hook(hook))
|
| 1065 |
+
return handle_list
|
| 1066 |
+
|
| 1067 |
+
def set_quant_config(self, quant_config: dict):
|
| 1068 |
+
super().set_quant_config(quant_config)
|
| 1069 |
+
for inner in self._inners:
|
| 1070 |
+
inner.set_quant_config(quant_config)
|
| 1071 |
+
|
| 1072 |
+
def set_overlap_args(
|
| 1073 |
+
self, combine_overlap_args: CombineOverlapArgs, meta_overlap_args: dict
|
| 1074 |
+
):
|
| 1075 |
+
super().set_overlap_args(combine_overlap_args, meta_overlap_args)
|
| 1076 |
+
for inner in self._inners:
|
| 1077 |
+
inner.set_overlap_args(combine_overlap_args, meta_overlap_args)
|
| 1078 |
+
|
| 1079 |
+
def clear_overlap_args(self):
|
| 1080 |
+
super().clear_overlap_args()
|
| 1081 |
+
for inner in self._inners:
|
| 1082 |
+
inner.clear_overlap_args()
|
sglang/python/sglang/srt/checkpoint_engine/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Checkpoint engine module for SGLang.
|
| 3 |
+
|
| 4 |
+
This module provides functionality for updating model weights via checkpoint engine.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from sglang.srt.checkpoint_engine.update import main
|
| 8 |
+
|
| 9 |
+
__all__ = ["main"]
|
sglang/python/sglang/srt/checkpoint_engine/checkpoint_engine_worker.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023-2024 SGLang Team
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ==============================================================================
|
| 14 |
+
"""
|
| 15 |
+
Checkpoint-engine integration for SGLang.
|
| 16 |
+
This module provides weight update functionality via IPC for checkpoint-engine compatibility.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
from typing import Callable, Dict, Optional
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import zmq
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from checkpoint_engine.worker import update_weights_from_ipc
|
| 27 |
+
except ImportError:
|
| 28 |
+
raise ImportError(
|
| 29 |
+
"checkpoint-engine is not installed. "
|
| 30 |
+
"Please install it with: pip install sglang[checkpoint-engine]"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SGLangCheckpointEngineWorkerExtension:
|
| 37 |
+
"""
|
| 38 |
+
Worker extension for SGLang to support checkpoint-engine IPC weight updates.
|
| 39 |
+
This class provides the interface needed for checkpoint-engine integration.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self):
|
| 43 |
+
self._zmq_ctx: Optional[zmq.Context] = None
|
| 44 |
+
|
| 45 |
+
def get_device_uuid(self) -> str:
|
| 46 |
+
"""Get the UUID of current device."""
|
| 47 |
+
# We need to implement this to get the device UUID
|
| 48 |
+
# This will be overridden when integrated into SGLang's worker
|
| 49 |
+
raise NotImplementedError(
|
| 50 |
+
"This method should be overridden by SGLang integration"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
def get_device_id(self) -> int:
|
| 54 |
+
"""Get the device ID."""
|
| 55 |
+
raise NotImplementedError(
|
| 56 |
+
"This method should be overridden by SGLang integration"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def get_model_loader(self) -> Callable:
|
| 60 |
+
"""Get the model weight loader function."""
|
| 61 |
+
raise NotImplementedError(
|
| 62 |
+
"This method should be overridden by SGLang integration"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def get_post_hook(self) -> Optional[Callable]:
|
| 66 |
+
"""Get the post-processing hook after weight loading."""
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
def update_weights_from_ipc(self, zmq_handles: Dict[str, str]):
|
| 70 |
+
"""
|
| 71 |
+
Update weights from IPC communication.
|
| 72 |
+
Args:
|
| 73 |
+
zmq_handles: Dict mapping device UUID to ZMQ socket path
|
| 74 |
+
"""
|
| 75 |
+
if self._zmq_ctx is None:
|
| 76 |
+
self._zmq_ctx = zmq.Context()
|
| 77 |
+
device_uuid = self.get_device_uuid()
|
| 78 |
+
device_id = self.get_device_id()
|
| 79 |
+
if device_uuid not in zmq_handles:
|
| 80 |
+
raise ValueError(
|
| 81 |
+
f"Device UUID {device_uuid} not found in zmq_handles: {list(zmq_handles.keys())}"
|
| 82 |
+
)
|
| 83 |
+
update_weights_from_ipc(
|
| 84 |
+
self._zmq_ctx,
|
| 85 |
+
zmq_handles[device_uuid],
|
| 86 |
+
device_id=device_id,
|
| 87 |
+
run=self.get_model_loader(),
|
| 88 |
+
post_hook=self.get_post_hook(),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class SGLangCheckpointEngineWorkerExtensionImpl(SGLangCheckpointEngineWorkerExtension):
|
| 93 |
+
"""
|
| 94 |
+
Implementation of SGLangCheckpointEngineWorkerExtension that integrates with SGLang's model runner.
|
| 95 |
+
This class provides the concrete implementation for checkpoint-engine IPC weight updates.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def __init__(self, model_runner):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.model_runner = model_runner
|
| 101 |
+
|
| 102 |
+
def get_device_uuid(self) -> str:
|
| 103 |
+
"""Get the UUID of current device."""
|
| 104 |
+
# Get device UUID for current device
|
| 105 |
+
device_id = torch.cuda.current_device()
|
| 106 |
+
try:
|
| 107 |
+
return f"GPU-{torch.cuda.get_device_properties(device_id).uuid!s}"
|
| 108 |
+
except AssertionError as e:
|
| 109 |
+
raise ValueError(f"Failed to get GPU UUID for device {device_id}") from e
|
| 110 |
+
|
| 111 |
+
def get_device_id(self) -> int:
|
| 112 |
+
"""Get the device ID."""
|
| 113 |
+
return torch.cuda.current_device()
|
| 114 |
+
|
| 115 |
+
def get_model_loader(self) -> Callable:
|
| 116 |
+
"""Get the model weight loader function."""
|
| 117 |
+
return self.model_runner.model.load_weights
|
| 118 |
+
|
| 119 |
+
def get_post_hook(self) -> Optional[Callable]:
|
| 120 |
+
"""Get the post-processing hook after weight loading."""
|
| 121 |
+
|
| 122 |
+
def post_hook():
|
| 123 |
+
# Perform post-processing after weight loading similar to DefaultModelLoader
|
| 124 |
+
try:
|
| 125 |
+
from sglang.srt.model_loader.loader import device_loading_context
|
| 126 |
+
|
| 127 |
+
# Process quantization methods after loading weights
|
| 128 |
+
for _, module in self.model_runner.model.named_modules():
|
| 129 |
+
quant_method = getattr(module, "quant_method", None)
|
| 130 |
+
if quant_method is not None:
|
| 131 |
+
# Move parameters to device if needed for quantization processing
|
| 132 |
+
target_device = torch.device(
|
| 133 |
+
"cuda", torch.cuda.current_device()
|
| 134 |
+
)
|
| 135 |
+
with device_loading_context(module, target_device):
|
| 136 |
+
quant_method.process_weights_after_loading(module)
|
| 137 |
+
# Call model-specific post-loading hook if available
|
| 138 |
+
if hasattr(self.model_runner.model, "post_load_weights"):
|
| 139 |
+
self.model_runner.model.post_load_weights()
|
| 140 |
+
except Exception as e:
|
| 141 |
+
logger.warning(f"Post-hook processing failed: {e}")
|
| 142 |
+
|
| 143 |
+
return post_hook
|
sglang/python/sglang/srt/checkpoint_engine/update.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Usage:
|
| 3 |
+
1) Launch the server with wait-for-initial-weights option in one terminal:
|
| 4 |
+
python -m sglang.launch_server --model-path /workspace/Qwen/Qwen3-4B/ --tensor-parallel-size 2 --port 19730 --load-format dummy --checkpoint-engine-wait-weights-before-ready --mem-fraction-static 0.7
|
| 5 |
+
|
| 6 |
+
2) Torchrun this script in another terminal:
|
| 7 |
+
torchrun --nproc-per-node 2 update.py --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2
|
| 8 |
+
|
| 9 |
+
Or use the integrated entry point:
|
| 10 |
+
python -m sglang.srt.checkpoint_engine.update --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import pickle
|
| 17 |
+
import subprocess
|
| 18 |
+
import sys
|
| 19 |
+
import time
|
| 20 |
+
from collections import defaultdict
|
| 21 |
+
from collections.abc import Callable
|
| 22 |
+
from contextlib import contextmanager
|
| 23 |
+
from typing import Literal
|
| 24 |
+
|
| 25 |
+
import httpx
|
| 26 |
+
import torch
|
| 27 |
+
import torch.distributed as dist
|
| 28 |
+
from safetensors import safe_open
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from checkpoint_engine.ps import ParameterServer
|
| 32 |
+
from loguru import logger
|
| 33 |
+
except ImportError:
|
| 34 |
+
# Fallback for when checkpoint_engine is not available
|
| 35 |
+
ParameterServer = None
|
| 36 |
+
import logging
|
| 37 |
+
|
| 38 |
+
logger = logging.getLogger(__name__)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@contextmanager
|
| 42 |
+
def timer(msg: str):
|
| 43 |
+
start = time.perf_counter()
|
| 44 |
+
yield
|
| 45 |
+
end = time.perf_counter()
|
| 46 |
+
logger.info(f"{msg} duration: {end - start:.2f} seconds")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def check_sglang_ready(
|
| 50 |
+
endpoint: str, inference_parallel_size: int, uds: str | None = None
|
| 51 |
+
):
|
| 52 |
+
rank = int(os.getenv("RANK", 0))
|
| 53 |
+
if rank != rank // inference_parallel_size * inference_parallel_size:
|
| 54 |
+
return
|
| 55 |
+
retry_num = 0
|
| 56 |
+
transport = None
|
| 57 |
+
if uds is not None:
|
| 58 |
+
transport = httpx.HTTPTransport(uds=uds)
|
| 59 |
+
with httpx.Client(transport=transport) as client:
|
| 60 |
+
while True:
|
| 61 |
+
try:
|
| 62 |
+
response = client.get(f"{endpoint}/ping", timeout=10)
|
| 63 |
+
response.raise_for_status()
|
| 64 |
+
break
|
| 65 |
+
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
|
| 66 |
+
if retry_num % 10 == 0:
|
| 67 |
+
logger.warning(
|
| 68 |
+
f"fail to check sglang ready, retry {retry_num} times, error: {e}"
|
| 69 |
+
)
|
| 70 |
+
retry_num += 1
|
| 71 |
+
time.sleep(0.1)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def split_checkpoint_files(
|
| 75 |
+
checkpoint_path: str, rank: int, world_size: int
|
| 76 |
+
) -> list[str]:
|
| 77 |
+
checkpoint_files = [
|
| 78 |
+
os.path.join(checkpoint_path, f)
|
| 79 |
+
for f in filter(
|
| 80 |
+
lambda x: x.endswith(".safetensors"), os.listdir(checkpoint_path)
|
| 81 |
+
)
|
| 82 |
+
]
|
| 83 |
+
files_per_rank = (len(checkpoint_files) + world_size - 1) // world_size
|
| 84 |
+
return checkpoint_files[rank * files_per_rank : (rank + 1) * files_per_rank]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def split_tensors(
|
| 88 |
+
checkpoint_path: str, rank: int, world_size: int
|
| 89 |
+
) -> dict[str, torch.Tensor]:
|
| 90 |
+
index_fn = os.path.join(checkpoint_path, "model.safetensors.index.json")
|
| 91 |
+
with open(index_fn) as f:
|
| 92 |
+
weight_map: dict[str, str] = json.load(f)["weight_map"]
|
| 93 |
+
weights_per_rank = (len(weight_map) + world_size - 1) // world_size
|
| 94 |
+
fn_tensors: dict[str, list[str]] = defaultdict(list)
|
| 95 |
+
weight_keys = list(weight_map.items())
|
| 96 |
+
for name, file in weight_keys[
|
| 97 |
+
rank * weights_per_rank : (rank + 1) * weights_per_rank
|
| 98 |
+
]:
|
| 99 |
+
fn_tensors[file].append(name)
|
| 100 |
+
named_tensors = {}
|
| 101 |
+
for file, names in fn_tensors.items():
|
| 102 |
+
with safe_open(os.path.join(checkpoint_path, file), framework="pt") as f:
|
| 103 |
+
for name in names:
|
| 104 |
+
named_tensors[name] = f.get_tensor(name)
|
| 105 |
+
return named_tensors
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def req_inference(
|
| 109 |
+
endpoint: str,
|
| 110 |
+
inference_parallel_size: int,
|
| 111 |
+
timeout: float = 300.0,
|
| 112 |
+
uds: str | None = None,
|
| 113 |
+
weight_version: str | None = None,
|
| 114 |
+
) -> Callable[[list[tuple[str, str]]], None]:
|
| 115 |
+
rank = int(os.getenv("RANK", 0))
|
| 116 |
+
src = rank // inference_parallel_size * inference_parallel_size
|
| 117 |
+
|
| 118 |
+
def req_func(socket_paths: list[tuple[str, str]]):
|
| 119 |
+
if rank == src:
|
| 120 |
+
with httpx.Client(transport=httpx.HTTPTransport(uds=uds)) as client:
|
| 121 |
+
resp = client.post(
|
| 122 |
+
f"{endpoint}/update_weights_from_ipc",
|
| 123 |
+
json={
|
| 124 |
+
"zmq_handles": dict(
|
| 125 |
+
socket_paths[src : src + inference_parallel_size]
|
| 126 |
+
),
|
| 127 |
+
"flush_cache": True,
|
| 128 |
+
"weight_version": weight_version,
|
| 129 |
+
},
|
| 130 |
+
timeout=timeout,
|
| 131 |
+
)
|
| 132 |
+
resp.raise_for_status()
|
| 133 |
+
|
| 134 |
+
return req_func
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def update_weights(
|
| 138 |
+
ps,
|
| 139 |
+
checkpoint_name: str,
|
| 140 |
+
checkpoint_files: list[str],
|
| 141 |
+
named_tensors: dict[str, torch.Tensor],
|
| 142 |
+
req_func: Callable[[list[tuple[str, str]]], None],
|
| 143 |
+
inference_parallel_size: int,
|
| 144 |
+
endpoint: str,
|
| 145 |
+
save_metas_file: str | None = None,
|
| 146 |
+
update_method: Literal["broadcast", "p2p", "all"] = "broadcast",
|
| 147 |
+
uds: str | None = None,
|
| 148 |
+
):
|
| 149 |
+
ps.register_checkpoint(
|
| 150 |
+
checkpoint_name, files=checkpoint_files, named_tensors=named_tensors
|
| 151 |
+
)
|
| 152 |
+
ps.init_process_group()
|
| 153 |
+
check_sglang_ready(endpoint, inference_parallel_size, uds)
|
| 154 |
+
dist.barrier()
|
| 155 |
+
with timer("Gather metas"):
|
| 156 |
+
ps.gather_metas(checkpoint_name)
|
| 157 |
+
if save_metas_file and int(os.getenv("RANK")) == 0:
|
| 158 |
+
with open(save_metas_file, "wb") as f:
|
| 159 |
+
pickle.dump(ps.get_metas(), f)
|
| 160 |
+
|
| 161 |
+
if update_method == "broadcast" or update_method == "all":
|
| 162 |
+
with timer("Update weights without setting ranks"):
|
| 163 |
+
ps.update(checkpoint_name, req_func)
|
| 164 |
+
|
| 165 |
+
if update_method == "p2p" or update_method == "all":
|
| 166 |
+
if update_method:
|
| 167 |
+
# sleep 2s to wait destroy process group
|
| 168 |
+
time.sleep(2)
|
| 169 |
+
with timer("Update weights with setting ranks"):
|
| 170 |
+
ps.update(
|
| 171 |
+
checkpoint_name, req_func, ranks=list(range(inference_parallel_size))
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def join(
|
| 176 |
+
ps: ParameterServer,
|
| 177 |
+
checkpoint_name: str,
|
| 178 |
+
load_metas_file: str,
|
| 179 |
+
req_func: Callable[[list[tuple[str, str]]], None],
|
| 180 |
+
inference_parallel_size: int,
|
| 181 |
+
endpoint: str,
|
| 182 |
+
uds: str | None = None,
|
| 183 |
+
):
|
| 184 |
+
assert load_metas_file, "load_metas_file is required"
|
| 185 |
+
with open(load_metas_file, "rb") as f:
|
| 186 |
+
metas = pickle.load(f)
|
| 187 |
+
ps.init_process_group()
|
| 188 |
+
check_sglang_ready(endpoint, inference_parallel_size, uds)
|
| 189 |
+
dist.barrier()
|
| 190 |
+
with timer("Gather metas before join"):
|
| 191 |
+
ps.gather_metas(checkpoint_name)
|
| 192 |
+
ps.load_metas(metas)
|
| 193 |
+
with timer(
|
| 194 |
+
f"Update weights with setting ranks as range(0, {inference_parallel_size}) by using p2p"
|
| 195 |
+
):
|
| 196 |
+
ps.update(checkpoint_name, req_func, ranks=list(range(inference_parallel_size)))
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def run_with_torchrun():
|
| 200 |
+
"""Run the update script with torchrun automatically."""
|
| 201 |
+
# Parse inference_parallel_size from command line arguments to determine nproc-per-node
|
| 202 |
+
inference_parallel_size = 8 # default
|
| 203 |
+
args = sys.argv[1:] # Skip the script name
|
| 204 |
+
|
| 205 |
+
# Look for --inference-parallel-size in arguments
|
| 206 |
+
for i, arg in enumerate(args):
|
| 207 |
+
if arg == "--inference-parallel-size" and i + 1 < len(args):
|
| 208 |
+
try:
|
| 209 |
+
inference_parallel_size = int(args[i + 1])
|
| 210 |
+
except ValueError:
|
| 211 |
+
pass
|
| 212 |
+
break
|
| 213 |
+
elif arg.startswith("--inference-parallel-size="):
|
| 214 |
+
try:
|
| 215 |
+
inference_parallel_size = int(arg.split("=", 1)[1])
|
| 216 |
+
except ValueError:
|
| 217 |
+
pass
|
| 218 |
+
break
|
| 219 |
+
|
| 220 |
+
# Build torchrun command
|
| 221 |
+
cmd = ["torchrun", f"--nproc-per-node={inference_parallel_size}", __file__] + args
|
| 222 |
+
|
| 223 |
+
print(f"Running: {' '.join(cmd)}", file=sys.stderr)
|
| 224 |
+
|
| 225 |
+
# Execute torchrun with the original script
|
| 226 |
+
try:
|
| 227 |
+
result = subprocess.run(cmd, check=False)
|
| 228 |
+
sys.exit(result.returncode)
|
| 229 |
+
except FileNotFoundError:
|
| 230 |
+
print(
|
| 231 |
+
"Error: torchrun command not found. Please ensure PyTorch is installed.",
|
| 232 |
+
file=sys.stderr,
|
| 233 |
+
)
|
| 234 |
+
sys.exit(1)
|
| 235 |
+
except KeyboardInterrupt:
|
| 236 |
+
print("\nInterrupted by user", file=sys.stderr)
|
| 237 |
+
sys.exit(130)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def main():
|
| 241 |
+
# Check if we're running under torchrun or need to invoke it
|
| 242 |
+
if os.getenv("RANK") is None:
|
| 243 |
+
# Not running under torchrun, so invoke it
|
| 244 |
+
run_with_torchrun()
|
| 245 |
+
return
|
| 246 |
+
|
| 247 |
+
# Running under torchrun, proceed with normal execution
|
| 248 |
+
parser = argparse.ArgumentParser(description="Update weights example")
|
| 249 |
+
parser.add_argument("--checkpoint-path", type=str, default=None)
|
| 250 |
+
parser.add_argument("--save-metas-file", type=str, default=None)
|
| 251 |
+
parser.add_argument("--load-metas-file", type=str, default=None)
|
| 252 |
+
parser.add_argument("--sleep-time", type=int, default=0)
|
| 253 |
+
parser.add_argument("--endpoint", type=str, default="http://localhost:19730")
|
| 254 |
+
parser.add_argument("--inference-parallel-size", type=int, default=8)
|
| 255 |
+
parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0")
|
| 256 |
+
parser.add_argument("--update-method", type=str, default="broadcast")
|
| 257 |
+
parser.add_argument("--uds", type=str, default=None)
|
| 258 |
+
parser.add_argument("--weight-version", type=str, default=None)
|
| 259 |
+
args = parser.parse_args()
|
| 260 |
+
|
| 261 |
+
# Get rank and world_size from environment (set by torchrun)
|
| 262 |
+
rank = int(os.getenv("RANK", 0))
|
| 263 |
+
world_size = int(os.getenv("WORLD_SIZE", 1))
|
| 264 |
+
|
| 265 |
+
req_func = req_inference(
|
| 266 |
+
args.endpoint,
|
| 267 |
+
args.inference_parallel_size,
|
| 268 |
+
uds=args.uds,
|
| 269 |
+
weight_version=args.weight_version,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
if ParameterServer is None:
|
| 273 |
+
print("Error: checkpoint_engine package not available", file=sys.stderr)
|
| 274 |
+
sys.exit(1)
|
| 275 |
+
|
| 276 |
+
ps = ParameterServer(auto_pg=True)
|
| 277 |
+
ps._p2p_store = None
|
| 278 |
+
if args.load_metas_file:
|
| 279 |
+
join(
|
| 280 |
+
ps,
|
| 281 |
+
args.checkpoint_name,
|
| 282 |
+
args.load_metas_file,
|
| 283 |
+
req_func,
|
| 284 |
+
args.inference_parallel_size,
|
| 285 |
+
args.endpoint,
|
| 286 |
+
args.uds,
|
| 287 |
+
)
|
| 288 |
+
else:
|
| 289 |
+
if args.checkpoint_path and os.path.exists(
|
| 290 |
+
os.path.join(args.checkpoint_path, "model.safetensors.index.json")
|
| 291 |
+
):
|
| 292 |
+
named_tensors = split_tensors(args.checkpoint_path, rank, world_size)
|
| 293 |
+
checkpoint_files = []
|
| 294 |
+
else:
|
| 295 |
+
checkpoint_files = (
|
| 296 |
+
split_checkpoint_files(args.checkpoint_path, rank, world_size)
|
| 297 |
+
if args.checkpoint_path
|
| 298 |
+
else []
|
| 299 |
+
)
|
| 300 |
+
named_tensors = {}
|
| 301 |
+
update_weights(
|
| 302 |
+
ps,
|
| 303 |
+
args.checkpoint_name,
|
| 304 |
+
checkpoint_files,
|
| 305 |
+
named_tensors,
|
| 306 |
+
req_func,
|
| 307 |
+
args.inference_parallel_size,
|
| 308 |
+
args.endpoint,
|
| 309 |
+
args.save_metas_file,
|
| 310 |
+
args.update_method,
|
| 311 |
+
args.uds,
|
| 312 |
+
)
|
| 313 |
+
time.sleep(args.sleep_time)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
if __name__ == "__main__":
|
| 317 |
+
main()
|
sglang/python/sglang/srt/compilation/__pycache__/compilation_config.cpython-311.pyc
ADDED
|
Binary file (2.79 kB). View file
|
|
|
sglang/python/sglang/srt/compilation/__pycache__/compile.cpython-311.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
sglang/python/sglang/srt/compilation/__pycache__/piecewise_context_manager.cpython-311.pyc
ADDED
|
Binary file (5.35 kB). View file
|
|
|
sglang/python/sglang/srt/compilation/backend.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/backend.py
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import ast
|
| 5 |
+
import dataclasses
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import pprint
|
| 9 |
+
import time
|
| 10 |
+
from collections.abc import Sequence
|
| 11 |
+
from contextlib import contextmanager
|
| 12 |
+
from typing import Any, Callable, Optional
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.fx as fx
|
| 16 |
+
from torch._dispatch.python import enable_python_dispatcher
|
| 17 |
+
|
| 18 |
+
from sglang.srt.compilation.compilation_config import CompilationConfig
|
| 19 |
+
from sglang.srt.compilation.compilation_counter import compilation_counter
|
| 20 |
+
from sglang.srt.compilation.compiler_interface import EagerAdapter, InductorAdaptor
|
| 21 |
+
from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend
|
| 22 |
+
from sglang.srt.compilation.npu_piecewise_backend import NPUPiecewiseBackend
|
| 23 |
+
from sglang.srt.compilation.pass_manager import PostGradPassManager
|
| 24 |
+
from sglang.srt.utils.common import is_npu, rank0_log
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def make_compiler(config: CompilationConfig):
|
| 30 |
+
if config.compiler == "eager":
|
| 31 |
+
return EagerAdapter()
|
| 32 |
+
elif config.compiler == "inductor":
|
| 33 |
+
return InductorAdaptor()
|
| 34 |
+
else:
|
| 35 |
+
raise ValueError(f"Unknown compiler: {config.compiler}")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def make_backend(
|
| 39 |
+
graph: fx.GraphModule,
|
| 40 |
+
compile_config: CompilationConfig,
|
| 41 |
+
inductor_config: dict[str, Any],
|
| 42 |
+
graph_pool: Any,
|
| 43 |
+
piecewise_compile_index: int,
|
| 44 |
+
total_piecewise_compiles: int,
|
| 45 |
+
sym_shape_indices: list[int],
|
| 46 |
+
compiled_graph_for_general_shape: Callable,
|
| 47 |
+
sglang_backend,
|
| 48 |
+
):
|
| 49 |
+
|
| 50 |
+
backend_cls = CUDAPiecewiseBackend if not is_npu() else NPUPiecewiseBackend
|
| 51 |
+
return backend_cls(
|
| 52 |
+
graph,
|
| 53 |
+
compile_config,
|
| 54 |
+
inductor_config,
|
| 55 |
+
graph_pool,
|
| 56 |
+
piecewise_compile_index,
|
| 57 |
+
total_piecewise_compiles,
|
| 58 |
+
sym_shape_indices,
|
| 59 |
+
compiled_graph_for_general_shape,
|
| 60 |
+
sglang_backend,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class CompilerManager:
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
config: CompilationConfig,
|
| 68 |
+
):
|
| 69 |
+
self.cache = dict()
|
| 70 |
+
self.is_cache_updated = False
|
| 71 |
+
self.compiler = make_compiler(config)
|
| 72 |
+
|
| 73 |
+
def compute_hash(self):
|
| 74 |
+
return self.compiler.compute_hash()
|
| 75 |
+
|
| 76 |
+
def initialize_cache(
|
| 77 |
+
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
| 78 |
+
):
|
| 79 |
+
self.disable_cache = disable_cache
|
| 80 |
+
self.cache_dir = cache_dir
|
| 81 |
+
self.cache_file_path = os.path.join(cache_dir, "sglang_compile_cache.py")
|
| 82 |
+
|
| 83 |
+
if not disable_cache and os.path.exists(self.cache_file_path):
|
| 84 |
+
with open(self.cache_file_path) as f:
|
| 85 |
+
self.cache = ast.literal_eval(f.read())
|
| 86 |
+
|
| 87 |
+
self.compiler.initialize_cache(
|
| 88 |
+
cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def save_to_file(self):
|
| 92 |
+
if self.disable_cache or not self.is_cache_updated:
|
| 93 |
+
return
|
| 94 |
+
printer = pprint.PrettyPrinter(indent=4)
|
| 95 |
+
data = printer.pformat(self.cache)
|
| 96 |
+
with open(self.cache_file_path, "w") as f:
|
| 97 |
+
f.write(data)
|
| 98 |
+
|
| 99 |
+
def load(
|
| 100 |
+
self,
|
| 101 |
+
graph: fx.GraphModule,
|
| 102 |
+
example_inputs: list[Any],
|
| 103 |
+
graph_index: int,
|
| 104 |
+
runtime_shape: Optional[int] = None,
|
| 105 |
+
) -> Optional[Callable]:
|
| 106 |
+
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
|
| 107 |
+
compiled_graph = self.compiler.load(
|
| 108 |
+
handle, graph, example_inputs, graph_index, runtime_shape
|
| 109 |
+
)
|
| 110 |
+
if runtime_shape is None:
|
| 111 |
+
logger.debug(
|
| 112 |
+
"Directly load the %s-th graph for dynamic shape from %s via "
|
| 113 |
+
"handle %s",
|
| 114 |
+
graph_index,
|
| 115 |
+
self.compiler.name,
|
| 116 |
+
handle,
|
| 117 |
+
)
|
| 118 |
+
else:
|
| 119 |
+
logger.debug(
|
| 120 |
+
"Directly load the %s-th graph for shape %s from %s via " "handle %s",
|
| 121 |
+
graph_index,
|
| 122 |
+
str(runtime_shape),
|
| 123 |
+
self.compiler.name,
|
| 124 |
+
handle,
|
| 125 |
+
)
|
| 126 |
+
return compiled_graph
|
| 127 |
+
|
| 128 |
+
def compile(
|
| 129 |
+
self,
|
| 130 |
+
graph: fx.GraphModule,
|
| 131 |
+
example_inputs,
|
| 132 |
+
inductor_config: dict[str, Any],
|
| 133 |
+
graph_index: int = 0,
|
| 134 |
+
num_graphs: int = 1,
|
| 135 |
+
runtime_shape: Optional[int] = None,
|
| 136 |
+
) -> Any:
|
| 137 |
+
if graph_index == 0:
|
| 138 |
+
# before compiling the first graph, record the start time
|
| 139 |
+
global compilation_start_time
|
| 140 |
+
compilation_start_time = time.time()
|
| 141 |
+
|
| 142 |
+
compilation_counter.num_backend_compilations += 1
|
| 143 |
+
|
| 144 |
+
compiled_graph = None
|
| 145 |
+
|
| 146 |
+
# TODO(Yuwei): support cache loading
|
| 147 |
+
|
| 148 |
+
# no compiler cached the graph, or the cache is disabled,
|
| 149 |
+
# we need to compile it
|
| 150 |
+
if isinstance(self.compiler, InductorAdaptor):
|
| 151 |
+
maybe_key = None
|
| 152 |
+
else:
|
| 153 |
+
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
|
| 154 |
+
compiled_graph, handle = self.compiler.compile(
|
| 155 |
+
graph, example_inputs, inductor_config, runtime_shape, maybe_key
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
assert compiled_graph is not None, "Failed to compile the graph"
|
| 159 |
+
|
| 160 |
+
# store the artifact in the cache
|
| 161 |
+
if handle is not None:
|
| 162 |
+
self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
|
| 163 |
+
compilation_counter.num_cache_entries_updated += 1
|
| 164 |
+
self.is_cache_updated = True
|
| 165 |
+
if graph_index == 0:
|
| 166 |
+
# adds some info logging for the first graph
|
| 167 |
+
if runtime_shape is None:
|
| 168 |
+
logger.info("Cache the graph for dynamic shape for later use")
|
| 169 |
+
else:
|
| 170 |
+
logger.info(
|
| 171 |
+
"Cache the graph of shape %s for later use", str(runtime_shape)
|
| 172 |
+
)
|
| 173 |
+
if runtime_shape is None:
|
| 174 |
+
logger.debug(
|
| 175 |
+
"Store the %s-th graph for dynamic shape from %s via " "handle %s",
|
| 176 |
+
graph_index,
|
| 177 |
+
self.compiler.name,
|
| 178 |
+
handle,
|
| 179 |
+
)
|
| 180 |
+
else:
|
| 181 |
+
logger.debug(
|
| 182 |
+
"Store the %s-th graph for shape %s from %s via handle %s",
|
| 183 |
+
graph_index,
|
| 184 |
+
str(runtime_shape),
|
| 185 |
+
self.compiler.name,
|
| 186 |
+
handle,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# after compiling the last graph, record the end time
|
| 190 |
+
if graph_index == num_graphs - 1:
|
| 191 |
+
now = time.time()
|
| 192 |
+
elapsed = now - compilation_start_time
|
| 193 |
+
if runtime_shape is None:
|
| 194 |
+
logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed)
|
| 195 |
+
else:
|
| 196 |
+
logger.info(
|
| 197 |
+
"Compiling a graph for shape %s takes %.2f s",
|
| 198 |
+
runtime_shape,
|
| 199 |
+
elapsed,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
return compiled_graph
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
@dataclasses.dataclass
|
| 206 |
+
class SplitItem:
|
| 207 |
+
submod_name: str
|
| 208 |
+
graph_id: int
|
| 209 |
+
is_splitting_graph: bool
|
| 210 |
+
graph: fx.GraphModule
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def split_graph(
|
| 214 |
+
graph: fx.GraphModule, ops: list[str]
|
| 215 |
+
) -> tuple[fx.GraphModule, list[SplitItem]]:
|
| 216 |
+
# split graph by ops
|
| 217 |
+
subgraph_id = 0
|
| 218 |
+
node_to_subgraph_id = {}
|
| 219 |
+
split_op_graphs = []
|
| 220 |
+
for node in graph.graph.nodes:
|
| 221 |
+
if node.op in ("output", "placeholder"):
|
| 222 |
+
continue
|
| 223 |
+
if node.op == "call_function" and str(node.target) in ops:
|
| 224 |
+
subgraph_id += 1
|
| 225 |
+
node_to_subgraph_id[node] = subgraph_id
|
| 226 |
+
split_op_graphs.append(subgraph_id)
|
| 227 |
+
subgraph_id += 1
|
| 228 |
+
else:
|
| 229 |
+
node_to_subgraph_id[node] = subgraph_id
|
| 230 |
+
|
| 231 |
+
# `keep_original_order` is important!
|
| 232 |
+
# otherwise pytorch might reorder the nodes and
|
| 233 |
+
# the semantics of the graph will change when we
|
| 234 |
+
# have mutations in the graph
|
| 235 |
+
split_gm = torch.fx.passes.split_module.split_module(
|
| 236 |
+
graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
outputs = []
|
| 240 |
+
|
| 241 |
+
names = [name for (name, module) in split_gm.named_modules()]
|
| 242 |
+
|
| 243 |
+
for name in names:
|
| 244 |
+
if "." in name or name == "":
|
| 245 |
+
# recursive child module or the root module
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
module = getattr(split_gm, name)
|
| 249 |
+
|
| 250 |
+
graph_id = int(name.replace("submod_", ""))
|
| 251 |
+
outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
|
| 252 |
+
|
| 253 |
+
# sort by intetger graph_id, rather than string name
|
| 254 |
+
outputs.sort(key=lambda x: x.graph_id)
|
| 255 |
+
|
| 256 |
+
return split_gm, outputs
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# we share the global graph pool among all the backends
|
| 260 |
+
global_graph_pool = None
|
| 261 |
+
|
| 262 |
+
compilation_start_time = 0.0
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
| 266 |
+
def __init__(
|
| 267 |
+
self,
|
| 268 |
+
module: torch.fx.GraphModule,
|
| 269 |
+
compile_submod_names: list[str],
|
| 270 |
+
inductor_config: dict[str, Any],
|
| 271 |
+
graph_pool,
|
| 272 |
+
compile_config: CompilationConfig,
|
| 273 |
+
sglang_backend: "SGLangBackend",
|
| 274 |
+
):
|
| 275 |
+
super().__init__(module)
|
| 276 |
+
from torch._guards import detect_fake_mode
|
| 277 |
+
|
| 278 |
+
self.fake_mode = detect_fake_mode()
|
| 279 |
+
self.compile_submod_names = compile_submod_names
|
| 280 |
+
self.graph_pool = graph_pool
|
| 281 |
+
self.sglang_backend = sglang_backend
|
| 282 |
+
# When True, it annoyingly dumps the torch.fx.Graph on errors.
|
| 283 |
+
self.extra_traceback = False
|
| 284 |
+
self.inductor_config = inductor_config
|
| 285 |
+
self.compile_config = compile_config
|
| 286 |
+
|
| 287 |
+
def run(self, *args):
|
| 288 |
+
fake_args = [
|
| 289 |
+
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
| 290 |
+
for t in args
|
| 291 |
+
]
|
| 292 |
+
with self.fake_mode, enable_python_dispatcher():
|
| 293 |
+
return super().run(*fake_args)
|
| 294 |
+
|
| 295 |
+
def call_module(
|
| 296 |
+
self,
|
| 297 |
+
target: torch.fx.node.Target,
|
| 298 |
+
args: tuple[torch.fx.node.Argument, ...],
|
| 299 |
+
kwargs: dict[str, Any],
|
| 300 |
+
) -> Any:
|
| 301 |
+
assert isinstance(target, str)
|
| 302 |
+
output = super().call_module(target, args, kwargs)
|
| 303 |
+
|
| 304 |
+
if target in self.compile_submod_names:
|
| 305 |
+
index = self.compile_submod_names.index(target)
|
| 306 |
+
submod = self.fetch_attr(target)
|
| 307 |
+
sym_shape_indices = [
|
| 308 |
+
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
| 309 |
+
]
|
| 310 |
+
global compilation_start_time
|
| 311 |
+
compiled_graph_for_dynamic_shape = (
|
| 312 |
+
self.sglang_backend.compiler_manager.compile(
|
| 313 |
+
submod,
|
| 314 |
+
args,
|
| 315 |
+
self.inductor_config,
|
| 316 |
+
graph_index=index,
|
| 317 |
+
num_graphs=len(self.compile_submod_names),
|
| 318 |
+
runtime_shape=None,
|
| 319 |
+
)
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
self.module.__dict__[target] = make_backend(
|
| 323 |
+
submod,
|
| 324 |
+
self.compile_config,
|
| 325 |
+
self.inductor_config,
|
| 326 |
+
self.graph_pool,
|
| 327 |
+
index,
|
| 328 |
+
len(self.compile_submod_names),
|
| 329 |
+
sym_shape_indices,
|
| 330 |
+
compiled_graph_for_dynamic_shape,
|
| 331 |
+
self.sglang_backend,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
| 335 |
+
|
| 336 |
+
return output
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
model_tag: str = "backbone"
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
@contextmanager
|
| 343 |
+
def set_model_tag(tag: str):
|
| 344 |
+
"""Context manager to set the model tag."""
|
| 345 |
+
global model_tag
|
| 346 |
+
assert (
|
| 347 |
+
tag != model_tag
|
| 348 |
+
), f"Model tag {tag} is the same as the current tag {model_tag}."
|
| 349 |
+
old_tag = model_tag
|
| 350 |
+
model_tag = tag
|
| 351 |
+
try:
|
| 352 |
+
yield
|
| 353 |
+
finally:
|
| 354 |
+
model_tag = old_tag
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class SGLangBackend:
|
| 358 |
+
|
| 359 |
+
graph_pool: Any
|
| 360 |
+
_called: bool = False
|
| 361 |
+
# the graph we compiled
|
| 362 |
+
graph: fx.GraphModule
|
| 363 |
+
# the stiching graph module for all the piecewise graphs
|
| 364 |
+
split_gm: fx.GraphModule
|
| 365 |
+
piecewise_graphs: list[SplitItem]
|
| 366 |
+
returned_callable: Callable
|
| 367 |
+
# Inductor passes to run on the graph pre-defunctionalization
|
| 368 |
+
post_grad_passes: Sequence[Callable]
|
| 369 |
+
sym_tensor_indices: list[int]
|
| 370 |
+
input_buffers: list[torch.Tensor]
|
| 371 |
+
compiler_manager: CompilerManager
|
| 372 |
+
|
| 373 |
+
def __init__(
|
| 374 |
+
self,
|
| 375 |
+
config: CompilationConfig,
|
| 376 |
+
graph_pool: Any,
|
| 377 |
+
):
|
| 378 |
+
rank0_log(f"Initializing SGLangBackend")
|
| 379 |
+
assert graph_pool is not None
|
| 380 |
+
self.graph_pool = graph_pool
|
| 381 |
+
|
| 382 |
+
self.post_grad_pass_manager = PostGradPassManager()
|
| 383 |
+
self.sym_tensor_indices = []
|
| 384 |
+
self.input_buffers = []
|
| 385 |
+
|
| 386 |
+
self.compiler_manager = CompilerManager(config)
|
| 387 |
+
self.inductor_config = {
|
| 388 |
+
"enable_auto_functionalized_v2": False,
|
| 389 |
+
}
|
| 390 |
+
self.compile_config = config
|
| 391 |
+
|
| 392 |
+
def configure_post_pass(self):
|
| 393 |
+
self.post_grad_pass_manager.configure()
|
| 394 |
+
self.inductor_config["post_grad_custom_post_pass"] = self.post_grad_pass_manager
|
| 395 |
+
|
| 396 |
+
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
| 397 |
+
rank0_log(f"SGLangBackend __call__")
|
| 398 |
+
base_cache_dir = os.path.expanduser(
|
| 399 |
+
os.getenv("SGLANG_CACHE_DIR", "~/.cache/sglang/")
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
cache_hash = self.compiler_manager.compute_hash()
|
| 403 |
+
cache_dir = os.path.join(
|
| 404 |
+
base_cache_dir,
|
| 405 |
+
"torch_compile_cache",
|
| 406 |
+
cache_hash,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 410 |
+
rank = 0
|
| 411 |
+
dp_rank = 0
|
| 412 |
+
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", model_tag)
|
| 413 |
+
os.makedirs(local_cache_dir, exist_ok=True)
|
| 414 |
+
self.compiler_manager.initialize_cache(
|
| 415 |
+
local_cache_dir, disable_cache=False, prefix=""
|
| 416 |
+
)
|
| 417 |
+
compilation_counter.num_graphs_seen += 1
|
| 418 |
+
|
| 419 |
+
assert not self._called, "SGLangBackend can only be called once"
|
| 420 |
+
|
| 421 |
+
self.graph = graph
|
| 422 |
+
self.configure_post_pass()
|
| 423 |
+
|
| 424 |
+
self.split_gm, self.piecewise_graphs = split_graph(
|
| 425 |
+
graph,
|
| 426 |
+
self.compile_config.split_ops,
|
| 427 |
+
)
|
| 428 |
+
from torch._dynamo.utils import lazy_format_graph_code
|
| 429 |
+
|
| 430 |
+
# depyf will hook lazy_format_graph_code and dump the graph
|
| 431 |
+
# for debugging, no need to print the graph here
|
| 432 |
+
lazy_format_graph_code("before split", self.graph)
|
| 433 |
+
lazy_format_graph_code("after split", self.split_gm)
|
| 434 |
+
|
| 435 |
+
compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
|
| 436 |
+
|
| 437 |
+
submod_names_to_compile = [
|
| 438 |
+
item.submod_name
|
| 439 |
+
for item in self.piecewise_graphs
|
| 440 |
+
if not item.is_splitting_graph
|
| 441 |
+
]
|
| 442 |
+
|
| 443 |
+
PiecewiseCompileInterpreter(
|
| 444 |
+
self.split_gm,
|
| 445 |
+
submod_names_to_compile,
|
| 446 |
+
self.inductor_config,
|
| 447 |
+
self.graph_pool,
|
| 448 |
+
self.compile_config,
|
| 449 |
+
self,
|
| 450 |
+
).run(*example_inputs)
|
| 451 |
+
|
| 452 |
+
rank = torch.distributed.get_rank()
|
| 453 |
+
|
| 454 |
+
if rank == 0:
|
| 455 |
+
graph_path = os.path.join(
|
| 456 |
+
local_cache_dir, f"computation_graph_{time.time()}.py"
|
| 457 |
+
)
|
| 458 |
+
if not os.path.exists(graph_path):
|
| 459 |
+
# code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
|
| 460 |
+
# use `print_readable` because it can include submodules
|
| 461 |
+
src = (
|
| 462 |
+
"from __future__ import annotations\nimport torch\n"
|
| 463 |
+
+ self.split_gm.print_readable(print_output=False)
|
| 464 |
+
)
|
| 465 |
+
src = src.replace("<lambda>", "GraphModule")
|
| 466 |
+
with open(graph_path, "w") as f:
|
| 467 |
+
f.write(src)
|
| 468 |
+
|
| 469 |
+
rank0_log(f"Computation graph saved to {graph_path}")
|
| 470 |
+
|
| 471 |
+
self._called = True
|
| 472 |
+
return self.split_gm
|
sglang/python/sglang/srt/compilation/compilation_config.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_config.py
|
| 2 |
+
|
| 3 |
+
from typing import Callable, List, Optional
|
| 4 |
+
|
| 5 |
+
SPLIT_OPS = []
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def register_split_op(op_name: Optional[str] = None):
|
| 9 |
+
def decorator(op_func: Callable):
|
| 10 |
+
name = op_name or op_func.__name__
|
| 11 |
+
SPLIT_OPS.append(f"sglang.{name}")
|
| 12 |
+
return op_func
|
| 13 |
+
|
| 14 |
+
return decorator
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# TODO(Yuwei): support better compile config support
|
| 18 |
+
class CompilationConfig:
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
capture_sizes: List[int],
|
| 22 |
+
compiler: str = "eager",
|
| 23 |
+
enable_debug_mode: bool = False,
|
| 24 |
+
):
|
| 25 |
+
self.traced_files = set()
|
| 26 |
+
self.capture_sizes = capture_sizes
|
| 27 |
+
self.compiler = compiler
|
| 28 |
+
self.enable_debug_mode = enable_debug_mode
|
| 29 |
+
self.split_ops = []
|
| 30 |
+
self.split_ops.extend(SPLIT_OPS)
|
| 31 |
+
|
| 32 |
+
def add_split_op(self, op: str):
|
| 33 |
+
self.split_ops.append(op)
|
| 34 |
+
|
| 35 |
+
def add_traced_file(self, file_path: str):
|
| 36 |
+
self.traced_files.add(file_path)
|
| 37 |
+
|
| 38 |
+
def get_traced_files(self):
|
| 39 |
+
return self.traced_files
|
| 40 |
+
|
| 41 |
+
def get_capture_sizes(self):
|
| 42 |
+
return self.capture_sizes
|
| 43 |
+
|
| 44 |
+
def get_enable_debug_mode(self):
|
| 45 |
+
return self.enable_debug_mode
|
sglang/python/sglang/srt/compilation/compilation_counter.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_counter.py
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import dataclasses
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclasses.dataclass
|
| 9 |
+
class CompilationCounter:
|
| 10 |
+
num_models_seen: int = 0
|
| 11 |
+
num_graphs_seen: int = 0
|
| 12 |
+
# including the splitting ops
|
| 13 |
+
num_piecewise_graphs_seen: int = 0
|
| 14 |
+
# not including the splitting ops
|
| 15 |
+
num_piecewise_capturable_graphs_seen: int = 0
|
| 16 |
+
num_backend_compilations: int = 0
|
| 17 |
+
# Number of gpu_model_runner attempts to trigger CUDAGraphs capture
|
| 18 |
+
num_gpu_runner_capture_triggers: int = 0
|
| 19 |
+
# Number of CUDAGraphs captured
|
| 20 |
+
num_cudagraph_captured: int = 0
|
| 21 |
+
# InductorAdapter.compile calls
|
| 22 |
+
num_inductor_compiles: int = 0
|
| 23 |
+
# EagerAdapter.compile calls
|
| 24 |
+
num_eager_compiles: int = 0
|
| 25 |
+
# The number of time vLLM's compiler cache entry was updated
|
| 26 |
+
num_cache_entries_updated: int = 0
|
| 27 |
+
# The number of standalone_compile compiled artifacts saved
|
| 28 |
+
num_compiled_artifacts_saved: int = 0
|
| 29 |
+
# Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS
|
| 30 |
+
dynamo_as_is_count: int = 0
|
| 31 |
+
|
| 32 |
+
def clone(self) -> "CompilationCounter":
|
| 33 |
+
return copy.deepcopy(self)
|
| 34 |
+
|
| 35 |
+
@contextmanager
|
| 36 |
+
def expect(self, **kwargs):
|
| 37 |
+
old = self.clone()
|
| 38 |
+
yield
|
| 39 |
+
for k, v in kwargs.items():
|
| 40 |
+
assert getattr(self, k) - getattr(old, k) == v, (
|
| 41 |
+
f"{k} not as expected, before it is {getattr(old, k)}"
|
| 42 |
+
f", after it is {getattr(self, k)}, "
|
| 43 |
+
f"expected diff is {v}"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
compilation_counter = CompilationCounter()
|
sglang/python/sglang/srt/compilation/compile.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import types
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Callable, Optional, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from sglang.srt.compilation.compilation_config import CompilationConfig
|
| 12 |
+
from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph
|
| 13 |
+
from sglang.srt.utils.common import rank0_log
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class IntermediateTensors:
|
| 20 |
+
"""For all pipeline stages except the last, we need to return the hidden
|
| 21 |
+
states and residuals to be sent to the next stage. This data structure
|
| 22 |
+
contains the hidden states and residuals for a request.
|
| 23 |
+
|
| 24 |
+
Each stage also needs to handle its own finished_sending and
|
| 25 |
+
finished_recving in case of kv transfer.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
tensors: dict[str, torch.Tensor]
|
| 29 |
+
# [req_ids]
|
| 30 |
+
finished_sending: Optional[set[str]] = None
|
| 31 |
+
finished_recving: Optional[set[str]] = None
|
| 32 |
+
|
| 33 |
+
def __init__(self, tensors):
|
| 34 |
+
# manually define this function, so that
|
| 35 |
+
# Dynamo knows `IntermediateTensors()` comes from this file.
|
| 36 |
+
# Otherwise, dataclass will generate this function by evaluating
|
| 37 |
+
# a string, and we will lose the information about the source file.
|
| 38 |
+
self.tensors = tensors
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, key: Union[str, slice]):
|
| 41 |
+
if isinstance(key, str):
|
| 42 |
+
return self.tensors[key]
|
| 43 |
+
elif isinstance(key, slice):
|
| 44 |
+
return self.__class__({k: v[key] for k, v in self.tensors.items()})
|
| 45 |
+
|
| 46 |
+
def __setitem__(self, key: str, value: torch.Tensor):
|
| 47 |
+
self.tensors[key] = value
|
| 48 |
+
|
| 49 |
+
def items(self):
|
| 50 |
+
return self.tensors.items()
|
| 51 |
+
|
| 52 |
+
def __len__(self):
|
| 53 |
+
return len(self.tensors)
|
| 54 |
+
|
| 55 |
+
def __eq__(self, other: object):
|
| 56 |
+
return isinstance(other, self.__class__) and self
|
| 57 |
+
|
| 58 |
+
def __repr__(self) -> str:
|
| 59 |
+
return f"IntermediateTensors(tensors={self.tensors})"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _normalize_dims(dims, ndim: int):
|
| 63 |
+
dims = [dims] if isinstance(dims, int) else list(dims)
|
| 64 |
+
return [d if d >= 0 else ndim + d for d in dims]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class _MaybeIntermediateTensors:
|
| 68 |
+
"""Duck-typed check to support your IntermediateTensors without importing."""
|
| 69 |
+
|
| 70 |
+
def __init__(self, obj):
|
| 71 |
+
self.is_intermediate = hasattr(obj, "tensors") and isinstance(
|
| 72 |
+
getattr(obj, "tensors"), dict
|
| 73 |
+
)
|
| 74 |
+
self.obj = obj
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _mark_dynamic_on_value(val, dims):
|
| 78 |
+
if isinstance(val, torch.Tensor):
|
| 79 |
+
torch._dynamo.maybe_mark_dynamic(val, _normalize_dims(dims, val.ndim))
|
| 80 |
+
else:
|
| 81 |
+
mit = _MaybeIntermediateTensors(val)
|
| 82 |
+
if mit.is_intermediate:
|
| 83 |
+
for t in mit.obj.tensors.values():
|
| 84 |
+
torch._dynamo.maybe_mark_dynamic(t, _normalize_dims(dims, t.ndim))
|
| 85 |
+
# else: ignore (None or non-tensor)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _infer_dynamic_arg_dims_from_annotations(forward_fn):
|
| 89 |
+
sig = inspect.signature(forward_fn)
|
| 90 |
+
dyn = {}
|
| 91 |
+
for name, p in sig.parameters.items():
|
| 92 |
+
ann = p.annotation
|
| 93 |
+
# Accept torch.Tensor / Optional[torch.Tensor] / your IntermediateTensors types by name
|
| 94 |
+
if (
|
| 95 |
+
ann is torch.Tensor
|
| 96 |
+
or getattr(getattr(ann, "__args__", [None])[0], "__name__", "") == "Tensor"
|
| 97 |
+
):
|
| 98 |
+
dyn[name] = 0
|
| 99 |
+
elif getattr(ann, "__name__", "") in ("IntermediateTensors",) or any(
|
| 100 |
+
getattr(a, "__name__", "") == "IntermediateTensors"
|
| 101 |
+
for a in getattr(ann, "__args__", [])
|
| 102 |
+
):
|
| 103 |
+
dyn[name] = 0
|
| 104 |
+
elif ann == "torch.Tensor" or ann == "Optional[torch.Tensor]":
|
| 105 |
+
# For future import annotations (e.g. from __future__ import annotations), the annotation is a string
|
| 106 |
+
dyn[name] = 0
|
| 107 |
+
if not dyn:
|
| 108 |
+
raise ValueError("No dynamic dims inferred; pass dynamic_arg_dims explicitly.")
|
| 109 |
+
return dyn
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def install_torch_compiled(
|
| 113 |
+
module: torch.nn.Module,
|
| 114 |
+
*,
|
| 115 |
+
dynamic_arg_dims: dict[str, Union[int, list[int]]] | None = None,
|
| 116 |
+
backend_factory: Optional[Callable[[torch.fx.GraphModule, list], Callable]] = None,
|
| 117 |
+
compile_config: CompilationConfig = None,
|
| 118 |
+
fullgraph: bool = True,
|
| 119 |
+
graph_pool: Any = None,
|
| 120 |
+
):
|
| 121 |
+
rank0_log(f"install_torch_compiled")
|
| 122 |
+
unbound_fwd = module.__class__.forward
|
| 123 |
+
if not callable(unbound_fwd):
|
| 124 |
+
raise TypeError("module.__class__.forward must be callable")
|
| 125 |
+
original_code = unbound_fwd.__code__
|
| 126 |
+
|
| 127 |
+
dyn_map = dynamic_arg_dims or _infer_dynamic_arg_dims_from_annotations(unbound_fwd)
|
| 128 |
+
|
| 129 |
+
if backend_factory is None:
|
| 130 |
+
from sglang.srt.compilation.backend import SGLangBackend
|
| 131 |
+
|
| 132 |
+
backend_factory = lambda gm, ex: SGLangBackend(compile_config, graph_pool)(
|
| 133 |
+
gm, ex
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
compiled_codes: list[type(original_code)] = []
|
| 137 |
+
state = {"compiled": False, "compiled_callable": None}
|
| 138 |
+
|
| 139 |
+
def bytecode_hook(old_code, new_code):
|
| 140 |
+
if old_code is not original_code:
|
| 141 |
+
return
|
| 142 |
+
frame = sys._getframe()
|
| 143 |
+
while frame and frame.f_back:
|
| 144 |
+
frame = frame.f_back
|
| 145 |
+
if (
|
| 146 |
+
frame.f_code.co_name == "_compile"
|
| 147 |
+
and os.path.basename(frame.f_code.co_filename) == "convert_frame.py"
|
| 148 |
+
):
|
| 149 |
+
break
|
| 150 |
+
try:
|
| 151 |
+
dynamo_frame = frame.f_locals["frame"]
|
| 152 |
+
except Exception:
|
| 153 |
+
return
|
| 154 |
+
if dynamo_frame.f_code is not old_code:
|
| 155 |
+
return
|
| 156 |
+
if dynamo_frame.f_locals.get("self") is not module:
|
| 157 |
+
return
|
| 158 |
+
compiled_codes.append(new_code)
|
| 159 |
+
|
| 160 |
+
torch._dynamo.convert_frame.register_bytecode_hook(bytecode_hook)
|
| 161 |
+
|
| 162 |
+
def _ensure_compiled(self, *args, **kwargs):
|
| 163 |
+
"""Compile on first use (with flag ON)."""
|
| 164 |
+
if state["compiled"]:
|
| 165 |
+
return
|
| 166 |
+
# Mark dynamic dims only when we are about to compile
|
| 167 |
+
sig = inspect.signature(unbound_fwd)
|
| 168 |
+
ba = sig.bind(self, *args, **kwargs)
|
| 169 |
+
ba.apply_defaults()
|
| 170 |
+
for name, dims in (dyn_map or {}).items():
|
| 171 |
+
if name in ba.arguments:
|
| 172 |
+
val = ba.arguments[name]
|
| 173 |
+
if val is not None:
|
| 174 |
+
_mark_dynamic_on_value(val, dims)
|
| 175 |
+
|
| 176 |
+
# Avoid cross-instance cache reuse
|
| 177 |
+
torch._dynamo.eval_frame.remove_from_cache(unbound_fwd.__code__)
|
| 178 |
+
|
| 179 |
+
bound = types.MethodType(unbound_fwd, self)
|
| 180 |
+
compiled_callable = torch.compile(
|
| 181 |
+
bound, fullgraph=fullgraph, backend=backend_factory
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Trigger Dynamo so bytecode hook can capture
|
| 185 |
+
compiled_callable(*args, **kwargs)
|
| 186 |
+
|
| 187 |
+
state["compiled"] = True
|
| 188 |
+
state["compiled_callable"] = compiled_callable
|
| 189 |
+
|
| 190 |
+
def trampoline(self, *args, **kwargs):
|
| 191 |
+
use_compiled = is_in_piecewise_cuda_graph()
|
| 192 |
+
if use_compiled:
|
| 193 |
+
if not state["compiled"]:
|
| 194 |
+
_ensure_compiled(self, *args, **kwargs)
|
| 195 |
+
|
| 196 |
+
compiled_callable = state["compiled_callable"]
|
| 197 |
+
return compiled_callable(*args, **kwargs)
|
| 198 |
+
else:
|
| 199 |
+
# Explicitly run the original uncompiled forward
|
| 200 |
+
return unbound_fwd(self, *args, **kwargs)
|
| 201 |
+
|
| 202 |
+
module.forward = types.MethodType(trampoline, module)
|
| 203 |
+
return module
|
sglang/python/sglang/srt/compilation/compiler_interface.py
ADDED
|
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compiler_interface.py
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import copy
|
| 5 |
+
import hashlib
|
| 6 |
+
import os
|
| 7 |
+
from contextlib import ExitStack
|
| 8 |
+
from typing import Any, Callable, Optional
|
| 9 |
+
from unittest.mock import patch
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch._inductor.compile_fx
|
| 13 |
+
import torch.fx as fx
|
| 14 |
+
|
| 15 |
+
from sglang.srt.compilation.compilation_counter import compilation_counter
|
| 16 |
+
from sglang.srt.compilation.inductor_pass import pass_context
|
| 17 |
+
from sglang.srt.utils.common import torch_release
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class CompilerInterface:
|
| 21 |
+
"""
|
| 22 |
+
The interface for a compiler that can be used by vLLM.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
# The name of the compiler, e.g. inductor.
|
| 26 |
+
# This is a class-level attribute.
|
| 27 |
+
name: str
|
| 28 |
+
|
| 29 |
+
def initialize_cache(
|
| 30 |
+
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
when the vLLM process uses `cache_dir` as the cache directory,
|
| 34 |
+
the compiler should initialize itself with the cache directory,
|
| 35 |
+
e.g. by re-directing its own cache directory to a sub-directory.
|
| 36 |
+
|
| 37 |
+
prefix can be used in combination with cache_dir to figure out the base
|
| 38 |
+
cache directory, e.g. there're multiple parts of model being compiled,
|
| 39 |
+
but we want to share the same cache directory for all of them.
|
| 40 |
+
|
| 41 |
+
e.g.
|
| 42 |
+
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
|
| 43 |
+
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
|
| 44 |
+
"""
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
def compute_hash(self) -> str:
|
| 48 |
+
"""
|
| 49 |
+
Gather all the relevant information from the vLLM config,
|
| 50 |
+
to compute a hash so that we can cache the compiled model.
|
| 51 |
+
|
| 52 |
+
See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash]
|
| 53 |
+
to check what information
|
| 54 |
+
is already considered by default. This function should only
|
| 55 |
+
consider the information that is specific to the compiler.
|
| 56 |
+
"""
|
| 57 |
+
return ""
|
| 58 |
+
|
| 59 |
+
def compile(
|
| 60 |
+
self,
|
| 61 |
+
graph: fx.GraphModule,
|
| 62 |
+
example_inputs: list[Any],
|
| 63 |
+
compiler_config: dict[str, Any],
|
| 64 |
+
runtime_shape: Optional[int] = None,
|
| 65 |
+
key: Optional[str] = None,
|
| 66 |
+
) -> tuple[Optional[Callable], Optional[Any]]:
|
| 67 |
+
"""
|
| 68 |
+
Compile the graph with the given example inputs and compiler config,
|
| 69 |
+
with a runtime shape. If the `runtime_shape` is None, it means
|
| 70 |
+
the `example_inputs` have a dynamic shape. Otherwise, the
|
| 71 |
+
`runtime_shape` specifies the shape of the inputs. Right now we only
|
| 72 |
+
support one variable shape for all inputs, which is the batchsize
|
| 73 |
+
(number of tokens) during inference.
|
| 74 |
+
|
| 75 |
+
Dynamo will make sure `graph(*example_inputs)` is valid.
|
| 76 |
+
|
| 77 |
+
The function should return a compiled callable function, as well as
|
| 78 |
+
a handle that can be used to directly load the compiled function.
|
| 79 |
+
|
| 80 |
+
The handle should be a plain Python object, preferably a string or a
|
| 81 |
+
file path for readability.
|
| 82 |
+
|
| 83 |
+
If the compiler doesn't support caching, it should return None for the
|
| 84 |
+
handle. If the compiler fails to compile the graph, it should return
|
| 85 |
+
None for the compiled function as well.
|
| 86 |
+
|
| 87 |
+
`key` is required for StandaloneInductorAdapter, it specifies where to
|
| 88 |
+
save the compiled artifact. The compiled artifact gets saved to
|
| 89 |
+
`cache_dir/key`.
|
| 90 |
+
"""
|
| 91 |
+
return None, None
|
| 92 |
+
|
| 93 |
+
def load(
|
| 94 |
+
self,
|
| 95 |
+
handle: Any,
|
| 96 |
+
graph: fx.GraphModule,
|
| 97 |
+
example_inputs: list[Any],
|
| 98 |
+
graph_index: int,
|
| 99 |
+
runtime_shape: Optional[int] = None,
|
| 100 |
+
) -> Callable:
|
| 101 |
+
"""
|
| 102 |
+
Load the compiled function from the handle.
|
| 103 |
+
Raises an error if the handle is invalid.
|
| 104 |
+
|
| 105 |
+
The handle is the second return value of the `compile` function.
|
| 106 |
+
"""
|
| 107 |
+
raise NotImplementedError("caching is not supported")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_inductor_factors() -> list[Any]:
|
| 111 |
+
factors: list[Any] = []
|
| 112 |
+
# summarize system state
|
| 113 |
+
from torch._inductor.codecache import CacheBase
|
| 114 |
+
|
| 115 |
+
system_factors = CacheBase.get_system()
|
| 116 |
+
factors.append(system_factors)
|
| 117 |
+
|
| 118 |
+
# summarize pytorch state
|
| 119 |
+
from torch._inductor.codecache import torch_key
|
| 120 |
+
|
| 121 |
+
torch_factors = torch_key()
|
| 122 |
+
factors.append(torch_factors)
|
| 123 |
+
return factors
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class AlwaysHitShapeEnv:
|
| 127 |
+
"""
|
| 128 |
+
Why do we need this class:
|
| 129 |
+
|
| 130 |
+
For normal `torch.compile` usage, every compilation will have
|
| 131 |
+
one Dynamo bytecode compilation and one Inductor compilation.
|
| 132 |
+
The Inductor compilation happens under the context of the
|
| 133 |
+
Dynamo bytecode compilation, and that context is used to
|
| 134 |
+
determine the dynamic shape information, etc.
|
| 135 |
+
|
| 136 |
+
For our use case, we only run Dynamo bytecode compilation once,
|
| 137 |
+
and run Inductor compilation multiple times with different shapes
|
| 138 |
+
plus a general shape. The compilation for specific shapes happens
|
| 139 |
+
outside of the context of the Dynamo bytecode compilation. At that
|
| 140 |
+
time, we don't have shape environment to provide to Inductor, and
|
| 141 |
+
it will fail the Inductor code cache lookup.
|
| 142 |
+
|
| 143 |
+
By providing a dummy shape environment that always hits, we can
|
| 144 |
+
make the Inductor code cache lookup always hit, and we can
|
| 145 |
+
compile the graph for different shapes as needed.
|
| 146 |
+
|
| 147 |
+
The following dummy methods are obtained by trial-and-error
|
| 148 |
+
until it works.
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
def __init__(self) -> None:
|
| 152 |
+
self.guards: list[Any] = []
|
| 153 |
+
|
| 154 |
+
def evaluate_guards_expression(self, *args, **kwargs):
|
| 155 |
+
return True
|
| 156 |
+
|
| 157 |
+
def get_pruned_guards(self, *args, **kwargs):
|
| 158 |
+
return []
|
| 159 |
+
|
| 160 |
+
def produce_guards_expression(self, *args, **kwargs):
|
| 161 |
+
return ""
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class InductorAdaptor(CompilerInterface):
|
| 165 |
+
"""
|
| 166 |
+
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
name = "inductor"
|
| 170 |
+
|
| 171 |
+
def compute_hash(self) -> str:
|
| 172 |
+
factors = get_inductor_factors()
|
| 173 |
+
hash_str = hashlib.md5(
|
| 174 |
+
str(factors).encode(), usedforsecurity=False
|
| 175 |
+
).hexdigest()[:10]
|
| 176 |
+
return hash_str
|
| 177 |
+
|
| 178 |
+
def initialize_cache(
|
| 179 |
+
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
| 180 |
+
):
|
| 181 |
+
self.cache_dir = cache_dir
|
| 182 |
+
self.prefix = prefix
|
| 183 |
+
self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
|
| 184 |
+
if disable_cache:
|
| 185 |
+
return
|
| 186 |
+
# redirect the cache directory to a sub-directory
|
| 187 |
+
# set flags so that Inductor and Triton store their cache
|
| 188 |
+
# in the cache_dir, then users only need to copy the cache_dir
|
| 189 |
+
# to another machine to reuse the cache.
|
| 190 |
+
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
|
| 191 |
+
os.makedirs(inductor_cache, exist_ok=True)
|
| 192 |
+
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
|
| 193 |
+
triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
|
| 194 |
+
os.makedirs(triton_cache, exist_ok=True)
|
| 195 |
+
os.environ["TRITON_CACHE_DIR"] = triton_cache
|
| 196 |
+
|
| 197 |
+
def compile(
|
| 198 |
+
self,
|
| 199 |
+
graph: fx.GraphModule,
|
| 200 |
+
example_inputs: list[Any],
|
| 201 |
+
compiler_config: dict[str, Any],
|
| 202 |
+
runtime_shape: Optional[int] = None,
|
| 203 |
+
key: Optional[str] = None,
|
| 204 |
+
) -> tuple[Optional[Callable], Optional[Any]]:
|
| 205 |
+
compilation_counter.num_inductor_compiles += 1
|
| 206 |
+
from torch._inductor.compile_fx import compile_fx
|
| 207 |
+
|
| 208 |
+
current_config = {}
|
| 209 |
+
if compiler_config is not None:
|
| 210 |
+
current_config.update(compiler_config)
|
| 211 |
+
|
| 212 |
+
# disable remote cache
|
| 213 |
+
current_config["fx_graph_cache"] = True
|
| 214 |
+
current_config["fx_graph_remote_cache"] = False
|
| 215 |
+
|
| 216 |
+
set_inductor_config(current_config, runtime_shape)
|
| 217 |
+
|
| 218 |
+
# inductor can inplace modify the graph, so we need to copy it
|
| 219 |
+
# see https://github.com/pytorch/pytorch/issues/138980
|
| 220 |
+
graph = copy.deepcopy(graph)
|
| 221 |
+
|
| 222 |
+
# it's the first time we compile this graph
|
| 223 |
+
# the assumption is that we don't have nested Inductor compilation.
|
| 224 |
+
# compiled_fx_graph_hash will only be called once, and we can hook
|
| 225 |
+
# it to get the hash of the compiled graph directly.
|
| 226 |
+
|
| 227 |
+
hash_str, file_path = None, None
|
| 228 |
+
from torch._inductor.codecache import FxGraphCache, compiled_fx_graph_hash
|
| 229 |
+
|
| 230 |
+
if torch_release[:2] == (2, 5):
|
| 231 |
+
original_load = FxGraphCache.load
|
| 232 |
+
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
|
| 233 |
+
|
| 234 |
+
def hijack_load(*args, **kwargs):
|
| 235 |
+
inductor_compiled_graph = original_load(*args, **kwargs)
|
| 236 |
+
nonlocal file_path
|
| 237 |
+
compiled_fn = inductor_compiled_graph.current_callable
|
| 238 |
+
file_path = compiled_fn.__code__.co_filename # noqa
|
| 239 |
+
if not file_path.startswith(self.base_cache_dir):
|
| 240 |
+
# hooked in the align_inputs_from_check_idxs function
|
| 241 |
+
# in torch/_inductor/utils.py
|
| 242 |
+
for cell in compiled_fn.__closure__:
|
| 243 |
+
if not callable(cell.cell_contents):
|
| 244 |
+
continue
|
| 245 |
+
if cell.cell_contents.__code__.co_filename.startswith(
|
| 246 |
+
self.base_cache_dir
|
| 247 |
+
):
|
| 248 |
+
# this is the real file path compiled from Inductor
|
| 249 |
+
file_path = cell.cell_contents.__code__.co_filename
|
| 250 |
+
break
|
| 251 |
+
return inductor_compiled_graph
|
| 252 |
+
|
| 253 |
+
hijacked_compile_fx_inner = (
|
| 254 |
+
torch._inductor.compile_fx.compile_fx_inner
|
| 255 |
+
) # noqa
|
| 256 |
+
elif torch_release >= (2, 6):
|
| 257 |
+
# function renamed in 2.6
|
| 258 |
+
original_load_name = None
|
| 259 |
+
|
| 260 |
+
def hijacked_compile_fx_inner(*args, **kwargs):
|
| 261 |
+
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
|
| 262 |
+
nonlocal hash_str
|
| 263 |
+
inductor_compiled_graph = output
|
| 264 |
+
if inductor_compiled_graph is not None:
|
| 265 |
+
nonlocal file_path
|
| 266 |
+
compiled_fn = inductor_compiled_graph.current_callable
|
| 267 |
+
file_path = compiled_fn.__code__.co_filename # noqa
|
| 268 |
+
if not file_path.startswith(self.base_cache_dir):
|
| 269 |
+
# hooked in the align_inputs_from_check_idxs function
|
| 270 |
+
# in torch/_inductor/utils.py
|
| 271 |
+
for cell in compiled_fn.__closure__:
|
| 272 |
+
if not callable(cell.cell_contents):
|
| 273 |
+
continue
|
| 274 |
+
code = cell.cell_contents.__code__
|
| 275 |
+
if code.co_filename.startswith(self.base_cache_dir):
|
| 276 |
+
# this is the real file path
|
| 277 |
+
# compiled from Inductor
|
| 278 |
+
file_path = code.co_filename
|
| 279 |
+
break
|
| 280 |
+
hash_str = inductor_compiled_graph._fx_graph_cache_key
|
| 281 |
+
return output
|
| 282 |
+
|
| 283 |
+
def hijack_compiled_fx_graph_hash(*args, **kwargs):
|
| 284 |
+
out = compiled_fx_graph_hash(*args, **kwargs)
|
| 285 |
+
nonlocal hash_str
|
| 286 |
+
hash_str = out[0]
|
| 287 |
+
return out
|
| 288 |
+
|
| 289 |
+
def _check_can_cache(*args, **kwargs):
|
| 290 |
+
# no error means it can be cached.
|
| 291 |
+
# Inductor refuses to cache the graph outside of Dynamo
|
| 292 |
+
# tracing context, and also disables caching for graphs
|
| 293 |
+
# with high-order ops.
|
| 294 |
+
# For vLLM, in either case, we want to cache the graph.
|
| 295 |
+
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
|
| 296 |
+
return
|
| 297 |
+
|
| 298 |
+
def _get_shape_env() -> AlwaysHitShapeEnv:
|
| 299 |
+
return AlwaysHitShapeEnv()
|
| 300 |
+
|
| 301 |
+
with ExitStack() as stack:
|
| 302 |
+
# hijack to get the compiled graph itself
|
| 303 |
+
if original_load_name is not None:
|
| 304 |
+
stack.enter_context(patch(original_load_name, hijack_load))
|
| 305 |
+
|
| 306 |
+
# for hijacking the hash of the compiled graph
|
| 307 |
+
stack.enter_context(
|
| 308 |
+
patch(
|
| 309 |
+
"torch._inductor.codecache.compiled_fx_graph_hash",
|
| 310 |
+
hijack_compiled_fx_graph_hash,
|
| 311 |
+
)
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# for providing a dummy shape environment
|
| 315 |
+
stack.enter_context(
|
| 316 |
+
patch(
|
| 317 |
+
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
| 318 |
+
_get_shape_env,
|
| 319 |
+
)
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
| 323 |
+
|
| 324 |
+
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
| 325 |
+
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
| 326 |
+
stack.enter_context(
|
| 327 |
+
patch(
|
| 328 |
+
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
| 329 |
+
_get_shape_env,
|
| 330 |
+
)
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# for forcing the graph to be cached
|
| 334 |
+
stack.enter_context(
|
| 335 |
+
patch(
|
| 336 |
+
"torch._inductor.codecache.FxGraphCache._check_can_cache",
|
| 337 |
+
_check_can_cache,
|
| 338 |
+
)
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# Dynamo metrics context, see method for more details.
|
| 342 |
+
stack.enter_context(self.metrics_context())
|
| 343 |
+
|
| 344 |
+
# Disable remote caching. When these are on, on remote cache-hit,
|
| 345 |
+
# the monkey-patched functions never actually get called.
|
| 346 |
+
# vLLM today assumes and requires the monkey-patched functions to
|
| 347 |
+
# get hit.
|
| 348 |
+
# TODO(zou3519): we're going to replace this all with
|
| 349 |
+
# standalone_compile sometime.
|
| 350 |
+
|
| 351 |
+
stack.enter_context(
|
| 352 |
+
torch._inductor.config.patch(fx_graph_remote_cache=False)
|
| 353 |
+
)
|
| 354 |
+
# InductorAdaptor (unfortunately) requires AOTAutogradCache
|
| 355 |
+
# to be turned off to run. It will fail to acquire the hash_str
|
| 356 |
+
# and error if not.
|
| 357 |
+
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
|
| 358 |
+
stack.enter_context(
|
| 359 |
+
torch._functorch.config.patch(enable_autograd_cache=False)
|
| 360 |
+
)
|
| 361 |
+
stack.enter_context(
|
| 362 |
+
torch._functorch.config.patch(enable_remote_autograd_cache=False)
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
with pass_context(runtime_shape):
|
| 366 |
+
compiled_graph = compile_fx(
|
| 367 |
+
graph,
|
| 368 |
+
example_inputs,
|
| 369 |
+
inner_compile=hijacked_compile_fx_inner,
|
| 370 |
+
config_patches=current_config,
|
| 371 |
+
)
|
| 372 |
+
return compiled_graph, (hash_str, file_path)
|
| 373 |
+
|
| 374 |
+
def load(
|
| 375 |
+
self,
|
| 376 |
+
handle: Any,
|
| 377 |
+
graph: fx.GraphModule,
|
| 378 |
+
example_inputs: list[Any],
|
| 379 |
+
graph_index: int,
|
| 380 |
+
runtime_shape: Optional[int] = None,
|
| 381 |
+
) -> Callable:
|
| 382 |
+
assert isinstance(handle, tuple)
|
| 383 |
+
assert isinstance(handle[0], str)
|
| 384 |
+
assert isinstance(handle[1], str)
|
| 385 |
+
hash_str = handle[0]
|
| 386 |
+
|
| 387 |
+
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
| 388 |
+
from torch._inductor.codecache import FxGraphCache
|
| 389 |
+
|
| 390 |
+
with ExitStack() as exit_stack:
|
| 391 |
+
exit_stack.enter_context(
|
| 392 |
+
patch(
|
| 393 |
+
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
| 394 |
+
lambda *args, **kwargs: AlwaysHitShapeEnv(),
|
| 395 |
+
)
|
| 396 |
+
)
|
| 397 |
+
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
| 398 |
+
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
| 399 |
+
exit_stack.enter_context(
|
| 400 |
+
patch(
|
| 401 |
+
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
| 402 |
+
lambda *args, **kwargs: AlwaysHitShapeEnv(),
|
| 403 |
+
)
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
# Dynamo metrics context, see method for more details.
|
| 407 |
+
exit_stack.enter_context(self.metrics_context())
|
| 408 |
+
|
| 409 |
+
if torch_release[:2] == (2, 5):
|
| 410 |
+
inductor_compiled_graph = FxGraphCache._lookup_graph(
|
| 411 |
+
hash_str, example_inputs, True, False
|
| 412 |
+
)
|
| 413 |
+
assert inductor_compiled_graph is not None, (
|
| 414 |
+
"Inductor cache lookup failed. Please remove"
|
| 415 |
+
f"the cache directory and try again." # noqa
|
| 416 |
+
)
|
| 417 |
+
elif torch_release >= (2, 6):
|
| 418 |
+
from torch._inductor.output_code import CompiledFxGraphConstantsWithGm
|
| 419 |
+
|
| 420 |
+
constants = CompiledFxGraphConstantsWithGm(graph)
|
| 421 |
+
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
|
| 422 |
+
hash_str, example_inputs, True, None, constants
|
| 423 |
+
)
|
| 424 |
+
assert inductor_compiled_graph is not None, (
|
| 425 |
+
"Inductor cache lookup failed. Please remove"
|
| 426 |
+
f"the cache directory and try again." # noqa
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
# Inductor calling convention (function signature):
|
| 430 |
+
# f(list) -> tuple
|
| 431 |
+
# Dynamo calling convention (function signature):
|
| 432 |
+
# f(*args) -> Any
|
| 433 |
+
|
| 434 |
+
# need to know if the graph returns a tuple
|
| 435 |
+
from torch._inductor.compile_fx import graph_returns_tuple
|
| 436 |
+
|
| 437 |
+
returns_tuple = graph_returns_tuple(graph)
|
| 438 |
+
|
| 439 |
+
# this is the callable we return to Dynamo to run
|
| 440 |
+
def compiled_graph(*args):
|
| 441 |
+
# convert args to list
|
| 442 |
+
list_args = list(args)
|
| 443 |
+
graph_output = inductor_compiled_graph(list_args)
|
| 444 |
+
# unpack the tuple if needed
|
| 445 |
+
if returns_tuple:
|
| 446 |
+
return graph_output
|
| 447 |
+
else:
|
| 448 |
+
return graph_output[0]
|
| 449 |
+
|
| 450 |
+
return compiled_graph
|
| 451 |
+
|
| 452 |
+
def metrics_context(self) -> contextlib.AbstractContextManager:
|
| 453 |
+
"""
|
| 454 |
+
This method returns the Dynamo metrics context (if it exists,
|
| 455 |
+
otherwise a null context). It is used by various compile components.
|
| 456 |
+
Present in torch>=2.6, it's used inside FxGraphCache in
|
| 457 |
+
torch==2.6 (but not after). It might also be used in various other
|
| 458 |
+
torch.compile internal functions.
|
| 459 |
+
|
| 460 |
+
Because it is re-entrant, we always set it (even if entering via Dynamo
|
| 461 |
+
and the context was already entered). We might want to revisit if it
|
| 462 |
+
should be set at a different level of compilation.
|
| 463 |
+
|
| 464 |
+
This is likely a bug in PyTorch: public APIs should not rely on
|
| 465 |
+
manually setting up internal contexts. But we also rely on non-public
|
| 466 |
+
APIs which might not provide these guarantees.
|
| 467 |
+
"""
|
| 468 |
+
import torch._dynamo.utils
|
| 469 |
+
|
| 470 |
+
return torch._dynamo.utils.get_metrics_context()
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def set_inductor_config(config, runtime_shape):
|
| 474 |
+
if isinstance(runtime_shape, int):
|
| 475 |
+
# for a specific batchsize, tuning triton kernel parameters
|
| 476 |
+
# can be beneficial
|
| 477 |
+
config["max_autotune"] = True
|
| 478 |
+
config["coordinate_descent_tuning"] = True
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
class EagerAdapter(CompilerInterface):
|
| 482 |
+
name = "eager"
|
| 483 |
+
|
| 484 |
+
def compile(
|
| 485 |
+
self,
|
| 486 |
+
graph: fx.GraphModule,
|
| 487 |
+
example_inputs: list[Any],
|
| 488 |
+
compiler_config: dict[str, Any],
|
| 489 |
+
runtime_shape: Optional[int] = None,
|
| 490 |
+
key: Optional[str] = None,
|
| 491 |
+
num_graphs: int = 1,
|
| 492 |
+
) -> tuple[Optional[Callable], Optional[Any]]:
|
| 493 |
+
return graph, None
|
| 494 |
+
|
| 495 |
+
def load(
|
| 496 |
+
self,
|
| 497 |
+
handle: Any,
|
| 498 |
+
graph: fx.GraphModule,
|
| 499 |
+
example_inputs: list[Any],
|
| 500 |
+
graph_index: int,
|
| 501 |
+
runtime_shape: Optional[int] = None,
|
| 502 |
+
num_graphs: int = 1,
|
| 503 |
+
) -> Callable:
|
| 504 |
+
raise NotImplementedError("eager compilation is not supported")
|
sglang/python/sglang/srt/compilation/cuda_piecewise_backend.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/cuda_piecewise_backend.py
|
| 2 |
+
|
| 3 |
+
import dataclasses
|
| 4 |
+
import logging
|
| 5 |
+
from contextlib import ExitStack
|
| 6 |
+
from typing import Any, Callable, Optional
|
| 7 |
+
from unittest.mock import patch
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.fx as fx
|
| 11 |
+
|
| 12 |
+
from sglang.srt.compilation.compilation_config import CompilationConfig
|
| 13 |
+
from sglang.srt.compilation.compilation_counter import compilation_counter
|
| 14 |
+
from sglang.srt.compilation.piecewise_context_manager import (
|
| 15 |
+
get_pcg_capture_stream,
|
| 16 |
+
is_in_pcg_torch_compile,
|
| 17 |
+
)
|
| 18 |
+
from sglang.srt.compilation.weak_ref_tensor import weak_ref_tensors
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclasses.dataclass
|
| 24 |
+
class ConcreteSizeEntry:
|
| 25 |
+
runtime_shape: int
|
| 26 |
+
need_to_compile: bool # the size is in compile_sizes
|
| 27 |
+
use_cudagraph: bool # the size is in cudagraph_capture_sizes
|
| 28 |
+
|
| 29 |
+
compiled: bool = False
|
| 30 |
+
runnable: Callable = None # type: ignore
|
| 31 |
+
num_finished_warmup: int = 0
|
| 32 |
+
cudagraph: Optional[torch.cuda.CUDAGraph] = None
|
| 33 |
+
output: Optional[Any] = None
|
| 34 |
+
|
| 35 |
+
# for cudagraph debugging, track the input addresses
|
| 36 |
+
# during capture, and check if they are the same during replay
|
| 37 |
+
input_addresses: Optional[list[int]] = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class CUDAPiecewiseBackend:
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
graph: fx.GraphModule,
|
| 45 |
+
compile_config: CompilationConfig,
|
| 46 |
+
inductor_config: dict[str, Any],
|
| 47 |
+
graph_pool: Any,
|
| 48 |
+
piecewise_compile_index: int,
|
| 49 |
+
total_piecewise_compiles: int,
|
| 50 |
+
sym_shape_indices: list[int],
|
| 51 |
+
compiled_graph_for_general_shape: Callable,
|
| 52 |
+
sglang_backend,
|
| 53 |
+
):
|
| 54 |
+
"""
|
| 55 |
+
The backend for piecewise compilation.
|
| 56 |
+
It mainly handles the compilation and cudagraph capturing.
|
| 57 |
+
|
| 58 |
+
We will compile `self.graph` once for the general shape,
|
| 59 |
+
and then compile for different shapes specified in
|
| 60 |
+
`compilation_config.compile_sizes`.
|
| 61 |
+
|
| 62 |
+
Independently, we will capture cudagraph for different shapes.
|
| 63 |
+
|
| 64 |
+
If a shape needs both compilation and cudagraph, we will
|
| 65 |
+
compile it first, and then capture cudagraph.
|
| 66 |
+
"""
|
| 67 |
+
self.graph = graph
|
| 68 |
+
self.inductor_config = inductor_config
|
| 69 |
+
self.graph_pool = graph_pool
|
| 70 |
+
self.piecewise_compile_index = piecewise_compile_index
|
| 71 |
+
self.total_piecewise_compiles = total_piecewise_compiles
|
| 72 |
+
self.sglang_backend = sglang_backend
|
| 73 |
+
|
| 74 |
+
self.is_first_graph = piecewise_compile_index == 0
|
| 75 |
+
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
|
| 76 |
+
|
| 77 |
+
self.compile_sizes: set[int] = set([])
|
| 78 |
+
self.compile_config = compile_config
|
| 79 |
+
self.cudagraph_capture_sizes: set[int] = set(compile_config.get_capture_sizes())
|
| 80 |
+
|
| 81 |
+
self.first_run_finished = False
|
| 82 |
+
|
| 83 |
+
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
|
| 84 |
+
|
| 85 |
+
self.sym_shape_indices = sym_shape_indices
|
| 86 |
+
|
| 87 |
+
# the entries for different shapes that we need to either
|
| 88 |
+
# compile or capture cudagraph
|
| 89 |
+
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
|
| 90 |
+
|
| 91 |
+
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
| 92 |
+
# and updates during the compilation process, so we need to copy it
|
| 93 |
+
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
|
| 94 |
+
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
|
| 95 |
+
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
| 96 |
+
runtime_shape=shape,
|
| 97 |
+
need_to_compile=shape in self.compile_sizes,
|
| 98 |
+
use_cudagraph=shape in self.cudagraph_capture_sizes,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def check_for_ending_compilation(self):
|
| 102 |
+
if self.is_last_graph and not self.to_be_compiled_sizes:
|
| 103 |
+
# no specific sizes to compile
|
| 104 |
+
# save the hash of the inductor graph for the next run
|
| 105 |
+
self.sglang_backend.compiler_manager.save_to_file()
|
| 106 |
+
|
| 107 |
+
def __call__(self, *args) -> Any:
|
| 108 |
+
if not self.first_run_finished:
|
| 109 |
+
self.first_run_finished = True
|
| 110 |
+
self.check_for_ending_compilation()
|
| 111 |
+
return self.compiled_graph_for_general_shape(*args)
|
| 112 |
+
|
| 113 |
+
if len(self.sym_shape_indices) == 0:
|
| 114 |
+
return self.compiled_graph_for_general_shape(*args)
|
| 115 |
+
|
| 116 |
+
runtime_shape = args[self.sym_shape_indices[0]]
|
| 117 |
+
if runtime_shape not in self.concrete_size_entries:
|
| 118 |
+
# we don't need to do anything for this shape
|
| 119 |
+
return self.compiled_graph_for_general_shape(*args)
|
| 120 |
+
|
| 121 |
+
entry = self.concrete_size_entries[runtime_shape]
|
| 122 |
+
|
| 123 |
+
if entry.runnable is None:
|
| 124 |
+
entry.runnable = self.compiled_graph_for_general_shape
|
| 125 |
+
|
| 126 |
+
if entry.need_to_compile and not entry.compiled:
|
| 127 |
+
entry.compiled = True
|
| 128 |
+
self.to_be_compiled_sizes.remove(runtime_shape)
|
| 129 |
+
# args are real arguments
|
| 130 |
+
entry.runnable = self.sglang_backend.compiler_manager.compile(
|
| 131 |
+
self.graph,
|
| 132 |
+
args,
|
| 133 |
+
self.inductor_config,
|
| 134 |
+
graph_index=self.piecewise_compile_index,
|
| 135 |
+
num_graphs=self.total_piecewise_compiles,
|
| 136 |
+
runtime_shape=runtime_shape,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# finished compilations for all required shapes
|
| 140 |
+
if self.is_last_graph and not self.to_be_compiled_sizes:
|
| 141 |
+
self.check_for_ending_compilation()
|
| 142 |
+
|
| 143 |
+
if is_in_pcg_torch_compile():
|
| 144 |
+
return entry.runnable(*args)
|
| 145 |
+
|
| 146 |
+
if entry.cudagraph is None:
|
| 147 |
+
if entry.num_finished_warmup < 1: # noqa
|
| 148 |
+
entry.num_finished_warmup += 1
|
| 149 |
+
return entry.runnable(*args)
|
| 150 |
+
|
| 151 |
+
if self.compile_config.get_enable_debug_mode():
|
| 152 |
+
input_addresses = [
|
| 153 |
+
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
| 154 |
+
]
|
| 155 |
+
entry.input_addresses = input_addresses
|
| 156 |
+
cudagraph = torch.cuda.CUDAGraph()
|
| 157 |
+
|
| 158 |
+
with ExitStack() as stack:
|
| 159 |
+
if not self.is_first_graph:
|
| 160 |
+
# during every model forward, we will capture
|
| 161 |
+
# many pieces of cudagraphs (roughly one per layer).
|
| 162 |
+
# running gc again and again across layers will
|
| 163 |
+
# make the cudagraph capture very slow.
|
| 164 |
+
# therefore, we only run gc for the first graph,
|
| 165 |
+
# and disable gc for the rest of the graphs.
|
| 166 |
+
stack.enter_context(patch("gc.collect", lambda: None))
|
| 167 |
+
stack.enter_context(patch("torch.cuda.empty_cache", lambda: None))
|
| 168 |
+
# mind-exploding: carefully manage the reference and memory.
|
| 169 |
+
stream = get_pcg_capture_stream()
|
| 170 |
+
assert (
|
| 171 |
+
stream is not None
|
| 172 |
+
), "PCG capture stream is not set, please check if runtime recompilation happened"
|
| 173 |
+
with torch.cuda.graph(cudagraph, pool=self.graph_pool, stream=stream):
|
| 174 |
+
# `output` is managed by pytorch's cudagraph pool
|
| 175 |
+
output = entry.runnable(*args)
|
| 176 |
+
if self.is_last_graph:
|
| 177 |
+
# by converting it to weak ref,
|
| 178 |
+
# the original `output` will immediately be released
|
| 179 |
+
# to save memory. It is only safe to do this for
|
| 180 |
+
# the last graph, because the output of the last graph
|
| 181 |
+
# will not be used by any other cuda graph.
|
| 182 |
+
output = weak_ref_tensors(output)
|
| 183 |
+
|
| 184 |
+
# here we always use weak ref for the output
|
| 185 |
+
# to save memory
|
| 186 |
+
entry.output = weak_ref_tensors(output)
|
| 187 |
+
entry.cudagraph = cudagraph
|
| 188 |
+
|
| 189 |
+
compilation_counter.num_cudagraph_captured += 1
|
| 190 |
+
|
| 191 |
+
# important: we need to return the output, rather than
|
| 192 |
+
# the weak ref of the output, so that pytorch can correctly
|
| 193 |
+
# manage the memory during cuda graph capture
|
| 194 |
+
return output
|
| 195 |
+
|
| 196 |
+
if self.compile_config.get_enable_debug_mode():
|
| 197 |
+
# check if the input addresses are the same
|
| 198 |
+
new_input_addresses = [
|
| 199 |
+
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
| 200 |
+
]
|
| 201 |
+
assert new_input_addresses == entry.input_addresses, (
|
| 202 |
+
"Input addresses for cudagraphs are different during replay."
|
| 203 |
+
f" Expected {entry.input_addresses}, got {new_input_addresses}"
|
| 204 |
+
)
|
| 205 |
+
entry.cudagraph.replay()
|
| 206 |
+
return entry.output
|
sglang/python/sglang/srt/compilation/fix_functionalization.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fix_functionalization.py
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import operator
|
| 5 |
+
from collections.abc import Iterable
|
| 6 |
+
from typing import Optional, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
| 10 |
+
|
| 11 |
+
from sglang.srt.compilation.fx_utils import is_func
|
| 12 |
+
from sglang.srt.compilation.inductor_pass import SGLangInductorPass
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class FixFunctionalizationPass(SGLangInductorPass):
|
| 18 |
+
"""
|
| 19 |
+
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
|
| 20 |
+
After this pass, DCE (dead-code elimination) should never be run,
|
| 21 |
+
as de-functionalized nodes may appear as dead code.
|
| 22 |
+
|
| 23 |
+
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __call__(self, graph: torch.fx.Graph):
|
| 27 |
+
self.begin()
|
| 28 |
+
self.dump_graph(graph, "before_fix_functionalization")
|
| 29 |
+
|
| 30 |
+
self.nodes_to_remove: list[torch.fx.Node] = []
|
| 31 |
+
count = 0
|
| 32 |
+
for node in graph.nodes:
|
| 33 |
+
if not is_func(node, auto_functionalized):
|
| 34 |
+
continue # Avoid deep if-elif nesting
|
| 35 |
+
count += 1
|
| 36 |
+
|
| 37 |
+
self.dump_graph(graph, "before_fix_functionalization_cleanup")
|
| 38 |
+
|
| 39 |
+
# Remove the nodes all at once
|
| 40 |
+
count_removed = len(self.nodes_to_remove)
|
| 41 |
+
for node in self.nodes_to_remove:
|
| 42 |
+
graph.erase_node(node)
|
| 43 |
+
|
| 44 |
+
logger.debug(
|
| 45 |
+
"De-functionalized %s nodes, removed %s nodes", count, count_removed
|
| 46 |
+
)
|
| 47 |
+
self.dump_graph(graph, "after_fix_functionalization")
|
| 48 |
+
self.end_and_log()
|
| 49 |
+
|
| 50 |
+
def _remove(self, node_or_nodes: Union[torch.fx.Node, Iterable[torch.fx.Node]]):
|
| 51 |
+
"""
|
| 52 |
+
Stage a node (or nodes) for removal at the end of the pass.
|
| 53 |
+
"""
|
| 54 |
+
if isinstance(node_or_nodes, torch.fx.Node):
|
| 55 |
+
self.nodes_to_remove.append(node_or_nodes)
|
| 56 |
+
else:
|
| 57 |
+
self.nodes_to_remove.extend(node_or_nodes)
|
| 58 |
+
|
| 59 |
+
def defunctionalize(
|
| 60 |
+
self,
|
| 61 |
+
graph: torch.fx.Graph,
|
| 62 |
+
node: torch.fx.Node,
|
| 63 |
+
mutated_args: dict[int, Union[torch.fx.Node, str]],
|
| 64 |
+
args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None,
|
| 65 |
+
):
|
| 66 |
+
"""
|
| 67 |
+
De-functionalize a node by replacing it with a call to the original.
|
| 68 |
+
It also replaces the getitem users with the mutated arguments.
|
| 69 |
+
See replace_users_with_mutated_args and insert_defunctionalized.
|
| 70 |
+
"""
|
| 71 |
+
self.replace_users_with_mutated_args(node, mutated_args)
|
| 72 |
+
self.insert_defunctionalized(graph, node, args=args)
|
| 73 |
+
self._remove(node)
|
| 74 |
+
|
| 75 |
+
def replace_users_with_mutated_args(
|
| 76 |
+
self, node: torch.fx.Node, mutated_args: dict[int, Union[torch.fx.Node, str]]
|
| 77 |
+
):
|
| 78 |
+
"""
|
| 79 |
+
Replace all getitem users of the auto-functionalized node with the
|
| 80 |
+
mutated arguments.
|
| 81 |
+
:param node: The auto-functionalized node
|
| 82 |
+
:param mutated_args: The mutated arguments, indexed by getitem index.
|
| 83 |
+
If the value of an arg is a string, `node.kwargs[arg]` is used.
|
| 84 |
+
"""
|
| 85 |
+
for idx, user in self.getitem_users(node).items():
|
| 86 |
+
arg = mutated_args[idx]
|
| 87 |
+
arg = node.kwargs[arg] if isinstance(arg, str) else arg
|
| 88 |
+
user.replace_all_uses_with(arg)
|
| 89 |
+
self._remove(user)
|
| 90 |
+
|
| 91 |
+
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
|
| 92 |
+
"""
|
| 93 |
+
Returns the operator.getitem users of the auto-functionalized node,
|
| 94 |
+
indexed by the index they are getting.
|
| 95 |
+
"""
|
| 96 |
+
users = {}
|
| 97 |
+
for user in node.users:
|
| 98 |
+
if is_func(user, operator.getitem):
|
| 99 |
+
idx = user.args[1]
|
| 100 |
+
users[idx] = user
|
| 101 |
+
return users
|
| 102 |
+
|
| 103 |
+
def insert_defunctionalized(
|
| 104 |
+
self,
|
| 105 |
+
graph: torch.fx.Graph,
|
| 106 |
+
node: torch.fx.Node,
|
| 107 |
+
args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None,
|
| 108 |
+
):
|
| 109 |
+
"""
|
| 110 |
+
Insert a new defunctionalized node into the graph before node.
|
| 111 |
+
If one of the kwargs is 'out', provide args directly,
|
| 112 |
+
as node.kwargs cannot be used.
|
| 113 |
+
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
|
| 114 |
+
|
| 115 |
+
:param graph: Graph to insert the defunctionalized node into
|
| 116 |
+
:param node: The auto-functionalized node to defunctionalize
|
| 117 |
+
:param args: If we cannot use kwargs, specify args directly.
|
| 118 |
+
If an arg is a string, `node.kwargs[arg]` is used.
|
| 119 |
+
""" # noqa: E501
|
| 120 |
+
assert is_func(
|
| 121 |
+
node, auto_functionalized
|
| 122 |
+
), f"node must be auto-functionalized, is {node} instead"
|
| 123 |
+
|
| 124 |
+
# Create a new call to the original function
|
| 125 |
+
with graph.inserting_before(node):
|
| 126 |
+
function = node.args[0]
|
| 127 |
+
if args is None:
|
| 128 |
+
graph.call_function(function, kwargs=node.kwargs)
|
| 129 |
+
else:
|
| 130 |
+
# Args passed as strings refer to items in node.kwargs
|
| 131 |
+
args = tuple(
|
| 132 |
+
node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
|
| 133 |
+
)
|
| 134 |
+
graph.call_function(function, args=args)
|
sglang/python/sglang/srt/compilation/fx_utils.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fx_utils.py
|
| 2 |
+
|
| 3 |
+
import operator
|
| 4 |
+
from collections.abc import Iterable, Iterator
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
from torch import fx
|
| 8 |
+
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
| 9 |
+
from torch._ops import OpOverload
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def is_func(node: fx.Node, target) -> bool:
|
| 13 |
+
return node.op == "call_function" and node.target == target
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
|
| 17 |
+
return is_func(node, auto_functionalized) and node.args[0] == op
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Returns the first specified node with the given op (if it exists)
|
| 21 |
+
def find_specified_fn_maybe(
|
| 22 |
+
nodes: Iterable[fx.Node], op: OpOverload
|
| 23 |
+
) -> Optional[fx.Node]:
|
| 24 |
+
for node in nodes:
|
| 25 |
+
if node.target == op:
|
| 26 |
+
return node
|
| 27 |
+
return None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Returns the first specified node with the given op
|
| 31 |
+
def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
| 32 |
+
node = find_specified_fn_maybe(nodes, op)
|
| 33 |
+
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
| 34 |
+
return node
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Returns the first auto_functionalized node with the given op (if it exists)
|
| 38 |
+
def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> Optional[fx.Node]:
|
| 39 |
+
for node in nodes:
|
| 40 |
+
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
|
| 41 |
+
return node
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Returns the first auto_functionalized node with the given op
|
| 46 |
+
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
| 47 |
+
node = find_auto_fn_maybe(nodes, op)
|
| 48 |
+
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
| 49 |
+
return node
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# Returns the getitem node that extracts the idx-th element from node
|
| 53 |
+
# (if it exists)
|
| 54 |
+
def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]:
|
| 55 |
+
for user in node.users:
|
| 56 |
+
if is_func(user, operator.getitem) and user.args[1] == idx:
|
| 57 |
+
return user
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Returns the getitem node that extracts the idx-th element from node
|
| 62 |
+
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
|
| 63 |
+
ret = find_getitem_maybe(node, idx)
|
| 64 |
+
assert ret is not None, f"Could not find getitem {idx} in node {node}"
|
| 65 |
+
return ret
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# An auto-functionalization-aware utility for finding nodes with a specific op
|
| 69 |
+
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
|
| 70 |
+
if not op._schema.is_mutable:
|
| 71 |
+
yield from graph.find_nodes(op="call_function", target=op)
|
| 72 |
+
|
| 73 |
+
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
|
| 74 |
+
if n.args[0] == op:
|
| 75 |
+
yield n
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Asserts that the node only has one user and returns it
|
| 79 |
+
# Even if a node has only 1 user, it might share storage with another node,
|
| 80 |
+
# which might need to be taken into account.
|
| 81 |
+
def get_only_user(node: fx.Node) -> fx.Node:
|
| 82 |
+
assert len(node.users) == 1
|
| 83 |
+
return next(iter(node.users))
|