YAML Metadata Warning: empty or missing yaml metadata in repo card

Check out the documentation for more information.

LeanLlama-8B

LeanLlama-8B is a memory-efficient variant of Meta Llama 3.1 8B Instruct that reduces KV cache memory usage during inference while preserving output quality. It is a drop-in replacement that loads and runs like any standard Hugging Face model.

What changed

The base Llama 3.1 8B Instruct weights are unmodified. LeanLlama adds a small set of learned projection modules that compress the value representations stored in the KV cache at a subset of layers. During generation, these modules automatically compress cached values on-the-fly, reducing the memory footprint of long-context inference without requiring any changes to your generation code.

Quality

Evaluated against the uncompressed baseline on standard benchmarks:

Metric Delta
Perplexity +4.82%
Distinct-2 (lexical diversity) -0.93%

In practice, generation quality is nearly indistinguishable from the original model for typical instruction-following and conversational workloads.

128K context validation

Verified on a single NVIDIA A40 (45 GB) with a needle-in-haystack retrieval task at full context length:

Metric Result
Input tokens 126,239
Prefill 132.5s (953 tok/s)
Generation 64 tokens in 8.9s (7.2 tok/s)
Peak GPU memory 37.97 GB
Needle retrieved Yes

For long-context inference, use chunked prefill with logits_to_keep=0 on intermediate chunks to avoid materializing the full logits tensor. See the usage example below.

Usage

Basic generation:

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "miike-ai/LeanLlama-8B",
    trust_remote_code=True,
    dtype="auto",
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("miike-ai/LeanLlama-8B")

inputs = tokenizer("What is the capital of France?", return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=128)
print(tokenizer.decode(output[0], skip_special_tokens=True))

Long-context generation (chunked prefill):

import torch

CHUNK = 4096
input_ids = tokenizer(long_text, return_tensors="pt").input_ids.to(model.device)
seq_len = input_ids.shape[1]

with torch.no_grad():
    past_kv = None
    for start in range(0, seq_len, CHUNK):
        end = min(start + CHUNK, seq_len)
        keep = 1 if end == seq_len else 0
        out = model(
            input_ids=input_ids[:, start:end],
            past_key_values=past_kv,
            use_cache=True,
            logits_to_keep=keep,
        )
        past_kv = out.past_key_values

    # Generate from the prefilled cache
    next_id = out.logits[:, -1:, :].argmax(dim=-1)
    for _ in range(max_new_tokens):
        out = model(input_ids=next_id, past_key_values=past_kv, use_cache=True)
        past_kv = out.past_key_values
        next_id = out.logits[:, -1:, :].argmax(dim=-1)

No special configuration or post-processing is needed. The compression runs transparently inside the model's forward pass.

Base model

  • Architecture: Llama 3.1
  • Parameters: 8B
  • Source: meta-llama/Llama-3.1-8B-Instruct
  • Context window: 128K tokens
  • License: Llama 3.1 Community License
Downloads last month
38
Safetensors
Model size
8B params
Tensor type
F32
·
F16
·
U8
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including miike-ai/LeanLlama-8B