GENOMA LABS / research Claude Opus 4.7 (1M context) commited on
Commit
5ea05a3
·
1 Parent(s): 95298f5

Round 3a: RULER NIAH single-needle harness (logic validated)

Browse files

scripts/niah_harness.py implements the canonical RULER NIAH-1 task as a
pluggable harness for the upcoming H2O eviction quality sweep:

- NIAHGenerator: builds (haystack, needle, question) at a target context
length in chars, with position-fraction control for depth-coverage sweeps.
Uses 20 deterministic filler sentences (PG-19-style English structure)
and 12-char alphanumeric magic-string needles.
- NIAHScorer: case-insensitive exact-match scoring on the magic string.
- run_niah_cell: drives N trials at one context length, calls a pluggable
generate_fn(prompt) -> response callable. Same harness will drive HF
transformers + ollama + AirLLM Kimi backends without modification.
- write_cell_csv: per-trial output for analysis.

Self-test (no model required, runs in <1s):
- oracle_correct mock: 100% accuracy (25/25)
- oracle_wrong mock: 0% accuracy (0/25)
- oracle_partial 70% mock: 72.5% accuracy (145/200, within expected band)
- Per-trial CSV: 200 rows written cleanly

Ready for Round 3b (validation on Qwen2.5-7B control baseline) and Round 4
(full Kimi K2.6 sweep) once GS GPU headroom opens up post-Kimi-split.

Cross-ref: arXiv:2404.06654 (RULER paper) for the canonical task definition.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Files changed (1) hide show
  1. scripts/niah_harness.py +338 -0
