YAML Metadata Warning: empty or missing yaml metadata in repo card
Check out the documentation for more information.
LeanLlama-8B
LeanLlama-8B is a memory-efficient variant of Meta Llama 3.1 8B Instruct that reduces KV cache memory usage during inference while preserving output quality. It is a drop-in replacement that loads and runs like any standard Hugging Face model.
What changed
The base Llama 3.1 8B Instruct weights are unmodified. LeanLlama adds a small set of learned projection modules that compress the value representations stored in the KV cache at a subset of layers. During generation, these modules automatically compress cached values on-the-fly, reducing the memory footprint of long-context inference without requiring any changes to your generation code.
Quality
Evaluated against the uncompressed baseline on standard benchmarks:
| Metric | Delta |
|---|---|
| Perplexity | +4.82% |
| Distinct-2 (lexical diversity) | -0.93% |
In practice, generation quality is nearly indistinguishable from the original model for typical instruction-following and conversational workloads.
128K context validation
Verified on a single NVIDIA A40 (45 GB) with a needle-in-haystack retrieval task at full context length:
| Metric | Result |
|---|---|
| Input tokens | 126,239 |
| Prefill | 132.5s (953 tok/s) |
| Generation | 64 tokens in 8.9s (7.2 tok/s) |
| Peak GPU memory | 37.97 GB |
| Needle retrieved | Yes |
For long-context inference, use chunked prefill with logits_to_keep=0 on intermediate chunks to avoid materializing the full logits tensor. See the usage example below.
Usage
Basic generation:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"miike-ai/LeanLlama-8B",
trust_remote_code=True,
dtype="auto",
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("miike-ai/LeanLlama-8B")
inputs = tokenizer("What is the capital of France?", return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=128)
print(tokenizer.decode(output[0], skip_special_tokens=True))
Long-context generation (chunked prefill):
import torch
CHUNK = 4096
input_ids = tokenizer(long_text, return_tensors="pt").input_ids.to(model.device)
seq_len = input_ids.shape[1]
with torch.no_grad():
past_kv = None
for start in range(0, seq_len, CHUNK):
end = min(start + CHUNK, seq_len)
keep = 1 if end == seq_len else 0
out = model(
input_ids=input_ids[:, start:end],
past_key_values=past_kv,
use_cache=True,
logits_to_keep=keep,
)
past_kv = out.past_key_values
# Generate from the prefilled cache
next_id = out.logits[:, -1:, :].argmax(dim=-1)
for _ in range(max_new_tokens):
out = model(input_ids=next_id, past_key_values=past_kv, use_cache=True)
past_kv = out.past_key_values
next_id = out.logits[:, -1:, :].argmax(dim=-1)
No special configuration or post-processing is needed. The compression runs transparently inside the model's forward pass.
Base model
- Architecture: Llama 3.1
- Parameters: 8B
- Source:
meta-llama/Llama-3.1-8B-Instruct - Context window: 128K tokens
- License: Llama 3.1 Community License
- Downloads last month
- 38