Lekr0's picture
Add files using upload-large-folder tool
a227c91 verified
# TPU
SGLang supports high-performance TPU inference through the SGLang-JAX backend, which is specifically optimized for Google Cloud TPUs. The JAX-based implementation delivers exceptional throughput and low latency for Large Language Model (LLM) serving workloads on TPU hardware.
For TPU-specific issues or feature requests, please visit the [sglang-jax GitHub issues page](https://github.com/sgl-project/sglang-jax/issues).
**NOTE:** SGLang TPU support is implemented via the SGLang-JAX backend, a dedicated JAX-based inference engine maintained as a separate repository at [https://github.com/sgl-project/sglang-jax](https://github.com/sgl-project/sglang-jax).
## System Requirements
### Supported TPU Hardware
| TPU Type | HBM Memory | Availability |
|----------|-----------|--------------|
| TPU v6e | 32 GB | Google Cloud |
| TPU v7 | 96 GB per core | Google Cloud |
### Software Requirements
- **Python:** 3.12 or higher
- **JAX:** Latest version with TPU support
- **Environment:** Google Cloud TPU VM or compatible TPU runtime
- **Optional:** SkyPilot for simplified cloud deployment
## Feature Support Matrix
SGLang-JAX provides comprehensive TPU-optimized features for production LLM serving:
| Feature | Support Status | Description |
|---------|---------------|-------------|
| High-Throughput Continuous Batching | ✅ | Dynamic request batching for maximum TPU utilization |
| Radix Tree KV Cache | ✅ | Memory-efficient prefix sharing between requests |
| FlashAttention Backend | ✅ | TPU-optimized attention kernel for long sequences |
| Tensor Parallelism | ✅ | Distribute models across multiple TPU cores |
| Paged Attention | ✅ | Flexible KV cache management with paging |
| Speculative Decoding (EAGLE/EAGLE3) | ✅ | 20-40% throughput improvement for compatible models |
| Chunked Prefill | ✅ | Mixed prefill-decode batching |
| OpenAI-Compatible API | ✅ | Drop-in replacement for OpenAI API |
| Data Parallel Attention | 🚧 | In development - Attention computation with data parallelism |
| Quantization | 🚧 | In development - Model quantization for reduced memory usage |
| Multi-LoRA | 🚧 | In development - Serve multiple LoRA adapters simultaneously |
### Attention Backend Comparison
| Backend | Paged Attention | Spec Decoding | MLA | Sliding Window |
|---------|----------------|---------------|-----|----------------|
| FlashAttention (fa) | ✅ | ✅ | ❌ | ✅ |
| Native | ❌ | ❌ | ❌ | ❌ |
**NOTE:** FlashAttention backend is recommended for production workloads due to superior memory efficiency and performance.
## Optimized Model List
The following models have been tested and optimized for TPU deployment:
| Model Family | Performance Status |
|--------------|-------------------|
| [Qwen 3](https://huggingface.co/Qwen) | ⭐ Recommended for production |
| [Qwen 3 MoE](https://huggingface.co/Qwen) | ⭐ Best performance |
| [Qwen 2](https://huggingface.co/Qwen) | Needs improvement |
| [Qwen 2 MoE](https://huggingface.co/Qwen) | Needs improvement |
| [Qwen 1.5](https://huggingface.co/Qwen) | Needs improvement |
| [Llama/LLaMA](https://huggingface.co/meta-llama) | Needs improvement |
| [Grok-2](https://huggingface.co/xai-org) | Needs improvement |
| [Gemma 2](https://huggingface.co/google) | Verified on TPU |
| Bailing MoE | Needs improvement |
## Installation
### Method 1: Using PyPI (Recommended)
```bash
pip install sglang-jax
```
### Method 2: From Source
```bash
git clone https://github.com/sgl-project/sglang-jax
cd sglang-jax
uv venv --python 3.12 && source .venv/bin/activate
uv pip install -e "python[all]"
```
### Method 3: Using Docker
**NOTE:** Docker support for TPU is currently under development. Please use PyPI or source installation methods.
### Method 4: Cloud TPU with SkyPilot
[SkyPilot](https://github.com/skypilot-org/skypilot) provides simplified deployment on Google Cloud TPU:
1. Install SkyPilot and configure GCP access (see [SkyPilot documentation](https://skypilot.readthedocs.io/))
2. Create a SkyPilot configuration file:
<details>
<summary>SkyPilot YAML: <code>sglang-jax.sky.yaml</code></summary>
```yaml
# sglang-jax.sky.yaml
resources:
accelerators: tpu-v6e-4
accelerator_args:
tpu_vm: True
runtime_version: v2-alpha-tpuv6e
run: |
git clone https://github.com/sgl-project/sglang-jax.git
cd sglang-jax
uv venv --python 3.12
source .venv/bin/activate
uv pip install -e "python[all]"
```
</details>
3. Launch your TPU cluster:
```bash
# Standard deployment
sky launch -c sglang-jax sglang-jax.sky.yaml --infra=gcp
# With spot instances for cost savings
sky launch -c sglang-jax sglang-jax.sky.yaml --infra=gcp --use-spot
```
## Launch of the Serving Engine
### Basic Example: Qwen-7B
```bash
JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 -u -m sgl_jax.launch_server \
--model-path Qwen/Qwen-7B-Chat \
--trust-remote-code \
--dist-init-addr=0.0.0.0:10011 \
--nnodes=1 \
--tp-size=4 \
--device=tpu \
--random-seed=3 \
--node-rank=0 \
--mem-fraction-static=0.8 \
--max-prefill-tokens=8192 \
--download-dir=/tmp \
--dtype=bfloat16 \
--skip-server-warmup \
--host 0.0.0.0 \
--port 30000
```
**Key Parameters Explained:**
1. `JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache` - Enables JIT compilation caching to accelerate server startup on subsequent runs
2. `--tp-size=4` - Tensor parallelism size; match this to your TPU core count (typically 1, 4, or 8)
3. `--device=tpu` - Specifies TPU device (this is the default for sglang-jax)
4. `--dtype=bfloat16` - Uses bfloat16 precision, which TPUs are optimized for
5. `--mem-fraction-static=0.8` - Allocates 80% of TPU HBM for static memory (adjustable from 0.2 to 0.9)
6. `--max-prefill-tokens=8192` - Maximum number of tokens processed in the prefill phase
### High-Performance Configuration: Qwen3-8B
For production workloads with optimal throughput:
```bash
python3 -u -m sgl_jax.launch_server \
--model-path Qwen/Qwen3-8B \
--trust-remote-code \
--tp-size=4 \
--device=tpu \
--mem-fraction-static=0.8 \
--chunked-prefill-size=2048 \
--dtype=bfloat16 \
--max-running-requests=256 \
--page-size=128 \
--attention-backend=fa
```
### Advanced: Speculative Decoding (EAGLE3)
Speculative decoding can improve throughput by 20-40% for compatible models:
```bash
python3 -u -m sgl_jax.launch_server \
--model-path Qwen/Qwen3-32B \
--trust-remote-code \
--device=tpu \
--tp-size=4 \
--mem-fraction-static=0.8 \
--max-prefill-tokens=4096 \
--attention-backend=fa \
--dtype=bfloat16 \
--port=30000 \
--host=0.0.0.0 \
--disable-overlap-schedule \
--speculative-algorithm=EAGLE3 \
--speculative-draft-model-path=AngelSlim/Qwen3-32B_eagle3 \
--page-size=64 \
--speculative-eagle-topk=1 \
--speculative-num-steps=3 \
--speculative-num-draft-tokens=4
```
**NOTE:** Speculative decoding is currently supported for Qwen3 and LLaMA model families. See the [Speculative Decoding documentation](https://github.com/sgl-project/sglang-jax/blob/main/docs/features/speculative_decoding.md) for detailed configuration guidance.
### Multi-Node Distributed Serving
For large models requiring multiple TPU VMs:
```bash
# Node 0 (coordinator)
python3 -m sgl_jax.launch_server \
--model-path MODEL_PATH \
--dist-init-addr=NODE0_IP:10011 \
--nnodes=2 \
--node-rank=0 \
--tp-size=8 \
[other parameters...]
# Node 1 (worker)
python3 -m sgl_jax.launch_server \
--model-path MODEL_PATH \
--dist-init-addr=NODE0_IP:10011 \
--nnodes=2 \
--node-rank=1 \
--tp-size=8 \
[other parameters...]
```
## Benchmarking with Requests
### Throughput Testing
Basic throughput benchmark:
```bash
python3 -m sgl_jax.bench_serving \
--backend sgl-jax \
--dataset-name random \
--num-prompts=100 \
--random-input=512 \
--random-output=128 \
--max-concurrency=8 \
--random-range-ratio=1 \
--warmup-requests=0
```
### Latency Testing
Measure single-batch latency:
```bash
python3 -m sgl_jax.bench_one_batch_server \
--base-url http://127.0.0.1:30000 \
--model-path Qwen/Qwen-7B-Chat \
--batch-size=32 \
--input-len=256 \
--output-len=32
```
### Comprehensive Benchmark Script
For systematic performance evaluation across different configurations:
```bash
#!/bin/bash
set -e
backend=${1:-sgl-jax}
num_prompts_per_concurrency=3
input_seq_lens=(1024 4096 8192)
output_seq_lens=(1 1024)
max_concurrencies=(8 16 32 64 128 256)
for input_seq_len in "${input_seq_lens[@]}"; do
for output_seq_len in "${output_seq_lens[@]}"; do
echo "======================================="
echo "Testing ISL/OSL: $input_seq_len/$output_seq_len"
echo "======================================="
for max_concurrency in "${max_concurrencies[@]}"; do
num_prompts=$((num_prompts_per_concurrency * max_concurrency))
python3 -m sgl_jax.bench_serving \
--backend ${backend} \
--dataset-name random \
--num-prompts ${num_prompts} \
--random-input ${input_seq_len} \
--random-output ${output_seq_len} \
--max-concurrency ${max_concurrency} \
--random-range-ratio 1 \
--disable-ignore-eos \
--warmup-requests 0
done
done
done
```
For detailed help on all benchmark parameters:
```bash
python3 -m sgl_jax.bench_serving --help
```
See the [Benchmark and Profiling Guide](https://github.com/sgl-project/sglang-jax/blob/main/docs/developer_guide/benchmark_and_profiling.md) for advanced benchmarking techniques and profiling with JAX Profiler.
## Performance Optimization
### Memory Optimization
**Reduce memory usage:**
- Lower `--mem-fraction-static` (from 0.8 → 0.5 → 0.3)
- Decrease `--max-prefill-tokens` (from 16384 → 8192 → 4096)
- Reduce `--max-running-requests`
**Handle OOM errors:**
- Start with conservative memory settings (`--mem-fraction-static=0.5`)
- Gradually increase until you find the optimal balance
- Increase `--page-size` for better memory locality (1 → 16 → 64 → 128)
### Throughput Optimization
To maximize tokens per second:
- Use FlashAttention backend: `--attention-backend=fa`
- Enable speculative decoding (EAGLE3) for Qwen3 models (20-40% improvement)
- Increase `--max-running-requests` to 256+
- Set `--mem-fraction-static` to 0.8+ (if memory allows)
- Use larger page sizes (64-128)
- Enable chunked prefill: `--chunked-prefill-size=2048`
### Latency Optimization
To minimize time-to-first-token (TTFT) and inter-token latency:
- Reduce `--page-size` to 1-4
- Lower `--max-running-requests` (16-32) for smaller batches
- Reduce `--chunked-prefill-size`
- Use conservative memory settings to avoid GC pauses
### TPU-Specific Optimizations
1. **JIT Compilation Cache:**
```bash
export JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache
```
Always set this environment variable to cache compiled kernels and accelerate server startup.
2. **Data Type Optimization:**
Use `--dtype=bfloat16` for TPU native optimization. TPUs are specifically designed for bfloat16 computations.
3. **Tensor Parallelism:**
Match `--tp-size` to your TPU core configuration (1, 4, or 8) for optimal model distribution.
4. **Attention Backend:**
Always use `--attention-backend=fa` (FlashAttention) for production workloads.
## Troubleshooting
### OOM (Out of Memory) Errors
If you encounter out-of-memory errors:
1. Reduce `--mem-fraction-static` from 0.8 to 0.5 or lower
2. Decrease `--max-prefill-tokens` from 8192 to 4096 or 2048
3. Lower `--max-running-requests` to reduce concurrent batch size
4. Increase `--page-size` for better memory layout efficiency
### Compilation Long-Time
If the server takes too long to start:
1. Ensure `JAX_COMPILATION_CACHE_DIR` is properly set
2. Understand that the first run requires JIT compilation (this is normal)
3. Subsequent runs will be significantly faster with cached compilations
4. Consider using `--skip-server-warmup` to defer compilation until first request
### Low Throughput
If you're not achieving expected throughput:
1. Verify `--tp-size` matches your TPU core configuration
2. Check that `--attention-backend=fa` is enabled
3. Increase `--max-running-requests` to enable larger batch formation
4. Consider enabling speculative decoding for compatible models
5. Ensure memory settings allow for sufficient batch sizes
### Connection Issues
If clients cannot connect to the server:
1. Ensure `--host=0.0.0.0` for external access (not just `127.0.0.1`)
2. Verify firewall rules allow traffic on the specified port (default: 30000)
3. Check that the server process is running: `curl http://localhost:30000/health`
## Advanced Features
### Speculative Decoding
SGLang-JAX supports EAGLE and EAGLE3 speculative decoding algorithms for Qwen3 and LLaMA model families. Speculative decoding can improve throughput by 20-40% without affecting output quality.
See the [Speculative Decoding documentation](https://github.com/sgl-project/sglang-jax/blob/main/docs/features/speculative_decoding.md) for detailed configuration and supported model combinations.
### Chunked Prefill
Enable mixed prefill-decode batching for better TPU utilization:
```bash
--chunked-prefill-size=2048 --enable-mixed-chunk
```
This allows the scheduler to mix prefill operations with decode operations in the same batch, improving overall throughput.
### Custom Attention Backends
SGLang-JAX supports a plugin-based attention backend system. You can implement custom attention kernels optimized for specific use cases.
See the [Attention Backend documentation](https://github.com/sgl-project/sglang-jax/blob/main/docs/features/attention_backend.md) for implementation details.
### Environment Verification
Verify your TPU setup before deploying:
```bash
python -c "from sgl_jax import check_env; check_env.check_env()"
```
This command checks:
- Installed package versions
- TPU device availability and specifications
- System resources and configuration
- Compatibility of settings
## Contributing
We welcome contributions to improve TPU support in SGLang-JAX!
### Areas for Contribution
**Check the [Development Roadmap](https://github.com/sgl-project/sglang-jax/issues/190)** to see planned features and find opportunities to contribute new functionality.
Current contribution areas include:
- Performance optimizations for specific TPU generations
- Support for additional model architectures
- Documentation improvements and examples
- Bug reports and fixes
- Benchmark results and performance analysis
### How to Contribute
1. Visit the [sglang-jax repository](https://github.com/sgl-project/sglang-jax)
2. Read the [Contribution Guide](https://github.com/sgl-project/sglang-jax/blob/main/docs/developer_guide/contribution_guide.md)
3. Join the [SGL-JAX Slack community](https://sgl-fru7574.slack.com/archives/C09EBE5HT5X) for discussions
4. Report issues at [sglang-jax/issues](https://github.com/sgl-project/sglang-jax/issues)
### Testing on TPU
For contributors who need TPU access for testing:
- Refer to the [TPU Resources Guide](https://github.com/sgl-project/sglang-jax/blob/main/docs/developer_guide/tpu_resources_guide.md) for information on accessing TPU hardware
- Use SkyPilot with spot instances for cost-effective testing
- Follow the [Benchmark and Profiling Guide](https://github.com/sgl-project/sglang-jax/blob/main/docs/developer_guide/benchmark_and_profiling.md) for performance validation
## References
### Documentation
- [SGLang-JAX Repository](https://github.com/sgl-project/sglang-jax)
- [SGLang-JAX Installation Guide](https://github.com/sgl-project/sglang-jax/blob/main/docs/get_started/install.md)
- [Qwen Models Quick Start](https://github.com/sgl-project/sglang-jax/blob/main/docs/basic_usage/qwen.md)
- [Benchmark and Profiling Guide](https://github.com/sgl-project/sglang-jax/blob/main/docs/developer_guide/benchmark_and_profiling.md)
- [Speculative Decoding](https://github.com/sgl-project/sglang-jax/blob/main/docs/features/speculative_decoding.md)
### External Resources
- [JAX Documentation](https://jax.readthedocs.io/)
- [Google Cloud TPU Documentation](https://cloud.google.com/tpu/docs)
- [SkyPilot Documentation](https://skypilot.readthedocs.io/)