harshithsaiv commited on
Commit
abfc070
Β·
1 Parent(s): 0774ec2

feat: Update Readme

Browse files
Files changed (1) hide show
  1. README.md +27 -72
README.md CHANGED
@@ -1,53 +1,5 @@
1
- # Per-Head Mixed-Precision KV Cache Compression
2
-
3
- > Calibrate once. Compress smarter. Same quality.
4
-
5
- Most KV cache quantization treats every attention head equally.
6
- This is wrong. Some heads are **26x more sensitive** to quantization than others.
7
- We measure this, allocate bits accordingly, and get better compression than uniform 8-bit with zero quality loss.
8
-
9
- ---
10
-
11
- ## Results
12
-
13
- ![Memory vs Context](figures/memory_vs_context_both.png)
14
-
15
- ![Compression](figures/compression_bar_both.png)
16
-
17
- | Model | Method | Avg Bits | KV @ 8K | vs FP16 | vs 8-bit | Perplexity | Speed |
18
- |-------|--------|----------|---------|---------|---------|------------|-------|
19
- | Mistral-7B | FP16 Baseline | 16 | 1073 MB | 1.0x | β€” | 14.23 | 37.2 t/s |
20
- | Mistral-7B | Uniform 8-bit | 8 | 537 MB | 2.0x | 1.0x | ~same | ~same |
21
- | Mistral-7B | **Per-Head Mixed (Ours)** | **6.95** | **467 MB** | **2.3x** | **1.15x** | **14.23** | **37.2 t/s** |
22
- | Llama-3-8B | FP16 Baseline | 16 | 1073 MB | 1.0x | β€” | 20.7 | 36.7 t/s |
23
- | Llama-3-8B | Uniform 8-bit | 8 | 537 MB | 2.0x | 1.0x | ~same | ~same |
24
- | Llama-3-8B | **Per-Head Mixed (Ours)** | **7.84** | **526 MB** | **2.04x** | **1.02x** | **20.7** | **36.7 t/s** |
25
-
26
- ---
27
-
28
- ## Long Context Results
29
-
30
- ![Long Context](figures/long_context_both.png)
31
-
32
- ![OOM Story](figures/oom_story.png)
33
-
34
- | Context | FP16 (Mistral) | Ours (Mistral) | FP16 (Llama) | Ours (Llama) |
35
- |---------|---------------|----------------|--------------|--------------|
36
- | 8K | 1,074 MB | 467 MB | 1,074 MB | 526 MB |
37
- | 16K | 2,147 MB | 933 MB | 2,147 MB | 1,053 MB |
38
- | 32K | 4,295 MB | 1,866 MB | OOM | ~2,106 MB |
39
-
40
- Llama-3-8B FP16 runs out of memory at 32K context. Our method fits.
41
-
42
- ---
43
-
44
- ## The Key Insight
45
-
46
- ![Sensitivity Heatmap](figures/mistral-7b_sensitivity_heatmap.png)
47
-
48
- Each cell is one attention head. Darker means more sensitive, which means it needs higher precision.
49
- The variance is massive β€” heads in the same layer need completely different treatment.
50
- Uniform quantization ignores this entirely.
51
 
52
  ---
53
 
@@ -55,18 +7,21 @@ Uniform quantization ignores this entirely.
55
 
56
  **Step 1 β€” Calibrate (once, ~20 minutes)**
57
 
58
- Run 256 WikiText samples through the model. For each attention head measure reconstruction error at 4-bit and 8-bit. Save the optimal bit allocation to a JSON file (~1KB).
 
59
 
60
  **Step 2 β€” Compress (every inference)**
61
 
62
- Load the bit allocation. Quantize each head to its optimal precision. 4-bit heads use half the memory. 8-bit heads stay accurate.
 
63
 
64
  **Step 3 β€” Results**
65
 
66
- - 2.3x memory reduction on Mistral-7B
67
  - 2.04x memory reduction on Llama-3-8B
68
  - Zero perplexity degradation on both models
69
  - Same decode speed at 37 tokens/sec
 
70
 
71
  ---
72
 
@@ -94,11 +49,11 @@ Run full pipeline:
94
 
95
  Run step by step:
96
 
97
- make baseline MODEL=mistral-7b
98
- make calibrate MODEL=mistral-7b
99
- make integrate MODEL=mistral-7b
100
- make benchmark MODEL=mistral-7b
101
- make benchmark-long MODEL=mistral-7b
102
  make visualize
103
 
104
  ---
@@ -107,12 +62,13 @@ Run step by step:
107
 
108
  kv-cache-compression/
109
  β”œβ”€β”€ kernel/
110
- β”‚ └── quant_cache.py mixed-precision quantize/dequantize
 
111
  β”œβ”€β”€ scripts/
112
  β”‚ β”œβ”€β”€ baseline.py FP16 baseline measurements
113
  β”‚ β”œβ”€β”€ calibrate.py per-head sensitivity calibration
114
- β”‚ β”œβ”€β”€ integrate.py quantized inference integration
115
- β”‚ β”œβ”€β”€ benchmark.py full benchmark suite
116
  β”‚ β”œβ”€β”€ benchmark_long_context.py 16K/32K context benchmarks
117
  β”‚ β”œβ”€β”€ visualize_results.py benchmark graphs
118
  β”‚ β”œβ”€β”€ visualize_long_context.py long context graphs
@@ -122,8 +78,8 @@ Run step by step:
122
  β”‚ β”œβ”€β”€ run_mistral.sh full Mistral pipeline
123
  β”‚ └── run_llama.sh full Llama pipeline
124
  β”œβ”€β”€ results/
