proactive-cache / examples /benchmark.py
skhavin's picture
feat: initial release of proactive-cache v0.1.0
b786614
Raw
History Blame Contribute Delete
5.81 kB
"""
benchmark.py — Reproduce the core ProactiveCache results.
Runs perplexity evaluation at multiple KV budgets comparing:
- Full Attention (baseline)
- StreamingLLM
- ProactiveCache (ours)
Usage:
python examples/benchmark.py --model meta-llama/Llama-3.1-8B --dataset wikitext
python examples/benchmark.py --model meta-llama/Llama-3.1-8B --dataset pg19 --budgets 128 256 512
"""
import argparse
import time
import torch
import numpy as np
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from proactive_cache import ProactiveCache
from proactive_cache.eviction import evict
from proactive_cache.utils import to_tuple_kv, to_dynamic_cache
def parse_args():
parser = argparse.ArgumentParser(description="ProactiveCache benchmark")
parser.add_argument("--model", default="unsloth/meta-llama-3.1-8B-bnb-4bit")
parser.add_argument("--dataset", choices=["wikitext", "pg19"], default="wikitext")
parser.add_argument("--budgets", type=int, nargs="+", default=[128, 256, 512])
parser.add_argument("--num-docs", type=int, default=20)
parser.add_argument("--seq-len", type=int, default=1024)
parser.add_argument("--load-in-4bit", action="store_true", default=True)
return parser.parse_args()
def streaming_llm_indices(seq_len, budget):
sink = min(4, budget)
recent = budget - sink
sinks = list(range(sink))
recents = list(range(max(sink, seq_len - recent), seq_len))
return sorted(set(sinks + recents))[:budget]
def eval_ppl(model, input_ids, past_kv, eval_start, device):
targets = input_ids[:, eval_start:]
gen_len = targets.shape[1]
if gen_len < 5:
return None
nlls = []
next_token = input_ids[:, eval_start - 1:eval_start]
for i in range(gen_len):
out = model(next_token, past_key_values=past_kv, use_cache=True)
past_kv = out.past_key_values
logits = out.logits[:, -1, :]
nll = torch.nn.functional.cross_entropy(logits, targets[:, i]).item()
nlls.append(nll)
next_token = targets[:, i].unsqueeze(0)
return float(np.exp(np.mean(nlls)))
def run_method(model, chunks, device, method, budget, prototypes):
ppls = []
torch.cuda.reset_peak_memory_stats()
t0 = time.time()
for ids in tqdm(chunks, desc=f"{method} B={budget}"):
ids = ids.to(device)
seq_len = ids.shape[1]
eval_start = seq_len - min(128, seq_len // 4)
if eval_start < 20:
continue
ctx = ids[:, :eval_start - 1]
with torch.no_grad():
out = model(ctx, use_cache=True)
past_kv = out.past_key_values
if method != "full" and budget is not None:
ctx_len = ctx.shape[1]
if method == "proactive":
past_kv = evict(past_kv, budget, prototypes, ctx_len, device)
elif method == "streaming":
idx = streaming_llm_indices(ctx_len, min(budget, ctx_len))
idx_t = torch.tensor(idx, device=device)
kv_t = to_tuple_kv(past_kv)
pruned = tuple((k.index_select(2, idx_t), v.index_select(2, idx_t)) for k, v in kv_t)
past_kv = to_dynamic_cache(pruned)
ppl = eval_ppl(model, ids, past_kv, eval_start, device)
if ppl:
ppls.append(ppl)
elapsed = time.time() - t0
vram = torch.cuda.max_memory_allocated() / 1e6
return {"ppl": float(np.mean(ppls)), "vram_mb": vram, "time_s": elapsed}
def main():
args = parse_args()
bnb_config = BitsAndBytesConfig(
load_in_4bit=args.load_in_4bit,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
) if args.load_in_4bit else None
print(f"Loading {args.model}...")
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
args.model,
quantization_config=bnb_config,
device_map={"": "cuda"},
trust_remote_code=True,
)
model.eval()
device = next(model.parameters()).device
# Profile and build prototypes
prototypes = ProactiveCache.profile(
model, tokenizer, corpus=args.dataset,
num_docs=30, seq_len=512,
)
# Load chunks
if args.dataset == "wikitext":
from datasets import load_dataset
raw = load_dataset("wikitext", "wikitext-103-v1", split="validation")
texts = [" ".join(r["text"] for r in raw.select(range(i, i+10)) if r["text"].strip())
for i in range(0, args.num_docs * 10, 10)]
else:
from datasets import load_dataset
raw = list(load_dataset("emozilla/pg19", split="test", streaming=True).take(args.num_docs))
texts = [r["text"][:3000] for r in raw]
chunks = []
for t in texts[:args.num_docs]:
ids = tokenizer(t, return_tensors="pt", truncation=True, max_length=args.seq_len)["input_ids"]
chunks.append(ids)
# Benchmark
print(f"\n{'='*60}")
print(f" Benchmark: {args.model} | {args.dataset.upper()}")
print(f"{'='*60}")
print(f"{'Method':<22} {'Budget':>6} {'PPL':>8} {'VRAM(MB)':>10} {'Time(s)':>8}")
print("-" * 60)
full = run_method(model, chunks, device, "full", None, None)
print(f"{'Full Attention':<22} {'all':>6} {full['ppl']:>8.2f} {full['vram_mb']:>10.0f} {full['time_s']:>8.1f}")
for budget in args.budgets:
print("-" * 60)
for method, label in [("streaming", "StreamingLLM"), ("proactive", "ProactiveCache")]:
r = run_method(model, chunks, device, method, budget, prototypes)
print(f"{label:<22} {budget:>6} {r['ppl']:>8.2f} {r['vram_mb']:>10.0f} {r['time_s']:>8.1f}")
if __name__ == "__main__":
main()