docs: standardize model card for public release
Browse files
README.md
CHANGED
|
@@ -19,11 +19,13 @@ pipeline_tag: text-generation
|
|
| 19 |
|
| 20 |
A speculative decoding draft head for [Qwen/Qwen2.5-14B-Instruct](https://huggingface.co/Qwen/Qwen2.5-14B-Instruct), trained using the [EAGLE3](https://arxiv.org/abs/2503.01840) method on Google Cloud TPU with the [SpecJAX](https://github.com/tails-mpt/SpecJAX) framework.
|
| 21 |
|
| 22 |
-
EAGLE3 draft heads accelerate autoregressive generation by proposing multiple tokens per step that a target model then verifies in parallel — typically achieving 2
|
| 23 |
|
| 24 |
## Usage
|
| 25 |
|
| 26 |
-
### SGLang
|
|
|
|
|
|
|
| 27 |
|
| 28 |
```bash
|
| 29 |
python -m sglang.launch_server \
|
|
@@ -35,6 +37,21 @@ python -m sglang.launch_server \
|
|
| 35 |
--dtype bfloat16
|
| 36 |
```
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
### Python (SGLang client)
|
| 39 |
|
| 40 |
```python
|
|
@@ -55,7 +72,7 @@ llm = sgl.LLM(
|
|
| 55 |
| Parameter | Value |
|
| 56 |
|-----------|-------|
|
| 57 |
| Framework | [SpecJAX](https://github.com/tails-mpt/SpecJAX) — pure JAX, no Flax/PyTorch |
|
| 58 |
-
| Hardware | Google Cloud TPU v4-32 (4 hosts
|
| 59 |
| Dataset | 54K mixed: ShareGPT (45%) + UltraChat-200K (35%) + Open-PerfectBlend (20%) |
|
| 60 |
| Epochs | 3 |
|
| 61 |
| Steps | 4,983 per epoch |
|
|
@@ -84,7 +101,7 @@ Token acceptance rates on generic instruction-following data (ShareGPT-style pro
|
|
| 84 |
| acc_5 | 52.1% |
|
| 85 |
| acc_6 | 50.7% |
|
| 86 |
|
| 87 |
-
*Measured on held-out evaluation data. Actual throughput gains depend on hardware, prompt distribution, and
|
| 88 |
|
| 89 |
## Model Architecture
|
| 90 |
|
|
@@ -107,7 +124,7 @@ The draft head is a single-layer transformer that operates on the target model's
|
|
| 107 |
|
| 108 |
## License
|
| 109 |
|
| 110 |
-
This model is released under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
|
| 111 |
|
| 112 |
## References
|
| 113 |
|
|
|
|
| 19 |
|
| 20 |
A speculative decoding draft head for [Qwen/Qwen2.5-14B-Instruct](https://huggingface.co/Qwen/Qwen2.5-14B-Instruct), trained using the [EAGLE3](https://arxiv.org/abs/2503.01840) method on Google Cloud TPU with the [SpecJAX](https://github.com/tails-mpt/SpecJAX) framework.
|
| 21 |
|
| 22 |
+
EAGLE3 draft heads accelerate autoregressive generation by proposing multiple tokens per step that a target model then verifies in parallel — typically achieving 2-3x throughput gains with no change in output quality.
|
| 23 |
|
| 24 |
## Usage
|
| 25 |
|
| 26 |
+
### SGLang (GPU)
|
| 27 |
+
|
| 28 |
+
> **Note**: Qwen2.5 EAGLE3 support requires a small patch to SGLang (adding `set_eagle3_layers_to_capture()` to the Qwen2 model). See the [SpecJAX inference guide](https://github.com/tails-mpt/SpecJAX/tree/main/inference) for details.
|
| 29 |
|
| 30 |
```bash
|
| 31 |
python -m sglang.launch_server \
|
|
|
|
| 37 |
--dtype bfloat16
|
| 38 |
```
|
| 39 |
|
| 40 |
+
### sglang-jax (TPU)
|
| 41 |
+
|
| 42 |
+
> **Note**: Requires the same Qwen2 EAGLE3 patch applied to sglang-jax. The sglang-jax EAGLE3 pipeline is functional but not yet performance-optimized.
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
python -m sgl_jax.launch_server \
|
| 46 |
+
--model-path Qwen/Qwen2.5-14B-Instruct \
|
| 47 |
+
--speculative-algorithm EAGLE3 \
|
| 48 |
+
--speculative-draft-model-path thoughtworks/Qwen2.5-14B-Instruct-Eagle3 \
|
| 49 |
+
--speculative-eagle-topk 1 \
|
| 50 |
+
--speculative-num-steps 3 \
|
| 51 |
+
--speculative-num-draft-tokens 4 \
|
| 52 |
+
--tp-size 4 --dtype bfloat16
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
### Python (SGLang client)
|
| 56 |
|
| 57 |
```python
|
|
|
|
| 72 |
| Parameter | Value |
|
| 73 |
|-----------|-------|
|
| 74 |
| Framework | [SpecJAX](https://github.com/tails-mpt/SpecJAX) — pure JAX, no Flax/PyTorch |
|
| 75 |
+
| Hardware | Google Cloud TPU v4-32 (4 hosts x 4 chips, TP=4, DP=4) |
|
| 76 |
| Dataset | 54K mixed: ShareGPT (45%) + UltraChat-200K (35%) + Open-PerfectBlend (20%) |
|
| 77 |
| Epochs | 3 |
|
| 78 |
| Steps | 4,983 per epoch |
|
|
|
|
| 101 |
| acc_5 | 52.1% |
|
| 102 |
| acc_6 | 50.7% |
|
| 103 |
|
| 104 |
+
*Measured on held-out evaluation data. Actual throughput gains depend on hardware, prompt distribution, and runtime version.*
|
| 105 |
|
| 106 |
## Model Architecture
|
| 107 |
|
|
|
|
| 124 |
|
| 125 |
## License
|
| 126 |
|
| 127 |
+
This model is released under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0), consistent with the base model's license.
|
| 128 |
|
| 129 |
## References
|
| 130 |
|