File size: 16,447 Bytes
a227c91 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 | # TPU
SGLang supports high-performance TPU inference through the SGLang-JAX backend, which is specifically optimized for Google Cloud TPUs. The JAX-based implementation delivers exceptional throughput and low latency for Large Language Model (LLM) serving workloads on TPU hardware.
For TPU-specific issues or feature requests, please visit the [sglang-jax GitHub issues page](https://github.com/sgl-project/sglang-jax/issues).
**NOTE:** SGLang TPU support is implemented via the SGLang-JAX backend, a dedicated JAX-based inference engine maintained as a separate repository at [https://github.com/sgl-project/sglang-jax](https://github.com/sgl-project/sglang-jax).
## System Requirements
### Supported TPU Hardware
| TPU Type | HBM Memory | Availability |
|----------|-----------|--------------|
| TPU v6e | 32 GB | Google Cloud |
| TPU v7 | 96 GB per core | Google Cloud |
### Software Requirements
- **Python:** 3.12 or higher
- **JAX:** Latest version with TPU support
- **Environment:** Google Cloud TPU VM or compatible TPU runtime
- **Optional:** SkyPilot for simplified cloud deployment
## Feature Support Matrix
SGLang-JAX provides comprehensive TPU-optimized features for production LLM serving:
| Feature | Support Status | Description |
|---------|---------------|-------------|
| High-Throughput Continuous Batching | β
| Dynamic request batching for maximum TPU utilization |
| Radix Tree KV Cache | β
| Memory-efficient prefix sharing between requests |
| FlashAttention Backend | β
| TPU-optimized attention kernel for long sequences |
| Tensor Parallelism | β
| Distribute models across multiple TPU cores |
| Paged Attention | β
| Flexible KV cache management with paging |
| Speculative Decoding (EAGLE/EAGLE3) | β
| 20-40% throughput improvement for compatible models |
| Chunked Prefill | β
| Mixed prefill-decode batching |
| OpenAI-Compatible API | β
| Drop-in replacement for OpenAI API |
| Data Parallel Attention | π§ | In development - Attention computation with data parallelism |
| Quantization | π§ | In development - Model quantization for reduced memory usage |
| Multi-LoRA | π§ | In development - Serve multiple LoRA adapters simultaneously |
### Attention Backend Comparison
| Backend | Paged Attention | Spec Decoding | MLA | Sliding Window |
|---------|----------------|---------------|-----|----------------|
| FlashAttention (fa) | β
| β
| β | β
|
| Native | β | β | β | β |
**NOTE:** FlashAttention backend is recommended for production workloads due to superior memory efficiency and performance.
## Optimized Model List
The following models have been tested and optimized for TPU deployment:
| Model Family | Performance Status |
|--------------|-------------------|
| [Qwen 3](https://huggingface.co/Qwen) | β Recommended for production |
| [Qwen 3 MoE](https://huggingface.co/Qwen) | β Best performance |
| [Qwen 2](https://huggingface.co/Qwen) | Needs improvement |
| [Qwen 2 MoE](https://huggingface.co/Qwen) | Needs improvement |
| [Qwen 1.5](https://huggingface.co/Qwen) | Needs improvement |
| [Llama/LLaMA](https://huggingface.co/meta-llama) | Needs improvement |
| [Grok-2](https://huggingface.co/xai-org) | Needs improvement |
| [Gemma 2](https://huggingface.co/google) | Verified on TPU |
| Bailing MoE | Needs improvement |
## Installation
### Method 1: Using PyPI (Recommended)
```bash
pip install sglang-jax
```
### Method 2: From Source
```bash
git clone https://github.com/sgl-project/sglang-jax
cd sglang-jax
uv venv --python 3.12 && source .venv/bin/activate
uv pip install -e "python[all]"
```
### Method 3: Using Docker
**NOTE:** Docker support for TPU is currently under development. Please use PyPI or source installation methods.
### Method 4: Cloud TPU with SkyPilot
[SkyPilot](https://github.com/skypilot-org/skypilot) provides simplified deployment on Google Cloud TPU:
1. Install SkyPilot and configure GCP access (see [SkyPilot documentation](https://skypilot.readthedocs.io/))
2. Create a SkyPilot configuration file:
<details>
<summary>SkyPilot YAML: <code>sglang-jax.sky.yaml</code></summary>
```yaml
# sglang-jax.sky.yaml
resources:
accelerators: tpu-v6e-4
accelerator_args:
tpu_vm: True
runtime_version: v2-alpha-tpuv6e
run: |
git clone https://github.com/sgl-project/sglang-jax.git
cd sglang-jax
uv venv --python 3.12
source .venv/bin/activate
uv pip install -e "python[all]"
```
</details>
3. Launch your TPU cluster:
```bash
# Standard deployment
sky launch -c sglang-jax sglang-jax.sky.yaml --infra=gcp
# With spot instances for cost savings
sky launch -c sglang-jax sglang-jax.sky.yaml --infra=gcp --use-spot
```
## Launch of the Serving Engine
### Basic Example: Qwen-7B
```bash
JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 -u -m sgl_jax.launch_server \
--model-path Qwen/Qwen-7B-Chat \
--trust-remote-code \
--dist-init-addr=0.0.0.0:10011 \
--nnodes=1 \
--tp-size=4 \
--device=tpu \
--random-seed=3 \
--node-rank=0 \
--mem-fraction-static=0.8 \
--max-prefill-tokens=8192 \
--download-dir=/tmp \
--dtype=bfloat16 \
--skip-server-warmup \
--host 0.0.0.0 \
--port 30000
```
**Key Parameters Explained:**
1. `JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache` - Enables JIT compilation caching to accelerate server startup on subsequent runs
2. `--tp-size=4` - Tensor parallelism size; match this to your TPU core count (typically 1, 4, or 8)
3. `--device=tpu` - Specifies TPU device (this is the default for sglang-jax)
4. `--dtype=bfloat16` - Uses bfloat16 precision, which TPUs are optimized for
5. `--mem-fraction-static=0.8` - Allocates 80% of TPU HBM for static memory (adjustable from 0.2 to 0.9)
6. `--max-prefill-tokens=8192` - Maximum number of tokens processed in the prefill phase
### High-Performance Configuration: Qwen3-8B
For production workloads with optimal throughput:
```bash
python3 -u -m sgl_jax.launch_server \
--model-path Qwen/Qwen3-8B \
--trust-remote-code \
--tp-size=4 \
--device=tpu \
--mem-fraction-static=0.8 \
--chunked-prefill-size=2048 \
--dtype=bfloat16 \
--max-running-requests=256 \
--page-size=128 \
--attention-backend=fa
```
### Advanced: Speculative Decoding (EAGLE3)
Speculative decoding can improve throughput by 20-40% for compatible models:
```bash
python3 -u -m sgl_jax.launch_server \
--model-path Qwen/Qwen3-32B \
--trust-remote-code \
--device=tpu \
--tp-size=4 \
--mem-fraction-static=0.8 \
--max-prefill-tokens=4096 \
--attention-backend=fa \
--dtype=bfloat16 \
--port=30000 \
--host=0.0.0.0 \
--disable-overlap-schedule \
--speculative-algorithm=EAGLE3 \
--speculative-draft-model-path=AngelSlim/Qwen3-32B_eagle3 \
--page-size=64 \
--speculative-eagle-topk=1 \
--speculative-num-steps=3 \
--speculative-num-draft-tokens=4
```
**NOTE:** Speculative decoding is currently supported for Qwen3 and LLaMA model families. See the [Speculative Decoding documentation](https://github.com/sgl-project/sglang-jax/blob/main/docs/features/speculative_decoding.md) for detailed configuration guidance.
### Multi-Node Distributed Serving
For large models requiring multiple TPU VMs:
```bash
# Node 0 (coordinator)
python3 -m sgl_jax.launch_server \
--model-path MODEL_PATH \
--dist-init-addr=NODE0_IP:10011 \
--nnodes=2 \
--node-rank=0 \
--tp-size=8 \
[other parameters...]
# Node 1 (worker)
python3 -m sgl_jax.launch_server \
--model-path MODEL_PATH \
--dist-init-addr=NODE0_IP:10011 \
--nnodes=2 \
--node-rank=1 \
--tp-size=8 \
[other parameters...]
```
## Benchmarking with Requests
### Throughput Testing
Basic throughput benchmark:
```bash
python3 -m sgl_jax.bench_serving \
--backend sgl-jax \
--dataset-name random \
--num-prompts=100 \
--random-input=512 \
--random-output=128 \
--max-concurrency=8 \
--random-range-ratio=1 \
--warmup-requests=0
```
### Latency Testing
Measure single-batch latency:
```bash
python3 -m sgl_jax.bench_one_batch_server \
--base-url http://127.0.0.1:30000 \
--model-path Qwen/Qwen-7B-Chat \
--batch-size=32 \
--input-len=256 \
--output-len=32
```
### Comprehensive Benchmark Script
For systematic performance evaluation across different configurations:
```bash
#!/bin/bash
set -e
backend=${1:-sgl-jax}
num_prompts_per_concurrency=3
input_seq_lens=(1024 4096 8192)
output_seq_lens=(1 1024)
max_concurrencies=(8 16 32 64 128 256)
for input_seq_len in "${input_seq_lens[@]}"; do
for output_seq_len in "${output_seq_lens[@]}"; do
echo "======================================="
echo "Testing ISL/OSL: $input_seq_len/$output_seq_len"
echo "======================================="
for max_concurrency in "${max_concurrencies[@]}"; do
num_prompts=$((num_prompts_per_concurrency * max_concurrency))
python3 -m sgl_jax.bench_serving \
--backend ${backend} \
--dataset-name random \
--num-prompts ${num_prompts} \
--random-input ${input_seq_len} \
--random-output ${output_seq_len} \
--max-concurrency ${max_concurrency} \
--random-range-ratio 1 \
--disable-ignore-eos \
--warmup-requests 0
done
done
done
```
For detailed help on all benchmark parameters:
```bash
python3 -m sgl_jax.bench_serving --help
```
See the [Benchmark and Profiling Guide](https://github.com/sgl-project/sglang-jax/blob/main/docs/developer_guide/benchmark_and_profiling.md) for advanced benchmarking techniques and profiling with JAX Profiler.
## Performance Optimization
### Memory Optimization
**Reduce memory usage:**
- Lower `--mem-fraction-static` (from 0.8 β 0.5 β 0.3)
- Decrease `--max-prefill-tokens` (from 16384 β 8192 β 4096)
- Reduce `--max-running-requests`
**Handle OOM errors:**
- Start with conservative memory settings (`--mem-fraction-static=0.5`)
- Gradually increase until you find the optimal balance
- Increase `--page-size` for better memory locality (1 β 16 β 64 β 128)
### Throughput Optimization
To maximize tokens per second:
- Use FlashAttention backend: `--attention-backend=fa`
- Enable speculative decoding (EAGLE3) for Qwen3 models (20-40% improvement)
- Increase `--max-running-requests` to 256+
- Set `--mem-fraction-static` to 0.8+ (if memory allows)
- Use larger page sizes (64-128)
- Enable chunked prefill: `--chunked-prefill-size=2048`
### Latency Optimization
To minimize time-to-first-token (TTFT) and inter-token latency:
- Reduce `--page-size` to 1-4
- Lower `--max-running-requests` (16-32) for smaller batches
- Reduce `--chunked-prefill-size`
- Use conservative memory settings to avoid GC pauses
### TPU-Specific Optimizations
1. **JIT Compilation Cache:**
```bash
export JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache
```
Always set this environment variable to cache compiled kernels and accelerate server startup.
2. **Data Type Optimization:**
Use `--dtype=bfloat16` for TPU native optimization. TPUs are specifically designed for bfloat16 computations.
3. **Tensor Parallelism:**
Match `--tp-size` to your TPU core configuration (1, 4, or 8) for optimal model distribution.
4. **Attention Backend:**
Always use `--attention-backend=fa` (FlashAttention) for production workloads.
## Troubleshooting
### OOM (Out of Memory) Errors
If you encounter out-of-memory errors:
1. Reduce `--mem-fraction-static` from 0.8 to 0.5 or lower
2. Decrease `--max-prefill-tokens` from 8192 to 4096 or 2048
3. Lower `--max-running-requests` to reduce concurrent batch size
4. Increase `--page-size` for better memory layout efficiency
### Compilation Long-Time
If the server takes too long to start:
1. Ensure `JAX_COMPILATION_CACHE_DIR` is properly set
2. Understand that the first run requires JIT compilation (this is normal)
3. Subsequent runs will be significantly faster with cached compilations
4. Consider using `--skip-server-warmup` to defer compilation until first request
### Low Throughput
If you're not achieving expected throughput:
1. Verify `--tp-size` matches your TPU core configuration
2. Check that `--attention-backend=fa` is enabled
3. Increase `--max-running-requests` to enable larger batch formation
4. Consider enabling speculative decoding for compatible models
5. Ensure memory settings allow for sufficient batch sizes
### Connection Issues
If clients cannot connect to the server:
1. Ensure `--host=0.0.0.0` for external access (not just `127.0.0.1`)
2. Verify firewall rules allow traffic on the specified port (default: 30000)
3. Check that the server process is running: `curl http://localhost:30000/health`
## Advanced Features
### Speculative Decoding
SGLang-JAX supports EAGLE and EAGLE3 speculative decoding algorithms for Qwen3 and LLaMA model families. Speculative decoding can improve throughput by 20-40% without affecting output quality.
See the [Speculative Decoding documentation](https://github.com/sgl-project/sglang-jax/blob/main/docs/features/speculative_decoding.md) for detailed configuration and supported model combinations.
### Chunked Prefill
Enable mixed prefill-decode batching for better TPU utilization:
```bash
--chunked-prefill-size=2048 --enable-mixed-chunk
```
This allows the scheduler to mix prefill operations with decode operations in the same batch, improving overall throughput.
### Custom Attention Backends
SGLang-JAX supports a plugin-based attention backend system. You can implement custom attention kernels optimized for specific use cases.
See the [Attention Backend documentation](https://github.com/sgl-project/sglang-jax/blob/main/docs/features/attention_backend.md) for implementation details.
### Environment Verification
Verify your TPU setup before deploying:
```bash
python -c "from sgl_jax import check_env; check_env.check_env()"
```
This command checks:
- Installed package versions
- TPU device availability and specifications
- System resources and configuration
- Compatibility of settings
## Contributing
We welcome contributions to improve TPU support in SGLang-JAX!
### Areas for Contribution
**Check the [Development Roadmap](https://github.com/sgl-project/sglang-jax/issues/190)** to see planned features and find opportunities to contribute new functionality.
Current contribution areas include:
- Performance optimizations for specific TPU generations
- Support for additional model architectures
- Documentation improvements and examples
- Bug reports and fixes
- Benchmark results and performance analysis
### How to Contribute
1. Visit the [sglang-jax repository](https://github.com/sgl-project/sglang-jax)
2. Read the [Contribution Guide](https://github.com/sgl-project/sglang-jax/blob/main/docs/developer_guide/contribution_guide.md)
3. Join the [SGL-JAX Slack community](https://sgl-fru7574.slack.com/archives/C09EBE5HT5X) for discussions
4. Report issues at [sglang-jax/issues](https://github.com/sgl-project/sglang-jax/issues)
### Testing on TPU
For contributors who need TPU access for testing:
- Refer to the [TPU Resources Guide](https://github.com/sgl-project/sglang-jax/blob/main/docs/developer_guide/tpu_resources_guide.md) for information on accessing TPU hardware
- Use SkyPilot with spot instances for cost-effective testing
- Follow the [Benchmark and Profiling Guide](https://github.com/sgl-project/sglang-jax/blob/main/docs/developer_guide/benchmark_and_profiling.md) for performance validation
## References
### Documentation
- [SGLang-JAX Repository](https://github.com/sgl-project/sglang-jax)
- [SGLang-JAX Installation Guide](https://github.com/sgl-project/sglang-jax/blob/main/docs/get_started/install.md)
- [Qwen Models Quick Start](https://github.com/sgl-project/sglang-jax/blob/main/docs/basic_usage/qwen.md)
- [Benchmark and Profiling Guide](https://github.com/sgl-project/sglang-jax/blob/main/docs/developer_guide/benchmark_and_profiling.md)
- [Speculative Decoding](https://github.com/sgl-project/sglang-jax/blob/main/docs/features/speculative_decoding.md)
### External Resources
- [JAX Documentation](https://jax.readthedocs.io/)
- [Google Cloud TPU Documentation](https://cloud.google.com/tpu/docs)
- [SkyPilot Documentation](https://skypilot.readthedocs.io/)
|