lagkv_cache / README.md
nielsr's picture
nielsr HF Staff
Add pipeline tag, license, and remove redundant section
fdbfb59 verified
|
raw
history blame
1.84 kB
metadata
library_name: transformers
tags:
  - custom_generate
license: mit
pipeline_tag: text-generation

LagKV Cache

Introduction

LagKV Cache diagram from the original paper

LagKV is an efficient and robust KV compression algorithm. It uses lag tokens information to compress the previous ones which significantly boost the compression performance with little computation overhead.

Original Github

Details are in the following work:

LagKV: Lag-Relative Information of the KV Cache Tells Which Tokens Are Important

Example usage

We can use the custom generation method in this repository like the the base generate from transformers:

# requires `transformers>=4.52.0`
from transformers import AutoModelForCausalLM, AutoTokenizer
# Preparing model, tokenizer, and model inputs
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", device_map="auto")
messages = [{"role": "user", "content": "Tell me a story about a cat."}]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    enable_thinking=False
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Using lagkv cache
gen_out = model.generate(
    # usual `generate` arguments
    **model_inputs,
    do_sample=False,
    max_new_tokens=100,
    return_dict_in_generate=True,
    # lagkv cache arguments (default `lag_ratio=0.5,lag_size=128,lag_sink_size=16`)
    custom_generate="CMB-AI-LAB/lagkv_cache",
    trust_remote_code=True,
)
print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True))
assert "lagkvcache" in str(type(gen_out.past_key_values)).lower()