DBTrimKV is the dynamic-budget variant of TrimKV: a single global KV budget is shared across layers and heads and reallocated on the fly, with the retention-gate's final projection tied across layers.
This repository hosts the DBTrimKV retention-gate weights for Qwen/Qwen3-4B (32768-token training context, M = 128). The base-model weights are not included — they are loaded from Qwen/Qwen3-4B at runtime and the retention-gate weights from trimkv_weights.pth are overlaid on top.
For the full list of released checkpoints, training recipes, and benchmark scripts, see the GitHub repository: https://github.com/ngocbh/trimkv.
Quick start
import torch
from trimkv.models.qwen3 import TrimKVQwen3ForCausalLM
from trimkv.cache_utils import PagedTrimKVCache
from transformers import AutoTokenizer
model = TrimKVQwen3ForCausalLM.from_pretrained(
"ngocbh/DBTrimKV-Qwen3-4B-Math",
torch_dtype=torch.bfloat16,
load_trimkv_weights=True,
download_from="huggingface",
use_cache=True,
device_map="cuda",
)
model.config._attn_implementation = "flash_attention_2"
tokenizer = AutoTokenizer.from_pretrained(
model.config.base_model, use_fast=True, padding_side="left"
)
past_key_values = PagedTrimKVCache(
num_layers=model.config.num_hidden_layers,
num_heads=model.config.num_key_value_heads,
max_seq_len=32768,
memory_size=128,
num_blocks_ratio=1.0,
buffer_size=32,
strategy="fixed_budget",
device="cuda",
)
# Use as a normal HF model — pass `past_key_values=past_key_values` to .generate
See examples/test_qwen3.py in the GitHub repo for a full runnable example.
Training details
- Base model:
Qwen/Qwen3-4B - Variant: DBTrimKV (
retention_gate=rg10) - Training dataset: open-r1/OpenR1-Math-220k
- Training memory size M:
128 - Training context length:
32768 - Loss:
fwkl_ntp - Attention impl:
rg_attn_flex
Citation
For the up-to-date BibTeX entry, see the GitHub repository.
- Downloads last month
- 14