Upload README.md with huggingface_hub
Browse files
README.md
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
tags:
|
| 4 |
+
- mixtral
|
| 5 |
+
- kv-cache-compression
|
| 6 |
+
- inference-optimization
|
| 7 |
+
- memory-efficient
|
| 8 |
+
license: apache-2.0
|
| 9 |
+
base_model: mistralai/Mixtral-8x7B-v0.1
|
| 10 |
+
pipeline_tag: text-generation
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# LeanMixtral-8x7B-38
|
| 14 |
+
|
| 15 |
+
Mixtral 8x7B with **256x KV cache compression** on 38% of layers (12 out of 32), reducing KV cache memory for those layers while maintaining model quality.
|
| 16 |
+
|
| 17 |
+
## What is this?
|
| 18 |
+
|
| 19 |
+
This model applies learned linear autoencoders to compress the key-value cache in 12 of Mixtral's 32 transformer layers. Each KV vector (1024 dims) is compressed to just 4 dimensions via an encoder, then reconstructed through a 2-layer decoder (4 -> 128 -> 1024) with GELU activation. Compression happens per-token at inference time with no changes to the attention mechanism itself.
|
| 20 |
+
|
| 21 |
+
**Key results (measured on WikiText-2):**
|
| 22 |
+
- Baseline perplexity: **7.60**
|
| 23 |
+
- With compression: **10.66** (+3.06 points)
|
| 24 |
+
- Top-1 prediction agreement: **~76%**
|
| 25 |
+
- Top-5 prediction agreement: **~99.6%**
|
| 26 |
+
- Compressed layers: 3, 8, 13, 14, 16, 22, 23, 24, 25, 26, 27, 28
|
| 27 |
+
|
| 28 |
+
These layers were selected through a per-layer compressibility sweep that identified which layers have inherently low-dimensional KV subspaces and can tolerate aggressive compression with minimal impact on output quality.
|
| 29 |
+
|
| 30 |
+
## Usage
|
| 31 |
+
|
| 32 |
+
```python
|
| 33 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 34 |
+
|
| 35 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 36 |
+
"miike-ai/LeanMixtral-8x7B-38",
|
| 37 |
+
trust_remote_code=True,
|
| 38 |
+
device_map="auto",
|
| 39 |
+
)
|
| 40 |
+
tokenizer = AutoTokenizer.from_pretrained("miike-ai/LeanMixtral-8x7B-38")
|
| 41 |
+
|
| 42 |
+
inputs = tokenizer("The future of AI is", return_tensors="pt").to(model.device)
|
| 43 |
+
outputs = model.generate(**inputs, max_new_tokens=100, do_sample=True, temperature=0.7)
|
| 44 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### With 4-bit quantization (fits on a single 24GB GPU)
|
| 48 |
+
|
| 49 |
+
```python
|
| 50 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 51 |
+
import torch
|
| 52 |
+
|
| 53 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 54 |
+
"miike-ai/LeanMixtral-8x7B-38",
|
| 55 |
+
trust_remote_code=True,
|
| 56 |
+
device_map="auto",
|
| 57 |
+
quantization_config=BitsAndBytesConfig(
|
| 58 |
+
load_in_4bit=True,
|
| 59 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 60 |
+
bnb_4bit_quant_type="nf4",
|
| 61 |
+
bnb_4bit_use_double_quant=True,
|
| 62 |
+
),
|
| 63 |
+
)
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
## How it works
|
| 67 |
+
|
| 68 |
+
Each compressed layer has two small autoencoder modules attached to the self-attention block:
|
| 69 |
+
|
| 70 |
+
- **K compressor**: `Linear(1024, 4)` encoder + `Linear(4, 128) -> GELU -> Linear(128, 1024)` decoder
|
| 71 |
+
- **V compressor**: Same architecture
|
| 72 |
+
|
| 73 |
+
After each forward pass, the model intercepts the KV cache and compresses the new tokens in the configured layers. The compressors use `_SafeLinear` (a custom `nn.Module`) instead of `nn.Linear` so that bitsandbytes quantization leaves them untouched.
|
| 74 |
+
|
| 75 |
+
The compressor weights add only **12.6 MB** to the model (168 parameters across 12 layers).
|
| 76 |
+
|
| 77 |
+
## Layer selection methodology
|
| 78 |
+
|
| 79 |
+
All 32 layers were individually tested with compression to measure per-layer perplexity impact:
|
| 80 |
+
|
| 81 |
+
| Category | Layers | Per-layer PPL impact |
|
| 82 |
+
|----------|--------|---------------------|
|
| 83 |
+
| Safe (<0.1 pts) | 14, 22, 23, 24, 25, 26, 27, 28 | Negligible |
|
| 84 |
+
| Low (0.1-0.3 pts) | 3, 8, 13, 16 | Minor |
|
| 85 |
+
| Moderate (0.3-0.5 pts) | Most others | Noticeable |
|
| 86 |
+
| Critical (>1 pt) | 0, 1 | Never compressed |
|
| 87 |
+
|
| 88 |
+
Layers were added progressively in order of compressibility. At 12 layers the cumulative impact stays under ~3 PPL points; adding more causes errors to compound rapidly.
|
| 89 |
+
|
| 90 |
+
## Architecture details
|
| 91 |
+
|
| 92 |
+
- **Base model**: [mistralai/Mixtral-8x7B-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)
|
| 93 |
+
- **Model class**: `LeanMixtralForCausalLM` (extends `MixtralForCausalLM`)
|
| 94 |
+
- **Compression ratio**: 256x per compressed layer (1024 -> 4 dimensions)
|
| 95 |
+
- **Compressed layers**: 12/32 (38%)
|
| 96 |
+
- **Compressor overhead**: 12.6 MB (negligible vs 87 GB base model)
|
| 97 |
+
- **Training data for compressors**: WikiText-2 (6 epochs per layer)
|
| 98 |
+
- **Precision**: bf16 weights (safetensors format)
|
| 99 |
+
|
| 100 |
+
## Limitations
|
| 101 |
+
|
| 102 |
+
- Compression adds a small perplexity penalty (~3 points on WikiText-2)
|
| 103 |
+
- Generation quality may degrade slightly on tasks requiring precise recall from compressed layers
|
| 104 |
+
- The `trust_remote_code=True` flag is required since this uses a custom model class
|
| 105 |
+
- Compressor weights were trained on WikiText-2; other domains may see different compression quality
|
| 106 |
+
|
| 107 |
+
## Citation
|
| 108 |
+
|
| 109 |
+
If you use this model, please cite the base model:
|
| 110 |
+
|
| 111 |
+
```bibtex
|
| 112 |
+
@article{jiang2024mixtral,
|
| 113 |
+
title={Mixtral of Experts},
|
| 114 |
+
author={Jiang, Albert Q and Sablayrolles, Alexandre and Roux, Antoine and Mensch, Arthur and Savary, Blanche and Bamford, Chris and Chaplot, Devendra Singh and Casas, Diego de las and Hanna, Emma Bou and Bressand, Florian and others},
|
| 115 |
+
journal={arXiv preprint arXiv:2401.04088},
|
| 116 |
+
year={2024}
|
| 117 |
+
}
|
| 118 |
+
```
|