File size: 4,891 Bytes
182f1e3
 
 
 
 
 
 
 
 
 
 
 
fb50d14
 
 
 
 
 
 
 
5aa0abe
fb50d14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad4c719
fb50d14
 
 
 
 
3b7c315
fb50d14
3b7c315
fb50d14
182f1e3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
---
license: apache-2.0
datasets:
- open-r1/OpenR1-Math-220k
base_model:
- Qwen/Qwen3-4B
tags:
- math
- trimkv
- KV
- Cache
- Compression
---

> TRIM-KV is an efficient and learnable key–value eviction strategy designed to improve the efficiency of large language models (LLMs) in long-horizon inference.

The core idea behind TRIM-KV is to learn the intrinsic importance of each key–value pair at creation time, which we call *token retention*, and then decay this importance exponentially over time to mimic the standard inference running with eviction. 

The retention score is query-agnostic and captures the long-term utility of tokens. This is different from attention scores, which are query-dependent: they capture the short-term utility for predicting the next token and are recomputed at every step, making them local, myopic, and highly dependent on the transient decoding state.

<a href="https://arxiv.org/pdf/2512.03324"><img src="https://img.shields.io/badge/arxiv-2512.03324-red?style=for-the-badge"></a>

### Why TRIM-KV?

It's fast

<div align="center">
    <img width="1000" alt="teaser" src="https://github.com/ngocbh/trimkv/blob/main/assets/speed.png?raw=true"/>
</div>

It's smart

<div align="center">
    <img width="1000" alt="teaser" src="https://github.com/ngocbh/trimkv/blob/main/assets/performance.png?raw=true"/>
</div>


And it's interpretable

<div align="center">
    <img width="1000" alt="teaser" src="https://github.com/ngocbh/trimkv/blob/main/assets/eviction.png?raw=true"/>
</div>

<div align="center">
    <img width="1000" alt="teaser" src="https://github.com/ngocbh/trimkv/blob/main/assets/vis.png?raw=true"/>
</div>

---

## Getting Started

### Requirements

- Python 3.11 or higher (tested with 3.12)
- PyTorch 2.7.0 or higher (tested with 2.8.0)
- FlashAttention 2.7.2.post1 or higher (tested with 2.8.0)
- Transformers 4.57.1

```sh
pip install -r requirements.txt
```

This is a minimal set of requirements for training purposes. Additional dependencies may be needed for running specific experiments. We provided a full example of the environment used in our experiments in [`examples/env.yaml`](examples/env.yaml).

### Installation

From the root of the repo:

```sh
git clone https://github.com/ngocbh/trimkv.git
cd trimkv
pip install -e .
````

---

## Quick Start

```python
import torch
from trimkv.models.qwen3 import TrimKVQwen3ForCausalLM
from trimkv.cache_utils import TrimKVCache
from transformers import AutoTokenizer

model_path = "<TrimKV model_path here>"
download_from = "huggingface"  # options: "wandb", "local", "huggingface"

model = TrimKVQwen3ForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    load_trimkv_weights=True,
    download_from=download_from,
    use_cache=True,
    device_map="cuda",
)

# Configure TRIM-KV settings
model.config._attn_implementation = "flash_attention_2"
model.config.compress_memory = True
model.config.memory_size = 512
model.config.buffer_size = 128

tokenizer = AutoTokenizer.from_pretrained(
    model.config.base_model,
    use_fast=True,
    padding_side="left",
)

# Use model.generate as normal.
# Note: TRIM-KV uses TrimKVCache under the hood. So please pass TrimKVCache to model.generate
```

For a runnable end-to-end example, see [`examples/test_qwen3.py`](examples/test_qwen3.py).

## Released Models

| Base Model                    | TRIM-KV Checkpoints                           | Training Datasets        | Training Context Len | Training $M$ |
|------------------------------|-----------------------------------------------|--------------------------|-------------------------|--------------|
| Qwen3-1.7B                     | [TRIM-KV-Qwen3-1.7B-Math](https://huggingface.co/ngocbh/TrimKV-Qwen3-1.7B-Math)            | OpenR1-Math-220k          | 16K   | 512     |
| Qwen3-4B                     | [TRIM-KV-Qwen3-4B-Math](https://huggingface.co/ngocbh/TrimKV-Qwen3-4B-Math)            | OpenR1-Math-220k          | 16K   | 512     |
| Qwen3-8B                     | [TRIM-KV-Qwen3-8B-Math](https://huggingface.co/ngocbh/TrimKV-Qwen3-8B-Math)            |  OpenR1-Math-220k          | 16K   | 512     |
| Qwen3-14B                    | [TRIM-KV-Qwen3-14B-Math](https://huggingface.co/ngocbh/TrimKV-Qwen3-14B-Math)           |  OpenR1-Math-220k         | 16K   | 512     |
| Qwen3-4B-Instruct-2507 | [TrimKV-Qwen3-4B-Instruct-2507](https://huggingface.co/ngocbh/TrimKV-Qwen3-4B-Instruct-2507) | Synth-Long, BookSum, Buddhi      |  128K      | 4096     |
| Phi-3-mini-128k-instruct | [TrimKV-Phi-3-mini-128k-instruct](https://huggingface.co/ngocbh/TrimKV-Phi-3-mini-128k-instruct) | LongAlpaca          |  128K  | 2048 |
| DeepSeek-R1-Distill-Llama-8B                    | [TrimKV-DeepSeek-R1-Distill-Llama-8B](https://huggingface.co/ngocbh/TrimKV-DeepSeek-R1-Distill-Llama-8B)           |  OpenR1-Math-220k         | 32K   | 512     |

---