Per-Head Mixed-Precision KV Cache Compression
Calibrate once. Pack truly. Same quality.
Most KV cache quantization treats every attention head equally. This is wrong. Some heads are 26x more sensitive to quantization than others. We measure this, allocate bits per head, and use a Triton kernel to truly pack 4-bit values β achieving better compression than uniform 8-bit with zero quality loss.
Key Finding
Simply storing 4-bit values in uint8 wastes the compression benefit entirely. True bit-packing via our Triton kernel is required to realize theoretical savings.
- Naive uint8 storage: same memory as uniform 8-bit (2.0x) β no benefit
- Triton true packing: genuine 2.3x compression β real savings on actual GPU
Results
| Model | Method | KV @ 8K | vs FP16 | vs 8-bit | Perplexity | Speed |
|---|---|---|---|---|---|---|
| Mistral-7B | FP16 Baseline | 1073 MB | 1.00x | β | 14.23 | 37.4 t/s |
| Mistral-7B | Uniform 8-bit | 537 MB | 2.00x | 1.00x | ~same | ~same |
| Mistral-7B | Naive Per-Head (uint8) | 537 MB | 2.00x | 1.00x | ~same | ~same |
| Mistral-7B | Triton True 4-bit (Ours) | 467 MB | 2.30x | 1.15x | 14.23 | 37.4 t/s |
| Llama-3-8B | FP16 Baseline | 1073 MB | 1.00x | β | 20.70 | 36.8 t/s |
| Llama-3-8B | Uniform 8-bit | 537 MB | 2.00x | 1.00x | ~same | ~same |
| Llama-3-8B | Naive Per-Head (uint8) | 537 MB | 2.00x | 1.00x | ~same | ~same |
| Llama-3-8B | Triton True 4-bit (Ours) | 526 MB | 2.04x | 1.02x | 20.70 | 36.8 t/s |
Long Context Results
| Context | FP16 | Naive (uint8) | Triton True 4-bit |
|---|---|---|---|
| 8K | 1,074 MB | 537 MB (2.0x) | 467 MB (2.3x) |
| 16K | 2,147 MB | 1,074 MB (2.0x) | 933 MB (2.3x) |
| 32K | 4,295 MB | 2,147 MB (2.0x) | 1,866 MB (2.3x) |
Llama-3-8B FP16 runs out of memory at 32K context. Our Triton method fits.
The Key Insight
Each cell is one attention head. Darker means more sensitive β needs higher precision. The variance is massive. Heads in the same layer need completely different treatment. Uniform quantization ignores this entirely.
Why True Bit-Packing Matters
Naive implementations store 4-bit values in uint8 β one full byte per value. 65536 values = 65536 bytes = same compression as 8-bit, no additional benefit.
Our Triton kernel truly packs two 4-bit values per byte. 65536 values = 32768 bytes = genuine 2.3x compression on actual GPU memory.
The Triton kernel is not just faster β it is the only way to realize the theoretical memory savings from 4-bit quantization.
How It Works
Step 1 β Calibrate once, around 20 minutes
Run 256 WikiText samples through the model. For each attention head measure reconstruction error at 4-bit and 8-bit. Save optimal bit allocation to JSON.
Step 2 β Compress every inference
Load the bit allocation. Use the Triton kernel to truly pack 4-bit heads at two values per byte. Keep 8-bit heads at full precision.
Step 3 β Results
- 2.30x memory reduction on Mistral-7B vs 2.00x for naive and uniform methods
- 2.04x memory reduction on Llama-3-8B
- Zero perplexity degradation on both models
- Same decode speed at 37 tokens per second
- Triton kernel is 10 to 12 percent faster than naive PyTorch
Quick Start
git clone https://github.com/harshithsaiv/kv-cache-compression
cd kv-cache-compression
pip install -r requirements.txt
# Download Mistral (no approval needed)
hf download mistralai/Mistral-7B-Instruct-v0.3 --local-dir ./mistral-model
# Download Llama (requires HuggingFace approval)
hf download meta-llama/Meta-Llama-3-8B-Instruct --local-dir ./llama-model
# Run full pipeline
make run-mistral
make run-llama
make run-both
# Or step by step
make baseline MODEL=mistral-7b
make calibrate MODEL=mistral-7b
make integrate MODEL=mistral-7b
make benchmark MODEL=mistral-7b
make benchmark-long MODEL=mistral-7b
make visualize
Project Structure
kv-cache-compression/
βββ kernel/
β βββ quant_cache.py naive uint8 implementation
β βββ quant_cache_triton.py true Triton 4-bit bit-packing
βββ scripts/
β βββ baseline.py FP16 baseline measurements
β βββ calibrate.py per-head sensitivity calibration
β βββ integrate.py naive vs triton comparison
β βββ benchmark.py full 4-method benchmark suite
β βββ benchmark_long_context.py 16K/32K context benchmarks
β βββ visualize_results.py benchmark graphs
β βββ visualize_long_context.py long context graphs
β βββ visualize_sensitivity.py heatmap generation
βββ examples/
β βββ quick_start.py 10-line usage example
β βββ run_mistral.sh full Mistral pipeline
β βββ run_llama.sh full Llama pipeline
βββ results/
β βββ mistral-7b/ all mistral results and JSONs
β βββ llama-3-8b/ all llama results and JSONs
βββ figures/ all generated graphs
βββ requirements.txt pip dependencies
βββ Makefile one-command pipeline
βββ README.md
Hardware and Environment
- GPU: NVIDIA A100 SXM4 40GB
- CUDA: 13.0
- PyTorch: 2.7.0
- Triton: 3.3.0
- OS: Ubuntu 22.04
Limitations
- Tested on 7-8B models only. Larger models need validation.
- Calibration uses WikiText-2. Domain-specific data may improve results.
- Integration is HuggingFace only. vLLM integration is planned.
- Llama-3-8B compression is modest due to higher head sensitivity.
What Is Next
- vLLM PagedAttention integration
- 32K and 128K context experiments
- Llama-3-70B and Qwen-72B validation
- Dynamic per-token bit allocation at decode time
- ArXiv paper with full evaluation
License
MIT. Free to use, modify, and distribute.
Built in one week on an A100 SXM4 40GB. Questions, issues, and PRs welcome.
Model tree for harshithsaiv/kv-cache-compression
Base model
meta-llama/Meta-Llama-3-8B-Instruct



