nielsr HF Staff commited on
Commit
723e4ef
·
verified ·
1 Parent(s): 8d2795f

Improve model card: Add metadata, tags, and usage example

Browse files

This PR significantly improves the model card by:
- Adding `pipeline_tag: text-generation` and `library_name: transformers` to the metadata, which enhances discoverability and indicates compatibility with the Hugging Face ecosystem.
- Including additional `tags` such as `speculative-decoding` and `inference-acceleration` for more granular filtering.
- Expanding the model description with an overview of `SpecDec++` derived from the paper's abstract, providing better context.
- Integrating a comprehensive and runnable Python code snippet directly from the paper's GitHub repository (`specdec_pp/sample.py`), guiding users on how to load and use the Acceptance Prediction Head for accelerated text generation.

The existing arXiv paper link and GitHub repository link are preserved. As no explicit project page URL was provided, it has not been included.

Files changed (1) hide show
  1. README.md +74 -3
README.md CHANGED
@@ -1,13 +1,84 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
3
  ---
4
 
5
- The Acceptance Prediction Head for Llama-2-chat 7B and 70B model pair trained with `weight_mismatch=6` and `resnet_num_layers=3`. It is recommended to be used with `stop_threshold=0.7`. See [arxiv: 2405.19715](https://arxiv.org/abs/2405.19715) for more details.
6
 
7
- Usage: [GitHub](https://github.com/Kaffaljidhmah2/SpecDec_pp)
8
 
 
9
 
10
- ### Citation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  ```bibtex
13
  @article{huang2024specdec++,
 
1
  ---
2
  license: apache-2.0
3
+ pipeline_tag: text-generation
4
+ library_name: transformers
5
+ tags:
6
+ - speculative-decoding
7
+ - inference-acceleration
8
  ---
9
 
10
+ # SpecDec++: Boosting Speculative Decoding via Adaptive Candidate Lengths
11
 
12
+ Speculative decoding is a technique to significantly reduce the inference latency of large language models (LLMs) by utilizing a smaller and faster draft model. **SpecDec++** is an enhanced version of speculative decoding that adaptively determines the candidate length on the fly. It formulates this choice as a Markov Decision Process, theoretically showing that the optimal policy involves stopping speculation when the probability of rejection exceeds a threshold.
13
 
14
+ Motivated by this theory, SpecDec++ augments the draft model with a trained acceptance prediction head to predict the conditional acceptance probability of candidate tokens. This adaptive method achieves substantial speedups: 2.04x on the Alpaca dataset (7.2% improvement over baseline speculative decoding), 2.26x on GSM8K (9.4% improvement), and 2.23x on HumanEval (11.1% improvement).
15
 
16
+ This repository contains the **Acceptance Prediction Head for Llama-2-chat 7B and 70B model pair** trained with `weight_mismatch=6` and `resnet_num_layers=3`. It is recommended to be used with `stop_threshold=0.7`.
17
+
18
+ **Paper**: [SpecDec++: Boosting Speculative Decoding via Adaptive Candidate Lengths](https://arxiv.org/abs/2405.19715)
19
+ **Code**: [GitHub Repository](https://github.com/Kaffaljidhmah2/SpecDec_pp)
20
+
21
+ ## Usage
22
+
23
+ To use this Acceptance Prediction Head for accelerated text generation with SpecDec++, you will need to integrate it with a base large language model (e.g., Llama-2-chat 7B) using the `EaModel` class provided in the original paper's repository.
24
+
25
+ First, clone the `SpecDec_pp` repository and install its dependencies:
26
+
27
+ ```bash
28
+ git clone https://github.com/Kaffaljidhmah2/SpecDec_pp.git
29
+ cd SpecDec_pp
30
+ pip install -r requirements.txt
31
+ ```
32
+
33
+ Then, you can use the following Python snippet, adapted from `specdec_pp/sample.py`, to perform accelerated generation:
34
+
35
+ ```python
36
+ import torch
37
+ from transformers import AutoTokenizer
38
+ # EaModel is a custom class from the SpecDec_pp repository.
39
+ # Ensure the repository is cloned and its `specdec_pp` directory is accessible in your Python path.
40
+ from eagle.model.ea_model import EaModel
41
+ from fastchat.model import get_conversation_template
42
+
43
+ # Define the paths for your base Large Language Model and this Acceptance Prediction Head
44
+ # Replace with the actual model IDs or local paths
45
+ base_model_path = "meta-llama/Llama-2-7b-chat-hf" # Example: The base LLM to accelerate
46
+ ea_model_path = "hacky/acchead-llama2-chat-7bx70b" # This Acceptance Prediction Head checkpoint
47
+
48
+ # Load the EaModel, which integrates the base LLM and the acceptance prediction head
49
+ model = EaModel.from_pretrained(
50
+ base_model_path=base_model_path,
51
+ ea_model_path=ea_model_path,
52
+ torch_dtype=torch.float16, # Use appropriate precision (e.g., torch.float16 or torch.bfloat16)
53
+ low_cpu_mem_usage=True,
54
+ device_map="auto",
55
+ total_token=-1 # -1 enables adaptive candidate length as per SpecDec++
56
+ )
57
+ model.eval()
58
+
59
+ # Prepare your prompt using the correct chat template for the base model (e.g., for Llama-2-chat)
60
+ your_message = "What are the benefits of speculative decoding?"
61
+ conv = get_conversation_template("llama-2") # Use "vicuna" or "llama3" as needed for your base model
62
+ conv.append_message(conv.roles[0], your_message)
63
+ conv.append_message(conv.roles[1], None) # The assistant's response will be appended here
64
+ prompt = conv.get_prompt()
65
+
66
+ # Tokenize input and move to the appropriate device (e.g., GPU)
67
+ input_ids = model.tokenizer([prompt]).input_ids
68
+ input_ids = torch.as_tensor(input_ids).cuda() # Requires CUDA-enabled GPU
69
+
70
+ # Generate output using the `eagenerate` function for accelerated inference
71
+ with torch.no_grad():
72
+ output_ids = model.eagenerate(input_ids, temperature=0.7, max_new_tokens=256)
73
+
74
+ # Decode and print the generated text
75
+ output = model.tokenizer.decode(output_ids[0])
76
+ print(output)
77
+ ```
78
+
79
+ ## Citation
80
+
81
+ If you find this useful in your research, please consider citing our paper.
82
 
83
  ```bibtex
84
  @article{huang2024specdec++,