"""Round 3b: validate the NIAH harness against a small control model on plain HuggingFace transformers (no AirLLM). Confirms the harness works end-to-end on a real LLM before we point it at the much heavier Kimi K2.6 path. Defaults to Qwen2.5-3B-Instruct (downloads ~6 GB on first run). For a smaller fingerprint, pass `--model unsloth/Qwen2.5-3B-Instruct-unsloth-bnb-4bit` (~2 GB). Run: python scripts/run_control_niah.py \ --model Qwen/Qwen2.5-3B-Instruct \ --context-lengths 2000 8000 \ --trials 10 \ --out-dir /tmp/niah_control Expected: Qwen2.5-3B should score >= 80% on the 2K-char (~500-token) cell. The 8K cell will be lower; that's the point — proves the harness picks up context-length quality decay. """ from __future__ import annotations import argparse import csv import sys import time from pathlib import Path # Local imports sys.path.insert(0, str(Path(__file__).resolve().parent)) from niah_harness import run_niah_cell, write_cell_csv def make_hf_generate_fn(model_id: str, dtype: str = "bfloat16", device: str = "auto", max_new_tokens: int = 64): """Return (model, tokenizer, generate_fn) for a HF transformers model.""" import torch from transformers import AutoModelForCausalLM, AutoTokenizer print(f"[control] loading {model_id} ({dtype}) ...") t0 = time.time() tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) torch_dtype = getattr(torch, dtype) if hasattr(torch, dtype) else torch.float16 model = AutoModelForCausalLM.from_pretrained( model_id, dtype=torch_dtype, device_map=device, trust_remote_code=True, ) model.eval() print(f"[control] loaded in {time.time()-t0:.1f}s, params: {sum(p.numel() for p in model.parameters()):,}") def generate(prompt: str) -> str: inputs = tok(prompt, return_tensors="pt").to(model.device) prompt_len = inputs["input_ids"].shape[1] with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.0, pad_token_id=tok.eos_token_id, ) # Strip prompt tokens, decode just the new tokens new_tokens = out[0, prompt_len:] return tok.decode(new_tokens, skip_special_tokens=True) return model, tok, generate def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", default="Qwen/Qwen2.5-3B-Instruct", help="HF model ID. Defaults to Qwen2.5-3B-Instruct.") ap.add_argument("--dtype", default="bfloat16") ap.add_argument("--device", default="auto") ap.add_argument("--context-lengths", type=int, nargs="+", default=[2000, 8000]) ap.add_argument("--trials", type=int, default=10) ap.add_argument("--max-new-tokens", type=int, default=64) ap.add_argument("--out-dir", default="/tmp/niah_control") ap.add_argument("--seed", type=int, default=42) args = ap.parse_args() out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) model, tok, gen = make_hf_generate_fn( args.model, dtype=args.dtype, device=args.device, max_new_tokens=args.max_new_tokens ) print(f"\n[control] sweep: {len(args.context_lengths)} cells, {args.trials} trials each") summary = [] for ctx in args.context_lengths: cell_id = f"ctx{ctx}_baseline" print(f"\n[control] === cell {cell_id} ===") result = run_niah_cell( generate_fn=gen, target_chars=ctx, n_trials=args.trials, seed=args.seed, ) cell_csv = out_dir / f"{cell_id}.csv" write_cell_csv(result, cell_csv) print(f"[control] accuracy={result.accuracy:.2%} ({result.n_correct}/{result.n_trials}) " f"elapsed={result.elapsed_s:.1f}s mean_response_chars={result.mean_response_chars:.0f}") summary.append({ "cell_id": cell_id, "ctx_chars": ctx, "n_trials": result.n_trials, "n_correct": result.n_correct, "accuracy": round(result.accuracy, 4), "elapsed_s": round(result.elapsed_s, 1), }) summary_path = out_dir / "summary.csv" if summary: with open(summary_path, "w", newline="") as f: w = csv.DictWriter(f, fieldnames=list(summary[0].keys())) w.writeheader() w.writerows(summary) print(f"\n[control] summary -> {summary_path}") # Sanity check: the harness should resolve the needle on at least the # smallest context. If 2K accuracy is < 50%, something is wrong with # either the harness OR the model is fundamentally broken at retrieval. if summary: smallest = min(summary, key=lambda s: s["ctx_chars"]) if smallest["accuracy"] < 0.5: print(f"\n[control] WARNING: smallest-context accuracy {smallest['accuracy']:.2%} is < 50%.") print("[control] either the harness is broken or the model can't do NIAH at all.") print("[control] inspect a per-trial CSV to see what the model actually produced.") else: print(f"\n[control] PASS: smallest-context accuracy {smallest['accuracy']:.2%} >= 50%.") print("[control] harness validated end-to-end on a real LLM.") if __name__ == "__main__": main()