File size: 12,553 Bytes
a227c91 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 | # PD Disaggregation
## Why and What is PD Disaggregation?
Large Language Model (LLM) inference comprises two distinct phases: **Prefill** and **Decode**. The Prefill phase is computation-intensive, processing the entire input sequence, while the Decode phase is memory-intensive, managing the Key-Value (KV) cache for token generation. Traditionally, these phases are handled within a unified engine, where combined scheduling of prefill and decode batches introduces inefficiencies. To address these challenges, we introduce **Prefill and Decoding (PD) Disaggregation** in SGLang.
### Issues with Unified Scheduling
The conventional unified engine, which processes prefill and decode batches together, results in two significant problems:
1. **Prefill Interruption**: Incoming prefill batches frequently interrupt ongoing decode batches, causing substantial delays in token generation.
2. **DP Attention Imbalance**: In data-parallel (DP) attention, one DP worker may process a prefill batch while another handles a decode batch simultaneously, leading to increased decode latency.
PD Disaggregation resolves these by separating the two stages, enabling tailored optimizations for each.
For the design details, please refer to [link](https://docs.google.com/document/d/1rQXJwKd5b9b1aOzLh98mnyMhBMhlxXA5ATZTHoQrwvc/edit?tab=t.0).
Currently, we support Mooncake and NIXL as the transfer engine.
## Profiling in PD Disaggregation Mode
When you need to profile prefill or decode workers in PD disaggregation mode, please refer to the [Profile In PD Disaggregation Mode](https://docs.sglang.io/developer_guide/benchmark_and_profiling.html#profile-in-pd-disaggregation-mode) section in the Benchmark and Profiling guide. Due to torch profiler limitations, prefill and decode workers must be profiled separately using dedicated command-line options.
## Router Integration
For deploying PD disaggregation at scale with load balancing and fault tolerance, SGLang provides a router. The router can distribute requests between prefill and decode instances using various routing policies. For detailed information on setting up routing with PD disaggregation, including configuration options and deployment patterns, see the [SGLang Model Gateway (former Router)](../advanced_features/sgl_model_gateway.md#prefill-decode-disaggregation).
## Mooncake
### Requirements
```bash
uv pip install mooncake-transfer-engine
```
### Usage
### Llama Single Node
```bash
python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B-Instruct \
--disaggregation-mode prefill \
--port 30000 \
--disaggregation-ib-device mlx5_roce0
python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B-Instruct \
--disaggregation-mode decode \
--port 30001 \
--base-gpu-id 1 \
--disaggregation-ib-device mlx5_roce0
python -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
```
### DeepSeek Multi-Node
```bash
# prefill 0
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3-0324 \
--disaggregation-ib-device ${device_name} \
--disaggregation-mode prefill \
--host ${local_ip} \
--port 30000 \
--trust-remote-code \
--dist-init-addr ${prefill_master_ip}:5000 \
--nnodes 2 \
--node-rank 0 \
--tp-size 16 \
--dp-size 8 \
--enable-dp-attention \
--moe-a2a-backend deepep \
--mem-fraction-static 0.8
# prefill 1
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3-0324 \
--disaggregation-ib-device ${device_name} \
--disaggregation-mode prefill \
--host ${local_ip} \
--port 30000 \
--trust-remote-code \
--dist-init-addr ${prefill_master_ip}:5000 \
--nnodes 2 \
--node-rank 1 \
--tp-size 16 \
--dp-size 8 \
--enable-dp-attention \
--moe-a2a-backend deepep \
--mem-fraction-static 0.8
# decode 0
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3-0324 \
--disaggregation-ib-device ${device_name} \
--disaggregation-mode decode \
--host ${local_ip} \
--port 30001 \
--trust-remote-code \
--dist-init-addr ${decode_master_ip}:5000 \
--nnodes 2 \
--node-rank 0 \
--tp-size 16 \
--dp-size 8 \
--enable-dp-attention \
--moe-a2a-backend deepep \
--mem-fraction-static 0.8 \
--max-running-requests 128
# decode 1
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3-0324 \
--disaggregation-ib-device ${device_name} \
--disaggregation-mode decode \
--host ${local_ip} \
--port 30001 \
--trust-remote-code \
--dist-init-addr ${decode_master_ip}:5000 \
--nnodes 2 \
--node-rank 1 \
--tp-size 16 \
--dp-size 8 \
--enable-dp-attention \
--moe-a2a-backend deepep \
--mem-fraction-static 0.8 \
--max-running-requests 128
```
### Advanced Configuration
PD Disaggregation with Mooncake supports the following environment variables for fine-grained control over system behavior.
#### NVLink Transport Configuration
To enable NVLink transport for KV cache transfers with the mooncake backend (recommended for NVL72 deployments), set the following environment variables. Note that auxiliary data transfer will still use TCP as a temporary workaround.
```bash
export SGLANG_MOONCAKE_CUSTOM_MEM_POOL=NVLINK
export MC_FORCE_MNNVL=True
```
The `SGLANG_MOONCAKE_CUSTOM_MEM_POOL` environment variable enables the custom memory pool. Supported values are `NVLINK` (or `True`), `BAREX`, and `INTRA_NODE_NVLINK`.
#### Prefill Server Configuration
| Variable | Description | Default |
|:--------:|:-----------:|:--------:
| **`SGLANG_DISAGGREGATION_THREAD_POOL_SIZE`** | Controls the total number of worker threads for KVCache transfer operations per TP rank | A dynamic value calculated by `int(0.75 * os.cpu_count()) // 8)`, which is limited to be larger than 4 and less than 12 to ensure efficiency and prevent thread race conditions |
| **`SGLANG_DISAGGREGATION_QUEUE_SIZE`** | Sets the number of parallel transfer queues. KVCache transfer requests from multiple decode instances will be sharded into these queues so that they can share the threads and the transfer bandwidth at the same time. If it is set to `1`, then we transfer requests one by one according to fcfs strategy | `4` |
| **`SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT`** | Timeout (seconds) for receiving destination KV indices during request initialization | `300` |
| **`SGLANG_DISAGGREGATION_BOOTSTRAP_ENTRY_CLEANUP_INTERVAL`** | Interval (seconds) between cleanups of bootstrap entries | `120` |
If a greater mean TTFT is acceptable, you can `export SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600` (10 minutes) to relax the timeout condition.
Please be aware that this setting will cause prefill instances to take a longer time to clean up the affected memory resources when a running decode node loses connection.
#### Decode Server Configuration
| Variable | Description | Default |
|:--------:|:-----------:|:--------:
| **`SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL`** | Interval (seconds) between health checks to prefill bootstrap servers | `5.0` |
| **`SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE`** | Consecutive heartbeat failures before marking prefill server offline | `2` |
| **`SGLANG_DISAGGREGATION_WAITING_TIMEOUT`** | Timeout (seconds) for receiving KV Cache after request initialization | `300` |
If a greater mean TTFT is acceptable, you can `export SGLANG_DISAGGREGATION_WAITING_TIMEOUT=600` (10 minutes) to relax the timeout condition.
## NIXL
### Requirements
Install via pip.
```bash
pip install nixl
```
Or build from source - may be required if you already have UCX installed.
```bash
git clone https://github.com/ai-dynamo/nixl.git
cd nixl
pip install . --config-settings=setup-args="-Ducx_path=/path/to/ucx"
```
### Usage
### Llama Single Node
```bash
python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B-Instruct \
--disaggregation-mode prefill \
--port 30000 \
--disaggregation-transfer-backend nixl
python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B-Instruct \
--disaggregation-mode decode \
--port 30001 \
--base-gpu-id 1 \
--disaggregation-transfer-backend nixl
python -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
```
### DeepSeek Multi-Node
```bash
# prefill 0
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3-0324 \
--disaggregation-transfer-backend nixl \
--disaggregation-mode prefill \
--host ${local_ip} \
--port 30000 \
--trust-remote-code \
--dist-init-addr ${prefill_master_ip}:5000 \
--nnodes 2 \
--node-rank 0 \
--tp-size 16 \
--dp-size 8 \
--enable-dp-attention \
--moe-a2a-backend deepep \
--mem-fraction-static 0.8
# prefill 1
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3-0324 \
--disaggregation-transfer-backend nixl \
--disaggregation-mode prefill \
--host ${local_ip} \
--port 30000 \
--trust-remote-code \
--dist-init-addr ${prefill_master_ip}:5000 \
--nnodes 2 \
--node-rank 1 \
--tp-size 16 \
--dp-size 8 \
--enable-dp-attention \
--moe-a2a-backend deepep \
--mem-fraction-static 0.8
# decode 0
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3-0324 \
--disaggregation-transfer-backend nixl \
--disaggregation-mode decode \
--host ${local_ip} \
--port 30001 \
--trust-remote-code \
--dist-init-addr ${decode_master_ip}:5000 \
--nnodes 2 \
--node-rank 0 \
--tp-size 16 \
--dp-size 8 \
--enable-dp-attention \
--moe-a2a-backend deepep \
--mem-fraction-static 0.8 \
--max-running-requests 128
# decode 1
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3-0324 \
--disaggregation-transfer-backend nixl \
--disaggregation-mode decode \
--host ${local_ip} \
--port 30001 \
--trust-remote-code \
--dist-init-addr ${decode_master_ip}:5000 \
--nnodes 2 \
--node-rank 1 \
--tp-size 16 \
--dp-size 8 \
--enable-dp-attention \
--moe-a2a-backend deepep \
--mem-fraction-static 0.8 \
--max-running-requests 128
```
### Advanced Configuration
#### NIXL Backend Selection
By default, NIXL uses the **UCX** backend for KV cache transfers. You can select a different NIXL plugin backend depending on your infrastructure using the environment variable `SGLANG_DISAGGREGATION_NIXL_BACKEND`.
Example: `export SGLANG_DISAGGREGATION_NIXL_BACKEND=LIBFABRIC`
**Available backends:** UCX (default), LIBFABRIC, or any installed NIXL plugin.
Example usage:
```bash
export SGLANG_DISAGGREGATION_NIXL_BACKEND=LIBFABRIC
python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B-Instruct \
--disaggregation-mode prefill \
--disaggregation-transfer-backend nixl \
--port 30000
```
## ASCEND
### Usage
Use ascend backend with [memfabric_hybrid](https://gitcode.com/Ascend/memfabric_hybrid) and ASCEND_MF_STORE_URL being set
```bash
pip install memfabric-hybrid==1.0.0
export ASCEND_MF_STORE_URL="tcp://xxx.xx.xxx.xxx:xxxx"
```
Use mooncake backend, more details can be found in mooncake section.
```bash
export ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE=true
```
ASCEND_NPU_PHY_ID need to be set in container env
```bash
export ASCEND_NPU_PHY_ID=xxx
```
### Llama Single Node
```bash
python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B-Instruct \
--disaggregation-mode prefill \
--port 30000 \
--disaggregation-transfer-backend ascend
python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B-Instruct \
--disaggregation-mode decode \
--port 30001 \
--base-gpu-id 1 \
--disaggregation-transfer-backend ascend
python -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
```
### DeepSeek Multi-Node
```bash
# prefill 0
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3-0324 \
--disaggregation-transfer-backend ascend \
--disaggregation-mode prefill \
--host ${local_ip} \
--port 30000 \
--trust-remote-code \
--dist-init-addr ${prefill_master_ip}:5000 \
--nnodes 1 \
--node-rank 0 \
--tp-size 16
# decode 0
python -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3-0324 \
--disaggregation-transfer-backend ascend \
--disaggregation-mode decode \
--host ${local_ip} \
--port 30001 \
--trust-remote-code \
--dist-init-addr ${decode_master_ip}:5000 \
--nnodes 1 \
--node-rank 0 \
--tp-size 16
```
|