kv-cache-eviction-mla / scripts /niah_harness.py
GENOMA LABS / research
Round 3a: RULER NIAH single-needle harness (logic validated)
5ea05a3
"""RULER NIAH (Needle-In-A-Haystack) harness for KV cache eviction benchmarks.
Implements the canonical single-needle NIAH task used to measure long-context
retrieval accuracy. Used as the quality probe for the H2O eviction sweep:
for each context length and budget, run N trials and report exact-match
accuracy on the magic-string needle.
Components:
NIAHTrial - one (haystack, needle, question) instance
NIAHGenerator - produces NIAHTrial at a given context length
NIAHScorer - exact-match scoring against a model response
NIAHRunner - drives N trials + computes accuracy
The model integration is intentionally pluggable: the runner takes a callable
that maps prompt -> response. This lets the same harness drive HF transformers
generation, ollama API calls, or any other completion backend.
Reference:
Hsieh et al. 2024, "RULER: What's the Real Context Size of Your Long-Context
Language Models?" (arXiv:2404.06654). NIAH-1 single-needle variant.
"""
from __future__ import annotations
import csv
import hashlib
import random
import re
import string
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Optional
# ---------------------------------------------------------------------------
# Filler text. RULER uses essays from PG-19 + arxiv abstracts; for a
# self-contained harness we use a deterministic Lorem-style filler that
# covers the same statistical properties the eviction policy faces:
#
# - Sentence and paragraph structure (so attention has parseable boundaries)
# - Some token-level repetition (so heavy-hitter scoring has variance)
# - No accidental needle-shaped tokens (we screen below)
# ---------------------------------------------------------------------------
_FILLER_SENTENCES = [
"The mountain ridge cast long shadows over the valley as the sun began to set behind the western peaks.",
"Old rivers carved deep canyons into the limestone cliffs across many millions of years of patient flow.",
"Travelers crossing the high pass often paused at the stone marker to read the names carved by earlier wanderers.",
"The library held leather-bound volumes whose pages had grown brittle but whose ink still spoke clearly to readers.",
"Sailors learned to read the sky before the instruments arrived, and many trusted the old patterns longer than the new.",
"Cooks in the harbor kitchens prepared meals from whatever the morning catch yielded, adapting recipes day by day.",
"Schoolchildren studied the maps with care, tracing rivers from the high country down to the coastal plains.",
"Engineers built bridges that flexed with the wind rather than fighting it, and the bridges stood for generations.",
"Farmers in the river valley rotated three crops across four fields, leaving one fallow each year for recovery.",
"Astronomers tracked the slow drift of distant stars and wrote down patterns that took lifetimes to confirm.",
"Musicians tuned their instruments by ear in the cold morning air and played until the sun warmed the wood.",
"Doctors in the clinic kept careful notes on each patient, building a record that helped them spot patterns over time.",
"Carpenters chose their timber by listening to the grain and tapping the planks to hear the resonance.",
"Sheep grazed on the high meadows during summer and were brought down to the lower pastures before the snow.",
"Blacksmiths heated the iron until it glowed and shaped each tool to fit the hand that would use it.",
"Fishermen mended their nets in the late afternoon while the boats rocked gently against the wooden pier.",
"Weavers passed the shuttle through the loom with steady hands and the rhythm carried them through long days.",
"Potters shaped clay on slow wheels and let each piece dry for several days before the firing began.",
"Beekeepers walked carefully through the orchards in spring and collected honey only when the hives ran full.",
"Gardeners trained climbing vines along stone walls and harvested berries only when the morning dew had lifted.",
]
# Magic-string needle template. The scorer matches this exact string
# (case-insensitive, allowing surrounding punctuation).
NEEDLE_TEMPLATE = "The secret password is {magic}. Remember this exact value."
QUESTION_TEMPLATE = "What is the secret password? Answer with just the password value."
def _make_magic_string(rng: random.Random, length: int = 12) -> str:
"""Generate a unique alphanumeric needle that won't collide with filler."""
return "".join(rng.choice(string.ascii_uppercase + string.digits) for _ in range(length))
# ---------------------------------------------------------------------------
# Trial generation
# ---------------------------------------------------------------------------
@dataclass
class NIAHTrial:
trial_id: str
target_chars: int
needle_position_frac: float # 0.0 to 1.0
magic: str
haystack_chars: int
prompt: str
expected: str # the magic-string the model must reproduce
class NIAHGenerator:
"""Generates NIAH trials at a target context length (measured in chars).
To map chars to tokens approximately: English text averages ~4 chars/token
for most BPE tokenizers. So target_chars=20_000 ~= 5K tokens.
"""
def __init__(self, seed: int = 42):
self.rng = random.Random(seed)
def generate(self, target_chars: int, position_frac: Optional[float] = None) -> NIAHTrial:
"""Build one trial. position_frac in [0,1]; None = random."""
if position_frac is None:
position_frac = self.rng.uniform(0.05, 0.95)
magic = _make_magic_string(self.rng)
needle = NEEDLE_TEMPLATE.format(magic=magic)
# Build haystack until target_chars reached, leaving room for the needle.
target_haystack = target_chars - len(needle) - len(QUESTION_TEMPLATE) - 200
if target_haystack < 0:
raise ValueError(f"target_chars={target_chars} too small for the template overhead")
chunks = []
used = 0
while used < target_haystack:
sentence = self.rng.choice(_FILLER_SENTENCES)
# screen — make sure no filler sentence accidentally contains the
# magic string (vanishingly unlikely with 12 random alnum, but safe)
if magic.lower() in sentence.lower():
continue
chunks.append(sentence)
used += len(sentence) + 1 # +1 for the space we'll join with
haystack_text = " ".join(chunks)
# Insert needle at position_frac
insert_idx = int(len(haystack_text) * position_frac)
# Snap to a word boundary
while insert_idx < len(haystack_text) and haystack_text[insert_idx] != " ":
insert_idx += 1
haystack_with_needle = (
haystack_text[:insert_idx]
+ " "
+ needle
+ " "
+ haystack_text[insert_idx:]
)
prompt = (
"Read the following passage carefully. After the passage, you will be asked "
"a question about a specific detail.\n\n"
"PASSAGE:\n"
f"{haystack_with_needle}\n\n"
"QUESTION:\n"
f"{QUESTION_TEMPLATE}"
)
# Stable ID for the trial
trial_id = hashlib.md5(f"{target_chars}:{position_frac}:{magic}".encode()).hexdigest()[:12]
return NIAHTrial(
trial_id=trial_id,
target_chars=target_chars,
needle_position_frac=position_frac,
magic=magic,
haystack_chars=len(haystack_with_needle),
prompt=prompt,
expected=magic,
)
# ---------------------------------------------------------------------------
# Scoring
# ---------------------------------------------------------------------------
class NIAHScorer:
"""Exact-match scoring with light normalization.
A response is correct if and only if the expected magic string appears
as a contiguous substring (case-insensitive). We don't require
case-perfect match because tokenizers occasionally case-shift small
portions of the needle; the magic string itself is uppercase + digits,
so the case shift is benign.
"""
@staticmethod
def is_correct(response: str, expected: str) -> bool:
return expected.upper() in response.upper()
@staticmethod
def score(trials: list[NIAHTrial], responses: list[str]) -> dict:
"""Return aggregate metrics dict."""
assert len(trials) == len(responses), "trials and responses length mismatch"
n = len(trials)
correct = sum(NIAHScorer.is_correct(r, t.expected) for r, t in zip(responses, trials))
return {
"n_trials": n,
"n_correct": correct,
"accuracy": correct / n if n else 0.0,
}
# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------
@dataclass
class NIAHCellResult:
target_chars: int
n_trials: int
n_correct: int
accuracy: float
mean_response_chars: float
elapsed_s: float
per_trial: list[dict] = field(default_factory=list)
def run_niah_cell(
generate_fn: Callable[[str], str],
target_chars: int,
n_trials: int = 25,
seed: int = 42,
) -> NIAHCellResult:
"""Run one (context_length) cell: N trials at the given target_chars.
generate_fn: callable str -> str. Takes the prompt, returns the model's
response. The harness does not care how the response was produced.
Returns a NIAHCellResult with per-trial details + aggregate metrics.
"""
gen = NIAHGenerator(seed=seed)
per_trial = []
n_correct = 0
total_response_chars = 0
t_start = time.time()
for i in range(n_trials):
# Spread positions evenly through [0.05, 0.95] so every cell exercises
# depth-position coverage (RULER convention).
frac = 0.05 + (0.90 * i / max(1, n_trials - 1))
trial = gen.generate(target_chars=target_chars, position_frac=frac)
response = generate_fn(trial.prompt)
correct = NIAHScorer.is_correct(response, trial.expected)
n_correct += int(correct)
total_response_chars += len(response)
per_trial.append({
"trial_id": trial.trial_id,
"needle_position_frac": round(trial.needle_position_frac, 3),
"haystack_chars": trial.haystack_chars,
"magic": trial.magic,
"response_chars": len(response),
"correct": int(correct),
})
elapsed = time.time() - t_start
return NIAHCellResult(
target_chars=target_chars,
n_trials=n_trials,
n_correct=n_correct,
accuracy=n_correct / n_trials if n_trials else 0.0,
mean_response_chars=total_response_chars / n_trials if n_trials else 0.0,
elapsed_s=elapsed,
per_trial=per_trial,
)
def write_cell_csv(result: NIAHCellResult, out_path: Path) -> None:
out_path.parent.mkdir(parents=True, exist_ok=True)
with open(out_path, "w", newline="") as f:
writer = csv.DictWriter(
f,
fieldnames=["trial_id", "needle_position_frac", "haystack_chars", "magic", "response_chars", "correct"],
)
writer.writeheader()
writer.writerows(result.per_trial)
# ---------------------------------------------------------------------------
# Self-test (run this file directly to verify harness logic)
# ---------------------------------------------------------------------------
def _selftest_oracle_correct(prompt: str) -> str:
"""Mock 'perfect' model: extract the needle by regex and reproduce it."""
m = re.search(r"The secret password is ([A-Z0-9]+)\.", prompt)
return f"The password is {m.group(1)}." if m else "(no password found)"
def _selftest_oracle_wrong(prompt: str) -> str:
"""Mock 'broken' model: returns a generic response without the needle."""
return "I cannot find a specific password in the passage."
def _selftest_oracle_partial(prompt: str) -> str:
"""Mock 'lossy' model: 70% accuracy."""
m = re.search(r"The secret password is ([A-Z0-9]+)\.", prompt)
if m and random.random() < 0.7:
return f"Password: {m.group(1)}"
return "I'm not sure what the password is."
def selftest():
print("[niah-selftest] generating one trial at 4000 chars...")
gen = NIAHGenerator(seed=1)
trial = gen.generate(target_chars=4000, position_frac=0.5)
print(f" trial_id={trial.trial_id}")
print(f" haystack_chars={trial.haystack_chars}")
print(f" magic={trial.magic}")
print(f" needle_position_frac={trial.needle_position_frac}")
print(f" prompt[:200]={trial.prompt[:200]!r}")
print(f" expected={trial.expected}")
assert trial.expected in trial.prompt, "needle must appear in the prompt"
assert NIAHScorer.is_correct(f"The password is {trial.expected}.", trial.expected)
assert not NIAHScorer.is_correct("I don't know.", trial.expected)
print("[niah-selftest] generator + scorer basic checks PASS")
print("\n[niah-selftest] runner with oracle_correct (should be 100%)...")
r = run_niah_cell(_selftest_oracle_correct, target_chars=4000, n_trials=25, seed=1)
print(f" accuracy={r.accuracy:.2%} ({r.n_correct}/{r.n_trials}) elapsed={r.elapsed_s:.2f}s")
assert r.accuracy == 1.0, f"oracle_correct should be 100%, got {r.accuracy:.2%}"
print("\n[niah-selftest] runner with oracle_wrong (should be 0%)...")
r = run_niah_cell(_selftest_oracle_wrong, target_chars=4000, n_trials=25, seed=1)
print(f" accuracy={r.accuracy:.2%} ({r.n_correct}/{r.n_trials}) elapsed={r.elapsed_s:.2f}s")
assert r.accuracy == 0.0, f"oracle_wrong should be 0%, got {r.accuracy:.2%}"
print("\n[niah-selftest] runner with oracle_partial (should be ~70%)...")
random.seed(42)
r = run_niah_cell(_selftest_oracle_partial, target_chars=4000, n_trials=200, seed=1)
print(f" accuracy={r.accuracy:.2%} ({r.n_correct}/{r.n_trials}) elapsed={r.elapsed_s:.2f}s")
assert 0.55 < r.accuracy < 0.85, f"oracle_partial should be ~70%, got {r.accuracy:.2%}"
print("\n[niah-selftest] writing per-trial CSV...")
out = Path("/tmp/niah_selftest.csv")
write_cell_csv(r, out)
n_rows = sum(1 for _ in open(out)) - 1
print(f" wrote {n_rows} rows -> {out}")
assert n_rows == 200
print("\n[niah-selftest] ALL CHECKS PASS")
if __name__ == "__main__":
selftest()