""" benchmark.py — Reproduce the core ProactiveCache results. Runs perplexity evaluation at multiple KV budgets comparing: - Full Attention (baseline) - StreamingLLM - ProactiveCache (ours) Usage: python examples/benchmark.py --model meta-llama/Llama-3.1-8B --dataset wikitext python examples/benchmark.py --model meta-llama/Llama-3.1-8B --dataset pg19 --budgets 128 256 512 """ import argparse import time import torch import numpy as np from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from proactive_cache import ProactiveCache from proactive_cache.eviction import evict from proactive_cache.utils import to_tuple_kv, to_dynamic_cache def parse_args(): parser = argparse.ArgumentParser(description="ProactiveCache benchmark") parser.add_argument("--model", default="unsloth/meta-llama-3.1-8B-bnb-4bit") parser.add_argument("--dataset", choices=["wikitext", "pg19"], default="wikitext") parser.add_argument("--budgets", type=int, nargs="+", default=[128, 256, 512]) parser.add_argument("--num-docs", type=int, default=20) parser.add_argument("--seq-len", type=int, default=1024) parser.add_argument("--load-in-4bit", action="store_true", default=True) return parser.parse_args() def streaming_llm_indices(seq_len, budget): sink = min(4, budget) recent = budget - sink sinks = list(range(sink)) recents = list(range(max(sink, seq_len - recent), seq_len)) return sorted(set(sinks + recents))[:budget] def eval_ppl(model, input_ids, past_kv, eval_start, device): targets = input_ids[:, eval_start:] gen_len = targets.shape[1] if gen_len < 5: return None nlls = [] next_token = input_ids[:, eval_start - 1:eval_start] for i in range(gen_len): out = model(next_token, past_key_values=past_kv, use_cache=True) past_kv = out.past_key_values logits = out.logits[:, -1, :] nll = torch.nn.functional.cross_entropy(logits, targets[:, i]).item() nlls.append(nll) next_token = targets[:, i].unsqueeze(0) return float(np.exp(np.mean(nlls))) def run_method(model, chunks, device, method, budget, prototypes): ppls = [] torch.cuda.reset_peak_memory_stats() t0 = time.time() for ids in tqdm(chunks, desc=f"{method} B={budget}"): ids = ids.to(device) seq_len = ids.shape[1] eval_start = seq_len - min(128, seq_len // 4) if eval_start < 20: continue ctx = ids[:, :eval_start - 1] with torch.no_grad(): out = model(ctx, use_cache=True) past_kv = out.past_key_values if method != "full" and budget is not None: ctx_len = ctx.shape[1] if method == "proactive": past_kv = evict(past_kv, budget, prototypes, ctx_len, device) elif method == "streaming": idx = streaming_llm_indices(ctx_len, min(budget, ctx_len)) idx_t = torch.tensor(idx, device=device) kv_t = to_tuple_kv(past_kv) pruned = tuple((k.index_select(2, idx_t), v.index_select(2, idx_t)) for k, v in kv_t) past_kv = to_dynamic_cache(pruned) ppl = eval_ppl(model, ids, past_kv, eval_start, device) if ppl: ppls.append(ppl) elapsed = time.time() - t0 vram = torch.cuda.max_memory_allocated() / 1e6 return {"ppl": float(np.mean(ppls)), "vram_mb": vram, "time_s": elapsed} def main(): args = parse_args() bnb_config = BitsAndBytesConfig( load_in_4bit=args.load_in_4bit, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) if args.load_in_4bit else None print(f"Loading {args.model}...") tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( args.model, quantization_config=bnb_config, device_map={"": "cuda"}, trust_remote_code=True, ) model.eval() device = next(model.parameters()).device # Profile and build prototypes prototypes = ProactiveCache.profile( model, tokenizer, corpus=args.dataset, num_docs=30, seq_len=512, ) # Load chunks if args.dataset == "wikitext": from datasets import load_dataset raw = load_dataset("wikitext", "wikitext-103-v1", split="validation") texts = [" ".join(r["text"] for r in raw.select(range(i, i+10)) if r["text"].strip()) for i in range(0, args.num_docs * 10, 10)] else: from datasets import load_dataset raw = list(load_dataset("emozilla/pg19", split="test", streaming=True).take(args.num_docs)) texts = [r["text"][:3000] for r in raw] chunks = [] for t in texts[:args.num_docs]: ids = tokenizer(t, return_tensors="pt", truncation=True, max_length=args.seq_len)["input_ids"] chunks.append(ids) # Benchmark print(f"\n{'='*60}") print(f" Benchmark: {args.model} | {args.dataset.upper()}") print(f"{'='*60}") print(f"{'Method':<22} {'Budget':>6} {'PPL':>8} {'VRAM(MB)':>10} {'Time(s)':>8}") print("-" * 60) full = run_method(model, chunks, device, "full", None, None) print(f"{'Full Attention':<22} {'all':>6} {full['ppl']:>8.2f} {full['vram_mb']:>10.0f} {full['time_s']:>8.1f}") for budget in args.budgets: print("-" * 60) for method, label in [("streaming", "StreamingLLM"), ("proactive", "ProactiveCache")]: r = run_method(model, chunks, device, method, budget, prototypes) print(f"{label:<22} {budget:>6} {r['ppl']:>8.2f} {r['vram_mb']:>10.0f} {r['time_s']:>8.1f}") if __name__ == "__main__": main()