Llama-3.1-8B-Instruct with TokenButler

This repository releases Llama-3.1-8B-Instruct augmented with a TokenButler sparse-attention predictor, packaged as a Hugging Face trust_remote_code model and loadable directly with the standard Transformers API.

TokenButler is a lightweight, query-aware token-importance predictor that works alongside a frozen large language model. At inference time it identifies the small subset of tokens in the KV cache that contribute meaningfully to each decoding step, enabling sparse attention with a fixed budget while preserving the full KV cache. The base model is left untouched; only the predictor parameters (a small fraction of the total) are added.

Method

Let G denote the producer frequency. Producer layers are the layers at indices 0, G, 2G, …. Each producer reads the layer's hidden states H ∈ ℝB×L×E and predicts low-dimensional importance queries for itself and the next G − 1 consumer layers:

Qimp = fθ(LN(H)) ∈ ℝ(B·H)×G×L×d′

For each layer the real (post-RoPE) cached keys are projected into the same d′-dimensional space via a learned matrix WK(ℓ) ∈ ℝD×d′:

Kimp(ℓ) = K(ℓ) · WK(ℓ)

A consumer at layer uses slot (ℓ − 1) mod G of its producer's importance queries together with Kimp(ℓ) to score every cached token and select a top-x subset under a fixed token budget, augmented by a sink prefix and a recent local window. Only the predictor parameters are trained, by distilling the masked causal attention distribution of the frozen base model.

For the full method, training procedure, and evaluation details see the TokenButler paper (under review at COLM 2026).

Configuration

Field Value
Base model meta-llama/Llama-3.1-8B-Instruct
Producer frequency G 4
Interaction dimension d′ 16
Predictor MLP hidden size 512
Sink tokens (min_sparse_index) 8
Local window (sliding_window) 128
Default token-sparsity policy fixed_50pc

The predictor adds approximately 0.15 GB on top of the 16 GB base model weights (about 1% parameter overhead).

Usage

Standard AutoTokenizer / AutoModelForCausalLM calls with trust_remote_code=True. Loading in bfloat16 is recommended; the default float32 precision needs roughly twice the GPU memory.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

model_name = "Alwahsh/Meta-Llama-3.1-8B-Instruct-Butler"
tokenizer  = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model      = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
).to("cuda")

generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
out = generator(
    "Long-context inference is challenging because ",
    max_new_tokens=200, do_sample=True, top_p=0.9, temperature=0.6,
)
print(out[0]["generated_text"])

Memory requirements

Precision Weights Practical GPU
bfloat16 / float16 ≈ 16 GB A100 40 GB, A6000, L4 24 GB, V100 32 GB
float32 (default if torch_dtype is not set) ≈ 32 GB A100 80 GB, multi-GPU

On Google Colab the free T4 instance has 16 GB of VRAM, which is too tight for an 8B model plus a KV cache. Use a Colab Pro / Pro+ runtime (A100, L4, or V100) and pass torch_dtype=torch.bfloat16 as shown above.

The default per-layer sparsity is 50% (every layer except layer 0 retains the top 50% of tokens, plus the sink prefix and the sliding window). To change the sparsity at runtime:

def set_sparsity(model, sparsity: str):
    for module in model.modules():
        if module.__class__.__name__.endswith("AttentionExperimental"):
            module.token_sparse_method = sparsity
            module.set_token_sparsity()
    return model

model = set_sparsity(model, "fixed_70pc")  # retain 30% of tokens per layer

The fixed_<N>pc policy is the supported strategy. Sink length and sliding window can be adjusted in the same way by setting module.min_sparse_index and module.sliding_window on every *AttentionExperimental module.

Files

File Purpose
config.json, generation_config.json Model and generation configuration
model-*.safetensors, model.safetensors.index.json Sharded model weights (base + predictor)
modeling_llama_butler.py Custom modeling code (loaded via trust_remote_code=True)
tokenizer.json, tokenizer_config.json, special_tokens_map.json Tokenizer files inherited from the base model
L3_8Bi_d16_i512_pf4.pt Standalone predictor checkpoint (not required for inference; provided for completeness)

Requirements

  • transformers >= 4.45
  • torch with bfloat16 support recommended

License

This release is governed by the Llama 3.1 Community License Agreement, inherited from the base model.

Downloads last month
110
Safetensors
Model size
8B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for Alwahsh/Meta-Llama-3.1-8B-Instruct-Butler

Finetuned
(2719)
this model