GRIFFIN: Effective Token Alignment for Faster Speculative Decoding
This repository contains the GRIFFIN model, a novel framework for accelerating inference in large language models, as described in the paper: GRIFFIN: Effective Token Alignment for Faster Speculative Decoding.
Abstract
Speculative decoding accelerates inference in large language models (LLMs) by generating multiple draft tokens simultaneously. However, existing methods often struggle with token misalignment between the training and decoding phases, limiting their performance. To address this, we propose GRIFFIN, a novel framework that incorporates a token-alignable training strategy and a token-alignable draft model to mitigate misalignment. The training strategy employs a loss masking mechanism to exclude highly misaligned tokens during training, preventing them from negatively impacting the draft model's optimization. The token-alignable draft model introduces input tokens to correct inconsistencies in generated features. Experiments on LLaMA, Vicuna, Qwen and Mixtral models demonstrate that GRIFFIN achieves an average acceptance length improvement of over 8% and a speedup ratio exceeding 7%, outperforming current speculative decoding state-of-the-art methods. Our code and GRIFFIN's draft models are released publicly in this https URL .
Code
The official implementation and training details can be found on the GitHub repository: https://github.com/hsj576/GRIFFIN
Overview
GRIFFIN is a novel framework designed to address token misalignment in speculative decoding. This repository provides the implementation of GRIFFIN, including its token-alignable training strategy and token-alignable draft model.
- GRIFFIN is:
- 4.2x faster than vanilla decoding.
- 1.3x faster than EAGLE-2.
Acceleration demo of GRIFFIN for llama3-8B in a 4090GPU
GRIFFIN Weights
| Base Model | GRIFFIN on Hugging Face | Base Model | GRIFFIN on Hugging Face |
|---|---|---|---|
| Vicuna-7B-v1.5 | husj576/GRIFFIN-Vicuna-7B-v1.5 | LLaMA2-Chat 7B | husj576/GRIFFIN-llama2-chat-7B |
| LLaMA3-Instruct 8B | husj576/GRIFFIN-llama3-instruct-8B | LLaMA2-Chat 13B | husj576/GRIFFIN-llama2-chat-13B |
| LLaMA3-Instruct 70B | husj576/GRIFFIN-llama3-instruct-70B | Qwen2-Instruct 7B | husj576/GRIFFIN-qwen2-instruct-7B |
Inference (Sample Usage)
The inference code we provide automatically allocates model weights (loading a model across multiple GPUs), allowing you to run models that exceed the memory of a single GPU.
You can use our provided eagenerate function for speedup generation just like using generate from Hugging Face. Here is an example:
from model.ea_model_griffin import EaModel
from fastchat.model import get_conversation_template
import torch
base_model_path = "Qwen/Qwen2-7B-Instruct" # Example base model
EAGLE_model_path = "husj576/GRIFFIN-qwen2-instruct-7B" # Example GRIFFIN model
model = EaModel.from_pretrained(
base_model_path=base_model_path,
ea_model_path=EAGLE_model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
total_token=-1
)
model.eval()
your_message="Hello"
conv = get_conversation_template("llama3") # Adjust template based on your base model (vicuna, llama2, llama3)
conv.append_message(conv.roles[0], your_message)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids=model.tokenizer([prompt]).input_ids
input_ids = torch.as_tensor(input_ids).cuda()
output_ids=model.eagenerate(input_ids,temperature=0.5,max_new_tokens=512)
output=model.tokenizer.decode(output_ids[0])
print(output)
Note: Vicuna, LLaMA2-Chat, and LLaMA3-Instruct are both chat models. You need to use the correct chat template, otherwise it will cause abnormal output from the model and affect the performance of GRIFFIN.
Citation
If you find our work helpful or inspiring, please feel free to cite it.
@misc{hu2025griffineffectivetokenalignment,
title={GRIFFIN: Effective Token Alignment for Faster Speculative Decoding},
author={Shijing Hu and Jingyang Li and Xingyu Xie and Zhihui Lu and Kim-Chuan Toh and Pan Zhou},
year={2025},
eprint={2502.11018},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2502.11018},
}
- Downloads last month
- 2