| --- |
| license: mit |
| datasets: |
| - Salesforce/wikitext |
| language: |
| - en |
| metrics: |
| - perplexity |
| base_model: |
| - mistralai/Mistral-7B-Instruct-v0.3 |
| - meta-llama/Meta-Llama-3-8B-Instruct |
| tags: |
| - quantization |
| - kv-cache |
| - llm-inference |
| - cuda |
| - triton |
| - memory-efficient |
| - mitral |
| - llama |
| - inference-optimization |
| - 4-bit |
| - mixed-precision |
| --- |
| # Per-Head Mixed-Precision KV Cache Compression |
|
|
| Calibrate once. Pack truly. Same quality. |
|
|
| Most KV cache quantization treats every attention head equally. This is wrong. |
| Some heads are 26x more sensitive to quantization than others. We measure this, |
| allocate bits per head, and use a Triton kernel to truly pack 4-bit values β |
| achieving better compression than uniform 8-bit with zero quality loss. |
|
|
| --- |
|
|
| ## Key Finding |
|
|
| Simply storing 4-bit values in uint8 wastes the compression benefit entirely. |
| True bit-packing via our Triton kernel is required to realize theoretical savings. |
|
|
| - Naive uint8 storage: same memory as uniform 8-bit (2.0x) β no benefit |
| - Triton true packing: genuine 2.3x compression β real savings on actual GPU |
|
|
| --- |
|
|
| ## Results |
|
|
|  |
|
|
|  |
|
|
| | Model | Method | KV @ 8K | vs FP16 | vs 8-bit | Perplexity | Speed | |
| |-------|--------|---------|---------|---------|------------|-------| |
| | Mistral-7B | FP16 Baseline | 1073 MB | 1.00x | β | 14.23 | 37.4 t/s | |
| | Mistral-7B | Uniform 8-bit | 537 MB | 2.00x | 1.00x | ~same | ~same | |
| | Mistral-7B | Naive Per-Head (uint8) | 537 MB | 2.00x | 1.00x | ~same | ~same | |
| | Mistral-7B | **Triton True 4-bit (Ours)** | **467 MB** | **2.30x** | **1.15x** | **14.23** | **37.4 t/s** | |
| | Llama-3-8B | FP16 Baseline | 1073 MB | 1.00x | β | 20.70 | 36.8 t/s | |
| | Llama-3-8B | Uniform 8-bit | 537 MB | 2.00x | 1.00x | ~same | ~same | |
| | Llama-3-8B | Naive Per-Head (uint8) | 537 MB | 2.00x | 1.00x | ~same | ~same | |
| | Llama-3-8B | **Triton True 4-bit (Ours)** | **526 MB** | **2.04x** | **1.02x** | **20.70** | **36.8 t/s** | |
|
|
| --- |
|
|
| ## Long Context Results |
|
|
|  |
|
|
|  |
|
|
| | Context | FP16 | Naive (uint8) | Triton True 4-bit | |
| |---------|------|---------------|-------------------| |
| | 8K | 1,074 MB | 537 MB (2.0x) | 467 MB (2.3x) | |
| | 16K | 2,147 MB | 1,074 MB (2.0x) | 933 MB (2.3x) | |
| | 32K | 4,295 MB | 2,147 MB (2.0x) | 1,866 MB (2.3x) | |
|
|
| Llama-3-8B FP16 runs out of memory at 32K context. Our Triton method fits. |
|
|
| --- |
|
|
| ## The Key Insight |
|
|
|  |
|
|
| Each cell is one attention head. Darker means more sensitive β needs higher precision. |
| The variance is massive. Heads in the same layer need completely different treatment. |
| Uniform quantization ignores this entirely. |
|
|
| --- |
|
|
| ## Why True Bit-Packing Matters |
|
|
| Naive implementations store 4-bit values in uint8 β one full byte per value. |
| 65536 values = 65536 bytes = same compression as 8-bit, no additional benefit. |
|
|
| Our Triton kernel truly packs two 4-bit values per byte. |
| 65536 values = 32768 bytes = genuine 2.3x compression on actual GPU memory. |
|
|
| The Triton kernel is not just faster β it is the only way to realize |
| the theoretical memory savings from 4-bit quantization. |
|
|
| --- |
|
|
| ## How It Works |
|
|
| Step 1 β Calibrate once, around 20 minutes |
|
|
| Run 256 WikiText samples through the model. For each attention head measure |
| reconstruction error at 4-bit and 8-bit. Save optimal bit allocation to JSON. |
|
|
| Step 2 β Compress every inference |
|
|
| Load the bit allocation. Use the Triton kernel to truly pack 4-bit heads |
| at two values per byte. Keep 8-bit heads at full precision. |
|
|
| Step 3 β Results |
|
|
| - 2.30x memory reduction on Mistral-7B vs 2.00x for naive and uniform methods |
| - 2.04x memory reduction on Llama-3-8B |
| - Zero perplexity degradation on both models |
| - Same decode speed at 37 tokens per second |
| - Triton kernel is 10 to 12 percent faster than naive PyTorch |
|
|
| --- |
|
|
| ## Quick Start |
|
|
| git clone https://github.com/harshithsaiv/kv-cache-compression |
| cd kv-cache-compression |
| pip install -r requirements.txt |
| |
| # Download Mistral (no approval needed) |
| hf download mistralai/Mistral-7B-Instruct-v0.3 --local-dir ./mistral-model |
| |
| # Download Llama (requires HuggingFace approval) |
| hf download meta-llama/Meta-Llama-3-8B-Instruct --local-dir ./llama-model |
| |
| # Run full pipeline |
| make run-mistral |
| make run-llama |
| make run-both |
| |
| # Or step by step |
| make baseline MODEL=mistral-7b |
| make calibrate MODEL=mistral-7b |
| make integrate MODEL=mistral-7b |
| make benchmark MODEL=mistral-7b |
| make benchmark-long MODEL=mistral-7b |
| make visualize |
| |
| --- |
|
|
| ## Project Structure |
|
|
| kv-cache-compression/ |
| βββ kernel/ |
| β βββ quant_cache.py naive uint8 implementation |
| β βββ quant_cache_triton.py true Triton 4-bit bit-packing |
| βββ scripts/ |
| β βββ baseline.py FP16 baseline measurements |
| β βββ calibrate.py per-head sensitivity calibration |
| β βββ integrate.py naive vs triton comparison |
| β βββ benchmark.py full 4-method benchmark suite |
| β βββ benchmark_long_context.py 16K/32K context benchmarks |
| β βββ visualize_results.py benchmark graphs |
| β βββ visualize_long_context.py long context graphs |
| β βββ visualize_sensitivity.py heatmap generation |
| βββ examples/ |
| β βββ quick_start.py 10-line usage example |
| β βββ run_mistral.sh full Mistral pipeline |
| β βββ run_llama.sh full Llama pipeline |
| βββ results/ |
| β βββ mistral-7b/ all mistral results and JSONs |
| β βββ llama-3-8b/ all llama results and JSONs |
| βββ figures/ all generated graphs |
| βββ requirements.txt pip dependencies |
| βββ Makefile one-command pipeline |
| βββ README.md |
| |
| --- |
|
|
| ## Hardware and Environment |
|
|
| - GPU: NVIDIA A100 SXM4 40GB |
| - CUDA: 13.0 |
| - PyTorch: 2.7.0 |
| - Triton: 3.3.0 |
| - OS: Ubuntu 22.04 |
|
|
| --- |
|
|
| ## Limitations |
|
|
| - Tested on 7-8B models only. Larger models need validation. |
| - Calibration uses WikiText-2. Domain-specific data may improve results. |
| - Integration is HuggingFace only. vLLM integration is planned. |
| - Llama-3-8B compression is modest due to higher head sensitivity. |
|
|
| --- |
|
|
| ## What Is Next |
|
|
| - vLLM PagedAttention integration |
| - 32K and 128K context experiments |
| - Llama-3-70B and Qwen-72B validation |
| - Dynamic per-token bit allocation at decode time |
| - ArXiv paper with full evaluation |
|
|
| <!-- --- |
|
|
| ## Citation |
|
|
| @misc{kvcache-perhead-2026, |
| title = {Per-Head Mixed-Precision KV Cache Compression with True Triton Bit-Packing}, |
| author = {Your Name}, |
| year = {2026}, |
| url = {https://github.com/YOURUSERNAME/kv-cache-compression} |
| } |
| |
| --- --> |
|
|
| ## License |
|
|
| MIT. Free to use, modify, and distribute. |
|
|
| Built in one week on an A100 SXM4 40GB. Questions, issues, and PRs welcome. |