LeanMixtral-8x7B-38
Mixtral 8x7B with 256x KV cache compression on 38% of layers (12 out of 32), reducing KV cache memory for those layers while maintaining model quality.
What is this?
This model applies learned linear autoencoders to compress the key-value cache in 12 of Mixtral's 32 transformer layers. Each KV vector (1024 dims) is compressed to just 4 dimensions via an encoder, then reconstructed through a 2-layer decoder (4 -> 128 -> 1024) with GELU activation. Compression happens per-token at inference time with no changes to the attention mechanism itself.
Key results (measured on WikiText-2):
- Baseline perplexity: 7.60
- With compression: 10.66 (+3.06 points)
- Top-1 prediction agreement: ~76%
- Top-5 prediction agreement: ~99.6%
- Compressed layers: 3, 8, 13, 14, 16, 22, 23, 24, 25, 26, 27, 28
These layers were selected through a per-layer compressibility sweep that identified which layers have inherently low-dimensional KV subspaces and can tolerate aggressive compression with minimal impact on output quality.
Usage
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"miike-ai/LeanMixtral-8x7B-38",
trust_remote_code=True,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("miike-ai/LeanMixtral-8x7B-38")
inputs = tokenizer("The future of AI is", return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=100, do_sample=True, temperature=0.7)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
With 4-bit quantization (fits on a single 24GB GPU)
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
model = AutoModelForCausalLM.from_pretrained(
"miike-ai/LeanMixtral-8x7B-38",
trust_remote_code=True,
device_map="auto",
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
),
)
How it works
Each compressed layer has two small autoencoder modules attached to the self-attention block:
- K compressor:
Linear(1024, 4)encoder +Linear(4, 128) -> GELU -> Linear(128, 1024)decoder - V compressor: Same architecture
After each forward pass, the model intercepts the KV cache and compresses the new tokens in the configured layers. The compressors use _SafeLinear (a custom nn.Module) instead of nn.Linear so that bitsandbytes quantization leaves them untouched.
The compressor weights add only 12.6 MB to the model (168 parameters across 12 layers).
Layer selection methodology
All 32 layers were individually tested with compression to measure per-layer perplexity impact:
| Category | Layers | Per-layer PPL impact |
|---|---|---|
| Safe (<0.1 pts) | 14, 22, 23, 24, 25, 26, 27, 28 | Negligible |
| Low (0.1-0.3 pts) | 3, 8, 13, 16 | Minor |
| Moderate (0.3-0.5 pts) | Most others | Noticeable |
| Critical (>1 pt) | 0, 1 | Never compressed |
Layers were added progressively in order of compressibility. At 12 layers the cumulative impact stays under ~3 PPL points; adding more causes errors to compound rapidly.
Architecture details
- Base model: mistralai/Mixtral-8x7B-v0.1
- Model class:
LeanMixtralForCausalLM(extendsMixtralForCausalLM) - Compression ratio: 256x per compressed layer (1024 -> 4 dimensions)
- Compressed layers: 12/32 (38%)
- Compressor overhead: 12.6 MB (negligible vs 87 GB base model)
- Training data for compressors: WikiText-2 (6 epochs per layer)
- Precision: bf16 weights (safetensors format)
Limitations
- Compression adds a small perplexity penalty (~3 points on WikiText-2)
- Generation quality may degrade slightly on tasks requiring precise recall from compressed layers
- The
trust_remote_code=Trueflag is required since this uses a custom model class - Compressor weights were trained on WikiText-2; other domains may see different compression quality
Citation
If you use this model, please cite the base model:
@article{jiang2024mixtral,
title={Mixtral of Experts},
author={Jiang, Albert Q and Sablayrolles, Alexandre and Roux, Antoine and Mensch, Arthur and Savary, Blanche and Bamford, Chris and Chaplot, Devendra Singh and Casas, Diego de las and Hanna, Emma Bou and Bressand, Florian and others},
journal={arXiv preprint arXiv:2401.04088},
year={2024}
}
- Downloads last month
- 5
Model tree for miike-ai/LeanMixtral-8x7B
Base model
mistralai/Mixtral-8x7B-v0.1