Per-Head Mixed-Precision KV Cache Compression

Calibrate once. Pack truly. Same quality.

Most KV cache quantization treats every attention head equally. This is wrong. Some heads are 26x more sensitive to quantization than others. We measure this, allocate bits per head, and use a Triton kernel to truly pack 4-bit values β€” achieving better compression than uniform 8-bit with zero quality loss.


Key Finding

Simply storing 4-bit values in uint8 wastes the compression benefit entirely. True bit-packing via our Triton kernel is required to realize theoretical savings.

  • Naive uint8 storage: same memory as uniform 8-bit (2.0x) β€” no benefit
  • Triton true packing: genuine 2.3x compression β€” real savings on actual GPU

Results

Memory vs Context

Compression

Model Method KV @ 8K vs FP16 vs 8-bit Perplexity Speed
Mistral-7B FP16 Baseline 1073 MB 1.00x β€” 14.23 37.4 t/s
Mistral-7B Uniform 8-bit 537 MB 2.00x 1.00x ~same ~same
Mistral-7B Naive Per-Head (uint8) 537 MB 2.00x 1.00x ~same ~same
Mistral-7B Triton True 4-bit (Ours) 467 MB 2.30x 1.15x 14.23 37.4 t/s
Llama-3-8B FP16 Baseline 1073 MB 1.00x β€” 20.70 36.8 t/s
Llama-3-8B Uniform 8-bit 537 MB 2.00x 1.00x ~same ~same
Llama-3-8B Naive Per-Head (uint8) 537 MB 2.00x 1.00x ~same ~same
Llama-3-8B Triton True 4-bit (Ours) 526 MB 2.04x 1.02x 20.70 36.8 t/s

Long Context Results

Long Context

32K Memory

Context FP16 Naive (uint8) Triton True 4-bit
8K 1,074 MB 537 MB (2.0x) 467 MB (2.3x)
16K 2,147 MB 1,074 MB (2.0x) 933 MB (2.3x)
32K 4,295 MB 2,147 MB (2.0x) 1,866 MB (2.3x)

Llama-3-8B FP16 runs out of memory at 32K context. Our Triton method fits.


The Key Insight

Sensitivity Heatmap

Each cell is one attention head. Darker means more sensitive β€” needs higher precision. The variance is massive. Heads in the same layer need completely different treatment. Uniform quantization ignores this entirely.


Why True Bit-Packing Matters

Naive implementations store 4-bit values in uint8 β€” one full byte per value. 65536 values = 65536 bytes = same compression as 8-bit, no additional benefit.

Our Triton kernel truly packs two 4-bit values per byte. 65536 values = 32768 bytes = genuine 2.3x compression on actual GPU memory.

The Triton kernel is not just faster β€” it is the only way to realize the theoretical memory savings from 4-bit quantization.


How It Works

Step 1 β€” Calibrate once, around 20 minutes

Run 256 WikiText samples through the model. For each attention head measure reconstruction error at 4-bit and 8-bit. Save optimal bit allocation to JSON.

Step 2 β€” Compress every inference

Load the bit allocation. Use the Triton kernel to truly pack 4-bit heads at two values per byte. Keep 8-bit heads at full precision.

Step 3 β€” Results

  • 2.30x memory reduction on Mistral-7B vs 2.00x for naive and uniform methods
  • 2.04x memory reduction on Llama-3-8B
  • Zero perplexity degradation on both models
  • Same decode speed at 37 tokens per second
  • Triton kernel is 10 to 12 percent faster than naive PyTorch

Quick Start

git clone https://github.com/harshithsaiv/kv-cache-compression
cd kv-cache-compression
pip install -r requirements.txt

# Download Mistral (no approval needed)
hf download mistralai/Mistral-7B-Instruct-v0.3 --local-dir ./mistral-model

# Download Llama (requires HuggingFace approval)
hf download meta-llama/Meta-Llama-3-8B-Instruct --local-dir ./llama-model

# Run full pipeline
make run-mistral
make run-llama
make run-both

# Or step by step
make baseline       MODEL=mistral-7b
make calibrate      MODEL=mistral-7b
make integrate      MODEL=mistral-7b
make benchmark      MODEL=mistral-7b
make benchmark-long MODEL=mistral-7b
make visualize

Project Structure

kv-cache-compression/
β”œβ”€β”€ kernel/
β”‚   β”œβ”€β”€ quant_cache.py              naive uint8 implementation
β”‚   └── quant_cache_triton.py       true Triton 4-bit bit-packing
β”œβ”€β”€ scripts/
β”‚   β”œβ”€β”€ baseline.py                 FP16 baseline measurements
β”‚   β”œβ”€β”€ calibrate.py                per-head sensitivity calibration
β”‚   β”œβ”€β”€ integrate.py                naive vs triton comparison
β”‚   β”œβ”€β”€ benchmark.py                full 4-method benchmark suite
β”‚   β”œβ”€β”€ benchmark_long_context.py   16K/32K context benchmarks
β”‚   β”œβ”€β”€ visualize_results.py        benchmark graphs
β”‚   β”œβ”€β”€ visualize_long_context.py   long context graphs
β”‚   └── visualize_sensitivity.py    heatmap generation
β”œβ”€β”€ examples/
β”‚   β”œβ”€β”€ quick_start.py              10-line usage example
β”‚   β”œβ”€β”€ run_mistral.sh              full Mistral pipeline
β”‚   └── run_llama.sh                full Llama pipeline
β”œβ”€β”€ results/
β”‚   β”œβ”€β”€ mistral-7b/                 all mistral results and JSONs
β”‚   └── llama-3-8b/                 all llama results and JSONs
β”œβ”€β”€ figures/                        all generated graphs
β”œβ”€β”€ requirements.txt                pip dependencies
β”œβ”€β”€ Makefile                        one-command pipeline
└── README.md

Hardware and Environment

  • GPU: NVIDIA A100 SXM4 40GB
  • CUDA: 13.0
  • PyTorch: 2.7.0
  • Triton: 3.3.0
  • OS: Ubuntu 22.04

Limitations

  • Tested on 7-8B models only. Larger models need validation.
  • Calibration uses WikiText-2. Domain-specific data may improve results.
  • Integration is HuggingFace only. vLLM integration is planned.
  • Llama-3-8B compression is modest due to higher head sensitivity.

What Is Next

  • vLLM PagedAttention integration
  • 32K and 128K context experiments
  • Llama-3-70B and Qwen-72B validation
  • Dynamic per-token bit allocation at decode time
  • ArXiv paper with full evaluation

License

MIT. Free to use, modify, and distribute.

Built in one week on an A100 SXM4 40GB. Questions, issues, and PRs welcome.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for harshithsaiv/kv-cache-compression

Finetuned
(1076)
this model

Dataset used to train harshithsaiv/kv-cache-compression