|
|
--- |
|
|
license: apache-2.0 |
|
|
datasets: |
|
|
- open-r1/OpenR1-Math-220k |
|
|
base_model: |
|
|
- Qwen/Qwen3-8B |
|
|
tags: |
|
|
- math |
|
|
- trimkv |
|
|
- KV |
|
|
- Cache |
|
|
- Compression |
|
|
--- |
|
|
|
|
|
> TRIM-KV is an efficient and learnable key–value eviction strategy designed to improve the efficiency of large language models (LLMs) in long-horizon inference. |
|
|
|
|
|
The core idea behind TRIM-KV is to learn the intrinsic importance of each key–value pair at creation time, which we call *token retention*, and then decay this importance exponentially over time to mimic the standard inference running with eviction. |
|
|
|
|
|
The retention score is query-agnostic and captures the long-term utility of tokens. This is different from attention scores, which are query-dependent: they capture the short-term utility for predicting the next token and are recomputed at every step, making them local, myopic, and highly dependent on the transient decoding state. |
|
|
|
|
|
|
|
|
<a href="https://arxiv.org/pdf/2512.03324"><img src="https://img.shields.io/badge/arxiv-2512.03324-red?style=for-the-badge"></a> |
|
|
|
|
|
|
|
|
### Why TRIM-KV? |
|
|
|
|
|
It's fast |
|
|
|
|
|
<div align="center"> |
|
|
<img width="1000" alt="teaser" src="https://github.com/ngocbh/trimkv/blob/main/assets/speed.png?raw=true"/> |
|
|
</div> |
|
|
|
|
|
It's smart |
|
|
|
|
|
<div align="center"> |
|
|
<img width="1000" alt="teaser" src="https://github.com/ngocbh/trimkv/blob/main/assets/performance.png?raw=true"/> |
|
|
</div> |
|
|
|
|
|
|
|
|
And it's interpretable |
|
|
|
|
|
<div align="center"> |
|
|
<img width="1000" alt="teaser" src="https://github.com/ngocbh/trimkv/blob/main/assets/eviction.png?raw=true"/> |
|
|
</div> |
|
|
|
|
|
<div align="center"> |
|
|
<img width="1000" alt="teaser" src="https://github.com/ngocbh/trimkv/blob/main/assets/vis.png?raw=true"/> |
|
|
</div> |
|
|
|
|
|
--- |
|
|
|
|
|
## Getting Started |
|
|
|
|
|
### Requirements |
|
|
|
|
|
- Python 3.11 or higher (tested with 3.12) |
|
|
- PyTorch 2.7.0 or higher (tested with 2.8.0) |
|
|
- FlashAttention 2.7.2.post1 or higher (tested with 2.8.0) |
|
|
- Transformers 4.57.1 |
|
|
|
|
|
```sh |
|
|
pip install -r requirements.txt |
|
|
``` |
|
|
|
|
|
This is a minimal set of requirements for training purposes. Additional dependencies may be needed for running specific experiments. We provided a full example of the environment used in our experiments in [`examples/env.yaml`](examples/env.yaml). |
|
|
|
|
|
### Installation |
|
|
|
|
|
From the root of the repo: |
|
|
|
|
|
```sh |
|
|
git clone https://github.com/ngocbh/trimkv.git |
|
|
cd trimkv |
|
|
pip install -e . |
|
|
```` |
|
|
|
|
|
--- |
|
|
|
|
|
## Quick Start |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from trimkv.models.qwen3 import TrimKVQwen3ForCausalLM |
|
|
from trimkv.cache_utils import TrimKVCache |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
model_path = "<TrimKV model_path here>" |
|
|
download_from = "huggingface" # options: "wandb", "local", "huggingface" |
|
|
|
|
|
model = TrimKVQwen3ForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
load_trimkv_weights=True, |
|
|
download_from=download_from, |
|
|
use_cache=True, |
|
|
device_map="cuda", |
|
|
) |
|
|
|
|
|
# Configure TRIM-KV settings |
|
|
model.config._attn_implementation = "flash_attention_2" |
|
|
model.config.compress_memory = True |
|
|
model.config.memory_size = 512 |
|
|
model.config.buffer_size = 128 |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
model.config.base_model, |
|
|
use_fast=True, |
|
|
padding_side="left", |
|
|
) |
|
|
|
|
|
# Use model.generate as normal. |
|
|
# Note: TRIM-KV uses TrimKVCache under the hood. So please pass TrimKVCache to model.generate |
|
|
``` |
|
|
|
|
|
For a runnable end-to-end example, see [`examples/test_qwen3.py`](examples/test_qwen3.py). |
|
|
|
|
|
## Released Models |
|
|
|
|
|
| Base Model | TRIM-KV Checkpoints | Training Datasets | Training Context Len | Training $M$ | |
|
|
|------------------------------|-----------------------------------------------|--------------------------|-------------------------|--------------| |
|
|
| Qwen3-1.7B | [TRIM-KV-Qwen3-1.7B-Math](https://huggingface.co/ngocbh/TrimKV-Qwen3-1.7B-Math) | OpenR1-Math-220k | 16K | 512 | |
|
|
| Qwen3-4B | [TRIM-KV-Qwen3-4B-Math](https://huggingface.co/ngocbh/TrimKV-Qwen3-4B-Math) | OpenR1-Math-220k | 16K | 512 | |
|
|
| Qwen3-8B | [TRIM-KV-Qwen3-8B-Math](https://huggingface.co/ngocbh/TrimKV-Qwen3-8B-Math) | OpenR1-Math-220k | 16K | 512 | |
|
|
| Qwen3-14B | [TRIM-KV-Qwen3-14B-Math](https://huggingface.co/ngocbh/TrimKV-Qwen3-14B-Math) | OpenR1-Math-220k | 16K | 512 | |
|
|
| Qwen3-4B-Instruct-2507 | [TrimKV-Qwen3-4B-Instruct-2507](https://huggingface.co/ngocbh/TrimKV-Qwen3-4B-Instruct-2507) | Synth-Long, BookSum, Buddhi | 128K | 4096 | |
|
|
| Phi-3-mini-128k-instruct | [TrimKV-Phi-3-mini-128k-instruct](https://huggingface.co/ngocbh/TrimKV-Phi-3-mini-128k-instruct) | LongAlpaca | 128K | 2048 | |
|
|
|
|
|
--- |