GRIFFIN: Effective Token Alignment for Faster Speculative Decoding

This repository contains the GRIFFIN model, a novel framework for accelerating inference in large language models, as described in the paper: GRIFFIN: Effective Token Alignment for Faster Speculative Decoding.

Abstract

Speculative decoding accelerates inference in large language models (LLMs) by generating multiple draft tokens simultaneously. However, existing methods often struggle with token misalignment between the training and decoding phases, limiting their performance. To address this, we propose GRIFFIN, a novel framework that incorporates a token-alignable training strategy and a token-alignable draft model to mitigate misalignment. The training strategy employs a loss masking mechanism to exclude highly misaligned tokens during training, preventing them from negatively impacting the draft model's optimization. The token-alignable draft model introduces input tokens to correct inconsistencies in generated features. Experiments on LLaMA, Vicuna, Qwen and Mixtral models demonstrate that GRIFFIN achieves an average acceptance length improvement of over 8% and a speedup ratio exceeding 7%, outperforming current speculative decoding state-of-the-art methods. Our code and GRIFFIN's draft models are released publicly in this https URL .

Code

The official implementation and training details can be found on the GitHub repository: https://github.com/hsj576/GRIFFIN

Overview

GRIFFIN is a novel framework designed to address token misalignment in speculative decoding. This repository provides the implementation of GRIFFIN, including its token-alignable training strategy and token-alignable draft model.

  • GRIFFIN is:
    • 4.2x faster than vanilla decoding.
    • 1.3x faster than EAGLE-2.

Acceleration demo of GRIFFIN for llama3-8B in a 4090GPU

demogif

GRIFFIN Weights

Base Model GRIFFIN on Hugging Face Base Model GRIFFIN on Hugging Face
Vicuna-7B-v1.5 husj576/GRIFFIN-Vicuna-7B-v1.5 LLaMA2-Chat 7B husj576/GRIFFIN-llama2-chat-7B
LLaMA3-Instruct 8B husj576/GRIFFIN-llama3-instruct-8B LLaMA2-Chat 13B husj576/GRIFFIN-llama2-chat-13B
LLaMA3-Instruct 70B husj576/GRIFFIN-llama3-instruct-70B Qwen2-Instruct 7B husj576/GRIFFIN-qwen2-instruct-7B

Inference (Sample Usage)

The inference code we provide automatically allocates model weights (loading a model across multiple GPUs), allowing you to run models that exceed the memory of a single GPU.

You can use our provided eagenerate function for speedup generation just like using generate from Hugging Face. Here is an example:

from model.ea_model_griffin import EaModel
from fastchat.model import get_conversation_template
import torch

base_model_path = "Qwen/Qwen2-7B-Instruct" # Example base model
EAGLE_model_path = "husj576/GRIFFIN-qwen2-instruct-7B" # Example GRIFFIN model

model = EaModel.from_pretrained(
    base_model_path=base_model_path,
    ea_model_path=EAGLE_model_path,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto",
    total_token=-1
)
model.eval()
your_message="Hello"
conv = get_conversation_template("llama3") # Adjust template based on your base model (vicuna, llama2, llama3)
conv.append_message(conv.roles[0], your_message)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids=model.tokenizer([prompt]).input_ids
input_ids = torch.as_tensor(input_ids).cuda()
output_ids=model.eagenerate(input_ids,temperature=0.5,max_new_tokens=512)
output=model.tokenizer.decode(output_ids[0])
print(output)

Note: Vicuna, LLaMA2-Chat, and LLaMA3-Instruct are both chat models. You need to use the correct chat template, otherwise it will cause abnormal output from the model and affect the performance of GRIFFIN.

Citation

If you find our work helpful or inspiring, please feel free to cite it.

@misc{hu2025griffineffectivetokenalignment,
      title={GRIFFIN: Effective Token Alignment for Faster Speculative Decoding},
      author={Shijing Hu and Jingyang Li and Xingyu Xie and Zhihui Lu and Kim-Chuan Toh and Pan Zhou},
      year={2025},
      eprint={2502.11018},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2502.11018},
}
Downloads last month
2
Safetensors
Model size
1B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Paper for husj576/GRIFFIN-qwen2-instruct-7B