Llama-3.1-8B-Instruct with TokenButler
This repository releases Llama-3.1-8B-Instruct augmented with a
TokenButler sparse-attention predictor, packaged as a Hugging Face
trust_remote_code model and loadable directly with the standard
Transformers API.
TokenButler is a lightweight, query-aware token-importance predictor that works alongside a frozen large language model. At inference time it identifies the small subset of tokens in the KV cache that contribute meaningfully to each decoding step, enabling sparse attention with a fixed budget while preserving the full KV cache. The base model is left untouched; only the predictor parameters (a small fraction of the total) are added.
Method
Let G denote the producer frequency. Producer layers are the layers at
indices 0, G, 2G, …. Each producer reads the layer's hidden states
H ∈ ℝB×L×E and predicts low-dimensional importance queries for
itself and the next G − 1 consumer layers:
Qimp = fθ(LN(H)) ∈ ℝ(B·H)×G×L×d′
For each layer ℓ the real (post-RoPE) cached keys are projected into the same d′-dimensional space via a learned matrix WK(ℓ) ∈ ℝD×d′:
Kimp(ℓ) = K(ℓ) · WK(ℓ)
A consumer at layer ℓ uses slot (ℓ − 1) mod G of its producer's importance queries together with Kimp(ℓ) to score every cached token and select a top-x subset under a fixed token budget, augmented by a sink prefix and a recent local window. Only the predictor parameters are trained, by distilling the masked causal attention distribution of the frozen base model.
For the full method, training procedure, and evaluation details see the TokenButler paper (under review at COLM 2026).
Configuration
| Field | Value |
|---|---|
| Base model | meta-llama/Llama-3.1-8B-Instruct |
| Producer frequency G | 4 |
| Interaction dimension d′ | 16 |
| Predictor MLP hidden size | 512 |
Sink tokens (min_sparse_index) |
8 |
Local window (sliding_window) |
128 |
| Default token-sparsity policy | fixed_50pc |
The predictor adds approximately 0.15 GB on top of the 16 GB base model weights (about 1% parameter overhead).
Usage
Standard AutoTokenizer / AutoModelForCausalLM calls with
trust_remote_code=True. Loading in bfloat16 is recommended; the default
float32 precision needs roughly twice the GPU memory.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
model_name = "Alwahsh/Meta-Llama-3.1-8B-Instruct-Butler"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
).to("cuda")
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
out = generator(
"Long-context inference is challenging because ",
max_new_tokens=200, do_sample=True, top_p=0.9, temperature=0.6,
)
print(out[0]["generated_text"])
Memory requirements
| Precision | Weights | Practical GPU |
|---|---|---|
bfloat16 / float16 |
≈ 16 GB | A100 40 GB, A6000, L4 24 GB, V100 32 GB |
float32 (default if torch_dtype is not set) |
≈ 32 GB | A100 80 GB, multi-GPU |
On Google Colab the free T4 instance has 16 GB of VRAM, which is too tight
for an 8B model plus a KV cache. Use a Colab Pro / Pro+ runtime (A100, L4,
or V100) and pass torch_dtype=torch.bfloat16 as shown above.
The default per-layer sparsity is 50% (every layer except layer 0 retains the top 50% of tokens, plus the sink prefix and the sliding window). To change the sparsity at runtime:
def set_sparsity(model, sparsity: str):
for module in model.modules():
if module.__class__.__name__.endswith("AttentionExperimental"):
module.token_sparse_method = sparsity
module.set_token_sparsity()
return model
model = set_sparsity(model, "fixed_70pc") # retain 30% of tokens per layer
The fixed_<N>pc policy is the supported strategy. Sink length and sliding
window can be adjusted in the same way by setting module.min_sparse_index
and module.sliding_window on every *AttentionExperimental module.
Files
| File | Purpose |
|---|---|
config.json, generation_config.json |
Model and generation configuration |
model-*.safetensors, model.safetensors.index.json |
Sharded model weights (base + predictor) |
modeling_llama_butler.py |
Custom modeling code (loaded via trust_remote_code=True) |
tokenizer.json, tokenizer_config.json, special_tokens_map.json |
Tokenizer files inherited from the base model |
L3_8Bi_d16_i512_pf4.pt |
Standalone predictor checkpoint (not required for inference; provided for completeness) |
Requirements
transformers >= 4.45torchwith bfloat16 support recommended
License
This release is governed by the Llama 3.1 Community License Agreement, inherited from the base model.
- Downloads last month
- 110
Model tree for Alwahsh/Meta-Llama-3.1-8B-Instruct-Butler
Base model
meta-llama/Llama-3.1-8B