--- license: mit datasets: - Salesforce/wikitext language: - en metrics: - perplexity base_model: - mistralai/Mistral-7B-Instruct-v0.3 - meta-llama/Meta-Llama-3-8B-Instruct tags: - quantization - kv-cache - llm-inference - cuda - triton - memory-efficient - mitral - llama - inference-optimization - 4-bit - mixed-precision --- # Per-Head Mixed-Precision KV Cache Compression Calibrate once. Pack truly. Same quality. Most KV cache quantization treats every attention head equally. This is wrong. Some heads are 26x more sensitive to quantization than others. We measure this, allocate bits per head, and use a Triton kernel to truly pack 4-bit values — achieving better compression than uniform 8-bit with zero quality loss. --- ## Key Finding Simply storing 4-bit values in uint8 wastes the compression benefit entirely. True bit-packing via our Triton kernel is required to realize theoretical savings. - Naive uint8 storage: same memory as uniform 8-bit (2.0x) — no benefit - Triton true packing: genuine 2.3x compression — real savings on actual GPU --- ## Results ![Memory vs Context](memory_vs_context_4methods.png) ![Compression](compression_bar_4methods.png) | Model | Method | KV @ 8K | vs FP16 | vs 8-bit | Perplexity | Speed | |-------|--------|---------|---------|---------|------------|-------| | Mistral-7B | FP16 Baseline | 1073 MB | 1.00x | — | 14.23 | 37.4 t/s | | Mistral-7B | Uniform 8-bit | 537 MB | 2.00x | 1.00x | ~same | ~same | | Mistral-7B | Naive Per-Head (uint8) | 537 MB | 2.00x | 1.00x | ~same | ~same | | Mistral-7B | **Triton True 4-bit (Ours)** | **467 MB** | **2.30x** | **1.15x** | **14.23** | **37.4 t/s** | | Llama-3-8B | FP16 Baseline | 1073 MB | 1.00x | — | 20.70 | 36.8 t/s | | Llama-3-8B | Uniform 8-bit | 537 MB | 2.00x | 1.00x | ~same | ~same | | Llama-3-8B | Naive Per-Head (uint8) | 537 MB | 2.00x | 1.00x | ~same | ~same | | Llama-3-8B | **Triton True 4-bit (Ours)** | **526 MB** | **2.04x** | **1.02x** | **20.70** | **36.8 t/s** | --- ## Long Context Results ![Long Context](long_context_4methods.png) ![32K Memory](memory_32k_4methods.png) | Context | FP16 | Naive (uint8) | Triton True 4-bit | |---------|------|---------------|-------------------| | 8K | 1,074 MB | 537 MB (2.0x) | 467 MB (2.3x) | | 16K | 2,147 MB | 1,074 MB (2.0x) | 933 MB (2.3x) | | 32K | 4,295 MB | 2,147 MB (2.0x) | 1,866 MB (2.3x) | Llama-3-8B FP16 runs out of memory at 32K context. Our Triton method fits. --- ## The Key Insight ![Sensitivity Heatmap](mistral-7b_sensitivity_heatmap.png) Each cell is one attention head. Darker means more sensitive — needs higher precision. The variance is massive. Heads in the same layer need completely different treatment. Uniform quantization ignores this entirely. --- ## Why True Bit-Packing Matters Naive implementations store 4-bit values in uint8 — one full byte per value. 65536 values = 65536 bytes = same compression as 8-bit, no additional benefit. Our Triton kernel truly packs two 4-bit values per byte. 65536 values = 32768 bytes = genuine 2.3x compression on actual GPU memory. The Triton kernel is not just faster — it is the only way to realize the theoretical memory savings from 4-bit quantization. --- ## How It Works Step 1 — Calibrate once, around 20 minutes Run 256 WikiText samples through the model. For each attention head measure reconstruction error at 4-bit and 8-bit. Save optimal bit allocation to JSON. Step 2 — Compress every inference Load the bit allocation. Use the Triton kernel to truly pack 4-bit heads at two values per byte. Keep 8-bit heads at full precision. Step 3 — Results - 2.30x memory reduction on Mistral-7B vs 2.00x for naive and uniform methods - 2.04x memory reduction on Llama-3-8B - Zero perplexity degradation on both models - Same decode speed at 37 tokens per second - Triton kernel is 10 to 12 percent faster than naive PyTorch --- ## Quick Start git clone https://github.com/harshithsaiv/kv-cache-compression cd kv-cache-compression pip install -r requirements.txt # Download Mistral (no approval needed) hf download mistralai/Mistral-7B-Instruct-v0.3 --local-dir ./mistral-model # Download Llama (requires HuggingFace approval) hf download meta-llama/Meta-Llama-3-8B-Instruct --local-dir ./llama-model # Run full pipeline make run-mistral make run-llama make run-both # Or step by step make baseline MODEL=mistral-7b make calibrate MODEL=mistral-7b make integrate MODEL=mistral-7b make benchmark MODEL=mistral-7b make benchmark-long MODEL=mistral-7b make visualize --- ## Project Structure kv-cache-compression/ ├── kernel/ │ ├── quant_cache.py naive uint8 implementation │ └── quant_cache_triton.py true Triton 4-bit bit-packing ├── scripts/ │ ├── baseline.py FP16 baseline measurements │ ├── calibrate.py per-head sensitivity calibration │ ├── integrate.py naive vs triton comparison │ ├── benchmark.py full 4-method benchmark suite │ ├── benchmark_long_context.py 16K/32K context benchmarks │ ├── visualize_results.py benchmark graphs │ ├── visualize_long_context.py long context graphs │ └── visualize_sensitivity.py heatmap generation ├── examples/ │ ├── quick_start.py 10-line usage example │ ├── run_mistral.sh full Mistral pipeline │ └── run_llama.sh full Llama pipeline ├── results/ │ ├── mistral-7b/ all mistral results and JSONs │ └── llama-3-8b/ all llama results and JSONs ├── figures/ all generated graphs ├── requirements.txt pip dependencies ├── Makefile one-command pipeline └── README.md --- ## Hardware and Environment - GPU: NVIDIA A100 SXM4 40GB - CUDA: 13.0 - PyTorch: 2.7.0 - Triton: 3.3.0 - OS: Ubuntu 22.04 --- ## Limitations - Tested on 7-8B models only. Larger models need validation. - Calibration uses WikiText-2. Domain-specific data may improve results. - Integration is HuggingFace only. vLLM integration is planned. - Llama-3-8B compression is modest due to higher head sensitivity. --- ## What Is Next - vLLM PagedAttention integration - 32K and 128K context experiments - Llama-3-70B and Qwen-72B validation - Dynamic per-token bit allocation at decode time - ArXiv paper with full evaluation ## License MIT. Free to use, modify, and distribute. Built in one week on an A100 SXM4 40GB. Questions, issues, and PRs welcome.