yuyijiong nielsr HF Staff commited on
Commit
4484ec8
·
1 Parent(s): 2ecd604

Add metadata, paper/code links, and sample usage (#1)

Browse files

- Add metadata, paper/code links, and sample usage (8dea2d08154e9d8cce9b6698ede7da9016cac450)


Co-authored-by: Niels Rogge <nielsr@users.noreply.huggingface.co>

Files changed (1) hide show
  1. README.md +88 -90
README.md CHANGED
@@ -1,90 +1,88 @@
1
- # Speculation head checkpoints
2
-
3
- Pre-trained **pipeline speculation head** weights. Each `.pt` file is a single checkpoint produced by training; pair it with the **same** base model architecture it was trained on (see `config["base_model_path"]` inside the file).
4
-
5
- For inference, evaluation, and training examples, see the official repo:
6
- **https://github.com/yuyijiong/speculative_pipeline_decoding**
7
-
8
- ## Filename format
9
-
10
- Files are named:
11
-
12
- ```text
13
- {model}_s{num_stages}_l{num_spec_layers}.pt
14
- ```
15
-
16
- | Part | Meaning |
17
- |------|---------|
18
- | `{model}` | Base model tag from training config (e.g. `Qwen3.5-4B`, `Qwen3.5-9B`) |
19
- | `s{...}` | `num_stages` — pipeline depth (number of target-model stages) |
20
- | `l{...}` | `num_spec_layers` — number of Transformer layers in the speculation module |
21
-
22
- Example: `Qwen3.5-9B_s16_l2.pt` → Qwen3.5-9B base, 16 stages, 2 spec layers.
23
-
24
- ## Checkpoint contents
25
-
26
- Each file is a PyTorch archive with two top-level keys:
27
-
28
- ```python
29
- {
30
- "state_dict": ..., # weights of the speculation module
31
- "config": { ... }, # hyperparameters and metadata
32
- }
33
- ```
34
-
35
- ### `config` fields (always present)
36
-
37
- | Field | Description |
38
- |-------|-------------|
39
- | `base_model_path` | Base model path recorded at training time (often a machine-local path; override at load time — see below) |
40
- | `hidden_size` | Hidden size (matches base model) |
41
- | `vocab_size` | Base model vocabulary size |
42
- | `draft_vocab_size` | Draft head output size (full vocab or draft subset) |
43
- | `num_stages` | Pipeline depth (same as `s` in filename) |
44
- | `num_spec_layers` | Speculation module depth (same as `l` in filename) |
45
- | `version` | Checkpoint format version (`10`) |
46
- | `trained_with_use_deepest` | Whether training used deepest-layer features |
47
- | `shallow_hidden_layer_indices` | Which base layers feed the speculation module |
48
-
49
- ### `config` fields (optional)
50
-
51
- | Field | Description |
52
- |-------|-------------|
53
- | `spec_init_from_base_layers` | Base layers used to initialize the spec module (if any) |
54
- | `draft_token_ids` | Draft vocabulary token ids (only when trained with a draft vocab subset) |
55
-
56
- ## Loading checkpoints
57
-
58
- `config["base_model_path"]` is often a **local path from the training machine** (e.g. `/share/models/Qwen3.5-4B`). On your machine, pass the correct Hugging Face id or local directory via `--base_model_path`; it **overrides** the path stored in the checkpoint:
59
-
60
- ```bash
61
- python pipeline_inference.py \
62
- --spec_head_ckpt /path/to/Qwen3.5-4B_s4_l2.pt \
63
- --base_model_path Qwen/Qwen3.5-4B
64
-
65
- python eval.py \
66
- --spec_head_ckpt /path/to/Qwen3.5-4B_s4_l2.pt \
67
- --base_model_path /your/local/Qwen3.5-4B \
68
- --data_dir eval_data \
69
- --output_dir ./eval_output
70
- ```
71
-
72
- If `--base_model_path` is omitted, the value from `config["base_model_path"]` is used as-is.
73
-
74
- More usage details: [speculative_pipeline_decoding](https://github.com/yuyijiong/speculative_pipeline_decoding).
75
-
76
- ## Citation
77
-
78
- If you use this repo, please cite our paper:
79
-
80
- ```bibtex
81
- @misc{yu2026speculativepipelinedecodinghigheraccruacy,
82
- title={Speculative Pipeline Decoding: Higher-Accruacy and Zero-Bubble Speculation via Pipeline Parallelism},
83
- author={Yijiong Yu and Huazheng Wang and Shuai Yuan and Ruilong Ren and Ji Pei},
84
- year={2026},
85
- eprint={2605.30852},
86
- archivePrefix={arXiv},
87
- primaryClass={cs.CL},
88
- url={https://arxiv.org/abs/2605.30852},
89
- }
90
- ```
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: transformers
4
+ pipeline_tag: text-generation
5
+ tags:
6
+ - speculative-decoding
7
+ - pipeline-parallelism
8
+ - llm-acceleration
9
+ ---
10
+
11
+ # Speculative Pipeline Decoding: Speculation Head Checkpoints
12
+
13
+ This repository contains pre-trained **pipeline speculation head** weights for the paper [Speculative Pipeline Decoding: Higher-Accruacy and Zero-Bubble Speculation via Pipeline Parallelism](https://huggingface.co/papers/2605.30852).
14
+
15
+ Speculative Pipeline Decoding (SPD) is a framework that unlocks the potential of pipeline parallelism for LLM decoding acceleration. By partitioning the target LLM into $n$ pipeline stages, SPD allows the model to process $n$ tokens in parallel, achieving higher acceptance rates and zero latency bubbles.
16
+
17
+ - **Paper:** [https://huggingface.co/papers/2605.30852](https://huggingface.co/papers/2605.30852)
18
+ - **Code:** [https://github.com/yuyijiong/speculative_pipeline_decoding](https://github.com/yuyijiong/speculative_pipeline_decoding)
19
+
20
+ ## Quick Start (Inference)
21
+
22
+ To run inference using these checkpoints, clone the official repository and use the provided `pipeline_inference.py` script. You must pair the speculation head with the corresponding base model it was trained on.
23
+
24
+ ```bash
25
+ python pipeline_inference.py \
26
+ --spec_head_ckpt /path/to/checkpoint.pt \
27
+ --base_model_path Qwen/Qwen3.5-4B \
28
+ --max_new_tokens 100 \
29
+ --temperature 0.0
30
+ ```
31
+
32
+ ## Checkpoint Information
33
+
34
+ Each `.pt` file is a single checkpoint produced by training. For more details on training and evaluation, see the [official repo](https://github.com/yuyijiong/speculative_pipeline_decoding).
35
+
36
+ ### Filename format
37
+
38
+ Files are named:
39
+ `{model}_s{num_stages}_l{num_spec_layers}.pt`
40
+
41
+ | Part | Meaning |
42
+ |------|---------|
43
+ | `{model}` | Base model tag (e.g. `Qwen3.5-4B`, `Qwen3.5-9B`) |
44
+ | `s{...}` | `num_stages` — pipeline depth (number of target-model stages) |
45
+ | `l{...}` | `num_spec_layers` — number of Transformer layers in the speculation module |
46
+
47
+ Example: `Qwen3.5-9B_s16_l2.pt` → Qwen3.5-9B base, 16 stages, 2 spec layers.
48
+
49
+ ### Checkpoint contents
50
+
51
+ Each file is a PyTorch archive with two top-level keys:
52
+
53
+ ```python
54
+ {
55
+ "state_dict": ..., # weights of the speculation module
56
+ "config": { ... }, # hyperparameters and metadata
57
+ }
58
+ ```
59
+
60
+ ### `config` fields (always present)
61
+
62
+ | Field | Description |
63
+ |-------|-------------|
64
+ | `base_model_path` | Base model path recorded at training time (can be overridden via `--base_model_path` at load time) |
65
+ | `hidden_size` | Hidden size (matches base model) |
66
+ | `vocab_size` | Base model vocabulary size |
67
+ | `draft_vocab_size` | Draft head output size (full vocab or draft subset) |
68
+ | `num_stages` | Pipeline depth (same as `s` in filename) |
69
+ | `num_spec_layers` | Speculation module depth (same as `l` in filename) |
70
+ | `version` | Checkpoint format version (`10`) |
71
+ | `trained_with_use_deepest` | Whether training used deepest-layer features |
72
+ | `shallow_hidden_layer_indices` | Which base layers feed the speculation module |
73
+
74
+ ## Citation
75
+
76
+ If you use this work, please cite our paper:
77
+
78
+ ```bibtex
79
+ @misc{yu2026speculativepipelinedecodinghigheraccruacy,
80
+ title={Speculative Pipeline Decoding: Higher-Accruacy and Zero-Bubble Speculation via Pipeline Parallelism},
81
+ author={Yijiong Yu and Huazheng Wang and Shuai Yuan and Ruilong Ren and Ji Pei},
82
+ year={2026},
83
+ eprint={2605.30852},
84
+ archivePrefix={arXiv},
85
+ primaryClass={cs.CL},
86
+ url={https://arxiv.org/abs/2605.30852},
87
+ }
88
+ ```