Improve model card: Add metadata, tags, and usage example
Browse filesThis PR significantly improves the model card by:
- Adding `pipeline_tag: text-generation` and `library_name: transformers` to the metadata, which enhances discoverability and indicates compatibility with the Hugging Face ecosystem.
- Including additional `tags` such as `speculative-decoding` and `inference-acceleration` for more granular filtering.
- Expanding the model description with an overview of `SpecDec++` derived from the paper's abstract, providing better context.
- Integrating a comprehensive and runnable Python code snippet directly from the paper's GitHub repository (`specdec_pp/sample.py`), guiding users on how to load and use the Acceptance Prediction Head for accelerated text generation.
The existing arXiv paper link and GitHub repository link are preserved. As no explicit project page URL was provided, it has not been included.
|
@@ -1,13 +1,84 @@
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
| 4 |
|
| 5 |
-
|
| 6 |
|
| 7 |
-
|
| 8 |
|
|
|
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
```bibtex
|
| 13 |
@article{huang2024specdec++,
|
|
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
+
pipeline_tag: text-generation
|
| 4 |
+
library_name: transformers
|
| 5 |
+
tags:
|
| 6 |
+
- speculative-decoding
|
| 7 |
+
- inference-acceleration
|
| 8 |
---
|
| 9 |
|
| 10 |
+
# SpecDec++: Boosting Speculative Decoding via Adaptive Candidate Lengths
|
| 11 |
|
| 12 |
+
Speculative decoding is a technique to significantly reduce the inference latency of large language models (LLMs) by utilizing a smaller and faster draft model. **SpecDec++** is an enhanced version of speculative decoding that adaptively determines the candidate length on the fly. It formulates this choice as a Markov Decision Process, theoretically showing that the optimal policy involves stopping speculation when the probability of rejection exceeds a threshold.
|
| 13 |
|
| 14 |
+
Motivated by this theory, SpecDec++ augments the draft model with a trained acceptance prediction head to predict the conditional acceptance probability of candidate tokens. This adaptive method achieves substantial speedups: 2.04x on the Alpaca dataset (7.2% improvement over baseline speculative decoding), 2.26x on GSM8K (9.4% improvement), and 2.23x on HumanEval (11.1% improvement).
|
| 15 |
|
| 16 |
+
This repository contains the **Acceptance Prediction Head for Llama-2-chat 7B and 70B model pair** trained with `weight_mismatch=6` and `resnet_num_layers=3`. It is recommended to be used with `stop_threshold=0.7`.
|
| 17 |
+
|
| 18 |
+
**Paper**: [SpecDec++: Boosting Speculative Decoding via Adaptive Candidate Lengths](https://arxiv.org/abs/2405.19715)
|
| 19 |
+
**Code**: [GitHub Repository](https://github.com/Kaffaljidhmah2/SpecDec_pp)
|
| 20 |
+
|
| 21 |
+
## Usage
|
| 22 |
+
|
| 23 |
+
To use this Acceptance Prediction Head for accelerated text generation with SpecDec++, you will need to integrate it with a base large language model (e.g., Llama-2-chat 7B) using the `EaModel` class provided in the original paper's repository.
|
| 24 |
+
|
| 25 |
+
First, clone the `SpecDec_pp` repository and install its dependencies:
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
git clone https://github.com/Kaffaljidhmah2/SpecDec_pp.git
|
| 29 |
+
cd SpecDec_pp
|
| 30 |
+
pip install -r requirements.txt
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
Then, you can use the following Python snippet, adapted from `specdec_pp/sample.py`, to perform accelerated generation:
|
| 34 |
+
|
| 35 |
+
```python
|
| 36 |
+
import torch
|
| 37 |
+
from transformers import AutoTokenizer
|
| 38 |
+
# EaModel is a custom class from the SpecDec_pp repository.
|
| 39 |
+
# Ensure the repository is cloned and its `specdec_pp` directory is accessible in your Python path.
|
| 40 |
+
from eagle.model.ea_model import EaModel
|
| 41 |
+
from fastchat.model import get_conversation_template
|
| 42 |
+
|
| 43 |
+
# Define the paths for your base Large Language Model and this Acceptance Prediction Head
|
| 44 |
+
# Replace with the actual model IDs or local paths
|
| 45 |
+
base_model_path = "meta-llama/Llama-2-7b-chat-hf" # Example: The base LLM to accelerate
|
| 46 |
+
ea_model_path = "hacky/acchead-llama2-chat-7bx70b" # This Acceptance Prediction Head checkpoint
|
| 47 |
+
|
| 48 |
+
# Load the EaModel, which integrates the base LLM and the acceptance prediction head
|
| 49 |
+
model = EaModel.from_pretrained(
|
| 50 |
+
base_model_path=base_model_path,
|
| 51 |
+
ea_model_path=ea_model_path,
|
| 52 |
+
torch_dtype=torch.float16, # Use appropriate precision (e.g., torch.float16 or torch.bfloat16)
|
| 53 |
+
low_cpu_mem_usage=True,
|
| 54 |
+
device_map="auto",
|
| 55 |
+
total_token=-1 # -1 enables adaptive candidate length as per SpecDec++
|
| 56 |
+
)
|
| 57 |
+
model.eval()
|
| 58 |
+
|
| 59 |
+
# Prepare your prompt using the correct chat template for the base model (e.g., for Llama-2-chat)
|
| 60 |
+
your_message = "What are the benefits of speculative decoding?"
|
| 61 |
+
conv = get_conversation_template("llama-2") # Use "vicuna" or "llama3" as needed for your base model
|
| 62 |
+
conv.append_message(conv.roles[0], your_message)
|
| 63 |
+
conv.append_message(conv.roles[1], None) # The assistant's response will be appended here
|
| 64 |
+
prompt = conv.get_prompt()
|
| 65 |
+
|
| 66 |
+
# Tokenize input and move to the appropriate device (e.g., GPU)
|
| 67 |
+
input_ids = model.tokenizer([prompt]).input_ids
|
| 68 |
+
input_ids = torch.as_tensor(input_ids).cuda() # Requires CUDA-enabled GPU
|
| 69 |
+
|
| 70 |
+
# Generate output using the `eagenerate` function for accelerated inference
|
| 71 |
+
with torch.no_grad():
|
| 72 |
+
output_ids = model.eagenerate(input_ids, temperature=0.7, max_new_tokens=256)
|
| 73 |
+
|
| 74 |
+
# Decode and print the generated text
|
| 75 |
+
output = model.tokenizer.decode(output_ids[0])
|
| 76 |
+
print(output)
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
## Citation
|
| 80 |
+
|
| 81 |
+
If you find this useful in your research, please consider citing our paper.
|
| 82 |
|
| 83 |
```bibtex
|
| 84 |
@article{huang2024specdec++,
|