Instructions to use GenomaLabs-com/kv-cache-eviction-mla with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use GenomaLabs-com/kv-cache-eviction-mla with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("GenomaLabs-com/kv-cache-eviction-mla", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """RULER NIAH (Needle-In-A-Haystack) harness for KV cache eviction benchmarks. | |
| Implements the canonical single-needle NIAH task used to measure long-context | |
| retrieval accuracy. Used as the quality probe for the H2O eviction sweep: | |
| for each context length and budget, run N trials and report exact-match | |
| accuracy on the magic-string needle. | |
| Components: | |
| NIAHTrial - one (haystack, needle, question) instance | |
| NIAHGenerator - produces NIAHTrial at a given context length | |
| NIAHScorer - exact-match scoring against a model response | |
| NIAHRunner - drives N trials + computes accuracy | |
| The model integration is intentionally pluggable: the runner takes a callable | |
| that maps prompt -> response. This lets the same harness drive HF transformers | |
| generation, ollama API calls, or any other completion backend. | |
| Reference: | |
| Hsieh et al. 2024, "RULER: What's the Real Context Size of Your Long-Context | |
| Language Models?" (arXiv:2404.06654). NIAH-1 single-needle variant. | |
| """ | |
| from __future__ import annotations | |
| import csv | |
| import hashlib | |
| import random | |
| import re | |
| import string | |
| import time | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Callable, Optional | |
| # --------------------------------------------------------------------------- | |
| # Filler text. RULER uses essays from PG-19 + arxiv abstracts; for a | |
| # self-contained harness we use a deterministic Lorem-style filler that | |
| # covers the same statistical properties the eviction policy faces: | |
| # | |
| # - Sentence and paragraph structure (so attention has parseable boundaries) | |
| # - Some token-level repetition (so heavy-hitter scoring has variance) | |
| # - No accidental needle-shaped tokens (we screen below) | |
| # --------------------------------------------------------------------------- | |
| _FILLER_SENTENCES = [ | |
| "The mountain ridge cast long shadows over the valley as the sun began to set behind the western peaks.", | |
| "Old rivers carved deep canyons into the limestone cliffs across many millions of years of patient flow.", | |
| "Travelers crossing the high pass often paused at the stone marker to read the names carved by earlier wanderers.", | |
| "The library held leather-bound volumes whose pages had grown brittle but whose ink still spoke clearly to readers.", | |
| "Sailors learned to read the sky before the instruments arrived, and many trusted the old patterns longer than the new.", | |
| "Cooks in the harbor kitchens prepared meals from whatever the morning catch yielded, adapting recipes day by day.", | |
| "Schoolchildren studied the maps with care, tracing rivers from the high country down to the coastal plains.", | |
| "Engineers built bridges that flexed with the wind rather than fighting it, and the bridges stood for generations.", | |
| "Farmers in the river valley rotated three crops across four fields, leaving one fallow each year for recovery.", | |
| "Astronomers tracked the slow drift of distant stars and wrote down patterns that took lifetimes to confirm.", | |
| "Musicians tuned their instruments by ear in the cold morning air and played until the sun warmed the wood.", | |
| "Doctors in the clinic kept careful notes on each patient, building a record that helped them spot patterns over time.", | |
| "Carpenters chose their timber by listening to the grain and tapping the planks to hear the resonance.", | |
| "Sheep grazed on the high meadows during summer and were brought down to the lower pastures before the snow.", | |
| "Blacksmiths heated the iron until it glowed and shaped each tool to fit the hand that would use it.", | |
| "Fishermen mended their nets in the late afternoon while the boats rocked gently against the wooden pier.", | |
| "Weavers passed the shuttle through the loom with steady hands and the rhythm carried them through long days.", | |
| "Potters shaped clay on slow wheels and let each piece dry for several days before the firing began.", | |
| "Beekeepers walked carefully through the orchards in spring and collected honey only when the hives ran full.", | |
| "Gardeners trained climbing vines along stone walls and harvested berries only when the morning dew had lifted.", | |
| ] | |
| # Magic-string needle template. The scorer matches this exact string | |
| # (case-insensitive, allowing surrounding punctuation). | |
| NEEDLE_TEMPLATE = "The secret password is {magic}. Remember this exact value." | |
| QUESTION_TEMPLATE = "What is the secret password? Answer with just the password value." | |
| def _make_magic_string(rng: random.Random, length: int = 12) -> str: | |
| """Generate a unique alphanumeric needle that won't collide with filler.""" | |
| return "".join(rng.choice(string.ascii_uppercase + string.digits) for _ in range(length)) | |
| # --------------------------------------------------------------------------- | |
| # Trial generation | |
| # --------------------------------------------------------------------------- | |
| class NIAHTrial: | |
| trial_id: str | |
| target_chars: int | |
| needle_position_frac: float # 0.0 to 1.0 | |
| magic: str | |
| haystack_chars: int | |
| prompt: str | |
| expected: str # the magic-string the model must reproduce | |
| class NIAHGenerator: | |
| """Generates NIAH trials at a target context length (measured in chars). | |
| To map chars to tokens approximately: English text averages ~4 chars/token | |
| for most BPE tokenizers. So target_chars=20_000 ~= 5K tokens. | |
| """ | |
| def __init__(self, seed: int = 42): | |
| self.rng = random.Random(seed) | |
| def generate(self, target_chars: int, position_frac: Optional[float] = None) -> NIAHTrial: | |
| """Build one trial. position_frac in [0,1]; None = random.""" | |
| if position_frac is None: | |
| position_frac = self.rng.uniform(0.05, 0.95) | |
| magic = _make_magic_string(self.rng) | |
| needle = NEEDLE_TEMPLATE.format(magic=magic) | |
| # Build haystack until target_chars reached, leaving room for the needle. | |
| target_haystack = target_chars - len(needle) - len(QUESTION_TEMPLATE) - 200 | |
| if target_haystack < 0: | |
| raise ValueError(f"target_chars={target_chars} too small for the template overhead") | |
| chunks = [] | |
| used = 0 | |
| while used < target_haystack: | |
| sentence = self.rng.choice(_FILLER_SENTENCES) | |
| # screen — make sure no filler sentence accidentally contains the | |
| # magic string (vanishingly unlikely with 12 random alnum, but safe) | |
| if magic.lower() in sentence.lower(): | |
| continue | |
| chunks.append(sentence) | |
| used += len(sentence) + 1 # +1 for the space we'll join with | |
| haystack_text = " ".join(chunks) | |
| # Insert needle at position_frac | |
| insert_idx = int(len(haystack_text) * position_frac) | |
| # Snap to a word boundary | |
| while insert_idx < len(haystack_text) and haystack_text[insert_idx] != " ": | |
| insert_idx += 1 | |
| haystack_with_needle = ( | |
| haystack_text[:insert_idx] | |
| + " " | |
| + needle | |
| + " " | |
| + haystack_text[insert_idx:] | |
| ) | |
| prompt = ( | |
| "Read the following passage carefully. After the passage, you will be asked " | |
| "a question about a specific detail.\n\n" | |
| "PASSAGE:\n" | |
| f"{haystack_with_needle}\n\n" | |
| "QUESTION:\n" | |
| f"{QUESTION_TEMPLATE}" | |
| ) | |
| # Stable ID for the trial | |
| trial_id = hashlib.md5(f"{target_chars}:{position_frac}:{magic}".encode()).hexdigest()[:12] | |
| return NIAHTrial( | |
| trial_id=trial_id, | |
| target_chars=target_chars, | |
| needle_position_frac=position_frac, | |
| magic=magic, | |
| haystack_chars=len(haystack_with_needle), | |
| prompt=prompt, | |
| expected=magic, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Scoring | |
| # --------------------------------------------------------------------------- | |
| class NIAHScorer: | |
| """Exact-match scoring with light normalization. | |
| A response is correct if and only if the expected magic string appears | |
| as a contiguous substring (case-insensitive). We don't require | |
| case-perfect match because tokenizers occasionally case-shift small | |
| portions of the needle; the magic string itself is uppercase + digits, | |
| so the case shift is benign. | |
| """ | |
| def is_correct(response: str, expected: str) -> bool: | |
| return expected.upper() in response.upper() | |
| def score(trials: list[NIAHTrial], responses: list[str]) -> dict: | |
| """Return aggregate metrics dict.""" | |
| assert len(trials) == len(responses), "trials and responses length mismatch" | |
| n = len(trials) | |
| correct = sum(NIAHScorer.is_correct(r, t.expected) for r, t in zip(responses, trials)) | |
| return { | |
| "n_trials": n, | |
| "n_correct": correct, | |
| "accuracy": correct / n if n else 0.0, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Runner | |
| # --------------------------------------------------------------------------- | |
| class NIAHCellResult: | |
| target_chars: int | |
| n_trials: int | |
| n_correct: int | |
| accuracy: float | |
| mean_response_chars: float | |
| elapsed_s: float | |
| per_trial: list[dict] = field(default_factory=list) | |
| def run_niah_cell( | |
| generate_fn: Callable[[str], str], | |
| target_chars: int, | |
| n_trials: int = 25, | |
| seed: int = 42, | |
| ) -> NIAHCellResult: | |
| """Run one (context_length) cell: N trials at the given target_chars. | |
| generate_fn: callable str -> str. Takes the prompt, returns the model's | |
| response. The harness does not care how the response was produced. | |
| Returns a NIAHCellResult with per-trial details + aggregate metrics. | |
| """ | |
| gen = NIAHGenerator(seed=seed) | |
| per_trial = [] | |
| n_correct = 0 | |
| total_response_chars = 0 | |
| t_start = time.time() | |
| for i in range(n_trials): | |
| # Spread positions evenly through [0.05, 0.95] so every cell exercises | |
| # depth-position coverage (RULER convention). | |
| frac = 0.05 + (0.90 * i / max(1, n_trials - 1)) | |
| trial = gen.generate(target_chars=target_chars, position_frac=frac) | |
| response = generate_fn(trial.prompt) | |
| correct = NIAHScorer.is_correct(response, trial.expected) | |
| n_correct += int(correct) | |
| total_response_chars += len(response) | |
| per_trial.append({ | |
| "trial_id": trial.trial_id, | |
| "needle_position_frac": round(trial.needle_position_frac, 3), | |
| "haystack_chars": trial.haystack_chars, | |
| "magic": trial.magic, | |
| "response_chars": len(response), | |
| "correct": int(correct), | |
| }) | |
| elapsed = time.time() - t_start | |
| return NIAHCellResult( | |
| target_chars=target_chars, | |
| n_trials=n_trials, | |
| n_correct=n_correct, | |
| accuracy=n_correct / n_trials if n_trials else 0.0, | |
| mean_response_chars=total_response_chars / n_trials if n_trials else 0.0, | |
| elapsed_s=elapsed, | |
| per_trial=per_trial, | |
| ) | |
| def write_cell_csv(result: NIAHCellResult, out_path: Path) -> None: | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(out_path, "w", newline="") as f: | |
| writer = csv.DictWriter( | |
| f, | |
| fieldnames=["trial_id", "needle_position_frac", "haystack_chars", "magic", "response_chars", "correct"], | |
| ) | |
| writer.writeheader() | |
| writer.writerows(result.per_trial) | |
| # --------------------------------------------------------------------------- | |
| # Self-test (run this file directly to verify harness logic) | |
| # --------------------------------------------------------------------------- | |
| def _selftest_oracle_correct(prompt: str) -> str: | |
| """Mock 'perfect' model: extract the needle by regex and reproduce it.""" | |
| m = re.search(r"The secret password is ([A-Z0-9]+)\.", prompt) | |
| return f"The password is {m.group(1)}." if m else "(no password found)" | |
| def _selftest_oracle_wrong(prompt: str) -> str: | |
| """Mock 'broken' model: returns a generic response without the needle.""" | |
| return "I cannot find a specific password in the passage." | |
| def _selftest_oracle_partial(prompt: str) -> str: | |
| """Mock 'lossy' model: 70% accuracy.""" | |
| m = re.search(r"The secret password is ([A-Z0-9]+)\.", prompt) | |
| if m and random.random() < 0.7: | |
| return f"Password: {m.group(1)}" | |
| return "I'm not sure what the password is." | |
| def selftest(): | |
| print("[niah-selftest] generating one trial at 4000 chars...") | |
| gen = NIAHGenerator(seed=1) | |
| trial = gen.generate(target_chars=4000, position_frac=0.5) | |
| print(f" trial_id={trial.trial_id}") | |
| print(f" haystack_chars={trial.haystack_chars}") | |
| print(f" magic={trial.magic}") | |
| print(f" needle_position_frac={trial.needle_position_frac}") | |
| print(f" prompt[:200]={trial.prompt[:200]!r}") | |
| print(f" expected={trial.expected}") | |
| assert trial.expected in trial.prompt, "needle must appear in the prompt" | |
| assert NIAHScorer.is_correct(f"The password is {trial.expected}.", trial.expected) | |
| assert not NIAHScorer.is_correct("I don't know.", trial.expected) | |
| print("[niah-selftest] generator + scorer basic checks PASS") | |
| print("\n[niah-selftest] runner with oracle_correct (should be 100%)...") | |
| r = run_niah_cell(_selftest_oracle_correct, target_chars=4000, n_trials=25, seed=1) | |
| print(f" accuracy={r.accuracy:.2%} ({r.n_correct}/{r.n_trials}) elapsed={r.elapsed_s:.2f}s") | |
| assert r.accuracy == 1.0, f"oracle_correct should be 100%, got {r.accuracy:.2%}" | |
| print("\n[niah-selftest] runner with oracle_wrong (should be 0%)...") | |
| r = run_niah_cell(_selftest_oracle_wrong, target_chars=4000, n_trials=25, seed=1) | |
| print(f" accuracy={r.accuracy:.2%} ({r.n_correct}/{r.n_trials}) elapsed={r.elapsed_s:.2f}s") | |
| assert r.accuracy == 0.0, f"oracle_wrong should be 0%, got {r.accuracy:.2%}" | |
| print("\n[niah-selftest] runner with oracle_partial (should be ~70%)...") | |
| random.seed(42) | |
| r = run_niah_cell(_selftest_oracle_partial, target_chars=4000, n_trials=200, seed=1) | |
| print(f" accuracy={r.accuracy:.2%} ({r.n_correct}/{r.n_trials}) elapsed={r.elapsed_s:.2f}s") | |
| assert 0.55 < r.accuracy < 0.85, f"oracle_partial should be ~70%, got {r.accuracy:.2%}" | |
| print("\n[niah-selftest] writing per-trial CSV...") | |
| out = Path("/tmp/niah_selftest.csv") | |
| write_cell_csv(r, out) | |
| n_rows = sum(1 for _ in open(out)) - 1 | |
| print(f" wrote {n_rows} rows -> {out}") | |
| assert n_rows == 200 | |
| print("\n[niah-selftest] ALL CHECKS PASS") | |
| if __name__ == "__main__": | |
| selftest() | |