125
- β”‚ β”œβ”€οΏ½οΏ½οΏ½ mistral-7b/ baseline, calibration, benchmark
126
- β”‚ └── llama-3-8b/ baseline, calibration, benchmark
127
  β”œβ”€β”€ figures/ all generated graphs
128
  β”œβ”€β”€ requirements.txt pip dependencies
129
  β”œβ”€β”€ Makefile one-command pipeline
@@ -143,28 +99,27 @@ Run step by step:
143
 
144
  ## Limitations
145
 
146
- - Current 4-bit implementation stores values in uint8 which wastes half the space. True bit-packing via Triton kernel is in progress on the triton-kernel branch.
147
- - Calibration uses WikiText-2. Domain-specific calibration may improve results for specialized use cases.
148
- - Tested on 7-8B models only. Larger models need validation.
149
  - Integration is HuggingFace only. vLLM integration is planned.
 
150
 
151
- <!-- ---
152
 
153
  ## What's Next
154
 
155
- - True Triton 4-bit bit-packing kernel (triton-kernel branch)
156
  - vLLM PagedAttention integration
157
  - 32K and 128K context experiments
158
  - Llama-3-70B and Qwen-72B validation
159
  - Dynamic per-token bit allocation at decode time
160
- - ArXiv paper with full evaluation -->
161
 
162
- <!-- ---
163
 
164
- ## Citation
165
 
166
  @misc{kvcache-perhead-2026,
167
- title = {Per-Head Mixed-Precision KV Cache Compression},
168
  author = {Your Name},
169
  year = {2026},
170
  url = {https://github.com/YOURUSERNAME/kv-cache-compression}
 
1
+ This is why the Triton kernel is essential β€” not just faster, but the only way
2
+ to realize the theoretical memory savings from 4-bit quantization.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  ---
5
 
 
7
 
8
  **Step 1 β€” Calibrate (once, ~20 minutes)**
9
 
10
+ Run 256 WikiText samples through the model. For each attention head measure
11
+ reconstruction error at 4-bit and 8-bit. Save optimal bit allocation to JSON (~1KB).
12
 
13
  **Step 2 β€” Compress (every inference)**
14
 
15
+ Load the bit allocation. Use Triton kernel to truly pack 4-bit heads (2 values per byte).
16
+ Keep 8-bit heads at full precision. Result: genuine memory reduction.
17
 
18
  **Step 3 β€” Results**
19
 
20
+ - 2.30x memory reduction on Mistral-7B (vs 2.00x for naive/uniform)
21
  - 2.04x memory reduction on Llama-3-8B
22
  - Zero perplexity degradation on both models
23
  - Same decode speed at 37 tokens/sec
24
+ - Triton kernel is 10-12% faster than naive PyTorch implementation
25
 
26
  ---
27
 
 
49
 
50
  Run step by step:
51
 
52
+ make baseline MODEL=mistral-7b
53
+ make calibrate MODEL=mistral-7b
54
+ make integrate MODEL=mistral-7b
55
+ make benchmark MODEL=mistral-7b
56
+ make benchmark-long MODEL=mistral-7b
57
  make visualize
58
 
59
  ---
 
62
 
63
  kv-cache-compression/
64
  β”œβ”€β”€ kernel/
65
+ β”‚ β”œβ”€β”€ quant_cache.py naive uint8 implementation
66
+ β”‚ └── quant_cache_triton.py true Triton 4-bit bit-packing
67
  β”œβ”€β”€ scripts/
68
  β”‚ β”œβ”€β”€ baseline.py FP16 baseline measurements
69
  β”‚ β”œβ”€β”€ calibrate.py per-head sensitivity calibration
70
+ β”‚ β”œβ”€β”€ integrate.py naive vs triton comparison
71
+ β”‚ β”œβ”€β”€ benchmark.py full 4-method benchmark suite
72
  β”‚ β”œβ”€β”€ benchmark_long_context.py 16K/32K context benchmarks
73
  β”‚ β”œβ”€β”€ visualize_results.py benchmark graphs
74
  β”‚ β”œβ”€β”€ visualize_long_context.py long context graphs
 
78
  β”‚ β”œβ”€β”€ run_mistral.sh full Mistral pipeline
79
  β”‚ └── run_llama.sh full Llama pipeline
80
  β”œβ”€β”€ results/
81
+ β”‚ β”œβ”€β”€ mistral-7b/ all mistral results and JSONs
82
+ β”‚ └── llama-3-8b/ all llama results and JSONs
83
  β”œβ”€β”€ figures/ all generated graphs
84
  β”œβ”€β”€ requirements.txt pip dependencies
85
  β”œβ”€β”€ Makefile one-command pipeline
 
99
 
100
  ## Limitations
101
 
102
+ - Tested on 7-8B models only. Larger models (70B+) need validation.
103
+ - Calibration uses WikiText-2. Domain-specific calibration may improve results.
 
104
  - Integration is HuggingFace only. vLLM integration is planned.
105
+ - Llama-3-8B Triton compression (2.04x) is modest due to high head sensitivity.
106
 
107
+ ---
108
 
109
  ## What's Next
110
 
 
111
  - vLLM PagedAttention integration
112
  - 32K and 128K context experiments
113
  - Llama-3-70B and Qwen-72B validation
114
  - Dynamic per-token bit allocation at decode time
115
+ - ArXiv paper with full evaluation
116
 
117
+ ---
118
 
119
+ <!-- ## Citation
120
 
121
  @misc{kvcache-perhead-2026,
122
+ title = {Per-Head Mixed-Precision KV Cache Compression with True Triton Bit-Packing},
123
  author = {Your Name},
124
  year = {2026},
125
  url = {https://github.com/YOURUSERNAME/kv-cache-compression}