scripts/niah_harness.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RULER NIAH (Needle-In-A-Haystack) harness for KV cache eviction benchmarks.
2
+
3
+ Implements the canonical single-needle NIAH task used to measure long-context
4
+ retrieval accuracy. Used as the quality probe for the H2O eviction sweep:
5
+ for each context length and budget, run N trials and report exact-match
6
+ accuracy on the magic-string needle.
7
+
8
+ Components:
9
+ NIAHTrial - one (haystack, needle, question) instance
10
+ NIAHGenerator - produces NIAHTrial at a given context length
11
+ NIAHScorer - exact-match scoring against a model response
12
+ NIAHRunner - drives N trials + computes accuracy
13
+
14
+ The model integration is intentionally pluggable: the runner takes a callable
15
+ that maps prompt -> response. This lets the same harness drive HF transformers
16
+ generation, ollama API calls, or any other completion backend.
17
+
18
+ Reference:
19
+ Hsieh et al. 2024, "RULER: What's the Real Context Size of Your Long-Context
20
+ Language Models?" (arXiv:2404.06654). NIAH-1 single-needle variant.
21
+ """
22
+ from __future__ import annotations
23
+
24
+ import csv
25
+ import hashlib
26
+ import random
27
+ import re
28
+ import string
29
+ import time
30
+ from dataclasses import dataclass, field
31
+ from pathlib import Path
32
+ from typing import Callable, Optional
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Filler text. RULER uses essays from PG-19 + arxiv abstracts; for a
37
+ # self-contained harness we use a deterministic Lorem-style filler that
38
+ # covers the same statistical properties the eviction policy faces:
39
+ #
40
+ # - Sentence and paragraph structure (so attention has parseable boundaries)
41
+ # - Some token-level repetition (so heavy-hitter scoring has variance)
42
+ # - No accidental needle-shaped tokens (we screen below)
43
+ # ---------------------------------------------------------------------------
44
+
45
+ _FILLER_SENTENCES = [
46
+ "The mountain ridge cast long shadows over the valley as the sun began to set behind the western peaks.",
47
+ "Old rivers carved deep canyons into the limestone cliffs across many millions of years of patient flow.",
48
+ "Travelers crossing the high pass often paused at the stone marker to read the names carved by earlier wanderers.",
49
+ "The library held leather-bound volumes whose pages had grown brittle but whose ink still spoke clearly to readers.",
50
+ "Sailors learned to read the sky before the instruments arrived, and many trusted the old patterns longer than the new.",
51
+ "Cooks in the harbor kitchens prepared meals from whatever the morning catch yielded, adapting recipes day by day.",
52
+ "Schoolchildren studied the maps with care, tracing rivers from the high country down to the coastal plains.",
53
+ "Engineers built bridges that flexed with the wind rather than fighting it, and the bridges stood for generations.",
54
+ "Farmers in the river valley rotated three crops across four fields, leaving one fallow each year for recovery.",
55
+ "Astronomers tracked the slow drift of distant stars and wrote down patterns that took lifetimes to confirm.",
56
+ "Musicians tuned their instruments by ear in the cold morning air and played until the sun warmed the wood.",
57
+ "Doctors in the clinic kept careful notes on each patient, building a record that helped them spot patterns over time.",
58
+ "Carpenters chose their timber by listening to the grain and tapping the planks to hear the resonance.",
59
+ "Sheep grazed on the high meadows during summer and were brought down to the lower pastures before the snow.",
60
+ "Blacksmiths heated the iron until it glowed and shaped each tool to fit the hand that would use it.",
61
+ "Fishermen mended their nets in the late afternoon while the boats rocked gently against the wooden pier.",
62
+ "Weavers passed the shuttle through the loom with steady hands and the rhythm carried them through long days.",
63
+ "Potters shaped clay on slow wheels and let each piece dry for several days before the firing began.",
64
+ "Beekeepers walked carefully through the orchards in spring and collected honey only when the hives ran full.",
65
+ "Gardeners trained climbing vines along stone walls and harvested berries only when the morning dew had lifted.",
66
+ ]
67
+
68
+ # Magic-string needle template. The scorer matches this exact string
69
+ # (case-insensitive, allowing surrounding punctuation).
70
+ NEEDLE_TEMPLATE = "The secret password is {magic}. Remember this exact value."
71
+ QUESTION_TEMPLATE = "What is the secret password? Answer with just the password value."
72
+
73
+
74
+ def _make_magic_string(rng: random.Random, length: int = 12) -> str:
75
+ """Generate a unique alphanumeric needle that won't collide with filler."""
76
+ return "".join(rng.choice(string.ascii_uppercase + string.digits) for _ in range(length))
77
+
78
+
79
+ # ---------------------------------------------------------------------------
80
+ # Trial generation
81
+ # ---------------------------------------------------------------------------
82
+
83
+ @dataclass
84
+ class NIAHTrial:
85
+ trial_id: str
86
+ target_chars: int
87
+ needle_position_frac: float # 0.0 to 1.0
88
+ magic: str
89
+ haystack_chars: int
90
+ prompt: str
91
+ expected: str # the magic-string the model must reproduce
92
+
93
+
94
+ class NIAHGenerator:
95
+ """Generates NIAH trials at a target context length (measured in chars).
96
+
97
+ To map chars to tokens approximately: English text averages ~4 chars/token
98
+ for most BPE tokenizers. So target_chars=20_000 ~= 5K tokens.
99
+ """
100
+
101
+ def __init__(self, seed: int = 42):
102
+ self.rng = random.Random(seed)
103
+
104
+ def generate(self, target_chars: int, position_frac: Optional[float] = None) -> NIAHTrial:
105
+ """Build one trial. position_frac in [0,1]; None = random."""
106
+ if position_frac is None:
107
+ position_frac = self.rng.uniform(0.05, 0.95)
108
+
109
+ magic = _make_magic_string(self.rng)
110
+ needle = NEEDLE_TEMPLATE.format(magic=magic)
111
+
112
+ # Build haystack until target_chars reached, leaving room for the needle.
113
+ target_haystack = target_chars - len(needle) - len(QUESTION_TEMPLATE) - 200
114
+ if target_haystack < 0:
115
+ raise ValueError(f"target_chars={target_chars} too small for the template overhead")
116
+
117
+ chunks = []
118
+ used = 0
119
+ while used < target_haystack:
120
+ sentence = self.rng.choice(_FILLER_SENTENCES)
121
+ # screen — make sure no filler sentence accidentally contains the
122
+ # magic string (vanishingly unlikely with 12 random alnum, but safe)
123
+ if magic.lower() in sentence.lower():
124
+ continue
125
+ chunks.append(sentence)
126
+ used += len(sentence) + 1 # +1 for the space we'll join with
127
+
128
+ haystack_text = " ".join(chunks)
129
+
130
+ # Insert needle at position_frac
131
+ insert_idx = int(len(haystack_text) * position_frac)
132
+ # Snap to a word boundary
133
+ while insert_idx < len(haystack_text) and haystack_text[insert_idx] != " ":
134
+ insert_idx += 1
135
+ haystack_with_needle = (
136
+ haystack_text[:insert_idx]
137
+ + " "
138
+ + needle
139
+ + " "
140
+ + haystack_text[insert_idx:]
141
+ )
142
+
143
+ prompt = (
144
+ "Read the following passage carefully. After the passage, you will be asked "
145
+ "a question about a specific detail.\n\n"
146
+ "PASSAGE:\n"
147
+ f"{haystack_with_needle}\n\n"
148
+ "QUESTION:\n"
149
+ f"{QUESTION_TEMPLATE}"
150
+ )
151
+
152
+ # Stable ID for the trial
153
+ trial_id = hashlib.md5(f"{target_chars}:{position_frac}:{magic}".encode()).hexdigest()[:12]
154
+
155
+ return NIAHTrial(
156
+ trial_id=trial_id,
157
+ target_chars=target_chars,
158
+ needle_position_frac=position_frac,
159
+ magic=magic,
160
+ haystack_chars=len(haystack_with_needle),
161
+ prompt=prompt,
162
+ expected=magic,
163
+ )
164
+
165
+
166
+ # ---------------------------------------------------------------------------
167
+ # Scoring
168
+ # ---------------------------------------------------------------------------
169
+
170
+ class NIAHScorer:
171
+ """Exact-match scoring with light normalization.
172
+
173
+ A response is correct if and only if the expected magic string appears
174
+ as a contiguous substring (case-insensitive). We don't require
175
+ case-perfect match because tokenizers occasionally case-shift small
176
+ portions of the needle; the magic string itself is uppercase + digits,
177
+ so the case shift is benign.
178
+ """
179
+
180
+ @staticmethod
181
+ def is_correct(response: str, expected: str) -> bool:
182
+ return expected.upper() in response.upper()
183
+
184
+ @staticmethod
185
+ def score(trials: list[NIAHTrial], responses: list[str]) -> dict:
186
+ """Return aggregate metrics dict."""
187
+ assert len(trials) == len(responses), "trials and responses length mismatch"
188
+ n = len(trials)
189
+ correct = sum(NIAHScorer.is_correct(r, t.expected) for r, t in zip(responses, trials))
190
+ return {
191
+ "n_trials": n,
192
+ "n_correct": correct,
193
+ "accuracy": correct / n if n else 0.0,
194
+ }
195
+
196
+
197
+ # ---------------------------------------------------------------------------
198
+ # Runner
199
+ # ---------------------------------------------------------------------------
200
+
201
+ @dataclass
202
+ class NIAHCellResult:
203
+ target_chars: int
204
+ n_trials: int
205
+ n_correct: int
206
+ accuracy: float
207
+ mean_response_chars: float
208
+ elapsed_s: float
209
+ per_trial: list[dict] = field(default_factory=list)
210
+
211
+
212
+ def run_niah_cell(
213
+ generate_fn: Callable[[str], str],
214
+ target_chars: int,
215
+ n_trials: int = 25,
216
+ seed: int = 42,
217
+ ) -> NIAHCellResult:
218
+ """Run one (context_length) cell: N trials at the given target_chars.
219
+
220
+ generate_fn: callable str -> str. Takes the prompt, returns the model's
221
+ response. The harness does not care how the response was produced.
222
+
223
+ Returns a NIAHCellResult with per-trial details + aggregate metrics.
224
+ """
225
+ gen = NIAHGenerator(seed=seed)
226
+ per_trial = []
227
+ n_correct = 0
228
+ total_response_chars = 0
229
+ t_start = time.time()
230
+
231
+ for i in range(n_trials):
232
+ # Spread positions evenly through [0.05, 0.95] so every cell exercises
233
+ # depth-position coverage (RULER convention).
234
+ frac = 0.05 + (0.90 * i / max(1, n_trials - 1))
235
+ trial = gen.generate(target_chars=target_chars, position_frac=frac)
236
+ response = generate_fn(trial.prompt)
237
+ correct = NIAHScorer.is_correct(response, trial.expected)
238
+ n_correct += int(correct)
239
+ total_response_chars += len(response)
240
+
241
+ per_trial.append({
242
+ "trial_id": trial.trial_id,
243
+ "needle_position_frac": round(trial.needle_position_frac, 3),
244
+ "haystack_chars": trial.haystack_chars,
245
+ "magic": trial.magic,
246
+ "response_chars": len(response),
247
+ "correct": int(correct),
248
+ })
249
+
250
+ elapsed = time.time() - t_start
251
+ return NIAHCellResult(
252
+ target_chars=target_chars,
253
+ n_trials=n_trials,
254
+ n_correct=n_correct,
255
+ accuracy=n_correct / n_trials if n_trials else 0.0,
256
+ mean_response_chars=total_response_chars / n_trials if n_trials else 0.0,
257
+ elapsed_s=elapsed,
258
+ per_trial=per_trial,
259
+ )
260
+
261
+
262
+ def write_cell_csv(result: NIAHCellResult, out_path: Path) -> None:
263
+ out_path.parent.mkdir(parents=True, exist_ok=True)
264
+ with open(out_path, "w", newline="") as f:
265
+ writer = csv.DictWriter(
266
+ f,
267
+ fieldnames=["trial_id", "needle_position_frac", "haystack_chars", "magic", "response_chars", "correct"],
268
+ )
269
+ writer.writeheader()
270
+ writer.writerows(result.per_trial)
271
+
272
+
273
+ # ---------------------------------------------------------------------------
274
+ # Self-test (run this file directly to verify harness logic)
275
+ # ---------------------------------------------------------------------------
276
+
277
+ def _selftest_oracle_correct(prompt: str) -> str:
278
+ """Mock 'perfect' model: extract the needle by regex and reproduce it."""
279
+ m = re.search(r"The secret password is ([A-Z0-9]+)\.", prompt)
280
+ return f"The password is {m.group(1)}." if m else "(no password found)"
281
+
282
+
283
+ def _selftest_oracle_wrong(prompt: str) -> str:
284
+ """Mock 'broken' model: returns a generic response without the needle."""
285
+ return "I cannot find a specific password in the passage."
286
+
287
+
288
+ def _selftest_oracle_partial(prompt: str) -> str:
289
+ """Mock 'lossy' model: 70% accuracy."""
290
+ m = re.search(r"The secret password is ([A-Z0-9]+)\.", prompt)
291
+ if m and random.random() < 0.7:
292
+ return f"Password: {m.group(1)}"
293
+ return "I'm not sure what the password is."
294
+
295
+
296
+ def selftest():
297
+ print("[niah-selftest] generating one trial at 4000 chars...")
298
+ gen = NIAHGenerator(seed=1)
299
+ trial = gen.generate(target_chars=4000, position_frac=0.5)
300
+ print(f" trial_id={trial.trial_id}")
301
+ print(f" haystack_chars={trial.haystack_chars}")
302
+ print(f" magic={trial.magic}")
303
+ print(f" needle_position_frac={trial.needle_position_frac}")
304
+ print(f" prompt[:200]={trial.prompt[:200]!r}")
305
+ print(f" expected={trial.expected}")
306
+ assert trial.expected in trial.prompt, "needle must appear in the prompt"
307
+ assert NIAHScorer.is_correct(f"The password is {trial.expected}.", trial.expected)
308
+ assert not NIAHScorer.is_correct("I don't know.", trial.expected)
309
+ print("[niah-selftest] generator + scorer basic checks PASS")
310
+
311
+ print("\n[niah-selftest] runner with oracle_correct (should be 100%)...")
312
+ r = run_niah_cell(_selftest_oracle_correct, target_chars=4000, n_trials=25, seed=1)
313
+ print(f" accuracy={r.accuracy:.2%} ({r.n_correct}/{r.n_trials}) elapsed={r.elapsed_s:.2f}s")
314
+ assert r.accuracy == 1.0, f"oracle_correct should be 100%, got {r.accuracy:.2%}"
315
+
316
+ print("\n[niah-selftest] runner with oracle_wrong (should be 0%)...")
317
+ r = run_niah_cell(_selftest_oracle_wrong, target_chars=4000, n_trials=25, seed=1)
318
+ print(f" accuracy={r.accuracy:.2%} ({r.n_correct}/{r.n_trials}) elapsed={r.elapsed_s:.2f}s")
319
+ assert r.accuracy == 0.0, f"oracle_wrong should be 0%, got {r.accuracy:.2%}"
320
+
321
+ print("\n[niah-selftest] runner with oracle_partial (should be ~70%)...")
322
+ random.seed(42)
323
+ r = run_niah_cell(_selftest_oracle_partial, target_chars=4000, n_trials=200, seed=1)
324
+ print(f" accuracy={r.accuracy:.2%} ({r.n_correct}/{r.n_trials}) elapsed={r.elapsed_s:.2f}s")
325
+ assert 0.55 < r.accuracy < 0.85, f"oracle_partial should be ~70%, got {r.accuracy:.2%}"
326
+
327
+ print("\n[niah-selftest] writing per-trial CSV...")
328
+ out = Path("/tmp/niah_selftest.csv")
329
+ write_cell_csv(r, out)
330
+ n_rows = sum(1 for _ in open(out)) - 1
331
+ print(f" wrote {n_rows} rows -> {out}")
332
+ assert n_rows == 200
333
+
334
+ print("\n[niah-selftest] ALL CHECKS PASS")
335
+
336
+
337
+ if __name__ == "__main__":
338
+ selftest()