kv-cache-eviction-mla / scripts /run_control_niah.py
GENOMA LABS / research
Round 3b prep: control-validation driver (plain HF transformers, no AirLLM)
754890f
"""Round 3b: validate the NIAH harness against a small control model on plain
HuggingFace transformers (no AirLLM). Confirms the harness works end-to-end
on a real LLM before we point it at the much heavier Kimi K2.6 path.
Defaults to Qwen2.5-3B-Instruct (downloads ~6 GB on first run). For a smaller
fingerprint, pass `--model unsloth/Qwen2.5-3B-Instruct-unsloth-bnb-4bit` (~2 GB).
Run:
python scripts/run_control_niah.py \
--model Qwen/Qwen2.5-3B-Instruct \
--context-lengths 2000 8000 \
--trials 10 \
--out-dir /tmp/niah_control
Expected: Qwen2.5-3B should score >= 80% on the 2K-char (~500-token) cell.
The 8K cell will be lower; that's the point — proves the harness picks up
context-length quality decay.
"""
from __future__ import annotations
import argparse
import csv
import sys
import time
from pathlib import Path
# Local imports
sys.path.insert(0, str(Path(__file__).resolve().parent))
from niah_harness import run_niah_cell, write_cell_csv
def make_hf_generate_fn(model_id: str, dtype: str = "bfloat16", device: str = "auto", max_new_tokens: int = 64):
"""Return (model, tokenizer, generate_fn) for a HF transformers model."""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
print(f"[control] loading {model_id} ({dtype}) ...")
t0 = time.time()
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
torch_dtype = getattr(torch, dtype) if hasattr(torch, dtype) else torch.float16
model = AutoModelForCausalLM.from_pretrained(
model_id,
dtype=torch_dtype,
device_map=device,
trust_remote_code=True,
)
model.eval()
print(f"[control] loaded in {time.time()-t0:.1f}s, params: {sum(p.numel() for p in model.parameters()):,}")
def generate(prompt: str) -> str:
inputs = tok(prompt, return_tensors="pt").to(model.device)
prompt_len = inputs["input_ids"].shape[1]
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=0.0,
pad_token_id=tok.eos_token_id,
)
# Strip prompt tokens, decode just the new tokens
new_tokens = out[0, prompt_len:]
return tok.decode(new_tokens, skip_special_tokens=True)
return model, tok, generate
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model", default="Qwen/Qwen2.5-3B-Instruct",
help="HF model ID. Defaults to Qwen2.5-3B-Instruct.")
ap.add_argument("--dtype", default="bfloat16")
ap.add_argument("--device", default="auto")
ap.add_argument("--context-lengths", type=int, nargs="+", default=[2000, 8000])
ap.add_argument("--trials", type=int, default=10)
ap.add_argument("--max-new-tokens", type=int, default=64)
ap.add_argument("--out-dir", default="/tmp/niah_control")
ap.add_argument("--seed", type=int, default=42)
args = ap.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
model, tok, gen = make_hf_generate_fn(
args.model, dtype=args.dtype, device=args.device, max_new_tokens=args.max_new_tokens
)
print(f"\n[control] sweep: {len(args.context_lengths)} cells, {args.trials} trials each")
summary = []
for ctx in args.context_lengths:
cell_id = f"ctx{ctx}_baseline"
print(f"\n[control] === cell {cell_id} ===")
result = run_niah_cell(
generate_fn=gen,
target_chars=ctx,
n_trials=args.trials,
seed=args.seed,
)
cell_csv = out_dir / f"{cell_id}.csv"
write_cell_csv(result, cell_csv)
print(f"[control] accuracy={result.accuracy:.2%} ({result.n_correct}/{result.n_trials}) "
f"elapsed={result.elapsed_s:.1f}s mean_response_chars={result.mean_response_chars:.0f}")
summary.append({
"cell_id": cell_id,
"ctx_chars": ctx,
"n_trials": result.n_trials,
"n_correct": result.n_correct,
"accuracy": round(result.accuracy, 4),
"elapsed_s": round(result.elapsed_s, 1),
})
summary_path = out_dir / "summary.csv"
if summary:
with open(summary_path, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=list(summary[0].keys()))
w.writeheader()
w.writerows(summary)
print(f"\n[control] summary -> {summary_path}")
# Sanity check: the harness should resolve the needle on at least the
# smallest context. If 2K accuracy is < 50%, something is wrong with
# either the harness OR the model is fundamentally broken at retrieval.
if summary:
smallest = min(summary, key=lambda s: s["ctx_chars"])
if smallest["accuracy"] < 0.5:
print(f"\n[control] WARNING: smallest-context accuracy {smallest['accuracy']:.2%} is < 50%.")
print("[control] either the harness is broken or the model can't do NIAH at all.")
print("[control] inspect a per-trial CSV to see what the model actually produced.")
else:
print(f"\n[control] PASS: smallest-context accuracy {smallest['accuracy']:.2%} >= 50%.")
print("[control] harness validated end-to-end on a real LLM.")
if __name__ == "__main__":
main()