Instructions to use GenomaLabs-com/kv-cache-eviction-mla with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use GenomaLabs-com/kv-cache-eviction-mla with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("GenomaLabs-com/kv-cache-eviction-mla", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 5,417 Bytes
754890f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """Round 3b: validate the NIAH harness against a small control model on plain
HuggingFace transformers (no AirLLM). Confirms the harness works end-to-end
on a real LLM before we point it at the much heavier Kimi K2.6 path.
Defaults to Qwen2.5-3B-Instruct (downloads ~6 GB on first run). For a smaller
fingerprint, pass `--model unsloth/Qwen2.5-3B-Instruct-unsloth-bnb-4bit` (~2 GB).
Run:
python scripts/run_control_niah.py \
--model Qwen/Qwen2.5-3B-Instruct \
--context-lengths 2000 8000 \
--trials 10 \
--out-dir /tmp/niah_control
Expected: Qwen2.5-3B should score >= 80% on the 2K-char (~500-token) cell.
The 8K cell will be lower; that's the point — proves the harness picks up
context-length quality decay.
"""
from __future__ import annotations
import argparse
import csv
import sys
import time
from pathlib import Path
# Local imports
sys.path.insert(0, str(Path(__file__).resolve().parent))
from niah_harness import run_niah_cell, write_cell_csv
def make_hf_generate_fn(model_id: str, dtype: str = "bfloat16", device: str = "auto", max_new_tokens: int = 64):
"""Return (model, tokenizer, generate_fn) for a HF transformers model."""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
print(f"[control] loading {model_id} ({dtype}) ...")
t0 = time.time()
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
torch_dtype = getattr(torch, dtype) if hasattr(torch, dtype) else torch.float16
model = AutoModelForCausalLM.from_pretrained(
model_id,
dtype=torch_dtype,
device_map=device,
trust_remote_code=True,
)
model.eval()
print(f"[control] loaded in {time.time()-t0:.1f}s, params: {sum(p.numel() for p in model.parameters()):,}")
def generate(prompt: str) -> str:
inputs = tok(prompt, return_tensors="pt").to(model.device)
prompt_len = inputs["input_ids"].shape[1]
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=0.0,
pad_token_id=tok.eos_token_id,
)
# Strip prompt tokens, decode just the new tokens
new_tokens = out[0, prompt_len:]
return tok.decode(new_tokens, skip_special_tokens=True)
return model, tok, generate
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model", default="Qwen/Qwen2.5-3B-Instruct",
help="HF model ID. Defaults to Qwen2.5-3B-Instruct.")
ap.add_argument("--dtype", default="bfloat16")
ap.add_argument("--device", default="auto")
ap.add_argument("--context-lengths", type=int, nargs="+", default=[2000, 8000])
ap.add_argument("--trials", type=int, default=10)
ap.add_argument("--max-new-tokens", type=int, default=64)
ap.add_argument("--out-dir", default="/tmp/niah_control")
ap.add_argument("--seed", type=int, default=42)
args = ap.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
model, tok, gen = make_hf_generate_fn(
args.model, dtype=args.dtype, device=args.device, max_new_tokens=args.max_new_tokens
)
print(f"\n[control] sweep: {len(args.context_lengths)} cells, {args.trials} trials each")
summary = []
for ctx in args.context_lengths:
cell_id = f"ctx{ctx}_baseline"
print(f"\n[control] === cell {cell_id} ===")
result = run_niah_cell(
generate_fn=gen,
target_chars=ctx,
n_trials=args.trials,
seed=args.seed,
)
cell_csv = out_dir / f"{cell_id}.csv"
write_cell_csv(result, cell_csv)
print(f"[control] accuracy={result.accuracy:.2%} ({result.n_correct}/{result.n_trials}) "
f"elapsed={result.elapsed_s:.1f}s mean_response_chars={result.mean_response_chars:.0f}")
summary.append({
"cell_id": cell_id,
"ctx_chars": ctx,
"n_trials": result.n_trials,
"n_correct": result.n_correct,
"accuracy": round(result.accuracy, 4),
"elapsed_s": round(result.elapsed_s, 1),
})
summary_path = out_dir / "summary.csv"
if summary:
with open(summary_path, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=list(summary[0].keys()))
w.writeheader()
w.writerows(summary)
print(f"\n[control] summary -> {summary_path}")
# Sanity check: the harness should resolve the needle on at least the
# smallest context. If 2K accuracy is < 50%, something is wrong with
# either the harness OR the model is fundamentally broken at retrieval.
if summary:
smallest = min(summary, key=lambda s: s["ctx_chars"])
if smallest["accuracy"] < 0.5:
print(f"\n[control] WARNING: smallest-context accuracy {smallest['accuracy']:.2%} is < 50%.")
print("[control] either the harness is broken or the model can't do NIAH at all.")
print("[control] inspect a per-trial CSV to see what the model actually produced.")
else:
print(f"\n[control] PASS: smallest-context accuracy {smallest['accuracy']:.2%} >= 50%.")
print("[control] harness validated end-to-end on a real LLM.")
if __name__ == "__main__":
main()
|