ngocbh commited on
Commit
5d3eb18
·
1 Parent(s): 105fb1a
Files changed (1) hide show
  1. README.md +125 -0
README.md ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - open-r1/OpenR1-Math-220k
5
+ base_model:
6
+ - Qwen/Qwen3-8B
7
+ tags:
8
+ - math
9
+ - trimkv
10
+ - KV
11
+ - Cache
12
+ - Compression
13
+ ---
14
+
15
+ > TRIM-KV is an efficient and learnable key–value eviction strategy designed to improve the efficiency of large language models (LLMs) in long-horizon inference.
16
+
17
+ The core idea behind TRIM-KV is to learn the intrinsic importance of each key–value pair at creation time, which we call *token retention*, and then decay this importance exponentially over time to mimic the standard inference running with eviction.
18
+
19
+ The retention score is query-agnostic and captures the long-term utility of tokens. This is different from attention scores, which are query-dependent: they capture the short-term utility for predicting the next token and are recomputed at every step, making them local, myopic, and highly dependent on the transient decoding state.
20
+
21
+
22
+ ### Why TRIM-KV?
23
+
24
+ It's fast
25
+
26
+ <div align="center">
27
+ <img width="1000" alt="teaser" src="https://github.com/ngocbh/trimkv/blob/main/assets/speed.png?raw=true"/>
28
+ </div>
29
+
30
+ It's smart
31
+
32
+ <div align="center">
33
+ <img width="1000" alt="teaser" src="https://github.com/ngocbh/trimkv/blob/main/assets/performance.png?raw=true"/>
34
+ </div>
35
+
36
+
37
+ And it's interpretable
38
+
39
+ <div align="center">
40
+ <img width="1000" alt="teaser" src="https://github.com/ngocbh/trimkv/blob/main/assets/eviction.png?raw=true"/>
41
+ </div>
42
+
43
+ <div align="center">
44
+ <img width="1000" alt="teaser" src="https://github.com/ngocbh/trimkv/blob/main/assets/vis.png?raw=true"/>
45
+ </div>
46
+
47
+ ---
48
+
49
+ ## Getting Started
50
+
51
+ ### Requirements
52
+
53
+ - Python 3.11 or higher (tested with 3.12)
54
+ - PyTorch 2.7.0 or higher (tested with 2.8.0)
55
+ - FlashAttention 2.7.2.post1 or higher (tested with 2.8.0)
56
+ - Transformers 4.57.1
57
+
58
+ ```sh
59
+ pip install -r requirements.txt
60
+ ```
61
+
62
+ This is a minimal set of requirements for training purposes. Additional dependencies may be needed for running specific experiments. We provided a full example of the environment used in our experiments in [`examples/env.yaml`](examples/env.yaml).
63
+
64
+ ### Installation
65
+
66
+ From the root of the repo:
67
+
68
+ ```sh
69
+ git clone https://github.com/ngocbh/trimkv.git
70
+ cd trimkv
71
+ pip install -e .
72
+ ````
73
+
74
+ ---
75
+
76
+ ## Quick Start
77
+
78
+ ```python
79
+ import torch
80
+ from trimkv.models.qwen3 import TrimKVQwen3ForCausalLM
81
+ from trimkv.cache_utils import TrimKVCache
82
+ from transformers import AutoTokenizer
83
+
84
+ model_path = "<TrimKV model_path here>"
85
+ download_from = "huggingface" # options: "wandb", "local", "huggingface"
86
+
87
+ model = TrimKVQwen3ForCausalLM.from_pretrained(
88
+ model_path,
89
+ torch_dtype=torch.bfloat16,
90
+ load_trimkv_weights=True,
91
+ download_from=download_from,
92
+ use_cache=True,
93
+ device_map="cuda",
94
+ )
95
+
96
+ # Configure TRIM-KV settings
97
+ model.config._attn_implementation = "flash_attention_2"
98
+ model.config.compress_memory = True
99
+ model.config.memory_size = 512
100
+ model.config.buffer_size = 128
101
+
102
+ tokenizer = AutoTokenizer.from_pretrained(
103
+ model.config.base_model,
104
+ use_fast=True,
105
+ padding_side="left",
106
+ )
107
+
108
+ # Use model.generate as normal.
109
+ # Note: TRIM-KV uses TrimKVCache under the hood. So please pass TrimKVCache to model.generate
110
+ ```
111
+
112
+ For a runnable end-to-end example, see [`examples/test_qwen3.py`](examples/test_qwen3.py).
113
+
114
+ ## Released Models
115
+
116
+ | Base Model | TRIM-KV Checkpoints | Training Datasets | Max Context Len | Training $M$ |
117
+ |------------------------------|-----------------------------------------------|--------------------------|-------------------------|--------------|
118
+ | Qwen3-1.7B | [TRIM-KV-Qwen3-1.7B-Math](https://huggingface.co/ngocbh/TrimKV-Qwen3-1.7B-Math) | OpenR1-Math-220k | 16K | 512 |
119
+ | Qwen3-4B | [TRIM-KV-Qwen3-4B-Math](https://huggingface.co/ngocbh/TrimKV-Qwen3-4B-Math) | OpenR1-Math-220k | 16K | 512 |
120
+ | Qwen3-8B | [TRIM-KV-Qwen3-8B-Math](https://huggingface.co/ngocbh/TrimKV-Qwen3-8B-Math) | OpenR1-Math-220k | 16K | 512 |
121
+ | Qwen3-14B | [TRIM-KV-Qwen3-14B-Math](https://huggingface.co/ngocbh/TrimKV-Qwen3-14B-Math) | OpenR1-Math-220k | 16K | 512 |
122
+ | Qwen3-4B-Instruct-2507 | [TrimKV-Qwen3-4B-Instruct-2507](https://huggingface.co/ngocbh/TrimKV-Qwen3-4B-Instruct-2507) | Synth-Long, BookSum, Buddhi 128K | | 4096 |
123
+ | Phi-3-mini-128k-instruct | [TrimKV-Phi-3-mini-128k-instruct](https://huggingface.co/ngocbh/TrimKV-Phi-3-mini-128k-instruct) | LongAlpaca | 128K | 2048 |
124
+
125
+ ---