ngocbh's picture
Update README.md
1df8151 verified
---
base_model:
- Qwen/Qwen3-4B
datasets:
- open-r1/OpenR1-Math-220k
license: apache-2.0
pipeline_tag: text-generation
tags:
- math
- dbtrimkv
- trimkv
- kv-cache
- compression
---
> DBTrimKV is the dynamic-budget variant of TrimKV: a single global KV budget is shared across layers and heads and reallocated on the fly, with the retention-gate's final projection tied across layers.
This repository hosts the **DBTrimKV** retention-gate weights for `Qwen/Qwen3-4B` (32768-token training context, M = 128). The base-model weights are not included — they are loaded from `Qwen/Qwen3-4B` at runtime and the retention-gate weights from `trimkv_weights.pth` are overlaid on top.
This model was introduced in the paper [Make Each Token Count: Towards Improving Long-Context Performance with KV Cache Eviction](https://huggingface.co/papers/2605.09649).
<a href="https://arxiv.org/pdf/2512.03324"><img src="https://img.shields.io/badge/arxiv-2512.03324-red?style=for-the-badge"></a>
For the full list of released checkpoints, training recipes, and benchmark scripts, see the GitHub repository: **https://github.com/ngocbh/trimkv**.
## Quick start
To use this model, please install the `trimkv` library from the [GitHub repo](https://github.com/ngocbh/trimkv).
```python
import torch
from trimkv.models.qwen3 import TrimKVQwen3ForCausalLM
from trimkv.cache_utils import PagedTrimKVCache
from transformers import AutoTokenizer
model = TrimKVQwen3ForCausalLM.from_pretrained(
"ngocbh/DBTrimKV-Qwen3-4B-Math",
torch_dtype=torch.bfloat16,
load_trimkv_weights=True,
download_from="huggingface",
use_cache=True,
device_map="cuda",
)
model.config._attn_implementation = "flash_attention_2"
tokenizer = AutoTokenizer.from_pretrained(
model.config.base_model, use_fast=True, padding_side="left"
)
past_key_values = PagedTrimKVCache(
num_layers=model.config.num_hidden_layers,
num_heads=model.config.num_key_value_heads,
max_seq_len=32768,
memory_size=128,
num_blocks_ratio=1.0,
buffer_size=32,
strategy="fixed_budget",
device="cuda",
)
# Use as a normal HF model — pass `past_key_values=past_key_values` to .generate
```
See [`examples/test_qwen3.py`](https://github.com/ngocbh/trimkv/blob/main/examples/test_qwen3.py) in the GitHub repo for a full runnable example.
## Training details
- **Base model**: `Qwen/Qwen3-4B`
- **Variant**: **DBTrimKV** (`retention_gate=rg10`)
- **Training dataset**: `open-r1/OpenR1-Math-220k`
- **Training memory size M**: `128`
- **Training context length**: `32768`
- **Loss**: `fwkl_ntp`
- **Attention impl**: `rg_attn_flex`
## Citation
```bibtex
@article{bui2025cache,
title={Cache what lasts: Token retention for memory-bounded kv cache in llms},
author={Bui, Ngoc and Sharma, Shubham and Lamba, Simran and Mishra, Saumitra and Ying, Rex},
journal={arXiv preprint arXiv:2512.03324},
year={2025}
}
@article{bui2025make,
title={Make Each Token Count: Towards Improving Long-Context Performance with KV Cache Eviction},
author={Bui, Ngoc and Nguyen, Hieu Trung and Cohan, Arman and Ying, Rex},
journal={arXiv preprint arXiv:2512.03324},
year={2025}
}
```