diff --git a/sglang/3rdparty/amd/profiling/PROFILING.md b/sglang/3rdparty/amd/profiling/PROFILING.md new file mode 100644 index 0000000000000000000000000000000000000000..7e15ec844f2bf7743c641362826e957d6efd912b --- /dev/null +++ b/sglang/3rdparty/amd/profiling/PROFILING.md @@ -0,0 +1,425 @@ +## Profiling SGLang Infer System with AMD GPUs +This AppNote describes the SGLang profiling technical, code augment and running steps for systems with AMD Instinct GPUs, nevertheless the same procedure may work with Nvidia GPUs too. +Examples and steps are provided in detail, to facilitate easy reproduce and use to localize performance problem towards optimizations. +Two primary methods are covered: +- [RPD](https://github.com/ROCm/rocmProfileData.git) +- [PyTorch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html) + +### Profiling SGLang Infer System with RPD Profiler +RPD profiler is a low-overhead cross-platform profiler. Therefore, the same RPD code augment not only works for profiling on ROCm/AMD GPUs, but also works for profiling on CUDA/Nvidia GPUs as well. To do RPD profiling on SGLang repository, please use scripts and patch files included in this directory and follow the steps below: +1. Install RPD with rpd.patch applied during installation using install_rpd.sh, both files are in this directory. + +install_rpd.sh + +```bash +# download and install RPD +apt update && apt install -y sqlite3 libsqlite3-dev libfmt-dev + +# install rpd module +git clone https://github.com/ROCmSoftwarePlatform/rocmProfileData +cd rocmProfileData +git checkout 976899e9c6dbc6dd2bccf770818e4e44125590ac +git apply rpd.patch +make && make install +cd rocpd_python && python setup.py install && cd .. +cd rpd_tracer && make clean;make install && python setup.py install && cd .. +``` + +rpd.patch + +```bash +diff --git a/rpd_tracer/Makefile b/rpd_tracer/Makefile +index e9d9feb..b2e9e1a 100644 +--- a/rpd_tracer/Makefile ++++ b/rpd_tracer/Makefile +@@ -16,7 +16,7 @@ ifneq (,$(HIP_PATH)) + $(info Building with roctracer) + RPD_LIBS += -L/opt/rocm/lib -lroctracer64 -lroctx64 -lamdhip64 -lrocm_smi64 + RPD_INCLUDES += -I/opt/rocm/include -I/opt/rocm/include/roctracer -I/opt/rocm/include/hsa +- RPD_SRCS += RoctracerDataSource.cpp RocmSmiDataSource.cpp ++ RPD_SRCS += RoctracerDataSource.cpp + RPD_INCLUDES += -D__HIP_PLATFORM_AMD__ + endif +``` +2. Add loadTracer.sh file included in this directory to /sglang/python/sglang. + +loadTracer.sh + +```bash +#!/bin/bash +################################################################################ +# Copyright (c) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +################################################################################ +OUTPUT_FILE="trace.rpd" + +if [ "$1" = "-o" ] ; then + OUTPUT_FILE=$2 + shift + shift +fi + +if [ -e ${OUTPUT_FILE} ] ; then + rm ${OUTPUT_FILE} +fi + +python3 -m rocpd.schema --create ${OUTPUT_FILE} +if [ $? != 0 ] ; then + echo "Error: Could not create rpd file. Please run 'python setup.py install' from the rocpd_python dir" + exit +fi + +export RPDT_FILENAME=${OUTPUT_FILE} +export RPDT_AUTOSTART=0 +LD_PRELOAD=librocm-smi_64:librpd_tracer.so "$@" +``` +3. Apply patch (provided in this directory) with "git apply rpd_profile_server_enable.patch" if the main profiling purpose is to get info on gpu kernels as well as limited cpu activity info. + +#### Common Notes 1 +Please note that although we are doing TP=8 in the example, we purposely only log RPD profiling on 2 ranks in the patch file (i.e.tp_rank=0/1) for profiling/visualization convenience, as even Perfetto streaming mode can only load maximal 8GB json file for visualization. With 2 ranks logged in RPD profiling, we could still check whether there are issues among ranks (e.g. load imbalance issue, nccl issue), and at the same time, we could log relatively longer time duration before the json file generated from RPD file hits 8GB size. + +rpd_profile_server_enable.patch + +```bash +diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py +index 62d1ff9..9021c01 100644 +--- a/python/sglang/srt/managers/scheduler.py ++++ b/python/sglang/srt/managers/scheduler.py +@@ -71,6 +71,8 @@ from sglang.srt.utils import ( + suppress_other_loggers, + ) + from sglang.utils import get_exception_traceback ++from rpdTracerControl import rpdTracerControl ++rpdTracerControl.skipCreate() + + logger = logging.getLogger(__name__) + +@@ -245,6 +247,7 @@ class Scheduler: + ], + with_stack=True, + ) ++ self.rpd = rpdTracerControl() + + @torch.inference_mode() + def event_loop(self): +@@ -1027,15 +1030,24 @@ class Scheduler: + def start_profile(self) -> None: + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") +- self.profiler.start() ++ #self.profiler.start() #block pytorch profiler for rpd profiler enabling ++ if self.tp_rank == 0 or self.tp_rank == 1: ++ self.rpd.start() ++ self.rpd.rangePush("", "rpd profile range", "") ++ logger.info("rpd is enabled") + + def stop_profile(self) -> None: + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") +- self.profiler.stop() +- self.profiler.export_chrome_trace( +- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz" +- ) ++ #self.profiler.stop() ++ #self.profiler.export_chrome_trace( ++ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz" ++ #) ++ if self.tp_rank ==0 or self.tp_rank ==1: ++ self.rpd.rangePop() ++ self.rpd.stop() ++ self.rpd.flush() ++ logger.info("rpd is done") + logger.info("Profiler is done") +``` + +#### Advanced Debugging with RPD Profiler +Sometimes, we want to use rpd profiler to capture more CPU and python activities in order to debug some challenging issues (e.g. root cause of load imbalance across gpu processes, root cause of bubbles, etc). Only in such cases, we need to apply patch "git apply rpd_profile_server_enable_wCPU_activities.patch", where 3 files are modified. + +rpd_profile_server_enable_wCPU_activities.patch + +```bash +diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py +index 62d1ff9..2edb427 100644 +--- a/python/sglang/srt/managers/scheduler.py ++++ b/python/sglang/srt/managers/scheduler.py +@@ -71,6 +71,8 @@ from sglang.srt.utils import ( + suppress_other_loggers, + ) + from sglang.utils import get_exception_traceback ++from rpdTracerControl import rpdTracerControl ++rpdTracerControl.skipCreate() + + logger = logging.getLogger(__name__) + +@@ -245,6 +247,7 @@ class Scheduler: + ], + with_stack=True, + ) ++ self.rpd = rpdTracerControl() + + @torch.inference_mode() + def event_loop(self): +@@ -1027,15 +1030,26 @@ class Scheduler: + def start_profile(self) -> None: + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") +- self.profiler.start() ++ #self.profiler.start() ++ logger.info("torch profiler is disabled") ++ if self.tp_rank == 0 or self.tp_rank == 1: ++ self.rpd.setPythonTrace(True) ++ self.rpd.start() ++ self.rpd.rangePush("", "scheduler", "") ++ logger.info("rpd is enabled inside scheduler profiling") + + def stop_profile(self) -> None: + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") +- self.profiler.stop() +- self.profiler.export_chrome_trace( +- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz" +- ) ++ #self.profiler.stop() ++ #self.profiler.export_chrome_trace( ++ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz" ++ #) ++ if self.tp_rank ==0 or self.tp_rank ==1: ++ self.rpd.rangePop() ++ self.rpd.stop() ++ self.rpd.flush() ++ logger.info("rpd is done inside scheduler") + logger.info("Profiler is done") + + +diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py +index 2621ccd..181df85 100644 +--- a/python/sglang/srt/managers/tokenizer_manager.py ++++ b/python/sglang/srt/managers/tokenizer_manager.py +@@ -58,6 +58,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams + from sglang.srt.server_args import PortArgs, ServerArgs + from sglang.srt.utils import is_generation_model, is_multimodal_model + ++from rpdTracerControl import rpdTracerControl ++rpdTracerControl.skipCreate() ++ ++ + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + logger = logging.getLogger(__name__) +@@ -514,10 +518,20 @@ class TokenizerManager: + self.send_to_scheduler.send_pyobj(req) + + def start_profile(self): ++ rpd = rpdTracerControl() ++ rpd.setPythonTrace(True) ++ rpd.start() ++ rpd.rangePush("", "tokenizer_manager", "") ++ logger.info("tokenizer_manager rpd profiling started!") + req = ProfileReq.START_PROFILE + self.send_to_scheduler.send_pyobj(req) + + def stop_profile(self): ++ rpd = rpdTracerControl() ++ rpd.rangePop() ++ rpd.stop() ++ rpd.flush() ++ logger.info("rpd profiling is done inside tokenizer_manager!") + req = ProfileReq.STOP_PROFILE + self.send_to_scheduler.send_pyobj(req) + +diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py +index 7111c93..2bd722c 100644 +--- a/python/sglang/srt/server.py ++++ b/python/sglang/srt/server.py +@@ -30,6 +30,8 @@ import threading + import time + from http import HTTPStatus + from typing import Dict, List, Optional, Union ++from rpdTracerControl import rpdTracerControl ++rpdTracerControl.skipCreate() + + # Fix a bug of Python threading + setattr(threading, "_register_atexit", lambda *args, **kwargs: None) +@@ -152,6 +154,11 @@ async def flush_cache(): + @app.post("/start_profile") + async def start_profile(): + """Start profiling.""" ++ rpd = rpdTracerControl() ++ rpd.setPythonTrace(True) ++ rpd.start() ++ rpd.rangePush("", "server rpd profile range", "") ++ logger.info("rpd profiling started in server.py!") + tokenizer_manager.start_profile() + return Response( + content="Start profiling.\n", +@@ -164,6 +171,11 @@ async def start_profile(): + async def stop_profile(): + """Stop profiling.""" + tokenizer_manager.stop_profile() ++ rpd = rpdTracerControl() ++ rpd.rangePop() ++ rpd.stop() ++ rpd.flush() ++ logger.info("rpd profiling is done in server.py!") + return Response( + content="Stop profiling. This will take some time.\n", + status_code=200, +``` + +4. As an example for grok1 profiling, we create a dummy_grok1 directory with config.json (see content below) inside this directory and copy this directory to the right path for "--model-path" if you want to use the example server.sh file provided. +```bash +cat ../dummy_grok1/config.json +{ + "architectures": [ + "Grok1ModelForCausalLM" + ], + "embedding_multiplier_scale": 78.38367176906169, + "output_multiplier_scale": 0.5773502691896257, + "vocab_size": 131072, + "hidden_size": 6144, + "intermediate_size": 32768, + "max_position_embeddings": 8192, + "num_experts_per_tok": 2, + "num_local_experts": 8, + "num_attention_heads": 48, + "num_hidden_layers": 64, + "num_key_value_heads": 8, + "head_dim": 128, + "rms_norm_eps": 1e-05, + "rope_theta": 10000.0, + "model_type": "mixtral", + "torch_dtype": "bfloat16" +} +``` +5. Launch server with rpd enabled script ./server.sh in one terminal inside the docker container. + +#### Common Notes 2 +- Remember to change model-path to the correct path +- loadTracer.sh is needed to conduct profiling +- SGLANG_TORCH_PROFILER_DIR is used for default torch profiler +- Do not use loadTracer.sh if you are using the torch profiler, simply use python3 -m sglang.launch_server. + + +server.sh + +```bash +#!/bin/bash + +# export SGLANG_TORCH_PROFILER_DIR=/data/sglang/ +export SGLANG_TORCH_PROFILER_DIR=/sgl-workspace/sglang/profile/ + +# Get the current timestamp +TIMESTAMP=$(date +"%Y%m%d_%H%M%S") + +# Define the log file with a timestamp +LOGFILE="sglang_server_log_$TIMESTAMP.json" + +# Run the Python command and save the output to the log file +loadTracer.sh python3 -m sglang.launch_server \ + --model-path /sgl-workspace/sglang/dummy_grok1 \ + --tokenizer-path Xenova/grok-1-tokenizer \ + --load-format dummy \ + --quantization fp8 \ + --tp 8 \ + --port 30000 \ + --disable-radix-cache 2>&1 | tee "$LOGFILE" +``` +6. Open another terminal for the same docker container, and run the rpd enabled ./client.sh after you see "The server is fired up and is ready to roll!" message from server side terminal. + +#### Common Notes 3 +- Use curl http://localhost:30000/start_profile & curl http://localhost:30000/stop_profile to control the start and end of profiling. Check sglang/python/sglang/srt/managers/scheduler.py for more details. +- Please don't use RPD profiler together with PyTorch profiler to avoid interference. +- The rocmProfileData/tools/rpd2tracing.py file is used to generate json file from RPD file. + +client.sh + +```bash +#!/bin/bash + +# Start profiling via API +curl http://localhost:30000/start_profile -H "Content-Type: application/json" + +# Benchmark serving using sglang with random dataset and tokenizer +# Define the log file with a timestamp +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +LOGFILE="sglang_client_log_$TIMESTAMP.json" + +# Run the benchmark with specified parameters and save logs +python3 -m sglang.bench_serving \ + --backend sglang \ + --tokenizer Xenova/grok-1-tokenizer \ + --dataset-name random \ + --random-input 1024\ + --random-output 1024 \ + --num-prompts 120 \ + --request-rate 8 \ + --output-file online.jsonl 2>&1 | tee "$LOGFILE" + +# Stop profiling via API +curl http://localhost:30000/stop_profile -H "Content-Type: application/json" + +# Convert tracing file to csv & json +sqlite3 trace.rpd ".mode csv" ".header on" ".output trace.csv" "select * from top;" ".output stdout" +python3 ./rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json +``` +7. Follow [Perfetto docs](https://perfetto.dev/docs/visualization/large-traces) to visualize large json files. Try to adjust parameters so that the trace.json file size is less than 9GB. + +### Profiling SGLang Infer System with PyTorch Profiler + +Please use the steps as follows: + +1. Apply the patch torch_profiler.patch. Note that you can modify "if self.tp_rank == 0" in the patch to allow more ranks be recorded in profiling. + +torch_profiler.patch +```bash +diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py +index 62d1ff9..6ecd78c 100644 +--- a/python/sglang/srt/managers/scheduler.py ++++ b/python/sglang/srt/managers/scheduler.py +@@ -240,7 +240,6 @@ class Scheduler: + ) + self.profiler = torch.profiler.profile( + activities=[ +- torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, +@@ -1033,9 +1032,11 @@ class Scheduler: + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.stop() +- self.profiler.export_chrome_trace( +- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz" +- ) ++ if self.tp_rank == 0: ++ with open(f"stats_repro_{int(time.time())}.txt", "w") as f: ++ print(self.profiler.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=-1), file=f) ++ print("Profiling stats done.") ++ + logger.info("Profiler is done") +``` + +2. Create the model path directory and copy it to the right path for "--model-path" if you want to use the server.sh file provided. + +3. Modify the included server.sh by removing "loadTracer.sh" before python command and launch script ./server.sh in one terminal inside the docker container. + +4. Similar to step 6 in RPD profiling section, but remove the last 2 lines in client.sh, which converted rpd file into csv and json files. Run modified client.sh for PyTorch profiling. +------- +- [Torch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html) diff --git a/sglang/3rdparty/amd/profiling/client.sh b/sglang/3rdparty/amd/profiling/client.sh new file mode 100644 index 0000000000000000000000000000000000000000..150ea9f193f9fdf217a1448b6c5da2057dcd80d9 --- /dev/null +++ b/sglang/3rdparty/amd/profiling/client.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# Start profiling via API +curl http://localhost:30000/start_profile -H "Content-Type: application/json" + +# Benchmark serving using sglang with random dataset and tokenizer +# Define the log file with a timestamp +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +LOGFILE="sglang_client_log_$TIMESTAMP.json" + +# Run the benchmark with specified parameters and save logs +python3 -m sglang.bench_serving \ + --backend sglang \ + --tokenizer Xenova/grok-1-tokenizer \ + --dataset-name random \ + --random-input 1024\ + --random-output 1024 \ + --num-prompts 240 \ + --request-rate 8 \ + --output-file online.jsonl 2>&1 | tee "$LOGFILE" + +# Stop profiling via API +curl http://localhost:30000/stop_profile -H "Content-Type: application/json" + +# Convert tracing file to csv & json +sqlite3 trace.rpd ".mode csv" ".header on" ".output trace.csv" "select * from top;" ".output stdout" +python3 /sgl-workspace/rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json diff --git a/sglang/3rdparty/amd/profiling/install_rpd.sh b/sglang/3rdparty/amd/profiling/install_rpd.sh new file mode 100644 index 0000000000000000000000000000000000000000..d1b04b9889767d19d3ce39a6a1dcdb81d93cf478 --- /dev/null +++ b/sglang/3rdparty/amd/profiling/install_rpd.sh @@ -0,0 +1,10 @@ +# download and install RPD +apt update && apt install -y sqlite3 libsqlite3-dev libfmt-dev + +# install rpd module +git clone https://github.com/ROCmSoftwarePlatform/rocmProfileData +cd rocmProfileData +git apply rpd.patch +make && make install +cd rocpd_python && python setup.py install && cd .. +cd rpd_tracer && make clean;make install && python setup.py install && cd .. diff --git a/sglang/3rdparty/amd/profiling/loadTracer.sh b/sglang/3rdparty/amd/profiling/loadTracer.sh new file mode 100644 index 0000000000000000000000000000000000000000..8a95a335cb0e5ef0502e66f47ef54c368320017c --- /dev/null +++ b/sglang/3rdparty/amd/profiling/loadTracer.sh @@ -0,0 +1,43 @@ +#!/bin/bash +################################################################################ +# Copyright (c) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +################################################################################ +OUTPUT_FILE="trace.rpd" + +if [ "$1" = "-o" ] ; then + OUTPUT_FILE=$2 + shift + shift +fi + +if [ -e ${OUTPUT_FILE} ] ; then + rm ${OUTPUT_FILE} +fi + +python3 -m rocpd.schema --create ${OUTPUT_FILE} +if [ $? != 0 ] ; then + echo "Error: Could not create rpd file. Please run 'python setup.py install' from the rocpd_python dir" + exit +fi + +export RPDT_FILENAME=${OUTPUT_FILE} +export RPDT_AUTOSTART=0 +LD_PRELOAD=librocm-smi_64:librpd_tracer.so "$@" diff --git a/sglang/3rdparty/amd/profiling/rpd.patch b/sglang/3rdparty/amd/profiling/rpd.patch new file mode 100644 index 0000000000000000000000000000000000000000..87917654ac8451f3d37ae1384cc8e074c8a6e31c --- /dev/null +++ b/sglang/3rdparty/amd/profiling/rpd.patch @@ -0,0 +1,12 @@ +diff --git a/rpd_tracer/Makefile b/rpd_tracer/Makefile +index e9d9feb..b2e9e1a 100644 +--- a/rpd_tracer/Makefile ++++ b/rpd_tracer/Makefile +@@ -16,7 +16,7 @@ ifneq (,$(HIP_PATH)) + $(info Building with roctracer) + RPD_LIBS += -L/opt/rocm/lib -lroctracer64 -lroctx64 -lamdhip64 -lrocm_smi64 + RPD_INCLUDES += -I/opt/rocm/include -I/opt/rocm/include/roctracer -I/opt/rocm/include/hsa +- RPD_SRCS += RoctracerDataSource.cpp RocmSmiDataSource.cpp ++ RPD_SRCS += RoctracerDataSource.cpp + RPD_INCLUDES += -D__HIP_PLATFORM_AMD__ + endif diff --git a/sglang/3rdparty/amd/profiling/rpd_profile_server_enable.patch b/sglang/3rdparty/amd/profiling/rpd_profile_server_enable.patch new file mode 100644 index 0000000000000000000000000000000000000000..3cd3915334eeb62abb92e2a0e9ef7b881613c882 --- /dev/null +++ b/sglang/3rdparty/amd/profiling/rpd_profile_server_enable.patch @@ -0,0 +1,49 @@ +diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py +index 62d1ff9..9021c01 100644 +--- a/python/sglang/srt/managers/scheduler.py ++++ b/python/sglang/srt/managers/scheduler.py +@@ -71,6 +71,8 @@ from sglang.srt.utils import ( + suppress_other_loggers, + ) + from sglang.utils import get_exception_traceback ++from rpdTracerControl import rpdTracerControl ++rpdTracerControl.skipCreate() + + logger = logging.getLogger(__name__) + +@@ -245,6 +247,7 @@ class Scheduler: + ], + with_stack=True, + ) ++ self.rpd = rpdTracerControl() + + @torch.inference_mode() + def event_loop(self): +@@ -1027,15 +1030,24 @@ class Scheduler: + def start_profile(self) -> None: + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") +- self.profiler.start() ++ #self.profiler.start() #block pytorch profiler for rpd profiler enabling ++ if self.tp_rank == 0 or self.tp_rank == 1: ++ self.rpd.start() ++ self.rpd.rangePush("", "rpd profile range", "") ++ logger.info("rpd is enabled") + + def stop_profile(self) -> None: + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") +- self.profiler.stop() +- self.profiler.export_chrome_trace( +- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz" +- ) ++ #self.profiler.stop() ++ #self.profiler.export_chrome_trace( ++ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz" ++ #) ++ if self.tp_rank ==0 or self.tp_rank ==1: ++ self.rpd.rangePop() ++ self.rpd.stop() ++ self.rpd.flush() ++ logger.info("rpd is done") + logger.info("Profiler is done") diff --git a/sglang/3rdparty/amd/profiling/rpd_profile_server_enable_wCPU_activities.patch b/sglang/3rdparty/amd/profiling/rpd_profile_server_enable_wCPU_activities.patch new file mode 100644 index 0000000000000000000000000000000000000000..5416f4d571ae119f3cb24765984aaa75cd5d0e73 --- /dev/null +++ b/sglang/3rdparty/amd/profiling/rpd_profile_server_enable_wCPU_activities.patch @@ -0,0 +1,126 @@ +diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py +index 62d1ff9..2edb427 100644 +--- a/python/sglang/srt/managers/scheduler.py ++++ b/python/sglang/srt/managers/scheduler.py +@@ -71,6 +71,8 @@ from sglang.srt.utils import ( + suppress_other_loggers, + ) + from sglang.utils import get_exception_traceback ++from rpdTracerControl import rpdTracerControl ++rpdTracerControl.skipCreate() + + logger = logging.getLogger(__name__) + +@@ -245,6 +247,7 @@ class Scheduler: + ], + with_stack=True, + ) ++ self.rpd = rpdTracerControl() + + @torch.inference_mode() + def event_loop(self): +@@ -1027,15 +1030,26 @@ class Scheduler: + def start_profile(self) -> None: + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") +- self.profiler.start() ++ #self.profiler.start() ++ logger.info("torch profiler is disabled") ++ if self.tp_rank == 0 or self.tp_rank == 1: ++ self.rpd.setPythonTrace(True) ++ self.rpd.start() ++ self.rpd.rangePush("", "scheduler", "") ++ logger.info("rpd is enabled inside scheduler profiling") + + def stop_profile(self) -> None: + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") +- self.profiler.stop() +- self.profiler.export_chrome_trace( +- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz" +- ) ++ #self.profiler.stop() ++ #self.profiler.export_chrome_trace( ++ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz" ++ #) ++ if self.tp_rank ==0 or self.tp_rank ==1: ++ self.rpd.rangePop() ++ self.rpd.stop() ++ self.rpd.flush() ++ logger.info("rpd is done inside scheduler") + logger.info("Profiler is done") + + +diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py +index 2621ccd..181df85 100644 +--- a/python/sglang/srt/managers/tokenizer_manager.py ++++ b/python/sglang/srt/managers/tokenizer_manager.py +@@ -58,6 +58,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams + from sglang.srt.server_args import PortArgs, ServerArgs + from sglang.srt.utils import is_generation_model, is_multimodal_model + ++from rpdTracerControl import rpdTracerControl ++rpdTracerControl.skipCreate() ++ ++ + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + logger = logging.getLogger(__name__) +@@ -514,10 +518,20 @@ class TokenizerManager: + self.send_to_scheduler.send_pyobj(req) + + def start_profile(self): ++ rpd = rpdTracerControl() ++ rpd.setPythonTrace(True) ++ rpd.start() ++ rpd.rangePush("", "tokenizer_manager", "") ++ logger.info("tokenizer_manager rpd profiling started!") + req = ProfileReq.START_PROFILE + self.send_to_scheduler.send_pyobj(req) + + def stop_profile(self): ++ rpd = rpdTracerControl() ++ rpd.rangePop() ++ rpd.stop() ++ rpd.flush() ++ logger.info("rpd profiling is done inside tokenizer_manager!") + req = ProfileReq.STOP_PROFILE + self.send_to_scheduler.send_pyobj(req) + +diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py +index 7111c93..2bd722c 100644 +--- a/python/sglang/srt/server.py ++++ b/python/sglang/srt/server.py +@@ -30,6 +30,8 @@ import threading + import time + from http import HTTPStatus + from typing import Dict, List, Optional, Union ++from rpdTracerControl import rpdTracerControl ++rpdTracerControl.skipCreate() + + # Fix a bug of Python threading + setattr(threading, "_register_atexit", lambda *args, **kwargs: None) +@@ -152,6 +154,11 @@ async def flush_cache(): + @app.post("/start_profile") + async def start_profile(): + """Start profiling.""" ++ rpd = rpdTracerControl() ++ rpd.setPythonTrace(True) ++ rpd.start() ++ rpd.rangePush("", "server rpd profile range", "") ++ logger.info("rpd profiling started in server.py!") + tokenizer_manager.start_profile() + return Response( + content="Start profiling.\n", +@@ -164,6 +171,11 @@ async def start_profile(): + async def stop_profile(): + """Stop profiling.""" + tokenizer_manager.stop_profile() ++ rpd = rpdTracerControl() ++ rpd.rangePop() ++ rpd.stop() ++ rpd.flush() ++ logger.info("rpd profiling is done in server.py!") + return Response( + content="Stop profiling. This will take some time.\n", + status_code=200, diff --git a/sglang/3rdparty/amd/profiling/server.sh b/sglang/3rdparty/amd/profiling/server.sh new file mode 100644 index 0000000000000000000000000000000000000000..f877e6c7acd40f48e6281f11be8c24c6ec6cf6aa --- /dev/null +++ b/sglang/3rdparty/amd/profiling/server.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# export SGLANG_TORCH_PROFILER_DIR=/data/sglang/ +export SGLANG_TORCH_PROFILER_DIR=/sgl-workspace/sglang/profile/ + +# Get the current timestamp +TIMESTAMP=$(date +"%Y%m%d_%H%M%S") + +# Define the log file with a timestamp +LOGFILE="sglang_server_log_$TIMESTAMP.json" + +# Run the Python command and save the output to the log file +loadTracer.sh python3 -m sglang.launch_server \ + --model-path /sgl-workspace/sglang/dummy_grok1 \ + --tokenizer-path Xenova/grok-1-tokenizer \ + --load-format dummy \ + --quantization fp8 \ + --tp 8 \ + --port 30000 \ + --disable-radix-cache 2>&1 | tee "$LOGFILE" diff --git a/sglang/3rdparty/amd/profiling/torch_profiler.patch b/sglang/3rdparty/amd/profiling/torch_profiler.patch new file mode 100644 index 0000000000000000000000000000000000000000..40f55740638a954ad4097520aa882c3351c2fbc7 --- /dev/null +++ b/sglang/3rdparty/amd/profiling/torch_profiler.patch @@ -0,0 +1,25 @@ +diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py +index 62d1ff9..6ecd78c 100644 +--- a/python/sglang/srt/managers/scheduler.py ++++ b/python/sglang/srt/managers/scheduler.py +@@ -240,7 +240,6 @@ class Scheduler: + ) + self.profiler = torch.profiler.profile( + activities=[ +- torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, +@@ -1033,9 +1032,11 @@ class Scheduler: + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.stop() +- self.profiler.export_chrome_trace( +- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz" +- ) ++ if self.tp_rank == 0: ++ with open(f"stats_repro_{int(time.time())}.txt", "w") as f: ++ print(self.profiler.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=-1), file=f) ++ print("Profiling stats done.") ++ + logger.info("Profiler is done") diff --git a/sglang/3rdparty/amd/sgl-kernel/CMakeLists_rocm.txt b/sglang/3rdparty/amd/sgl-kernel/CMakeLists_rocm.txt new file mode 100644 index 0000000000000000000000000000000000000000..86ba249c1bb76f7bbb57a6583035890cbd642e1d --- /dev/null +++ b/sglang/3rdparty/amd/sgl-kernel/CMakeLists_rocm.txt @@ -0,0 +1,159 @@ +cmake_minimum_required(VERSION 3.24 FATAL_ERROR) +project(sgl_kernel LANGUAGES CXX) + +# Cmake +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) +set(CMAKE_SHARED_LIBRARY_PREFIX "") + +set(CMAKE_COLOR_DIAGNOSTICS ON) +set(CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL "ON") + +# Python / Torch +find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED) + +execute_process( + COMMAND ${Python_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)" + OUTPUT_VARIABLE TORCH_PY_PREFIX + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +set(Torch_DIR "${TORCH_PY_PREFIX}/Torch") +list(APPEND CMAKE_PREFIX_PATH "${TORCH_PY_PREFIX}/Torch") +find_package(Torch REQUIRED) + +execute_process( + COMMAND ${Python_EXECUTABLE} -c "import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))" + OUTPUT_VARIABLE TORCH_CXX11_ABI + OUTPUT_STRIP_TRAILING_WHITESPACE +) +if(TORCH_CXX11_ABI STREQUAL "0") + add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) +else() + add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1) +endif() + +# ROCm/HIP +enable_language(HIP) +find_package(hip REQUIRED CONFIG) + +# Determine AMDGPU target from environment variable or default to gfx942 +set(AMDGPU_TARGET_ENV "$ENV{AMDGPU_TARGET}") + +if(AMDGPU_TARGET_ENV) + # Use environment variable if specified + set(AMDGPU_TARGETS "${AMDGPU_TARGET_ENV}") + message(STATUS "Using AMDGPU_TARGET from environment: ${AMDGPU_TARGETS}") +else() + # Default to gfx942 only + set(AMDGPU_TARGETS "gfx942") + message(STATUS "AMDGPU_TARGET not set, defaulting to gfx942") +endif() + +# Set HIP architectures +set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS}) + +# FP8 macro selection +# Always define HIP_FP8_TYPE_FNUZ=1 (for gfx942 and host compilation) +# Additionally define HIP_FP8_TYPE_E4M3=1 when building for gfx950 +# The existing utils.h logic will pick the right one based on architecture +set(SGL_FP8_MACROS "-DHIP_FP8_TYPE_FNUZ=1") + +if(AMDGPU_TARGETS MATCHES "gfx950") + list(APPEND SGL_FP8_MACROS "-DHIP_FP8_TYPE_E4M3=1") + message(STATUS "Multi-arch build: Enabling both HIP_FP8_TYPE_FNUZ (gfx942) and HIP_FP8_TYPE_E4M3 (gfx950)") +elseif(AMDGPU_TARGETS MATCHES "gfx942") + message(STATUS "Single-arch build: Enabling HIP_FP8_TYPE_FNUZ for gfx942") +else() + message(FATAL_ERROR "Unsupported AMDGPU_TARGET '${AMDGPU_TARGETS}'. Expected 'gfx942' or 'gfx950' or both.") +endif() + +# TopK dynamic smem bytes +# Dynamic shared-memory budget for the TopK kernels. +# - gfx942 (MI300/MI325): LDS is typically 64KB per workgroup -> keep dynamic smem <= ~48KB +# (leaves room for static shared allocations in the kernel). +# - gfx95x (MI350): LDS is larger (e.g. 160KB per CU) -> allow the original 128KB dynamic smem. +if(AMDGPU_TARGET_ONE STREQUAL "gfx942") + math(EXPR SGL_TOPK_DYNAMIC_SMEM_BYTES "48 * 1024") +else() + math(EXPR SGL_TOPK_DYNAMIC_SMEM_BYTES "32 * 1024 * 4") +endif() + +set(SGL_TOPK_MACROS "-DSGL_TOPK_DYNAMIC_SMEM_BYTES=${SGL_TOPK_DYNAMIC_SMEM_BYTES}") + +# Paths / includes +set(PROJ_ROOT ${CMAKE_CURRENT_LIST_DIR}) +set(SGL_INCLUDE_DIRS + ${PROJ_ROOT}/include + ${PROJ_ROOT}/include/impl + ${PROJ_ROOT}/csrc + ${TORCH_INCLUDE_DIRS} +) + +# Platform-specific library directory +set(PLAT_LIB_DIR "/usr/lib/x86_64-linux-gnu") +link_directories(${PLAT_LIB_DIR}) + +# Sources +set(SOURCES +${PROJ_ROOT}/csrc/allreduce/custom_all_reduce.hip +${PROJ_ROOT}/csrc/allreduce/deterministic_all_reduce.hip +${PROJ_ROOT}/csrc/allreduce/quick_all_reduce.hip +${PROJ_ROOT}/csrc/common_extension_rocm.cc +${PROJ_ROOT}/csrc/elementwise/activation.hip +${PROJ_ROOT}/csrc/elementwise/pos_enc.hip +${PROJ_ROOT}/csrc/elementwise/topk.hip +${PROJ_ROOT}/csrc/grammar/apply_token_bitmask_inplace_hip.hip +${PROJ_ROOT}/csrc/kvcacheio/transfer.hip +${PROJ_ROOT}/csrc/moe/moe_align_kernel.hip +${PROJ_ROOT}/csrc/moe/moe_topk_softmax_kernels.hip +${PROJ_ROOT}/csrc/moe/moe_topk_sigmoid_kernels.hip +${PROJ_ROOT}/csrc/speculative/eagle_utils.hip +) +set_source_files_properties( + ${SOURCES} + PROPERTIES + LANGUAGE HIP +) + +# Compile / Link flags +add_compile_options($<$:-O3>) + +set(SGL_HIP_FLAGS + -DNDEBUG + -DOPERATOR_NAMESPACE=sgl_kernel + -O3 + -std=c++17 + -DENABLE_BF16 + -DENABLE_FP8 + ${SGL_FP8_MACROS} + -Wno-pass-failed + -Wundefined-internal + ${SGL_TOPK_MACROS} +) + +# Python extension +Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES}) +target_include_directories(common_ops PRIVATE ${SGL_INCLUDE_DIRS}) + +# Apply per-language flags +target_compile_options(common_ops PRIVATE + $<$:${SGL_HIP_FLAGS}> +) + +target_link_libraries(common_ops PRIVATE + ${TORCH_LIBRARIES} + hip::device + hip::host + hiprtc + amdhip64 +) + +target_link_options(common_ops PRIVATE + "SHELL:-Wl,-rpath,'\$ORIGIN/../../torch/lib'" +) + +install(TARGETS common_ops + LIBRARY DESTINATION sgl_kernel +) diff --git a/sglang/3rdparty/amd/sgl-kernel/build_rocm.sh b/sglang/3rdparty/amd/sgl-kernel/build_rocm.sh new file mode 100644 index 0000000000000000000000000000000000000000..1022d8bb50f363fc4a1d5d3ba428f0cb8bbb435d --- /dev/null +++ b/sglang/3rdparty/amd/sgl-kernel/build_rocm.sh @@ -0,0 +1,123 @@ +#!/bin/bash +set -euo pipefail +ROCM_VERSION=$1 + +PYTHON_ROOT_PATH="/opt/venv/bin" +AMDGPU_TARGET="gfx942;gfx950" + +echo "Python root path is: $PYTHON_ROOT_PATH" + +# Get version from git tags +SGLANG_VERSION="v0.5.6" # Default version, will be overridden if git tags are found + +# Fetch tags from origin to ensure we have the latest +if git fetch --tags origin; then + # Get the latest version tag sorted by version number (e.g., v0.5.7) + VERSION_FROM_TAG=$(git tag -l 'v[0-9]*' --sort=-v:refname | head -1) + if [ -n "$VERSION_FROM_TAG" ]; then + SGLANG_VERSION="$VERSION_FROM_TAG" + echo "Using SGLang version from git tags: $SGLANG_VERSION" + else + echo "Warning: No version tags found; using default $SGLANG_VERSION" >&2 + fi +else + echo "Warning: Failed to fetch tags from origin; using default $SGLANG_VERSION" >&2 +fi + +# Default base tags (can be overridden by command line arguments) +DEFAULT_MI30X_BASE_TAG="${SGLANG_VERSION}-rocm700-mi30x" +DEFAULT_MI35X_BASE_TAG="${SGLANG_VERSION}-rocm700-mi35x" + +# Parse command line arguments +MI30X_BASE_TAG="${DEFAULT_MI30X_BASE_TAG}" +MI35X_BASE_TAG="${DEFAULT_MI35X_BASE_TAG}" + +# Detect GPU architecture from the Kubernetes runner hostname +HOSTNAME_VALUE=$(hostname) +GPU_ARCH="mi30x" # default + +# Host names look like: linux-mi35x-gpu-1-xxxxx-runner-zzzzz +if [[ "${HOSTNAME_VALUE}" =~ ^linux-(mi[0-9]+[a-z]*)-gpu-[0-9]+ ]]; then + GPU_ARCH="${BASH_REMATCH[1]}" + echo "Detected GPU architecture from hostname: ${GPU_ARCH}" +else + echo "Warning: could not parse GPU architecture from '${HOSTNAME_VALUE}', defaulting to ${GPU_ARCH}" +fi + +case "${GPU_ARCH}" in + mi35x) + echo "Runner uses ${GPU_ARCH}; will fetch mi35x image." + ;; + mi30x|mi300|mi325) + echo "Runner uses ${GPU_ARCH}; will fetch mi30x image." + GPU_ARCH="mi30x" + ;; + *) + echo "Runner architecture '${GPU_ARCH}' unrecognised; defaulting to mi30x image." >&2 + GPU_ARCH="mi30x" + ;; +esac + +if [[ -f /etc/podinfo/gha-render-devices ]]; then + DEVICE_FLAG=$(cat /etc/podinfo/gha-render-devices) +else + DEVICE_FLAG="--device /dev/dri" +fi + +# Find the latest image +find_latest_image() { + local gpu_arch=$1 + local base_tag days_back image_tag + + case "${gpu_arch}" in + mi30x) base_tag="${MI30X_BASE_TAG}" ;; + mi35x) base_tag="${MI35X_BASE_TAG}" ;; + *) echo "Error: unsupported GPU architecture '${gpu_arch}'" >&2; return 1 ;; + esac + + for days_back in {0..6}; do + image_tag="${base_tag}-$(date -d "${days_back} days ago" +%Y%m%d)" + echo "Checking for image: rocm/sgl-dev:${image_tag}" >&2 + if docker manifest inspect "rocm/sgl-dev:${image_tag}" >/dev/null 2>&1; then + echo "Found available image: rocm/sgl-dev:${image_tag}" >&2 + echo "rocm/sgl-dev:${image_tag}" + return 0 + fi + done + + echo "Error: no ${gpu_arch} image found in the last 7 days for base ${base_tag}" >&2 + echo "Using hard-coded fallback…" >&2 + if [[ "${gpu_arch}" == "mi35x" ]]; then + echo "rocm/sgl-dev:v0.5.3-rocm700-mi35x-20251009" + else + echo "rocm/sgl-dev:v0.5.3-rocm700-mi30x-20251009" + fi +} + +# Pull and run the latest image +IMAGE=$(find_latest_image "${GPU_ARCH}") +echo "Pulling Docker image: ${IMAGE}" +docker pull "${IMAGE}" + +docker run --rm \ + -v $(pwd):/sgl-kernel \ + -e AMDGPU_TARGET="${AMDGPU_TARGET}" \ + ${IMAGE} \ + bash -c " + # Install CMake (version >= 3.26) - Robust Installation + export CMAKE_VERSION_MAJOR=3.31 + export CMAKE_VERSION_MINOR=1 + echo \"Downloading CMake from: https://cmake.org/files/v\${CMAKE_VERSION_MAJOR}/cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-x86_64.tar.gz\" + wget https://cmake.org/files/v\${CMAKE_VERSION_MAJOR}/cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-x86_64.tar.gz + tar -xzf cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-x86_64.tar.gz + mv cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-x86_64 /opt/cmake + export PATH=/opt/cmake/bin:\$PATH + + ${PYTHON_ROOT_PATH}/pip install --no-cache-dir ninja setuptools wheel numpy uv scikit-build-core && \ + + cd /sgl-kernel && \ + rm -rf CMakeLists.txt && mv CMakeLists_rocm.txt CMakeLists.txt && \ + ${PYTHON_ROOT_PATH}/python rocm_hipify.py && \ + ${PYTHON_ROOT_PATH}/python -m uv build --wheel -Cbuild-dir=build . --color=always --no-build-isolation && \ + ./rename_wheels_rocm.sh +" diff --git a/sglang/3rdparty/amd/sgl-kernel/rename_wheels_rocm.sh b/sglang/3rdparty/amd/sgl-kernel/rename_wheels_rocm.sh new file mode 100644 index 0000000000000000000000000000000000000000..691407a3e63f9db51ef9e81379bd8b5abe098e7a --- /dev/null +++ b/sglang/3rdparty/amd/sgl-kernel/rename_wheels_rocm.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +set -ex + +WHEEL_DIR="dist" + +wheel_files=($WHEEL_DIR/*.whl) +for wheel in "${wheel_files[@]}"; do + intermediate_wheel="${wheel/linux/manylinux2014}" + + # Extract the current python version from the wheel name + if [[ $intermediate_wheel =~ -cp([0-9]+)- ]]; then + cp_version="${BASH_REMATCH[1]}" + else + echo "Could not extract Python version from wheel name: $intermediate_wheel" + continue + fi + + # Detect ROCm version and add appropriate suffix + if ls /opt | grep -q "7.0"; then + new_wheel="${intermediate_wheel/-cp${cp_version}/+rocm700-cp${cp_version}}" + else + new_wheel="$intermediate_wheel" + fi + + if [[ "$wheel" != "$new_wheel" ]]; then + echo "Renaming $wheel to $new_wheel" + mv -- "$wheel" "$new_wheel" + fi +done +echo "Wheel renaming completed." diff --git a/sglang/3rdparty/amd/sgl-kernel/rocm_hipify.py b/sglang/3rdparty/amd/sgl-kernel/rocm_hipify.py new file mode 100644 index 0000000000000000000000000000000000000000..8373f741d1d60b971e9b2b4503e27835673d535c --- /dev/null +++ b/sglang/3rdparty/amd/sgl-kernel/rocm_hipify.py @@ -0,0 +1,40 @@ +from pathlib import Path + +import torch +from torch.utils.cpp_extension import CUDAExtension + +root = Path(__file__).parent.resolve() + +include_dirs = [ + root / "include", + root / "include" / "impl", + root / "csrc", +] + +sources = [ + "csrc/allreduce/custom_all_reduce.hip", + "csrc/allreduce/deterministic_all_reduce.hip", + "csrc/allreduce/quick_all_reduce.cu", + "csrc/common_extension_rocm.cc", + "csrc/elementwise/activation.cu", + "csrc/elementwise/pos_enc.cu", + "csrc/elementwise/topk.cu", + "csrc/grammar/apply_token_bitmask_inplace_cuda.cu", + "csrc/kvcacheio/transfer.cu", + "csrc/moe/moe_align_kernel.cu", + "csrc/moe/moe_topk_softmax_kernels.cu", + "csrc/moe/moe_topk_sigmoid_kernels.cu", + "csrc/speculative/eagle_utils.cu", +] + +libraries = ["hiprtc", "amdhip64", "c10", "torch", "torch_python"] + +ext_modules = [ + CUDAExtension( + name="sgl_kernel.common_ops", + sources=sources, + include_dirs=include_dirs, + libraries=libraries, + py_limited_api=False, + ), +] diff --git a/sglang/3rdparty/amd/tuning/TUNING.md b/sglang/3rdparty/amd/tuning/TUNING.md new file mode 100644 index 0000000000000000000000000000000000000000..a903bba03eca30efc8076cc878258ca6278290cf --- /dev/null +++ b/sglang/3rdparty/amd/tuning/TUNING.md @@ -0,0 +1,118 @@ +## Tuning SGLang Infer System with AMD GPUs +This AppNote describes the SGLang performance tuning technical, code harness and running steps for systems with AMD Instinct GPUs. +Harness code, examples and steps are provided in detail, to facilitate easy reproduce & use to tune performance towards workloads. +Three primary runtime areas are covered: + +## 1. Triton Kernels +To maximize Triton kernel efficiency, several strategies can be employed: + +### Key Environment Variables: +- **num_stages**: Adjusts the number of pipeline stages to optimize kernel efficiency based on the specific type of operations (e.g., General Matrix Multiplication - GEMM). +- **waves_per_eu**: Controls the usage of Vector General Purpose Registers (VGPR) to enhance occupancy, thereby improving latency or throughput. +- **BLOCK_M, BLOCK_N, BLOCK_K**: Tunable tile sizes that assist in balancing memory transfer and computational efficiency. +- **matrix_instr_nonkdim**: Optimizes the usage of Matrix-Fused Multiply-Add (MFMA) instructions for specific kernel types, such as Flash Attention. +- **OPTIMIZE_EPILOGUE**: An environment variable that can be set to `1` to enhance performance by eliminating the `convert_layout` operation in the kernel's epilogue. +```python +@triton.autotune(configs=[ + triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1), + triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1), + triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1), + triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1), + triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1), + triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1), + triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1), + triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1), + triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1), + ], key=['BLOCK_N', 'NUM_TOKEN_BLKS'], use_cuda_graph=True) +@triton.jit +def _triton_kernel_function(): + ... +``` +## 2. Torch Tunable Operations +**TunableOp** is a feature in PyTorch that allows for the definition and optimization of custom kernels with tunable parameters. This feature is particularly useful for enhancing the performance of kernels by experimenting with different configurations. + +### Key Environment Variables: +1. **PYTORCH_TUNABLEOP_ENABLED**: + - Default: `0` + - Set to `1` to enable TunableOp. + +2. **PYTORCH_TUNABLEOP_TUNING**: + - Default: `1` + - Set to `0` to disable tuning. If a tuned entry is not found, it will run the tuning step and record the entry when PYTORCH_TUNABLEOP_ENABLED is enabled. + +3. **PYTORCH_TUNABLEOP_VERBOSE**: + - Default: `0` + - Set to `1` to enable verbose output for TunableOp. + +### Usage Example: +To enable TunableOp and tuning, and optionally enable verbose mode, you can run the following command in your terminal: + +```bash +#Tuning +PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=1 your_script.sh + +#Inference with tuning op +PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 your_script.sh + +#Print out the log +PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 PYTORCH_TUNABLEOP_VERBOSE=1 your_script.sh + +``` +## 3. Torch Compilation + + +The following are suggestions for optimizing matrix multiplication (GEMM) and convolution (conv) operations in PyTorch using Inductor, a part of the PyTorch compilation framework. The goal is to leverage Triton to achieve better performance. + +To tune Triton kernels with GEMM and convolution ops (conv), use the `torch.compile` function with the max-autotune mode. This benchmarks a predefined list of Triton configurations and selects the fastest one for each shape. + +### Key Configurations: +1. **Max Autotune**: + - Set `torch._inductor.config.max_autotune = True` or `TORCHINDUCTOR_MAX_AUTOTUNE=1`. + +2. **Fine-Grained Control**: + - Enable GEMM tuning: `torch._inductor.config.max_autotune_gemm = True`. + - Enable tuning for pointwise/reduction ops: `torch._inductor.config.max_autotune.pointwise = True`. + +3. **Backend Selection**: + - Use `torch._inductor.max_autotune_gemm_backends` to limit backends to TRITON for better performance. + +4. **Freezing for Inference**: + - Use `torch._inductor.config.freezing=True` to enable constant folding optimizations. + +5. **Debugging**: + - Set `TORCH_COMPILE_DEBUG=1` to extract Triton kernels generated by Inductor. + +### Example Code Block: +```bash +#Gemm Tuning +TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 your_script.sh + +#Specify your backend to TRITON for Gemm Tuning +TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=TRITON your_script.sh + +#Inference with large improvement on AMD GPU +TORCHINDUCTOR_FREEZING=1 your_script.sh +``` +## 4. Fused MOE kernel +To maximize moe kernel efficiency, need to use below scripts to find out the best launch configuration + +### Key parameters: +- **--model**: what moe model type to do tuning, it will automatically decide the size of d_model, model_intermediate_size, num_layers +- **--tp-size**: simulate the whole model run configuration to set the dimension size using tp correctly +- **--batch**: M dimension size of moe kernel, for prefill moe kernel the value is batch*input_len, for decode moe kernel the value is batch +- **--dtype**: computation type + +```bash +#Tuning +#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quantization fp8" to run, it defined batch-size 32 input length 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run). +#so we can tune decode moe use below command +python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32" +# and use this command to tune prefill moe +python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32768" +``` + +## Reference + +For more detailed information on tuning SGLang performance with AMD GPUs, please refer to the following link: + +[ROCm Documentation: Triton Kernel Performance Optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#triton-kernel-performance-optimization) diff --git a/sglang/3rdparty/amd/tuning/benchmark_moe_rocm.py b/sglang/3rdparty/amd/tuning/benchmark_moe_rocm.py new file mode 100644 index 0000000000000000000000000000000000000000..131b25270ab874960c3f436d6af93e2addf77bc8 --- /dev/null +++ b/sglang/3rdparty/amd/tuning/benchmark_moe_rocm.py @@ -0,0 +1,378 @@ +import argparse +import json +import os +import sys + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from tqdm import tqdm +from transformers import AutoConfig + +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + fused_moe, + get_config_file_name, +) + +padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0 + + +def main(model, tp_size, dtype: str, batches): + method = fused_moe + + for bs in batches: + run_grid(int(bs), model=model, method=method, tp_size=tp_size, dtype=dtype) + + +def prune_configs(M, N, K, configs): + pruned_configs = [] + elemBytes_a = 1 # [DV Note] Hard-coded for float16 (2 bytes) + elemBytes_b = 1 # [DV Note] Hard-coded for float16 (2 bytes) + + mfma = 16 if M < 32 or N < 32 else 32 + + # TODO (zhanglx): figure out the boundary between large and small gemms + large_gemm = False + if M >= 2048 and N >= 2048: + large_gemm = True + + for config in configs: + BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") + BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") + BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") + num_warps = config.get("num_warps") + matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") + # kpack = config.get("kpack") + if matrix_instr_nonkdim > mfma: + continue + if mfma == 4 and BLOCK_SIZE_K < 64: + continue + # some layouts could not work properly in case + # number elements per thread is less 1 + if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: + continue + SPLIT_K = 1 # config.get("SPLIT_K") + GROUP_M = config.get("GROUP_SIZE_M") + if matrix_instr_nonkdim > BLOCK_SIZE_M or matrix_instr_nonkdim > BLOCK_SIZE_N: + continue + if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M: + continue + if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N: + continue + # Skip BLOCK_SIZE that is too large compare to M/N + # unless BLOCK_SIZE is already small enough + if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: + continue + if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: + continue + # skip large split_k when not necessary + if SPLIT_K != 1 and not need_split_k(M, N, K): + continue + # skip split_k that leads to EVEN_K = false + leap = SPLIT_K * BLOCK_SIZE_K + modv = K % leap + if modv != 0: + continue + # skip large GROUP_M + if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: + continue + # out of shared memory resource + # TODO (zhanglx): This does not consider the LDS usage in the epilogue + LDS = ( + BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b + ) + if LDS > 65536: + continue + # Skip small block sizes and num_warps for large gemm + # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 + if large_gemm: + if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: + continue + if BLOCK_SIZE_K < 64: + continue + if num_warps < 4: + continue + + pruned_configs.append(config) + + return pruned_configs + + +def union_of_list_of_dicts(l1, l2): + result = [] + temp_list = l1.copy() + temp_list.extend(l2) + for myDict in temp_list: + if myDict not in result: + result.append(myDict) + + return result + + +def run_grid(bs, model, method, tp_size, dtype: str): + + config = AutoConfig.from_pretrained(model) + + top_k = config.num_experts_per_tok + d_model = config.hidden_size + model_intermediate_size = config.intermediate_size + num_layers = config.num_hidden_layers + hidden_states_dtype = config.torch_dtype + + if config.num_experts_per_tok: + if config.architectures[0] == "Grok1ModelForCausalLM": + num_total_experts = config.num_experts + else: + num_total_experts = config.num_local_experts + else: + raise ValueError(f"Unsupported Mixtral model {model}") + + # tp_size = 2 + num_warmup_calls = 10 + num_calls = 30 + + num_warmup_trials = 1 + num_trials = 1 + + full_configs = [] + + block_m_range = [16, 32, 64, 128, 256] + block_n_range = [16, 32, 64, 128, 256] + block_k_range = [32, 64, 128, 256] # MUST >= 32 + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 4, 8, 16, 32] + # For now we see better perf with num_stages=0 for all gemm configs we care + # But keep this explicit so that we do not forget we may need to set it to + # other values in the future + num_stage_range = [2] + waves_per_eu_range = [0, 1, 2, 4, 8] + # Remove 32 because of triton compiling error + matrix_instr_nonkdim_range = [16] + kpack_range = [1, 2] + + for block_size_m in block_m_range: + for block_size_n in block_n_range: + for block_size_k in block_k_range: + for group_size_m in group_m_range: + for num_warps in num_warps_range: + for num_stages in num_stage_range: + for waves_per_eu in waves_per_eu_range: + for matrix_instr_nonkdim in matrix_instr_nonkdim_range: + for kpack in kpack_range: + full_configs.append( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, + "waves_per_eu": waves_per_eu, + "matrix_instr_nonkdim": matrix_instr_nonkdim, + "kpack": kpack, + } + ) + + M1 = bs * 2 + N1 = model_intermediate_size * 2 // tp_size + K1 = d_model + prune_configs_1 = prune_configs(M1, N1, K1, full_configs) + + M2 = bs * 2 + N2 = d_model + K2 = model_intermediate_size // tp_size + prune_configs_2 = prune_configs(M2, N2, K2, full_configs) + + configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2) + + print(f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \ + {len(prune_configs_2)=} | {len(configs)=}") + + best_config = None + best_time_us = 1e20 + + print(f"{tp_size=} {bs=}") + + for config in tqdm(configs): + # warmup + try: + print(config) + for _ in range(num_warmup_trials): + run_timing( + num_calls=num_warmup_calls, + bs=bs, + d_model=d_model, + num_total_experts=num_total_experts, + top_k=top_k, + tp_size=tp_size, + model_intermediate_size=model_intermediate_size, + method=method, + config=config, + dtype=dtype, + hidden_states_dtype=hidden_states_dtype, + ) + except triton.runtime.autotuner.OutOfResources: + continue + + # trial + for _ in range(num_trials): + kernel_dur_ms = run_timing( + num_calls=num_calls, + bs=bs, + d_model=d_model, + num_total_experts=num_total_experts, + top_k=top_k, + tp_size=tp_size, + model_intermediate_size=model_intermediate_size, + method=method, + config=config, + dtype=dtype, + hidden_states_dtype=hidden_states_dtype, + ) + + kernel_dur_us = 1000 * kernel_dur_ms + model_dur_ms = kernel_dur_ms * num_layers + + if kernel_dur_us < best_time_us: + best_config = config + best_time_us = kernel_dur_us + + tqdm.write( + f"{kernel_dur_us=:.1f} {model_dur_ms=:.1f}" + f" {bs=} {tp_size=} {top_k=} {num_total_experts=} " + f"{d_model=} {model_intermediate_size=} {num_layers=}" + ) + + print("best_time_us", best_time_us) + print("best_config", best_config) + + # holds Dict[str, Dict[str, int]] + filename = get_config_file_name( + num_total_experts, + model_intermediate_size // tp_size, + "float8" if dtype == "float8" else None, + ) + print(f"writing config to file {filename}") + existing_content = {} + if os.path.exists(filename): + with open(filename, "r") as f: + existing_content = json.load(f) + existing_content[str(bs)] = best_config + with open(filename, "w") as f: + json.dump(existing_content, f, indent=4) + f.write("\n") + + +def run_timing( + num_calls: int, + bs: int, + d_model: int, + num_total_experts: int, + top_k: int, + tp_size: int, + model_intermediate_size: int, + method, + config, + dtype: str, + hidden_states_dtype, +) -> float: + shard_intermediate_size = model_intermediate_size // tp_size + + hidden_states = torch.rand( + (bs, d_model), + device="cuda:0", + dtype=hidden_states_dtype, + ) + + w1 = torch.rand( + (num_total_experts, 2 * shard_intermediate_size, d_model + padding_size), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + w2 = torch.rand( + (num_total_experts, d_model, shard_intermediate_size + padding_size), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + + if dtype == "float8": + w1 = w1.to(torch.float8_e4m3fnuz) + w2 = w2.to(torch.float8_e4m3fnuz) + w1_scale = torch.ones( + num_total_experts, device=hidden_states.device, dtype=torch.float32 + ) + w2_scale = torch.ones( + num_total_experts, device=hidden_states.device, dtype=torch.float32 + ) + a1_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32) + a2_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32) + + gating_output = F.softmax( + torch.rand( + (num_calls, bs, num_total_experts), + device=hidden_states.device, + dtype=torch.float32, + ), + dim=-1, + ) + + ################################## + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for i in range(num_calls): + hidden_states = method( + hidden_states=hidden_states, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + gating_output=gating_output[0], + topk=top_k, + renormalize=True, + inplace=True, + override_config=config, + use_fp8=dtype == "float8", + ) + + end_event.record() + end_event.synchronize() + + dur_ms = start_event.elapsed_time(end_event) / num_calls + return dur_ms + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="benchmark_mixtral_moe", + description="Benchmark and tune the fused_moe kernel", + ) + parser.add_argument( + "--dtype", + type=str, + default="auto", + choices=["float8", "float16", "bfloat16"], + help="Data type used for fused_moe kernel computations", + ) + parser.add_argument("--model", type=str, default="hpcai-tech/grok-1") + + parser.add_argument("--tp-size", type=int, default=2, help="Tensor paralleli size") + parser.add_argument("-b", "--batches", type=str) + + args = parser.parse_args() + + batches = args.batches.split(",") + + sys.exit(main(args.model, args.tp_size, args.dtype, batches)) diff --git a/sglang/docs/supported_models/extending/modelscope.md b/sglang/docs/supported_models/extending/modelscope.md new file mode 100644 index 0000000000000000000000000000000000000000..4740c2770f9e9d910b8269e5431b604be9a3f740 --- /dev/null +++ b/sglang/docs/supported_models/extending/modelscope.md @@ -0,0 +1,28 @@ +# Use Models From ModelScope + +To use a model from [ModelScope](https://www.modelscope.cn), set the environment variable `SGLANG_USE_MODELSCOPE`. + +```bash +export SGLANG_USE_MODELSCOPE=true +``` + +We take [Qwen2-7B-Instruct](https://www.modelscope.cn/models/qwen/qwen2-7b-instruct) as an example. + +Launch the Server: +```bash +python -m sglang.launch_server --model-path qwen/Qwen2-7B-Instruct --port 30000 +``` + +Or start it by docker: + +```bash +docker run --gpus all \ + -p 30000:30000 \ + -v ~/.cache/modelscope:/root/.cache/modelscope \ + --env "SGLANG_USE_MODELSCOPE=true" \ + --ipc=host \ + lmsysorg/sglang:latest \ + python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0 --port 30000 +``` + +Note that modelscope uses a different cache directory than huggingface. You may need to set it manually to avoid running out of disk space. diff --git a/sglang/docs/supported_models/extending/support_new_models.md b/sglang/docs/supported_models/extending/support_new_models.md new file mode 100644 index 0000000000000000000000000000000000000000..bc683a636530eca52b9795b9013df686c141bb1b --- /dev/null +++ b/sglang/docs/supported_models/extending/support_new_models.md @@ -0,0 +1,320 @@ +# How to Support New Models + +This document explains how to add support for new language models and multimodal large language models (MLLMs) in +SGLang. It also covers how to test new models and register external implementations. + +## How to Support a New Language Model + +To support a new model in SGLang, you only need to add a single file under +the [SGLang Models Directory](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/models). You can learn +from existing model implementations and create a new file for your model. For most models, you should be able to find a +similar model to start with (e.g., starting from Llama). Also refer how +to [port a Model from vLLM to SGLang](#port-a-model-from-vllm-to-sglang) + +## How to Support a New Multimodal Large Language Model + +To support a new multimodal large language model (MLLM) in SGLang, there are several key components in addition to the +standard LLM support: + +1. **Register your new model as multimodal**: + Extend `is_multimodal_model` + in [model_config.py](https://github.com/sgl-project/sglang/blob/0ab3f437aba729b348a683ab32b35b214456efc7/python/sglang/srt/configs/model_config.py#L561) + to return `True` for your model. + +2. **Register a new chat-template**: + Only when your default chat-template is unable to accept images as input: Register a new chat template in [conversation.py](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/conversation.py) and the corresponding matching function. + +3. **Multimodal Data Processor**: + Define a new `Processor` class that inherits from `BaseMultimodalProcessor` and register this processor as your + model’s dedicated processor. + See [multimodal_processor.py](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/multimodal/processors) + for more details. + +4. **Handle Multimodal Tokens**: + Implement a `pad_input_ids` function for your new model. In this function, multimodal tokens in the prompt should be + expanded (if necessary) and padded with multimodal-data-hashes so that SGLang can recognize different multimodal data + with `RadixAttention`. + +5. **Handle Image Feature Extraction**: + Implement a `get_image_feature` function for your new model, which extracts image features from raw image data and converts them into the embeddings used by the language model. + +6. **Adapt to Vision Attention**: + Adapt the multi-headed `Attention` of ViT with SGLang’s `VisionAttention`. + +You can refer to [Qwen2VL](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/qwen2_vl.py) or +other mllm implementations. These models demonstrate how to correctly handle both multimodal and textual inputs. + +## Testing and Debugging + +Please note all your testing and benchmarking results in PR description. + +### Interactive Debugging + +For interactive debugging, compare the outputs of Hugging Face/Transformers and SGLang. The following two commands +should give the same text output and very similar prefill logits: + +- Get the reference output: + ```bash + python3 scripts/playground/reference_hf.py --model-path [new model] --model-type {text,mllm} + ``` +- Get the SGLang output: + ```bash + python3 -m sglang.bench_one_batch --correct --model [new model] + ``` + +### Add the Model to the Test Suite + +To ensure the new model is well maintained, add it to the test suite by including it in the `ALL_OTHER_MODELS` list in +the [test_generation_models.py](https://github.com/sgl-project/sglang/blob/main/test/srt/models/test_generation_models.py) +file, test the new model on your local machine and report the results on demonstrative benchmarks (GSM8K, MMLU, MMMU, +MMMU-Pro, etc.) in your PR. \\ +For VLMs, also include a test in `test_vision_openai_server_{x}.py` (e.g. [test_vision_openai_server_a.py](https://github.com/sgl-project/sglang/blob/main/test/srt/test_vision_openai_server_a.py), [test_vision_openai_server_b.py](https://github.com/sgl-project/sglang/blob/main/test/srt/test_vision_openai_server_b.py)). + +This is an example command to run to test a new model on your local machine: + +```bash +ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others +``` + +### Benchmark + +- **(Required) MMMU**: follow MMMU benchmark [README.md](https://github.com/sgl-project/sglang/blob/main/benchmark/mmmu/README.md) to get SGLang vs. HF Transformer accuracy comparison. The accuracy score from SGLang run should not be much lower than that from HF Transformer run. Similarly, follow https://docs.sglang.io/developer_guide/benchmark_and_profiling.html to get performance comparison: TTFT and throughput must meet or exceed baselines (e.g., HF Transformer). +- **(Optional) Other evals**: If you ran other evals, please note the results in PR description. + +## Port a Model from vLLM to SGLang + +The [vLLM Models Directory](https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models) is a valuable +resource, as vLLM covers many models. SGLang reuses vLLM’s interface and some layers, making it easier to port models +from vLLM to SGLang. + +To port a model from vLLM to SGLang: + +- Compare these two files for guidance: + - [SGLang Llama Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py) + - [vLLM Llama Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py) +- The major differences include: + - **Replace vLLM’s `Attention` with `RadixAttention`** (ensure you pass `layer_id` to `RadixAttention`). + - **Replace vLLM’s `LogitsProcessor` with SGLang’s `LogitsProcessor`.** + - **Replace the multi-headed `Attention` of ViT with SGLang’s `VisionAttention`.** + - **Replace other vLLM layers** (such as `RMSNorm`, `SiluAndMul`) with SGLang layers. + - **Remove `Sample`.** + - **Change the `forward()` functions** and add a `forward_batch()` method. + - **Add `EntryClass`** at the end. + - **Ensure that the new implementation uses only SGLang components** and does not rely on any vLLM components. + +Note: make sure you add your new model to the supported models list in the supported models documentation. + +## Registering an External Model Implementation + +In addition to the methods above, you can register your new model with the `ModelRegistry` before launching the server. +This allows you to integrate your model without modifying the source code. + +For example: + +```python +from sglang.srt.models.registry import ModelRegistry +from sglang.srt.entrypoints.http_server import launch_server + +# For a single model, add it to the registry: +ModelRegistry.models[model_name] = model_class + +# For multiple models, you can imitate the import_model_classes() function: +from functools import lru_cache + +@lru_cache() +def import_new_model_classes(): + model_arch_name_to_cls = {} + # Populate model_arch_name_to_cls with your new model classes. + ... + return model_arch_name_to_cls + +ModelRegistry.models.update(import_new_model_classes()) + +# Launch the server with your server arguments: +launch_server(server_args) +``` + +## Example: Implementing and Serving a Llama Wrapper Model + +Below is an introductory, step-by-step walkthrough on how to implement a new model end-to-end in SGLang and then run it via the [Offline Engine](https://github.com/sgl-project/sglang/blob/main/docs/basic_usage/offline_engine_api.ipynb). + +### Implementing Our Model + +To keep things simple, this new model will be a simple wrapper around [Llama 3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct), and our goal will be just to bias the output logits for each `forward` call by taking the square root of each individual logit. + +Let's start by defining our model in a file called `llama_wrapper.py`. +The first step is to import the necessary libraries from SRT, which is SGLang's internal backend. + +```python +# In the file `llama_wrapper.py` + +import torch +from transformers import LlamaConfig +from typing import Optional +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors + +from sglang.srt.models.llama import LlamaForCausalLM +``` + +Next, we declare a new `class` for our model and have it inherit from `LlamaForCausalLM`, which allows our model to access `LlamaForCausalLM`'s predefined modules and layers, such as `LlamaAttention` and `LlamaMLP`. +Note that almost all model implementations take in `config` and `quant_config` as arguments for their `__init__` method; `config` and `quant_config` are passed in via [`model_loader/loader.py`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_loader/loader.py#L219). +Because we have inherited from `LlamaForCausalLM`, we can pass our parameters directly to its constructor, which will set the member variables for us. + +```python +class LlamaWrapper(LlamaForCausalLM): + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config, quant_config=quant_config, prefix=prefix) +``` + +Now, we want to define the `forward` method, which is what will be called at inference time. +Note that the signature for `forward` is essentially the same for any model; you can take a look at the other models defined in the [`models` directory](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/) for references. +To see where exactly `forward` is called in the SGLang runtime's internals, take a look at [`forward_decode`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_executor/model_runner.py#L1705) and [`forward_extend`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_executor/model_runner.py#L1724) in the [`ModelRunner` class](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_executor/model_runner.py). + +```python + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + input_embeds: Optional[torch.Tensor] = None, + get_embedding: bool = False, + ) -> LogitsProcessorOutput: +``` + +We now call the `__call__` method for `self.model` (which is a member variable that `LlamaForCausalLM` defines in its `__init__` method), which eventually calls `LlamaForCausalLM`'s `forward` method. +After that, we feed the `hidden_states` into our model's `LogitsProcessor` (again defined in `LlamaForCausalLM`). + +```python + hidden_states = self.model( + input_ids, + positions, + forward_batch, + input_embeds, + pp_proxy_tensors=pp_proxy_tensors, + ) + + res: LogitsProcessorOutput = self.logits_processor( + input_ids, + hidden_states, + self.lm_head, + forward_batch, + ) +``` + +After receiving the logits for the next token, we can finally perform our biasing step. + +```python + orig_logits = res.next_token_logits + res.next_token_logits = torch.where( + orig_logits > 0, + orig_logits.sqrt(), + orig_logits + ) + + return res +``` + +Now, our `LlamaWrapper` model is created and ready to be served! + +### Serving Our Model Via SGLang's Offline Engine + +The next step of this walkthrough involves hosting our new model offline, so that it can be served locally and without an HTTP server. + +First, create a new file called `run.py`. +Now, we must ensure that SGLang's `ModelRegistry` can find our model. +To do this, we first download the model's configuration and weights from Huggingface. + +```python +# In the file `run.py` + +import asyncio +from functools import lru_cache +from huggingface_hub import snapshot_download +from llama_wrapper import LlamaWrapper # Make sure to import our new model! +import sglang as sgl +from sglang.srt.models.registry import ModelRegistry + +# Make sure to request access to this model on Huggingface, then export your +# `HF_TOKEN` to download the model snapshot +llama_dir = snapshot_download( + repo_id="meta-llama/Llama-3.1-8B-Instruct", + local_dir="./llama_ckpt", +) +``` + +Now that we have our model on disk, we want to point it to `LlamaWrapper` by changing the `architectures` field in `./llama_ckpt/config.json` to be `LlamaWrapper`. +That way, when we pass in the path of our model checkpoint to SGLang, it will know that we want to use "LlamaWrapper" instead of "LlamaForCausalLM" as our model. + +```python +{ + "architectures": [ + # "LlamaForCausalLM" + "LlamaWrapper" + ], + ... +} +``` + +However, if we don't link our `LlamaWrapper` class to the "LlamaWrapper" registry keyword, then SGLang won't be able to find our model. +Thus, to register our `LlamaWrapper`, we want to follow the steps in the above section titled "Registering an External Model Implementation". + +```python +@lru_cache() +def import_new_model_classes(): + model_arch_name_to_cls = {"LlamaWrapper": LlamaWrapper} + return model_arch_name_to_cls + +ModelRegistry.models.update(import_new_model_classes()) +``` + +Lastly, when we create our `Engine`, we just pass in the path to the local model directory. +Then, our `LlamaWrapper` is ready to be served; for this walkthrough, we will use SGLang `Engine`'s non-streaming asynchronous generation endpoint. + +```python +def main(): + llm = sgl.Engine(model_path="./llama_ckpt") + sampling_params = {"temperature": 0.2, "top_k": 5} + prompts = [ + "Write a short, neutral self-introduction for a fictional character. Hello, my name is", + "Provide a concise factual statement about France’s capital city. The capital of France is", + "Explain possible future trends in artificial intelligence. The future of AI is", + ] + + asyncio.run(run_llm(llm, sampling_params, prompts)) + + llm.shutdown() + +async def run_llm( + llm, + sampling_params, + prompts, +) -> None: + outputs = await llm.async_generate(prompts, sampling_params) + + for prompt, output in zip(prompts, outputs): + print(f"\nPrompt: {prompt}") + print(f"Generated text: {output['text']}") + +if __name__ == "__main__": + main() +``` + +Now, when we call `python run.py`, we will get the outputs of our newly created model! + +## Documentation + +Add to table of supported models in [generative_models.md](../text_generation/generative_models.md) or [multimodal_language_models.md](../text_generation/multimodal_language_models.md) + +--- + +By following these guidelines, you can add support for new language models and multimodal large language models in +SGLang and ensure they are thoroughly tested and easily integrated into the system. diff --git a/sglang/docs/supported_models/retrieval_ranking/classify_models.md b/sglang/docs/supported_models/retrieval_ranking/classify_models.md new file mode 100644 index 0000000000000000000000000000000000000000..9b3c6a5914f09f13ab9ae73f2b988bb25309096f --- /dev/null +++ b/sglang/docs/supported_models/retrieval_ranking/classify_models.md @@ -0,0 +1,162 @@ +# Classification API + +This document describes the `/v1/classify` API endpoint implementation in SGLang, which is compatible with vLLM's classification API format. + +## Overview + +The classification API allows you to classify text inputs using classification models. This implementation follows the same format as vLLM's 0.7.0 classification API. + +## API Endpoint + +``` +POST /v1/classify +``` + +## Request Format + +```json +{ + "model": "model_name", + "input": "text to classify" +} +``` + +### Parameters + +- `model` (string, required): The name of the classification model to use +- `input` (string, required): The text to classify +- `user` (string, optional): User identifier for tracking +- `rid` (string, optional): Request ID for tracking +- `priority` (integer, optional): Request priority + +## Response Format + +```json +{ + "id": "classify-9bf17f2847b046c7b2d5495f4b4f9682", + "object": "list", + "created": 1745383213, + "model": "jason9693/Qwen2.5-1.5B-apeach", + "data": [ + { + "index": 0, + "label": "Default", + "probs": [0.565970778465271, 0.4340292513370514], + "num_classes": 2 + } + ], + "usage": { + "prompt_tokens": 10, + "total_tokens": 10, + "completion_tokens": 0, + "prompt_tokens_details": null + } +} +``` + +### Response Fields + +- `id`: Unique identifier for the classification request +- `object`: Always "list" +- `created`: Unix timestamp when the request was created +- `model`: The model used for classification +- `data`: Array of classification results + - `index`: Index of the result + - `label`: Predicted class label + - `probs`: Array of probabilities for each class + - `num_classes`: Total number of classes +- `usage`: Token usage information + - `prompt_tokens`: Number of input tokens + - `total_tokens`: Total number of tokens + - `completion_tokens`: Number of completion tokens (always 0 for classification) + - `prompt_tokens_details`: Additional token details (optional) + +## Example Usage + +### Using curl + +```bash +curl -v "http://127.0.0.1:8000/v1/classify" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "jason9693/Qwen2.5-1.5B-apeach", + "input": "Loved the new café—coffee was great." + }' +``` + +### Using Python + +```python +import requests +import json + +# Make classification request +response = requests.post( + "http://127.0.0.1:8000/v1/classify", + headers={"Content-Type": "application/json"}, + json={ + "model": "jason9693/Qwen2.5-1.5B-apeach", + "input": "Loved the new café—coffee was great." + } +) + +# Parse response +result = response.json() +print(json.dumps(result, indent=2)) +``` + +## Supported Models + +The classification API works with any classification model supported by SGLang, including: + +### Classification Models (Multi-class) +- `LlamaForSequenceClassification` - Multi-class classification +- `Qwen2ForSequenceClassification` - Multi-class classification +- `Qwen3ForSequenceClassification` - Multi-class classification +- `BertForSequenceClassification` - Multi-class classification +- `Gemma2ForSequenceClassification` - Multi-class classification + +**Label Mapping**: The API automatically uses the `id2label` mapping from the model's `config.json` file to provide meaningful label names instead of generic class names. If `id2label` is not available, it falls back to `LABEL_0`, `LABEL_1`, etc., or `Class_0`, `Class_1` as a last resort. + +### Reward Models (Single score) +- `InternLM2ForRewardModel` - Single reward score +- `Qwen2ForRewardModel` - Single reward score +- `LlamaForSequenceClassificationWithNormal_Weights` - Special reward model + +**Note**: The `/classify` endpoint in SGLang was originally designed for reward models but now supports all non-generative models. Our `/v1/classify` endpoint provides a standardized vLLM-compatible interface for classification tasks. + +## Error Handling + +The API returns appropriate HTTP status codes and error messages: + +- `400 Bad Request`: Invalid request format or missing required fields +- `500 Internal Server Error`: Server-side processing error + +Error response format: +```json +{ + "error": "Error message", + "type": "error_type", + "code": 400 +} +``` + +## Implementation Details + +The classification API is implemented using: + +1. **Rust Model Gateway**: Handles routing and request/response models in `sgl-model-gateway/src/protocols/spec.rs` +2. **Python HTTP Server**: Implements the actual endpoint in `python/sglang/srt/entrypoints/http_server.py` +3. **Classification Service**: Handles the classification logic in `python/sglang/srt/entrypoints/openai/serving_classify.py` + +## Testing + +Use the provided test script to verify the implementation: + +```bash +python test_classify_api.py +``` + +## Compatibility + +This implementation is compatible with vLLM's classification API format, allowing seamless migration from vLLM to SGLang for classification tasks. diff --git a/sglang/docs/supported_models/retrieval_ranking/embedding_models.md b/sglang/docs/supported_models/retrieval_ranking/embedding_models.md new file mode 100644 index 0000000000000000000000000000000000000000..906466ac5e6b00e2dccd9b29e4592f7449e606c3 --- /dev/null +++ b/sglang/docs/supported_models/retrieval_ranking/embedding_models.md @@ -0,0 +1,126 @@ +# Embedding Models + +SGLang provides robust support for embedding models by integrating efficient serving mechanisms with its flexible programming interface. This integration allows for streamlined handling of embedding tasks, facilitating faster and more accurate retrieval and semantic search operations. SGLang's architecture enables better resource utilization and reduced latency in embedding model deployment. + +```{important} +Embedding models are executed with `--is-embedding` flag and some may require `--trust-remote-code` +``` + +## Quick Start + +### Launch Server + +```shell +python3 -m sglang.launch_server \ + --model-path Qwen/Qwen3-Embedding-4B \ + --is-embedding \ + --host 0.0.0.0 \ + --port 30000 +``` + +### Client Request + +```python +import requests + +url = "http://127.0.0.1:30000" + +payload = { + "model": "Qwen/Qwen3-Embedding-4B", + "input": "What is the capital of France?", + "encoding_format": "float" +} + +response = requests.post(url + "/v1/embeddings", json=payload).json() +print("Embedding:", response["data"][0]["embedding"]) +``` + + + +## Multimodal Embedding Example + +For multimodal models like GME that support both text and images: + +```shell +python3 -m sglang.launch_server \ + --model-path Alibaba-NLP/gme-Qwen2-VL-2B-Instruct \ + --is-embedding \ + --chat-template gme-qwen2-vl \ + --host 0.0.0.0 \ + --port 30000 +``` + +```python +import requests + +url = "http://127.0.0.1:30000" + +text_input = "Represent this image in embedding space." +image_path = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/023.jpg" + +payload = { + "model": "gme-qwen2-vl", + "input": [ + { + "text": text_input + }, + { + "image": image_path + } + ], +} + +response = requests.post(url + "/v1/embeddings", json=payload).json() + +print("Embeddings:", [x.get("embedding") for x in response.get("data", [])]) +``` + +## Matryoshka Embedding Example + +[Matryoshka Embeddings](https://sbert.net/examples/sentence_transformer/training/matryoshka/README.html#matryoshka-embeddings) or [Matryoshka Representation Learning (MRL)](https://arxiv.org/abs/2205.13147) is a technique used in training embedding models. It allows user to trade off between performance and cost. + +### 1. Launch a Matryoshka‑capable model + +If the model config already includes `matryoshka_dimensions` or `is_matryoshka` then no override is needed. Otherwise, you can use `--json-model-override-args` as below: + +```shell +python3 -m sglang.launch_server \ + --model-path Qwen/Qwen3-Embedding-0.6B \ + --is-embedding \ + --host 0.0.0.0 \ + --port 30000 \ + --json-model-override-args '{"matryoshka_dimensions": [128, 256, 512, 1024, 1536]}' +``` + +1. Setting `"is_matryoshka": true` allows truncating to any dimension. Otherwise, the server will validate that the specified dimension in the request is one of `matryoshka_dimensions`. +2. Omitting `dimensions` in a request returns the full vector. + +### 2. Make requests with different output dimensions + +```python +import requests + +url = "http://127.0.0.1:30000" + +# Request a truncated (Matryoshka) embedding by specifying a supported dimension. +payload = { + "model": "Qwen/Qwen3-Embedding-0.6B", + "input": "Explain diffusion models simply.", + "dimensions": 512 # change to 128 / 1024 / omit for full size +} + +response = requests.post(url + "/v1/embeddings", json=payload).json() +print("Embedding:", response["data"][0]["embedding"]) +``` + + +## Supported Models + +| Model Family | Example Model | Chat Template | Description | +| ------------------------------------------ | -------------------------------------- | ------------- | --------------------------------------------------------------------------- | +| **E5 (Llama/Mistral based)** | `intfloat/e5-mistral-7b-instruct` | N/A | High-quality text embeddings based on Mistral/Llama architectures | +| **GTE-Qwen2** | `Alibaba-NLP/gte-Qwen2-7B-instruct` | N/A | Alibaba's text embedding model with multilingual support | +| **Qwen3-Embedding** | `Qwen/Qwen3-Embedding-4B` | N/A | Latest Qwen3-based text embedding model for semantic representation | +| **BGE** | `BAAI/bge-large-en-v1.5` | N/A | BAAI's text embeddings (requires `attention-backend` triton/torch_native) | +| **GME (Multimodal)** | `Alibaba-NLP/gme-Qwen2-VL-2B-Instruct`| `gme-qwen2-vl`| Multimodal embedding for text and image cross-modal tasks | +| **CLIP** | `openai/clip-vit-large-patch14-336` | N/A | OpenAI's CLIP for image and text embeddings | diff --git a/sglang/docs/supported_models/retrieval_ranking/rerank_models.md b/sglang/docs/supported_models/retrieval_ranking/rerank_models.md new file mode 100644 index 0000000000000000000000000000000000000000..bb989128a8ec271ec128b4cdbaabae7e20c5e6f3 --- /dev/null +++ b/sglang/docs/supported_models/retrieval_ranking/rerank_models.md @@ -0,0 +1,313 @@ +# Rerank Models + +SGLang offers comprehensive support for rerank models by incorporating optimized serving frameworks with a flexible programming interface. This setup enables efficient processing of cross-encoder reranking tasks, improving the accuracy and relevance of search result ordering. SGLang’s design ensures high throughput and low latency during reranker model deployment, making it ideal for semantic-based result refinement in large-scale retrieval systems. + +```{important} +Rerank models in SGLang fall into two categories: + +- **Cross-encoder rerank models**: run with `--is-embedding` (embedding runner). +- **Decoder-only rerank models**: run **without** `--is-embedding` and use next-token logprob scoring (yes/no). + - Text-only (e.g. Qwen3-Reranker) + - Multimodal (e.g. Qwen3-VL-Reranker): also supports image/video content + +Some models may require `--trust-remote-code`. +``` + +## Supported rerank models + +| Model Family (Rerank) | Example HuggingFace Identifier | Chat Template | Description | +|------------------------------------------------|--------------------------------------|---------------|----------------------------------------------------------------------------------------------------------------------------------| +| **BGE-Reranker (BgeRerankModel)** | `BAAI/bge-reranker-v2-m3` | N/A | Currently only support `attention-backend` `triton` and `torch_native`. High-performance cross-encoder reranker model from BAAI. Suitable for reranking search results based on semantic relevance. | +| **Qwen3-Reranker (decoder-only yes/no)** | `Qwen/Qwen3-Reranker-8B` | `examples/chat_template/qwen3_reranker.jinja` | Decoder-only reranker using next-token logprob scoring for labels (yes/no). Launch **without** `--is-embedding`. | +| **Qwen3-VL-Reranker (multimodal yes/no)** | `Qwen/Qwen3-VL-Reranker-2B` | `examples/chat_template/qwen3_vl_reranker.jinja` | Multimodal decoder-only reranker supporting text, images, and videos. Uses yes/no logprob scoring. Launch **without** `--is-embedding`. | + + +## Cross-Encoder Rerank (embedding runner) + +### Launch Command + +```shell +python3 -m sglang.launch_server \ + --model-path BAAI/bge-reranker-v2-m3 \ + --host 0.0.0.0 \ + --disable-radix-cache \ + --chunked-prefill-size -1 \ + --attention-backend triton \ + --is-embedding \ + --port 30000 +``` + +### Example Client Request + +```python +import requests + +url = "http://127.0.0.1:30000/v1/rerank" + +payload = { + "model": "BAAI/bge-reranker-v2-m3", + "query": "what is panda?", + "documents": [ + "hi", + "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." + ], + "top_n": 1, + "return_documents": True +} + +response = requests.post(url, json=payload) +response_json = response.json() + +for item in response_json: + if item.get("document"): + print(f"Score: {item['score']:.2f} - Document: '{item['document']}'") + else: + print(f"Score: {item['score']:.2f} - Index: {item['index']}") +``` + +**Request Parameters:** + +- `query` (required): The query text to rank documents against +- `documents` (required): List of documents to be ranked +- `model` (required): Model to use for reranking +- `top_n` (optional): Maximum number of documents to return. Defaults to returning all documents. If specified value is greater than the total number of documents, all documents will be returned. +- `return_documents` (optional): Whether to return documents in the response. Defaults to `True`. + +## Qwen3-Reranker (decoder-only yes/no rerank) + +### Launch Command + +```shell +python3 -m sglang.launch_server \ + --model-path Qwen/Qwen3-Reranker-0.6B \ + --trust-remote-code \ + --disable-radix-cache \ + --host 0.0.0.0 \ + --port 8001 \ + --chat-template examples/chat_template/qwen3_reranker.jinja +``` + +```{note} +Qwen3-Reranker uses decoder-only logprob scoring (yes/no). Do NOT launch it with `--is-embedding`. +``` + +### Example Client Request (supports optional instruct, top_n, and return_documents) + +```shell +curl -X POST http://127.0.0.1:8001/v1/rerank \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen3-Reranker-0.6B", + "query": "法国首都是哪里?", + "documents": [ + "法国的首都是巴黎。", + "德国的首都是柏林。", + "香蕉是黄色的水果。" + ], + "instruct": "Given a web search query, retrieve relevant passages that answer the query.", + "top_n": 2, + "return_documents": true + }' +``` + +**Request Parameters:** + +- `query` (required): The query text to rank documents against +- `documents` (required): List of documents to be ranked +- `model` (required): Model to use for reranking +- `instruct` (optional): Instruction text for the reranker +- `top_n` (optional): Maximum number of documents to return. Defaults to returning all documents. If specified value is greater than the total number of documents, all documents will be returned. +- `return_documents` (optional): Whether to return documents in the response. Defaults to `True`. + +### Response Format + +`/v1/rerank` returns a list of objects (sorted by descending score): + +- `score`: float, higher means more relevant +- `document`: the original document string (only included when `return_documents` is `true`) +- `index`: the original index in the input `documents` +- `meta_info`: optional debug/usage info (may be present for some models) + +The number of returned results is controlled by the `top_n` parameter. If `top_n` is not specified or is greater than the total number of documents, all documents are returned. + +Example (with `return_documents: true`): + +```json +[ + {"score": 0.99, "document": "法国的首都是巴黎。", "index": 0}, + {"score": 0.01, "document": "德国的首都是柏林。", "index": 1}, + {"score": 0.00, "document": "香蕉是黄色的水果。", "index": 2} +] +``` + +Example (with `return_documents: false`): + +```json +[ + {"score": 0.99, "index": 0}, + {"score": 0.01, "index": 1}, + {"score": 0.00, "index": 2} +] +``` + +Example (with `top_n: 2`): + +```json +[ + {"score": 0.99, "document": "法国的首都是巴黎。", "index": 0}, + {"score": 0.01, "document": "德国的首都是柏林。", "index": 1} +] +``` + +### Common Pitfalls + +- If you launch Qwen3-Reranker with `--is-embedding`, `/v1/rerank` cannot compute yes/no logprob scores. Relaunch **without** `--is-embedding`. +- If you see a validation error like "score should be a valid number" and the backend returned a list, upgrade to a version that coerces `embedding[0]` into `score` for rerank responses. + +## Qwen3-VL-Reranker (multimodal decoder-only rerank) + +Qwen3-VL-Reranker extends the Qwen3-Reranker to support multimodal content, allowing reranking of documents containing text, images, and videos. + +### Launch Command + +```shell +python3 -m sglang.launch_server \ + --model-path Qwen/Qwen3-VL-Reranker-2B \ + --trust-remote-code \ + --disable-radix-cache \ + --host 0.0.0.0 \ + --port 30000 \ + --chat-template examples/chat_template/qwen3_vl_reranker.jinja +``` + +```{note} +Qwen3-VL-Reranker uses decoder-only logprob scoring (yes/no) like Qwen3-Reranker. Do NOT launch it with `--is-embedding`. +``` + +### Text-Only Reranking (backward compatible) + +```python +import requests + +url = "http://127.0.0.1:30000/v1/rerank" + +payload = { + "model": "Qwen3-VL-Reranker-2B", + "query": "What is machine learning?", + "documents": [ + "Machine learning is a branch of artificial intelligence that enables computers to learn from data.", + "The weather in Paris is usually mild with occasional rain.", + "Deep learning is a subset of machine learning using neural networks with many layers.", + ], + "instruct": "Retrieve passages that answer the question.", + "return_documents": True +} + +response = requests.post(url, json=payload) +results = response.json() + +for item in results: + print(f"Score: {item['score']:.4f} - {item['document'][:60]}...") +``` + +### Image Reranking (text query, image/mixed documents) + +```python +import requests + +url = "http://127.0.0.1:30000/v1/rerank" + +payload = { + "query": "A woman playing with her dog on a beach at sunset.", + "documents": [ + # Document 1: Text description + "A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset.", + # Document 2: Image URL + [ + { + "type": "image_url", + "image_url": { + "url": "https://example.com/beach_dog.jpeg" + } + } + ], + # Document 3: Text + Image (mixed) + [ + {"type": "text", "text": "A joyful scene at the beach:"}, + { + "type": "image_url", + "image_url": { + "url": "https://example.com/beach_dog.jpeg" + } + } + ] + ], + "instruct": "Retrieve images or text relevant to the user's query.", + "return_documents": False +} + +response = requests.post(url, json=payload) +results = response.json() + +for item in results: + print(f"Index: {item['index']}, Score: {item['score']:.4f}") +``` + +### Multimodal Query Reranking (query with image) + +```python +import requests + +url = "http://127.0.0.1:30000/v1/rerank" + +payload = { + # Query with text and image + "query": [ + {"type": "text", "text": "Find similar images to this:"}, + { + "type": "image_url", + "image_url": { + "url": "https://example.com/reference_image.jpeg" + } + } + ], + "documents": [ + "A cat sleeping on a couch.", + "A woman and her dog enjoying the sunset at the beach.", + "A busy city street with cars and pedestrians.", + [ + { + "type": "image_url", + "image_url": { + "url": "https://example.com/similar_image.jpeg" + } + } + ] + ], + "instruct": "Find images or descriptions similar to the query image." +} + +response = requests.post(url, json=payload) +results = response.json() + +for item in results: + print(f"Index: {item['index']}, Score: {item['score']:.4f}") +``` + +### Request Parameters (Multimodal) + +- `query` (required): Can be a string (text-only) or a list of content parts: + - `{"type": "text", "text": "..."}` for text + - `{"type": "image_url", "image_url": {"url": "..."}}` for images + - `{"type": "video_url", "video_url": {"url": "..."}}` for videos +- `documents` (required): List where each document can be a string or list of content parts (same format as query) +- `instruct` (optional): Instruction text for the reranker +- `top_n` (optional): Maximum number of documents to return +- `return_documents` (optional): Whether to return documents in the response (default: `false`) + +### Common Pitfalls + +- Always use `--chat-template examples/chat_template/qwen3_vl_reranker.jinja` for Qwen3-VL-Reranker. +- Do NOT launch with `--is-embedding`. +- For best results, use `--disable-radix-cache` to avoid caching issues with multimodal content. +- **Note**: Currently only `Qwen3-VL-Reranker-2B` is tested and supported. The 8B model may have different behavior and is not guaranteed to work with this template. diff --git a/sglang/docs/supported_models/specialized/index.rst b/sglang/docs/supported_models/specialized/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..40d108acb3df4db6e971b0e6a5b9cda9405edabe --- /dev/null +++ b/sglang/docs/supported_models/specialized/index.rst @@ -0,0 +1,9 @@ +Specialized Models +================== + +Models for specialized tasks like reward modeling. + +.. toctree:: + :maxdepth: 1 + + reward_models.md diff --git a/sglang/docs/supported_models/specialized/reward_models.md b/sglang/docs/supported_models/specialized/reward_models.md new file mode 100644 index 0000000000000000000000000000000000000000..ef4474637fad4ad11c9b1b8f32af02e76943149f --- /dev/null +++ b/sglang/docs/supported_models/specialized/reward_models.md @@ -0,0 +1,28 @@ +# Reward Models + +These models output a scalar reward score or classification result, often used in reinforcement learning or content moderation tasks. + +```{important} +They are executed with `--is-embedding` and some may require `--trust-remote-code`. +``` + +## Example launch Command + +```shell +python3 -m sglang.launch_server \ + --model-path Qwen/Qwen2.5-Math-RM-72B \ # example HF/local path + --is-embedding \ + --host 0.0.0.0 \ + --tp-size=4 \ # set for tensor parallelism + --port 30000 \ +``` + +## Supported models + +| Model Family (Reward) | Example HuggingFace Identifier | Description | +|---------------------------------------------------------------------------|-----------------------------------------------------|---------------------------------------------------------------------------------| +| **Llama (3.1 Reward / `LlamaForSequenceClassification`)** | `Skywork/Skywork-Reward-Llama-3.1-8B-v0.2` | Reward model (preference classifier) based on Llama 3.1 (8B) for scoring and ranking responses for RLHF. | +| **Gemma 2 (27B Reward / `Gemma2ForSequenceClassification`)** | `Skywork/Skywork-Reward-Gemma-2-27B-v0.2` | Derived from Gemma‑2 (27B), this model provides human preference scoring for RLHF and multilingual tasks. | +| **InternLM 2 (Reward / `InternLM2ForRewardMode`)** | `internlm/internlm2-7b-reward` | InternLM 2 (7B)–based reward model used in alignment pipelines to guide outputs toward preferred behavior. | +| **Qwen2.5 (Reward - Math / `Qwen2ForRewardModel`)** | `Qwen/Qwen2.5-Math-RM-72B` | A 72B math-specialized RLHF reward model from the Qwen2.5 series, tuned for evaluating and refining responses. | +| **Qwen2.5 (Reward - Sequence / `Qwen2ForSequenceClassification`)** | `jason9693/Qwen2.5-1.5B-apeach` | A smaller Qwen2.5 variant used for sequence classification, offering an alternative RLHF scoring mechanism. | diff --git a/sglang/docs/supported_models/text_generation/diffusion_language_models.md b/sglang/docs/supported_models/text_generation/diffusion_language_models.md new file mode 100644 index 0000000000000000000000000000000000000000..7dbb4828b6956ee079349086d87c19bea991423b --- /dev/null +++ b/sglang/docs/supported_models/text_generation/diffusion_language_models.md @@ -0,0 +1,111 @@ +# Diffusion Language Models + +Diffusion language models have shown promise for non-autoregressive text generation with parallel decoding capabilities. Unlike auto-regressive language models, different diffusion language models require different decoding strategies. + +## Example Launch Command + +SGLang supports different DLLM algorithms such as `LowConfidence` and `JointThreshold`. + +```shell +python3 -m sglang.launch_server \ + --model-path inclusionAI/LLaDA2.0-mini \ # example HF/local path + --dllm-algorithm LowConfidence \ + --dllm-algorithm-config ./config.yaml \ # Optional. Uses the algorithm's default if not set. + --host 0.0.0.0 \ + --port 30000 +``` + +## Example Configuration File + +Depending on the algorithm selected, the configuration parameters vary. + +LowConfidence Config: + +```yaml +# Confidence threshold for accepting predicted tokens +# - Higher values: More conservative, better quality but slower +# - Lower values: More aggressive, faster but potentially lower quality +# Range: 0.0 - 1.0 +threshold: 0.95 + +# Default: 32, for LLaDA2MoeModelLM +block_size: 32 +``` + +JointThreshold Config: + +```yaml +# Decoding threshold for Mask-to-Token (M2T) phase +# - Higher values: More conservative, better quality but slower +# - Lower values: More aggressive, faster but potentially lower quality +# Range: 0.0 - 1.0 +threshold: 0.5 +# Decoding threshold for Token-to-Token (T2T) phase +# Range: 0.0 - 1.0 +# Setting to 0.0 allows full editing (recommended for most cases). +edit_threshold: 0.0 +# Max extra T2T steps after all masks are removed. Prevents infinite loops. +max_post_edit_steps: 16 +# 2-gram repetition penalty (default 0). +# An empirical value of 3 is often sufficient to mitigate most repetitions. +penalty_lambda: 0 +``` + +## Example Client Code Snippet + +Just like other supported models, diffusion language models can be used via the REST API or Python client. + +Python client example for making a generation request to the launched server: + +```python +import sglang as sgl + +def main(): + llm = sgl.Engine(model_path="inclusionAI/LLaDA2.0-mini", + dllm_algorithm="LowConfidence", + max_running_requests=1, + trust_remote_code=True) + + prompts = [ + "SYSTEMdetailed thinking off<|role_end|>HUMAN Write a brief introduction of the great wall <|role_end|>ASSISTANT" + ] + + sampling_params = { + "temperature": 0, + "max_new_tokens": 1024, + } + + outputs = llm.generate(prompts, sampling_params) + print(outputs) + +if __name__ == '__main__': + main() +``` + +Curl example for making a generation request to the launched server: + +```bash +curl -X POST "http://127.0.0.1:30000/generate" \ + -H "Content-Type: application/json" \ + -d '{ + "text": [ + "SYSTEMdetailed thinking off<|role_end|>HUMAN Write the number from 1 to 128 <|role_end|>ASSISTANT", + "SYSTEMdetailed thinking off<|role_end|>HUMAN Write a brief introduction of the great wall <|role_end|>ASSISTANT" + ], + "stream": true, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 1024 + } + }' +``` + +## Supported Models + +Below the supported models are summarized in a table. + +| Model Family | Example Model | Description | +| -------------------------- | ---------------------------- | ---------------------------------------------------------------------------------------------------- | +| **LLaDA2.0 (mini, flash)** | `inclusionAI/LLaDA2.0-flash` | LLaDA2.0-flash is a diffusion language model featuring a 100B Mixture-of-Experts (MoE) architecture. | +| **SDAR (JetLM)** | `JetLM/SDAR-8B-Chat` | SDAR series diffusion language model (Chat), dense architecture. | +| **SDAR (JetLM)** | `JetLM/SDAR-30B-A3B-Chat` | SDAR series diffusion language model (Chat), MoE architecture. | diff --git a/sglang/docs/supported_models/text_generation/generative_models.md b/sglang/docs/supported_models/text_generation/generative_models.md new file mode 100644 index 0000000000000000000000000000000000000000..f73aa200faed4e904f646faf0d43949eadb9167c --- /dev/null +++ b/sglang/docs/supported_models/text_generation/generative_models.md @@ -0,0 +1,72 @@ +# Large Language Models + +These models accept text input and produce text output (e.g., chat completions). They are primarily large language models (LLMs), some with mixture-of-experts (MoE) architectures for scaling. + +## Example launch Command + +```shell +python3 -m sglang.launch_server \ + --model-path meta-llama/Llama-3.2-1B-Instruct \ # example HF/local path + --host 0.0.0.0 \ + --port 30000 \ +``` + +## Supported models + +Below the supported models are summarized in a table. + +If you are unsure if a specific architecture is implemented, you can search for it via GitHub. For example, to search for `Qwen3ForCausalLM`, use the expression: + +``` +repo:sgl-project/sglang path:/^python\/sglang\/srt\/models\// Qwen3ForCausalLM +``` + +in the GitHub search bar. + +| Model Family (Variants) | Example HuggingFace Identifier | Description | +|-------------------------------------|--------------------------------------------------|----------------------------------------------------------------------------------------| +| **DeepSeek** (v1, v2, v3/R1) | `deepseek-ai/DeepSeek-R1` | Series of advanced reasoning-optimized models (including a 671B MoE) trained with reinforcement learning; top performance on complex reasoning, math, and code tasks. [SGLang provides Deepseek v3/R1 model-specific optimizations](../basic_usage/deepseek.md) and [Reasoning Parser](../advanced_features/separate_reasoning.ipynb)| +| **Kimi K2** (Thinking, Instruct) | `moonshotai/Kimi-K2-Instruct` | Moonshot AI's 1 trillion parameter MoE model (32B active) with 128K–256K context; state-of-the-art agentic intelligence with stable long-horizon agency across 200–300 sequential tool calls. Features MLA attention and native INT4 quantization. [See Reasoning Parser docs](../advanced_features/separate_reasoning.ipynb)| +| **Kimi Linear** (48B-A3B) | `moonshotai/Kimi-Linear-48B-A3B-Instruct` | Moonshot AI's hybrid linear attention model (48B total, 3B active) with 1M token context; features Kimi Delta Attention (KDA) for up to 6× faster decoding and 75% KV cache reduction vs full attention. | +| **GPT-OSS** | `openai/gpt-oss-20b`, `openai/gpt-oss-120b` | OpenAI’s latest GPT-OSS series for complex reasoning, agentic tasks, and versatile developer use cases.| +| **Qwen** (3.5, 3, 3MoE, 3Next, 2.5, 2 series) | `Qwen/Qwen3.5-397B-A17B`, `Qwen/Qwen3-0.6B`, `Qwen/Qwen3-30B-A3B` | Alibaba’s latest Qwen3 series for complex reasoning, language understanding, and generation tasks; Support for MoE variants along with previous generation 2.5, 2, etc. [SGLang provides Qwen3 specific reasoning parser](../advanced_features/separate_reasoning.ipynb)| +| **Llama** (2, 3.x, 4 series) | `meta-llama/Llama-4-Scout-17B-16E-Instruct` | Meta's open LLM series, spanning 7B to 400B parameters (Llama 2, 3, and new Llama 4) with well-recognized performance. [SGLang provides Llama-4 model-specific optimizations](../basic_usage/llama4.md) | +| **Mistral** (Mixtral, NeMo, Small3) | `mistralai/Mistral-7B-Instruct-v0.2` | Open 7B LLM by Mistral AI with strong performance; extended into MoE (“Mixtral”) and NeMo Megatron variants for larger scale. | +| **Gemma** (v1, v2, v3) | `google/gemma-3-1b-it` | Google’s family of efficient multilingual models (1B–27B); Gemma 3 offers a 128K context window, and its larger (4B+) variants support vision input. | +| **Phi** (Phi-1.5, Phi-2, Phi-3, Phi-4, Phi-MoE series) | `microsoft/Phi-4-multimodal-instruct`, `microsoft/Phi-3.5-MoE-instruct` | Microsoft’s Phi family of small models (1.3B–5.6B); Phi-4-multimodal (5.6B) processes text, images, and speech, Phi-4-mini is a high-accuracy text model and Phi-3.5-MoE is a mixture-of-experts model. | +| **MiniCPM** (v3, 4B) | `openbmb/MiniCPM3-4B` | OpenBMB’s series of compact LLMs for edge devices; MiniCPM 3 (4B) achieves GPT-3.5-level results in text tasks. | +| **OLMo** (2, 3) | `allenai/OLMo-3-1125-32B`, `allenai/OLMo-3-32B-Think`, `allenai/OLMo-2-1124-7B-Instruct` | Allen AI’s series of Open Language Models designed to enable the science of language models. | +| **OLMoE** (Open MoE) | `allenai/OLMoE-1B-7B-0924` | Allen AI’s open Mixture-of-Experts model (7B total, 1B active parameters) delivering state-of-the-art results with sparse expert activation. | +| **MiniMax-M2** (M2, M2.1, M2.5) | `MiniMaxAI/MiniMax-M2.5`, `MiniMaxAI/MiniMax-M2.1`, `MiniMaxAI/MiniMax-M2` | MiniMax's SOTA LLM for coding & agentic workflows. | +| **StableLM** (3B, 7B) | `stabilityai/stablelm-tuned-alpha-7b` | StabilityAI’s early open-source LLM (3B & 7B) for general text generation; a demonstration model with basic instruction-following ability. | +| **Command-(R,A)** (Cohere) | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, `CohereLabs/c4ai-command-a-03-2025` | Cohere’s open conversational LLM (Command series) optimized for long context, retrieval-augmented generation, and tool use. | +| **DBRX** (Databricks) | `databricks/dbrx-instruct` | Databricks’ 132B-parameter MoE model (36B active) trained on 12T tokens; competes with GPT-3.5 quality as a fully open foundation model. | +| **Grok** (xAI) | `xai-org/grok-1` | xAI’s grok-1 model known for vast size(314B parameters) and high quality; integrated in SGLang for high-performance inference. | +| **ChatGLM** (GLM-130B family) | `THUDM/chatglm2-6b` | Zhipu AI’s bilingual chat model (6B) excelling at Chinese-English dialogue; fine-tuned for conversational quality and alignment. | +| **InternLM 2** (7B, 20B) | `internlm/internlm2-7b` | Next-gen InternLM (7B and 20B) from SenseTime, offering strong reasoning and ultra-long context support (up to 200K tokens). | +| **ExaONE 3** (Korean-English) | `LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct` | LG AI Research’s Korean-English model (7.8B) trained on 8T tokens; provides high-quality bilingual understanding and generation. | +| **Baichuan 2** (7B, 13B) | `baichuan-inc/Baichuan2-13B-Chat` | BaichuanAI’s second-generation Chinese-English LLM (7B/13B) with improved performance and an open commercial license. | +| **XVERSE** (MoE) | `xverse/XVERSE-MoE-A36B` | Yuanxiang’s open MoE LLM (XVERSE-MoE-A36B: 255B total, 36B active) supporting ~40 languages; delivers 100B+ dense-level performance via expert routing. | +| **SmolLM** (135M–1.7B) | `HuggingFaceTB/SmolLM-1.7B` | Hugging Face’s ultra-small LLM series (135M–1.7B params) offering surprisingly strong results, enabling advanced AI on mobile/edge devices. | +| **GLM-4** (Multilingual 9B) | `ZhipuAI/glm-4-9b-chat` | Zhipu’s GLM-4 series (up to 9B parameters) – open multilingual models with support for 1M-token context and even a 5.6B multimodal variant (Phi-4V). | +| **MiMo** (7B series) | `XiaomiMiMo/MiMo-7B-RL` | Xiaomi's reasoning-optimized model series, leverages Multiple-Token Prediction for faster inference. | +| **ERNIE-4.5** (4.5, 4.5MoE series) | `baidu/ERNIE-4.5-21B-A3B-PT` | Baidu's ERNIE-4.5 series which consists of MoE with 47B and 3B active parameters, with the largest model having 424B total parameters, as well as a 0.3B dense model. | +| **Arcee AFM-4.5B** | `arcee-ai/AFM-4.5B-Base` | Arcee's foundational model series for real world reliability and edge deployments. | +| **Persimmon** (8B) | `adept/persimmon-8b-chat` | Adept’s open 8B model with a 16K context window and fast inference; trained for broad usability and licensed under Apache 2.0. | +| **Solar** (10.7B) | `upstage/SOLAR-10.7B-Instruct-v1.0` | Upstage's 10.7B parameter model, optimized for instruction-following tasks. This architecture incorporates a depth-up scaling methodology, enhancing model performance. | +| **Tele FLM** (52B-1T) | `CofeAI/Tele-FLM` | BAAI & TeleAI's multilingual model, available in 52-billion and 1-trillion parameter variants. It is a decoder-only transformer trained on ~2T tokens | +| **Ling** (16.8B–290B) | `inclusionAI/Ling-lite`, `inclusionAI/Ling-plus` | InclusionAI’s open MoE models. Ling-Lite has 16.8B total / 2.75B active parameters, and Ling-Plus has 290B total / 28.8B active parameters. They are designed for high performance on NLP and complex reasoning tasks. | +| **Granite 3.0, 3.1** (IBM) | `ibm-granite/granite-3.1-8b-instruct` | IBM's open dense foundation models optimized for reasoning, code, and business AI use cases. Integrated with Red Hat and watsonx systems. | +| **Granite 3.0 MoE** (IBM) | `ibm-granite/granite-3.0-3b-a800m-instruct` | IBM’s Mixture-of-Experts models offering strong performance with cost-efficiency. MoE expert routing designed for enterprise deployment at scale. | +| **GPT-J** (6B) | `EleutherAI/gpt-j-6b` | EleutherAI's GPT-2-like causal language model (6B) trained on the [Pile](https://pile.eleuther.ai/) dataset. | +| **Orion** (14B) | `OrionStarAI/Orion-14B-Base` | A series of open-source multilingual large language models by OrionStarAI, pretrained on a 2.5T token multilingual corpus including Chinese, English, Japanese, Korean, etc, and it exhibits superior performance in these languages. | +| **Llama Nemotron Super** (v1, v1.5, NVIDIA) | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, `nvidia/Llama-3_3-Nemotron-Super-49B-v1_5` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. | +| **Llama Nemotron Ultra** (v1, NVIDIA) | `nvidia/Llama-3_1-Nemotron-Ultra-253B-v1` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. | +| **NVIDIA Nemotron Nano 2.0** | `nvidia/NVIDIA-Nemotron-Nano-9B-v2` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. `Nemotron-Nano-9B-v2` is a hybrid Mamba-Transformer language model designed to increase throughput for reasoning workloads while achieving state-of-the-art accuracy compared to similarly-sized models. | +| **StarCoder2** (3B-15B) | `bigcode/starcoder2-7b` | StarCoder2 is a family of open large language models (LLMs) specialized for code generation and understanding. It is the successor to StarCoder, jointly developed by the BigCode project (a collaboration between Hugging Face, ServiceNow Research, and other contributors). | +| **Jet-Nemotron** | `jet-ai/Jet-Nemotron-2B` | Jet-Nemotron is a new family of hybrid-architecture language models that surpass state-of-the-art open-source full-attention language models, while achieving significant efficiency gains. | +| **Trinity** (Nano, Mini) | `arcee-ai/Trinity-Mini` | Arcee's foundational MoE Trinity family of models, open weights under Apache 2.0. | +| **Falcon-H1** (0.5B–34B) | `tiiuae/Falcon-H1-34B-Instruct` | TII's hybrid Mamba-Transformer architecture combining attention and state-space models for efficient long-context inference. | +| **Hunyuan-Large** (389B, MoE) | `tencent/Tencent-Hunyuan-Large` | Tencent's open-source MoE model with 389B total / 52B active parameters, featuring Cross-Layer Attention (CLA) for improved efficiency. | +| **IBM Granite 4.0 (Hybrid, Dense)** | `ibm-granite/granite-4.0-h-micro`, `ibm-granite/granite-4.0-micro` | IBM Granite 4.0 micro models: hybrid Mamba–MoE (`h-micro`) and dense (`micro`) variants. Enterprise-focused reasoning models | +| **Sarvam 2** (30B-A2B, 105B-A10B) | `sarvamai/sarvam-2` | Sarvam's Mixture-of-Experts models. The 105B variant uses MLA (Multi-head Latent Attention) and the 30B variant uses GQA, both with 128 routed experts. | diff --git a/sglang/docs/supported_models/text_generation/index.rst b/sglang/docs/supported_models/text_generation/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..e315f83d1a057dc385de8006427d962f80bc81f6 --- /dev/null +++ b/sglang/docs/supported_models/text_generation/index.rst @@ -0,0 +1,11 @@ +Text Generation +=============== + +Models for generating text from text or multimodal inputs. + +.. toctree:: + :maxdepth: 1 + + generative_models.md + multimodal_language_models.md + diffusion_language_models.md diff --git a/sglang/docs/supported_models/text_generation/multimodal_language_models.md b/sglang/docs/supported_models/text_generation/multimodal_language_models.md new file mode 100644 index 0000000000000000000000000000000000000000..0dab3a28af2538cb00c953d78caa7ac6445f36b6 --- /dev/null +++ b/sglang/docs/supported_models/text_generation/multimodal_language_models.md @@ -0,0 +1,136 @@ +# Multimodal Language Models + +These models accept multi-modal inputs (e.g., images and text) and generate text output. They augment language models with multimodal encoders. + +## Example launch Command + +```shell +python3 -m sglang.launch_server \ + --model-path meta-llama/Llama-3.2-11B-Vision-Instruct \ # example HF/local path + --host 0.0.0.0 \ + --port 30000 \ +``` + +> See the [OpenAI APIs section](https://docs.sglang.io/basic_usage/openai_api_vision.html) for how to send multimodal requests. + +## Supported models + +Below the supported models are summarized in a table. + +If you are unsure if a specific architecture is implemented, you can search for it via GitHub. For example, to search for `Qwen2_5_VLForConditionalGeneration`, use the expression: + +``` +repo:sgl-project/sglang path:/^python\/sglang\/srt\/models\// Qwen2_5_VLForConditionalGeneration +``` + +in the GitHub search bar. + + +| Model Family (Variants) | Example HuggingFace Identifier | Description | Notes | +|----------------------------|--------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------| +| **Qwen-VL** | `Qwen/Qwen3-VL-235B-A22B-Instruct` | Alibaba's vision-language extension of Qwen; for example, Qwen2.5-VL (7B and larger variants) can analyze and converse about image content. | | +| **DeepSeek-VL2** | `deepseek-ai/deepseek-vl2` | Vision-language variant of DeepSeek (with a dedicated image processor), enabling advanced multimodal reasoning on image and text inputs. | | +| **DeepSeek-OCR / OCR-2** | `deepseek-ai/DeepSeek-OCR-2` | OCR-focused DeepSeek models for document understanding and text extraction. | Use `--trust-remote-code`. | +| **Janus-Pro** (1B, 7B) | `deepseek-ai/Janus-Pro-7B` | DeepSeek's open-source multimodal model capable of both image understanding and generation. Janus-Pro employs a decoupled architecture for separate visual encoding paths, enhancing performance in both tasks. | | +| **MiniCPM-V / MiniCPM-o** | `openbmb/MiniCPM-V-2_6` | MiniCPM-V (2.6, ~8B) supports image inputs, and MiniCPM-o adds audio/video; these multimodal LLMs are optimized for end-side deployment on mobile/edge devices. | | +| **Llama 3.2 Vision** (11B) | `meta-llama/Llama-3.2-11B-Vision-Instruct` | Vision-enabled variant of Llama 3 (11B) that accepts image inputs for visual question answering and other multimodal tasks. | | +| **LLaVA** (v1.5 & v1.6) | *e.g.* `liuhaotian/llava-v1.5-13b` | Open vision-chat models that add an image encoder to LLaMA/Vicuna (e.g. LLaMA2 13B) for following multimodal instruction prompts. | | +| **LLaVA-NeXT** (8B, 72B) | `lmms-lab/llava-next-72b` | Improved LLaVA models (with an 8B Llama3 version and a 72B version) offering enhanced visual instruction-following and accuracy on multimodal benchmarks. | | +| **LLaVA-OneVision** | `lmms-lab/llava-onevision-qwen2-7b-ov` | Enhanced LLaVA variant integrating Qwen as the backbone; supports multiple images (and even video frames) as inputs via an OpenAI Vision API-compatible format. | | +| **Gemma 3 (Multimodal)** | `google/gemma-3-4b-it` | Gemma 3's larger models (4B, 12B, 27B) accept images (each image encoded as 256 tokens) alongside text in a combined 128K-token context. | | +| **Kimi-VL** (A3B) | `moonshotai/Kimi-VL-A3B-Instruct` | Kimi-VL is a multimodal model that can understand and generate text from images. | | +| **Mistral-Small-3.1-24B** | `mistralai/Mistral-Small-3.1-24B-Instruct-2503` | Mistral 3.1 is a multimodal model that can generate text from text or images input. It also supports tool calling and structured output. | | +| **Phi-4-multimodal-instruct** | `microsoft/Phi-4-multimodal-instruct` | Phi-4-multimodal-instruct is the multimodal variant of the Phi-4-mini model, enhanced with LoRA for improved multimodal capabilities. It supports text, vision and audio modalities in SGLang. | | +| **MiMo-VL** (7B) | `XiaomiMiMo/MiMo-VL-7B-RL` | Xiaomi's compact yet powerful vision-language model featuring a native resolution ViT encoder for fine-grained visual details, an MLP projector for cross-modal alignment, and the MiMo-7B language model optimized for complex reasoning tasks. | | +| **GLM-4.5V** (106B) / **GLM-4.1V**(9B) | `zai-org/GLM-4.5V` | GLM-4.5V and GLM-4.1V-Thinking: Towards Versatile Multimodal Reasoning with Scalable Reinforcement Learning | Use `--chat-template glm-4v` | +| **GLM-OCR** | `zai-org/GLM-OCR` | GLM-OCR: A fast and accurate general OCR model | | +| **DotsVLM** (General/OCR) | `rednote-hilab/dots.vlm1.inst` | RedNote's vision-language model built on a 1.2B vision encoder and DeepSeek V3 LLM, featuring NaViT vision encoder trained from scratch with dynamic resolution support and enhanced OCR capabilities through structured image data training. | | +| **DotsVLM-OCR** | `rednote-hilab/dots.ocr` | Specialized OCR variant of DotsVLM optimized for optical character recognition tasks with enhanced text extraction and document understanding capabilities. | Don't use `--trust-remote-code` | +| **NVILA** (8B, 15B, Lite-2B, Lite-8B, Lite-15B) | `Efficient-Large-Model/NVILA-8B` | `chatml` | NVILA explores the full stack efficiency of multi-modal design, achieving cheaper training, faster deployment and better performance. | +| **NVIDIA Nemotron Nano 2.0 VL** | `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16` | NVIDIA Nemotron Nano v2 VL enables multi-image reasoning and video understanding, along with strong document intelligence, visual Q&A and summarization capabilities. It builds on Nemotron Nano V2, a hybrid Mamba-Transformer LLM, in order to achieve higher inference throughput in long document and video scenarios. | Use `--trust-remote-code`. You may need to adjust `--max-mamba-cache-size` [default is 512] to fit memory constraints. | +| **Ernie4.5-VL** | `baidu/ERNIE-4.5-VL-28B-A3B-PT` | Baidu's vision-language models(28B,424B). Support image and video comprehension, and also support thinking. | | +| **JetVLM** | | JetVLM is an vision-language model designed for high-performance multimodal understanding and generation tasks built upon Jet-Nemotron. | Coming soon | +| **Step3-VL** (10B) | `stepfun-ai/Step3-VL-10B` | StepFun's lightweight open-source 10B parameter VLM for multimodal intelligence, excelling in visual perception, complex reasoning, and human alignment. | | +| **Qwen3-Omni** | `Qwen/Qwen3-Omni-30B-A3B-Instruct` | Alibaba's omni-modal MoE model. Currently supports the **Thinker** component (multimodal understanding for text, images, audio, and video), while the **Talker** component (audio generation) is not yet supported. | | + +## Video Input Support + +SGLang supports video input for Vision-Language Models (VLMs), enabling temporal reasoning tasks such as video question answering, captioning, and holistic scene understanding. Video clips are decoded, key frames are sampled, and the resulting tensors are batched together with the text prompt, allowing multimodal inference to integrate visual and linguistic context. + +| Model Family | Example Identifier | Video notes | +|--------------|--------------------|-------------| +| **Qwen-VL** (Qwen2-VL, Qwen2.5-VL, Qwen3-VL, Qwen3-Omni) | `Qwen/Qwen3-VL-235B-A22B-Instruct` | The processor gathers `video_data`, runs Qwen's frame sampler, and merges the resulting features with text tokens before inference. | +| **GLM-4v** (4.5V, 4.1V, MOE) | `zai-org/GLM-4.5V` | Video clips are read with Decord, converted to tensors, and passed to the model alongside metadata for rotary-position handling. | +| **NVILA** (Full & Lite) | `Efficient-Large-Model/NVILA-8B` | The runtime samples eight frames per clip and attaches them to the multimodal request when `video_data` is present. | +| **LLaVA video variants** (LLaVA-NeXT-Video, LLaVA-OneVision) | `lmms-lab/LLaVA-NeXT-Video-7B` | The processor routes video prompts to the LlavaVid video-enabled architecture, and the provided example shows how to query it with `sgl.video(...)` clips. | +| **NVIDIA Nemotron Nano 2.0 VL** | `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16` | The processor samples at 2 FPS, at a max of 128 frames, as per model training. The model uses [EVS](../../python/sglang/srt/multimodal/evs/README.md), a pruning method that removes redundant tokens from video embeddings. By default `video_pruning_rate=0.7`. Change this by providing: `--json-model-override-args '{"video_pruning_rate": 0.0}'` to disable EVS, for example. | +| **JetVLM** | | The runtime samples eight frames per clip and attaches them to the multimodal request when `video_data` is present. | + +Use `sgl.video(path, num_frames)` when building prompts to attach clips from your SGLang programs. + +Example OpenAI-compatible request that sends a video clip: + +```python +import requests + +url = "http://localhost:30000/v1/chat/completions" + +data = { + "model": "Qwen/Qwen3-VL-30B-A3B-Instruct", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What’s happening in this video?"}, + { + "type": "video_url", + "video_url": { + "url": "https://github.com/sgl-project/sgl-test-files/raw/refs/heads/main/videos/jobs_presenting_ipod.mp4" + }, + }, + ], + } + ], + "max_tokens": 300, +} + +response = requests.post(url, json=data) +print(response.text) +``` + +## Usage Notes + +### Performance Optimization + +For multimodal models, you can use the `--keep-mm-feature-on-device` flag to optimize for latency at the cost of increased GPU memory usage: + +- **Default behavior**: Multimodal feature tensors are moved to CPU after processing to save GPU memory +- **With `--keep-mm-feature-on-device`**: Feature tensors remain on GPU, reducing device-to-host copy overhead and improving latency, but consuming more GPU memory + +Use this flag when you have sufficient GPU memory and want to minimize latency for multimodal inference. + +### Multimodal Inputs Limitation + +- **Use `--mm-process-config '{"image":{"max_pixels":1048576},"video":{"fps":3,"max_pixels":602112,"max_frames":60}}'`**: To set `image`, `video`, and `audio` input limits. + +This can reduce GPU memory usage, improve inference speed, and help to avoid OOM, but may impact model performance, thus set a proper value based on your specific use case. Currently, only `qwen_vl` supports this config. Please refer to [qwen_vl processor](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/multimodal/processors/qwen_vl.py) for understanding the meaning of each parameter. + +### Bidirectional Attention in Multimodal Model Serving +**Note for serving the Gemma-3 multimodal model**: + +As mentioned in [Welcome Gemma 3: Google's all new multimodal, multilingual, long context open LLM +](https://huggingface.co/blog/gemma3#multimodality), Gemma-3 employs bidirectional attention between image tokens during the prefill phase. Currently, SGLang only supports bidirectional attention when using the Triton Attention Backend. Note, however, that SGLang's current bidirectional attention implementation is incompatible with both CUDA Graph and Chunked Prefill. + +To enable bidirectional attention, you can use the `TritonAttnBackend` while disabling CUDA Graph and Chunked Prefill. Example launch command: +```shell +python -m sglang.launch_server \ + --model-path google/gemma-3-4b-it \ + --host 0.0.0.0 --port 30000 \ + --enable-multimodal \ + --dtype bfloat16 --triton-attention-reduce-in-fp32 \ + --attention-backend triton \ # Use Triton attention backend + --disable-cuda-graph \ # Disable Cuda Graph + --chunked-prefill-size -1 # Disable Chunked Prefill +``` + +If higher serving performance is required and a certain degree of accuracy loss is acceptable, you may choose to use other attention backends, and you can also enable features like CUDA Graph and Chunked Prefill for better performance, but note that the model will fall back to using causal attention instead of bidirectional attention. diff --git a/sglang/python/sglang/srt/__pycache__/constants.cpython-311.pyc b/sglang/python/sglang/srt/__pycache__/constants.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a760597764f0597c0ec6488ed7888857a5b4c9ae Binary files /dev/null and b/sglang/python/sglang/srt/__pycache__/constants.cpython-311.pyc differ diff --git a/sglang/python/sglang/srt/__pycache__/environ.cpython-311.pyc b/sglang/python/sglang/srt/__pycache__/environ.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3411d2f6324e7eacc1b05a64651d688114caf76 Binary files /dev/null and b/sglang/python/sglang/srt/__pycache__/environ.cpython-311.pyc differ diff --git a/sglang/python/sglang/srt/batch_overlap/__pycache__/operations.cpython-311.pyc b/sglang/python/sglang/srt/batch_overlap/__pycache__/operations.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..579d378a41f5735cf4717264f09bd75d1bcfb3e5 Binary files /dev/null and b/sglang/python/sglang/srt/batch_overlap/__pycache__/operations.cpython-311.pyc differ diff --git a/sglang/python/sglang/srt/batch_overlap/__pycache__/operations_strategy.cpython-311.pyc b/sglang/python/sglang/srt/batch_overlap/__pycache__/operations_strategy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01eddb87237558e906a9f1d9098cca529f4c16b9 Binary files /dev/null and b/sglang/python/sglang/srt/batch_overlap/__pycache__/operations_strategy.cpython-311.pyc differ diff --git a/sglang/python/sglang/srt/batch_overlap/__pycache__/single_batch_overlap.cpython-311.pyc b/sglang/python/sglang/srt/batch_overlap/__pycache__/single_batch_overlap.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d43beff6a24504cfb6af84246d810ea0b99226d Binary files /dev/null and b/sglang/python/sglang/srt/batch_overlap/__pycache__/single_batch_overlap.cpython-311.pyc differ diff --git a/sglang/python/sglang/srt/batch_overlap/__pycache__/two_batch_overlap.cpython-311.pyc b/sglang/python/sglang/srt/batch_overlap/__pycache__/two_batch_overlap.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efcf37045c9f266728b3aaa91ab0f34ce121eb68 Binary files /dev/null and b/sglang/python/sglang/srt/batch_overlap/__pycache__/two_batch_overlap.cpython-311.pyc differ diff --git a/sglang/python/sglang/srt/batch_overlap/operations.py b/sglang/python/sglang/srt/batch_overlap/operations.py new file mode 100644 index 0000000000000000000000000000000000000000..3d61ac82f500d4d22e01fe34d58c499558dcf010 --- /dev/null +++ b/sglang/python/sglang/srt/batch_overlap/operations.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import os +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Sequence, Union + +import torch + +from sglang.srt.layers.dp_attention import set_dp_buffer_len + +if TYPE_CHECKING: + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + +_ENABLE_PROFILE = bool(int(os.environ.get("SGLANG_OPERATIONS_ENABLE_PROFILE", "0"))) + +if _ENABLE_PROFILE: + import nvtx + + +def execute_operations(inputs, operations): + stages = _convert_operations_to_stages(operations) + executor = _StageExecutor("primary", stages, inputs=inputs) + for _ in range(executor.num_stages): + executor.next() + assert executor.done + return executor.output + + +def execute_overlapped_operations( + inputs_arr: Sequence, + operations_arr: Sequence, + delta_stages: Sequence[int], +) -> Sequence: + # Make it explicit for clarity; if we need multi-batch overlap, this can be generalized + inputs_a, inputs_b = inputs_arr + operations_a, operations_b = operations_arr + delta_stage_a, delta_stage_b = delta_stages + assert delta_stage_a == 0 + delta_stage = delta_stage_b + + stages_a = _convert_operations_to_stages(operations_a) + stages_b = _convert_operations_to_stages(operations_b) + executor_a = _StageExecutor("a", stages_a, inputs=inputs_a) + executor_b = _StageExecutor("b", stages_b, inputs=inputs_b) + + for _ in range(delta_stage): + executor_a.next() + + for _ in range(executor_a.num_stages - delta_stage): + executor_a.next() + executor_b.next() + + for _ in range(delta_stage): + executor_b.next() + + assert executor_a.done and executor_b.done + return [executor_a.output, executor_b.output] + + +class YieldOperation: + pass + + +@dataclass +class ExecutionOperation: + debug_name: str + fn: Callable + + +Operation = Union[YieldOperation, ExecutionOperation, Callable] +Stage = List[ExecutionOperation] + + +class _StageExecutor: + def __init__(self, debug_name: str, stages: List[Stage], inputs: dict): + self._debug_name = debug_name + self._stages = stages + self._index = 0 + self._stage_state = _StateDict() + self._stage_output = inputs + + # handling DP attention + forward_batch: ForwardBatch = inputs["forward_batch"] + self._global_dp_buffer_len = forward_batch.global_dp_buffer_len + self._local_dp_buffer_len = forward_batch.tbo_padded_len + self._global_num_tokens = forward_batch.global_num_tokens_cpu + self._is_dp_max_padding = forward_batch.dp_padding_mode.is_max_len() + + def next(self): + assert not self.done + + stage = self._stages[self._index] + + # TODO: We currently always call set_dp_buffer_len here because sub-batches + # may have different padded lengths. It can likely be removed after TBO slice & + # pad logic is refactored. + set_dp_buffer_len( + self._global_dp_buffer_len, + self._local_dp_buffer_len, + self._is_dp_max_padding, + self._global_num_tokens, + ) + + with _annotate_region(debug_name=f"{self._debug_name}{self._index}"): + for op in stage: + with _annotate_region(debug_name=op.debug_name): + self._stage_output = op.fn( + state=self._stage_state, + **( + self._stage_output if self._stage_output is not None else {} + ), + ) + + self._index += 1 + + @property + def output(self): + assert self.done + return self._stage_output + + @property + def done(self): + return self._index >= self.num_stages + + @property + def num_stages(self): + return len(self._stages) + + +@contextmanager +def _annotate_region(debug_name): + if _ENABLE_PROFILE: + with torch.autograd.profiler.record_function(debug_name): + with nvtx.annotate(debug_name): + yield + else: + yield + + +class _StateDict: + def __init__(self): + self._data = {} + + def __setattr__(self, key, value): + if key == "_data": + super().__setattr__(key, value) + return + assert ( + key not in self._data + ), f"`{key}` already exist, are you sure you want to override it?" + self._data[key] = value + + def __getattr__(self, item): + return self._data[item] + + def __delattr__(self, item): + del self._data[item] + + def pop(self, item): + return self._data.pop(item) + + def update(self, values: Dict[str, Any]): + for k, v in values.items(): + setattr(self, k, v) + + def get(self, item): + return self._data.get(item) + + def clear(self, expect_keys: Sequence[str]): + if set(self._data.keys()) != set(expect_keys): + raise Exception( + f"Unexpected keys when clearing. This may indicate you do not release memory early enough but leave it until here. {list(self._data.keys())=} {expect_keys=}" + ) + + self._data.clear() + + +def _convert_operations_to_stages(operations: List[Operation]) -> List[Stage]: + operations = _decorate_operations(operations) + operation_chunks = list( + _chunk_by_separator(operations, lambda op: isinstance(op, YieldOperation)) + ) + assert all(len(chunk) > 0 for chunk in operation_chunks) + return operation_chunks + + +def _chunk_by_separator( + items: List[Any], is_separator: Callable[[Any], bool] +) -> Generator[List[Any], None, None]: + pending_items = [] + for item in items: + if is_separator(item): + yield pending_items + pending_items = [] + else: + pending_items.append(item) + if len(pending_items) > 0: + yield pending_items + + +def _decorate_operations(operations: List[Operation], debug_name_prefix: str = ""): + return [_decorate_operation(op, debug_name_prefix) for op in operations] + + +def _decorate_operation(operation: Operation, debug_name_prefix: str): + if isinstance(operation, YieldOperation): + return operation + return ExecutionOperation( + debug_name=debug_name_prefix + + getattr(operation, "__name__", "unknown").replace("op_", ""), + fn=operation, + ) diff --git a/sglang/python/sglang/srt/batch_overlap/operations_strategy.py b/sglang/python/sglang/srt/batch_overlap/operations_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..d39ad838577d8d89e96d467be1875eff25c5434e --- /dev/null +++ b/sglang/python/sglang/srt/batch_overlap/operations_strategy.py @@ -0,0 +1,302 @@ +from dataclasses import dataclass +from typing import List, Optional + +import torch + +from sglang.srt.batch_overlap import operations +from sglang.srt.batch_overlap.operations import Operation +from sglang.srt.layers.moe.token_dispatcher import DeepEPConfig +from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.utils import is_hip + +_is_hip = is_hip() + + +@dataclass +class OperationsStrategy: + operations: List[Operation] + deep_gemm_num_sms: Optional[int] = None + tbo_delta_stages: Optional[int] = None + + @classmethod + def concat(cls, items: List["OperationsStrategy"]) -> "OperationsStrategy": + return OperationsStrategy( + operations=[x for item in items for x in item.operations], + deep_gemm_num_sms=_assert_all_same( + [item.deep_gemm_num_sms for item in items] + ), + tbo_delta_stages=_assert_all_same( + [item.tbo_delta_stages for item in items] + ), + ) + + @staticmethod + def init_new_tbo( + layers: torch.nn.ModuleList, + forward_mode: ForwardMode, + ) -> "OperationsStrategy": + layer_name = layers[0].__class__.__name__ + if layer_name == "DeepseekV2DecoderLayer": + return OperationsStrategy.concat( + [ + _compute_moe_deepseek_layer_operations_strategy_tbo( + layer, forward_mode + ) + for layer in layers + ] + ) + elif layer_name == "Qwen3MoeDecoderLayer": + return OperationsStrategy.concat( + [ + _compute_moe_qwen3_layer_operations_strategy_tbo( + layer, forward_mode + ) + for layer in layers + ] + ) + elif layer_name == "MiMoV2DecoderLayer": + return OperationsStrategy.concat( + [ + _compute_moe_mimov2_layer_operations_strategy_tbo( + layer, forward_mode + ) + for layer in layers + ] + ) + else: + raise NotImplementedError + + +def _assert_all_same(items: List): + assert all(item == items[0] for item in items) + return items[0] + + +# -------------------------------- Strategy for DeepSeek --------------------------------------- + + +# TODO can refactor to make it more fancy if we have more complex strategies +def _compute_moe_deepseek_layer_operations_strategy_tbo( + layer: torch.nn.Module, + forward_mode: ForwardMode, +) -> OperationsStrategy: + assert layer.is_layer_sparse, "dense layer TBO not yet implemented" + if forward_mode == ForwardMode.EXTEND: + return _compute_moe_deepseek_blog_prefill(layer) + elif ( + forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY + ): + return _compute_moe_deepseek_blog_decode(layer) + else: + raise NotImplementedError(f"Unsupported {forward_mode=}") + + +def _compute_moe_deepseek_blog_prefill(layer): + device_properties = torch.cuda.get_device_properties(device="cuda") + total_num_sms = device_properties.multi_processor_count + deep_gemm_num_sms = None + if not _is_hip: + deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms + + return OperationsStrategy( + deep_gemm_num_sms=deep_gemm_num_sms, + tbo_delta_stages=0, + operations=[ + layer.op_comm_prepare_attn, + layer.self_attn.op_prepare, + layer.self_attn.op_core, + layer.op_comm_prepare_mlp, + layer.mlp.op_gate, + layer.mlp.op_select_experts, + layer.mlp.op_dispatch_a, + operations.YieldOperation(), + layer.mlp.op_dispatch_b, + layer.mlp.op_experts, + layer.mlp.op_combine_a, + operations.YieldOperation(), + layer.mlp.op_shared_experts, + layer.mlp.op_combine_b, + layer.mlp.op_output, + layer.op_comm_postprocess_layer, + ], + ) + + +def _compute_moe_deepseek_blog_decode(layer): + return OperationsStrategy( + deep_gemm_num_sms=None, + tbo_delta_stages=2, + operations=[ + layer.op_comm_prepare_attn, + layer.self_attn.op_prepare, + operations.YieldOperation(), + layer.self_attn.op_core, + layer.op_comm_prepare_mlp, + layer.mlp.op_gate, + layer.mlp.op_select_experts, + operations.YieldOperation(), + layer.mlp.op_dispatch_a, + layer.mlp.op_shared_experts, + operations.YieldOperation(), + layer.mlp.op_dispatch_b, + layer.mlp.op_experts, + layer.mlp.op_combine_a, + operations.YieldOperation(), + layer.mlp.op_combine_b, + operations.YieldOperation(), + layer.mlp.op_output, + layer.op_comm_postprocess_layer, + ], + ) + + +# -------------------------------- Strategy for Qwen3 --------------------------------------- + + +# TODO: unstable, current strategy is almost the same as DeepSeek, keep redundant code here for +# convenience to adjust strategy +def _compute_moe_qwen3_layer_operations_strategy_tbo( + layer: torch.nn.Module, + forward_mode: ForwardMode, +) -> OperationsStrategy: + assert layer.is_layer_sparse, "qwen3 moe only support sparse layers" + if forward_mode == ForwardMode.EXTEND: + return _compute_moe_qwen3_prefill(layer) + elif ( + forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY + ): + return _compute_moe_qwen3_decode(layer) + else: + raise NotImplementedError(f"Unsupported {forward_mode=}") + + +def _compute_moe_qwen3_prefill(layer): + device_properties = torch.cuda.get_device_properties(device="cuda") + total_num_sms = device_properties.multi_processor_count + deep_gemm_num_sms = None + if not _is_hip: + deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms + + return OperationsStrategy( + deep_gemm_num_sms=deep_gemm_num_sms, + tbo_delta_stages=0, + operations=[ + layer.op_comm_prepare_attn, + layer.self_attn.op_prepare, + layer.self_attn.op_core, + layer.op_comm_prepare_mlp, + layer.mlp.op_gate, + layer.mlp.op_select_experts, + layer.mlp.op_dispatch_a, + operations.YieldOperation(), + layer.mlp.op_dispatch_b, + layer.mlp.op_experts, + layer.mlp.op_combine_a, + operations.YieldOperation(), + layer.mlp.op_combine_b, + layer.mlp.op_output, + layer.op_comm_postprocess_layer, + ], + ) + + +def _compute_moe_qwen3_decode(layer): + return OperationsStrategy( + deep_gemm_num_sms=None, + tbo_delta_stages=2, + operations=[ + layer.op_comm_prepare_attn, + layer.self_attn.op_prepare, + operations.YieldOperation(), + layer.self_attn.op_core, + layer.op_comm_prepare_mlp, + layer.mlp.op_gate, + layer.mlp.op_select_experts, + operations.YieldOperation(), + layer.mlp.op_dispatch_a, + operations.YieldOperation(), + layer.mlp.op_dispatch_b, + layer.mlp.op_experts, + layer.mlp.op_combine_a, + operations.YieldOperation(), + layer.mlp.op_combine_b, + layer.mlp.op_output, + layer.op_comm_postprocess_layer, + operations.YieldOperation(), + ], + ) + + +# -------------------------------- Strategy for MiMoV2DecoderLayer --------------------------------------- + + +# TODO: unstable; current strategy matches DeepSeek for the common operations (MiMoV2 has no op_shared_experts), +# so we keep this redundant code here for convenience when adjusting the strategy +def _compute_moe_mimov2_layer_operations_strategy_tbo( + layer: torch.nn.Module, + forward_mode: ForwardMode, +) -> OperationsStrategy: + assert layer.is_layer_sparse, "MiMoV2DecoderLayer moe only support sparse layers" + if forward_mode == ForwardMode.EXTEND: + return _compute_moe_mimov2_prefill(layer) + elif ( + forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY + ): + return _compute_moe_mimov2_decode(layer) + else: + raise NotImplementedError(f"Unsupported {forward_mode=}") + + +def _compute_moe_mimov2_prefill(layer): + device_properties = torch.cuda.get_device_properties(device="cuda") + total_num_sms = device_properties.multi_processor_count + deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms + + return OperationsStrategy( + deep_gemm_num_sms=deep_gemm_num_sms, + tbo_delta_stages=0, + operations=[ + layer.op_comm_prepare_attn, + layer.self_attn.op_prepare, + layer.self_attn.op_core, + layer.op_comm_prepare_mlp, + layer.mlp.op_gate, + layer.mlp.op_select_experts, + layer.mlp.op_dispatch_a, + operations.YieldOperation(), + layer.mlp.op_dispatch_b, + layer.mlp.op_experts, + layer.mlp.op_combine_a, + operations.YieldOperation(), + layer.mlp.op_combine_b, + layer.mlp.op_output, + layer.op_comm_postprocess_layer, + ], + ) + + +def _compute_moe_mimov2_decode(layer): + return OperationsStrategy( + deep_gemm_num_sms=None, + tbo_delta_stages=2, + operations=[ + layer.op_comm_prepare_attn, + layer.self_attn.op_prepare, + operations.YieldOperation(), + layer.self_attn.op_core, + layer.op_comm_prepare_mlp, + layer.mlp.op_gate, + layer.mlp.op_select_experts, + operations.YieldOperation(), + layer.mlp.op_dispatch_a, + operations.YieldOperation(), + layer.mlp.op_dispatch_b, + layer.mlp.op_experts, + layer.mlp.op_combine_a, + operations.YieldOperation(), + layer.mlp.op_combine_b, + layer.mlp.op_output, + layer.op_comm_postprocess_layer, + operations.YieldOperation(), + ], + ) diff --git a/sglang/python/sglang/srt/batch_overlap/single_batch_overlap.py b/sglang/python/sglang/srt/batch_overlap/single_batch_overlap.py new file mode 100644 index 0000000000000000000000000000000000000000..815ba8715de678aca24457c51c72fb08a7d0b3be --- /dev/null +++ b/sglang/python/sglang/srt/batch_overlap/single_batch_overlap.py @@ -0,0 +1,144 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import torch + +from sglang.srt.environ import envs +from sglang.srt.layers.moe import get_moe_runner_backend +from sglang.srt.layers.moe.utils import is_sbo_enabled +from sglang.srt.utils import is_blackwell + + +class SboFlags: + # TODO may have: "enable_dispatch_gateup_gemm_two_stream_overlap", ... + + @classmethod + def enable_combine_down_gemm_two_stream_overlap(cls): + return ( + is_sbo_enabled() + # currently only cutedsl backend supports it + and ( + get_moe_runner_backend().is_flashinfer_cutedsl() + or (get_moe_runner_backend().is_deep_gemm() and not is_blackwell()) + ) + ) + + @classmethod + def enable_combine_shared_two_stream_overlap(cls): + return ( + is_sbo_enabled() + and not cls.enable_dispatch_shared_one_stream_overlap() + and not envs.SGLANG_BLACKWELL_OVERLAP_SHARED_EXPERTS_OUTSIDE_SBO.get() + ) + + @classmethod + def enable_dispatch_shared_one_stream_overlap(cls): + return is_sbo_enabled() and not is_blackwell() + + @classmethod + def fuse_shared_experts_inside_sbo(cls): + return ( + cls.enable_combine_shared_two_stream_overlap() + or cls.enable_dispatch_shared_one_stream_overlap() + ) + + +@dataclass +class CombineOverlapArgs: + # this "overlap" flag means overlapping with down gemm, not the general two-stream overlap + overlap: bool + stream: torch.cuda.Stream + wait_event: torch.cuda.Event + num_sms: Optional[int] = None + signal: Optional[torch.Tensor] = None + block_m: Optional[int] = 64 + threshold: Optional[int] = 0 + + +@dataclass +class DownGemmOverlapArgs: + num_sms: int + signal: torch.Tensor + start_event: torch.cuda.Event + + +def compute_overlap_args(dispatch_output, alt_stream): + if not ( + SboFlags.enable_combine_down_gemm_two_stream_overlap() + or SboFlags.enable_combine_shared_two_stream_overlap() + ): + return None, None, {} + + hidden_states = dispatch_output.hidden_states + + num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape + + total_num_sms = torch.cuda.get_device_properties( + device="cuda" + ).multi_processor_count + + if envs.SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS.is_set(): + communicate_num_sms = envs.SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS.get() + else: + communicate_num_sms = 32 if is_blackwell() else 3 + compute_num_sms = total_num_sms - communicate_num_sms + + assert alt_stream is not None + combine_wait_event = torch.cuda.Event() + combine_overlap_args = CombineOverlapArgs( + overlap=False, + num_sms=communicate_num_sms, + stream=alt_stream, + wait_event=combine_wait_event, + ) + meta_overlap_args = dict( + compute_num_sms=compute_num_sms, + ) + down_gemm_overlap_args = None + + if SboFlags.enable_combine_down_gemm_two_stream_overlap(): + # TODO use zero_allocator to remove this `torch.zeros` call + # NOTE ours v2 use uint32 not int32 currently + if is_blackwell(): + combine_signal = torch.zeros( + num_local_experts, dtype=torch.uint32, device=hidden_states.device + ) + else: + MIN_BLOCK_M = 64 + combine_signal_size = num_local_experts * ( + (num_tokens_static + MIN_BLOCK_M - 1) // MIN_BLOCK_M + ) + combine_signal = torch.zeros( + combine_signal_size, dtype=torch.int32, device=hidden_states.device + ) + + down_gemm_overlap_args = DownGemmOverlapArgs( + signal=combine_signal, + start_event=combine_wait_event, + num_sms=compute_num_sms, + ) + combine_overlap_args.overlap = True + combine_overlap_args.signal = combine_signal + combine_overlap_args.threshold = compute_num_sms + else: + meta_overlap_args |= dict( + record_event_after_down=combine_wait_event, + ) + + return combine_overlap_args, down_gemm_overlap_args, meta_overlap_args diff --git a/sglang/python/sglang/srt/batch_overlap/two_batch_overlap.py b/sglang/python/sglang/srt/batch_overlap/two_batch_overlap.py new file mode 100644 index 0000000000000000000000000000000000000000..c0d1a4923d93b600af8613cc25e3d0ed65229f5a --- /dev/null +++ b/sglang/python/sglang/srt/batch_overlap/two_batch_overlap.py @@ -0,0 +1,1082 @@ +from __future__ import annotations + +import copy +import dataclasses +import logging +from dataclasses import replace +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence + +import torch + +from sglang.srt.batch_overlap.operations import ( + execute_operations, + execute_overlapped_operations, +) +from sglang.srt.batch_overlap.operations_strategy import OperationsStrategy +from sglang.srt.layers import deep_gemm_wrapper +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.communicator import ( + CommunicateContext, + CommunicateSummableTensorPairFn, + ScatterMode, +) +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.layers.moe import ( + get_deepep_mode, + get_moe_a2a_backend, + get_tbo_token_distribution_threshold, + is_tbo_enabled, +) +from sglang.srt.layers.moe.token_dispatcher import ( + DeepEPDispatcher, + MooncakeEPDispatcher, + MoriEPDispatcher, +) +from sglang.srt.layers.moe.token_dispatcher.base import BaseDispatcher +from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.model_executor.forward_batch_info import ( + ForwardBatch, + ForwardMode, + compute_position, +) +from sglang.srt.server_args import get_global_server_args +from sglang.srt.speculative.spec_info import SpecInput +from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip + +if TYPE_CHECKING: + from sglang.srt.batch_overlap.single_batch_overlap import CombineOverlapArgs + from sglang.srt.layers.moe.token_dispatcher import DispatchOutput + from sglang.srt.speculative.eagle_info import EagleVerifyInput + +_is_hip = is_hip() + +_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG") + +logger = logging.getLogger(__name__) + + +# -------------------------------- Compute Basic Info --------------------------------------- + + +def get_token_num_per_seq( + forward_mode: ForwardMode, + spec_info: Optional[SpecInput] = None, +): + if forward_mode.is_target_verify(): + return spec_info.draft_token_num + elif forward_mode.is_decode(): + return 1 + elif forward_mode.is_idle(): + return 0 + else: + # For extend, we should not use `token_num_per_seq`. + return None + + +# TODO: may smartly disable TBO when batch size is too small b/c it will slow down +def compute_split_seq_index( + forward_mode: ForwardMode, + num_tokens: int, + extend_lens: Optional[Sequence[int]], + token_num_per_seq: Optional[int], +) -> Optional[int]: + if forward_mode == ForwardMode.EXTEND: + assert extend_lens is not None + return _split_extend_seqs(extend_lens) + elif forward_mode.is_target_verify() or forward_mode.is_decode(): + assert token_num_per_seq is not None + return (num_tokens // token_num_per_seq) // 2 + elif forward_mode.is_idle() or forward_mode.is_prebuilt(): + assert num_tokens == 0 + return 0 + else: + raise NotImplementedError() + + +def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool: + if extend_lens is None: + return False + + vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens) + left_sum = sum(extend_lens[:vanilla_split_seq_index]) + overall_sum = sum(extend_lens) + threshold = get_tbo_token_distribution_threshold() + assert threshold <= 0.5, f"{threshold=}" + return left_sum < overall_sum * threshold or left_sum > overall_sum * ( + 1 - threshold + ) + + +def _split_extend_seqs(arr: Sequence[int]) -> int: + if _is_two_chunk_split_enabled(arr): + return _split_array_by_cum_less_than_half(arr) + + return _split_array_by_balanced_sum(arr) + + +def _split_array_by_cum_less_than_half(arr: Sequence[int]) -> int: + left_sum = 0 + overall_sum = sum(arr) + half_sum = overall_sum // 2 + chosen_index = 0 + + for i in range(len(arr)): + left_sum += arr[i] + if left_sum > half_sum: + chosen_index = i + break + + return chosen_index + + +def _split_array_by_balanced_sum(arr: Sequence[int]) -> int: + overall_sum = sum(arr) + left_sum = 0 + min_diff = float("inf") + best_index = 0 + + for i in range(1, len(arr)): + left_sum += arr[i - 1] + right_sum = overall_sum - left_sum + diff = abs(left_sum - right_sum) + if diff <= min_diff: + min_diff = diff + best_index = i + else: + break + + return best_index + + +def _update_device_and_sum_field_from_cpu_field( + batch: ForwardBatch, cpu_field: str, device_field: str, sum_field: str = None +): + cpu_value = getattr(batch, cpu_field, None) + old_device_value = getattr(batch, device_field, None) + if ( + cpu_value is None + or old_device_value is None + or not (isinstance(cpu_value, torch.Tensor) or isinstance(cpu_value, list)) + ): + return + + new_device_value = ( + cpu_value + if isinstance(cpu_value, torch.Tensor) + else torch.tensor(cpu_value, dtype=old_device_value.dtype) + ).to(device=get_global_server_args().device, non_blocking=True) + setattr(batch, device_field, new_device_value) + + if sum_field is not None: + sum_value = ( + cpu_value.sum().item() + if isinstance(cpu_value, torch.Tensor) + else sum(cpu_value) + ) + setattr(batch, sum_field, sum_value) + + +def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int: + if seq_index == 0: + return 0 + + offset = 0 + max_seq_len = min(seq_index, spec_info.seq_lens_cpu.shape[0]) + for i in range(max_seq_len): + offset += ( + spec_info.seq_lens_cpu[i] + spec_info.draft_token_num + ) * spec_info.draft_token_num + return offset + + +def split_spec_info( + spec_info: Optional[EagleVerifyInput], + start_seq_index: int, + end_seq_index: int, + start_token_index: int, + end_token_index: int, +): + if spec_info is None: + return None + if spec_info.draft_token is not None: + draft_token = spec_info.draft_token[start_token_index:end_token_index] + else: + draft_token = None + if spec_info.custom_mask is not None and spec_info.draft_token is not None: + custom_mask_start = _compute_mask_offset(start_seq_index, spec_info) + if end_seq_index == spec_info.seq_lens_cpu.shape[0]: + custom_mask_end = spec_info.custom_mask.shape[0] + else: + custom_mask_end = _compute_mask_offset(end_seq_index, spec_info) + + if custom_mask_end > custom_mask_start: + custom_mask = spec_info.custom_mask[custom_mask_start:custom_mask_end] + else: + custom_mask = spec_info.custom_mask + else: + custom_mask = spec_info.custom_mask + if spec_info.positions is not None: + positions = spec_info.positions[start_token_index:end_token_index] + else: + positions = None + if spec_info.retrive_index is not None: + retrive_index = spec_info.retrive_index[start_seq_index:end_seq_index] + else: + retrive_index = None + if spec_info.retrive_next_token is not None: + retrive_next_token = spec_info.retrive_next_token[start_seq_index:end_seq_index] + else: + retrive_next_token = None + if spec_info.retrive_next_sibling is not None: + retrive_next_sibling = spec_info.retrive_next_sibling[ + start_seq_index:end_seq_index + ] + else: + retrive_next_sibling = None + if spec_info.retrive_cum_len is not None: + retrive_cum_len = spec_info.retrive_cum_len[start_seq_index:end_seq_index] + else: + retrive_cum_len = None + + if spec_info.seq_lens_cpu is not None: + seq_lens_cpu = spec_info.seq_lens_cpu[start_seq_index:end_seq_index] + else: + seq_lens_cpu = None + if seq_lens_cpu is not None: + seq_lens_sum = seq_lens_cpu.sum() + else: + seq_lens_sum = None + output_spec_info = replace( + spec_info, + custom_mask=custom_mask, + draft_token=draft_token, + positions=positions, + retrive_index=retrive_index, + retrive_next_token=retrive_next_token, + retrive_next_sibling=retrive_next_sibling, + retrive_cum_len=retrive_cum_len, + seq_lens_cpu=seq_lens_cpu, + seq_lens_sum=seq_lens_sum, + ) + return output_spec_info + + +def compute_split_token_index( + split_seq_index: int, + forward_mode: "ForwardMode", + extend_seq_lens: Optional[Sequence[int]], + token_num_per_seq: Optional[int], +) -> int: + if forward_mode == ForwardMode.EXTEND: + assert extend_seq_lens is not None + if _is_two_chunk_split_enabled(extend_seq_lens): + return sum(extend_seq_lens) // 2 + return sum(extend_seq_lens[:split_seq_index]) + elif forward_mode.is_target_verify() or forward_mode.is_decode(): + assert token_num_per_seq is not None + return split_seq_index * token_num_per_seq + elif forward_mode.is_idle(): + assert split_seq_index == 0 + return 0 + else: + raise NotImplementedError + + +def compute_split_indices_for_cuda_graph_replay( + forward_mode: ForwardMode, + cuda_graph_num_tokens: int, + spec_info: Optional[SpecInput], +): + forward_mode_for_tbo_split = ( + forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE + ) + token_num_per_seq = get_token_num_per_seq( + forward_mode=forward_mode, spec_info=spec_info + ) + tbo_split_seq_index = compute_split_seq_index( + forward_mode=forward_mode_for_tbo_split, + num_tokens=cuda_graph_num_tokens, + extend_lens=None, + token_num_per_seq=token_num_per_seq, + ) + tbo_split_token_index = compute_split_token_index( + split_seq_index=tbo_split_seq_index, + forward_mode=forward_mode_for_tbo_split, + extend_seq_lens=None, + token_num_per_seq=token_num_per_seq, + ) + return tbo_split_seq_index, tbo_split_token_index + + +# -------------------------------- Preparation --------------------------------------- + + +class TboCudaGraphRunnerPlugin: + def __init__(self): + self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32) + + def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int): + if not is_tbo_enabled(): + return + token_num_per_seq = get_token_num_per_seq( + forward_mode=batch.forward_mode, spec_info=batch.spec_info + ) + + batch.tbo_split_seq_index = compute_split_seq_index( + forward_mode=batch.forward_mode, + num_tokens=num_tokens, + extend_lens=None, + token_num_per_seq=token_num_per_seq, + ) + # For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true + assert batch.tbo_split_seq_index is not None, f"{num_tokens=}" + + self._tbo_children_num_token_non_padded[...] = ( + TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded(batch) + ) + + TboForwardBatchPreparer.prepare_raw( + batch, + tbo_children_num_token_non_padded=self._tbo_children_num_token_non_padded, + ) + + def replay_prepare( + self, + forward_mode: ForwardMode, + bs: int, + num_token_non_padded: int, + spec_info: Optional[SpecInput], + ): + token_num_per_seq = get_token_num_per_seq( + forward_mode=forward_mode, spec_info=spec_info + ) + tbo_split_seq_index, tbo_split_token_index = ( + compute_split_indices_for_cuda_graph_replay( + forward_mode=forward_mode, + cuda_graph_num_tokens=bs * token_num_per_seq, + spec_info=spec_info, + ) + ) + + self._tbo_children_num_token_non_padded[...] = ( + TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded_raw( + tbo_split_token_index=tbo_split_token_index, + num_token_non_padded=num_token_non_padded, + ) + ) + + +class TboDPAttentionPreparer: + def prepare_all_gather( + self, + local_batch: ScheduleBatch, + ): + + deepep_mode = get_deepep_mode() + enable_a2a_moe = not get_moe_a2a_backend().is_none() + enable_two_batch_overlap = is_tbo_enabled() + + self.enable_two_batch_overlap = enable_two_batch_overlap + + if local_batch is not None: + token_num_per_seq = get_token_num_per_seq( + forward_mode=local_batch.forward_mode, spec_info=local_batch.spec_info + ) + + if ( + local_batch.forward_mode.is_target_verify() + or local_batch.forward_mode.is_decode() + ): + num_tokens = local_batch.batch_size() * token_num_per_seq + elif local_batch.forward_mode.is_prebuilt(): + num_tokens = 0 + else: + num_tokens = local_batch.extend_num_tokens + self.local_tbo_split_seq_index = compute_split_seq_index( + forward_mode=local_batch.forward_mode, + num_tokens=num_tokens, + extend_lens=local_batch.extend_lens, + token_num_per_seq=token_num_per_seq, + ) + resolved_deepep_mode = deepep_mode.resolve(local_batch.is_extend_in_batch) + local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not ( + ( + local_batch.forward_mode.is_extend() + and not local_batch.forward_mode.is_target_verify() + ) + and enable_a2a_moe + and (resolved_deepep_mode.is_low_latency()) + ) + else: + self.local_tbo_split_seq_index = 0 + local_can_run_tbo = True + + local_forward_mode = self._compute_local_forward_mode(local_batch) + + return local_can_run_tbo, local_forward_mode + + def compute_output(self, partial_global_info): + # Perform only one Device-to-Host (D2H) memory copy + cpu_data = partial_global_info[:, :2].cpu() + local_can_run_tbo_aggregated = min(cpu_data[:, 0].tolist()) + forward_modes = cpu_data[:, 1].tolist() + + global_forward_mode, forward_mode_agree = self._compute_global_forward_mode( + forward_modes + ) + + can_run_tbo = ( + self.enable_two_batch_overlap + and local_can_run_tbo_aggregated + and forward_mode_agree + ) + + tbo_split_seq_index = self.local_tbo_split_seq_index if can_run_tbo else None + global_forward_mode = global_forward_mode if can_run_tbo else None + return tbo_split_seq_index, global_forward_mode + + @staticmethod + def _compute_local_forward_mode(local_batch): + return ( + local_batch.forward_mode if local_batch is not None else ForwardMode.IDLE + ).value + + @staticmethod + def _compute_global_forward_mode(forward_modes): + forward_modes_excluding_idle_and_prebuilt = [ + x + for x in forward_modes + if x != ForwardMode.IDLE.value and x != ForwardMode.PREBUILT.value + ] + + if not forward_modes_excluding_idle_and_prebuilt: + return ForwardMode.IDLE, False + + forward_mode_agree = TboDPAttentionPreparer._is_all_same( + forward_modes_excluding_idle_and_prebuilt + ) + + global_forward_mode = ( + ForwardMode(forward_modes_excluding_idle_and_prebuilt[0]) + if forward_mode_agree + else None + ) + return global_forward_mode, forward_mode_agree + + @staticmethod + def _is_all_same(x): + return all(value == x[0] for value in x) + + +class TboForwardBatchPreparer: + @classmethod + def prepare(cls, batch: ForwardBatch, is_draft_worker: bool = False): + if batch.tbo_split_seq_index is None or is_draft_worker: + return + + tbo_children_num_token_non_padded = ( + cls.compute_tbo_children_num_token_non_padded(batch) + ) + cls.prepare_raw( + batch, tbo_children_num_token_non_padded=tbo_children_num_token_non_padded + ) + + @classmethod + def prepare_raw( + cls, batch: ForwardBatch, tbo_children_num_token_non_padded: torch.Tensor + ): + from sglang.srt.layers.attention.tbo_backend import TboAttnBackend + + tbo_split_token_index = cls._compute_split_token_index(batch) + + is_enable_two_chunk = ( + batch.forward_mode == ForwardMode.EXTEND + and _is_two_chunk_split_enabled(batch.extend_seq_lens_cpu) + ) + + if _tbo_debug: + logger.info( + f"TboForwardBatchPreparer.prepare " + f"is_enable_two_chunk={is_enable_two_chunk} " + f"tbo_split_seq_index={batch.tbo_split_seq_index} " + f"tbo_split_token_index={tbo_split_token_index} " + f"extend_seq_lens={batch.extend_seq_lens_cpu} " + f"bs={batch.batch_size} " + f"forward_mode={batch.forward_mode}" + ) + + assert isinstance(batch.attn_backend, TboAttnBackend) + attn_backend_child_a, attn_backend_child_b = batch.attn_backend.children + + [out_num_token_non_padded_a, out_num_token_non_padded_b] = ( + tbo_children_num_token_non_padded + ) + + child_a = cls.filter_batch( + batch, + start_token_index=0, + end_token_index=tbo_split_token_index, + start_seq_index=0, + end_seq_index=( + batch.tbo_split_seq_index + 1 + if is_enable_two_chunk + else batch.tbo_split_seq_index + ), + output_attn_backend=attn_backend_child_a, + out_num_token_non_padded=out_num_token_non_padded_a, + ) + child_b = cls.filter_batch( + batch, + start_token_index=tbo_split_token_index, + end_token_index=batch.input_ids.shape[0], + start_seq_index=batch.tbo_split_seq_index, + end_seq_index=batch.batch_size, + output_attn_backend=attn_backend_child_b, + out_num_token_non_padded=out_num_token_non_padded_b, + ) + + if is_enable_two_chunk: + cls.derive_fields_related_to_seq_len_for_two_chunk( + batch, + child_a=child_a, + child_b=child_b, + tbo_split_seq_index=batch.tbo_split_seq_index, + ) + + assert batch.tbo_children is None + batch.tbo_children = [child_a, child_b] + + @classmethod + def derive_fields_related_to_seq_len_for_two_chunk( + cls, + batch: ForwardBatch, + *, + child_a: ForwardBatch, + child_b: ForwardBatch, + tbo_split_seq_index: int, + ): + extend_seq_lens_cpu = batch.extend_seq_lens_cpu + overall_seq_lens_sum = sum(extend_seq_lens_cpu) + half_seq_lens_sum = overall_seq_lens_sum // 2 + left_last_seq_token_num = half_seq_lens_sum - sum( + extend_seq_lens_cpu[:tbo_split_seq_index] + ) + right_first_seq_token_num = ( + extend_seq_lens_cpu[tbo_split_seq_index] - left_last_seq_token_num + ) + + # making deepcopy to be extra safe + child_a.extend_seq_lens_cpu = copy.deepcopy(child_a.extend_seq_lens_cpu) + child_a.extend_seq_lens_cpu[-1] = left_last_seq_token_num + child_b.extend_seq_lens_cpu = copy.deepcopy(child_b.extend_seq_lens_cpu) + child_b.extend_seq_lens_cpu[0] = right_first_seq_token_num + for child in [child_a, child_b]: + _update_device_and_sum_field_from_cpu_field( + batch=child, + cpu_field="extend_seq_lens_cpu", + device_field="extend_seq_lens", + sum_field="extend_num_tokens", + ) + + assert ( + child_a.extend_num_tokens == half_seq_lens_sum + ), f"{child_a.extend_num_tokens=}, {half_seq_lens_sum=}" + + child_a.seq_lens_cpu = copy.deepcopy(child_a.seq_lens_cpu) + child_a.seq_lens_cpu[-1] = ( + child_a.extend_seq_lens_cpu[-1] + child_a.extend_prefix_lens_cpu[-1] + ) + _update_device_and_sum_field_from_cpu_field( + batch=child_a, + cpu_field="seq_lens_cpu", + device_field="seq_lens", + sum_field="seq_lens_sum", + ) + + child_b.extend_prefix_lens_cpu = copy.deepcopy(child_b.extend_prefix_lens_cpu) + child_b.extend_prefix_lens_cpu[0] += left_last_seq_token_num + _update_device_and_sum_field_from_cpu_field( + batch=child_b, + cpu_field="extend_prefix_lens_cpu", + device_field="extend_prefix_lens", + sum_field=None, + ) + _, child_b.extend_start_loc = compute_position( + get_global_server_args().attention_backend, + child_b.extend_prefix_lens, + child_b.extend_seq_lens, + child_b.extend_num_tokens, + ) + + @classmethod + def filter_batch( + cls, + batch: ForwardBatch, + *, + start_token_index: int, + end_token_index: int, + start_seq_index: int, + end_seq_index: int, + output_attn_backend: AttentionBackend, + out_num_token_non_padded: torch.Tensor, + ): + assert ( + end_token_index >= start_token_index + ), f"{end_token_index=}, {start_token_index=}, batch={batch}" + num_tokens = batch.input_ids.shape[0] + num_seqs = batch.batch_size + + output_dict = dict() + + for key in [ + "input_ids", + "positions", + "out_cache_loc", + ]: + old_value = getattr(batch, key) + assert ( + old_value.shape[0] == num_tokens + ), f"{key=} {old_value=} {num_tokens=} {batch=}" + output_dict[key] = old_value[start_token_index:end_token_index] + + attention_tp_size = get_attention_tp_size() + output_dict["tbo_padded_len"] = ( + (end_token_index - start_token_index - 1) // attention_tp_size + 1 + ) * attention_tp_size + + for key in [ + "req_pool_indices", + "seq_lens", + "seq_lens_cpu", + "extend_seq_lens", + "extend_prefix_lens", + "extend_start_loc", + "extend_prefix_lens_cpu", + "extend_seq_lens_cpu", + "extend_logprob_start_lens_cpu", + "lora_ids", + "rids", + ]: + old_value = getattr(batch, key) + if old_value is None: + continue + elif batch.forward_mode.is_target_verify() and ( + key == "extend_seq_lens" + or key == "extend_prefix_lens" + or key == "extend_start_loc" + or key == "extend_prefix_lens_cpu" + or key == "extend_seq_lens_cpu" + or key == "extend_logprob_start_lens_cpu" + ): + output_dict[key] = None + continue + assert ( + len(old_value) == num_seqs + ), f"{key=} {old_value=} {num_seqs=} {batch=}" + output_dict[key] = old_value[start_seq_index:end_seq_index] + + spec_info = getattr(batch, "spec_info") + output_spec_info = split_spec_info( + spec_info=spec_info, + start_token_index=start_token_index, + end_token_index=end_token_index, + start_seq_index=start_seq_index, + end_seq_index=end_seq_index, + ) + output_dict["spec_info"] = output_spec_info + for key in [ + "forward_mode", + "is_extend_in_batch", + "all_extend_in_batch", + "return_logprob", + "req_to_token_pool", + "token_to_kv_pool", + "can_run_dp_cuda_graph", + "dp_padding_mode", + "global_forward_mode", + "is_prefill_only", + "spec_algorithm", + "capture_hidden_mode", + "padded_static_len", + "mrope_positions", # only used by qwen2-vl, thus not care + "split_index", # for split prefill + "orig_seq_lens", # only used by qwen-1m, thus not care + ]: + output_dict[key] = getattr(batch, key) + if not batch.forward_mode.is_target_verify(): + assert ( + _compute_extend_num_tokens(batch.input_ids, batch.forward_mode) + == batch.extend_num_tokens + ), f"{batch=}" + extend_num_tokens = _compute_extend_num_tokens( + output_dict["input_ids"], output_dict["forward_mode"] + ) + + # TODO improve, e.g. unify w/ `init_raw` + if ( + get_global_server_args().moe_dense_tp_size == 1 + and batch.global_dp_buffer_len is not None + ): + sum_len = end_token_index - start_token_index + global_dp_buffer_len = sum_len + else: + global_dp_buffer_len = None + + output_dict.update( + dict( + batch_size=end_seq_index - start_seq_index, + seq_lens_sum=( + output_dict["seq_lens_cpu"].sum() + if "seq_lens_cpu" in output_dict + else None + ), + extend_num_tokens=extend_num_tokens, + attn_backend=output_attn_backend, + num_token_non_padded=out_num_token_non_padded, + # TODO: handle it when we need TBO + DeepSeek V3.2 + num_token_non_padded_cpu=None, + tbo_split_seq_index=None, + tbo_parent_token_range=(start_token_index, end_token_index), + tbo_children=None, + original_global_num_tokens_cpu=None, + global_num_tokens_gpu=None, + global_num_tokens_cpu=None, + global_dp_buffer_len=global_dp_buffer_len, + global_num_tokens_for_logprob_gpu=None, + global_num_tokens_for_logprob_cpu=None, + sampling_info=None, + # For logits and logprobs post processing, thus we do not care + temp_scaled_logprobs=False, + temperature=None, + top_p_normalized_logprobs=False, + top_p=None, + mm_inputs=None, + top_logprobs_nums=None, + token_ids_logprobs=None, + next_token_logits_buffer=None, + return_hidden_states_before_norm=False, + ) + ) + + errors = [] + for field in dataclasses.fields(ForwardBatch): + if getattr(batch, field.name) is not None and field.name not in output_dict: + errors.append( + f"Field {field.name} has value, but is not yet supported (value={getattr(batch, field.name)} batch={batch})" + ) + if len(errors) > 0: + raise Exception(f"{len(errors)} errors happen:\n" + "\n\n".join(errors)) + + return ForwardBatch(**output_dict) + + @classmethod + def compute_tbo_children_num_token_non_padded(cls, batch: ForwardBatch): + return cls.compute_tbo_children_num_token_non_padded_raw( + tbo_split_token_index=cls._compute_split_token_index(batch), + num_token_non_padded=len(batch.input_ids), + ) + + @classmethod + def compute_tbo_children_num_token_non_padded_raw( + cls, tbo_split_token_index: int, num_token_non_padded: int + ): + # TODO we may make padding on both sub-batches to make it slightly more balanced + value_a = min(tbo_split_token_index, num_token_non_padded) + value_b = max(0, num_token_non_padded - tbo_split_token_index) + return torch.tensor([value_a, value_b], dtype=torch.int32).to( + device=get_global_server_args().device, non_blocking=True + ) + + @classmethod + def _compute_split_token_index(cls, batch: ForwardBatch): + token_num_per_seq = get_token_num_per_seq( + forward_mode=batch.forward_mode, spec_info=batch.spec_info + ) + return compute_split_token_index( + split_seq_index=batch.tbo_split_seq_index, + forward_mode=batch.forward_mode, + extend_seq_lens=batch.extend_seq_lens_cpu, + token_num_per_seq=token_num_per_seq, + ) + + +def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode): + if ( + forward_mode.is_decode() + or forward_mode.is_idle() + or forward_mode.is_target_verify() + ): + return None + elif forward_mode.is_extend(): + return input_ids.shape[0] + raise NotImplementedError + + +# -------------------------------- Execution --------------------------------------- + + +def model_forward_maybe_tbo( + layers, + enable_tbo: bool, + positions: torch.Tensor, + forward_batch: ForwardBatch, + hidden_states: torch.Tensor, + input_data_scatter_mode: ScatterMode, + residual: Optional[torch.Tensor], + zero_allocator: Optional[BumpAllocator] = None, +): + inputs = dict( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + residual=residual, + zero_allocator=zero_allocator, + ) + layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode + operations_strategy = OperationsStrategy.init_new_tbo( + layers, forward_batch.global_forward_mode + ) + if enable_tbo: + return _model_forward_tbo( + inputs=inputs, + operations_strategy=operations_strategy, + input_data_scatter_mode=input_data_scatter_mode, + layer_input_scatter_mode=layer_input_scatter_mode, + ) + else: + return _model_forward_non_tbo(inputs, operations_strategy) + + +def _model_forward_tbo( + inputs, + operations_strategy: OperationsStrategy, + input_data_scatter_mode: ScatterMode, + layer_input_scatter_mode: ScatterMode, +): + inputs_arr = _model_forward_tbo_split_inputs( + **inputs, + input_data_scatter_mode=input_data_scatter_mode, + layer_input_scatter_mode=layer_input_scatter_mode, + ) + original_hidden_states_len = inputs["hidden_states"].shape[0] + del inputs + + context = ( + empty_context() + if _is_hip + else deep_gemm_wrapper.configure_deep_gemm_num_sms( + operations_strategy.deep_gemm_num_sms + ) + ) + + with context: + outputs_arr = execute_overlapped_operations( + inputs_arr=inputs_arr, + operations_arr=[operations_strategy.operations] * 2, + delta_stages=[0, operations_strategy.tbo_delta_stages], + ) + + return _model_forward_tbo_merge_outputs(*outputs_arr, original_hidden_states_len) + + +def _model_forward_non_tbo(inputs, operations_strategy: OperationsStrategy): + outputs = execute_operations(inputs, operations_strategy.operations) + return outputs["hidden_states"], outputs["residual"] + + +def _model_forward_tbo_split_inputs( + hidden_states: torch.Tensor, + residual: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + zero_allocator: Optional[BumpAllocator], + input_data_scatter_mode: ScatterMode, + layer_input_scatter_mode: ScatterMode, +) -> List[Dict]: + tbo_splitter_scatter_mode = ScatterMode.TP_ATTN_FULL + context = CommunicateContext.init_new() + + hidden_states, residual = CommunicateSummableTensorPairFn.execute( + hidden_states_input_mode=input_data_scatter_mode, + residual_input_mode=input_data_scatter_mode, + output_mode=tbo_splitter_scatter_mode, + hidden_states=hidden_states, + residual=residual, + forward_batch=forward_batch, + context=context, + ) + + inputs_arr = _model_forward_tbo_split_inputs_raw( + hidden_states=hidden_states, + residual=residual, + positions=positions, + forward_batch=forward_batch, + zero_allocator=zero_allocator, + ) + + def _post_transform(hidden_states, residual, forward_batch, **kwargs): + hidden_states, residual = CommunicateSummableTensorPairFn.execute( + hidden_states_input_mode=tbo_splitter_scatter_mode, + residual_input_mode=tbo_splitter_scatter_mode, + output_mode=layer_input_scatter_mode, + hidden_states=hidden_states, + residual=residual, + forward_batch=forward_batch, + context=context, + ) + return dict( + hidden_states=hidden_states, + residual=residual, + forward_batch=forward_batch, + **kwargs, + ) + + return [_post_transform(**inputs) for inputs in inputs_arr] + + +def _model_forward_tbo_split_inputs_raw( + hidden_states: torch.Tensor, + residual: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + zero_allocator: Optional[BumpAllocator], +) -> List[Dict]: + return [ + dict( + **_model_forward_filter_inputs( + hidden_states=hidden_states, + residual=residual, + positions=positions, + output_forward_batch=output_forward_batch, + tbo_subbatch_index=tbo_subbatch_index, + ), + **( + dict(zero_allocator=zero_allocator) + if zero_allocator is not None + else {} + ), + ) + for tbo_subbatch_index, output_forward_batch in enumerate( + forward_batch.tbo_children + ) + ] + + +def _model_forward_filter_inputs( + hidden_states: torch.Tensor, + residual: torch.Tensor, + positions: torch.Tensor, + output_forward_batch: ForwardBatch, + tbo_subbatch_index: int, +) -> Dict: + token_slice = slice(*output_forward_batch.tbo_parent_token_range) + hidden_states = hidden_states[token_slice] + residual = None if residual is None else residual[token_slice] + positions = positions[token_slice] + + assert output_forward_batch.tbo_padded_len is not None + padded_len = output_forward_batch.tbo_padded_len + + def _pad(x): + nonlocal padded_len + if x is None: + return None + if x.shape[0] == padded_len: + return x + res = torch.zeros((padded_len, *x.shape[1:]), dtype=x.dtype, device=x.device) + res[: x.shape[0]] = x + return res + + return dict( + hidden_states=_pad(hidden_states), + residual=_pad(residual), + positions=_pad(positions), + forward_batch=output_forward_batch, + tbo_subbatch_index=tbo_subbatch_index, + ) + + +def _model_forward_tbo_merge_outputs(output_a, output_b, original_len): + def _handle_key(name): + value_a = output_a[name] + value_b = output_b[name] + assert (value_a is None) == (value_b is None) + if value_a is None: + return None + s0, t0 = output_a["forward_batch"].tbo_parent_token_range + s1, t1 = output_b["forward_batch"].tbo_parent_token_range + res = torch.zeros( + (original_len, *value_a.shape[1:]), + dtype=value_a.dtype, + device=value_a.device, + ) + res[slice(s0, t0)] = value_a[: t0 - s0] + res[slice(s1, t1)] = value_b[: t1 - s1] + return res + + return _handle_key("hidden_states"), _handle_key("residual") + + +# -------------------------------- Utilities and wrappers --------------------------------------- + + +class MaybeTboDeepEPDispatcher(BaseDispatcher): + def __init__(self, **kwargs): + super().__init__() + num_inner_dispatchers = 2 if is_tbo_enabled() else 1 + if get_moe_a2a_backend().is_deepep(): + self._inners = [ + DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers) + ] + elif get_moe_a2a_backend().is_mooncake(): + self._inners = [ + MooncakeEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers) + ] + elif get_moe_a2a_backend().is_mori(): + self._inners = [ + MoriEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers) + ] + + def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs): + return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs) + + def dispatch(self, **kwargs) -> DispatchOutput: + return self._execute("dispatch", **kwargs) + + def dispatch_a(self, **kwargs): + return self._execute("dispatch_a", **kwargs) + + def dispatch_b(self, **kwargs): + return self._execute("dispatch_b", **kwargs) + + def combine(self, **kwargs) -> torch.Tensor: + return self._execute("combine", **kwargs) + + def combine_a(self, **kwargs): + return self._execute("combine_a", **kwargs) + + def combine_b(self, **kwargs): + return self._execute("combine_b", **kwargs) + + def register_deepep_dispatch_hook(self, hook): + handle_list = [] + for inner in self._inners: + handle_list.append(inner.register_deepep_dispatch_hook(hook)) + return handle_list + + def set_quant_config(self, quant_config: dict): + super().set_quant_config(quant_config) + for inner in self._inners: + inner.set_quant_config(quant_config) + + def set_overlap_args( + self, combine_overlap_args: CombineOverlapArgs, meta_overlap_args: dict + ): + super().set_overlap_args(combine_overlap_args, meta_overlap_args) + for inner in self._inners: + inner.set_overlap_args(combine_overlap_args, meta_overlap_args) + + def clear_overlap_args(self): + super().clear_overlap_args() + for inner in self._inners: + inner.clear_overlap_args() diff --git a/sglang/python/sglang/srt/checkpoint_engine/__init__.py b/sglang/python/sglang/srt/checkpoint_engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d8a77f905b1b5f686ab07b1cb8e84ca8441270b6 --- /dev/null +++ b/sglang/python/sglang/srt/checkpoint_engine/__init__.py @@ -0,0 +1,9 @@ +""" +Checkpoint engine module for SGLang. + +This module provides functionality for updating model weights via checkpoint engine. +""" + +from sglang.srt.checkpoint_engine.update import main + +__all__ = ["main"] diff --git a/sglang/python/sglang/srt/checkpoint_engine/checkpoint_engine_worker.py b/sglang/python/sglang/srt/checkpoint_engine/checkpoint_engine_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..6f11c7872540f3dd8235d9e04b61bed48df523f3 --- /dev/null +++ b/sglang/python/sglang/srt/checkpoint_engine/checkpoint_engine_worker.py @@ -0,0 +1,143 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Checkpoint-engine integration for SGLang. +This module provides weight update functionality via IPC for checkpoint-engine compatibility. +""" + +import logging +from typing import Callable, Dict, Optional + +import torch +import zmq + +try: + from checkpoint_engine.worker import update_weights_from_ipc +except ImportError: + raise ImportError( + "checkpoint-engine is not installed. " + "Please install it with: pip install sglang[checkpoint-engine]" + ) + +logger = logging.getLogger(__name__) + + +class SGLangCheckpointEngineWorkerExtension: + """ + Worker extension for SGLang to support checkpoint-engine IPC weight updates. + This class provides the interface needed for checkpoint-engine integration. + """ + + def __init__(self): + self._zmq_ctx: Optional[zmq.Context] = None + + def get_device_uuid(self) -> str: + """Get the UUID of current device.""" + # We need to implement this to get the device UUID + # This will be overridden when integrated into SGLang's worker + raise NotImplementedError( + "This method should be overridden by SGLang integration" + ) + + def get_device_id(self) -> int: + """Get the device ID.""" + raise NotImplementedError( + "This method should be overridden by SGLang integration" + ) + + def get_model_loader(self) -> Callable: + """Get the model weight loader function.""" + raise NotImplementedError( + "This method should be overridden by SGLang integration" + ) + + def get_post_hook(self) -> Optional[Callable]: + """Get the post-processing hook after weight loading.""" + return None + + def update_weights_from_ipc(self, zmq_handles: Dict[str, str]): + """ + Update weights from IPC communication. + Args: + zmq_handles: Dict mapping device UUID to ZMQ socket path + """ + if self._zmq_ctx is None: + self._zmq_ctx = zmq.Context() + device_uuid = self.get_device_uuid() + device_id = self.get_device_id() + if device_uuid not in zmq_handles: + raise ValueError( + f"Device UUID {device_uuid} not found in zmq_handles: {list(zmq_handles.keys())}" + ) + update_weights_from_ipc( + self._zmq_ctx, + zmq_handles[device_uuid], + device_id=device_id, + run=self.get_model_loader(), + post_hook=self.get_post_hook(), + ) + + +class SGLangCheckpointEngineWorkerExtensionImpl(SGLangCheckpointEngineWorkerExtension): + """ + Implementation of SGLangCheckpointEngineWorkerExtension that integrates with SGLang's model runner. + This class provides the concrete implementation for checkpoint-engine IPC weight updates. + """ + + def __init__(self, model_runner): + super().__init__() + self.model_runner = model_runner + + def get_device_uuid(self) -> str: + """Get the UUID of current device.""" + # Get device UUID for current device + device_id = torch.cuda.current_device() + try: + return f"GPU-{torch.cuda.get_device_properties(device_id).uuid!s}" + except AssertionError as e: + raise ValueError(f"Failed to get GPU UUID for device {device_id}") from e + + def get_device_id(self) -> int: + """Get the device ID.""" + return torch.cuda.current_device() + + def get_model_loader(self) -> Callable: + """Get the model weight loader function.""" + return self.model_runner.model.load_weights + + def get_post_hook(self) -> Optional[Callable]: + """Get the post-processing hook after weight loading.""" + + def post_hook(): + # Perform post-processing after weight loading similar to DefaultModelLoader + try: + from sglang.srt.model_loader.loader import device_loading_context + + # Process quantization methods after loading weights + for _, module in self.model_runner.model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # Move parameters to device if needed for quantization processing + target_device = torch.device( + "cuda", torch.cuda.current_device() + ) + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + # Call model-specific post-loading hook if available + if hasattr(self.model_runner.model, "post_load_weights"): + self.model_runner.model.post_load_weights() + except Exception as e: + logger.warning(f"Post-hook processing failed: {e}") + + return post_hook diff --git a/sglang/python/sglang/srt/checkpoint_engine/update.py b/sglang/python/sglang/srt/checkpoint_engine/update.py new file mode 100644 index 0000000000000000000000000000000000000000..93c8b4b6e4c19553ef8357b8e25d55099a4a885a --- /dev/null +++ b/sglang/python/sglang/srt/checkpoint_engine/update.py @@ -0,0 +1,317 @@ +""" +Usage: +1) Launch the server with wait-for-initial-weights option in one terminal: + python -m sglang.launch_server --model-path /workspace/Qwen/Qwen3-4B/ --tensor-parallel-size 2 --port 19730 --load-format dummy --checkpoint-engine-wait-weights-before-ready --mem-fraction-static 0.7 + +2) Torchrun this script in another terminal: + torchrun --nproc-per-node 2 update.py --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2 + +Or use the integrated entry point: + python -m sglang.srt.checkpoint_engine.update --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2 +""" + +import argparse +import json +import os +import pickle +import subprocess +import sys +import time +from collections import defaultdict +from collections.abc import Callable +from contextlib import contextmanager +from typing import Literal + +import httpx +import torch +import torch.distributed as dist +from safetensors import safe_open + +try: + from checkpoint_engine.ps import ParameterServer + from loguru import logger +except ImportError: + # Fallback for when checkpoint_engine is not available + ParameterServer = None + import logging + + logger = logging.getLogger(__name__) + + +@contextmanager +def timer(msg: str): + start = time.perf_counter() + yield + end = time.perf_counter() + logger.info(f"{msg} duration: {end - start:.2f} seconds") + + +def check_sglang_ready( + endpoint: str, inference_parallel_size: int, uds: str | None = None +): + rank = int(os.getenv("RANK", 0)) + if rank != rank // inference_parallel_size * inference_parallel_size: + return + retry_num = 0 + transport = None + if uds is not None: + transport = httpx.HTTPTransport(uds=uds) + with httpx.Client(transport=transport) as client: + while True: + try: + response = client.get(f"{endpoint}/ping", timeout=10) + response.raise_for_status() + break + except (httpx.ConnectError, httpx.HTTPStatusError) as e: + if retry_num % 10 == 0: + logger.warning( + f"fail to check sglang ready, retry {retry_num} times, error: {e}" + ) + retry_num += 1 + time.sleep(0.1) + + +def split_checkpoint_files( + checkpoint_path: str, rank: int, world_size: int +) -> list[str]: + checkpoint_files = [ + os.path.join(checkpoint_path, f) + for f in filter( + lambda x: x.endswith(".safetensors"), os.listdir(checkpoint_path) + ) + ] + files_per_rank = (len(checkpoint_files) + world_size - 1) // world_size + return checkpoint_files[rank * files_per_rank : (rank + 1) * files_per_rank] + + +def split_tensors( + checkpoint_path: str, rank: int, world_size: int +) -> dict[str, torch.Tensor]: + index_fn = os.path.join(checkpoint_path, "model.safetensors.index.json") + with open(index_fn) as f: + weight_map: dict[str, str] = json.load(f)["weight_map"] + weights_per_rank = (len(weight_map) + world_size - 1) // world_size + fn_tensors: dict[str, list[str]] = defaultdict(list) + weight_keys = list(weight_map.items()) + for name, file in weight_keys[ + rank * weights_per_rank : (rank + 1) * weights_per_rank + ]: + fn_tensors[file].append(name) + named_tensors = {} + for file, names in fn_tensors.items(): + with safe_open(os.path.join(checkpoint_path, file), framework="pt") as f: + for name in names: + named_tensors[name] = f.get_tensor(name) + return named_tensors + + +def req_inference( + endpoint: str, + inference_parallel_size: int, + timeout: float = 300.0, + uds: str | None = None, + weight_version: str | None = None, +) -> Callable[[list[tuple[str, str]]], None]: + rank = int(os.getenv("RANK", 0)) + src = rank // inference_parallel_size * inference_parallel_size + + def req_func(socket_paths: list[tuple[str, str]]): + if rank == src: + with httpx.Client(transport=httpx.HTTPTransport(uds=uds)) as client: + resp = client.post( + f"{endpoint}/update_weights_from_ipc", + json={ + "zmq_handles": dict( + socket_paths[src : src + inference_parallel_size] + ), + "flush_cache": True, + "weight_version": weight_version, + }, + timeout=timeout, + ) + resp.raise_for_status() + + return req_func + + +def update_weights( + ps, + checkpoint_name: str, + checkpoint_files: list[str], + named_tensors: dict[str, torch.Tensor], + req_func: Callable[[list[tuple[str, str]]], None], + inference_parallel_size: int, + endpoint: str, + save_metas_file: str | None = None, + update_method: Literal["broadcast", "p2p", "all"] = "broadcast", + uds: str | None = None, +): + ps.register_checkpoint( + checkpoint_name, files=checkpoint_files, named_tensors=named_tensors + ) + ps.init_process_group() + check_sglang_ready(endpoint, inference_parallel_size, uds) + dist.barrier() + with timer("Gather metas"): + ps.gather_metas(checkpoint_name) + if save_metas_file and int(os.getenv("RANK")) == 0: + with open(save_metas_file, "wb") as f: + pickle.dump(ps.get_metas(), f) + + if update_method == "broadcast" or update_method == "all": + with timer("Update weights without setting ranks"): + ps.update(checkpoint_name, req_func) + + if update_method == "p2p" or update_method == "all": + if update_method: + # sleep 2s to wait destroy process group + time.sleep(2) + with timer("Update weights with setting ranks"): + ps.update( + checkpoint_name, req_func, ranks=list(range(inference_parallel_size)) + ) + + +def join( + ps: ParameterServer, + checkpoint_name: str, + load_metas_file: str, + req_func: Callable[[list[tuple[str, str]]], None], + inference_parallel_size: int, + endpoint: str, + uds: str | None = None, +): + assert load_metas_file, "load_metas_file is required" + with open(load_metas_file, "rb") as f: + metas = pickle.load(f) + ps.init_process_group() + check_sglang_ready(endpoint, inference_parallel_size, uds) + dist.barrier() + with timer("Gather metas before join"): + ps.gather_metas(checkpoint_name) + ps.load_metas(metas) + with timer( + f"Update weights with setting ranks as range(0, {inference_parallel_size}) by using p2p" + ): + ps.update(checkpoint_name, req_func, ranks=list(range(inference_parallel_size))) + + +def run_with_torchrun(): + """Run the update script with torchrun automatically.""" + # Parse inference_parallel_size from command line arguments to determine nproc-per-node + inference_parallel_size = 8 # default + args = sys.argv[1:] # Skip the script name + + # Look for --inference-parallel-size in arguments + for i, arg in enumerate(args): + if arg == "--inference-parallel-size" and i + 1 < len(args): + try: + inference_parallel_size = int(args[i + 1]) + except ValueError: + pass + break + elif arg.startswith("--inference-parallel-size="): + try: + inference_parallel_size = int(arg.split("=", 1)[1]) + except ValueError: + pass + break + + # Build torchrun command + cmd = ["torchrun", f"--nproc-per-node={inference_parallel_size}", __file__] + args + + print(f"Running: {' '.join(cmd)}", file=sys.stderr) + + # Execute torchrun with the original script + try: + result = subprocess.run(cmd, check=False) + sys.exit(result.returncode) + except FileNotFoundError: + print( + "Error: torchrun command not found. Please ensure PyTorch is installed.", + file=sys.stderr, + ) + sys.exit(1) + except KeyboardInterrupt: + print("\nInterrupted by user", file=sys.stderr) + sys.exit(130) + + +def main(): + # Check if we're running under torchrun or need to invoke it + if os.getenv("RANK") is None: + # Not running under torchrun, so invoke it + run_with_torchrun() + return + + # Running under torchrun, proceed with normal execution + parser = argparse.ArgumentParser(description="Update weights example") + parser.add_argument("--checkpoint-path", type=str, default=None) + parser.add_argument("--save-metas-file", type=str, default=None) + parser.add_argument("--load-metas-file", type=str, default=None) + parser.add_argument("--sleep-time", type=int, default=0) + parser.add_argument("--endpoint", type=str, default="http://localhost:19730") + parser.add_argument("--inference-parallel-size", type=int, default=8) + parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0") + parser.add_argument("--update-method", type=str, default="broadcast") + parser.add_argument("--uds", type=str, default=None) + parser.add_argument("--weight-version", type=str, default=None) + args = parser.parse_args() + + # Get rank and world_size from environment (set by torchrun) + rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + + req_func = req_inference( + args.endpoint, + args.inference_parallel_size, + uds=args.uds, + weight_version=args.weight_version, + ) + + if ParameterServer is None: + print("Error: checkpoint_engine package not available", file=sys.stderr) + sys.exit(1) + + ps = ParameterServer(auto_pg=True) + ps._p2p_store = None + if args.load_metas_file: + join( + ps, + args.checkpoint_name, + args.load_metas_file, + req_func, + args.inference_parallel_size, + args.endpoint, + args.uds, + ) + else: + if args.checkpoint_path and os.path.exists( + os.path.join(args.checkpoint_path, "model.safetensors.index.json") + ): + named_tensors = split_tensors(args.checkpoint_path, rank, world_size) + checkpoint_files = [] + else: + checkpoint_files = ( + split_checkpoint_files(args.checkpoint_path, rank, world_size) + if args.checkpoint_path + else [] + ) + named_tensors = {} + update_weights( + ps, + args.checkpoint_name, + checkpoint_files, + named_tensors, + req_func, + args.inference_parallel_size, + args.endpoint, + args.save_metas_file, + args.update_method, + args.uds, + ) + time.sleep(args.sleep_time) + + +if __name__ == "__main__": + main() diff --git a/sglang/python/sglang/srt/compilation/__pycache__/compilation_config.cpython-311.pyc b/sglang/python/sglang/srt/compilation/__pycache__/compilation_config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20c7f53376a588252aa6e142a68c1317a04610bf Binary files /dev/null and b/sglang/python/sglang/srt/compilation/__pycache__/compilation_config.cpython-311.pyc differ diff --git a/sglang/python/sglang/srt/compilation/__pycache__/compile.cpython-311.pyc b/sglang/python/sglang/srt/compilation/__pycache__/compile.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed919433bed365d558c6d9e8bae4126c61840798 Binary files /dev/null and b/sglang/python/sglang/srt/compilation/__pycache__/compile.cpython-311.pyc differ diff --git a/sglang/python/sglang/srt/compilation/__pycache__/piecewise_context_manager.cpython-311.pyc b/sglang/python/sglang/srt/compilation/__pycache__/piecewise_context_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45be632209579f32e88249ae7e34c8c0ab793329 Binary files /dev/null and b/sglang/python/sglang/srt/compilation/__pycache__/piecewise_context_manager.cpython-311.pyc differ diff --git a/sglang/python/sglang/srt/compilation/backend.py b/sglang/python/sglang/srt/compilation/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..8af025707f558fe2834431acdff73b80b38944fb --- /dev/null +++ b/sglang/python/sglang/srt/compilation/backend.py @@ -0,0 +1,472 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/backend.py + + +import ast +import dataclasses +import logging +import os +import pprint +import time +from collections.abc import Sequence +from contextlib import contextmanager +from typing import Any, Callable, Optional + +import torch +import torch.fx as fx +from torch._dispatch.python import enable_python_dispatcher + +from sglang.srt.compilation.compilation_config import CompilationConfig +from sglang.srt.compilation.compilation_counter import compilation_counter +from sglang.srt.compilation.compiler_interface import EagerAdapter, InductorAdaptor +from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend +from sglang.srt.compilation.npu_piecewise_backend import NPUPiecewiseBackend +from sglang.srt.compilation.pass_manager import PostGradPassManager +from sglang.srt.utils.common import is_npu, rank0_log + +logger = logging.getLogger(__name__) + + +def make_compiler(config: CompilationConfig): + if config.compiler == "eager": + return EagerAdapter() + elif config.compiler == "inductor": + return InductorAdaptor() + else: + raise ValueError(f"Unknown compiler: {config.compiler}") + + +def make_backend( + graph: fx.GraphModule, + compile_config: CompilationConfig, + inductor_config: dict[str, Any], + graph_pool: Any, + piecewise_compile_index: int, + total_piecewise_compiles: int, + sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + sglang_backend, +): + + backend_cls = CUDAPiecewiseBackend if not is_npu() else NPUPiecewiseBackend + return backend_cls( + graph, + compile_config, + inductor_config, + graph_pool, + piecewise_compile_index, + total_piecewise_compiles, + sym_shape_indices, + compiled_graph_for_general_shape, + sglang_backend, + ) + + +class CompilerManager: + def __init__( + self, + config: CompilationConfig, + ): + self.cache = dict() + self.is_cache_updated = False + self.compiler = make_compiler(config) + + def compute_hash(self): + return self.compiler.compute_hash() + + def initialize_cache( + self, cache_dir: str, disable_cache: bool = False, prefix: str = "" + ): + self.disable_cache = disable_cache + self.cache_dir = cache_dir + self.cache_file_path = os.path.join(cache_dir, "sglang_compile_cache.py") + + if not disable_cache and os.path.exists(self.cache_file_path): + with open(self.cache_file_path) as f: + self.cache = ast.literal_eval(f.read()) + + self.compiler.initialize_cache( + cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix + ) + + def save_to_file(self): + if self.disable_cache or not self.is_cache_updated: + return + printer = pprint.PrettyPrinter(indent=4) + data = printer.pformat(self.cache) + with open(self.cache_file_path, "w") as f: + f.write(data) + + def load( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None, + ) -> Optional[Callable]: + handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] + compiled_graph = self.compiler.load( + handle, graph, example_inputs, graph_index, runtime_shape + ) + if runtime_shape is None: + logger.debug( + "Directly load the %s-th graph for dynamic shape from %s via " + "handle %s", + graph_index, + self.compiler.name, + handle, + ) + else: + logger.debug( + "Directly load the %s-th graph for shape %s from %s via " "handle %s", + graph_index, + str(runtime_shape), + self.compiler.name, + handle, + ) + return compiled_graph + + def compile( + self, + graph: fx.GraphModule, + example_inputs, + inductor_config: dict[str, Any], + graph_index: int = 0, + num_graphs: int = 1, + runtime_shape: Optional[int] = None, + ) -> Any: + if graph_index == 0: + # before compiling the first graph, record the start time + global compilation_start_time + compilation_start_time = time.time() + + compilation_counter.num_backend_compilations += 1 + + compiled_graph = None + + # TODO(Yuwei): support cache loading + + # no compiler cached the graph, or the cache is disabled, + # we need to compile it + if isinstance(self.compiler, InductorAdaptor): + maybe_key = None + else: + maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" + compiled_graph, handle = self.compiler.compile( + graph, example_inputs, inductor_config, runtime_shape, maybe_key + ) + + assert compiled_graph is not None, "Failed to compile the graph" + + # store the artifact in the cache + if handle is not None: + self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle + compilation_counter.num_cache_entries_updated += 1 + self.is_cache_updated = True + if graph_index == 0: + # adds some info logging for the first graph + if runtime_shape is None: + logger.info("Cache the graph for dynamic shape for later use") + else: + logger.info( + "Cache the graph of shape %s for later use", str(runtime_shape) + ) + if runtime_shape is None: + logger.debug( + "Store the %s-th graph for dynamic shape from %s via " "handle %s", + graph_index, + self.compiler.name, + handle, + ) + else: + logger.debug( + "Store the %s-th graph for shape %s from %s via handle %s", + graph_index, + str(runtime_shape), + self.compiler.name, + handle, + ) + + # after compiling the last graph, record the end time + if graph_index == num_graphs - 1: + now = time.time() + elapsed = now - compilation_start_time + if runtime_shape is None: + logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed) + else: + logger.info( + "Compiling a graph for shape %s takes %.2f s", + runtime_shape, + elapsed, + ) + + return compiled_graph + + +@dataclasses.dataclass +class SplitItem: + submod_name: str + graph_id: int + is_splitting_graph: bool + graph: fx.GraphModule + + +def split_graph( + graph: fx.GraphModule, ops: list[str] +) -> tuple[fx.GraphModule, list[SplitItem]]: + # split graph by ops + subgraph_id = 0 + node_to_subgraph_id = {} + split_op_graphs = [] + for node in graph.graph.nodes: + if node.op in ("output", "placeholder"): + continue + if node.op == "call_function" and str(node.target) in ops: + subgraph_id += 1 + node_to_subgraph_id[node] = subgraph_id + split_op_graphs.append(subgraph_id) + subgraph_id += 1 + else: + node_to_subgraph_id[node] = subgraph_id + + # `keep_original_order` is important! + # otherwise pytorch might reorder the nodes and + # the semantics of the graph will change when we + # have mutations in the graph + split_gm = torch.fx.passes.split_module.split_module( + graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True + ) + + outputs = [] + + names = [name for (name, module) in split_gm.named_modules()] + + for name in names: + if "." in name or name == "": + # recursive child module or the root module + continue + + module = getattr(split_gm, name) + + graph_id = int(name.replace("submod_", "")) + outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module)) + + # sort by intetger graph_id, rather than string name + outputs.sort(key=lambda x: x.graph_id) + + return split_gm, outputs + + +# we share the global graph pool among all the backends +global_graph_pool = None + +compilation_start_time = 0.0 + + +class PiecewiseCompileInterpreter(torch.fx.Interpreter): + def __init__( + self, + module: torch.fx.GraphModule, + compile_submod_names: list[str], + inductor_config: dict[str, Any], + graph_pool, + compile_config: CompilationConfig, + sglang_backend: "SGLangBackend", + ): + super().__init__(module) + from torch._guards import detect_fake_mode + + self.fake_mode = detect_fake_mode() + self.compile_submod_names = compile_submod_names + self.graph_pool = graph_pool + self.sglang_backend = sglang_backend + # When True, it annoyingly dumps the torch.fx.Graph on errors. + self.extra_traceback = False + self.inductor_config = inductor_config + self.compile_config = compile_config + + def run(self, *args): + fake_args = [ + self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t + for t in args + ] + with self.fake_mode, enable_python_dispatcher(): + return super().run(*fake_args) + + def call_module( + self, + target: torch.fx.node.Target, + args: tuple[torch.fx.node.Argument, ...], + kwargs: dict[str, Any], + ) -> Any: + assert isinstance(target, str) + output = super().call_module(target, args, kwargs) + + if target in self.compile_submod_names: + index = self.compile_submod_names.index(target) + submod = self.fetch_attr(target) + sym_shape_indices = [ + i for i, x in enumerate(args) if isinstance(x, torch.SymInt) + ] + global compilation_start_time + compiled_graph_for_dynamic_shape = ( + self.sglang_backend.compiler_manager.compile( + submod, + args, + self.inductor_config, + graph_index=index, + num_graphs=len(self.compile_submod_names), + runtime_shape=None, + ) + ) + + self.module.__dict__[target] = make_backend( + submod, + self.compile_config, + self.inductor_config, + self.graph_pool, + index, + len(self.compile_submod_names), + sym_shape_indices, + compiled_graph_for_dynamic_shape, + self.sglang_backend, + ) + + compilation_counter.num_piecewise_capturable_graphs_seen += 1 + + return output + + +model_tag: str = "backbone" + + +@contextmanager +def set_model_tag(tag: str): + """Context manager to set the model tag.""" + global model_tag + assert ( + tag != model_tag + ), f"Model tag {tag} is the same as the current tag {model_tag}." + old_tag = model_tag + model_tag = tag + try: + yield + finally: + model_tag = old_tag + + +class SGLangBackend: + + graph_pool: Any + _called: bool = False + # the graph we compiled + graph: fx.GraphModule + # the stiching graph module for all the piecewise graphs + split_gm: fx.GraphModule + piecewise_graphs: list[SplitItem] + returned_callable: Callable + # Inductor passes to run on the graph pre-defunctionalization + post_grad_passes: Sequence[Callable] + sym_tensor_indices: list[int] + input_buffers: list[torch.Tensor] + compiler_manager: CompilerManager + + def __init__( + self, + config: CompilationConfig, + graph_pool: Any, + ): + rank0_log(f"Initializing SGLangBackend") + assert graph_pool is not None + self.graph_pool = graph_pool + + self.post_grad_pass_manager = PostGradPassManager() + self.sym_tensor_indices = [] + self.input_buffers = [] + + self.compiler_manager = CompilerManager(config) + self.inductor_config = { + "enable_auto_functionalized_v2": False, + } + self.compile_config = config + + def configure_post_pass(self): + self.post_grad_pass_manager.configure() + self.inductor_config["post_grad_custom_post_pass"] = self.post_grad_pass_manager + + def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: + rank0_log(f"SGLangBackend __call__") + base_cache_dir = os.path.expanduser( + os.getenv("SGLANG_CACHE_DIR", "~/.cache/sglang/") + ) + + cache_hash = self.compiler_manager.compute_hash() + cache_dir = os.path.join( + base_cache_dir, + "torch_compile_cache", + cache_hash, + ) + + os.makedirs(cache_dir, exist_ok=True) + rank = 0 + dp_rank = 0 + local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", model_tag) + os.makedirs(local_cache_dir, exist_ok=True) + self.compiler_manager.initialize_cache( + local_cache_dir, disable_cache=False, prefix="" + ) + compilation_counter.num_graphs_seen += 1 + + assert not self._called, "SGLangBackend can only be called once" + + self.graph = graph + self.configure_post_pass() + + self.split_gm, self.piecewise_graphs = split_graph( + graph, + self.compile_config.split_ops, + ) + from torch._dynamo.utils import lazy_format_graph_code + + # depyf will hook lazy_format_graph_code and dump the graph + # for debugging, no need to print the graph here + lazy_format_graph_code("before split", self.graph) + lazy_format_graph_code("after split", self.split_gm) + + compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs) + + submod_names_to_compile = [ + item.submod_name + for item in self.piecewise_graphs + if not item.is_splitting_graph + ] + + PiecewiseCompileInterpreter( + self.split_gm, + submod_names_to_compile, + self.inductor_config, + self.graph_pool, + self.compile_config, + self, + ).run(*example_inputs) + + rank = torch.distributed.get_rank() + + if rank == 0: + graph_path = os.path.join( + local_cache_dir, f"computation_graph_{time.time()}.py" + ) + if not os.path.exists(graph_path): + # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa + # use `print_readable` because it can include submodules + src = ( + "from __future__ import annotations\nimport torch\n" + + self.split_gm.print_readable(print_output=False) + ) + src = src.replace("", "GraphModule") + with open(graph_path, "w") as f: + f.write(src) + + rank0_log(f"Computation graph saved to {graph_path}") + + self._called = True + return self.split_gm diff --git a/sglang/python/sglang/srt/compilation/compilation_config.py b/sglang/python/sglang/srt/compilation/compilation_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0388bbedac0637a01f1d897a89c184a042a41752 --- /dev/null +++ b/sglang/python/sglang/srt/compilation/compilation_config.py @@ -0,0 +1,45 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_config.py + +from typing import Callable, List, Optional + +SPLIT_OPS = [] + + +def register_split_op(op_name: Optional[str] = None): + def decorator(op_func: Callable): + name = op_name or op_func.__name__ + SPLIT_OPS.append(f"sglang.{name}") + return op_func + + return decorator + + +# TODO(Yuwei): support better compile config support +class CompilationConfig: + def __init__( + self, + capture_sizes: List[int], + compiler: str = "eager", + enable_debug_mode: bool = False, + ): + self.traced_files = set() + self.capture_sizes = capture_sizes + self.compiler = compiler + self.enable_debug_mode = enable_debug_mode + self.split_ops = [] + self.split_ops.extend(SPLIT_OPS) + + def add_split_op(self, op: str): + self.split_ops.append(op) + + def add_traced_file(self, file_path: str): + self.traced_files.add(file_path) + + def get_traced_files(self): + return self.traced_files + + def get_capture_sizes(self): + return self.capture_sizes + + def get_enable_debug_mode(self): + return self.enable_debug_mode diff --git a/sglang/python/sglang/srt/compilation/compilation_counter.py b/sglang/python/sglang/srt/compilation/compilation_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..e973f8f2fc7d366c89410f93caa8eeaca360482d --- /dev/null +++ b/sglang/python/sglang/srt/compilation/compilation_counter.py @@ -0,0 +1,47 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_counter.py + +import copy +import dataclasses +from contextlib import contextmanager + + +@dataclasses.dataclass +class CompilationCounter: + num_models_seen: int = 0 + num_graphs_seen: int = 0 + # including the splitting ops + num_piecewise_graphs_seen: int = 0 + # not including the splitting ops + num_piecewise_capturable_graphs_seen: int = 0 + num_backend_compilations: int = 0 + # Number of gpu_model_runner attempts to trigger CUDAGraphs capture + num_gpu_runner_capture_triggers: int = 0 + # Number of CUDAGraphs captured + num_cudagraph_captured: int = 0 + # InductorAdapter.compile calls + num_inductor_compiles: int = 0 + # EagerAdapter.compile calls + num_eager_compiles: int = 0 + # The number of time vLLM's compiler cache entry was updated + num_cache_entries_updated: int = 0 + # The number of standalone_compile compiled artifacts saved + num_compiled_artifacts_saved: int = 0 + # Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS + dynamo_as_is_count: int = 0 + + def clone(self) -> "CompilationCounter": + return copy.deepcopy(self) + + @contextmanager + def expect(self, **kwargs): + old = self.clone() + yield + for k, v in kwargs.items(): + assert getattr(self, k) - getattr(old, k) == v, ( + f"{k} not as expected, before it is {getattr(old, k)}" + f", after it is {getattr(self, k)}, " + f"expected diff is {v}" + ) + + +compilation_counter = CompilationCounter() diff --git a/sglang/python/sglang/srt/compilation/compile.py b/sglang/python/sglang/srt/compilation/compile.py new file mode 100644 index 0000000000000000000000000000000000000000..46a9240fb2596022bad424cbb9ea013fee515c8a --- /dev/null +++ b/sglang/python/sglang/srt/compilation/compile.py @@ -0,0 +1,203 @@ +import inspect +import logging +import os +import sys +import types +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch + +from sglang.srt.compilation.compilation_config import CompilationConfig +from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph +from sglang.srt.utils.common import rank0_log + +logger = logging.getLogger(__name__) + + +@dataclass +class IntermediateTensors: + """For all pipeline stages except the last, we need to return the hidden + states and residuals to be sent to the next stage. This data structure + contains the hidden states and residuals for a request. + + Each stage also needs to handle its own finished_sending and + finished_recving in case of kv transfer. + """ + + tensors: dict[str, torch.Tensor] + # [req_ids] + finished_sending: Optional[set[str]] = None + finished_recving: Optional[set[str]] = None + + def __init__(self, tensors): + # manually define this function, so that + # Dynamo knows `IntermediateTensors()` comes from this file. + # Otherwise, dataclass will generate this function by evaluating + # a string, and we will lose the information about the source file. + self.tensors = tensors + + def __getitem__(self, key: Union[str, slice]): + if isinstance(key, str): + return self.tensors[key] + elif isinstance(key, slice): + return self.__class__({k: v[key] for k, v in self.tensors.items()}) + + def __setitem__(self, key: str, value: torch.Tensor): + self.tensors[key] = value + + def items(self): + return self.tensors.items() + + def __len__(self): + return len(self.tensors) + + def __eq__(self, other: object): + return isinstance(other, self.__class__) and self + + def __repr__(self) -> str: + return f"IntermediateTensors(tensors={self.tensors})" + + +def _normalize_dims(dims, ndim: int): + dims = [dims] if isinstance(dims, int) else list(dims) + return [d if d >= 0 else ndim + d for d in dims] + + +class _MaybeIntermediateTensors: + """Duck-typed check to support your IntermediateTensors without importing.""" + + def __init__(self, obj): + self.is_intermediate = hasattr(obj, "tensors") and isinstance( + getattr(obj, "tensors"), dict + ) + self.obj = obj + + +def _mark_dynamic_on_value(val, dims): + if isinstance(val, torch.Tensor): + torch._dynamo.maybe_mark_dynamic(val, _normalize_dims(dims, val.ndim)) + else: + mit = _MaybeIntermediateTensors(val) + if mit.is_intermediate: + for t in mit.obj.tensors.values(): + torch._dynamo.maybe_mark_dynamic(t, _normalize_dims(dims, t.ndim)) + # else: ignore (None or non-tensor) + + +def _infer_dynamic_arg_dims_from_annotations(forward_fn): + sig = inspect.signature(forward_fn) + dyn = {} + for name, p in sig.parameters.items(): + ann = p.annotation + # Accept torch.Tensor / Optional[torch.Tensor] / your IntermediateTensors types by name + if ( + ann is torch.Tensor + or getattr(getattr(ann, "__args__", [None])[0], "__name__", "") == "Tensor" + ): + dyn[name] = 0 + elif getattr(ann, "__name__", "") in ("IntermediateTensors",) or any( + getattr(a, "__name__", "") == "IntermediateTensors" + for a in getattr(ann, "__args__", []) + ): + dyn[name] = 0 + elif ann == "torch.Tensor" or ann == "Optional[torch.Tensor]": + # For future import annotations (e.g. from __future__ import annotations), the annotation is a string + dyn[name] = 0 + if not dyn: + raise ValueError("No dynamic dims inferred; pass dynamic_arg_dims explicitly.") + return dyn + + +def install_torch_compiled( + module: torch.nn.Module, + *, + dynamic_arg_dims: dict[str, Union[int, list[int]]] | None = None, + backend_factory: Optional[Callable[[torch.fx.GraphModule, list], Callable]] = None, + compile_config: CompilationConfig = None, + fullgraph: bool = True, + graph_pool: Any = None, +): + rank0_log(f"install_torch_compiled") + unbound_fwd = module.__class__.forward + if not callable(unbound_fwd): + raise TypeError("module.__class__.forward must be callable") + original_code = unbound_fwd.__code__ + + dyn_map = dynamic_arg_dims or _infer_dynamic_arg_dims_from_annotations(unbound_fwd) + + if backend_factory is None: + from sglang.srt.compilation.backend import SGLangBackend + + backend_factory = lambda gm, ex: SGLangBackend(compile_config, graph_pool)( + gm, ex + ) + + compiled_codes: list[type(original_code)] = [] + state = {"compiled": False, "compiled_callable": None} + + def bytecode_hook(old_code, new_code): + if old_code is not original_code: + return + frame = sys._getframe() + while frame and frame.f_back: + frame = frame.f_back + if ( + frame.f_code.co_name == "_compile" + and os.path.basename(frame.f_code.co_filename) == "convert_frame.py" + ): + break + try: + dynamo_frame = frame.f_locals["frame"] + except Exception: + return + if dynamo_frame.f_code is not old_code: + return + if dynamo_frame.f_locals.get("self") is not module: + return + compiled_codes.append(new_code) + + torch._dynamo.convert_frame.register_bytecode_hook(bytecode_hook) + + def _ensure_compiled(self, *args, **kwargs): + """Compile on first use (with flag ON).""" + if state["compiled"]: + return + # Mark dynamic dims only when we are about to compile + sig = inspect.signature(unbound_fwd) + ba = sig.bind(self, *args, **kwargs) + ba.apply_defaults() + for name, dims in (dyn_map or {}).items(): + if name in ba.arguments: + val = ba.arguments[name] + if val is not None: + _mark_dynamic_on_value(val, dims) + + # Avoid cross-instance cache reuse + torch._dynamo.eval_frame.remove_from_cache(unbound_fwd.__code__) + + bound = types.MethodType(unbound_fwd, self) + compiled_callable = torch.compile( + bound, fullgraph=fullgraph, backend=backend_factory + ) + + # Trigger Dynamo so bytecode hook can capture + compiled_callable(*args, **kwargs) + + state["compiled"] = True + state["compiled_callable"] = compiled_callable + + def trampoline(self, *args, **kwargs): + use_compiled = is_in_piecewise_cuda_graph() + if use_compiled: + if not state["compiled"]: + _ensure_compiled(self, *args, **kwargs) + + compiled_callable = state["compiled_callable"] + return compiled_callable(*args, **kwargs) + else: + # Explicitly run the original uncompiled forward + return unbound_fwd(self, *args, **kwargs) + + module.forward = types.MethodType(trampoline, module) + return module diff --git a/sglang/python/sglang/srt/compilation/compiler_interface.py b/sglang/python/sglang/srt/compilation/compiler_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..df7f28b103744132623582921c3af2d0c3566576 --- /dev/null +++ b/sglang/python/sglang/srt/compilation/compiler_interface.py @@ -0,0 +1,504 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compiler_interface.py + +import contextlib +import copy +import hashlib +import os +from contextlib import ExitStack +from typing import Any, Callable, Optional +from unittest.mock import patch + +import torch +import torch._inductor.compile_fx +import torch.fx as fx + +from sglang.srt.compilation.compilation_counter import compilation_counter +from sglang.srt.compilation.inductor_pass import pass_context +from sglang.srt.utils.common import torch_release + + +class CompilerInterface: + """ + The interface for a compiler that can be used by vLLM. + """ + + # The name of the compiler, e.g. inductor. + # This is a class-level attribute. + name: str + + def initialize_cache( + self, cache_dir: str, disable_cache: bool = False, prefix: str = "" + ): + """ + when the vLLM process uses `cache_dir` as the cache directory, + the compiler should initialize itself with the cache directory, + e.g. by re-directing its own cache directory to a sub-directory. + + prefix can be used in combination with cache_dir to figure out the base + cache directory, e.g. there're multiple parts of model being compiled, + but we want to share the same cache directory for all of them. + + e.g. + cache_dir = "/path/to/dir/backbone", prefix = "backbone" + cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head" + """ + pass + + def compute_hash(self) -> str: + """ + Gather all the relevant information from the vLLM config, + to compute a hash so that we can cache the compiled model. + + See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash] + to check what information + is already considered by default. This function should only + consider the information that is specific to the compiler. + """ + return "" + + def compile( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> tuple[Optional[Callable], Optional[Any]]: + """ + Compile the graph with the given example inputs and compiler config, + with a runtime shape. If the `runtime_shape` is None, it means + the `example_inputs` have a dynamic shape. Otherwise, the + `runtime_shape` specifies the shape of the inputs. Right now we only + support one variable shape for all inputs, which is the batchsize + (number of tokens) during inference. + + Dynamo will make sure `graph(*example_inputs)` is valid. + + The function should return a compiled callable function, as well as + a handle that can be used to directly load the compiled function. + + The handle should be a plain Python object, preferably a string or a + file path for readability. + + If the compiler doesn't support caching, it should return None for the + handle. If the compiler fails to compile the graph, it should return + None for the compiled function as well. + + `key` is required for StandaloneInductorAdapter, it specifies where to + save the compiled artifact. The compiled artifact gets saved to + `cache_dir/key`. + """ + return None, None + + def load( + self, + handle: Any, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None, + ) -> Callable: + """ + Load the compiled function from the handle. + Raises an error if the handle is invalid. + + The handle is the second return value of the `compile` function. + """ + raise NotImplementedError("caching is not supported") + + +def get_inductor_factors() -> list[Any]: + factors: list[Any] = [] + # summarize system state + from torch._inductor.codecache import CacheBase + + system_factors = CacheBase.get_system() + factors.append(system_factors) + + # summarize pytorch state + from torch._inductor.codecache import torch_key + + torch_factors = torch_key() + factors.append(torch_factors) + return factors + + +class AlwaysHitShapeEnv: + """ + Why do we need this class: + + For normal `torch.compile` usage, every compilation will have + one Dynamo bytecode compilation and one Inductor compilation. + The Inductor compilation happens under the context of the + Dynamo bytecode compilation, and that context is used to + determine the dynamic shape information, etc. + + For our use case, we only run Dynamo bytecode compilation once, + and run Inductor compilation multiple times with different shapes + plus a general shape. The compilation for specific shapes happens + outside of the context of the Dynamo bytecode compilation. At that + time, we don't have shape environment to provide to Inductor, and + it will fail the Inductor code cache lookup. + + By providing a dummy shape environment that always hits, we can + make the Inductor code cache lookup always hit, and we can + compile the graph for different shapes as needed. + + The following dummy methods are obtained by trial-and-error + until it works. + """ + + def __init__(self) -> None: + self.guards: list[Any] = [] + + def evaluate_guards_expression(self, *args, **kwargs): + return True + + def get_pruned_guards(self, *args, **kwargs): + return [] + + def produce_guards_expression(self, *args, **kwargs): + return "" + + +class InductorAdaptor(CompilerInterface): + """ + The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7. + """ + + name = "inductor" + + def compute_hash(self) -> str: + factors = get_inductor_factors() + hash_str = hashlib.md5( + str(factors).encode(), usedforsecurity=False + ).hexdigest()[:10] + return hash_str + + def initialize_cache( + self, cache_dir: str, disable_cache: bool = False, prefix: str = "" + ): + self.cache_dir = cache_dir + self.prefix = prefix + self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir + if disable_cache: + return + # redirect the cache directory to a sub-directory + # set flags so that Inductor and Triton store their cache + # in the cache_dir, then users only need to copy the cache_dir + # to another machine to reuse the cache. + inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache") + os.makedirs(inductor_cache, exist_ok=True) + os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache + triton_cache = os.path.join(self.base_cache_dir, "triton_cache") + os.makedirs(triton_cache, exist_ok=True) + os.environ["TRITON_CACHE_DIR"] = triton_cache + + def compile( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> tuple[Optional[Callable], Optional[Any]]: + compilation_counter.num_inductor_compiles += 1 + from torch._inductor.compile_fx import compile_fx + + current_config = {} + if compiler_config is not None: + current_config.update(compiler_config) + + # disable remote cache + current_config["fx_graph_cache"] = True + current_config["fx_graph_remote_cache"] = False + + set_inductor_config(current_config, runtime_shape) + + # inductor can inplace modify the graph, so we need to copy it + # see https://github.com/pytorch/pytorch/issues/138980 + graph = copy.deepcopy(graph) + + # it's the first time we compile this graph + # the assumption is that we don't have nested Inductor compilation. + # compiled_fx_graph_hash will only be called once, and we can hook + # it to get the hash of the compiled graph directly. + + hash_str, file_path = None, None + from torch._inductor.codecache import FxGraphCache, compiled_fx_graph_hash + + if torch_release[:2] == (2, 5): + original_load = FxGraphCache.load + original_load_name = "torch._inductor.codecache.FxGraphCache.load" + + def hijack_load(*args, **kwargs): + inductor_compiled_graph = original_load(*args, **kwargs) + nonlocal file_path + compiled_fn = inductor_compiled_graph.current_callable + file_path = compiled_fn.__code__.co_filename # noqa + if not file_path.startswith(self.base_cache_dir): + # hooked in the align_inputs_from_check_idxs function + # in torch/_inductor/utils.py + for cell in compiled_fn.__closure__: + if not callable(cell.cell_contents): + continue + if cell.cell_contents.__code__.co_filename.startswith( + self.base_cache_dir + ): + # this is the real file path compiled from Inductor + file_path = cell.cell_contents.__code__.co_filename + break + return inductor_compiled_graph + + hijacked_compile_fx_inner = ( + torch._inductor.compile_fx.compile_fx_inner + ) # noqa + elif torch_release >= (2, 6): + # function renamed in 2.6 + original_load_name = None + + def hijacked_compile_fx_inner(*args, **kwargs): + output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs) + nonlocal hash_str + inductor_compiled_graph = output + if inductor_compiled_graph is not None: + nonlocal file_path + compiled_fn = inductor_compiled_graph.current_callable + file_path = compiled_fn.__code__.co_filename # noqa + if not file_path.startswith(self.base_cache_dir): + # hooked in the align_inputs_from_check_idxs function + # in torch/_inductor/utils.py + for cell in compiled_fn.__closure__: + if not callable(cell.cell_contents): + continue + code = cell.cell_contents.__code__ + if code.co_filename.startswith(self.base_cache_dir): + # this is the real file path + # compiled from Inductor + file_path = code.co_filename + break + hash_str = inductor_compiled_graph._fx_graph_cache_key + return output + + def hijack_compiled_fx_graph_hash(*args, **kwargs): + out = compiled_fx_graph_hash(*args, **kwargs) + nonlocal hash_str + hash_str = out[0] + return out + + def _check_can_cache(*args, **kwargs): + # no error means it can be cached. + # Inductor refuses to cache the graph outside of Dynamo + # tracing context, and also disables caching for graphs + # with high-order ops. + # For vLLM, in either case, we want to cache the graph. + # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa + return + + def _get_shape_env() -> AlwaysHitShapeEnv: + return AlwaysHitShapeEnv() + + with ExitStack() as stack: + # hijack to get the compiled graph itself + if original_load_name is not None: + stack.enter_context(patch(original_load_name, hijack_load)) + + # for hijacking the hash of the compiled graph + stack.enter_context( + patch( + "torch._inductor.codecache.compiled_fx_graph_hash", + hijack_compiled_fx_graph_hash, + ) + ) + + # for providing a dummy shape environment + stack.enter_context( + patch( + "torch._inductor.codecache.FxGraphCache._get_shape_env", + _get_shape_env, + ) + ) + + from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache + + # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache + if hasattr(AOTAutogradCache, "_get_shape_env"): + stack.enter_context( + patch( + "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env", + _get_shape_env, + ) + ) + + # for forcing the graph to be cached + stack.enter_context( + patch( + "torch._inductor.codecache.FxGraphCache._check_can_cache", + _check_can_cache, + ) + ) + + # Dynamo metrics context, see method for more details. + stack.enter_context(self.metrics_context()) + + # Disable remote caching. When these are on, on remote cache-hit, + # the monkey-patched functions never actually get called. + # vLLM today assumes and requires the monkey-patched functions to + # get hit. + # TODO(zou3519): we're going to replace this all with + # standalone_compile sometime. + + stack.enter_context( + torch._inductor.config.patch(fx_graph_remote_cache=False) + ) + # InductorAdaptor (unfortunately) requires AOTAutogradCache + # to be turned off to run. It will fail to acquire the hash_str + # and error if not. + # StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem. + stack.enter_context( + torch._functorch.config.patch(enable_autograd_cache=False) + ) + stack.enter_context( + torch._functorch.config.patch(enable_remote_autograd_cache=False) + ) + + with pass_context(runtime_shape): + compiled_graph = compile_fx( + graph, + example_inputs, + inner_compile=hijacked_compile_fx_inner, + config_patches=current_config, + ) + return compiled_graph, (hash_str, file_path) + + def load( + self, + handle: Any, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None, + ) -> Callable: + assert isinstance(handle, tuple) + assert isinstance(handle[0], str) + assert isinstance(handle[1], str) + hash_str = handle[0] + + from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache + from torch._inductor.codecache import FxGraphCache + + with ExitStack() as exit_stack: + exit_stack.enter_context( + patch( + "torch._inductor.codecache.FxGraphCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv(), + ) + ) + # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache + if hasattr(AOTAutogradCache, "_get_shape_env"): + exit_stack.enter_context( + patch( + "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv(), + ) + ) + + # Dynamo metrics context, see method for more details. + exit_stack.enter_context(self.metrics_context()) + + if torch_release[:2] == (2, 5): + inductor_compiled_graph = FxGraphCache._lookup_graph( + hash_str, example_inputs, True, False + ) + assert inductor_compiled_graph is not None, ( + "Inductor cache lookup failed. Please remove" + f"the cache directory and try again." # noqa + ) + elif torch_release >= (2, 6): + from torch._inductor.output_code import CompiledFxGraphConstantsWithGm + + constants = CompiledFxGraphConstantsWithGm(graph) + inductor_compiled_graph, _ = FxGraphCache._lookup_graph( + hash_str, example_inputs, True, None, constants + ) + assert inductor_compiled_graph is not None, ( + "Inductor cache lookup failed. Please remove" + f"the cache directory and try again." # noqa + ) + + # Inductor calling convention (function signature): + # f(list) -> tuple + # Dynamo calling convention (function signature): + # f(*args) -> Any + + # need to know if the graph returns a tuple + from torch._inductor.compile_fx import graph_returns_tuple + + returns_tuple = graph_returns_tuple(graph) + + # this is the callable we return to Dynamo to run + def compiled_graph(*args): + # convert args to list + list_args = list(args) + graph_output = inductor_compiled_graph(list_args) + # unpack the tuple if needed + if returns_tuple: + return graph_output + else: + return graph_output[0] + + return compiled_graph + + def metrics_context(self) -> contextlib.AbstractContextManager: + """ + This method returns the Dynamo metrics context (if it exists, + otherwise a null context). It is used by various compile components. + Present in torch>=2.6, it's used inside FxGraphCache in + torch==2.6 (but not after). It might also be used in various other + torch.compile internal functions. + + Because it is re-entrant, we always set it (even if entering via Dynamo + and the context was already entered). We might want to revisit if it + should be set at a different level of compilation. + + This is likely a bug in PyTorch: public APIs should not rely on + manually setting up internal contexts. But we also rely on non-public + APIs which might not provide these guarantees. + """ + import torch._dynamo.utils + + return torch._dynamo.utils.get_metrics_context() + + +def set_inductor_config(config, runtime_shape): + if isinstance(runtime_shape, int): + # for a specific batchsize, tuning triton kernel parameters + # can be beneficial + config["max_autotune"] = True + config["coordinate_descent_tuning"] = True + + +class EagerAdapter(CompilerInterface): + name = "eager" + + def compile( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + num_graphs: int = 1, + ) -> tuple[Optional[Callable], Optional[Any]]: + return graph, None + + def load( + self, + handle: Any, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None, + num_graphs: int = 1, + ) -> Callable: + raise NotImplementedError("eager compilation is not supported") diff --git a/sglang/python/sglang/srt/compilation/cuda_piecewise_backend.py b/sglang/python/sglang/srt/compilation/cuda_piecewise_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..8ca1e6a43cbcd0d3d6a08969a52fd73ee5e34f36 --- /dev/null +++ b/sglang/python/sglang/srt/compilation/cuda_piecewise_backend.py @@ -0,0 +1,206 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/cuda_piecewise_backend.py + +import dataclasses +import logging +from contextlib import ExitStack +from typing import Any, Callable, Optional +from unittest.mock import patch + +import torch +import torch.fx as fx + +from sglang.srt.compilation.compilation_config import CompilationConfig +from sglang.srt.compilation.compilation_counter import compilation_counter +from sglang.srt.compilation.piecewise_context_manager import ( + get_pcg_capture_stream, + is_in_pcg_torch_compile, +) +from sglang.srt.compilation.weak_ref_tensor import weak_ref_tensors + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ConcreteSizeEntry: + runtime_shape: int + need_to_compile: bool # the size is in compile_sizes + use_cudagraph: bool # the size is in cudagraph_capture_sizes + + compiled: bool = False + runnable: Callable = None # type: ignore + num_finished_warmup: int = 0 + cudagraph: Optional[torch.cuda.CUDAGraph] = None + output: Optional[Any] = None + + # for cudagraph debugging, track the input addresses + # during capture, and check if they are the same during replay + input_addresses: Optional[list[int]] = None + + +class CUDAPiecewiseBackend: + + def __init__( + self, + graph: fx.GraphModule, + compile_config: CompilationConfig, + inductor_config: dict[str, Any], + graph_pool: Any, + piecewise_compile_index: int, + total_piecewise_compiles: int, + sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + sglang_backend, + ): + """ + The backend for piecewise compilation. + It mainly handles the compilation and cudagraph capturing. + + We will compile `self.graph` once for the general shape, + and then compile for different shapes specified in + `compilation_config.compile_sizes`. + + Independently, we will capture cudagraph for different shapes. + + If a shape needs both compilation and cudagraph, we will + compile it first, and then capture cudagraph. + """ + self.graph = graph + self.inductor_config = inductor_config + self.graph_pool = graph_pool + self.piecewise_compile_index = piecewise_compile_index + self.total_piecewise_compiles = total_piecewise_compiles + self.sglang_backend = sglang_backend + + self.is_first_graph = piecewise_compile_index == 0 + self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1 + + self.compile_sizes: set[int] = set([]) + self.compile_config = compile_config + self.cudagraph_capture_sizes: set[int] = set(compile_config.get_capture_sizes()) + + self.first_run_finished = False + + self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa + + self.sym_shape_indices = sym_shape_indices + + # the entries for different shapes that we need to either + # compile or capture cudagraph + self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} + + # to_be_compiled_sizes tracks the remaining sizes to compile, + # and updates during the compilation process, so we need to copy it + self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() + for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): + self.concrete_size_entries[shape] = ConcreteSizeEntry( + runtime_shape=shape, + need_to_compile=shape in self.compile_sizes, + use_cudagraph=shape in self.cudagraph_capture_sizes, + ) + + def check_for_ending_compilation(self): + if self.is_last_graph and not self.to_be_compiled_sizes: + # no specific sizes to compile + # save the hash of the inductor graph for the next run + self.sglang_backend.compiler_manager.save_to_file() + + def __call__(self, *args) -> Any: + if not self.first_run_finished: + self.first_run_finished = True + self.check_for_ending_compilation() + return self.compiled_graph_for_general_shape(*args) + + if len(self.sym_shape_indices) == 0: + return self.compiled_graph_for_general_shape(*args) + + runtime_shape = args[self.sym_shape_indices[0]] + if runtime_shape not in self.concrete_size_entries: + # we don't need to do anything for this shape + return self.compiled_graph_for_general_shape(*args) + + entry = self.concrete_size_entries[runtime_shape] + + if entry.runnable is None: + entry.runnable = self.compiled_graph_for_general_shape + + if entry.need_to_compile and not entry.compiled: + entry.compiled = True + self.to_be_compiled_sizes.remove(runtime_shape) + # args are real arguments + entry.runnable = self.sglang_backend.compiler_manager.compile( + self.graph, + args, + self.inductor_config, + graph_index=self.piecewise_compile_index, + num_graphs=self.total_piecewise_compiles, + runtime_shape=runtime_shape, + ) + + # finished compilations for all required shapes + if self.is_last_graph and not self.to_be_compiled_sizes: + self.check_for_ending_compilation() + + if is_in_pcg_torch_compile(): + return entry.runnable(*args) + + if entry.cudagraph is None: + if entry.num_finished_warmup < 1: # noqa + entry.num_finished_warmup += 1 + return entry.runnable(*args) + + if self.compile_config.get_enable_debug_mode(): + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses + cudagraph = torch.cuda.CUDAGraph() + + with ExitStack() as stack: + if not self.is_first_graph: + # during every model forward, we will capture + # many pieces of cudagraphs (roughly one per layer). + # running gc again and again across layers will + # make the cudagraph capture very slow. + # therefore, we only run gc for the first graph, + # and disable gc for the rest of the graphs. + stack.enter_context(patch("gc.collect", lambda: None)) + stack.enter_context(patch("torch.cuda.empty_cache", lambda: None)) + # mind-exploding: carefully manage the reference and memory. + stream = get_pcg_capture_stream() + assert ( + stream is not None + ), "PCG capture stream is not set, please check if runtime recompilation happened" + with torch.cuda.graph(cudagraph, pool=self.graph_pool, stream=stream): + # `output` is managed by pytorch's cudagraph pool + output = entry.runnable(*args) + if self.is_last_graph: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph, because the output of the last graph + # will not be used by any other cuda graph. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.cudagraph = cudagraph + + compilation_counter.num_cudagraph_captured += 1 + + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during cuda graph capture + return output + + if self.compile_config.get_enable_debug_mode(): + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + "Input addresses for cudagraphs are different during replay." + f" Expected {entry.input_addresses}, got {new_input_addresses}" + ) + entry.cudagraph.replay() + return entry.output diff --git a/sglang/python/sglang/srt/compilation/fix_functionalization.py b/sglang/python/sglang/srt/compilation/fix_functionalization.py new file mode 100644 index 0000000000000000000000000000000000000000..8673e3576b00444fda7221b6d6a5a732c7a0a368 --- /dev/null +++ b/sglang/python/sglang/srt/compilation/fix_functionalization.py @@ -0,0 +1,134 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fix_functionalization.py + +import logging +import operator +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch._higher_order_ops.auto_functionalize import auto_functionalized + +from sglang.srt.compilation.fx_utils import is_func +from sglang.srt.compilation.inductor_pass import SGLangInductorPass + +logger = logging.getLogger(__name__) + + +class FixFunctionalizationPass(SGLangInductorPass): + """ + This pass defunctionalizes certain nodes to avoid redundant tensor copies. + After this pass, DCE (dead-code elimination) should never be run, + as de-functionalized nodes may appear as dead code. + + To add new nodes to defunctionalize, add to the if-elif chain in __call__. + """ + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before_fix_functionalization") + + self.nodes_to_remove: list[torch.fx.Node] = [] + count = 0 + for node in graph.nodes: + if not is_func(node, auto_functionalized): + continue # Avoid deep if-elif nesting + count += 1 + + self.dump_graph(graph, "before_fix_functionalization_cleanup") + + # Remove the nodes all at once + count_removed = len(self.nodes_to_remove) + for node in self.nodes_to_remove: + graph.erase_node(node) + + logger.debug( + "De-functionalized %s nodes, removed %s nodes", count, count_removed + ) + self.dump_graph(graph, "after_fix_functionalization") + self.end_and_log() + + def _remove(self, node_or_nodes: Union[torch.fx.Node, Iterable[torch.fx.Node]]): + """ + Stage a node (or nodes) for removal at the end of the pass. + """ + if isinstance(node_or_nodes, torch.fx.Node): + self.nodes_to_remove.append(node_or_nodes) + else: + self.nodes_to_remove.extend(node_or_nodes) + + def defunctionalize( + self, + graph: torch.fx.Graph, + node: torch.fx.Node, + mutated_args: dict[int, Union[torch.fx.Node, str]], + args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None, + ): + """ + De-functionalize a node by replacing it with a call to the original. + It also replaces the getitem users with the mutated arguments. + See replace_users_with_mutated_args and insert_defunctionalized. + """ + self.replace_users_with_mutated_args(node, mutated_args) + self.insert_defunctionalized(graph, node, args=args) + self._remove(node) + + def replace_users_with_mutated_args( + self, node: torch.fx.Node, mutated_args: dict[int, Union[torch.fx.Node, str]] + ): + """ + Replace all getitem users of the auto-functionalized node with the + mutated arguments. + :param node: The auto-functionalized node + :param mutated_args: The mutated arguments, indexed by getitem index. + If the value of an arg is a string, `node.kwargs[arg]` is used. + """ + for idx, user in self.getitem_users(node).items(): + arg = mutated_args[idx] + arg = node.kwargs[arg] if isinstance(arg, str) else arg + user.replace_all_uses_with(arg) + self._remove(user) + + def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]: + """ + Returns the operator.getitem users of the auto-functionalized node, + indexed by the index they are getting. + """ + users = {} + for user in node.users: + if is_func(user, operator.getitem): + idx = user.args[1] + users[idx] = user + return users + + def insert_defunctionalized( + self, + graph: torch.fx.Graph, + node: torch.fx.Node, + args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None, + ): + """ + Insert a new defunctionalized node into the graph before node. + If one of the kwargs is 'out', provide args directly, + as node.kwargs cannot be used. + See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 + + :param graph: Graph to insert the defunctionalized node into + :param node: The auto-functionalized node to defunctionalize + :param args: If we cannot use kwargs, specify args directly. + If an arg is a string, `node.kwargs[arg]` is used. + """ # noqa: E501 + assert is_func( + node, auto_functionalized + ), f"node must be auto-functionalized, is {node} instead" + + # Create a new call to the original function + with graph.inserting_before(node): + function = node.args[0] + if args is None: + graph.call_function(function, kwargs=node.kwargs) + else: + # Args passed as strings refer to items in node.kwargs + args = tuple( + node.kwargs[arg] if isinstance(arg, str) else arg for arg in args + ) + graph.call_function(function, args=args) diff --git a/sglang/python/sglang/srt/compilation/fx_utils.py b/sglang/python/sglang/srt/compilation/fx_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b2e863e687181cf0e046fd9e413360c1d867112c --- /dev/null +++ b/sglang/python/sglang/srt/compilation/fx_utils.py @@ -0,0 +1,83 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fx_utils.py + +import operator +from collections.abc import Iterable, Iterator +from typing import Optional + +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._ops import OpOverload + + +def is_func(node: fx.Node, target) -> bool: + return node.op == "call_function" and node.target == target + + +def is_auto_func(node: fx.Node, op: OpOverload) -> bool: + return is_func(node, auto_functionalized) and node.args[0] == op + + +# Returns the first specified node with the given op (if it exists) +def find_specified_fn_maybe( + nodes: Iterable[fx.Node], op: OpOverload +) -> Optional[fx.Node]: + for node in nodes: + if node.target == op: + return node + return None + + +# Returns the first specified node with the given op +def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node: + node = find_specified_fn_maybe(nodes, op) + assert node is not None, f"Could not find {op} in nodes {nodes}" + return node + + +# Returns the first auto_functionalized node with the given op (if it exists) +def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> Optional[fx.Node]: + for node in nodes: + if is_func(node, auto_functionalized) and node.args[0] == op: # noqa + return node + return None + + +# Returns the first auto_functionalized node with the given op +def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node: + node = find_auto_fn_maybe(nodes, op) + assert node is not None, f"Could not find {op} in nodes {nodes}" + return node + + +# Returns the getitem node that extracts the idx-th element from node +# (if it exists) +def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]: + for user in node.users: + if is_func(user, operator.getitem) and user.args[1] == idx: + return user + return None + + +# Returns the getitem node that extracts the idx-th element from node +def find_getitem(node: fx.Node, idx: int) -> fx.Node: + ret = find_getitem_maybe(node, idx) + assert ret is not None, f"Could not find getitem {idx} in node {node}" + return ret + + +# An auto-functionalization-aware utility for finding nodes with a specific op +def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]: + if not op._schema.is_mutable: + yield from graph.find_nodes(op="call_function", target=op) + + for n in graph.find_nodes(op="call_function", target=auto_functionalized): + if n.args[0] == op: + yield n + + +# Asserts that the node only has one user and returns it +# Even if a node has only 1 user, it might share storage with another node, +# which might need to be taken into account. +def get_only_user(node: fx.Node) -> fx.Node: + assert len(node.users) == 1 + return next(iter(node.users)) diff --git a/sglang/python/sglang/srt/compilation/inductor_pass.py b/sglang/python/sglang/srt/compilation/inductor_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..acbde65bf8ab383a80c6f1e6eee9a57440a157cf --- /dev/null +++ b/sglang/python/sglang/srt/compilation/inductor_pass.py @@ -0,0 +1,140 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/inductor_pass.py + +import hashlib +import inspect +import json +import logging +import time +import types +from contextlib import contextmanager +from typing import Any, Callable, Optional, Union + +import torch +from torch import fx +from torch._dynamo.utils import lazy_format_graph_code +from torch._inductor.custom_graph_pass import CustomGraphPass + +logger = logging.getLogger(__name__) + +_pass_context = None + + +class PassContext: + + def __init__(self, runtime_shape: Optional[int]): + self.runtime_shape = runtime_shape + + +def get_pass_context() -> PassContext: + """Get the current pass context.""" + assert _pass_context is not None + return _pass_context + + +@contextmanager +def pass_context(runtime_shape: Optional[int]): + """A context manager that stores the current pass context, + usually it is a list of sizes to specialize. + """ + global _pass_context + prev_context = _pass_context + _pass_context = PassContext(runtime_shape) + try: + yield + finally: + _pass_context = prev_context + + +class InductorPass(CustomGraphPass): + """ + A custom graph pass that uses a hash of its source as the UUID. + This is defined as a convenience and should work in most cases. + """ + + def uuid(self) -> Any: + """ + Provide a unique identifier for the pass, used in Inductor code cache. + This should depend on the pass implementation, so that changes to the + pass result in recompilation. + By default, the object source is hashed. + """ + return InductorPass.hash_source(self) + + @staticmethod + def hash_source(*srcs: Union[str, Any]): + """ + Utility method to hash the sources of functions or objects. + :param srcs: strings or objects to add to the hash. + Objects and functions have their source inspected. + :return: + """ + hasher = hashlib.sha256() + for src in srcs: + if isinstance(src, str): + src_str = src + elif isinstance(src, types.FunctionType): + src_str = inspect.getsource(src) + else: + src_str = inspect.getsource(src.__class__) + hasher.update(src_str.encode("utf-8")) + return hasher.hexdigest() + + @staticmethod + def hash_dict(dict_: dict[Any, Any]): + """ + Utility method to hash a dictionary, can alternatively be used for uuid. + :return: A sha256 hash of the json rep of the dictionary. + """ + encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + def is_applicable_for_shape(self, shape: Optional[int]): + return True + + +class CallableInductorPass(InductorPass): + """ + This class is a wrapper for a callable that automatically provides an + implementation of the UUID. + """ + + def __init__( + self, callable: Callable[[fx.Graph], None], uuid: Optional[Any] = None + ): + self.callable = callable + self._uuid = self.hash_source(callable) if uuid is None else uuid + + def __call__(self, graph: torch.fx.Graph): + self.callable(graph) + + def uuid(self) -> Any: + return self._uuid + + +class SGLangInductorPass(InductorPass): + + def __init__( + self, + ): + self.pass_name = self.__class__.__name__ + + def dump_graph(self, graph: torch.fx.Graph, stage: str): + lazy_format_graph_code(stage, graph.owning_module) + + def begin(self): + self._start_time = time.perf_counter_ns() + + def end_and_log(self): + self._end_time = time.perf_counter_ns() + duration_ms = float(self._end_time - self._start_time) / 1.0e6 + logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) + + +class PrinterInductorPass(SGLangInductorPass): + + def __init__(self, name: str): + super().__init__() + self.name = name + + def __call__(self, graph: torch.fx.Graph): + self.dump_graph(graph, self.name) diff --git a/sglang/python/sglang/srt/compilation/npu_piecewise_backend.py b/sglang/python/sglang/srt/compilation/npu_piecewise_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..dc97bd5c3f74b28d5fd1c91ed6c7810e757a1137 --- /dev/null +++ b/sglang/python/sglang/srt/compilation/npu_piecewise_backend.py @@ -0,0 +1,109 @@ +from contextlib import ExitStack +from typing import Any, Callable +from unittest.mock import patch + +import torch +import torch.fx as fx + +from sglang.srt.compilation.compilation_config import CompilationConfig +from sglang.srt.compilation.compilation_counter import compilation_counter +from sglang.srt.compilation.cuda_piecewise_backend import ( + CUDAPiecewiseBackend, + weak_ref_tensors, +) + + +class NPUPiecewiseBackend(CUDAPiecewiseBackend): + def __init__( + self, + graph: fx.GraphModule, + compile_config: CompilationConfig, + inductor_config: dict[str, Any], + graph_pool: Any, + piecewise_compile_index: int, + total_piecewise_compiles: int, + sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + sglang_backend, + ): + super().__init__( + graph, + compile_config, + inductor_config, + graph_pool, + piecewise_compile_index, + total_piecewise_compiles, + sym_shape_indices, + compiled_graph_for_general_shape, + sglang_backend, + ) + + def __call__(self, *args): + runtime_shape = args[self.sym_shape_indices[0]] + if runtime_shape not in self.concrete_size_entries: + # we don't need to do anything for this shape + return self.compiled_graph_for_general_shape(*args) + + entry = self.concrete_size_entries[runtime_shape] + + if entry.runnable is None: + entry.runnable = self.compiled_graph_for_general_shape + + if entry.cudagraph is None: + if entry.num_finished_warmup < 1: # noqa + entry.num_finished_warmup += 1 + return entry.runnable(*args) + + if self.compile_config.get_enable_debug_mode(): + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses + npugraph = torch.npu.NPUGraph() + + with ExitStack() as stack: + if not self.is_first_graph: + # during every model forward, we will capture + # many pieces of cudagraphs (roughly one per layer). + # running gc again and again across layers will + # make the cudagraph capture very slow. + # therefore, we only run gc for the first graph, + # and disable gc for the rest of the graphs. + stack.enter_context(patch("gc.collect", lambda: None)) + stack.enter_context(patch("torch.npu.empty_cache", lambda: None)) + + # mind-exploding: carefully manage the reference and memory. + with torch.npu.graph(npugraph, pool=self.graph_pool): + # `output` is managed by pytorch's cudagraph pool + output = entry.runnable(*args) + if self.is_last_graph: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph, because the output of the last graph + # will not be used by any other cuda graph. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.cudagraph = npugraph + + compilation_counter.num_cudagraph_captured += 1 + + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during cuda graph capture + return output + + if self.compile_config.get_enable_debug_mode(): + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + "Input addresses for cudagraphs are different during replay." + f" Expected {entry.input_addresses}, got {new_input_addresses}" + ) + entry.cudagraph.replay() + return entry.output diff --git a/sglang/python/sglang/srt/compilation/pass_manager.py b/sglang/python/sglang/srt/compilation/pass_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..9173976f1878b01035d069844989b50fbc442f8b --- /dev/null +++ b/sglang/python/sglang/srt/compilation/pass_manager.py @@ -0,0 +1,66 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/pass_manager.py + +import logging + +from torch import fx as fx + +from sglang.srt.compilation.fix_functionalization import FixFunctionalizationPass +from sglang.srt.compilation.inductor_pass import ( + CustomGraphPass, + InductorPass, + SGLangInductorPass, + get_pass_context, +) + +logger = logging.getLogger(__name__) + + +class PostGradPassManager(CustomGraphPass): + """ + The pass manager for post-grad passes. + It handles configuration, adding custom passes, and running passes. + It supports uuid for the Inductor code cache. That includes torch<2.6 + support using pickling (in .inductor_pass.CustomGraphPass). + + The order of the post-grad post-passes is: + 1. passes (constructor parameter) + 2. default passes (NoopEliminationPass, FusionPass) + 3. config["post_grad_custom_post_pass"] (if it exists) + 4. fix_functionalization + This way, all passes operate on a functionalized graph. + """ + + def __init__(self): + self.passes: list[SGLangInductorPass] = [] + + def __call__(self, graph: fx.Graph): + shape = get_pass_context().runtime_shape + for pass_ in self.passes: + if pass_.is_applicable_for_shape(shape): + pass_(graph) + + # always run fix_functionalization last + self.fix_functionalization(graph) + + def configure( + self, + ): + self.pass_config = dict() + self.fix_functionalization = FixFunctionalizationPass() + + def add(self, pass_: InductorPass): + assert isinstance(pass_, InductorPass) + self.passes.append(pass_) + + def uuid(self): + """ + The PostGradPassManager is set as a custom pass in the Inductor and + affects compilation caching. Its uuid depends on the UUIDs of all + dependent passes and the pass config. See InductorPass for more info. + """ + pass_manager_uuid = "fshdakhsa" + state = {"pass_config": pass_manager_uuid, "passes": []} + for pass_ in self.passes: + state["passes"].append(pass_.uuid()) + state["passes"].append(self.fix_functionalization.uuid()) + return InductorPass.hash_dict(state) diff --git a/sglang/python/sglang/srt/compilation/piecewise_context_manager.py b/sglang/python/sglang/srt/compilation/piecewise_context_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..20a08a9972b9e5f1e730aa5bea7d8621f72b816b --- /dev/null +++ b/sglang/python/sglang/srt/compilation/piecewise_context_manager.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +import logging +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, List, Optional + +import torch + +logger = logging.getLogger(__name__) + + +if TYPE_CHECKING: + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + +_in_piecewise_cuda_graph = False +_in_pcg_torch_compile = False +_pcg_capture_stream = None + + +def is_in_piecewise_cuda_graph(): + return _in_piecewise_cuda_graph + + +def is_in_pcg_torch_compile(): + return _in_pcg_torch_compile + + +def get_pcg_capture_stream(): + return _pcg_capture_stream + + +@contextmanager +def enable_piecewise_cuda_graph_compile(): + global _in_pcg_torch_compile + _in_pcg_torch_compile = True + yield + _in_pcg_torch_compile = False + + +@contextmanager +def enable_piecewise_cuda_graph(): + global _in_piecewise_cuda_graph + _in_piecewise_cuda_graph = True + try: + yield + except Exception as e: + logger.error( + "Piecewise CUDA Graph failed with error: %s\n%s", + e, + PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG, + ) + raise + finally: + _in_piecewise_cuda_graph = False + + +@contextmanager +def set_pcg_capture_stream(stream: torch.cuda.Stream): + global _pcg_capture_stream + _pcg_capture_stream = stream + yield + _pcg_capture_stream = None + + +@dataclass +class ForwardContext: + def __init__(self): + self.forward_batch = None + self.attention_layers = None + self.quant_config = None + self.moe_layers = None + self.moe_fusions = None + + def set_forward_batch(self, forward_batch: ForwardBatch): + self.forward_batch = forward_batch + + def set_attention_layers(self, layers: List[Any]): + self.attention_layers = layers + + def set_quant_config(self, quant_config: Any): + self.quant_config = quant_config + + def set_moe_layers(self, layers: List[Any]): + self.moe_layers = layers + + def set_moe_fusions(self, fusions: List[Any]): + self.moe_fusions = fusions + + +_forward_context: Optional[ForwardContext] = None + + +def get_forward_context() -> Optional[ForwardContext]: + if _forward_context is None: + return None + return _forward_context + + +@contextmanager +def set_forward_context( + forward_batch: ForwardBatch, + attention_layers: List[Any], + quant_config: Any, + moe_layers: List[Any], + moe_fusions: List[Any], +): + global _forward_context + _forward_context = ForwardContext() + _forward_context.set_forward_batch(forward_batch) + _forward_context.set_attention_layers(attention_layers) + _forward_context.set_quant_config(quant_config) + _forward_context.set_moe_layers(moe_layers) + _forward_context.set_moe_fusions(moe_fusions) + try: + yield + finally: + _forward_context = None + + +PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG = ( + "Piecewise CUDA Graph is enabled by default as an experimental feature.\n" + "To work around this error, add --disable-piecewise-cuda-graph to your launch command.\n" + "Please report this issue at https://github.com/sgl-project/sglang/issues/new/choose" +) diff --git a/sglang/python/sglang/srt/compilation/weak_ref_tensor.py b/sglang/python/sglang/srt/compilation/weak_ref_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..3564849c74075c45cc3fd5c43355fec2c603ab77 --- /dev/null +++ b/sglang/python/sglang/srt/compilation/weak_ref_tensor.py @@ -0,0 +1,28 @@ +from typing import Any, Union + +import torch + +from sglang.srt.utils.common import is_cuda, is_hip, is_npu + +if is_cuda() or is_hip(): + from sgl_kernel import weak_ref_tensor +elif is_npu(): + from torch_npu._C import _weak_ref_tensor as weak_ref_tensor +else: + raise NotImplementedError("weak_ref_tensor is implemented only for CUDA and NPU.") + + +def weak_ref_tensors( + tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]], +) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: + """ + Convenience function to create weak references to tensors, + for single tensor, list of tensors or tuple of tensors. + """ + if isinstance(tensors, torch.Tensor): + return weak_ref_tensor(tensors) + if isinstance(tensors, list): + return [weak_ref_tensor(t) for t in tensors] + if isinstance(tensors, tuple): + return tuple(weak_ref_tensor(t) for t in tensors) + raise ValueError("Invalid type for tensors") diff --git a/sglang/python/sglang/srt/configs/__init__.py b/sglang/python/sglang/srt/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a3b37f54c0cdba156bf2b36a4617dc2462bad6d --- /dev/null +++ b/sglang/python/sglang/srt/configs/__init__.py @@ -0,0 +1,64 @@ +from sglang.srt.configs.afmoe import AfmoeConfig +from sglang.srt.configs.bailing_hybrid import BailingHybridConfig +from sglang.srt.configs.chatglm import ChatGLMConfig +from sglang.srt.configs.dbrx import DbrxConfig +from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config +from sglang.srt.configs.dots_ocr import DotsOCRConfig +from sglang.srt.configs.dots_vlm import DotsVLMConfig +from sglang.srt.configs.exaone import ExaoneConfig +from sglang.srt.configs.falcon_h1 import FalconH1Config +from sglang.srt.configs.granitemoehybrid import GraniteMoeHybridConfig +from sglang.srt.configs.janus_pro import MultiModalityConfig +from sglang.srt.configs.jet_nemotron import JetNemotronConfig +from sglang.srt.configs.jet_vlm import JetVLMConfig +from sglang.srt.configs.kimi_k25 import KimiK25Config +from sglang.srt.configs.kimi_linear import KimiLinearConfig +from sglang.srt.configs.kimi_vl import KimiVLConfig +from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig +from sglang.srt.configs.lfm2 import Lfm2Config +from sglang.srt.configs.lfm2_moe import Lfm2MoeConfig +from sglang.srt.configs.longcat_flash import LongcatFlashConfig +from sglang.srt.configs.nano_nemotron_vl import NemotronH_Nano_VL_V2_Config +from sglang.srt.configs.nemotron_h import NemotronHConfig +from sglang.srt.configs.olmo3 import Olmo3Config +from sglang.srt.configs.qwen3_5 import Qwen3_5Config, Qwen3_5MoeConfig +from sglang.srt.configs.qwen3_next import Qwen3NextConfig +from sglang.srt.configs.step3_vl import ( + Step3TextConfig, + Step3VisionEncoderConfig, + Step3VLConfig, +) +from sglang.srt.configs.step3p5 import Step3p5Config + +__all__ = [ + "AfmoeConfig", + "BailingHybridConfig", + "ExaoneConfig", + "ChatGLMConfig", + "DbrxConfig", + "DeepseekVL2Config", + "LongcatFlashConfig", + "MultiModalityConfig", + "KimiVLConfig", + "MoonViTConfig", + "Step3VLConfig", + "Step3TextConfig", + "Step3VisionEncoderConfig", + "Olmo3Config", + "KimiLinearConfig", + "KimiK25Config", + "Qwen3NextConfig", + "Qwen3_5Config", + "Qwen3_5MoeConfig", + "DotsVLMConfig", + "DotsOCRConfig", + "FalconH1Config", + "GraniteMoeHybridConfig", + "Lfm2Config", + "Lfm2MoeConfig", + "NemotronHConfig", + "NemotronH_Nano_VL_V2_Config", + "JetNemotronConfig", + "JetVLMConfig", + "Step3p5Config", +] diff --git a/sglang/python/sglang/srt/configs/afmoe.py b/sglang/python/sglang/srt/configs/afmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..eab477339bc04b9f72e46e23e3e19d43df4b0ba2 --- /dev/null +++ b/sglang/python/sglang/srt/configs/afmoe.py @@ -0,0 +1,102 @@ +from typing import List, Optional + +from transformers import PretrainedConfig + + +class AfmoeConfig(PretrainedConfig): + model_type = "afmoe" + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 4096, + intermediate_size: int = 11008, + moe_intermediate_size: int = 256, + num_hidden_layers: int = 32, + num_attention_heads: int = 32, + num_key_value_heads: Optional[int] = None, + head_dim: Optional[int] = None, + hidden_act: str = "silu", + max_position_embeddings: int = 131072, + initializer_range: float = 0.02, + rms_norm_eps: float = 1e-5, + use_cache: bool = True, + pad_token_id: Optional[int] = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + rope_theta: float = 10000.0, + rope_scaling: Optional[dict] = None, + attention_bias: bool = False, + attention_dropout: float = 0.0, + # MoE parameters + num_experts: Optional[int] = None, + num_experts_per_tok: Optional[int] = None, + num_shared_experts: int = 0, + num_dense_layers: int = 0, + # Routing parameters + score_func: str = "sigmoid", + route_norm: bool = True, + route_scale: float = 1.0, + n_group: int = 1, + topk_group: int = 1, + # Attention parameters + sliding_window: Optional[int] = None, + layer_types: Optional[List[str]] = None, + global_attn_every_n_layers: int = 4, + # muP scaling + mup_enabled: bool = False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.head_dim = ( + head_dim if head_dim is not None else hidden_size // num_attention_heads + ) + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + # MoE parameters + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.num_shared_experts = num_shared_experts + self.num_dense_layers = num_dense_layers + + # Routing parameters + self.score_func = score_func + self.route_norm = route_norm + self.route_scale = route_scale + self.n_group = n_group + self.topk_group = topk_group + + # Attention parameters + self.sliding_window = sliding_window + self.layer_types = layer_types + self.global_attn_every_n_layers = global_attn_every_n_layers + + # muP scaling + self.mup_enabled = mup_enabled + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/sglang/python/sglang/srt/configs/bailing_hybrid.py b/sglang/python/sglang/srt/configs/bailing_hybrid.py new file mode 100644 index 0000000000000000000000000000000000000000..40933d90a73cf2de2c16079028dc513ee27edf1e --- /dev/null +++ b/sglang/python/sglang/srt/configs/bailing_hybrid.py @@ -0,0 +1,188 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BailingHybrid model configuration""" + +import enum + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape + +logger = logging.get_logger(__name__) + + +class HybridLayerType(enum.Enum): + full_attention = "attention" + linear_attention = "linear_attention" + + +class BailingHybridConfig(PretrainedConfig): + + model_type = "bailing_hybrid" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=157184, + hidden_size=2048, + intermediate_size=5120, + num_hidden_layers=20, + num_attention_heads=16, + num_key_value_heads=4, + hidden_act="silu", + use_qkv_bias=False, # bailing only + use_bias=False, # bailing only + rms_norm_eps=1e-06, + tie_word_embeddings=False, # PretrainedConfig key, here change default value. + embedding_dropout=0.0, + attention_dropout=0.0, + output_dropout=0.0, + initializer_range=0.02, + max_position_embeddings=32768, + rope_theta=600000.0, + use_cache=True, + max_window_layers=20, + rope_scaling=None, + pad_token_id=156892, + eos_token_id=156892, + num_experts=256, + num_shared_experts=1, + num_experts_per_tok=8, + n_group=8, + topk_group=4, + moe_intermediate_size=512, + first_k_dense_replace=1, + head_dim=128, + output_router_logits=False, + use_qk_norm=True, + num_nextn_predict_layers=0, + mtp_loss_scaling_factor=0, + moe_router_enable_expert_bias=True, + routed_scaling_factor=1.0, + layer_group_size=1, + group_norm_size=1, + linear_silu=False, + kv_lora_rank=512, + q_lora_rank=None, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + rope_interleave=True, + **kwargs, + ): + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.use_qkv_bias = use_qkv_bias + self.use_bias = use_bias + self.rms_norm_eps = rms_norm_eps + self.embedding_dropout = embedding_dropout + self.attention_dropout = attention_dropout + self.output_dropout = output_dropout + self.num_nextn_predict_layers = num_nextn_predict_layers + self.mtp_loss_scaling_factor = mtp_loss_scaling_factor + self.initializer_range = initializer_range + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.use_cache = use_cache + self.max_window_layers = max_window_layers + self.head_dim = head_dim or self.hidden_size // self.num_attention_heads + self.rope_scaling = rope_scaling + self.use_qk_norm = use_qk_norm + self.moe_router_enable_expert_bias = moe_router_enable_expert_bias + self.routed_scaling_factor = routed_scaling_factor + + # MoE configs + self.num_experts = num_experts + self.num_shared_experts = num_shared_experts + self.num_experts_per_tok = num_experts_per_tok + self.n_group = n_group + self.topk_group = topk_group + self.moe_intermediate_size = moe_intermediate_size + self.first_k_dense_replace = first_k_dense_replace + self.output_router_logits = output_router_logits + + # Linear configs + self.layer_group_size = layer_group_size + self.group_norm_size = group_norm_size + self.linear_silu = linear_silu + self.num_linear_key_value_heads = num_attention_heads + # mla + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.rope_interleave = rope_interleave + self.for_nextn_model = False + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def layers_block_type(self): + if self.for_nextn_model: + return [HybridLayerType.full_attention.value] + + layer_type_list = [] + + for l in range(self.num_hidden_layers): + if (l + 1) % self.layer_group_size == 0: + layer_type_list.append(HybridLayerType.full_attention.value) + else: + layer_type_list.append(HybridLayerType.linear_attention.value) + + return layer_type_list + + @property + def linear_layer_ids(self): + return [ + i + for i, type_value in enumerate(self.layers_block_type) + if type_value == HybridLayerType.linear_attention.value + ] + + @property + def full_attention_layer_ids(self): + return [ + i + for i, type_value in enumerate(self.layers_block_type) + if type_value == HybridLayerType.full_attention.value + ] + + @property + def mamba2_cache_params(self) -> Mamba2CacheParams: + from sglang.srt.layers.dp_attention import get_attention_tp_size + + shape = Mamba2StateShape.create( + tp_world_size=get_attention_tp_size(), + intermediate_size=0, + n_groups=0, + num_heads=self.num_linear_key_value_heads, + head_dim=self.head_dim, + state_size=self.head_dim, + conv_kernel=1, + ) + + return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids) diff --git a/sglang/python/sglang/srt/configs/chatglm.py b/sglang/python/sglang/srt/configs/chatglm.py new file mode 100644 index 0000000000000000000000000000000000000000..9370c218aab8405944c81b6ddf59f6d61d7c999f --- /dev/null +++ b/sglang/python/sglang/srt/configs/chatglm.py @@ -0,0 +1,78 @@ +# Adapted from +# https://github.com/THUDM/ChatGLM2-6B +# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/chatglm.py + +# ChatGLM2 and ChatGLM3 share the same config. +# ChatGLM4 is officially supported by Huggingface +# transformers >= 4.46.0 is required +# https://huggingface.co/docs/transformers/en/model_doc/glm +from transformers import PretrainedConfig + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + attribute_map = { + "num_hidden_layers": "num_layers", + "n_head_kv": "multi_query_group_num", + } + + def __init__( + self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + interleaved_qkv=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs + ): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + # It is to be compatible with long lora. + self.max_position_embeddings = seq_length + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = ( + apply_residual_connection_post_layernorm + ) + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + self.interleaved_qkv = interleaved_qkv + super().__init__(**kwargs) diff --git a/sglang/python/sglang/srt/configs/dbrx.py b/sglang/python/sglang/srt/configs/dbrx.py new file mode 100644 index 0000000000000000000000000000000000000000..75ccbde944ea684689c656aaa75d3aeb1681c188 --- /dev/null +++ b/sglang/python/sglang/srt/configs/dbrx.py @@ -0,0 +1,279 @@ +# Adapted from +# https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py +# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/dbrx.py +"""Dbrx configuration.""" + +from typing import Any, Optional + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore + + +class DbrxAttentionConfig(PretrainedConfig): + """Configuration class for Dbrx Attention. + + [`DbrxAttention`] class. It is used to instantiate attention layers + according to the specified arguments, defining the layers architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + attn_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability for the attention layers. + clip_qkv (`float`, *optional*, defaults to None): + If not `None`, clip the queries, keys, and values in the attention layer to this value. + kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. + rope_theta (float): The base frequency for rope. + """ + + def __init__( + self, + attn_pdrop: float = 0, + clip_qkv: Optional[float] = None, + kv_n_heads: int = 1, + rope_theta: float = 10000.0, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.attn_pdrop = attn_pdrop + self.clip_qkv = clip_qkv + self.kv_n_heads = kv_n_heads + self.rope_theta = rope_theta + + for k in ["model_type"]: + if k in kwargs: + kwargs.pop(k) + if len(kwargs) != 0: + raise ValueError(f"Found unknown {kwargs=}") + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: str, **kwargs: Any + ) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + + if config_dict.get("model_type") == "dbrx": + config_dict = config_dict["attn_config"] + + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): + logger.warning( + "You are using a model of type %s to instantiate a model of " + "type %s. This is not supported for all configurations of " + "models and can yield errors.", + config_dict["model_type"], + cls.model_type, + ) + + return cls.from_dict(config_dict, **kwargs) + + +class DbrxFFNConfig(PretrainedConfig): + """Configuration class for Dbrx FFN. + + [`DbrxFFN`] class. It is used to instantiate feedforward layers according to + the specified arguments, defining the layers architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + ffn_act_fn (dict, optional): A dict specifying activation function for the FFN. + The dict should have a key 'name' with the value being the name of + the activation function along with any additional keyword arguments. + ffn_hidden_size (int, optional): The hidden size of the feedforward network. + moe_num_experts (int, optional): The number of experts in the mixture of experts layer. + moe_top_k (int, optional): The number of experts to use in the mixture of experts layer. + moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer. + moe_loss_weight (float, optional): The loss weight for the mixture of experts layer. + moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights. + uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment. + This should only be used for benchmarking purposes. + """ + + def __init__( + self, + ffn_act_fn: Optional[dict] = None, + ffn_hidden_size: int = 3584, + moe_num_experts: int = 4, + moe_top_k: int = 1, + moe_jitter_eps: Optional[float] = None, + moe_loss_weight: float = 0.01, + moe_normalize_expert_weights: Optional[float] = 1, + uniform_expert_assignment: bool = False, + **kwargs: Any, + ): + super().__init__() + if ffn_act_fn is None: + ffn_act_fn = {"name": "silu"} + self.ffn_act_fn = ffn_act_fn + self.ffn_hidden_size = ffn_hidden_size + self.moe_num_experts = moe_num_experts + self.moe_top_k = moe_top_k + self.moe_jitter_eps = moe_jitter_eps + self.moe_loss_weight = moe_loss_weight + self.moe_normalize_expert_weights = moe_normalize_expert_weights + self.uniform_expert_assignment = uniform_expert_assignment + + for k in ["model_type"]: + if k in kwargs: + kwargs.pop(k) + if len(kwargs) != 0: + raise ValueError(f"Found unknown {kwargs=}") + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: str, **kwargs: Any + ) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + + if config_dict.get("model_type") == "dbrx": + config_dict = config_dict["ffn_config"] + + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): + logger.warning( + "You are using a model of type %s to instantiate a model of " + "type %s. This is not supported for all " + "configurations of models and can yield errors.", + config_dict["model_type"], + cls.model_type, + ) + + return cls.from_dict(config_dict, **kwargs) + + +class DbrxConfig(PretrainedConfig): + """Configuration class for Dbrx. + + [`DbrxModel`]. It is used to instantiate a Dbrx model according to the + specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + d_model (`int`, *optional*, defaults to 6144): + Dimensionality of the embeddings and hidden states. + n_heads (`int`, *optional*, defaults to 48): + Number of attention heads for each attention layer in the Transformer encoder. + n_layers (`int`, *optional*, defaults to 40): + Number of hidden layers in the Transformer encoder. + max_seq_len (`int`, *optional*, defaults to 32768): + The maximum sequence length of the model. + vocab_size (`int`, *optional*, defaults to 100352): + Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by + the `inputs_ids` passed when calling [`DbrxModel`]. + resid_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability applied to the attention output before combining with residual. + emb_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability for the embedding layer. + attn_config (`dict`, *optional*): + A dictionary used to configure the model's attention module. + ffn_config (`dict`, *optional*): + A dictionary used to configure the model's FFN module. + use_cache (`bool`, *optional*, defaults to `False`): + Whether or not the model should return the last key/values attentions (not used by all models). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + + Example: + ```python + >>> from transformers import DbrxConfig, DbrxModel + + >>> # Initializing a Dbrx configuration + >>> configuration = DbrxConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = DbrxModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "dbrx" + attribute_map = { + "num_attention_heads": "n_heads", + "hidden_size": "d_model", + "num_hidden_layers": "n_layers", + "max_position_embeddings": "max_seq_len", + } + + def __init__( + self, + d_model: int = 2048, + n_heads: int = 16, + n_layers: int = 24, + max_seq_len: int = 2048, + vocab_size: int = 32000, + resid_pdrop: float = 0.0, + emb_pdrop: float = 0.0, + attn_config: Optional[DbrxAttentionConfig] = None, + ffn_config: Optional[DbrxFFNConfig] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + output_router_logits: bool = False, + router_aux_loss_coef: float = 0.05, + **kwargs: Any, + ): + if attn_config is None: + self.attn_config = DbrxAttentionConfig() + elif isinstance(attn_config, dict): + self.attn_config = DbrxAttentionConfig(**attn_config) + else: + self.attn_config = attn_config + + if ffn_config is None: + self.ffn_config = DbrxFFNConfig() + elif isinstance(ffn_config, dict): + self.ffn_config = DbrxFFNConfig(**ffn_config) + else: + self.ffn_config = ffn_config + + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.resid_pdrop = resid_pdrop + self.emb_pdrop = emb_pdrop + self.use_cache = use_cache + self.initializer_range = initializer_range + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + + tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + if tie_word_embeddings: + raise ValueError("tie_word_embeddings is not supported for Dbrx models.") + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/sglang/python/sglang/srt/configs/deepseek_ocr.py b/sglang/python/sglang/srt/configs/deepseek_ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..1677423d1811775221b48d257b65865e3cf4c6c9 --- /dev/null +++ b/sglang/python/sglang/srt/configs/deepseek_ocr.py @@ -0,0 +1,817 @@ +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import torch +from PIL import Image, ImageOps +from transformers import ( + AutoProcessor, + LlamaTokenizerFast, + PretrainedConfig, + ProcessorMixin, +) + +from sglang.srt.multimodal.customized_mm_processor_utils import ( + register_customized_processor, +) +from sglang.srt.sampling.custom_logit_processor import ( + DeepseekOCRNoRepeatNGramLogitProcessor, +) + +BASE_SIZE = 1024 +IMAGE_SIZE = 640 +CROP_MODE = True +MIN_CROPS = 2 +MAX_CROPS = 6 # max:9; If your GPU memory is small, it is recommended to set it to 6. +MAX_CONCURRENCY = 100 # If you have limited GPU memory, lower the concurrency count. +NUM_WORKERS = 64 # image pre-process (resize/padding) workers +PRINT_NUM_VIS_TOKENS = False +SKIP_REPEAT = True +MODEL_PATH = "deepseek-ai/DeepSeek-OCR" # change to your model path + +NGRAM_NO_REPEAT_SIZE = 30 +NGRAM_NO_REPEAT_WINDOW = 90 +# Whitelist `` and `` token ids to allow table structures. +NGRAM_NO_REPEAT_WHITELIST = (128821, 128822) + +DEFAULT_CUSTOM_LOGIT_PROCESSOR = DeepseekOCRNoRepeatNGramLogitProcessor.to_str() + + +def get_default_ngram_custom_params() -> Dict[str, Any]: + """Return default custom params for the DeepSeek-OCR n-gram no repeat processor.""" + + return { + "ngram_size": NGRAM_NO_REPEAT_SIZE, + "window_size": NGRAM_NO_REPEAT_WINDOW, + "whitelist_token_ids": list(NGRAM_NO_REPEAT_WHITELIST), + } + + +PROMPT = "\n<|grounding|>Convert the document to markdown." + + +class DictOutput(object): + def items(self): + return self.__dict__.items() + + def keys(self): + return self.__dict__.keys() + + def __getitem__(self, item): + return self.__dict__[item] + + def __contains__(self, key): + return key in self.__dict__ + + def __setitem__(self, key, value): + self.__dict__[key] = value + + +@dataclass +class VLChatProcessorOutput(DictOutput): + input_ids: torch.LongTensor + target_ids: torch.LongTensor + images_crop: torch.LongTensor + pixel_values: ( + torch.Tensor + ) # rename from "images" to "pixel_values" for compatibility + images_seq_mask: torch.BoolTensor + images_spatial_crop: torch.LongTensor + + def __len__(self): + return len(self.input_ids) + + +class ImageTransform(object): + def __init__( + self, + mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), + std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), + normalize: bool = True, + ): + self.mean = mean + self.std = std + self.normalize = normalize + + # only load torchvision.transforms when needed + try: + import torchvision.transforms as T + + # FIXME: add version check for gguf + except ImportError as err: + raise ImportError( + "Please install torchvision via `pip install torchvision` to use Deepseek-VL2." + ) from err + + transform_pipelines = [T.ToTensor()] + + if normalize: + transform_pipelines.append(T.Normalize(mean, std)) + + self.transform = T.Compose(transform_pipelines) + + def __call__(self, pil_img: Image.Image): + x = self.transform(pil_img) + return x + + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def dynamic_preprocess( + image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False +): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + ) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images, target_aspect_ratio + + +class DeepseekOCRProcessor(ProcessorMixin): + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + attributes = ["tokenizer"] + + def __init__( + self, + tokenizer: LlamaTokenizerFast, + candidate_resolutions: Tuple[Tuple[int, int]], + patch_size: int, + downsample_ratio: int, + image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), + image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5), + normalize: bool = True, + image_token: str = "", + pad_token: str = "<|▁pad▁|>", + add_special_token: bool = False, + sft_format: str = "deepseek", + mask_prompt: bool = True, + ignore_id: int = -100, + ocr2_mode: bool = False, + **kwargs, + ): + + self.candidate_resolutions = candidate_resolutions + self.image_size = candidate_resolutions[0][0] + self.patch_size = patch_size + self.image_mean = image_mean + self.image_std = image_std + self.normalize = normalize + self.downsample_ratio = downsample_ratio + self.base_size = BASE_SIZE + self.image_transform = ImageTransform( + mean=image_mean, std=image_std, normalize=normalize + ) + self.tokenizer = tokenizer + # must set this,padding side with make a difference in batch inference + self.tokenizer.padding_side = "left" + + # add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id' + if tokenizer.pad_token is None: + self.tokenizer.add_special_tokens({"pad_token": pad_token}) + + # add image token + image_token_id = self.tokenizer.vocab.get(image_token) + if image_token_id is None: + special_tokens = [image_token] + special_tokens_dict = {"additional_special_tokens": special_tokens} + self.tokenizer.add_special_tokens(special_tokens_dict) + self.image_token_id = self.tokenizer.vocab.get(image_token) + + # add five special tokens for grounding-related tasks + # <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|> + special_tokens = ["<|ref|>", "<|/ref|>", "<|det|>", "<|/det|>", "<|grounding|>"] + special_tokens_dict = {"additional_special_tokens": special_tokens} + self.tokenizer.add_special_tokens(special_tokens_dict) + + # add special tokens for SFT data + special_tokens = ["<|User|>", "<|Assistant|>"] + special_tokens_dict = {"additional_special_tokens": special_tokens} + self.tokenizer.add_special_tokens(special_tokens_dict) + + self.image_token = image_token + self.pad_token = pad_token + self.add_special_token = add_special_token + self.sft_format = sft_format + self.mask_prompt = mask_prompt + self.ignore_id = ignore_id + self.ocr2_mode = ocr2_mode + + super().__init__( + tokenizer, + **kwargs, + ) + + def format_messages_v2(self, messages: str, pil_images, max_req_input_len=-1): + """play the role of format_messages_v2 and get_images_info in the last version""" + tokenized_data = [] + masked_tokenized_data = [] # labels + images_list = [] + images_seq_mask = [] + images_spatial_crop = [] + + image_index = 0 + image_token_cnt = messages.count(self.image_token) + ( + input_ids, + images, + images_crop, + seq_mask, + spatial_crop, + num_image_tokens, + image_shapes, + ) = self.tokenize_with_images( + messages, + pil_images[image_index : image_index + image_token_cnt], + bos=True, + eos=True, + cropping=len(pil_images) <= 2, + ) + + image_index = image_token_cnt + images_list += images + images_seq_mask += seq_mask + images_spatial_crop = spatial_crop + + return ( + input_ids, + masked_tokenized_data, + images_list, + images_seq_mask, + images_spatial_crop, + images_crop, + ) + + @property + def bos_id(self): + return self.tokenizer.bos_token_id + + @property + def eos_id(self): + return self.tokenizer.eos_token_id + + @property + def pad_id(self): + return self.tokenizer.pad_token_id + + def encode(self, text: str, bos: bool = True, eos: bool = False): + t = self.tokenizer.encode(text, add_special_tokens=False) + + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + + return t + + def decode(self, t: List[int], **kwargs) -> str: + return self.tokenizer.decode(t, **kwargs) + + def process_one( + self, + prompt: str = None, + conversations: List[Dict[str, str]] = None, + images: List[Image.Image] = None, + apply_sft_format: bool = False, + inference_mode: bool = True, + system_prompt: str = "", + max_req_input_len: int = -1, + cropping: bool = True, + **kwargs, + ): + """ + + Args: + prompt (str): the formatted prompt; + conversations (List[Dict]): conversations with a list of messages; + images (List[ImageType]): the list of images; + apply_sft_format (bool): if prompt is not None, then apply the SFT format to prompt; + if conversations is not None, then it will always apply the SFT format to conversations; + inference_mode (bool): if True, then remove the last eos token; + system_prompt (str): the system prompt; + **kwargs: + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - target_ids (torch.LongTensor): [N + image tokens] + - images (torch.FloatTensor): [n_images, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + prompt = conversations or prompt + ( + input_ids, + masked_tokenized_str, + images_list, + images_seq_mask, + images_spatial_crop, + images_crop, + ) = self.format_messages_v2(prompt, images, max_req_input_len) + + target_ids = torch.LongTensor(masked_tokenized_str) + + has_images = len(images_list) > 0 + has_local_crops = False + if len(images_spatial_crop) > 0: + has_local_crops = any( + crop[0] > 1 or crop[1] > 1 for crop in images_spatial_crop + ) + + if len(images_list) == 0: + images = torch.zeros((1, 3, self.image_size, self.image_size)) + else: + images = torch.stack(images_list, dim=0) + + images_spatial_crop = torch.stack( + [images_spatial_crop], dim=0 + ) # stack the tensor to make it a batch of 1 + + prepare = VLChatProcessorOutput( + input_ids=input_ids, + target_ids=target_ids, + images_crop=images_crop, + pixel_values=images, + images_seq_mask=images_seq_mask, + images_spatial_crop=images_spatial_crop, + ) + prepare.has_images = has_images + prepare.has_local_crops = has_local_crops + + return prepare + + def __call__( + self, + *, + prompt: str = None, + conversations: List[Dict[str, str]] = None, + images: List[Image.Image] = None, + apply_sft_format: bool = False, + inference_mode: bool = True, + system_prompt: str = "", + max_req_input_len: int = -1, + text: list[str] = None, + **kwargs, + ): + assert text is None or isinstance(text, list) + if text is not None: + text = text[0] + prepare = self.process_one( + prompt=prompt or text, + conversations=conversations, + images=images, + apply_sft_format=apply_sft_format, + inference_mode=inference_mode, + system_prompt=system_prompt, + max_req_input_len=max_req_input_len, + ) + + return prepare + + def find_all_indices(self, messages, target_value): + indices = [] + for index, item in enumerate(messages): + if item == target_value: + indices.append(index) + return indices + + def tokenize_with_images( + self, + conversation: str, + images: List[Image.Image], + bos: bool = True, + eos: bool = True, + cropping: bool = True, + ): + """Tokenize text with tags.""" + + conversation = conversation + assert conversation.count(self.image_token) == len(images) + text_splits = conversation.split(self.image_token) + images_list, images_crop_list, images_seq_mask, images_spatial_crop = ( + [], + [], + [], + [], + ) + image_shapes = [] + num_image_tokens = [] + tokenized_str = [] + for text_sep, image in zip(text_splits, images): + """encode text_sep""" + tokenized_sep = self.encode(text_sep, bos=False, eos=False) + + tokenized_str += tokenized_sep + images_seq_mask += [False] * len(tokenized_sep) + + image_shapes.append(image.size) + + if image.size[0] <= 640 and image.size[1] <= 640: + crop_ratio = [1, 1] + else: + if cropping: + images_crop_raw, crop_ratio = dynamic_preprocess( + image, image_size=IMAGE_SIZE + ) + else: + crop_ratio = [1, 1] + + """process the global view""" + if self.image_size <= 640 and not cropping: + image = image.resize((self.image_size, self.image_size)) + + global_view = ImageOps.pad( + image, + (self.base_size, self.base_size), + color=tuple(int(x * 255) for x in self.image_transform.mean), + ) + images_list.append(self.image_transform(global_view)) + + num_width_tiles, num_height_tiles = crop_ratio + images_spatial_crop.append([num_width_tiles, num_height_tiles]) + + if num_width_tiles > 1 or num_height_tiles > 1: + for i in range(len(images_crop_raw)): + images_crop_list.append(self.image_transform(images_crop_raw[i])) + + """add image tokens""" + num_queries = math.ceil( + (self.image_size // self.patch_size) / self.downsample_ratio + ) + num_queries_base = math.ceil( + (self.base_size // self.patch_size) / self.downsample_ratio + ) + + if self.ocr2_mode: + tokenized_image = [] + if num_width_tiles > 1 or num_height_tiles > 1: + tokenized_image += [self.image_token_id] * ( + num_queries * num_width_tiles * num_queries * num_height_tiles + ) + tokenized_image += [self.image_token_id] * ( + num_queries_base * num_queries_base + ) + # One extra token for the view separator. + tokenized_image += [self.image_token_id] + else: + tokenized_image = ( + [self.image_token_id] * num_queries_base + [self.image_token_id] + ) * num_queries_base + tokenized_image += [self.image_token_id] + if num_width_tiles > 1 or num_height_tiles > 1: + tokenized_image += ( + [self.image_token_id] * (num_queries * num_width_tiles) + + [self.image_token_id] + ) * (num_queries * num_height_tiles) + tokenized_str += tokenized_image + + images_seq_mask += [True] * len(tokenized_image) + num_image_tokens.append(len(tokenized_image)) + + """process the last text split""" + tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False) + + tokenized_str += tokenized_sep + images_seq_mask += [False] * len(tokenized_sep) + + """add the bos and eos tokens""" + if bos: + tokenized_str = [self.bos_id] + tokenized_str + images_seq_mask = [False] + images_seq_mask + if eos: + tokenized_str = tokenized_str + [self.eos_id] + images_seq_mask = images_seq_mask + [False] + + assert len(tokenized_str) == len( + images_seq_mask + ), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}" + + masked_tokenized_str = [] + for token_index in tokenized_str: + if token_index != self.image_token_id: + masked_tokenized_str.append(token_index) + else: + masked_tokenized_str.append(self.ignore_id) + + assert ( + len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str) + ), ( + f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, " + f"imags_seq_mask's length {len(images_seq_mask)}, are not equal" + ) + input_ids = torch.LongTensor(tokenized_str) + target_ids = torch.LongTensor(masked_tokenized_str) + images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) + + # set input_ids < 0 | input_ids == self.image_token_id as ignore_id + target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = ( + self.ignore_id + ) + input_ids[input_ids < 0] = self.pad_id + + inference_mode = True + + if inference_mode: + # Remove the ending eos token + assert input_ids[-1] == self.eos_id + input_ids = input_ids[:-1] + target_ids = target_ids[:-1] + images_seq_mask = images_seq_mask[:-1] + + if len(images_list) == 0: + pixel_values = torch.zeros((1, 3, self.base_size, self.base_size)) + images_spatial_crop = torch.zeros((1, 1), dtype=torch.long) + images_crop = torch.zeros( + (1, 3, self.image_size, self.image_size) + ).unsqueeze(0) + else: + pixel_values = torch.stack(images_list, dim=0) + images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) + if images_crop_list: + images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0) + else: + images_crop = torch.zeros( + (1, 3, self.image_size, self.image_size) + ).unsqueeze(0) + + input_ids = input_ids.unsqueeze(0) + return ( + input_ids, + pixel_values, + images_crop, + images_seq_mask, + images_spatial_crop, + num_image_tokens, + image_shapes, + ) + + +class VisionEncoderConfig(PretrainedConfig): + model_type: str = "vision" + + model_name: str = "vit_so400m_patch14_siglip_384.webli" + image_size: int = 384 + patch_size: int = 16 + width: int = 1024 + layers: int = 24 + heads: int = 16 + mlp_ratio: int = 4 + global_pool: str = "map" + ignore_head: bool = True + class_token: bool = False + num_classes: int = 0 + use_checkpoint: bool = False + weight_init: str = "skip" + deterministic: bool = False + num_recomputing_layers: int = 0 + + def __init__( + self, + model_name: str = "vit_so400m_patch14_siglip_384.webli", + image_size: int = 384, + patch_size: int = 16, + width: int = 1024, + layers: int = 24, + heads: int = 16, + mlp_ratio: int = 4, + global_pool: str = "map", + ignore_head: bool = True, + class_token: bool = False, + num_classes: int = 0, + use_checkpoint: bool = False, + **kwargs, + ): + self.model_name = model_name + self.image_size = image_size + self.patch_size = patch_size + self.width = width + self.layers = layers + self.heads = heads + self.mlp_ratio = mlp_ratio + self.global_pool = global_pool + self.ignore_head = ignore_head + self.class_token = class_token + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + + super().__init__(**kwargs) + + +class MlpProjectorConfig(PretrainedConfig): + model_type = "mlp_projector" + projector_type: str = "downsample_mlp_gelu" + input_dim: int = 1152 + n_embed: int = 2048 + depth: int = 2 + mlp_ratio: int = 1 + downsample_ratio: int = 2 + token_pooling: bool = False + + def __init__( + self, + projector_type: str = "downsample_mlp_gelu", + input_dim: int = 1152, + n_embed: int = 2048, + depth: int = 2, + mlp_ratio: int = 1, + downsample_ratio: int = 2, + **kwargs, + ): + self.projector_type = projector_type + self.input_dim = input_dim + self.n_embed = n_embed + self.depth = depth + self.mlp_ratio = mlp_ratio + self.downsample_ratio = downsample_ratio + + super().__init__(**kwargs) + + +class DeepseekV2Config(PretrainedConfig): + model_type = "deepseek_v2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size=1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts=None, + n_routed_experts=None, + ep_size=1, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="gready", + n_group=None, + topk_group=None, + num_experts_per_tok=None, + moe_layer_freq=1, + first_k_dense_replace=0, + norm_topk_prob=False, + scoring_func="softmax", + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + use_mla=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = float(rms_norm_eps) + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.use_mla = use_mla + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +@register_customized_processor(processor_class=DeepseekOCRProcessor) +class DeepseekVLV2Config(PretrainedConfig): + # model_type = "deepseek_vl_v2" + model_type = "deepseek-ocr" + vision_config: VisionEncoderConfig + projector_config: MlpProjectorConfig + + tile_tag: str = "2D" + global_view_pos: str = "head" + candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),) + customized_processor_type: type[Any] = DeepseekOCRProcessor + + def __init__( + self, + tile_tag: str = "tile_tag", + global_view_pos: str = "head", + candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),), + **kwargs, + ): + super().__init__(**kwargs) + + vision_config = kwargs.get("vision_config", {}) + self.vision_config = VisionEncoderConfig(**vision_config) + + projector_config = kwargs.get("projector_config", {}) + self.projector_config = MlpProjectorConfig(**projector_config) + + language_config = kwargs.get("language_config", {}) + self.text_config = DeepseekV2Config(**language_config) + + self.tile_tag = tile_tag + self.global_view_pos = global_view_pos + self.candidate_resolutions = candidate_resolutions + self.vocab_size = self.text_config.vocab_size + self.hidden_size = self.text_config.hidden_size + + +AutoProcessor.register(DeepseekVLV2Config, DeepseekOCRProcessor) diff --git a/sglang/python/sglang/srt/configs/deepseekvl2.py b/sglang/python/sglang/srt/configs/deepseekvl2.py new file mode 100644 index 0000000000000000000000000000000000000000..9621f058bf631756940e1e99d9cee39f965d83d7 --- /dev/null +++ b/sglang/python/sglang/srt/configs/deepseekvl2.py @@ -0,0 +1,687 @@ +import math +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch +from PIL import Image, ImageOps +from transformers import ( + AutoProcessor, + LlamaTokenizerFast, + PretrainedConfig, + ProcessorMixin, +) + + +def select_best_resolution(image_size, candidate_resolutions): + # used for cropping + original_width, original_height = image_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float("inf") + + for width, height in candidate_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int( + original_height * scale + ) + effective_resolution = min( + downscaled_width * downscaled_height, original_width * original_height + ) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution + and wasted_resolution < min_wasted_resolution + ): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (width, height) + + return best_fit + + +class DictOutput(object): + def items(self): + return self.__dict__.items() + + def keys(self): + return self.__dict__.keys() + + def __getitem__(self, item): + return self.__dict__[item] + + def __contains__(self, key): + return key in self.__dict__ + + def __setitem__(self, key, value): + self.__dict__[key] = value + + +@dataclass +class VLChatProcessorOutput(DictOutput): + input_ids: torch.LongTensor + target_ids: torch.LongTensor + pixel_values: ( + torch.Tensor + ) # rename from "images" to "pixel_values" for compatibility + images_seq_mask: torch.BoolTensor + images_spatial_crop: torch.LongTensor + + def __len__(self): + return len(self.input_ids) + + +class ImageTransform(object): + def __init__( + self, + mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), + std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), + normalize: bool = True, + ): + self.mean = mean + self.std = std + self.normalize = normalize + + # only load torchvision.transforms when needed + try: + import torchvision.transforms as T + + # FIXME: add version check for gguf + except ImportError as err: + raise ImportError( + "Please install torchvision via `pip install torchvision` to use Deepseek-VL2." + ) from err + + transform_pipelines = [T.ToTensor()] + + if normalize: + transform_pipelines.append(T.Normalize(mean, std)) + + self.transform = T.Compose(transform_pipelines) + + def __call__(self, pil_img: Image.Image): + x = self.transform(pil_img) + return x + + +class DeepseekVLV2Processor(ProcessorMixin): + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + attributes = ["tokenizer"] + + def __init__( + self, + tokenizer: LlamaTokenizerFast, + candidate_resolutions: Tuple[Tuple[int, int]], + patch_size: int, + downsample_ratio: int, + image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), + image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5), + normalize: bool = True, + image_token: str = "", + pad_token: str = "<|▁pad▁|>", + add_special_token: bool = False, + sft_format: str = "deepseek", + mask_prompt: bool = True, + ignore_id: int = -100, + **kwargs, + ): + + self.candidate_resolutions = candidate_resolutions + self.image_size = candidate_resolutions[0][0] + self.patch_size = patch_size + self.image_mean = image_mean + self.image_std = image_std + self.normalize = normalize + self.downsample_ratio = downsample_ratio + + self.image_transform = ImageTransform( + mean=image_mean, std=image_std, normalize=normalize + ) + self.tokenizer = tokenizer + # must set this,padding side with make a difference in batch inference + self.tokenizer.padding_side = "left" + + # add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id' + if tokenizer.pad_token is None: + self.tokenizer.add_special_tokens({"pad_token": pad_token}) + + # add image token + image_token_id = self.tokenizer.vocab.get(image_token) + if image_token_id is None: + special_tokens = [image_token] + special_tokens_dict = {"additional_special_tokens": special_tokens} + self.tokenizer.add_special_tokens(special_tokens_dict) + self.image_token_id = self.tokenizer.vocab.get(image_token) + + # add five special tokens for grounding-related tasks + # <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|> + special_tokens = ["<|ref|>", "<|/ref|>", "<|det|>", "<|/det|>", "<|grounding|>"] + special_tokens_dict = {"additional_special_tokens": special_tokens} + self.tokenizer.add_special_tokens(special_tokens_dict) + + # add special tokens for SFT data + special_tokens = ["<|User|>", "<|Assistant|>"] + special_tokens_dict = {"additional_special_tokens": special_tokens} + self.tokenizer.add_special_tokens(special_tokens_dict) + + self.image_token = image_token + self.pad_token = pad_token + self.add_special_token = add_special_token + self.sft_format = sft_format + self.mask_prompt = mask_prompt + self.ignore_id = ignore_id + + super().__init__( + tokenizer, + **kwargs, + ) + + def format_messages_v2(self, messages, pil_images, max_req_input_len=-1): + """play the role of format_messages_v2 and get_images_info in the last version""" + tokenized_data = [] + masked_tokenized_data = [] # labels + images_list = [] + images_seq_mask = [] + images_spatial_crop = [] + + image_index = 0 + image_token_cnt = messages.count(self.image_token) + tokenized_str, images, seq_mask, spatial_crop = self.tokenize_with_images( + messages, + pil_images[image_index : image_index + image_token_cnt], + bos=True, + eos=True, + cropping=len(pil_images) <= 2, + max_req_input_len=max_req_input_len, + ) + + image_index = image_token_cnt + tokenized_data += tokenized_str + if self.mask_prompt: + masked_tokenized_data += [self.ignore_id] * len(tokenized_str) + else: + masked_tokenized_data += tokenized_str + images_list += images + images_seq_mask += seq_mask + images_spatial_crop += spatial_crop + + assert len(tokenized_data) == len( + images_seq_mask + ), f"format_messages_v2: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}" + + return ( + tokenized_data, + masked_tokenized_data, + images_list, + images_seq_mask, + images_spatial_crop, + ) + + @property + def bos_id(self): + return self.tokenizer.bos_token_id + + @property + def eos_id(self): + return self.tokenizer.eos_token_id + + @property + def pad_id(self): + return self.tokenizer.pad_token_id + + def encode(self, text: str, bos: bool = True, eos: bool = False): + t = self.tokenizer.encode(text, add_special_tokens=False) + + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + + return t + + def decode(self, t: List[int], **kwargs) -> str: + return self.tokenizer.decode(t, **kwargs) + + def process_one( + self, + prompt: str = None, + conversations: List[Dict[str, str]] = None, + images: List[Image.Image] = None, + apply_sft_format: bool = False, + inference_mode: bool = True, + system_prompt: str = "", + max_req_input_len: int = -1, + **kwargs, + ): + """ + + Args: + prompt (str): the formatted prompt; + conversations (List[Dict]): conversations with a list of messages; + images (List[ImageType]): the list of images; + apply_sft_format (bool): if prompt is not None, then apply the SFT format to prompt; + if conversations is not None, then it will always apply the SFT format to conversations; + inference_mode (bool): if True, then remove the last eos token; + system_prompt (str): the system prompt; + **kwargs: + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - target_ids (torch.LongTensor): [N + image tokens] + - images (torch.FloatTensor): [n_images, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + assert ( + prompt is None or conversations is None + ), "prompt and conversations cannot be used at the same time." + + ( + tokenized_str, + masked_tokenized_str, + images_list, + images_seq_mask, + images_spatial_crop, + ) = self.format_messages_v2(conversations, images, max_req_input_len) + + assert ( + len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str) + ), ( + f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, " + f"imags_seq_mask's length {len(images_seq_mask)}, are not equal" + ) + + input_ids = torch.LongTensor(tokenized_str) + target_ids = torch.LongTensor(masked_tokenized_str) + images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) + + # set input_ids < 0 | input_ids == self.image_token_id as ignore_id + target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = ( + self.ignore_id + ) + input_ids[input_ids < 0] = self.pad_id + + if inference_mode: + assert input_ids[-1] == self.eos_id + input_ids = input_ids[:-1] + target_ids = target_ids[:-1] + images_seq_mask = images_seq_mask[:-1] + + if len(images_list) == 0: + images = torch.zeros((1, 3, self.image_size, self.image_size)) + images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) + else: + images = torch.stack(images_list, dim=0) + images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) + + images_spatial_crop = torch.stack( + [images_spatial_crop], dim=0 + ) # stack the tensor to make it a batch of 1 + + prepare = VLChatProcessorOutput( + input_ids=input_ids, + target_ids=target_ids, + pixel_values=images, + images_seq_mask=images_seq_mask, + images_spatial_crop=images_spatial_crop, + ) + + return prepare + + def __call__( + self, + *, + prompt: str = None, + conversations: List[Dict[str, str]] = None, + images: List[Image.Image] = None, + apply_sft_format: bool = False, + inference_mode: bool = True, + system_prompt: str = "", + max_req_input_len: int = -1, + **kwargs, + ): + prepare = self.process_one( + prompt=prompt, + conversations=conversations, + images=images, + apply_sft_format=apply_sft_format, + inference_mode=inference_mode, + system_prompt=system_prompt, + max_req_input_len=max_req_input_len, + ) + + return prepare + + def find_all_indices(self, messages, target_value): + indices = [] + for index, item in enumerate(messages): + if item == target_value: + indices.append(index) + return indices + + def tokenize_with_images( + self, + conversation: str, + images: List[Image.Image], + bos: bool = True, + eos: bool = True, + cropping: bool = True, + max_req_input_len: int = -1, + ): + """Tokenize text with tags.""" + images_list, images_seq_mask, images_spatial_crop = [], [], [] + text_splits = conversation.split(self.image_token) + tokenized_str = [] + for text_sep, image in zip(text_splits, images): + """encode text_sep""" + tokenized_sep = self.encode(text_sep, bos=False, eos=False) + tokenized_str += tokenized_sep + images_seq_mask += [False] * len(tokenized_sep) + + """select best resolution for anyres""" + if cropping: + best_width, best_height = select_best_resolution( + image.size, self.candidate_resolutions + ) + else: + best_width, best_height = self.image_size, self.image_size + # print(image.size, (best_width, best_height)) # check the select_best_resolutions func + + """process the global view""" + global_view = ImageOps.pad( + image, + (self.image_size, self.image_size), + color=tuple(int(x * 255) for x in self.image_transform.mean), + ) + images_list.append(self.image_transform(global_view)) + + """process the local views""" + local_view = ImageOps.pad( + image, + (best_width, best_height), + color=tuple(int(x * 255) for x in self.image_transform.mean), + ) + for i in range(0, best_height, self.image_size): + for j in range(0, best_width, self.image_size): + images_list.append( + self.image_transform( + local_view.crop( + (j, i, j + self.image_size, i + self.image_size) + ) + ) + ) + + """record height / width crop num""" + num_width_tiles, num_height_tiles = ( + best_width // self.image_size, + best_height // self.image_size, + ) + images_spatial_crop.append([num_width_tiles, num_height_tiles]) + + """add image tokens""" + h = w = math.ceil( + (self.image_size // self.patch_size) / self.downsample_ratio + ) + # global views tokens h * (w + 1), 1 is for line separator + tokenized_image = [self.image_token_id] * h * (w + 1) + # add a separator between global and local views + tokenized_image += [self.image_token_id] + # local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1) + tokenized_image += ( + [self.image_token_id] + * (num_height_tiles * h) + * (num_width_tiles * w + 1) + ) + + tokenized_str += tokenized_image + images_seq_mask += [True] * len(tokenized_image) + # print(width_crop_num, height_crop_num, len(tokenized_image)) # test the correctness of the number of image-related tokens + + """process the last text split""" + tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False) + # deal with video, limit with request len + if max_req_input_len > -1: + if max_req_input_len < len(tokenized_sep) + len(tokenized_str) - 1: + rest = max_req_input_len - len(tokenized_sep) - 1 - 1024 + tokenized_str = tokenized_str[:rest] + images_seq_mask = images_seq_mask[:rest] + tokenized_str += tokenized_sep + images_seq_mask += [False] * len(tokenized_sep) + + """add the bos and eos tokens""" + if bos: + tokenized_str = [self.bos_id] + tokenized_str + images_seq_mask = [False] + images_seq_mask + if eos: + tokenized_str = tokenized_str + [self.eos_id] + images_seq_mask = images_seq_mask + [False] + + assert len(tokenized_str) == len( + images_seq_mask + ), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}" + + return tokenized_str, images_list, images_seq_mask, images_spatial_crop + + +class DeepseekVL2VisionEncoderConfig(PretrainedConfig): + model_type: str = "vision" + + model_name: str = "siglip_large_patch16_384" + image_size: int = 384 + patch_size: int = 16 + width: int = 1024 + layers: int = 24 + heads: int = 16 + mlp_ratio: int = 4 + global_pool: str = "map" + ignore_head: bool = True + class_token: bool = False + num_classes: int = 0 + use_checkpoint: bool = False + weight_init: str = "skip" + deterministic: bool = False + num_recomputing_layers: int = 0 + + def __init__( + self, + model_name: str = "siglip_large_patch16_384", + image_size: int = 384, + patch_size: int = 16, + width: int = 1024, + layers: int = 24, + heads: int = 16, + mlp_ratio: int = 4, + global_pool: str = "map", + ignore_head: bool = True, + class_token: bool = False, + num_classes: int = 0, + use_checkpoint: bool = False, + **kwargs, + ): + self.model_name = model_name + self.image_size = image_size + self.patch_size = patch_size + self.width = width + self.layers = layers + self.heads = heads + self.mlp_ratio = mlp_ratio + self.global_pool = global_pool + self.ignore_head = ignore_head + self.class_token = class_token + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + + super().__init__(**kwargs) + + +class DeepseekVL2MlpProjectorConfig(PretrainedConfig): + model_type = "mlp_projector" + projector_type: str = "downsample_mlp_gelu" + input_dim: int = 1152 + n_embed: int = 2048 + depth: int = 2 + mlp_ratio: int = 1 + downsample_ratio: int = 2 + token_pooling: bool = False + + def __init__( + self, + projector_type: str = "downsample_mlp_gelu", + input_dim: int = 1152, + n_embed: int = 2048, + depth: int = 2, + mlp_ratio: int = 1, + downsample_ratio: int = 2, + **kwargs, + ): + self.projector_type = projector_type + self.input_dim = input_dim + self.n_embed = n_embed + self.depth = depth + self.mlp_ratio = mlp_ratio + self.downsample_ratio = downsample_ratio + + super().__init__(**kwargs) + + +class DeepseekV2Config(PretrainedConfig): + + model_type = "deepseek_v2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size=1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts=None, + n_routed_experts=None, + ep_size=1, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="gready", + n_group=None, + topk_group=None, + num_experts_per_tok=None, + moe_layer_freq=1, + first_k_dense_replace=0, + norm_topk_prob=False, + scoring_func="softmax", + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + use_mla=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = float(rms_norm_eps) + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.use_mla = use_mla + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class DeepseekVL2Config(PretrainedConfig): + model_type = "deepseek_vl_v2" + vision_config: DeepseekVL2VisionEncoderConfig + projector_config: DeepseekVL2MlpProjectorConfig + language_config: DeepseekV2Config + + tile_tag: str = "2D" + global_view_pos: str = "head" + candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),) + + def __init__( + self, + tile_tag: str = "tile_tag", + global_view_pos: str = "head", + candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),), + **kwargs, + ): + super().__init__(**kwargs) + + vision_config = kwargs.get("vision_config", {}) + self.vision_config = DeepseekVL2VisionEncoderConfig(**vision_config) + + projector_config = kwargs.get("projector_config", {}) + self.projector_config = DeepseekVL2MlpProjectorConfig(**projector_config) + + language_config = kwargs.get("language_config", {}) + if isinstance(language_config, DeepseekV2Config): + self.language_config = language_config + else: + self.language_config = DeepseekV2Config(**language_config) + + self.tile_tag = tile_tag + self.global_view_pos = global_view_pos + self.candidate_resolutions = candidate_resolutions + self.architectures = ["DeepseekVL2ForCausalLM"] + + +AutoProcessor.register(DeepseekVL2Config, DeepseekVLV2Processor) diff --git a/sglang/python/sglang/srt/configs/device_config.py b/sglang/python/sglang/srt/configs/device_config.py new file mode 100644 index 0000000000000000000000000000000000000000..8ddcfd10816e96785a88b219a390383292617aa8 --- /dev/null +++ b/sglang/python/sglang/srt/configs/device_config.py @@ -0,0 +1,19 @@ +import logging +from typing import Optional + +import torch + +logger = logging.getLogger(__name__) + + +class DeviceConfig: + device: Optional[torch.device] + gpu_id: Optional[int] + + def __init__(self, device: str = "cuda", gpu_id: int = -1) -> None: + if device in ["cuda", "xpu", "hpu", "cpu", "npu", "musa"]: + self.device_type = device + else: + raise RuntimeError(f"Not supported device type: {device}") + self.device = torch.device(self.device_type) + self.gpu_id = gpu_id diff --git a/sglang/python/sglang/srt/configs/dots_ocr.py b/sglang/python/sglang/srt/configs/dots_ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..8b0693b8e9cc35fca97e2f9b0bf1a384d607ea39 --- /dev/null +++ b/sglang/python/sglang/srt/configs/dots_ocr.py @@ -0,0 +1,64 @@ +from typing import Optional + +from transformers import AutoProcessor, Qwen2_5_VLProcessor +from transformers.image_processing_utils import BaseImageProcessor +from transformers.models.qwen2 import Qwen2Config + +from sglang.srt.configs.dots_vlm import DotsVisionConfig + + +class DotsOCRConfig(Qwen2Config): + model_type = "dots_ocr" + + def __init__( + self, + image_token_id=151665, + video_token_id=151656, + vision_config: Optional[dict] = None, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_config = DotsVisionConfig(**(vision_config or {})) + + def save_pretrained(self, save_directory, **kwargs): + self._auto_class = None + super().save_pretrained(save_directory, **kwargs) + + +class DummyVideoProcessor(BaseImageProcessor): + model_input_names = ["pixel_values"] + + def __call__(self, *args, **kwargs): + return None + + +class DotsVLProcessor(Qwen2_5_VLProcessor): + def __init__( + self, + image_processor=None, + tokenizer=None, + video_processor=None, + chat_template=None, + **kwargs + ): + if video_processor is None: + video_processor = DummyVideoProcessor() + super().__init__( + image_processor, tokenizer, video_processor, chat_template=chat_template + ) + self.image_token = ( + "<|imgpad|>" + if not hasattr(tokenizer, "image_token") + else tokenizer.image_token + ) + self.image_token_id = ( + tokenizer.image_token_id + if getattr(tokenizer, "image_token_id", None) is not None + else tokenizer.convert_tokens_to_ids(self.image_token) + ) + + +AutoProcessor.register(DotsOCRConfig, DotsVLProcessor) diff --git a/sglang/python/sglang/srt/configs/dots_vlm.py b/sglang/python/sglang/srt/configs/dots_vlm.py new file mode 100644 index 0000000000000000000000000000000000000000..dc921582ccf8f6ece01902780aa5ed733a3cd9a8 --- /dev/null +++ b/sglang/python/sglang/srt/configs/dots_vlm.py @@ -0,0 +1,134 @@ +from transformers import AutoProcessor, PretrainedConfig +from transformers.processing_utils import ProcessingKwargs + +try: + from transformers import Qwen2_5_VLProcessor +except ImportError: + raise ImportError( + "Qwen2_5_VLProcessor can not be found. Please upgrade your transformers version." + ) + +from sglang.srt.configs.deepseekvl2 import DeepseekV2Config + + +class DotsVisionConfig(PretrainedConfig): + model_type: str = "dots_vit" + + def __init__( + self, + embed_dim: int = 1536, # vision encoder embed size + hidden_size: int = 1536, # after merger hidden size + intermediate_size: int = 4224, + num_hidden_layers: int = 42, + num_attention_heads: int = 12, + num_channels: int = 3, + patch_size: int = 14, + spatial_merge_size: int = 2, + temporal_patch_size: int = 1, + rms_norm_eps: float = 1e-5, + use_bias: bool = False, + attn_implementation="flash_attention_2", # "eager","sdpa","flash_attention_2" + initializer_range=0.02, + init_merger_std=0.02, + is_causal=False, # ve causal forward + post_norm=True, + gradient_checkpointing=False, + **kwargs, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.rms_norm_eps = rms_norm_eps + self.use_bias = use_bias + self.attn_implementation = attn_implementation + self.initializer_range = initializer_range + self.init_merger_std = init_merger_std + self.is_causal = is_causal + self.post_norm = post_norm + self.gradient_checkpointing = gradient_checkpointing + + +class DotsVLMConfig(PretrainedConfig): + model_type = "dots_vlm" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + vision_config = kwargs.get("vision_config", {}) + self.im_span_id = kwargs.get("image_token_id", 128815) + self.video_span_id = kwargs.get("video_token_id", 128836) + self.vision_config = DotsVisionConfig(**vision_config) + self.language_config = DeepseekV2Config(**kwargs) + self.architectures = ["DotsVLMForCausalLM"] + + +class DotsVLMProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + } + + +class DotsVLMProcessor(Qwen2_5_VLProcessor): + r""" + Constructs a DotsVLM processor which derives from Qwen2_5_VLProcessor, but overrides the image and video token ids. + Besides, its tokenizer is a LlamaTokenizerFast instead of Qwen2TokenizerFast. + [`DotsVLMProcessor`] offers all the functionalities of [`DotsVisionConfig`] and [`LlamaTokenizerFast`]. See the + [`~DotsVLMProcessor.__call__`] and [`~DotsVLMProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + + valid_kwargs = ["chat_template"] + + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + + def __init__( + self, image_processor=None, tokenizer=None, chat_template=None, **kwargs + ): + super().__init__(image_processor, tokenizer, chat_template=chat_template) + self.image_token = ( + "<|imgpad|>" + if not hasattr(tokenizer, "image_token") + else tokenizer.image_token + ) + self.video_token = ( + "<|video_pad|>" + if not hasattr(tokenizer, "video_token") + else tokenizer.video_token + ) + self.img_token = ( + "<|img|>" if not hasattr(tokenizer, "img_token") else tokenizer.img_token + ) + self.endofimg_token = ( + "<|endofimg|>" + if not hasattr(tokenizer, "endofimg_token") + else tokenizer.endofimg_token + ) + self.image_token_id = ( + tokenizer.image_token_id + if getattr(tokenizer, "image_token_id", None) + else tokenizer.encode(self.image_token)[0] + ) + self.video_token_id = ( + tokenizer.video_token_id + if getattr(tokenizer, "video_token_id", None) + else tokenizer.encode(self.video_token)[0] + ) + + +AutoProcessor.register(DotsVLMConfig, DotsVLMProcessor) diff --git a/sglang/python/sglang/srt/configs/exaone.py b/sglang/python/sglang/srt/configs/exaone.py new file mode 100644 index 0000000000000000000000000000000000000000..f5d91c45cf84c733d28528bc8721be6f8ad16bd4 --- /dev/null +++ b/sglang/python/sglang/srt/configs/exaone.py @@ -0,0 +1,196 @@ +# coding=utf-8 +# Copyright 2024 The LG AI Research EXAONE Lab. All rights reserved. +# Copyright 2024 The LG CNS AI Engineering Team. +# Copyright 2023-2024 SGLang Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""EXAONE model configuration""" + +from typing import Any, Dict + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +EXAONE_PRETRAINED_CONFIG_ARCHIVE_MAP: Dict[str, Any] = {} + + +# ruff: noqa: E501 +class ExaoneConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.ExaoneModel`. It is used to + instantiate a EXAONE model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Exaone + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 102400): + Vocabulary size of the EXAONE model. Defines the number of different tokens that can be represented by the + :obj:`inputs_ids` passed when calling :class:`~transformers.ExaoneModel`. Vocabulary size of the model. + Defines the different tokens that can be represented by the `inputs_ids` passed to the forward method of + :class:`~transformers.EXAONEModel`. + max_position_embeddings (:obj:`int`, `optional`, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_size (:obj:`int`, `optional`, defaults to 2048): + Dimensionality of the encoder layers and the pooler layer. + num_layers (:obj:`int`, `optional`, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, `optional`, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (:obj:`int`, `optional`): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + intermediate_size (:obj:`int`, `optional`, defaults to `hidden_size * 4`): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"silu"`): + The non-linear activation function (function or string) in the decoder. + rope_theta (:obj:`float`, `optional`, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (:obj:`Dict`, `optional`): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (:obj:`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (:obj:`float`, `optional`): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (:obj:`int`, `optional`): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (:obj:`float`, `optional`): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (:obj:`float`, `optional`): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (:obj:`float`, `optional`): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (:obj:`List[float]`, `optional`): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (:obj:`List[float]`, `optional`): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (:obj:`float`, `optional`): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (:obj:`float`, `optional`): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + embed_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for the attention probabilities. + layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5): + The epsilon used by the layer normalization layers. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if ``configs.is_decoder=True``. + bos_token_id (:obj:`int`, `optional`, defaults to 0): + Beginning of stream token id. + eos_token_id (:obj:`int`, `optional`, defaults to 2): + End of stream token id. + tie_word_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to tie weight embeddings + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + + Example:: + + >>> from transformers import EXAONEModel, ExaoneConfig + + >>> # Initializing a EXAONE configuration + >>> configuration = ExaoneConfig() + + >>> # Initializing a model from configuration + >>> model = EXAONEModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.configs + """ + + model_type = "exaone" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_hidden_layers": "num_layers"} + + def __init__( + self, + vocab_size=102400, + max_position_embeddings=2048, + hidden_size=2048, + num_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + intermediate_size=None, + activation_function="silu", + rope_theta=10000.0, + rope_scaling=None, + embed_dropout=0.0, + attention_dropout=0.0, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + bos_token_id=0, + eos_token_id=2, + tie_word_embeddings=True, + **kwargs + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_layers + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + if intermediate_size: + self.intermediate_size = intermediate_size + else: + self.intermediate_size = hidden_size * 4 + self.activation_function = activation_function + self.embed_dropout = embed_dropout + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs + ) diff --git a/sglang/python/sglang/srt/configs/falcon_h1.py b/sglang/python/sglang/srt/configs/falcon_h1.py new file mode 100644 index 0000000000000000000000000000000000000000..3aa038fb891da2ca7964c63c7070fc22cdd14b47 --- /dev/null +++ b/sglang/python/sglang/srt/configs/falcon_h1.py @@ -0,0 +1,315 @@ +# coding=utf-8 +# Copyright 2024 TII and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Falcon-H1 model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +from sglang.srt.configs.mamba_utils import ( + Mamba2CacheParams, + Mamba2StateShape, + mamba2_state_dtype, +) + +logger = logging.get_logger(__name__) + + +class FalconH1Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`FalconH1Model`]. It is used to instantiate a + FalconH1Model model according to the specified arguments, defining the model architecture. Instantiating a configuration + with defaults taken from [ibm-fms/FalconH1-9.8b-2.2T-hf](https://huggingface.co/ibm-fms/FalconH1-9.8b-2.2T-hf). + The FalconH1Model is a hybrid [mamba2](https://github.com/state-spaces/mamba) architecture with SwiGLU. + The checkpoints are jointly trained by IBM, Princeton, and UIUC. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 128000): + Vocabulary size of the FalconH1 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`FalconH1Model`] + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + num_logits_to_keep (`int` or `None`, *optional*, defaults to 1): + Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an + integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the + logits of the last prompt token are needed for generation. For long sequences, the logits for the entire + sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint + significantly. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + max_position_embeddings (`int`, *optional*, defaults to 8192): + Max cached sequence length for the model + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mamba_d_ssm (`int`, *optional*, defaults to 1024): + The dimension of the SSM state space latents. + mamba_n_heads (`int`, *optional*, defaults to 128): + The number of mamba heads used in the v2 implementation. + mamba_d_head (`int`, *optional*, defaults to `"auto"`): + Head embedding dimension size + mamba_n_groups (`int`, *optional*, defaults to 1): + The number of the mamba groups used in the v2 implementation. + mamba_d_state (`int`, *optional*, defaults to 256): + The dimension the mamba state space latents + mamba_d_conv (`int`, *optional*, defaults to 4): + The size of the mamba convolution kernel + mamba_expand (`int`, *optional*, defaults to 2): + Expanding factor (relative to hidden_size) used to determine the mamba intermediate size + mamba_chunk_size (`int`, *optional*, defaults to 256): + The chunks in which to break the sequence when doing prefill/training + mamba_conv_bias (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block. + mamba_proj_bias (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block + mamba_norm_before_gate (`bool`, *optional*, defaults to `True`): + Whether to use RMSNorm before the gate in the Mamba block + mamba_rms_norm (`bool`, *optional*, defaults to `False`): + Whether to use RMSNorm instead of LayerNorm in the Mamba block + projectors_bias (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the attention block + rope_theta (`float`, *optional*, defaults to 100000.0): + The theta value used for the RoPE embeddings. + rope_scaling (`float`, *optional*): + The scaling value used for the RoPE embeddings. If `None`, no scaling is applied. + lm_head_multiplier (`float`, *optional*, defaults to 1.0): + The multiplier for the LM head. This is used to scale the output of the LM head. + embedding_multiplier (`float`, *optional*, defaults to 1.0): + The multiplier for the embedding layer. This is used to scale the output of the embedding layer. + mlp_multipliers (`list[float]`, *optional*): + The multipliers for the MLP layers. This is used to scale the output of the MLP layers. The first value is + the multiplier of gate layer, the second value is the multiplier of the down_proj layer. + key_multiplier (`float`, *optional*): + The multiplier for the key layer. This is used to scale the output of the key layer. + attention_out_multiplier (`float`, *optional*): + The multiplier for the attention output layer. This is used to scale the output of the attention output + attention_in_multiplier (`float`, *optional*): + The multiplier for the attention input layer. This is used to scale the output of the attention input layer. + ssm_multipliers (`list[float]`, *optional*): + The multipliers for the SSM layers. This is used to scale the output of the SSM layers. + ssm_in_multiplier (`float`, *optional*): + The multiplier for the SSM input layer. This is used to scale the output of the SSM input layer. + ssm_out_multiplier (`float`, *optional*): + The multiplier for the SSM output layer. This is used to scale the output of the SSM output layer. + """ + + model_type = "falcon_h1" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=128000, + tie_word_embeddings=False, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + num_logits_to_keep=1, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + max_position_embeddings=8192, + attention_dropout=0.0, + mamba_d_ssm=1024, + mamba_n_heads=128, + mamba_d_head="auto", + mamba_n_groups=1, + mamba_d_state=256, + mamba_d_conv=4, + mamba_expand=2, + mamba_chunk_size=256, + mamba_conv_bias=True, + mamba_proj_bias=False, + mamba_norm_before_gate=True, + mamba_rms_norm=False, + projectors_bias=False, + rope_theta=100000.0, + rope_scaling=None, + lm_head_multiplier=1.0, + embedding_multiplier=1.0, + mlp_multipliers=None, + key_multiplier=None, + attention_out_multiplier=None, + attention_in_multiplier=None, + ssm_multipliers=None, + ssm_in_multiplier=None, + ssm_out_multiplier=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.attention_dropout = attention_dropout + self.attention_bias = False + self.mlp_bias = False + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep + + self.rope_theta = rope_theta + self.rope_scaling = None + self.rope_scaling = rope_scaling + self.projectors_bias = projectors_bias + self.mamba_intermediate = mamba_intermediate = ( + mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm + ) + + if mamba_intermediate % mamba_n_heads != 0: + raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size") + + # for the mamba_v2, must satisfy the following + if mamba_d_head == "auto": + mamba_d_head = mamba_intermediate // mamba_n_heads + + if mamba_d_head * mamba_n_heads != mamba_intermediate: + raise ValueError( + "The dimensions for the Mamba head state do not match the model intermediate_size" + ) + + self.mamba_d_ssm = mamba_d_ssm + self.mamba_n_heads = mamba_n_heads + self.mamba_d_head = mamba_d_head + self.mamba_n_groups = mamba_n_groups + self.mamba_d_state = mamba_d_state + self.mamba_d_conv = mamba_d_conv + self.mamba_expand = mamba_expand + self.mamba_chunk_size = mamba_chunk_size + self.mamba_conv_bias = mamba_conv_bias + self.mamba_proj_bias = mamba_proj_bias + + self.mamba_norm_before_gate = mamba_norm_before_gate + self.mamba_rms_norm = mamba_rms_norm + + self.lm_head_multiplier = lm_head_multiplier + self.embedding_multiplier = embedding_multiplier + + if mlp_multipliers is not None: + self.mlp_multipliers = mlp_multipliers + else: + self.mlp_multipliers = [1.0, 1.0] + + if attention_out_multiplier is not None: + self.attention_out_multiplier = attention_out_multiplier + else: + self.attention_out_multiplier = 1.0 + + if attention_in_multiplier is not None: + self.attention_in_multiplier = attention_in_multiplier + else: + self.attention_in_multiplier = 1.0 + + if key_multiplier is not None: + self.key_multiplier = key_multiplier + else: + self.key_multiplier = 1.0 + + if ssm_multipliers is not None: + self.ssm_multipliers = ssm_multipliers + else: + self.ssm_multipliers = [1.0, 1.0, 1.0, 1.0, 1.0] + + if ssm_in_multiplier is not None: + self.ssm_in_multiplier = ssm_in_multiplier + else: + self.ssm_in_multiplier = 1.0 + + if ssm_out_multiplier is not None: + self.ssm_out_multiplier = ssm_out_multiplier + else: + self.ssm_out_multiplier = 1.0 + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def layers_block_type(self): + return ["falcon_h1" for i in range(self.num_hidden_layers)] + + @property + def full_attention_layer_ids(self): + # For Falcon-H1, we do have attention on all layers + return range(self.num_hidden_layers) + + @property + def linear_layer_ids(self): + # For Falcon-H1, we do have mamba on all layers + return range(self.num_hidden_layers) + + @property + def mamba2_cache_params(self): + from sglang.srt.layers.dp_attention import get_attention_tp_size + + shape = Mamba2StateShape.create( + tp_world_size=get_attention_tp_size(), + intermediate_size=self.mamba_intermediate, + n_groups=self.mamba_n_groups, + num_heads=self.mamba_n_heads, + head_dim=self.mamba_d_head, + state_size=self.mamba_d_state, + conv_kernel=self.mamba_d_conv, + ) + return Mamba2CacheParams( + shape=shape, layers=self.linear_layer_ids, dtype=mamba2_state_dtype(self) + ) diff --git a/sglang/python/sglang/srt/configs/granitemoehybrid.py b/sglang/python/sglang/srt/configs/granitemoehybrid.py new file mode 100644 index 0000000000000000000000000000000000000000..8d2a7a7d1b024ada4e0e2ff811ada5edc4bb776c --- /dev/null +++ b/sglang/python/sglang/srt/configs/granitemoehybrid.py @@ -0,0 +1,301 @@ +# coding=utf-8 +# Copyright 2025 IBM and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""GraniteMoeHybrid model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape + +logger = logging.get_logger(__name__) + +MAMBA = "mamba" +ATTENTION = "attention" + + +class GraniteMoeHybridConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GraniteMoeHybridModel`]. It is used to instantiate a + GraniteMoeHybrid model according to the specified arguments, defining the model architecture. The GraniteMoeHybrid is a + hybrid architecture combining Mamba2 layers with attention layers, developed by IBM. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 100352): + Vocabulary size of the GraniteMoeHybrid model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GraniteMoeHybridModel`] + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 40): + Number of hidden layers in the model. + layer_types (`list[str]`, *optional*): + List of layer types for each layer. Each element should be either "mamba" or "attention". + If not provided, defaults to alternating pattern based on num_hidden_layers. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.1): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + normalization_function (`str`, *optional*, defaults to `"rmsnorm"`): + The normalization function to use. Currently only "rmsnorm" is supported. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 100256): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 100257): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 100257): + The id of the "end-of-sequence" token. + max_position_embeddings (`int`, *optional*, defaults to 131072): + Max cached sequence length for the model + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in attention layers. + position_embedding_type (`str`, *optional*, defaults to `"nope"`): + Type of position embedding. Can be "nope" (no position embedding) or "rope". + rope_theta (`float`, *optional*, defaults to 10000.0): + The theta value used for the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling configuration for the RoPE embeddings. If `None`, no scaling is applied. + mamba_d_state (`int`, *optional*, defaults to 128): + The dimension of the mamba state space latents + mamba_d_conv (`int`, *optional*, defaults to 4): + The size of the mamba convolution kernel + mamba_expand (`int`, *optional*, defaults to 2): + Expanding factor (relative to hidden_size) used to determine the mamba intermediate size + mamba_d_head (`int`, *optional*, defaults to 64): + Head embedding dimension size for Mamba + mamba_n_heads (`int`, *optional*, defaults to 64): + The number of mamba heads + mamba_n_groups (`int`, *optional*, defaults to 1): + The number of the mamba groups + mamba_chunk_size (`int`, *optional*, defaults to 256): + The chunks in which to break the sequence when doing prefill/training + mamba_conv_bias (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block. + mamba_proj_bias (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in the input and output projections of the mamba mixer block + embedding_multiplier (`float`, *optional*, defaults to 12.0): + The multiplier for the embedding layer. This is used to scale the output of the embedding layer. + logits_scaling (`float`, *optional*, defaults to 8.0): + The scaling factor for the logits. + attention_multiplier (`float`, *optional*, defaults to 0.015625): + The multiplier for the attention layers. + residual_multiplier (`float`, *optional*, defaults to 0.22): + The multiplier for residual connections. + num_local_experts (`int`, *optional*, defaults to 0): + Number of local experts in MoE layers. + num_experts_per_tok (`int`, *optional*, defaults to 0): + Number of experts to use per token in MoE layers. + shared_intermediate_size (`int`, *optional*, defaults to 8192): + Intermediate size for shared experts. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether to output router logits. + router_aux_loss_coef (`float`, *optional*, defaults to 0.01): + Auxiliary loss coefficient for the router. + """ + + model_type = "granitemoehybrid" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=100352, + tie_word_embeddings=True, + hidden_size=2048, + intermediate_size=8192, + num_hidden_layers=40, + layer_types=None, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + initializer_range=0.1, + rms_norm_eps=1e-5, + normalization_function="rmsnorm", + use_cache=True, + pad_token_id=100256, + bos_token_id=100257, + eos_token_id=100257, + max_position_embeddings=131072, + attention_dropout=0.0, + attention_bias=False, + position_embedding_type="nope", + rope_theta=10000.0, + rope_scaling=None, + mamba_d_state=128, + mamba_d_conv=4, + mamba_expand=2, + mamba_d_head=64, + mamba_n_heads=64, + mamba_n_groups=1, + mamba_chunk_size=256, + mamba_conv_bias=True, + mamba_proj_bias=False, + embedding_multiplier=12.0, + logits_scaling=8.0, + attention_multiplier=0.015625, + residual_multiplier=0.22, + num_local_experts=0, + num_experts_per_tok=0, + shared_intermediate_size=8192, + output_router_logits=False, + router_aux_loss_coef=0.01, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + + # Set layer types - if not provided, create default pattern + if layer_types is None: + # Default pattern: mamba layers with attention every 6th layer (roughly) + self.layer_types = [] + for i in range(num_hidden_layers): + if (i + 1) % 6 == 0: + self.layer_types.append(ATTENTION) + else: + self.layer_types.append(MAMBA) + else: + self.layer_types = layer_types + + # Validate layer_types + if len(self.layer_types) != self.num_hidden_layers: + raise ValueError( + f"layer_types must have length equal to num_hidden_layers ({num_hidden_layers}), " + f"but got {len(self.layer_types)}" + ) + + for layer_type in self.layer_types: + if layer_type not in [MAMBA, ATTENTION]: + raise ValueError( + f"Each element in layer_types must be either '{MAMBA}' or '{ATTENTION}', " + f"but got '{layer_type}'" + ) + + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.normalization_function = normalization_function + + self.use_cache = use_cache + self.max_position_embeddings = max_position_embeddings + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + + self.position_embedding_type = position_embedding_type + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + + # Mamba configuration + self.mamba_d_state = mamba_d_state + self.mamba_d_conv = mamba_d_conv + self.mamba_expand = mamba_expand + self.mamba_d_head = mamba_d_head + self.mamba_n_heads = mamba_n_heads + self.mamba_n_groups = mamba_n_groups + self.mamba_chunk_size = mamba_chunk_size + self.mamba_conv_bias = mamba_conv_bias + self.mamba_proj_bias = mamba_proj_bias + + # Calculate mamba intermediate size + self.mamba_intermediate_size = mamba_expand * hidden_size + + # Validate mamba configuration + if self.mamba_intermediate_size % mamba_n_heads != 0: + raise ValueError( + f"mamba_intermediate_size ({self.mamba_intermediate_size}) must be divisible by " + f"mamba_n_heads ({mamba_n_heads})" + ) + + if mamba_d_head * mamba_n_heads != self.mamba_intermediate_size: + raise ValueError( + f"mamba_d_head ({mamba_d_head}) * mamba_n_heads ({mamba_n_heads}) must equal " + f"mamba_intermediate_size ({self.mamba_intermediate_size})" + ) + + # Scaling factors + self.embedding_multiplier = embedding_multiplier + self.logits_scaling = logits_scaling + self.attention_multiplier = attention_multiplier + self.residual_multiplier = residual_multiplier + + # MoE configuration + self.num_local_experts = num_local_experts + self.num_experts_per_tok = num_experts_per_tok + self.shared_intermediate_size = shared_intermediate_size + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def mamba_layer_ids(self): + """Returns the indices of layers that are Mamba layers.""" + return [ + i for i in range(self.num_hidden_layers) if self.layer_types[i] == MAMBA + ] + + @property + def attention_layer_ids(self): + """Returns the indices of layers that are attention layers.""" + return [ + i for i in range(self.num_hidden_layers) if self.layer_types[i] == ATTENTION + ] + + @property + def full_attention_layer_ids(self): + """Alias for attention_layer_ids for compatibility.""" + return self.attention_layer_ids + + @property + def mamba2_cache_params(self): + """Returns the Mamba2 cache parameters for this configuration.""" + from sglang.srt.layers.dp_attention import get_attention_tp_size + + shape = Mamba2StateShape.create( + tp_world_size=get_attention_tp_size(), + intermediate_size=self.mamba_intermediate_size, + n_groups=self.mamba_n_groups, + num_heads=self.mamba_n_heads, + head_dim=self.mamba_d_head, + state_size=self.mamba_d_state, + conv_kernel=self.mamba_d_conv, + ) + return Mamba2CacheParams(shape=shape, layers=self.mamba_layer_ids) diff --git a/sglang/python/sglang/srt/configs/internvl.py b/sglang/python/sglang/srt/configs/internvl.py new file mode 100644 index 0000000000000000000000000000000000000000..3ba9c61c10e043bcbebe34b1760790ebfc120c4b --- /dev/null +++ b/sglang/python/sglang/srt/configs/internvl.py @@ -0,0 +1,706 @@ +import copy +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import sentencepiece as spm +from transformers import ( + TOKENIZER_MAPPING, + GptOssConfig, + LlamaConfig, + PretrainedConfig, + PreTrainedTokenizer, + Qwen2Config, + Qwen3Config, + Qwen3MoeConfig, +) + +from sglang.utils import logger + +# Copied from: https://github.com/OpenGVLab/InternVL/blob/34a81000402bf8f716bab8c9b57aff1f6b436bd0/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py#L21 + + +VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"} + +PRETRAINED_VOCAB_FILES_MAP = {} + + +# Modified from transformers.model.llama.configuration_llama.LlamaConfig +class InternLM2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate + an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the InternLM2-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`InternLM2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + Example: + + """ + + model_type = "internlm2" + _auto_class = "AutoConfig" + + def __init__( # pylint: disable=W0102 + self, + vocab_size=103168, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + bias=True, + rope_theta=10000, + rope_scaling=None, + attn_implementation="eager", + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.bias = bias + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + + self.attn_implementation = attn_implementation + if self.attn_implementation is None: + self.attn_implementation = "eager" + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if ( + rope_scaling_factor is None + or not isinstance(rope_scaling_factor, (float, int)) + or rope_scaling_factor < 1.0 + ): + raise ValueError( + f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor=}, {type(rope_scaling_factor)=}" + ) + if isinstance(rope_scaling_factor, int): + rope_scaling_factor = float(rope_scaling_factor) + + +class InternVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to + instantiate a vision encoder according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + Number of color channels in the input images (e.g., 3 for RGB). + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + qkv_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the queries and values in the self-attention layers. + hidden_size (`int`, *optional*, defaults to 3200): + Dimensionality of the encoder layers and the pooler layer. + num_attention_heads (`int`, *optional*, defaults to 25): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 12800): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + qk_normalization (`bool`, *optional*, defaults to `True`): + Whether to normalize the queries and keys in the self-attention layers. + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the Transformer encoder. + use_flash_attn (`bool`, *optional*, defaults to `True`): + Whether to use flash attention mechanism. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + drop_path_rate (`float`, *optional*, defaults to 0.0): + Dropout rate for stochastic depth. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 0.1): + A factor for layer scale. + """ + + model_type = "intern_vit_6b" + + def __init__( + self, + num_channels=3, + patch_size=14, + image_size=224, + qkv_bias=False, + hidden_size=3200, + num_attention_heads=25, + intermediate_size=12800, + qk_normalization=True, + num_hidden_layers=48, + use_flash_attn=True, + hidden_act="gelu", + layer_norm_eps=1e-6, + dropout=0.0, + drop_path_rate=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=0.1, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.dropout = dropout + self.drop_path_rate = drop_path_rate + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.qkv_bias = qkv_bias + self.qk_normalization = qk_normalization + self.use_flash_attn = use_flash_attn + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> "PretrainedConfig": + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + + if "vision_config" in config_dict: + config_dict = config_dict["vision_config"] + + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class InternVLChatConfig(PretrainedConfig): + model_type = "internvl_chat" + is_composition = True + + def __init__( + self, + vision_config=None, + llm_config=None, + use_backbone_lora=0, + use_llm_lora=0, + pad2square=False, + select_layer=-1, + force_image_size=None, + downsample_ratio=0.5, + template=None, + dynamic_image_size=False, + use_thumbnail=False, + ps_version="v1", + min_dynamic_patch=1, + max_dynamic_patch=6, + **kwargs, + ): + super().__init__(**kwargs) + + if vision_config is None: + vision_config = {"architectures": ["InternVisionModel"]} + logger.info( + "vision_config is None. Initializing the InternVisionConfig with default values." + ) + + if llm_config is None: + llm_config = {"architectures": ["InternLM2ForCausalLM"]} + logger.info( + "llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)." + ) + + self.vision_config = InternVisionConfig(**vision_config) + if llm_config.get("architectures")[0] == "LlamaForCausalLM": + self.llm_config = LlamaConfig(**llm_config) + elif llm_config.get("architectures")[0] == "InternLM2ForCausalLM": + self.llm_config = InternLM2Config(**llm_config) + elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM": + self.llm_config = Qwen2Config(**llm_config) + elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM": + self.llm_config = Qwen3MoeConfig(**llm_config) + elif llm_config.get("architectures")[0] == "Qwen3ForCausalLM": + self.llm_config = Qwen3Config(**llm_config) + elif llm_config.get("architectures")[0] == "GptOssForCausalLM": + self.llm_config = GptOssConfig(**llm_config) + else: + raise ValueError( + "Unsupported architecture: {}".format( + llm_config.get("architectures")[0] + ) + ) + + self.use_backbone_lora = use_backbone_lora + self.use_llm_lora = use_llm_lora + self.pad2square = pad2square + self.select_layer = select_layer + self.force_image_size = force_image_size + self.downsample_ratio = downsample_ratio + self.template = template + self.dynamic_image_size = dynamic_image_size + self.use_thumbnail = use_thumbnail + self.ps_version = ps_version # pixel shuffle version + self.min_dynamic_patch = min_dynamic_patch + self.max_dynamic_patch = max_dynamic_patch + + self.hidden_size = self.llm_config.hidden_size + # By default, we use tie_word_embeddings=False for models of all sizes. + self.tie_word_embeddings = False + self.llm_config.tie_word_embeddings = self.tie_word_embeddings + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["vision_config"] = self.vision_config.to_dict() + output["llm_config"] = self.llm_config.to_dict() + output["model_type"] = self.__class__.model_type + output["use_backbone_lora"] = self.use_backbone_lora + output["use_llm_lora"] = self.use_llm_lora + output["select_layer"] = self.select_layer + output["force_image_size"] = self.force_image_size + output["downsample_ratio"] = self.downsample_ratio + output["template"] = self.template + output["dynamic_image_size"] = self.dynamic_image_size + output["use_thumbnail"] = self.use_thumbnail + output["ps_version"] = self.ps_version + output["min_dynamic_patch"] = self.min_dynamic_patch + output["max_dynamic_patch"] = self.max_dynamic_patch + + return output + + +# # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast +# class InternLM2TokenizerFast(PreTrainedTokenizerFast): +# vocab_files_names = VOCAB_FILES_NAMES +# slow_tokenizer_class = InternLM2Tokenizer +# padding_side = 'left' +# model_input_names = ['input_ids', 'attention_mask'] +# _auto_class = 'AutoTokenizer' +# +# def __init__( +# self, +# vocab_file, +# unk_token='', +# bos_token='', +# eos_token='', +# pad_token='', +# sp_model_kwargs: Optional[Dict[str, Any]] = None, +# add_bos_token=True, +# add_eos_token=False, +# decode_with_prefix_space=False, +# clean_up_tokenization_spaces=False, +# **kwargs, +# ): +# super().__init__( +# vocab_file=vocab_file, +# unk_token=unk_token, +# bos_token=bos_token, +# eos_token=eos_token, +# pad_token=pad_token, +# sp_model_kwargs=sp_model_kwargs, +# add_bos_token=add_bos_token, +# add_eos_token=add_eos_token, +# decode_with_prefix_space=decode_with_prefix_space, +# clean_up_tokenization_spaces=clean_up_tokenization_spaces, +# **kwargs, +# ) +# self._add_bos_token = add_bos_token +# self._add_eos_token = add_eos_token +# self.update_post_processor() +# self.vocab_file = vocab_file +# +# @property +# def can_save_slow_tokenizer(self) -> bool: +# return os.path.isfile(self.vocab_file) if self.vocab_file else False +# +# def update_post_processor(self): +# """ +# Updates the underlying post processor with the current `bos_token` and `eos_token`. +# """ +# bos = self.bos_token +# bos_token_id = self.bos_token_id +# if bos is None and self.add_bos_token: +# raise ValueError('add_bos_token = True but bos_token = None') +# +# eos = self.eos_token +# eos_token_id = self.eos_token_id +# if eos is None and self.add_eos_token: +# raise ValueError('add_eos_token = True but eos_token = None') +# +# single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}" +# pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}" +# +# special_tokens = [] +# if self.add_bos_token: +# special_tokens.append((bos, bos_token_id)) +# if self.add_eos_token: +# special_tokens.append((eos, eos_token_id)) +# self._tokenizer.post_processor = processors.TemplateProcessing( +# single=single, pair=pair, special_tokens=special_tokens +# ) +# +# @property +# def add_eos_token(self): +# return self._add_eos_token +# +# @property +# def add_bos_token(self): +# return self._add_bos_token +# +# @add_eos_token.setter +# def add_eos_token(self, value): +# self._add_eos_token = value +# self.update_post_processor() +# +# @add_bos_token.setter +# def add_bos_token(self, value): +# self._add_bos_token = value +# self.update_post_processor() +# +# def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: +# if not self.can_save_slow_tokenizer: +# raise ValueError( +# 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow ' +# 'tokenizer.' +# ) +# +# if not os.path.isdir(save_directory): +# logger.error(f'Vocabulary path ({save_directory}) should be a directory') +# return +# out_vocab_file = os.path.join( +# save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file'] +# ) +# +# if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): +# copyfile(self.vocab_file, out_vocab_file) +# +# return (out_vocab_file,) + + +# Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer +class InternLM2Tokenizer(PreTrainedTokenizer): + """ + Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + model_input_names = ["input_ids", "attention_mask"] + _auto_class = "AutoTokenizer" + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token="", + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + decode_with_prefix_space=False, + clean_up_tokenization_spaces=False, + **kwargs, + ): + print("register succeed") + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.decode_with_prefix_space = decode_with_prefix_space + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + self._no_prefix_space_tokens = None + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def no_prefix_space_tokens(self): + if self._no_prefix_space_tokens is None: + vocab = self.convert_ids_to_tokens(list(range(self.vocab_size))) + self._no_prefix_space_tokens = { + i for i, tok in enumerate(vocab) if not tok.startswith("▁") + } + return self._no_prefix_space_tokens + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + @property + def bos_token_id(self) -> Optional[int]: + return self.sp_model.bos_id() + + @property + def eos_token_id(self) -> Optional[int]: + return self.sp_model.eos_id() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text): + """Returns a tokenized string.""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def _maybe_add_prefix_space(self, tokens, decoded): + if tokens and tokens[0] not in self.no_prefix_space_tokens: + return " " + decoded + else: + return decoded + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + out_string = self.clean_up_tokenization(out_string) + out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string) + return out_string[1:] + + def save_vocabulary( + self, save_directory, filename_prefix: Optional[str] = None + ) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"], + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath( + out_vocab_file + ) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + if self.add_bos_token: + bos_token_ids = [self.bos_token_id] + else: + bos_token_ids = [] + + output = bos_token_ids + token_ids_0 + + if token_ids_1 is not None: + output = output + token_ids_1 + + if self.add_eos_token: + output = output + [self.eos_token_id] + + return output + + def get_special_tokens_mask( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False, + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True, + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + + +TOKENIZER_MAPPING.register( + InternVLChatConfig, (InternLM2Tokenizer, None), exist_ok=True +) diff --git a/sglang/python/sglang/srt/configs/janus_pro.py b/sglang/python/sglang/srt/configs/janus_pro.py new file mode 100644 index 0000000000000000000000000000000000000000..d574953e95d92d4c68c67af8671f422e2be93f7f --- /dev/null +++ b/sglang/python/sglang/srt/configs/janus_pro.py @@ -0,0 +1,634 @@ +# Adapted from: +# https://github.com/deepseek-ai/Janus/tree/main/janus/models + +from dataclasses import dataclass +from typing import Dict, List, Tuple, Union + +import numpy as np +import PIL +import torch +from PIL.Image import Image +from transformers import ( + BaseImageProcessor, + BatchFeature, + LlamaConfig, + LlamaTokenizerFast, + PretrainedConfig, + ProcessorMixin, +) +from transformers.image_utils import to_numpy_array + +from sglang.srt.configs.utils import register_image_processor, register_processor +from sglang.srt.multimodal.mm_utils import expand2square + + +class DictToObject(dict): + def __init__(self, dictionary): + super(self).__init__(dictionary) + + for key, value in dictionary.items(): + if isinstance(value, dict): + value = DictToObject(value) + setattr(self, key, value) + + +class VisionConfig(PretrainedConfig): + model_type = "vision" + cls: str = "" + params = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = kwargs.get("params", {}) + + +class GenAlignerConfig(PretrainedConfig): + model_type = "gen_aligner" + cls: str = "" + params = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = kwargs.get("params", {}) + + +class GenHeadConfig(PretrainedConfig): + model_type = "gen_head" + cls: str = "" + params = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = kwargs.get("params", {}) + + +class AlignerConfig(PretrainedConfig): + model_type = "aligner" + cls: str = "" + params = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = kwargs.get("params", {}) + + +class GenVisionConfig(PretrainedConfig): + model_type = "gen_vision" + cls: str = "" + params = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = kwargs.get("params", {}) + + +@dataclass +class SigLIPVisionCfg: + width: int = 1152 + layers: Union[Tuple[int, int, int, int], int] = 27 + heads: int = 16 + patch_size: int = 14 + image_size: Union[Tuple[int, int], int] = 336 + global_pool: str = "map" + mlp_ratio: float = 3.7362 + class_token: bool = False + num_classes: int = 0 + use_checkpoint: bool = False + + +class MultiModalityConfig(PretrainedConfig): + model_type = "multi_modality" + vision_config: VisionConfig + aligner_config: AlignerConfig + + gen_vision_config: GenVisionConfig + gen_aligner_config: GenAlignerConfig + gen_head_config: GenHeadConfig + + language_config: LlamaConfig + + def __init__(self, **kwargs): + super().__init__(**kwargs) + vision_config = kwargs.get("vision_config", {}) + self.vision_config = VisionConfig(**vision_config) + + aligner_config = kwargs.get("aligner_config", {}) + self.aligner_config = AlignerConfig(**aligner_config) + + gen_vision_config = kwargs.get("gen_vision_config", {}) + self.gen_vision_config = GenVisionConfig(**gen_vision_config) + + gen_aligner_config = kwargs.get("gen_aligner_config", {}) + self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config) + + gen_head_config = kwargs.get("gen_head_config", {}) + self.gen_head_config = GenHeadConfig(**gen_head_config) + + language_config = kwargs.get("language_config", {}) + if isinstance(language_config, LlamaConfig): + self.language_config = language_config + else: + self.language_config = LlamaConfig(**language_config) + + +class VLMImageProcessor(BaseImageProcessor): + model_input_names = ["pixel_values"] + + def __init__( + self, + image_size: int, + min_size: int = 14, + image_mean: Union[Tuple[float, float, float], List[float]] = ( + 0.48145466, + 0.4578275, + 0.40821073, + ), + image_std: Union[Tuple[float, float, float], List[float]] = ( + 0.26862954, + 0.26130258, + 0.27577711, + ), + rescale_factor: float = 1.0 / 255.0, + do_normalize: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.rescale_factor = rescale_factor + self.image_mean = image_mean + self.image_std = image_std + self.min_size = min_size + self.do_normalize = do_normalize + + if image_mean is None: + self.background_color = (127, 127, 127) + else: + self.background_color = tuple([int(x * 255) for x in image_mean]) + + def resize(self, pil_img: Image) -> np.ndarray: + """ + + Args: + pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB + + Returns: + x (np.ndarray): [3, self.image_size, self.image_size] + """ + + width, height = pil_img.size + max_size = max(width, height) + + size = [ + max(int(height / max_size * self.image_size), self.min_size), + max(int(width / max_size * self.image_size), self.min_size), + ] + + if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0: + # print(f"orig size = {pil_img.size}, new size = {size}") + raise ValueError("Invalid size!") + + def resize( + pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True + ): + if isinstance(size, int): + w, h = pil_img.size + if (w <= h and w == size) or (h <= w and h == size): + return pil_img + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + size = (ow, oh) + else: + size = (size[1], size[0]) + + return pil_img.resize( + size, resample=interpolation, reducing_gap=None if antialias else 3.0 + ) + + pil_img = resize( + pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True + ) + + pil_img = expand2square(pil_img, self.background_color) + x = to_numpy_array(pil_img) + + # [H, W, 3] -> [3, H, W] + x = np.transpose(x, (2, 0, 1)) + + return x + + def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature: + # resize and pad to [self.image_size, self.image_size] + # then convert from [H, W, 3] to [3, H, W] + if not isinstance(images, list): + images = [images] + images: List[np.ndarray] = [self.resize(image) for image in images] + images = [image[:3, ...] for image in images] + + # rescale from [0, 255] -> [0, 1] + images = [ + self.rescale( + image=image, + scale=self.rescale_factor, + input_data_format="channels_first", + ) + for image in images + ] + + # normalize + if self.do_normalize: + images = [ + self.normalize( + image=image, + mean=self.image_mean, + std=self.image_std, + input_data_format="channels_first", + ) + for image in images + ] + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + @property + def default_shape(self): + return [3, self.image_size, self.image_size] + + +class DictOutput(object): + def items(self): + return self.__dict__.items() + + def keys(self): + return self.__dict__.keys() + + def __getitem__(self, item): + return self.__dict__[item] + + def __contains__(self, key): + return key in self.__dict__ + + def __setitem__(self, key, value): + self.__dict__[key] = value + + +@dataclass +class VLChatProcessorOutput(DictOutput): + sft_format: str + input_ids: torch.Tensor + pixel_values: torch.Tensor + num_image_tokens: torch.IntTensor + + def __len__(self): + return len(self.input_ids) + + +@dataclass +class BatchedVLChatProcessorOutput(DictOutput): + sft_format: List[str] + input_ids: torch.Tensor + pixel_values: torch.Tensor + attention_mask: torch.Tensor + images_seq_mask: torch.BoolTensor + images_emb_mask: torch.BoolTensor + + +# FIXME: had to place Official Processor here, since image_processor module would not be imported in all threads, +# hence AutoProcessor registration would not be affective in some cases +class VLChatProcessor(ProcessorMixin): + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + + attributes = ["image_processor", "tokenizer"] + + def __init__( + self, + image_processor: VLMImageProcessor, + tokenizer: LlamaTokenizerFast, + image_tag: str = "", + image_start_tag: str = "", + image_end_tag: str = "", + pad_tag: str = "<|▁pad▁|>", + num_image_tokens: int = 576, + add_special_token: bool = False, + sft_format: str = "deepseek", + mask_prompt: bool = True, + ignore_id: int = -100, + **kwargs, + ): + self.image_processor = image_processor + self.tokenizer = tokenizer + + image_id = self.tokenizer.vocab.get(image_tag) + if image_id is None: + special_tokens = [image_tag] + special_tokens_dict = {"additional_special_tokens": special_tokens} + self.tokenizer.add_special_tokens(special_tokens_dict) + # print(f"Add image tag = {image_tag} to the tokenizer") + + self.image_tag = image_tag + self.image_start_tag = image_start_tag + self.image_end_tag = image_end_tag + self.pad_tag = pad_tag + + self.num_image_tokens = num_image_tokens + self.add_special_token = add_special_token + self.sft_format = sft_format + self.ignore_id = ignore_id + + super().__init__( + image_processor, + tokenizer, + **kwargs, + ) + + @property + def image_token(self): + return self.image_tag + + @property + def image_id(self) -> int: + image_id = self.tokenizer.vocab.get(self.image_tag) + return image_id + + @property + def image_start_id(self): + image_start_id = self.tokenizer.vocab.get(self.image_start_tag) + return image_start_id + + @property + def image_end_id(self): + image_end_id = self.tokenizer.vocab.get(self.image_end_tag) + return image_end_id + + @property + def image_start_token(self): + return self.image_start_tag + + @property + def image_end_token(self): + return self.image_end_tag + + @property + def pad_id(self): + pad_id = self.tokenizer.vocab.get(self.pad_tag) + return pad_id + + def add_image_token( + self, + image_indices: List[int], + input_ids: torch.LongTensor, + ): + """ + + Args: + image_indices (List[int]): [index_0, index_1, ..., index_j] + input_ids (torch.LongTensor): [N] + + Returns: + input_ids (torch.LongTensor): [N + image tokens] + num_image_tokens (torch.IntTensor): [n_images] + """ + + input_slices = [] + + start = 0 + for index in image_indices: + if self.add_special_token: + end = index + 1 + else: + end = index + + # original text tokens + input_slices.append(input_ids[start:end]) + + # add boi, image tokens, eoi and set the mask as False + input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long)) + input_slices.append( + self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long) + ) + input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long)) + start = index + 1 + + # the left part + input_slices.append(input_ids[start:]) + + # concat all slices + input_ids = torch.cat(input_slices, dim=0) + num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices)) + + return input_ids, num_image_tokens + + def process_one( + self, + prompt: str = None, + images: List[Image] = None, + **kwargs, + ): + """ + + Args: + prompt (str): the formatted prompt; + images (List[ImageType]): the list of images; + **kwargs: + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - target_ids (torch.LongTensor): [N + image tokens] + - images (torch.FloatTensor): [n_images, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + sft_format = prompt + # tokenize + input_ids = self.tokenizer.encode(sft_format) + input_ids = torch.LongTensor(input_ids) + + # add image tokens to the input_ids + image_token_mask: torch.Tensor = (input_ids == self.image_id).to(torch.bool) + image_indices = image_token_mask.nonzero() + input_ids, num_image_tokens = self.add_image_token( + image_indices=image_indices, + input_ids=input_ids, + ) + + # load images + images_outputs = self.image_processor(images, return_tensors="pt") + + prepare = VLChatProcessorOutput( + sft_format=sft_format, + input_ids=input_ids, + pixel_values=images_outputs.pixel_values, + num_image_tokens=num_image_tokens, + ) + + return prepare + + def __call__( + self, + *, + prompt: str = None, + conversations: List[Dict[str, str]] = None, + images: List[Image] = None, + force_batchify: bool = True, + **kwargs, + ): + """ + + Args: + prompt (str): the formatted prompt; + conversations (List[Dict]): conversations with a list of messages; + images (List[ImageType]): the list of images; + force_batchify (bool): force batchify the inputs; + **kwargs: + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - images (torch.FloatTensor): [n_images, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + prepare = self.process_one( + prompt=prompt, conversations=conversations, images=images + ) + + if force_batchify: + prepare = self.batchify([prepare]) + + return prepare + + def batchify( + self, prepare_list: List[VLChatProcessorOutput] + ) -> BatchedVLChatProcessorOutput: + """ + Preprocesses the inputs for multimodal inference. + + Args: + prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput. + + Returns: + BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference. + """ + + batch_size = len(prepare_list) + sft_format = [] + n_images = [] + seq_lens = [] + for prepare in prepare_list: + n_images.append(len(prepare.num_image_tokens)) + seq_lens.append(len(prepare)) + + input_token_max_len = max(seq_lens) + max_n_images = max(1, max(n_images)) + + batched_input_ids = torch.full( + (batch_size, input_token_max_len), self.pad_id + ).long() # FIXME + batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long() + batched_pixel_values = torch.zeros( + (batch_size, max_n_images, *self.image_processor.default_shape) + ).float() + batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool() + batched_images_emb_mask = torch.zeros( + (batch_size, max_n_images, self.num_image_tokens) + ).bool() + + for i, prepare in enumerate(prepare_list): + input_ids = prepare.input_ids + seq_len = len(prepare) + n_image = len(prepare.num_image_tokens) + # left-padding + batched_attention_mask[i, -seq_len:] = 1 + batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids) + batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id + + if n_image > 0: + batched_pixel_values[i, :n_image] = prepare.pixel_values + for j, n_image_tokens in enumerate(prepare.num_image_tokens): + batched_images_emb_mask[i, j, :n_image_tokens] = True + + sft_format.append(prepare.sft_format) + + batched_prepares = BatchedVLChatProcessorOutput( + input_ids=batched_input_ids, + attention_mask=batched_attention_mask, + pixel_values=batched_pixel_values, + images_seq_mask=batched_images_seq_mask, + images_emb_mask=batched_images_emb_mask, + sft_format=sft_format, + ) + + return batched_prepares + + +class VLMImageProcessorConfig(PretrainedConfig): + model_type = "deepseek_vlm" + image_size: int + min_size: int + image_mean: Union[Tuple[float, float, float], List[float]] + image_std: Union[Tuple[float, float, float], List[float]] + rescale_factor: float + do_normalize: bool + + def __init__( + self, + image_size: int, + min_size: int = 14, + image_mean: Union[Tuple[float, float, float], List[float]] = ( + 0.48145466, + 0.4578275, + 0.40821073, + ), + image_std: Union[Tuple[float, float, float], List[float]] = ( + 0.26862954, + 0.26130258, + 0.27577711, + ), + rescale_factor: float = 1.0 / 255.0, + do_normalize: bool = True, + **kwargs, + ): + self.image_size = image_size + self.min_size = min_size + self.image_mean = image_mean + self.image_std = image_std + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + + super().__init__(**kwargs) + + +register_processor(MultiModalityConfig, VLChatProcessor) +register_image_processor(MultiModalityConfig, VLMImageProcessor) diff --git a/sglang/python/sglang/srt/configs/jet_nemotron.py b/sglang/python/sglang/srt/configs/jet_nemotron.py new file mode 100644 index 0000000000000000000000000000000000000000..1670da3b67f592df7c81c9ea89923afccf651df5 --- /dev/null +++ b/sglang/python/sglang/srt/configs/jet_nemotron.py @@ -0,0 +1,80 @@ +from dataclasses import dataclass +from typing import Any + +from transformers.configuration_utils import PretrainedConfig + +from sglang.srt.configs.mamba_utils import ( + Mamba2CacheParams, + Mamba2StateShape, + mamba2_state_dtype, +) + + +@dataclass +class JetBlockConfig: + mode: str + expand_v: float + num_heads: int + head_dim: int + norm_eps: str + conv_size: int + dconv_generator_reduction: int + dconv_implementation: str + + +class JetNemotronConfig(PretrainedConfig): + model_type: str = "jet_nemotron" + + efficient_attention_config: dict[str, dict[str, Any]] + hidden_act: str + hidden_size: int + initializer_range: float + intermediate_size: int + layer_types: list[str] + max_position_embeddings: int + num_attention_heads: int + num_key_value_heads: int + rms_norm_eps: float + rope_scaling: None + rope_theta: float + + @property + def full_attention_layer_ids(self) -> list[int]: + return [ + idx + for idx, layer_type in enumerate(self.layer_types) + if layer_type in ("attn", "swa") + ] + + @property + def linear_layer_ids(self) -> list[int]: + return [ + idx + for idx, layer_type in enumerate(self.layer_types) + if layer_type == "jet" + ] + + @property + def mamba2_cache_params(self) -> Mamba2CacheParams: + from sglang.srt.layers.dp_attention import get_attention_tp_size + + jet_block_config = JetBlockConfig(**self.efficient_attention_config["jet"]) + + num_heads = jet_block_config.num_heads + head_k_dim = jet_block_config.head_dim + head_v_dim = int(head_k_dim * jet_block_config.expand_v) + total_v_dim = num_heads * head_v_dim + + shape = Mamba2StateShape.create( + tp_world_size=get_attention_tp_size(), + intermediate_size=total_v_dim, + n_groups=num_heads, + num_heads=num_heads, + head_dim=head_v_dim, + state_size=head_k_dim, + conv_kernel=jet_block_config.conv_size, + ) + + return Mamba2CacheParams( + shape=shape, layers=self.linear_layer_ids, dtype=mamba2_state_dtype(self) + ) diff --git a/sglang/python/sglang/srt/configs/jet_vlm.py b/sglang/python/sglang/srt/configs/jet_vlm.py new file mode 100644 index 0000000000000000000000000000000000000000..9b8cba6e157bd7bf1a61afba1854724cb75d54d2 --- /dev/null +++ b/sglang/python/sglang/srt/configs/jet_vlm.py @@ -0,0 +1,53 @@ +from typing import Any + +from transformers.configuration_utils import PretrainedConfig +from transformers.models.siglip import SiglipVisionConfig + +from sglang.srt.configs.jet_nemotron import JetNemotronConfig +from sglang.srt.configs.mamba_utils import Mamba2CacheParams + + +class JetVLMConfig(PretrainedConfig): + model_type = "jet_vlm" + sub_configs = { + "text_config": JetNemotronConfig, + "vision_config": SiglipVisionConfig, + } + _auto_class = "AutoConfig" + + def __init__( + self, + *, + text_config: dict[str, Any] | None = None, + vision_config: dict[str, Any] | None = None, + image_token_id: int | None = None, + video_token_id: int | None = None, + **kwargs, + ): + self.text_config = ( + JetNemotronConfig(**text_config) + if text_config is not None + else JetNemotronConfig() + ) + self.vision_config = ( + SiglipVisionConfig(**vision_config) + if vision_config is not None + else SiglipVisionConfig() + ) + + self.image_token_id = image_token_id if image_token_id is not None else -1 + self.video_token_id = video_token_id if video_token_id is not None else -1 + + super().__init__(**kwargs) + + @property + def full_attention_layer_ids(self) -> list[int]: + return self.text_config.full_attention_layer_ids + + @property + def linear_layer_ids(self) -> list[int]: + return self.text_config.linear_layer_ids + + @property + def mamba2_cache_params(self) -> Mamba2CacheParams: + return self.text_config.mamba2_cache_params diff --git a/sglang/python/sglang/srt/configs/kimi_k25.py b/sglang/python/sglang/srt/configs/kimi_k25.py new file mode 100644 index 0000000000000000000000000000000000000000..1ea8e7d89eef16721fdb75ce1d55279c34d358ab --- /dev/null +++ b/sglang/python/sglang/srt/configs/kimi_k25.py @@ -0,0 +1,171 @@ +""" +Kimi K25 Model Configuration. +""" + +from transformers import DeepseekV3Config +from transformers.configuration_utils import PretrainedConfig + + +class KimiK25VisionConfig(PretrainedConfig): + """Vision configuration for K2-VL (vision tower + mm projector). + + Args: + Vision Tower Parameters: + patch_size: Patch size for vision tower. + init_pos_emb_height: Initial position embedding height. + init_pos_emb_width: Initial position embedding width. + init_pos_emb_time: Initial position embedding time dimension. + pos_emb_type: Type of position embedding. + num_attention_heads: Number of attention heads in vision tower. + num_hidden_layers: Number of hidden layers in vision tower. + hidden_size: Hidden size of vision tower. + intermediate_size: Intermediate size in vision tower FFN. + merge_kernel_size: Kernel size for spatial patch merging. + video_attn_type: Type of video attention. + merge_type: Type of merge operation. + + MM Projector Parameters: + mm_projector_type: Type of multimodal projector. + mm_hidden_size: Hidden size for projector (defaults to hidden_size). + projector_hidden_act: Activation function for projector. + projector_ln_eps: Layer norm epsilon for projector. + """ + + model_type = "kimi_k25" + + def __init__( + self, + # Vision Tower + patch_size: int = 14, + init_pos_emb_height: int = 64, + init_pos_emb_width: int = 64, + init_pos_emb_time: int = 4, + pos_emb_type: str = "divided_fixed", + num_attention_heads: int = 16, + num_hidden_layers: int = 27, + hidden_size: int = 1152, + intermediate_size: int = 4304, + merge_kernel_size: tuple[int, int] = (2, 2), + video_attn_type: str = "spatial_temporal", + merge_type: str = "sd2_tpool", + # MM Projector + mm_projector_type: str = "patchmerger", + mm_hidden_size: int | None = None, + projector_hidden_act: str = "gelu", + projector_ln_eps: float = 1e-5, + text_hidden_size: int = 7168, + **kwargs, + ): + super().__init__(**kwargs) + # Vision Tower + self.patch_size = patch_size + self.init_pos_emb_height = init_pos_emb_height + self.init_pos_emb_width = init_pos_emb_width + self.init_pos_emb_time = init_pos_emb_time + self.pos_emb_type = pos_emb_type + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.merge_kernel_size = merge_kernel_size + self.video_attn_type = video_attn_type + self.merge_type = merge_type + # MM Projector + self.mm_projector_type = mm_projector_type + if mm_hidden_size is not None: + self.mm_hidden_size = mm_hidden_size + else: + self.mm_hidden_size = hidden_size + self.projector_hidden_act = projector_hidden_act + self.projector_ln_eps = projector_ln_eps + self.text_hidden_size = text_hidden_size + + +class KimiK25Config(PretrainedConfig): + """K2-VL model configuration. + + K2-VL extends Kimi-VL with video support using video-chunks. + A video-chunk consists of multiple consecutive frames (default: 4) + that are processed together with temporal pooling. + + Args: + text_config: Configuration for the text model (DeepseekV3). + + Vision Tower Parameters: + patch_size: Patch size for vision tower. + init_pos_emb_height: Initial position embedding height. + init_pos_emb_width: Initial position embedding width. + init_pos_emb_time: Initial position embedding time dimension. + pos_emb_type: Type of position embedding. + vt_num_attention_heads: Number of attention heads in vision tower. + vt_num_hidden_layers: Number of hidden layers in vision tower. + vt_hidden_size: Hidden size of vision tower. + vt_intermediate_size: Intermediate size in vision tower FFN. + merge_kernel_size: Kernel size for spatial patch merging. + video_attn_type: Type of video attention. + merge_type: Type of merge operation. + + Video-Chunk Parameters: + temporal_merge_kernel_size: Number of frames per video chunk. + Default is 4, meaning 4 frames are merged into 1 chunk. + sample_fps: Video sampling frame rate. + timestamp_mode: Format for chunk timestamps. + + MM Projector Parameters: + mm_projector_type: Type of multimodal projector. + mm_hidden_size: Hidden size from vision tower. + projector_hidden_act: Activation function for projector. + projector_ln_eps: Layer norm epsilon for projector. + + Other Parameters: + ignore_index: The ignore index for the loss function. + media_placeholder_token_id: The token ID for media placeholders. + pad_token_id: The token ID for padding. + """ + + model_type = "kimi_k25" + + def __init__( + self, + text_config: dict | DeepseekV3Config | None = None, + vision_config: dict | KimiK25VisionConfig | None = None, + # Other parameters + ignore_index: int = -100, + media_placeholder_token_id: int = 163605, + pad_token_id: int = 0, + use_unified_vision_chunk: bool = False, + video_placeholder: str = "<|kimi_k25_video_placeholder|>", + **kwargs, + ): + if text_config is None: + text_config = DeepseekV3Config() + elif isinstance(text_config, dict): + text_config = DeepseekV3Config(**text_config) + + if vision_config is None: + vision_config = KimiK25VisionConfig() + elif isinstance(vision_config, dict): + vision_config = KimiK25VisionConfig(**vision_config) + self.vision_config = vision_config + self.text_config = text_config + # Other config + self.ignore_index = ignore_index + self.media_placeholder_token_id = media_placeholder_token_id + self.use_unified_vision_chunk = use_unified_vision_chunk + self.video_placeholder = video_placeholder + + # Propagate quantization config from text model + if getattr(self.text_config, "quantization_config", None) is not None: + self.quantization_config = self.text_config.quantization_config + + super().__init__(pad_token_id=pad_token_id, **kwargs) + + @property + def hidden_size(self) -> int: + """Get hidden size from text config for compatibility.""" + return self.text_config.hidden_size + + @property + def vocab_size(self) -> int: + """Get vocab size from text config for compatibility.""" + return self.text_config.vocab_size diff --git a/sglang/python/sglang/srt/configs/kimi_linear.py b/sglang/python/sglang/srt/configs/kimi_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..e73609044a1d8cd93709ddd8f8596d95dd81ba0b --- /dev/null +++ b/sglang/python/sglang/srt/configs/kimi_linear.py @@ -0,0 +1,161 @@ +# Adapted from: https://github.com/vllm-project/vllm/blob/0384aa7150c4c9778efca041ffd1beb3ad2bd694/vllm/transformers_utils/configs/kimi_linear.py +from transformers.configuration_utils import PretrainedConfig + +from sglang.srt.configs.mamba_utils import KimiLinearCacheParams, KimiLinearStateShape + + +class KimiLinearConfig(PretrainedConfig): + model_type = "kimi_linear" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + model_type="kimi_linear", + vocab_size=163840, + hidden_size=4096, + head_dim=None, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + rope_theta=10000.0, + rope_scaling=None, + tie_word_embeddings=False, + moe_intermediate_size: int | None = None, + moe_renormalize: bool = True, + moe_router_activation_func: str = "sigmoid", + num_experts: int | None = None, + num_experts_per_token: int | None = None, + num_shared_experts: int = 0, + routed_scaling_factor: float = 1.0, + first_k_dense_replace: int = 0, + moe_layer_freq: int = 1, + use_grouped_topk: bool = True, + num_expert_group: int = 1, + topk_group: int = 1, + q_lora_rank: int | None = None, + kv_lora_rank: int | None = None, + qk_nope_head_dim: int | None = None, + qk_rope_head_dim: int | None = None, + v_head_dim: int | None = None, + mla_use_nope: bool | None = False, + num_nextn_predict_layers: int = 0, + linear_attn_config: dict | None = None, + **kwargs, + ): + self.model_type = model_type + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.head_dim = ( + head_dim if head_dim is not None else hidden_size // num_attention_heads + ) + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.mla_use_nope = mla_use_nope + # moe config + self.n_routed_experts = self.num_experts = num_experts + self.num_experts_per_token = num_experts_per_token + self.moe_renormalize = moe_renormalize + self.num_shared_experts = num_shared_experts + self.routed_scaling_factor = routed_scaling_factor + self.moe_router_activation_func = moe_router_activation_func + assert self.moe_router_activation_func in ("softmax", "sigmoid") + self.moe_intermediate_size = moe_intermediate_size + self.first_k_dense_replace = first_k_dense_replace + self.moe_layer_freq = moe_layer_freq + self.use_grouped_topk = use_grouped_topk + self.num_expert_group = num_expert_group + self.topk_group = topk_group + self.num_nextn_predict_layers = num_nextn_predict_layers + + if linear_attn_config is not None: + assert linear_attn_config["kda_layers"] is not None + assert linear_attn_config["full_attn_layers"] is not None + self.linear_attn_config = linear_attn_config + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def is_mla(self): + return ( + self.q_lora_rank is not None + or self.kv_lora_rank is not None + or self.qk_nope_head_dim is not None + or self.qk_rope_head_dim is not None + or self.v_head_dim is not None + or self.mla_use_nope is True + ) + + @property + def is_moe(self): + return self.num_experts is not None + + @property + def is_linear_attn(self) -> bool: + return not ( + self.linear_attn_config is None + or ( + isinstance(self.linear_attn_config, dict) + and self.linear_attn_config["kda_layers"] is not None + and len(self.linear_attn_config["kda_layers"]) == 0 + ) + ) + + def is_kda_layer(self, layer_idx: int): + return ( + self.linear_attn_config is not None + and (layer_idx + 1) in self.linear_attn_config["kda_layers"] + ) + + @property + def linear_layer_ids(self): + return [i for i in range(self.num_hidden_layers) if self.is_kda_layer(i)] + + @property + def full_attention_layer_ids(self): + return [i for i in range(self.num_hidden_layers) if not self.is_kda_layer(i)] + + @property + def mamba2_cache_params(self) -> KimiLinearCacheParams: + from sglang.srt.layers.dp_attention import get_attention_tp_size + + shape = KimiLinearStateShape.create( + tp_world_size=get_attention_tp_size(), + num_heads=self.linear_attn_config["num_heads"], + head_dim=self.linear_attn_config["head_dim"], + conv_kernel_size=self.linear_attn_config["short_conv_kernel_size"], + ) + + return KimiLinearCacheParams(shape=shape, layers=self.linear_layer_ids) diff --git a/sglang/python/sglang/srt/configs/kimi_vl.py b/sglang/python/sglang/srt/configs/kimi_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..3c7d20f5944da297f5d54311ce592ed3c83eeed5 --- /dev/null +++ b/sglang/python/sglang/srt/configs/kimi_vl.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py +from typing import Optional, Union + +from transformers.configuration_utils import PretrainedConfig + +from sglang.srt.configs.deepseekvl2 import DeepseekV2Config +from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig + + +class KimiVLConfig(PretrainedConfig): + model_type = "kimi_vl" + + def __init__( + self, + vision_config: Optional[Union[dict, MoonViTConfig]] = None, + text_config: Optional[Union[dict, DeepseekV2Config]] = None, + ignore_index: int = -100, + media_placeholder_token_id: int = 163605, + pad_token_id: int = 0, + **kwargs + ): + if vision_config is None: + vision_config = MoonViTConfig() + elif isinstance(vision_config, dict): + vision_config = MoonViTConfig(**vision_config) + self.vision_config = vision_config + + if text_config is None: + text_config = DeepseekV2Config() + elif isinstance(text_config, dict): + text_config = DeepseekV2Config(**text_config) + self.text_config = text_config + + self.ignore_index = ignore_index + self.media_placeholder_token_id = media_placeholder_token_id + + super().__init__(pad_token_id=pad_token_id, **kwargs) diff --git a/sglang/python/sglang/srt/configs/kimi_vl_moonvit.py b/sglang/python/sglang/srt/configs/kimi_vl_moonvit.py new file mode 100644 index 0000000000000000000000000000000000000000..166809eb6e9861672682f25539197facf99fb0a3 --- /dev/null +++ b/sglang/python/sglang/srt/configs/kimi_vl_moonvit.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py +from transformers.configuration_utils import PretrainedConfig + + +class MoonViTConfig(PretrainedConfig): + model_type = "moonvit" + + def __init__( + self, + patch_size: int = 14, + init_pos_emb_height: int = 64, + init_pos_emb_width: int = 64, + num_attention_heads: int = 16, + num_hidden_layers: int = 27, + hidden_size: int = 1152, + intermediate_size: int = 4304, + merge_kernel_size: tuple[int, int] = (2, 2), + **kwargs, + ): + super().__init__(**kwargs) + self.patch_size = patch_size + # Positional embedding config + self.init_pos_emb_height = init_pos_emb_height + self.init_pos_emb_width = init_pos_emb_width + # Transformer config + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + # Patch merger config + self.merge_kernel_size = merge_kernel_size diff --git a/sglang/python/sglang/srt/configs/lfm2.py b/sglang/python/sglang/srt/configs/lfm2.py new file mode 100644 index 0000000000000000000000000000000000000000..bc74b4c23c3c19cf461d9403d1c2d3ffc0bb89d5 --- /dev/null +++ b/sglang/python/sglang/srt/configs/lfm2.py @@ -0,0 +1,104 @@ +# coding=utf-8 +# Copyright 2024 Liquid AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LFM2 (Liquid Foundation Model 2) configuration""" + +from typing import List, Optional + +from transformers import CONFIG_MAPPING +from transformers import Lfm2Config as HFLfm2Config +from transformers.utils import logging + +from sglang.srt.configs.mamba_utils import ( + Mamba2CacheParams, + Mamba2StateShape, + mamba2_state_dtype, +) + +logger = logging.get_logger(__name__) + + +class Lfm2Config(HFLfm2Config): + """ + SGLang configuration for LFM2 models. + + Extends HuggingFace's Lfm2Config with hybrid model properties needed by SGLang. + LFM2 uses a hybrid architecture mixing full attention and ShortConv layers. + """ + + @property + def full_attention_layer_ids(self) -> List[int]: + """Return indices of attention layers for KV cache.""" + return [i for i, lt in enumerate(self.layer_types) if lt == "full_attention"] + + @property + def linear_layer_ids(self) -> List[int]: + """Return indices of conv layers for conv state cache.""" + return [ + i for i, lt in enumerate(self.layer_types) if lt in ("conv", "short_conv") + ] + + @property + def mamba_chunk_size(self) -> int: + """Return chunk size for Mamba2 backend. LFM2 doesn't use chunking, return 1.""" + return 1 + + @property + def mamba2_cache_params(self) -> Optional[Mamba2CacheParams]: + """ + Get cache params for HybridReqToTokenPool initialization. + + LFM2 uses ShortConv layers with a small fixed-size cache (kernel_size - 1). + Unlike full Mamba2 models, LFM2 only uses the conv state, not SSM temporal state. + """ + from sglang.srt.layers.dp_attention import get_attention_tp_size + + conv_layer_ids = self.linear_layer_ids + if not conv_layer_ids: + return None + + hidden_size = self.hidden_size + conv_kernel = int(self.conv_L_cache) + + # get_attention_tp_size() requires initialization, default to 1 if not available + try: + tp_size = get_attention_tp_size() + except (AssertionError, RuntimeError): + tp_size = 1 + + # For ShortConv layers, we use a simplified Mamba2StateShape + # LFM2 doesn't use SSM state (state_size=0), only conv state + # We pass num_heads=tp_size so divide(tp_size, tp_size)=1 always works. + # Since state_size=0, the temporal state shape has zero elements anyway. + shape = Mamba2StateShape.create( + tp_world_size=tp_size, + intermediate_size=hidden_size, + n_groups=1, # ShortConv doesn't use grouping + num_heads=tp_size, # Ensures divide works; temporal state is empty anyway + head_dim=hidden_size, # Conv operates on full hidden dim + state_size=0, # No SSM temporal state for ShortConv + conv_kernel=conv_kernel, + ) + + return Mamba2CacheParams( + shape=shape, + layers=conv_layer_ids, + dtype=mamba2_state_dtype(self), + ) + + +# Override HuggingFace's Lfm2Config with our extended version +# Cannot use .register() because lfm2 is already registered by transformers +# Directly modify the internal _extra_content dict instead +CONFIG_MAPPING._extra_content["lfm2"] = Lfm2Config diff --git a/sglang/python/sglang/srt/configs/lfm2_moe.py b/sglang/python/sglang/srt/configs/lfm2_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..23112ca0891404ec55204e444e00e87af0a08a91 --- /dev/null +++ b/sglang/python/sglang/srt/configs/lfm2_moe.py @@ -0,0 +1,192 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LFM2-MoE (Liquid Foundation Model 2 - Mixture of Experts) configuration + +Note: HF transformers has Lfm2MoeConfig in v5.0.0rc2 (unreleased). +Once released, we could inherit from it like Lfm2Config does with HFLfm2Config. +For now, we define a standalone config to support the model immediately. +""" + +from typing import List, Optional + +from transformers import CONFIG_MAPPING +from transformers.configuration_utils import PretrainedConfig + +from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape + + +class Lfm2MoeConfig(PretrainedConfig): + """ + Configuration for LFM2-MoE models (e.g., LiquidAI/LFM2-8B-A1B). + + LFM2-MoE is a hybrid architecture with: + - Attention layers and ShortConv layers (like dense LFM2) + - MoE (Mixture of Experts) FFN layers with sigmoid routing + + Key MoE specifics: + - First `num_dense_layers` use dense MLP, rest use MoE + - Sigmoid routing (not softmax) with expert_bias for load balancing + - expert_bias is fp32 for numerical stability + """ + + model_type = "lfm2_moe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size: int = 65536, + hidden_size: int = 2048, + intermediate_size: int = 7168, + moe_intermediate_size: int = 1792, + num_hidden_layers: int = 32, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + max_position_embeddings: int = 128000, + initializer_range: float = 0.02, + norm_eps: float = 1e-5, + use_cache: bool = True, + pad_token_id: int = 0, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = True, + rope_parameters: Optional[dict] = None, + conv_bias: bool = False, + conv_L_cache: int = 3, + # MoE-specific parameters + num_dense_layers: int = 2, + num_experts: int = 32, + num_experts_per_tok: int = 4, + use_expert_bias: bool = True, + routed_scaling_factor: float = 1.0, + norm_topk_prob: bool = True, + # Layer types + layer_types: Optional[List[str]] = None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.norm_eps = norm_eps + self.use_cache = use_cache + + # Conv parameters + self.conv_bias = conv_bias + self.conv_L_cache = conv_L_cache + + # MoE parameters + self.num_dense_layers = num_dense_layers + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.use_expert_bias = use_expert_bias + self.routed_scaling_factor = routed_scaling_factor + self.norm_topk_prob = norm_topk_prob + + # Layer types (attention vs conv) + self.layer_types = layer_types + + # RoPE parameters + self.rope_parameters = rope_parameters + + # Validate layer_types length matches num_hidden_layers + if layer_types is not None and len(layer_types) != num_hidden_layers: + raise ValueError( + f"layer_types length ({len(layer_types)}) must match " + f"num_hidden_layers ({num_hidden_layers})" + ) + + # Handle tie_embedding alias from original config + tie_word_embeddings = kwargs.pop("tie_embedding", tie_word_embeddings) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def full_attention_layer_ids(self) -> List[int]: + """Return indices of attention layers for KV cache.""" + if self.layer_types is None: + return [] + return [i for i, lt in enumerate(self.layer_types) if lt == "full_attention"] + + @property + def linear_layer_ids(self) -> List[int]: + """Return indices of conv layers for conv state cache.""" + if self.layer_types is None: + return [] + return [ + i for i, lt in enumerate(self.layer_types) if lt in ("conv", "short_conv") + ] + + @property + def mamba_chunk_size(self) -> int: + """Return chunk size for Mamba2 backend. LFM2 doesn't use chunking.""" + return 1 + + @property + def mamba2_cache_params(self) -> Optional[Mamba2CacheParams]: + """ + Get cache params for HybridReqToTokenPool initialization. + + LFM2-MoE uses ShortConv layers with a small fixed-size cache. + """ + from sglang.srt.layers.dp_attention import get_attention_tp_size + + conv_layer_ids = self.linear_layer_ids + if not conv_layer_ids: + return None + + hidden_size = self.hidden_size + # conv_L_cache in config is kernel_size (e.g., 3) + conv_kernel = int(self.conv_L_cache) + # actual cache size is kernel_size - 1 (e.g., 2 for kernel=3) + + try: + tp_size = get_attention_tp_size() + except (AssertionError, RuntimeError): + tp_size = 1 + + shape = Mamba2StateShape.create( + tp_world_size=tp_size, + intermediate_size=hidden_size, + n_groups=1, + num_heads=tp_size, # Ensures divide works; temporal state is empty anyway + head_dim=hidden_size, + state_size=0, + conv_kernel=conv_kernel, + ) + + # Uses default mamba2_state_dtype() which reads SGLANG_MAMBA_CONV_DTYPE env var + # (defaults to bfloat16). Set SGLANG_MAMBA_CONV_DTYPE=float16 for fp16 inference. + return Mamba2CacheParams( + shape=shape, + layers=conv_layer_ids, + ) + + +# Register with transformers CONFIG_MAPPING so AutoConfig.from_pretrained() +# can instantiate our config class when loading models with model_type="lfm2_moe" +try: + CONFIG_MAPPING.register("lfm2_moe", Lfm2MoeConfig) +except Exception: + # Already registered or registration failed - use direct assignment + CONFIG_MAPPING._extra_content["lfm2_moe"] = Lfm2MoeConfig diff --git a/sglang/python/sglang/srt/configs/load_config.py b/sglang/python/sglang/srt/configs/load_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ddf8d2967ce64f3a99b44594672c226e306cc0a0 --- /dev/null +++ b/sglang/python/sglang/srt/configs/load_config.py @@ -0,0 +1,136 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py +import enum +import logging +from dataclasses import dataclass, field +from typing import Any, List, Optional, Union + +import orjson + +from sglang.srt.configs.modelopt_config import ModelOptConfig +from sglang.srt.utils import is_hip + +logger = logging.getLogger(__name__) + + +class LoadFormat(str, enum.Enum): + AUTO = "auto" + PT = "pt" + SAFETENSORS = "safetensors" + NPCACHE = "npcache" + DUMMY = "dummy" + SHARDED_STATE = "sharded_state" + GGUF = "gguf" + BITSANDBYTES = "bitsandbytes" + MISTRAL = "mistral" + LAYERED = "layered" + FLASH_RL = "flash_rl" # For RL training with quantized models + JAX = "jax" + REMOTE = "remote" + REMOTE_INSTANCE = "remote_instance" + RDMA = "rdma" + LOCAL_CACHED = "local_cached" + FASTSAFETENSORS = "fastsafetensors" + PRIVATE = "private" + + +@dataclass +class LoadConfig: + """ + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + "bitsandbytes" will load nf4 type weights. + "flash_rl" will load weights with support for RL training + with quantized models, enabling efficient weight reloading. + ignore_patterns: The list of patterns to ignore when loading the model. + Default to "original/**/*" to avoid repeated loading of llama's + checkpoints. + decryption_key_file: If set, decrypts the output files with a password read + from this file (after PBKDF2). + decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit. + + # ModelOpt-specific loading options + modelopt_checkpoint_restore_path: Optional[str] = None + modelopt_checkpoint_save_path: Optional[str] = None + modelopt_export_path: Optional[str] = None + """ + + load_format: Union[str, LoadFormat] = LoadFormat.AUTO + download_dir: Optional[str] = None + model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) + ignore_patterns: Optional[Union[List[str], str]] = None + decryption_key_file: Optional[str] = None + decrypt_max_concurrency: int = -1 + tp_rank: Optional[int] = None + remote_instance_weight_loader_seed_instance_ip: Optional[str] = None + remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None + remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None + remote_instance_weight_loader_backend: Optional[str] = None + remote_instance_weight_loader_transfer_engine: Optional[Any] = None + + # ModelOpt-specific loading options + modelopt_checkpoint_restore_path: Optional[str] = None + modelopt_checkpoint_save_path: Optional[str] = None + modelopt_export_path: Optional[str] = None + + # ModelOpt configuration object + modelopt_config: Optional[ModelOptConfig] = None + + # QuantizedRL-specific options (for FlashRL-style quantization) + rl_quant_profile: Optional[str] = ( + None # Path to rollout quantization profile (e.g., /root/profile.7b.pt) + ) + + # For multi-layer MTP + draft_model_idx: Optional[int] = None + + def __post_init__(self): + model_loader_extra_config = self.model_loader_extra_config or {} + if isinstance(model_loader_extra_config, str): + self.model_loader_extra_config = orjson.loads(model_loader_extra_config) + self._verify_load_format() + + if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: + logger.info( + "Ignoring the following patterns when downloading weights: %s", + self.ignore_patterns, + ) + else: + self.ignore_patterns = ["original/**/*"] + + # Create ModelOptConfig if not provided + if self.modelopt_config is None: + self.modelopt_config = ModelOptConfig( + checkpoint_restore_path=self.modelopt_checkpoint_restore_path, + checkpoint_save_path=self.modelopt_checkpoint_save_path, + export_path=self.modelopt_export_path, + ) + + def _verify_load_format(self) -> None: + if not isinstance(self.load_format, str): + return + + load_format = self.load_format.lower() + self.load_format = LoadFormat(load_format) + + rocm_not_supported_load_format: List[str] = [] + if is_hip() and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [ + f + for f in LoadFormat.__members__ + if (f not in rocm_not_supported_load_format) + ] + raise ValueError( + f"load format '{load_format}' is not supported in ROCm. " + f"Supported load formats are " + f"{rocm_supported_load_format}" + ) diff --git a/sglang/python/sglang/srt/configs/longcat_flash.py b/sglang/python/sglang/srt/configs/longcat_flash.py new file mode 100644 index 0000000000000000000000000000000000000000..e6a2dfb026cac002c0bf95aa07761393a7930ad4 --- /dev/null +++ b/sglang/python/sglang/srt/configs/longcat_flash.py @@ -0,0 +1,104 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +FLASH_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class LongcatFlashConfig(PretrainedConfig): + model_type = "longcat_flash" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=131072, + hidden_size=6144, + intermediate_size=None, + ffn_hidden_size=12288, + expert_ffn_hidden_size=2048, + num_layers=28, + num_hidden_layers=None, + num_attention_heads=64, + ep_size=1, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=128, + qk_nope_head_dim=128, + v_head_dim=128, + n_routed_experts=512, + moe_topk=12, + norm_topk_prob=False, + max_position_embeddings=131072, + rms_norm_eps=1e-05, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mla_scale_q_lora=True, + mla_scale_kv_lora=True, + torch_dtype="bfloat16", + params_dtype="bfloat16", + rounter_params_dtype="float32", + router_bias=False, + topk_method=None, + routed_scaling_factor=6.0, + zero_expert_num=256, + zero_expert_type="identity", + nextn_use_scmoe=False, + num_nextn_predict_layers=1, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + torch_dtype=torch_dtype, + params_dtype=params_dtype, + rounter_params_dtype=rounter_params_dtype, + topk_method=topk_method, + router_bias=router_bias, + nextn_use_scmoe=nextn_use_scmoe, + num_nextn_predict_layers=num_nextn_predict_layers, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = ( + num_hidden_layers if num_hidden_layers is not None else num_layers + ) + self.intermediate_size = ( + intermediate_size if intermediate_size is not None else ffn_hidden_size + ) + self.moe_intermediate_size = expert_ffn_hidden_size + self.num_attention_heads = num_attention_heads + self.ep_size = ep_size + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.n_routed_experts = n_routed_experts + self.moe_topk = moe_topk + self.norm_topk_prob = norm_topk_prob + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mla_scale_q_lora = mla_scale_q_lora + self.mla_scale_kv_lora = mla_scale_kv_lora + self.zero_expert_num = zero_expert_num + self.zero_expert_type = zero_expert_type + self.routed_scaling_factor = routed_scaling_factor + self.hidden_act = "silu" diff --git a/sglang/python/sglang/srt/configs/mamba_utils.py b/sglang/python/sglang/srt/configs/mamba_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..96b2bca68ce8b582bd75027153492062096e2f95 --- /dev/null +++ b/sglang/python/sglang/srt/configs/mamba_utils.py @@ -0,0 +1,239 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common config utils for mamba2 - NemotronH, FalconH1, Qwen3Next, LFM2, etc.""" + +import logging +from abc import ABC +from dataclasses import dataclass, field +from typing import List, Optional + +import numpy as np +import torch + +from sglang.srt.distributed.utils import divide +from sglang.srt.environ import envs + +logger = logging.getLogger(__name__) + + +def extra_groups_for_head_shards(ngroups: int, tp_size: int): + """Compute the increase in group numbers to account for + replication in order to accompany the head shards.""" + + # in the case ngoups % tp_size == 0, this will be zero + if ngroups % tp_size == 0: + return 0 + + # for n_groups == 1, this is exactly tp_size - n_groups + return tp_size - ngroups + + +@dataclass(kw_only=True, frozen=True) +class Mamba2StateDType: + conv: torch.dtype + temporal: torch.dtype + + +def mamba2_state_dtype(config=None) -> Mamba2StateDType: + """ + Get mamba2 state dtype from config or environment variable. + + Priority (from highest to lowest): + 1. Environment variable SGLANG_MAMBA_SSM_DTYPE + 2. Config file (config.mamba_ssm_dtype or config.text_config.mamba_ssm_dtype) + 3. Default "float32" + + Args: + config: Optional config object (PretrainedConfig). If provided, will read + mamba_ssm_dtype from it. For VL models, reads from text_config. + + Returns: + Mamba2StateDType with conv and temporal dtypes + """ + dtype_map = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, + } + conv_dtype = dtype_map.get(envs.SGLANG_MAMBA_CONV_DTYPE.get(), torch.bfloat16) + + # Get SSM dtype: default -> config -> env var + ssm_dtype = torch.float32 # Step 1: Default value + + # Step 2: Try to read from config + if config is not None: + config_dtype = None + if hasattr(config, "text_config") and hasattr( + config.text_config, "mamba_ssm_dtype" + ): + # VL model: read from text_config + config_dtype = config.text_config.mamba_ssm_dtype + elif hasattr(config, "mamba_ssm_dtype"): + # Text model: read from root config + config_dtype = config.mamba_ssm_dtype + + if config_dtype is not None: + if config_dtype not in dtype_map: + logger.warning( + f"Invalid mamba_ssm_dtype '{config_dtype}' in config. " + f"Must be one of {list(dtype_map.keys())}. Using default 'float32'." + ) + else: + ssm_dtype = dtype_map[config_dtype] + + # Step 3: Check environment variable, if not None, override + env_ssm_dtype = envs.SGLANG_MAMBA_SSM_DTYPE.get() + if env_ssm_dtype is not None: + if env_ssm_dtype not in dtype_map: + logger.warning( + f"Invalid mamba_ssm_dtype '{env_ssm_dtype}' from environment variable. " + f"Must be one of {list(dtype_map.keys())}. Using default 'float32'." + ) + else: + ssm_dtype = dtype_map[env_ssm_dtype] + + logger.debug(f"Mamba2 state dtype: conv_dtype={conv_dtype}, ssm_dtype={ssm_dtype}") + + return Mamba2StateDType(conv=conv_dtype, temporal=ssm_dtype) + + +@dataclass(kw_only=True, frozen=True) +class BaseLinearStateParams(ABC): + dtype: Mamba2StateDType = field(default_factory=lambda: mamba2_state_dtype(None)) + layers: list[int] + + @property + def mamba_cache_per_req(self) -> int: + conv_numel = int( + np.sum([np.prod(conv_shape) for conv_shape in self.shape.conv]) + ) + + ssm_numel = int(np.prod(self.shape.temporal)) + return ( + conv_numel * self.dtype.conv.itemsize + + ssm_numel * self.dtype.temporal.itemsize + ) * len(self.layers) + + +@dataclass(kw_only=True, frozen=True) +class Mamba2StateShape: + conv: list[tuple[int, int]] + temporal: tuple[int, int, int] + + intermediate_size: int + conv_dim: int + ssm_state_size: int + num_heads: int + head_dim: int + state_size: int + conv_kernel: int + + @staticmethod + def create( + *, + tp_world_size: int, + intermediate_size: int, + n_groups: int, + num_heads: int, + head_dim: int, + state_size: int, + conv_kernel: int, + ) -> "Mamba2StateShape": + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + if n_groups % tp_world_size != 0: + extra_groups = extra_groups_for_head_shards(n_groups, tp_world_size) + n_groups += extra_groups + # heads and n_groups are TP-ed + conv_dim = intermediate_size + 2 * n_groups * state_size + + # contiguous along 'dim' axis + conv_state_shape = divide(conv_dim, tp_world_size), conv_kernel - 1 + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., QWen3-Next: (32, 128, 128) + temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size) + return Mamba2StateShape( + conv=[conv_state_shape], + temporal=temporal_state_shape, + intermediate_size=intermediate_size, + conv_dim=conv_dim, + ssm_state_size=state_size, + num_heads=num_heads, + head_dim=head_dim, + state_size=state_size, + conv_kernel=conv_kernel, + ) + + +@dataclass(kw_only=True, frozen=True) +class Mamba2CacheParams(BaseLinearStateParams): + shape: Mamba2StateShape + + +@dataclass(kw_only=True, frozen=True) +class KimiLinearStateShape: + conv: List[tuple[int, int]] + temporal: tuple[int, int, int] + + num_heads: int + head_dim: int + num_k_heads: int + head_k_dim: int + conv_kernel: int + num_spec: int + + @staticmethod + def create( + *, + tp_world_size: int, + num_heads: int, + head_dim: int, + num_k_heads: Optional[int] = None, + head_k_dim: Optional[int] = None, + conv_kernel_size: int = 4, + num_spec: int = 0, + ) -> "KimiLinearStateShape": + if num_k_heads is None: + num_k_heads = num_heads + if head_k_dim is None: + head_k_dim = head_dim + + proj_size = num_heads * head_dim + proj_k_size = num_k_heads * head_k_dim + + conv_state_shape = (divide(proj_size, tp_world_size), conv_kernel_size - 1) + conv_state_k_shape = (divide(proj_k_size, tp_world_size), conv_kernel_size - 1) + temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim) + + conv_state_shape = ( + conv_state_shape[1], + conv_state_shape[0] + conv_state_k_shape[0] * 2, + ) + + return KimiLinearStateShape( + conv=[conv_state_shape], + temporal=temporal_state_shape, + num_heads=num_heads, + head_dim=head_dim, + num_k_heads=num_k_heads, + head_k_dim=head_k_dim, + conv_kernel=conv_kernel_size, + num_spec=num_spec, + ) + + +@dataclass(kw_only=True, frozen=True) +class KimiLinearCacheParams(BaseLinearStateParams): + shape: KimiLinearStateShape diff --git a/sglang/python/sglang/srt/configs/model_config.py b/sglang/python/sglang/srt/configs/model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..41ad51ce004f096f17bb6139728aa3fb54d3ff4a --- /dev/null +++ b/sglang/python/sglang/srt/configs/model_config.py @@ -0,0 +1,1462 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import json +import logging +import math +import os +from enum import Enum, IntEnum, auto +from pathlib import Path +from typing import Any, List, Optional, Set, Union + +import torch +from transformers import PretrainedConfig + +from sglang.srt.environ import envs +from sglang.srt.layers.quantization import QUANTIZATION_METHODS +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import is_hip, is_sm100_supported, retry +from sglang.srt.utils.hf_transformers_utils import ( + get_config, + get_context_length, + get_generation_config, + get_hf_text_config, + get_sparse_attention_config, +) +from sglang.utils import is_in_ci + +logger = logging.getLogger(__name__) + + +class AttentionArch(IntEnum): + MLA = auto() + MHA = auto() + + +class ModelImpl(str, Enum): + AUTO = "auto" + SGLANG = "sglang" + TRANSFORMERS = "transformers" + MINDSPORE = "mindspore" + + +def is_deepseek_nsa(config: PretrainedConfig) -> bool: + return ( + config.architectures is not None + and config.architectures[0] + in [ + "DeepseekV3ForCausalLM", + "DeepseekV32ForCausalLM", + "DeepseekV3ForCausalLMNextN", + "MistralLarge3ForCausalLM", + "PixtralForConditionalGeneration", + "GlmMoeDsaForCausalLM", + ] + and getattr(config, "index_topk", None) is not None + ) + + +def get_nsa_index_head_dim(config: PretrainedConfig) -> int: + assert is_deepseek_nsa(config) + return config.index_head_dim + + +def get_nsa_index_topk(config: PretrainedConfig) -> int: + assert is_deepseek_nsa(config) + return config.index_topk + + +def get_nsa_index_n_heads(config: PretrainedConfig) -> int: + assert is_deepseek_nsa(config) + return config.index_n_heads + + +class ModelConfig: + def __init__( + self, + model_path: str, + trust_remote_code: bool = True, + revision: Optional[str] = None, + context_length: Optional[int] = None, + model_override_args: str = "{}", + is_embedding: Optional[bool] = None, + enable_multimodal: Optional[bool] = None, + dtype: str = "auto", + quantization: Optional[str] = None, + override_config_file: Optional[str] = None, + is_draft_model: bool = False, + model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, + sampling_defaults: str = "openai", + quantize_and_serve: bool = False, + is_multi_layer_eagle: bool = False, + encoder_only: bool = False, + language_only: bool = False, + disable_hybrid_swa_memory: bool = False, + ) -> None: + # Parse args + self.model_path = model_path + self.revision = revision + self.quantization = quantization + self.is_draft_model = is_draft_model + self.model_impl = model_impl + self.sampling_defaults = sampling_defaults + self.quantize_and_serve = quantize_and_serve + self.is_multi_layer_eagle = is_multi_layer_eagle + self.disable_hybrid_swa_memory = disable_hybrid_swa_memory + + # Validate quantize_and_serve configuration + self._validate_quantize_and_serve_config() + + # Get hf config + self._maybe_pull_model_tokenizer_from_remote() + self.model_override_args = json.loads(model_override_args) + kwargs = {} + if override_config_file and override_config_file.strip(): + kwargs["_configuration_file"] = override_config_file.strip() + self.hf_config = get_config( + self.model_path, + trust_remote_code=trust_remote_code, + revision=revision, + model_override_args=self.model_override_args, + **kwargs, + ) + self.hf_text_config = get_hf_text_config(self.hf_config) + self.hf_generation_config = get_generation_config( + self.model_path, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + + # Set enable_multimodal + if enable_multimodal is None: + mm_disabled_models = [ + "Gemma3ForConditionalGeneration", + "Llama4ForConditionalGeneration", + "Step3VLForConditionalGeneration", + ] + if self.hf_config.architectures[0] in mm_disabled_models: + enable_multimodal = False + logger.info( + f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal." + ) + else: + enable_multimodal = True + + # Config draft model + self._config_draft_model() + + # Check model type + self.attention_chunk_size = getattr( + self.hf_text_config, "attention_chunk_size", None + ) + self.sliding_window_size = self._get_sliding_window_size() + self.is_generation = is_generation_model( + self.hf_config.architectures, is_embedding + ) + self.is_multimodal = enable_multimodal and is_multimodal_model( + self.hf_config.architectures + ) + self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model( + self.hf_config.architectures + ) + self.is_image_gen = enable_multimodal and is_image_gen_model( + self.hf_config.architectures + ) + self.is_audio_model = enable_multimodal and is_audio_model( + self.hf_config.architectures + ) + # TODO: requires further polishing + self.is_image_understandable_model = enable_multimodal and hasattr( + self.hf_config, "vision_config" + ) + self.is_audio_understandable_model = enable_multimodal and hasattr( + self.hf_config, "audio_config" + ) + + self.is_multimodal_chunked_prefill_supported = ( + enable_multimodal + and is_multimodal_chunked_prefill_supported(self.hf_config.architectures) + ) + self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures) + self.is_local_attention_model = is_local_attention_model( + self.hf_config.architectures + ) + self.is_piecewise_cuda_graph_disabled_model = ( + is_piecewise_cuda_graph_disabled_model(self.hf_config.architectures) + or is_deepseek_nsa(self.hf_text_config) + ) + self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + + # Derive context length and model shapes + self._derive_context_length(context_length) + self._derive_model_shapes() + + # Update hybrid model + self._derive_hybrid_model() + + # Verify quantization + self._verify_quantization() + + self._verify_transformers_version() + + # Verify dual-chunk attention config + self._verify_dual_chunk_attention_config() + + # Cache attributes + self.hf_eos_token_id = self._get_hf_eos_token_id() + + # multimodal + self.image_token_id = getattr( + self.hf_config, "image_token_id", None + ) or getattr(self.hf_config, "image_token_index", None) + + self.hf_config.encoder_only = encoder_only + self.hf_config.language_only = language_only + + # matryoshka embeddings + self.matryoshka_dimensions = getattr( + self.hf_config, "matryoshka_dimensions", None + ) + self.is_matryoshka = self.matryoshka_dimensions or getattr( + self.hf_config, "is_matryoshka", False + ) + + @staticmethod + def from_server_args( + server_args: ServerArgs, + model_path: str = None, + model_revision: str = None, + is_draft_model: bool = False, + **kwargs, + ): + quantization = ( + server_args.speculative_draft_model_quantization + if is_draft_model + else server_args.quantization + ) + override_config_file = ( + server_args.decrypted_draft_config_file + if is_draft_model + else server_args.decrypted_config_file + ) + return ModelConfig( + model_path=model_path or server_args.model_path, + trust_remote_code=server_args.trust_remote_code, + revision=model_revision or server_args.revision, + context_length=server_args.context_length, + model_override_args=server_args.json_model_override_args, + is_embedding=server_args.is_embedding, + enable_multimodal=server_args.enable_multimodal, + dtype=server_args.dtype, + quantization=quantization, + model_impl=server_args.model_impl, + sampling_defaults=server_args.sampling_defaults, + quantize_and_serve=server_args.quantize_and_serve, + override_config_file=override_config_file, + is_multi_layer_eagle=server_args.enable_multi_layer_eagle, + language_only=server_args.language_only, + encoder_only=server_args.encoder_only, + is_draft_model=is_draft_model, + disable_hybrid_swa_memory=server_args.disable_hybrid_swa_memory, + **kwargs, + ) + + def _config_draft_model(self): + is_draft_model = self.is_draft_model + + if is_draft_model and self.hf_config.architectures[0] in [ + "DeepseekV3ForCausalLM", + "GlmMoeDsaForCausalLM", + ]: + self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN" + + if is_draft_model and self.hf_config.architectures[0] in [ + "Glm4MoeForCausalLM", + "Glm4MoeLiteForCausalLM", + ]: + self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN" + + if is_draft_model and self.hf_config.architectures[0] in [ + "GlmOcrForConditionalGeneration", + ]: + self.hf_config.architectures[0] = "GlmOcrForConditionalGenerationNextN" + + if ( + is_draft_model + and self.hf_config.architectures[0] == "LongcatFlashForCausalLM" + ): + self.hf_config.architectures[0] = "LongcatFlashForCausalLMNextN" + self.hf_config.num_hidden_layers = self.hf_config.num_nextn_predict_layers + + if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM": + self.hf_config.architectures[0] = "MiMoMTP" + if ( + is_draft_model + and self.hf_config.architectures[0] == "MiMoV2FlashForCausalLM" + ): + self.hf_config.architectures[0] = "MiMoV2MTP" + if is_draft_model and self.hf_config.architectures[0] == "Step3p5ForCausalLM": + self.hf_config.architectures[0] = "Step3p5MTP" + if is_draft_model and self.hf_config.architectures[0] in [ + "BailingMoeV2ForCausalLM", + "BailingMoeForCausalLM", + "BailingMoeV2_5ForCausalLM", + ]: + self.hf_config.architectures[0] = "BailingMoeForCausalLMNextN" + if ( + is_draft_model + and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM" + ): + self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP" + + if is_draft_model and self.hf_config.architectures[0] == "Qwen3NextForCausalLM": + self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP" + self.hf_config.num_nextn_predict_layers = 1 + + if is_draft_model and self.hf_config.architectures[0] in [ + "Qwen3_5ForConditionalGeneration", + "Qwen3_5MoeForConditionalGeneration", + ]: + self.hf_config.architectures[0] = "Qwen3_5ForCausalLMMTP" + self.hf_config.num_nextn_predict_layers = 1 + + if is_draft_model and self.hf_config.architectures[0] == "ExaoneMoEForCausalLM": + self.hf_config.architectures[0] = "ExaoneMoEForCausalLMMTP" + self.hf_config.num_nextn_predict_layers = 1 + + if is_draft_model and self.hf_config.architectures[0] == "NemotronHForCausalLM": + self.hf_config.architectures[0] = "NemotronHForCausalLMMTP" + self.hf_config.num_nextn_predict_layers = 1 + + def _derive_hybrid_model(self): + # Use self.context_len after it has been initialized to prevent using context_len which may be None. + self.is_hybrid_swa = ( + is_hybrid_swa_model(self.hf_config.architectures) + and not self.disable_hybrid_swa_memory + ) + + if self.is_hybrid_swa: + self.swa_attention_layer_ids, self.full_attention_layer_ids = ( + get_hybrid_layer_ids( + self.hf_config.architectures, + self.hf_text_config, + ) + ) + + self.is_hybrid_swa_compress = self.hf_config.architectures[0] in [ + "MiMoV2FlashForCausalLM", + "MiMoV2MTP", + ] + + def _derive_context_length(self, context_length: int): + is_draft_model = self.is_draft_model + derived_context_len = get_context_length(self.hf_text_config) + + if context_length is not None: + if context_length > derived_context_len: + reason = "Target model's" if is_draft_model else "User-specified" + msg = ( + f"Warning: {reason} context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). " + f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config." + ) + if ( + envs.SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN.get() + or is_in_ci() # FIXME: fix this special case + ): + logger.warning(msg) + self.context_len = context_length + if is_draft_model: + self.hf_text_config.max_position_embeddings = context_length + logger.warning( + f"Overriding the draft model's max_position_embeddings to {context_length}." + ) + else: + raise ValueError( + f"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1" + ) + else: + self.context_len = context_length + else: + self.context_len = derived_context_len + + # Transfer context_len to HuggingFace config so models can access it + self.hf_config.context_len = self.context_len + + def _derive_model_shapes(self): + # Unify the config keys for hf_text_config + self.head_dim = getattr( + self.hf_text_config, + "head_dim", + self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads, + ) + self.v_head_dim = getattr( + self.hf_text_config, + "v_head_dim", + self.head_dim, + ) + + self.swa_head_dim = getattr( + self.hf_text_config, + "swa_head_dim", + self.head_dim, + ) + self.swa_v_head_dim = getattr( + self.hf_text_config, + "swa_v_head_dim", + self.v_head_dim, + ) + # FIXME: temporary special judge for MLA architecture + if ( + "DeepseekV2ForCausalLM" in self.hf_config.architectures + or "DeepseekV32ForCausalLM" in self.hf_config.architectures + or "DeepseekV3ForCausalLM" in self.hf_config.architectures + or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures + or "Glm4MoeLiteForCausalLM" in self.hf_config.architectures + or "GlmMoeDsaForCausalLM" in self.hf_config.architectures + or "LongcatFlashForCausalLM" in self.hf_config.architectures + or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures + or "DotsVLMForCausalLM" in self.hf_config.architectures + or "MistralLarge3ForCausalLM" in self.hf_config.architectures + or "PixtralForConditionalGeneration" in self.hf_config.architectures + or "MistralLarge3ForCausalLMEagle" in self.hf_config.architectures + or "KimiK25ForConditionalGeneration" in self.hf_config.architectures + ): + self.head_dim = 256 + self.attention_arch = AttentionArch.MLA + self.kv_lora_rank = self.hf_text_config.kv_lora_rank + self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim + self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim + self.v_head_dim = self.hf_text_config.v_head_dim + self.index_head_dim = ( + get_nsa_index_head_dim(self.hf_text_config) + if is_deepseek_nsa(self.hf_text_config) + else None + ) + # Handle rope scaling + self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim) + # in transformers v5, rope_scaling is just rope_parameters for backward compatibility + rope_scaling = self.hf_text_config.rope_scaling + if rope_scaling: + # v5 uses "rope_type", v4 uses "type" + rope_type = ( + rope_scaling.get("rope_type") + or rope_scaling.get("type") + or "default" + ) + if rope_type != "default": + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + elif "MiniCPM3ForCausalLM" in self.hf_config.architectures: + self.head_dim = 128 + self.attention_arch = AttentionArch.MLA + self.kv_lora_rank = self.hf_config.kv_lora_rank + self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim + elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures and getattr( + self.hf_text_config, "use_mla", True + ): + self.head_dim = 256 + self.attention_arch = AttentionArch.MLA + self.kv_lora_rank = self.hf_text_config.kv_lora_rank + self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim + self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim + elif "KimiVLForConditionalGeneration" in self.hf_config.architectures: + self.head_dim = 256 + self.attention_arch = AttentionArch.MLA + self.kv_lora_rank = self.hf_text_config.kv_lora_rank + self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim + self.v_head_dim = self.hf_text_config.v_head_dim + self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim + elif "KimiLinearForCausalLM" in self.hf_config.architectures: + self.head_dim = 72 + self.attention_arch = AttentionArch.MLA + self.kv_lora_rank = self.hf_config.kv_lora_rank + self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim + self.v_head_dim = self.hf_config.v_head_dim + self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim + elif ( + "BailingMoeV2_5ForCausalLM" in self.hf_config.architectures + or "BailingMoeForCausalLMNextN" in self.hf_config.architectures + ): + self.head_dim = self.hf_text_config.head_dim + self.attention_arch = AttentionArch.MLA + self.kv_lora_rank = self.hf_text_config.kv_lora_rank + self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim + self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim + self.v_head_dim = self.hf_config.v_head_dim + # Handle rope scaling with yarn + self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim) + if self.hf_config.rope_scaling: + mscale_all_dim = self.hf_config.rope_scaling.get( + "mscale_all_dim", False + ) + scaling_factor = self.hf_config.rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + elif "SarvamMLAForCausalLM" in self.hf_config.architectures: + self.head_dim = ( + self.hf_config.qk_nope_head_dim + self.hf_config.qk_rope_head_dim + ) + self.attention_arch = AttentionArch.MLA + self.kv_lora_rank = self.hf_config.kv_lora_rank + self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim + self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim + self.v_head_dim = self.hf_config.v_head_dim + self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim) + if self.hf_config.rope_scaling: + mscale_all_dim = self.hf_config.rope_scaling.get( + "mscale_all_dim", False + ) + scaling_factor = self.hf_config.rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + else: + if ( + "MistralModel" in self.hf_config.architectures + or "MixtralForCausalLM" in self.hf_config.architectures + or "MistralForCausalLM" in self.hf_config.architectures + ): + if getattr(self, "head_dim", None) is None: + self.head_dim = ( + self.hf_config.hidden_size // self.hf_config.num_attention_heads + ) + # In transformers==4.52.3, the head_dim is null in MistralConfig + if ( + not hasattr(self.hf_text_config, "head_dim") + or self.hf_text_config.head_dim is None + ): + setattr(self.hf_text_config, "head_dim", self.head_dim) + + elif "BaichuanForCausalLM" in self.hf_config.architectures: + self.use_alibi = self.hf_config.hidden_size != 4096 + + self.attention_arch = AttentionArch.MHA + + self.num_attention_heads = self.hf_text_config.num_attention_heads + self.num_key_value_heads = getattr( + self.hf_text_config, "num_key_value_heads", None + ) + + # for Dbrx and MPT models + if self.hf_config.model_type in ["dbrx", "mpt"]: + self.num_key_value_heads = getattr( + self.hf_config.attn_config, "kv_n_heads", None + ) + + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + self.hidden_size = self.hf_text_config.hidden_size + self.num_hidden_layers = self.hf_text_config.num_hidden_layers + self.num_attention_layers = self.num_hidden_layers + if "LongcatFlashForCausalLM" in self.hf_config.architectures: + self.num_attention_layers = self.num_hidden_layers * 2 + if "IQuestLoopCoderForCausalLM" in self.hf_config.architectures: + loop_num = getattr(self.hf_text_config, "loop_num", 1) + self.num_attention_layers = int(self.num_hidden_layers * int(loop_num)) + if "WhisperForConditionalGeneration" in self.hf_config.architectures: + # Whisper has unique layer ID scheme: + # - Encoder self-attention: 0 to encoder_layers-1 (no KV cache) + # - Decoder self-attention: encoder_layers to encoder_layers+decoder_layers-1 (uses KV cache) + # - Decoder cross-attention: encoder_layers+decoder_layers to encoder_layers+2*decoder_layers-1 + # Even though cross-attention doesn't save KV cache, attention backend needs buffer to exist + encoder_layers = getattr(self.hf_text_config, "encoder_layers", 0) + decoder_layers = getattr( + self.hf_text_config, "decoder_layers", self.num_hidden_layers + ) + self.num_attention_layers = encoder_layers + 2 * decoder_layers + self.num_nextn_predict_layers = getattr( + self.hf_text_config, "num_nextn_predict_layers", None + ) + self.vocab_size = self.hf_text_config.vocab_size + + def get_total_num_attention_heads(self) -> int: + return self.num_attention_heads + + def get_num_attention_heads(self, tensor_parallel_size) -> int: + total_num_attention_heads = self.num_attention_heads + return max(1, total_num_attention_heads // tensor_parallel_size) + + # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 + def get_total_num_kv_heads(self) -> int: + """Returns the total number of KV heads.""" + # For GPTBigCode & Falcon: + # NOTE: for falcon, when new_decoder_architecture is True, the + # multi_query flag is ignored and we use n_head_kv for the number of + # KV heads. + falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] + new_decoder_arch_falcon = ( + self.hf_config.model_type in falcon_model_types + and getattr(self.hf_config, "new_decoder_architecture", False) + ) + if not new_decoder_arch_falcon and getattr( + self.hf_text_config, "multi_query", False + ): + # Multi-query attention, only one KV head. + # Currently, tensor parallelism is not supported in this case. + return 1 + + # For DBRX and MPT + if self.hf_config.model_type in ["mpt"]: + if "kv_n_heads" in self.hf_config.attn_config: + return self.hf_config.attn_config["kv_n_heads"] + return self.hf_config.num_attention_heads + if self.hf_config.model_type in ["dbrx"]: + return getattr( + self.hf_config.attn_config, + "kv_n_heads", + self.hf_config.num_attention_heads, + ) + if self.hf_config.model_type in ["nemotron-nas"]: + nkvh = { + self.hf_config.num_attention_heads // block.attention.n_heads_in_group + for block in self.hf_config.block_configs + if not block.attention.no_op + } + if len(nkvh) == 0: + raise RuntimeError("Couldn't determine number of kv heads") + if len(nkvh) > 1: + raise ValueError( + "Variable GQA (VGQA) is not yet supported for nemotron-nas in sglang" + ) + return next(iter(nkvh)) + + attributes = [ + # For Falcon: + "n_head_kv", + "num_kv_heads", + # For LLaMA-2: + "num_key_value_heads", + # For ChatGLM: + "multi_query_group_num", + # For Step3 + "num_attention_groups", + ] + for attr in attributes: + num_kv_heads = getattr(self.hf_text_config, attr, None) + if num_kv_heads is not None: + return num_kv_heads + + # For non-grouped-query attention models, the number of KV heads is + # equal to the number of attention heads. + return self.hf_text_config.num_attention_heads + + def get_num_kv_heads(self, tensor_parallel_size) -> int: + """Returns the number of KV heads per GPU.""" + total_num_kv_heads = self.get_total_num_kv_heads() + # If tensor parallelism is used, we divide the number of KV heads by + # the tensor parallel size. We will replicate the KV heads in the + # case where the number of KV heads is smaller than the tensor + # parallel size so each GPU has at least one KV head. + return max(1, total_num_kv_heads // tensor_parallel_size) + + def get_swa_num_kv_heads(self, tensor_parallel_size) -> int: + """Similar to get_num_kv_heads(), but for SWA.""" + if hasattr(self.hf_text_config, "swa_num_key_value_heads"): + total_num_kv_heads = self.hf_text_config.swa_num_key_value_heads + return max(1, total_num_kv_heads // tensor_parallel_size) + elif hasattr(self.hf_text_config, "attention_other_setting"): # For step3p5 + total_num_kv_heads = self.hf_text_config.attention_other_setting.get( + "num_attention_groups" + ) + return max(1, total_num_kv_heads // tensor_parallel_size) + else: + return self.get_num_kv_heads(tensor_parallel_size) + + # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py + def _parse_quant_hf_config(self): + quant_cfg = getattr(self.hf_config, "quantization_config", None) + if quant_cfg is not None and not isinstance(quant_cfg, dict): + quant_cfg = quant_cfg.to_dict() + if quant_cfg is not None: + # Identify modelopt quantization + if ( + "quant_method" not in quant_cfg + or quant_cfg["quant_method"] == "modelopt" + ): + parsed_cfg = self._parse_modelopt_quant_config( + {"quantization": quant_cfg} + ) + if parsed_cfg: + quant_cfg.update(parsed_cfg) + + if quant_cfg is None: + # compressed-tensors uses a "compression_config" key + quant_cfg = getattr(self.hf_config, "compression_config", None) + if quant_cfg is None: + # check if is modelopt or mixed-precision model -- Both of them don't have corresponding field + # in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory + # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main + # example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main + is_local = os.path.exists(self.model_path) + if not is_local: + # Conditional import based on SGLANG_USE_MODELSCOPE environment variable + if envs.SGLANG_USE_MODELSCOPE.get(): + + from modelscope import HubApi, model_file_download + + hf_api = HubApi() + else: + import huggingface_hub + from huggingface_hub import HfApi, hf_hub_download + + hf_api = HfApi() + try: + # In offline mode, skip file_exists check to avoid OfflineModeIsEnabled error + # Instead, directly try to download/read from cache with local_files_only + file_exists = False # Initialize to avoid UnboundLocalError + if not huggingface_hub.constants.HF_HUB_OFFLINE: + # Online mode: check if file exists before attempting download (optimization) + file_exists = retry( + lambda: hf_api.file_exists( + self.model_path, "hf_quant_config.json" + ), + max_retry=2, + initial_delay=1.0, + max_delay=5.0, + ) + if not file_exists: + # File doesn't exist on hub, no need to try downloading + return quant_cfg # None + + # Download (online mode) or read from cache (offline mode) + if envs.SGLANG_USE_MODELSCOPE.get(): + quant_config_file = model_file_download( + model_id=self.model_path, + file_path="hf_quant_config.json", + revision=self.revision, + ) + else: + quant_config_file = hf_hub_download( + repo_id=self.model_path, + filename="hf_quant_config.json", + revision=self.revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) + with open(quant_config_file) as f: + quant_config_dict = json.load(f) + quant_cfg = self._parse_modelopt_quant_config(quant_config_dict) + except huggingface_hub.errors.LocalEntryNotFoundError: + # Offline mode and file not in cache - this is normal for non-quantized models + logger.debug( + f"hf_quant_config.json not found in cache for {self.model_path} " + "(offline mode, normal for non-quantized models)" + ) + except huggingface_hub.errors.OfflineModeIsEnabled: + # Should not reach here after our changes, but keep for safety + logger.warning( + "Offline mode is enabled, skipping hf_quant_config.json check" + ) + except Exception as e: + logger.warning( + "Failed to load hf_quant_config.json for model %s: %s", + self.model_path, + e, + ) + elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")): + quant_config_file = os.path.join( + self.model_path, "hf_quant_config.json" + ) + with open(quant_config_file) as f: + quant_config_dict = json.load(f) + quant_cfg = self._parse_modelopt_quant_config(quant_config_dict) + return quant_cfg + + def _find_quant_modelslim_config(self): + quant_config_file = Path(self.model_path, "quant_model_description.json") + quant_cfg = None + if quant_config_file.is_file(): + with open(quant_config_file) as f: + quant_cfg = json.load(f) + # This field is required for flagless model loading but is not present in + # modelslim model description, so we're adding it here manually. + quant_cfg["quant_method"] = "modelslim" + + return quant_cfg + + def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> Optional[dict]: + """Parse ModelOpt quantization config and return the appropriate quant_method.""" + json_quant_configs = quant_config_dict["quantization"] + quant_algo = json_quant_configs.get("quant_algo", None) + + if quant_algo == "MIXED_PRECISION": + return {"quant_method": "w4afp8", "quant_algo": quant_algo} + elif quant_algo and ("FP4" in quant_algo or "NVFP4" in quant_algo): + return {"quant_method": "modelopt_fp4", "quant_algo": quant_algo} + elif quant_algo and "FP8" in quant_algo: + return {"quant_method": "modelopt_fp8", "quant_algo": quant_algo} + else: + return None + + def get_quantization_config_log_str(self) -> Optional[str]: + """ + Get a concise string representation of the quantization config for logging. + Returns something like "quant=fp8, fmt=e4m3" or "quant=gptq, bits=4". + """ + try: + quant_cfg = self._parse_quant_hf_config() + if not quant_cfg: + return None + + quant_method = quant_cfg.get("quant_method", "quantized") + log_str = f"quant={quant_method}" + + # Append interesting fields if they exist + for field in ["bits", "quant_algo", "fmt"]: + if field in quant_cfg: + log_str += f", {field}={quant_cfg[field]}" + + return log_str + except Exception: + return None + + def _is_already_quantized(self) -> bool: + """Check if the model is already quantized based on config files.""" + # Check for quantization in hf_config (config.json) + if getattr(self.hf_config, "quantization_config", None) or getattr( + self.hf_config, "compression_config", None + ): + return True + + # Check for HuggingFace quantization config + from sglang.srt.utils import has_hf_quant_config + + return has_hf_quant_config(self.model_path) + + def _get_modelopt_quant_type(self) -> str: + """Extract ModelOpt quantization type from unified quantization flag.""" + if self.quantization == "modelopt_fp8": + return "fp8" + elif self.quantization == "modelopt_fp4": + return "nvfp4" + elif self.quantization == "modelopt": + # Auto-detect from model config + quant_cfg = self._parse_quant_hf_config() + if quant_cfg: + quant_method = quant_cfg.get("quant_method", "").lower() + if "fp4" in quant_method: + return "fp4" + elif "fp8" in quant_method: + return "fp8" + # Default to fp8 if can't detect + return "fp8" + else: + return "fp8" # Default fallback + + def _get_sliding_window_size(self) -> Optional[int]: + sliding_window_size = getattr(self.hf_text_config, "sliding_window_size", None) + if sliding_window_size is None: + sliding_window_size = getattr(self.hf_text_config, "sliding_window", None) + return sliding_window_size + + def _validate_quantize_and_serve_config(self): + """Validate quantize_and_serve configuration.""" + if not self.quantize_and_serve: + return + + # Check if ModelOpt quantization is specified + _MODELOPT_QUANTIZATION_METHODS = [ + "modelopt", + "modelopt_fp8", + "modelopt_fp4", + ] + modelopt_quantization_specified = ( + self.quantization in _MODELOPT_QUANTIZATION_METHODS + ) + + if not modelopt_quantization_specified: + raise ValueError( + "quantize_and_serve requires ModelOpt quantization (set with --quantization " + f"{{{', '.join(sorted(_MODELOPT_QUANTIZATION_METHODS))}}})" + ) + + # quantize_and_serve is disabled due to compatibility issues + raise NotImplementedError( + "quantize_and_serve functionality is currently disabled due to compatibility issues. " + "Please use the separate quantize-then-deploy workflow instead. " + "Step 1: Quantize and export model. " + "Step 2: Deploy the exported model." + ) + + # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py + def _verify_quantization(self) -> None: + supported_quantization = [*QUANTIZATION_METHODS] + rocm_supported_quantization = [ + "awq", + "gptq", + "fp8", + "compressed_tensors", + "compressed-tensors", + "fbgemm_fp8", + "w8a8_fp8", + "petit_nvfp4", + "quark", + "mxfp4", + "auto-round", + "quark_int4fp8_moe", + ] + optimized_quantization_methods = [ + "fp8", + "marlin", + "modelopt_fp8", + "modelopt_fp4", + "gptq_marlin_24", + "gptq_marlin", + "awq_marlin", + "fbgemm_fp8", + "compressed_tensors", + "compressed-tensors", + "experts_int8", + "w8a8_int8", + "w8a8_fp8", + "moe_wna16", + "qoq", + "w4afp8", + "petit_nvfp4", + "quark", + "modelslim", + ] + compatible_quantization_methods = { + "modelopt_fp8": ["modelopt"], + "modelopt_fp4": ["modelopt"], + "petit_nvfp4": ["modelopt"], + "w8a8_int8": ["compressed-tensors", "compressed_tensors"], + "w8a8_fp8": ["compressed-tensors", "compressed_tensors"], + } + if self.quantization is not None: + self.quantization = self.quantization.lower() + + # Parse quantization method from the HF and ModelSlim model config, if available. + # Only one function should return config, other should return None. + cfg_list = [] + hf_config = self._parse_quant_hf_config() + modelslim_config = self._find_quant_modelslim_config() + quant_config = modelslim_config or hf_config + if quant_config is not None: + cfg_list.append(quant_config) + + # Filter out None values + cfg_list = [item for item in cfg_list if item is not None] + if len(cfg_list) > 1: + raise ValueError( + "Config list contains configs from 2 methods, must be only 1" + ) + quant_cfg = cfg_list[0] if cfg_list else None + + if quant_cfg is not None: + quant_method = quant_cfg.get( + "quant_method", "" if not self.quantization else self.quantization + ).lower() + + # Detect which checkpoint is it + for _, method in QUANTIZATION_METHODS.items(): + quantization_override = method.override_quantization_method( + quant_cfg, self.quantization + ) + if quantization_override: + quant_method = quantization_override + self.quantization = quantization_override + break + + # Verify quantization configurations. + if self.quantization is None: + self.quantization = quant_method + elif self.quantization != quant_method: + # Check if the CLI-specified quantization is compatible with HF config's quant_method + is_compatible = ( + self.quantization in compatible_quantization_methods + and quant_method + in compatible_quantization_methods[self.quantization] + ) + if is_compatible: + # Keep the CLI-specified quantization (e.g., modelopt_fp4) even if + # HF config says "modelopt" - they are compatible + logger.info( + f"Using CLI-specified quantization ({self.quantization}) which is " + f"compatible with HF config quant_method ({quant_method})." + ) + elif self.is_draft_model: + # Allow auto-detection of quantization from checkpoint for draft model + # only if the CLI quantization is not compatible + logger.info( + f"Draft model quantization ({quant_method}) differs from " + f"main model quantization ({self.quantization}). " + f"Using draft model's detected quantization: {quant_method}" + ) + self.quantization = quant_method + else: + raise ValueError( + "Quantization method specified in the model config " + f"({quant_method}) does not match the quantization " + f"method specified in the `quantization` argument " + f"({self.quantization})." + ) + + # Check if the scale_fmt is ue8m0, and warn user if deepgemm is enabled for non-ue8m0 models on blackwell + self.use_scale_ue8m0 = quant_cfg.get("scale_fmt", None) == "ue8m0" + from sglang.srt.layers import deep_gemm_wrapper + + if not self.use_scale_ue8m0 and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0: + logger.warning( + "DeepGemm is enabled but the scale_fmt of checkpoint is not ue8m0. This might cause accuracy degradation on Blackwell." + ) + + if self.quantization is not None: + if self.quantization not in supported_quantization: + raise ValueError( + f"Unknown quantization method: {self.quantization}. Must " + f"be one of {supported_quantization}." + ) + if is_hip() and self.quantization not in rocm_supported_quantization: + raise ValueError( + f"{self.quantization} quantization is currently not " + f"supported in ROCm." + ) + if self.quantization not in optimized_quantization_methods: + # Don't warn for MXFP4 on SM100 since it has optimized kernels + if not (self.quantization == "mxfp4" and is_sm100_supported()): + logger.warning( + "%s quantization is not fully " + "optimized yet. The speed can be slower than " + "non-quantized models.", + self.quantization, + ) + + def _verify_dual_chunk_attention_config(self) -> None: + if hasattr(self.hf_config, "dual_chunk_attention_config"): + # Try loading the sparse attention config + sparse_attn_config = get_sparse_attention_config(self.model_path) + if not sparse_attn_config: + return + self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = ( + sparse_attn_config + ) + if ( + "sparse_attention_enabled" + not in self.hf_config.dual_chunk_attention_config + ): + self.hf_config.dual_chunk_attention_config[ + "sparse_attention_enabled" + ] = True + + def _verify_transformers_version(self): + import transformers + from packaging import version + + tf_version_str = getattr(transformers, "__version__", None) + if tf_version_str is None: + return + + vision_config = getattr(self.hf_config, "vision_config", None) + is_glm_46vmoe = "glm-4.6v" in self.model_path.lower() or ( + vision_config is not None + and getattr(vision_config, "model_type", None) == "glm4v_moe_vision" + # The vision config model type for GLM-4.5v is 'glm4v_moe', + # while for GLM-4.6v, it is 'glm4v_moe_vision'. + ) + needs_tf_v5 = is_glm_46vmoe + + tf_version = version.parse(tf_version_str) + required_version = version.parse("5.0.0dev0") + + if tf_version < required_version: + if needs_tf_v5: + raise ValueError( + f"Transformers version {tf_version_str} is not supported for model {self.model_path} " + f"or model type {self.hf_config.model_type}. " + "Please upgrade transformers to >= 5.0.0." + ) + elif not needs_tf_v5: + logger.warning( + f"Transformers version {tf_version_str} is used for model type {self.hf_config.model_type}. " + "If you experience issues related to RoPE parameters, " + "they may be due to incompatibilities between Transformers >=5.0.0 and some models. " + "You can try downgrading to transformers==4.57.1 as a workaround." + ) + + def _get_hf_eos_token_id(self) -> Optional[Set[int]]: + eos_ids = getattr(self.hf_config, "eos_token_id", None) + if eos_ids is not None: + # it can be either int or list of int + eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids) + if eos_ids is None: + eos_ids = set() + if self.hf_generation_config: + generation_eos_ids = getattr( + self.hf_generation_config, "eos_token_id", None + ) + if generation_eos_ids: + generation_eos_ids = ( + {generation_eos_ids} + if isinstance(generation_eos_ids, int) + else set(generation_eos_ids) + ) + eos_ids = eos_ids | generation_eos_ids + return eos_ids + + def get_default_sampling_params(self) -> dict[str, Any]: + """ + Get default sampling parameters from the model's generation config. + + This method returns non-default sampling parameters from the model's + generation_config.json when sampling_defaults is set to "model". + + Returns: + A dictionary containing the non-default sampling parameters. + """ + if self.sampling_defaults != "model": + return {} + + if self.hf_generation_config is None: + return {} + + config = self.hf_generation_config.to_dict() + + available_params = [ + "repetition_penalty", + "temperature", + "top_k", + "top_p", + "min_p", + ] + + default_sampling_params = { + p: config.get(p) for p in available_params if config.get(p) is not None + } + + return default_sampling_params + + def _maybe_pull_model_tokenizer_from_remote(self) -> None: + """ + Pull the model config files to a temporary + directory in case of remote. + + Args: + model: The model name or path. + + """ + from sglang.srt.connector import create_remote_connector + from sglang.srt.utils import is_remote_url + + if is_remote_url(self.model_path): + logger.info("Pulling model configs from remote...") + # BaseConnector implements __del__() to clean up the local dir. + # Since config files need to exist all the time, so we DO NOT use + # with statement to avoid closing the client. + client = create_remote_connector(self.model_path) + if is_remote_url(self.model_path): + client.pull_files(allow_pattern=["*config.json"]) + self.model_weights = self.model_path + self.model_path = client.get_local_dir() + + +# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py +_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.float16, + "float16": torch.float16, + "float": torch.float32, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + + +# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py +def _get_and_verify_dtype( + config: PretrainedConfig, + dtype: Union[str, torch.dtype], +) -> torch.dtype: + # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct + # because config.torch_dtype can be None. + if isinstance(config, dict): + config_dtype = config.get("dtype", None) or config.get("torch_dtype", None) + model_type = config.get("model_type", "") + else: + config_dtype = getattr(config, "dtype", None) + model_type = getattr(config, "model_type", "") + if isinstance(config_dtype, str): + config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None) + if config_dtype is None: + config_dtype = torch.float32 + + if isinstance(dtype, str): + dtype = dtype.lower() + if dtype == "auto": + if config_dtype == torch.float32: + if model_type.startswith("gemma"): + if model_type == "gemma": + gemma_version = "" + else: + gemma_version = model_type[5] + logger.info( + f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead " + "of float16 by default. Please specify `dtype` if you " + "want to use float16." + ) + torch_dtype = torch.bfloat16 + else: + # Following the common practice, we use float16 for float32 + # models. + torch_dtype = torch.float16 + else: + torch_dtype = config_dtype + else: + if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {dtype}") + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + elif isinstance(dtype, torch.dtype): + torch_dtype = dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + # Verify the dtype. + if torch_dtype != config_dtype: + if torch_dtype == torch.float32: + # Upcasting to float32 is allowed. + logger.info("Upcasting %s to %s.", config_dtype, torch_dtype) + pass + elif config_dtype == torch.float32: + # Downcasting from float32 to float16 or bfloat16 is allowed. + logger.info("Downcasting %s to %s.", config_dtype, torch_dtype) + pass + else: + # Casting between float16 and bfloat16 is allowed with a warning. + logger.warning("Casting %s to %s.", config_dtype, torch_dtype) + + return torch_dtype + + +def is_generation_model(model_architectures: List[str], is_embedding: bool = False): + # We have two ways to determine whether a model is a generative model. + # 1. Check the model architecture + # 2. check the `is_embedding` server args + + if ( + "LlamaEmbeddingModel" in model_architectures + or "MistralModel" in model_architectures + or "LlamaForSequenceClassification" in model_architectures + or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures + or "InternLM2ForRewardModel" in model_architectures + or "Qwen2ForRewardModel" in model_architectures + or "Qwen3ForRewardModel" in model_architectures + or "Qwen2ForSequenceClassification" in model_architectures + or "Qwen3ForSequenceClassification" in model_architectures + or "CLIPModel" in model_architectures + or "BertModel" in model_architectures + or "Contriever" in model_architectures + or "BertForSequenceClassification" in model_architectures + or "XLMRobertaModel" in model_architectures + or "XLMRobertaForSequenceClassification" in model_architectures + or "Gemma2ForSequenceClassification" in model_architectures + ): + return False + else: + return not is_embedding + + +multimodal_model_archs = [ + "CLIPModel", + "DeepseekVL2ForCausalLM", + "Ernie4_5_VLMoeForConditionalGeneration", + "Gemma3ForConditionalGeneration", + "Gemma3nForConditionalGeneration", + "Glm4vForConditionalGeneration", + "Glm4vMoeForConditionalGeneration", + "GlmOcrForConditionalGeneration", + "GlmAsrForConditionalGeneration", + "Grok1VForCausalLM", + "Grok1AForCausalLM", + "LlavaLlamaForCausalLM", + "Llama4ForConditionalGeneration", + "LlavaMistralForCausalLM", + "LlavaQwenForCausalLM", + "LlavaForConditionalGeneration", + "LlavaVidForCausalLM", + "LightOnOCRForConditionalGeneration", + "MiniCPMO", + "MiniCPMV", + "Mistral3ForConditionalGeneration", + "MultiModalityCausalLM", + "MllamaForConditionalGeneration", + "NemotronH_Nano_VL_V2", + "PixtralForConditionalGeneration", + "Qwen2AudioForConditionalGeneration", + "Qwen2VLForConditionalGeneration", + "Qwen2_5_VLForConditionalGeneration", + "Qwen3VLForConditionalGeneration", + "Qwen3VLMoeForConditionalGeneration", + "Qwen3_5ForConditionalGeneration", + "Qwen3_5MoeForConditionalGeneration", + "Qwen3OmniMoeForConditionalGeneration", + "KimiVLForConditionalGeneration", + "InternVLChatModel", + "InternS1ForConditionalGeneration", + "InternS1ProForConditionalGeneration", + "Phi4MMForCausalLM", + "WhisperForConditionalGeneration", + "Step3VLForConditionalGeneration", + "POINTSV15ChatModel", + "DotsVLMForCausalLM", + "DotsOCRForCausalLM", + "Sarashina2VisionForCausalLM", + "NVILAForConditionalGeneration", + "NVILALiteForConditionalGeneration", + "DeepseekOCRForCausalLM", + "JetVLMForConditionalGeneration", + "PaddleOCRVLForConditionalGeneration", + "MiDashengLMModel", + "StepVLForConditionalGeneration", + "KimiK25ForConditionalGeneration", +] + +piecewise_cuda_graph_disabled_model_archs = [ + "DeepseekV32ForCausalLM", + "Qwen3NextForCausalLM", + "GlmMoeDsaForCausalLM", + "BailingMoeV2_5ForCausalLM", + "LLaDAModelLM", +] + +if external_mm_model_arch := envs.SGLANG_EXTERNAL_MM_MODEL_ARCH.get(): + multimodal_model_archs.append(external_mm_model_arch) + + +def is_multimodal_model(model_architectures: List[str]): + if any( + multi_model_arch in model_architectures + for multi_model_arch in multimodal_model_archs + ): + return True + else: + return False + + +def is_multimodal_gen_model(model_architectures: List[str]): + return False + + +def is_image_gen_model(model_architectures: List[str]): + return False + + +def is_audio_model(model_architectures: List[str]): + models = [ + "WhisperForConditionalGeneration", + ] + return any(model in model_architectures for model in models) + + +def is_encoder_decoder_model(model_architectures: List[str]): + models = [ + "WhisperForConditionalGeneration", + "MllamaForConditionalGeneration", + ] + return any(model in model_architectures for model in models) + + +def is_local_attention_model(model_architectures: List[str]): + return "Llama4ForConditionalGeneration" in model_architectures + + +def is_multimodal_chunked_prefill_supported(model_architectures: List[str]): + """Check if chunked prefill is supported for a MultiModal model.""" + unsupported = [ + "Grok1VForCausalLM", + "Grok1AForCausalLM", + "LlavaLlamaForCausalLM", + "MllamaForConditionalGeneration", + "CLIPModel", + ] + if any(multi_model_arch in unsupported for multi_model_arch in model_architectures): + return False + else: + return True + + +def is_piecewise_cuda_graph_disabled_model(model_architectures: List[str]): + return any( + arch in piecewise_cuda_graph_disabled_model_archs + for arch in model_architectures + ) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def is_hybrid_swa_model(model_architectures: List[str]): + + hybrid_swa_archs = { + "Llama4ForConditionalGeneration", + "GptOssForCausalLM", + "MiMoV2FlashForCausalLM", + "MiMoV2MTP", + "Step3p5ForCausalLM", + "Step3p5MTP", + } + return any(arch in hybrid_swa_archs for arch in model_architectures) + + +def get_hybrid_layer_ids( + model_architectures: List[str], + hf_text_config: PretrainedConfig, +): + num_hidden_layers = hf_text_config.num_hidden_layers + if "Llama4ForConditionalGeneration" in model_architectures: + swa_attention_layer_ids = [ + i for i in range(num_hidden_layers) if (i + 1) % 4 != 0 + ] + full_attention_layer_ids = [ + i for i in range(num_hidden_layers) if (i + 1) % 4 == 0 + ] + elif "GptOssForCausalLM" in model_architectures: + layer_types = getattr(hf_text_config, "layer_types", None) + swa_attention_layer_ids = [ + i for i, x in enumerate(layer_types) if x == "sliding_attention" + ] + full_attention_layer_ids = [ + i for i, x in enumerate(layer_types) if x == "full_attention" + ] + elif "MiMoV2FlashForCausalLM" in model_architectures: + hybrid_layer_pattern = getattr(hf_text_config, "hybrid_layer_pattern", None) + swa_attention_layer_ids = [ + i for i in range(num_hidden_layers) if hybrid_layer_pattern[i] == 1 + ] + full_attention_layer_ids = [ + i for i in range(num_hidden_layers) if hybrid_layer_pattern[i] == 0 + ] + elif "MiMoV2MTP" in model_architectures: + swa_attention_layer_ids = [0] + full_attention_layer_ids = [] + elif "Step3p5ForCausalLM" in model_architectures: + layer_types = hf_text_config.layer_types + swa_attention_layer_ids = [ + i + for i, x in enumerate(layer_types) + if x == "sliding_attention" and i < num_hidden_layers + ] + full_attention_layer_ids = [ + i + for i, x in enumerate(layer_types) + if x == "full_attention" and i < num_hidden_layers + ] + elif "Step3p5MTP" in model_architectures: + swa_attention_layer_ids = [0] + full_attention_layer_ids = [] + else: + swa_attention_layer_ids = None + full_attention_layer_ids = None + return swa_attention_layer_ids, full_attention_layer_ids diff --git a/sglang/python/sglang/srt/configs/modelopt_config.py b/sglang/python/sglang/srt/configs/modelopt_config.py new file mode 100644 index 0000000000000000000000000000000000000000..911b4ce0cd96272db0266f686965aeda3ee50eac --- /dev/null +++ b/sglang/python/sglang/srt/configs/modelopt_config.py @@ -0,0 +1,30 @@ +# Configuration for NVIDIA ModelOpt quantization integration +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class ModelOptConfig: + """Configuration for NVIDIA ModelOpt quantization operations. + + This configuration class holds parameters for ModelOpt quantization, + checkpoint management, and model export operations. + + Args: + quant: Quantization method/type (e.g., "fp8", "fp4") + checkpoint_restore_path: Path to restore ModelOpt checkpoint from + checkpoint_save_path: Path to save ModelOpt checkpoint to + export_path: Path to export quantized model in HuggingFace format + quantize_and_serve: Whether to quantize and serve in one step + """ + + quant: Optional[str] = None + checkpoint_restore_path: Optional[str] = None + checkpoint_save_path: Optional[str] = None + export_path: Optional[str] = None + quantize_and_serve: bool = False + + def __post_init__(self): + """Validate configuration after initialization.""" + # Add any validation logic if needed + pass diff --git a/sglang/python/sglang/srt/configs/nano_nemotron_vl.py b/sglang/python/sglang/srt/configs/nano_nemotron_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..09ab29abf4652e6b19cd1cfe5e3eef76938c2864 --- /dev/null +++ b/sglang/python/sglang/srt/configs/nano_nemotron_vl.py @@ -0,0 +1,114 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Adapted from https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16/blob/cb5a65ff10232128389d882d805fa609427544f1/configuration.py + +from typing import Any + +from transformers.configuration_utils import PretrainedConfig + +from sglang.srt.configs.nemotron_h import NemotronHConfig +from sglang.srt.configs.radio import RadioConfig +from sglang.srt.multimodal.internvl_utils import IMAGENET_MEAN, IMAGENET_STD + + +def float_triplet(seq: Any): + a, b, c = tuple(seq) + assert ( + isinstance(a, float) and isinstance(b, float) and isinstance(c, float) + ), "expected three floats" + return a, b, c + + +class NemotronH_Nano_VL_V2_Config(PretrainedConfig): + model_type = "NemotronH_Nano_VL_V2" + is_composition = True + + def __init__( + self, + vision_config=None, + llm_config=None, + force_image_size: int = 512, + patch_size: int = 16, + downsample_ratio=0.5, + template=None, + ps_version="v2", + image_tag_type="internvl", + projector_hidden_size=4096, + vit_hidden_size=1280, + video_pruning_rate: float = 0.0, + video_context_token: str = "