ngocbh commited on
Commit
fe37a5f
·
verified ·
1 Parent(s): dea95bb

Upload DBTrimKV checkpoint (Qwen3-4B, OpenR1-Math-220k)

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -33
  2. README.md +73 -0
  3. config.json +87 -0
  4. trimkv_weights.pth +3 -0
.gitattributes CHANGED
@@ -1,35 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pth filter=lfs diff=lfs merge=lfs -text
2
+ *.bin filter=lfs diff=lfs merge=lfs -text
3
  *.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
README.md ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - open-r1/OpenR1-Math-220k
5
+ base_model:
6
+ - Qwen/Qwen3-4B
7
+ tags:
8
+ - math
9
+ - dbtrimkv
10
+ - trimkv
11
+ - kv-cache
12
+ - compression
13
+ ---
14
+
15
+ > DBTrimKV is the dynamic-budget variant of TrimKV: a single global KV budget is shared across layers and heads and reallocated on the fly, with the retention-gate's final projection tied across layers.
16
+
17
+ This repository hosts the **DBTrimKV** retention-gate weights for `Qwen/Qwen3-4B` (32768-token training context, M = 128). The base-model weights are not included — they are loaded from `Qwen/Qwen3-4B` at runtime and the retention-gate weights from `trimkv_weights.pth` are overlaid on top.
18
+
19
+ <a href="https://arxiv.org/pdf/2512.03324"><img src="https://img.shields.io/badge/arxiv-2512.03324-red?style=for-the-badge"></a>
20
+
21
+ For the full list of released checkpoints, training recipes, and benchmark scripts, see the GitHub repository: **https://github.com/ngocbh/trimkv**.
22
+
23
+ ## Quick start
24
+
25
+ ```python
26
+ import torch
27
+ from trimkv.models.qwen3 import TrimKVQwen3ForCausalLM
28
+ from trimkv.cache_utils import PagedTrimKVCache
29
+ from transformers import AutoTokenizer
30
+
31
+ model = TrimKVQwen3ForCausalLM.from_pretrained(
32
+ "ngocbh/DBTrimKV-Qwen3-4B-Math",
33
+ torch_dtype=torch.bfloat16,
34
+ load_trimkv_weights=True,
35
+ download_from="huggingface",
36
+ use_cache=True,
37
+ device_map="cuda",
38
+ )
39
+ model.config._attn_implementation = "flash_attention_2"
40
+
41
+ tokenizer = AutoTokenizer.from_pretrained(
42
+ model.config.base_model, use_fast=True, padding_side="left"
43
+ )
44
+
45
+ past_key_values = PagedTrimKVCache(
46
+ num_layers=model.config.num_hidden_layers,
47
+ num_heads=model.config.num_key_value_heads,
48
+ max_seq_len=32768,
49
+ memory_size=128,
50
+ num_blocks_ratio=1.0,
51
+ buffer_size=32,
52
+ strategy="fixed_budget",
53
+ device="cuda",
54
+ )
55
+
56
+ # Use as a normal HF model — pass `past_key_values=past_key_values` to .generate
57
+ ```
58
+
59
+ See [`examples/test_qwen3.py`](https://github.com/ngocbh/trimkv/blob/main/examples/test_qwen3.py) in the GitHub repo for a full runnable example.
60
+
61
+ ## Training details
62
+
63
+ - Base model: `Qwen/Qwen3-4B`
64
+ - Variant: **DBTrimKV** (`retention_gate=rg10`)
65
+ - Training dataset: open-r1/OpenR1-Math-220k
66
+ - Training memory size M: `128`
67
+ - Training context length: `32768`
68
+ - Loss: `fwkl_ntp`
69
+ - Attention impl: `rg_attn_flex`
70
+
71
+ ## Citation
72
+
73
+ For the up-to-date BibTeX entry, see the [GitHub repository](https://github.com/ngocbh/trimkv).
config.json ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_threshold": 0.0,
3
+ "architectures": [
4
+ "TrimKVQwen3ForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "attn_impl": "rg_attn_flex",
9
+ "base_loss": "fwkl_ntp",
10
+ "base_model": "Qwen/Qwen3-4B",
11
+ "bos_token_id": 151643,
12
+ "buffer_size": 128,
13
+ "compress_memory": true,
14
+ "compress_strategy": "alpha",
15
+ "dtype": "bfloat16",
16
+ "eos_token_id": 151645,
17
+ "floor_budget_ratio": 0.0,
18
+ "global_capacity": true,
19
+ "head_dim": 128,
20
+ "hidden_act": "silu",
21
+ "hidden_size": 2560,
22
+ "initializer_range": 0.02,
23
+ "intermediate_size": 9728,
24
+ "layer_types": [
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention",
44
+ "full_attention",
45
+ "full_attention",
46
+ "full_attention",
47
+ "full_attention",
48
+ "full_attention",
49
+ "full_attention",
50
+ "full_attention",
51
+ "full_attention",
52
+ "full_attention",
53
+ "full_attention",
54
+ "full_attention",
55
+ "full_attention",
56
+ "full_attention",
57
+ "full_attention",
58
+ "full_attention",
59
+ "full_attention",
60
+ "full_attention"
61
+ ],
62
+ "logit_block_size": 16384,
63
+ "max_position_embeddings": 40960,
64
+ "max_seq_len": 32768,
65
+ "max_window_layers": 36,
66
+ "memory_size": 128.0,
67
+ "model_type": "qwen3",
68
+ "num_attention_heads": 32,
69
+ "num_hidden_layers": 36,
70
+ "num_key_value_heads": 8,
71
+ "retention_gate": "rg10",
72
+ "retention_gate_bias_init": 18.0,
73
+ "retention_gate_intermediate_size": 512,
74
+ "retention_weight": 1.0,
75
+ "rg_dropout": 0.0,
76
+ "rms_norm_eps": 1e-06,
77
+ "rope_scaling": null,
78
+ "rope_theta": 1000000,
79
+ "sliding_window": null,
80
+ "tie_retention_gate_layers": true,
81
+ "tie_word_embeddings": true,
82
+ "trainable_params": "self_attn.retention_gate",
83
+ "transformers_version": "4.57.1",
84
+ "use_cache": false,
85
+ "use_sliding_window": false,
86
+ "vocab_size": 151936
87
+ }
trimkv_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b50190691b4b7621917777fe087ea8c9caae917cac07fd873f28b0285c2a70ef
3
+ size 113361053